[
  {
    "path": ".ai/AGENTS.md",
    "content": "# Diffusers — Agent Guide\n\n## Coding style\n\nStrive to write code as simple and explicit as possible.\n\n- Minimize small helper/utility functions — inline the logic instead. A reader should be able to follow the full flow without jumping between functions.\n- No defensive code or unused code paths — do not add fallback paths, safety checks, or configuration options \"just in case\". When porting from a research repo, delete training-time code paths, experimental flags, and ablation branches entirely — only keep the inference path you are actually integrating.\n- Do not guess user intent and silently correct behavior. Make the expected inputs clear in the docstring, and raise a concise error for unsupported cases rather than adding complex fallback logic.\n\n---\n\n### Dependencies\n- No new mandatory dependency without discussion (e.g. `einops`)\n- Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`\n\n## Code formatting\n- `make style` and `make fix-copies` should be run as the final step before opening a PR\n\n### Copied Code\n- Many classes are kept in sync with a source via a `# Copied from ...` header comment\n- Do not edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source\n- Remove the header to intentionally break the link\n\n### Models\n- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.\n- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.\n- See the **model-integration** skill for the attention pattern, pipeline rules, test setup instructions, and other important details.\n\n## Skills\n\nTask-specific guides live in `.ai/skills/` and are loaded on demand by AI agents.\nAvailable skills: **model-integration** (adding/converting pipelines), **parity-testing** (debugging numerical parity).\n"
  },
  {
    "path": ".ai/skills/model-integration/SKILL.md",
    "content": "---\nname: integrating-models\ndescription: >\n  Use when adding a new model or pipeline to diffusers, setting up file\n  structure for a new model, converting a pipeline to modular format, or\n  converting weights for a new version of an already-supported model.\n---\n\n## Goal\n\nIntegrate a new model into diffusers end-to-end. The overall flow:\n\n1. **Gather info** — ask the user for the reference repo, setup guide, a runnable inference script, and other objectives such as standard vs modular.\n2. **Confirm the plan** — once you have everything, tell the user exactly what you'll do: e.g. \"I'll integrate model X with pipeline Y into diffusers based on your script. I'll run parity tests (model-level and pipeline-level) using the `parity-testing` skill to verify numerical correctness against the reference.\"\n3. **Implement** — write the diffusers code (model, pipeline, scheduler if needed), convert weights, register in `__init__.py`.\n4. **Parity test** — use the `parity-testing` skill to verify component and e2e parity against the reference implementation.\n5. **Deliver a unit test** — provide a self-contained test script that runs the diffusers implementation, checks numerical output (np allclose), and saves an image/video for visual verification. This is what the user runs to confirm everything works.\n\nWork one workflow at a time — get it to full parity before moving on.\n\n## Setup — gather before starting\n\nBefore writing any code, gather info in this order:\n\n1. **Reference repo** — ask for the github link. If they've already set it up locally, ask for the path. Otherwise, ask what setup steps are needed (install deps, download checkpoints, set env vars, etc.) and run through them before proceeding.\n2. **Inference script** — ask for a runnable end-to-end script for a basic workflow first (e.g. T2V). Then ask what other workflows they want to support (I2V, V2V, etc.) and agree on the full implementation order together.\n3. **Standard vs modular** — standard pipelines, modular, or both?\n\nUse `AskUserQuestion` with structured choices for step 3 when the options are known.\n\n## Standard Pipeline Integration\n\n### File structure for a new model\n\n```\nsrc/diffusers/\n  models/transformers/transformer_<model>.py     # The core model\n  schedulers/scheduling_<model>.py               # If model needs a custom scheduler\n  pipelines/<model>/\n    __init__.py\n    pipeline_<model>.py                          # Main pipeline\n    pipeline_<model>_<variant>.py                # Variant pipelines (e.g. pyramid, distilled)\n    pipeline_output.py                           # Output dataclass\n  loaders/lora_pipeline.py                       # LoRA mixin (add to existing file)\n\ntests/\n  models/transformers/test_models_transformer_<model>.py\n  pipelines/<model>/test_<model>.py\n  lora/test_lora_layers_<model>.py\n\ndocs/source/en/api/\n  pipelines/<model>.md\n  models/<model>_transformer3d.md                # or appropriate name\n```\n\n### Integration checklist\n\n- [ ] Implement transformer model with `from_pretrained` support\n- [ ] Implement or reuse scheduler\n- [ ] Implement pipeline(s) with `__call__` method\n- [ ] Add LoRA support if applicable\n- [ ] Register all classes in `__init__.py` files (lazy imports)\n- [ ] Write unit tests (model, pipeline, LoRA)\n- [ ] Write docs\n- [ ] Run `make style` and `make quality`\n- [ ] Test parity with reference implementation (see `parity-testing` skill)\n\n### Attention pattern\n\nAttention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.\n\n```python\n# transformer_mymodel.py\n\nclass MyModelAttnProcessor:\n    _attention_backend = None\n    _parallel_config = None\n\n    def __call__(self, attn, hidden_states, attention_mask=None, ...):\n        query = attn.to_q(hidden_states)\n        key = attn.to_k(hidden_states)\n        value = attn.to_v(hidden_states)\n        # reshape, apply rope, etc.\n        hidden_states = dispatch_attention_fn(\n            query, key, value,\n            attn_mask=attention_mask,\n            backend=self._attention_backend,\n            parallel_config=self._parallel_config,\n        )\n        hidden_states = hidden_states.flatten(2, 3)\n        return attn.to_out[0](hidden_states)\n\n\nclass MyModelAttention(nn.Module, AttentionModuleMixin):\n    _default_processor_cls = MyModelAttnProcessor\n    _available_processors = [MyModelAttnProcessor]\n\n    def __init__(self, query_dim, heads=8, dim_head=64, ...):\n        super().__init__()\n        self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)\n        self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)\n        self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)\n        self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])\n        self.set_processor(MyModelAttnProcessor())\n\n    def forward(self, hidden_states, attention_mask=None, **kwargs):\n        return self.processor(self, hidden_states, attention_mask, **kwargs)\n```\n\nConsult the implementations in `src/diffusers/models/transformers/` if you need further references.\n\n### Implementation rules\n\n1. **Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also \"improve\" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.\n2. **Pipelines must inherit from `DiffusionPipeline`.** Consult implementations in `src/diffusers/pipelines` in case you need references.\n3. **Don't subclass an existing pipeline for a variant.** DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`).\n\n### Test setup\n\n- Slow tests gated with `@slow` and `RUN_SLOW=1`\n- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.\n\n### Common diffusers conventions\n\n- Pipelines inherit from `DiffusionPipeline`\n- Models use `ModelMixin` with `register_to_config` for config serialization\n- Schedulers use `SchedulerMixin` with `ConfigMixin`\n- Use `@torch.no_grad()` on pipeline `__call__`\n- Support `output_type=\"latent\"` for skipping VAE decode\n- Support `generator` parameter for reproducibility\n- Use `self.progress_bar(timesteps)` for progress tracking\n\n## Gotchas\n\n1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.\n\n2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.\n\n3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.\n\n4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.\n\n5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.\n\n6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.\n\n7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.\n\n8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.\n\n---\n\n## Modular Pipeline Conversion\n\nSee [modular-conversion.md](modular-conversion.md) for the full guide on converting standard pipelines to modular format, including block types, build order, guider abstraction, and conversion checklist.\n\n---\n\n## Weight Conversion Tips\n\n<!-- TODO: Add concrete examples as we encounter them. Common patterns to watch for:\n  - Fused QKV weights that need splitting into separate Q, K, V\n  - Scale/shift ordering differences (reference stores [shift, scale], diffusers expects [scale, shift])\n  - Weight transpositions (linear stored as transposed conv, or vice versa)\n  - Interleaved head dimensions that need reshaping\n  - Bias terms absorbed into different layers\n  Add each with a before/after code snippet showing the conversion. -->\n"
  },
  {
    "path": ".ai/skills/model-integration/modular-conversion.md",
    "content": "# Modular Pipeline Conversion Reference\n\n## When to use\n\nModular pipelines break a monolithic `__call__` into composable blocks. Convert when:\n- The model supports multiple workflows (T2V, I2V, V2V, etc.)\n- Users need to swap guidance strategies (CFG, CFG-Zero*, PAG)\n- You want to share blocks across pipeline variants\n\n## File structure\n\n```\nsrc/diffusers/modular_pipelines/<model>/\n  __init__.py                          # Lazy imports\n  modular_pipeline.py                  # Pipeline class (tiny, mostly config)\n  encoders.py                          # Text encoder + image/video VAE encoder blocks\n  before_denoise.py                    # Pre-denoise setup blocks\n  denoise.py                           # The denoising loop blocks\n  decoders.py                          # VAE decode block\n  modular_blocks_<model>.py            # Block assembly (AutoBlocks)\n```\n\n## Block types decision tree\n\n```\nIs this a single operation?\n  YES -> ModularPipelineBlocks (leaf block)\n\nDoes it run multiple blocks in sequence?\n  YES -> SequentialPipelineBlocks\n    Does it iterate (e.g. chunk loop)?\n      YES -> LoopSequentialPipelineBlocks\n\nDoes it choose ONE block based on which input is present?\n  Is the selection 1:1 with trigger inputs?\n    YES -> AutoPipelineBlocks (simple trigger mapping)\n    NO  -> ConditionalPipelineBlocks (custom select_block method)\n```\n\n## Build order (easiest first)\n\n1. `decoders.py` -- Takes latents, runs VAE decode, returns images/videos\n2. `encoders.py` -- Takes prompt, returns prompt_embeds. Add image/video VAE encoder if needed\n3. `before_denoise.py` -- Timesteps, latent prep, noise setup. Each logical operation = one block\n4. `denoise.py` -- The hardest. Convert guidance to guider abstraction\n\n## Key pattern: Guider abstraction\n\nOriginal pipeline has guidance baked in:\n```python\nfor i, t in enumerate(timesteps):\n    noise_pred = self.transformer(latents, prompt_embeds, ...)\n    if self.do_classifier_free_guidance:\n        noise_uncond = self.transformer(latents, negative_prompt_embeds, ...)\n        noise_pred = noise_uncond + scale * (noise_pred - noise_uncond)\n    latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n```\n\nModular pipeline separates concerns:\n```python\nguider_inputs = {\n    \"encoder_hidden_states\": (prompt_embeds, negative_prompt_embeds),\n}\n\nfor i, t in enumerate(timesteps):\n    components.guider.set_state(step=i, num_inference_steps=num_steps, timestep=t)\n    guider_state = components.guider.prepare_inputs(guider_inputs)\n\n    for batch in guider_state:\n        components.guider.prepare_models(components.transformer)\n        cond_kwargs = {k: getattr(batch, k) for k in guider_inputs}\n        context_name = getattr(batch, components.guider._identifier_key)\n        with components.transformer.cache_context(context_name):\n            batch.noise_pred = components.transformer(\n                hidden_states=latents, timestep=timestep,\n                return_dict=False, **cond_kwargs, **shared_kwargs,\n            )[0]\n        components.guider.cleanup_models(components.transformer)\n\n    noise_pred = components.guider(guider_state)[0]\n    latents = components.scheduler.step(noise_pred, t, latents, generator=generator)[0]\n```\n\n## Key pattern: Chunk loops for video models\n\nUse `LoopSequentialPipelineBlocks` for outer loop:\n```python\nclass ChunkDenoiseStep(LoopSequentialPipelineBlocks):\n    block_classes = [PrepareChunkStep, NoiseGenStep, DenoiseInnerStep, UpdateStep]\n```\n\nNote: blocks inside `LoopSequentialPipelineBlocks` receive `(components, block_state, k)` where `k` is the loop iteration index.\n\n## Key pattern: Workflow selection\n\n```python\nclass AutoDenoise(ConditionalPipelineBlocks):\n    block_classes = [V2VDenoiseStep, I2VDenoiseStep, T2VDenoiseStep]\n    block_trigger_inputs = [\"video_latents\", \"image_latents\"]\n    default_block_name = \"text2video\"\n```\n\n## Standard InputParam/OutputParam templates\n\n```python\n# Inputs\nInputParam.template(\"prompt\")              # str, required\nInputParam.template(\"negative_prompt\")     # str, optional\nInputParam.template(\"image\")               # PIL.Image, optional\nInputParam.template(\"generator\")           # torch.Generator, optional\nInputParam.template(\"num_inference_steps\") # int, default=50\nInputParam.template(\"latents\")             # torch.Tensor, optional\n\n# Outputs\nOutputParam.template(\"prompt_embeds\")\nOutputParam.template(\"negative_prompt_embeds\")\nOutputParam.template(\"image_latents\")\nOutputParam.template(\"latents\")\nOutputParam.template(\"videos\")\nOutputParam.template(\"images\")\n```\n\n## ComponentSpec patterns\n\n```python\n# Heavy models - loaded from pretrained\nComponentSpec(\"transformer\", YourTransformerModel)\nComponentSpec(\"vae\", AutoencoderKL)\n\n# Lightweight objects - created inline from config\nComponentSpec(\n    \"guider\",\n    ClassifierFreeGuidance,\n    config=FrozenDict({\"guidance_scale\": 7.5}),\n    default_creation_method=\"from_config\"\n)\n```\n\n## Conversion checklist\n\n- [ ] Read original pipeline's `__call__` end-to-end, map stages\n- [ ] Write test scripts (reference + target) with identical seeds\n- [ ] Create file structure under `modular_pipelines/<model>/`\n- [ ] Write decoder block (simplest)\n- [ ] Write encoder blocks (text, image, video)\n- [ ] Write before_denoise blocks (timesteps, latent prep, noise)\n- [ ] Write denoise block with guider abstraction (hardest)\n- [ ] Create pipeline class with `default_blocks_name`\n- [ ] Assemble blocks in `modular_blocks_<model>.py`\n- [ ] Wire up `__init__.py` with lazy imports\n- [ ] Run `make style` and `make quality`\n- [ ] Test all workflows for parity with reference\n"
  },
  {
    "path": ".ai/skills/parity-testing/SKILL.md",
    "content": "---\nname: testing-parity\ndescription: >\n  Use when debugging or verifying numerical parity between pipeline\n  implementations (e.g., research repo vs diffusers, standard vs modular).\n  Also relevant when outputs look wrong — washed out, pixelated, or have\n  visual artifacts — as these are usually parity bugs.\n---\n\n## Setup — gather before starting\n\nBefore writing any test code, gather:\n\n1. **Which two implementations** are being compared (e.g. research repo → diffusers, standard → modular, or research → modular). Use `AskUserQuestion` with structured choices if not already clear.\n2. **Two equivalent runnable scripts** — one for each implementation, both expected to produce identical output given the same inputs. These scripts define what \"parity\" means concretely.\n\nWhen invoked from the `model-integration` skill, you already have context: the reference script comes from step 2 of setup, and the diffusers script is the one you just wrote. You just need to make sure both scripts are runnable and use the same inputs/seed/params.\n\n## Test strategy\n\n**Component parity (CPU/float32) -- always run, as you build.**\nTest each component before assembling the pipeline. This is the foundation -- if individual pieces are wrong, the pipeline can't be right. Each component in isolation, strict max_diff < 1e-3.\n\nTest freshly converted checkpoints and saved checkpoints.\n- **Fresh**: convert from checkpoint weights, compare against reference (catches conversion bugs)\n- **Saved**: load from saved model on disk, compare against reference (catches stale saves)\n\nKeep component test scripts around -- you will need to re-run them during pipeline debugging with different inputs or config values.\n\nTemplate -- one self-contained script per component, reference and diffusers side-by-side:\n```python\n@torch.inference_mode()\ndef test_my_component(mode=\"fresh\", model_path=None):\n    # 1. Deterministic input\n    gen = torch.Generator().manual_seed(42)\n    x = torch.randn(1, 3, 64, 64, generator=gen, dtype=torch.float32)\n\n    # 2. Reference: load from checkpoint, run, free\n    ref_model = ReferenceModel.from_config(config)\n    ref_model.load_state_dict(load_weights(\"prefix\"), strict=True)\n    ref_model = ref_model.float().eval()\n    ref_out = ref_model(x).clone()\n    del ref_model\n\n    # 3. Diffusers: fresh (convert weights) or saved (from_pretrained)\n    if mode == \"fresh\":\n        diff_model = convert_my_component(load_weights(\"prefix\"))\n    else:\n        diff_model = DiffusersModel.from_pretrained(model_path, torch_dtype=torch.float32)\n    diff_model = diff_model.float().eval()\n    diff_out = diff_model(x)\n    del diff_model\n\n    # 4. Compare in same script -- no saving to disk\n    max_diff = (ref_out - diff_out).abs().max().item()\n    assert max_diff < 1e-3, f\"FAIL: max_diff={max_diff:.2e}\"\n```\nKey points: (a) both reference and diffusers component in one script -- never split into separate scripts that save/load intermediates, (b) deterministic input via seeded generator, (c) load one model at a time to fit in CPU RAM, (d) `.clone()` the reference output before deleting the model.\n\n**E2E visual (GPU/bfloat16) -- once the pipeline is assembled.**\nBoth pipelines generate independently with identical seeds/params. Save outputs and compare visually. If outputs look identical, you're done -- no need for deeper testing.\n\n**Pipeline stage tests -- only if E2E fails and you need to isolate the bug.**\nIf the user already suspects where divergence is, start there. Otherwise, work through stages in order.\n\nFirst, **match noise generation**: the way initial noise/latents are constructed (seed handling, generator, randn call order) often differs between the two scripts. If the noise doesn't match, nothing downstream will match. Check how noise is initialized in the diffusers script — if it doesn't match the reference, temporarily change it to match. Note what you changed so it can be reverted after parity is confirmed.\n\nFor small models, run on CPU/float32 for strict comparison. For large models (e.g. 22B params), CPU/float32 is impractical -- use GPU/bfloat16 with `enable_model_cpu_offload()` and relax tolerances (max_diff < 1e-1 for bfloat16 is typical for passing tests; cosine similarity > 0.9999 is a good secondary check).\n\nTest encode and decode stages first -- they're simpler and bugs there are easier to fix. Only debug the denoising loop if encode and decode both pass.\n\nThe challenge: pipelines are monolithic `__call__` methods -- you can't just call \"the encode part\". See [checkpoint-mechanism.md](checkpoint-mechanism.md) for the checkpoint class that lets you stop, save, or inject tensors at named locations inside the pipeline.\n\n**Stage test order — encode, decode, then denoise:**\n\n- **`encode`** (test first): Stop both pipelines at `\"preloop\"`. Compare **every single variable** that will be consumed by the denoising loop -- not just latents and sigmas, but also prompt embeddings, attention masks, positional coordinates, connector outputs, and any conditioning inputs.\n- **`decode`** (test second, before denoise): Run the reference pipeline fully -- checkpoint the post-loop latents AND let it finish to get the decoded output. Then feed those same post-loop latents through the diffusers pipeline's decode path. Compare both numerically AND visually.\n- **`denoise`** (test last): Run both pipelines with realistic `num_steps` (e.g. 30) so the scheduler computes correct sigmas/timesteps, but stop after 2 loop iterations using `after_step_1`. Don't set `num_steps=2` -- that produces unrealistic sigma schedules.\n\n```python\n# Encode stage -- stop before the loop, compare ALL inputs:\nref_ckpts = {\"preloop\": Checkpoint(save=True, stop=True)}\nrun_reference_pipeline(ref_ckpts)\nref_data = ref_ckpts[\"preloop\"].data\n\ndiff_ckpts = {\"preloop\": Checkpoint(save=True, stop=True)}\nrun_diffusers_pipeline(diff_ckpts)\ndiff_data = diff_ckpts[\"preloop\"].data\n\n# Compare EVERY variable consumed by the denoise loop:\ncompare_tensors(\"latents\", ref_data[\"latents\"], diff_data[\"latents\"])\ncompare_tensors(\"sigmas\", ref_data[\"sigmas\"], diff_data[\"sigmas\"])\ncompare_tensors(\"prompt_embeds\", ref_data[\"prompt_embeds\"], diff_data[\"prompt_embeds\"])\n# ... every single tensor the transformer forward() will receive\n```\n\n**E2E-injected visual test**: Once you've identified a suspected root cause using stage tests, confirm it with an e2e-injected run -- inject the known-good tensor from reference and generate a full video. If the output looks identical to reference, you've confirmed the root cause.\n\n## Debugging technique: Injection for root-cause isolation\n\nWhen stage tests show divergence, **inject a known-good tensor from one pipeline into the other** to test whether the remaining code is correct.\n\nThe principle: if you suspect input X is the root cause of divergence in stage S:\n1. Run the reference pipeline and capture X\n2. Run the diffusers pipeline but **replace** its X with the reference's X (via checkpoint load)\n3. Compare outputs of stage S\n\nIf outputs now match: X was the root cause. If they still diverge: the bug is in the stage logic itself, not in X.\n\n| What you're testing | What you inject | Where you inject |\n|---|---|---|\n| Is the decode stage correct? | Post-loop latents from reference | Before decode |\n| Is the denoise loop correct? | Pre-loop latents from reference | Before the loop |\n| Is step N correct? | Post-step-(N-1) latents from reference | Before step N |\n\n**Per-step accumulation tracing**: When injection confirms the loop is correct but you want to understand *how* a small initial difference compounds, capture `after_step_{i}` for every step and plot the max_diff curve. A healthy curve stays bounded; an exponential blowup in later steps points to an amplification mechanism (see Pitfall #13 in [pitfalls.md](pitfalls.md)).\n\n## Debugging technique: Visual comparison via frame extraction\n\nFor video pipelines, numerical metrics alone can be misleading. Extract and view individual frames:\n\n```python\nimport numpy as np\nfrom PIL import Image\n\ndef extract_frames(video_np, frame_indices):\n    \"\"\"video_np: (frames, H, W, 3) float array in [0, 1]\"\"\"\n    for idx in frame_indices:\n        frame = (video_np[idx] * 255).clip(0, 255).astype(np.uint8)\n        img = Image.fromarray(frame)\n        img.save(f\"frame_{idx}.png\")\n\n# Compare specific frames from both pipelines\nextract_frames(ref_video, [0, 60, 120])\nextract_frames(diff_video, [0, 60, 120])\n```\n\n## Testing rules\n\n1. **Never use reference code in the diffusers test path.** Each side must use only its own code.\n2. **Never monkey-patch model internals in tests.** Do not replace `model.forward` or patch internal methods.\n3. **Debugging instrumentation must be non-destructive.** Checkpoint captures for debugging are fine, but must not alter control flow or outputs.\n4. **Prefer CPU/float32 for numerical comparison when practical.** Float32 avoids bfloat16 precision noise that obscures real bugs. But for large models (22B+), GPU/bfloat16 with `enable_model_cpu_offload()` is necessary -- use relaxed tolerances and cosine similarity as a secondary metric.\n5. **Test both fresh conversion AND saved model.** Fresh catches conversion logic bugs; saved catches stale/corrupted weights from previous runs.\n6. **Diff configs before debugging.** Before investigating any divergence, dump and compare all config values. A 30-second config diff prevents hours of debugging based on wrong assumptions.\n7. **Never modify cached/downloaded model configs directly.** Don't edit files in `~/.cache/huggingface/`. Instead, save to a local directory or open a PR on the upstream repo.\n8. **Compare ALL loop inputs in the encode test.** The preloop checkpoint must capture every single tensor the transformer forward() will receive.\n\n## Comparison utilities\n\n```python\ndef compare_tensors(name: str, a: torch.Tensor, b: torch.Tensor, tol: float = 1e-3) -> bool:\n    if a.shape != b.shape:\n        print(f\"  FAIL {name}: shape mismatch {a.shape} vs {b.shape}\")\n        return False\n    diff = (a.float() - b.float()).abs()\n    max_diff = diff.max().item()\n    mean_diff = diff.mean().item()\n    cos = torch.nn.functional.cosine_similarity(\n        a.float().flatten().unsqueeze(0), b.float().flatten().unsqueeze(0)\n    ).item()\n    passed = max_diff < tol\n    print(f\"  {'PASS' if passed else 'FAIL'} {name}: max={max_diff:.2e}, mean={mean_diff:.2e}, cos={cos:.5f}\")\n    return passed\n```\nCosine similarity is especially useful for GPU/bfloat16 tests where max_diff can be noisy -- `cos > 0.9999` is a strong signal even when max_diff exceeds tolerance.\n\n## Gotchas\n\nSee [pitfalls.md](pitfalls.md) for the full list of gotchas to watch for during parity testing.\n"
  },
  {
    "path": ".ai/skills/parity-testing/checkpoint-mechanism.md",
    "content": "# Checkpoint Mechanism for Stage Testing\n\n## Overview\n\nPipelines are monolithic `__call__` methods -- you can't just call \"the encode part\". The checkpoint mechanism lets you stop, save, or inject tensors at named locations inside the pipeline.\n\n## The Checkpoint class\n\nAdd a `_checkpoints` argument to both the diffusers pipeline and the reference implementation.\n\n```python\n@dataclass\nclass Checkpoint:\n    save: bool = False   # capture variables into ckpt.data\n    stop: bool = False   # halt pipeline after this point\n    load: bool = False   # inject ckpt.data into local variables\n    data: dict = field(default_factory=dict)\n```\n\n## Pipeline instrumentation\n\nThe pipeline accepts an optional `dict[str, Checkpoint]`. Place checkpoint calls at boundaries between pipeline stages -- after each encoder, before the denoising loop (capture all loop inputs), after each loop iteration, after the loop (capture final latents before decode).\n\n```python\ndef __call__(self, prompt, ..., _checkpoints=None):\n    # --- text encoding ---\n    prompt_embeds = self.text_encoder(prompt)\n    _maybe_checkpoint(_checkpoints, \"text_encoding\", {\n        \"prompt_embeds\": prompt_embeds,\n    })\n\n    # --- prepare latents, sigmas, positions ---\n    latents = self.prepare_latents(...)\n    sigmas = self.scheduler.sigmas\n    # ...\n\n    _maybe_checkpoint(_checkpoints, \"preloop\", {\n        \"latents\": latents,\n        \"sigmas\": sigmas,\n        \"prompt_embeds\": prompt_embeds,\n        \"prompt_attention_mask\": prompt_attention_mask,\n        \"video_coords\": video_coords,\n        # capture EVERYTHING the loop needs -- every tensor the transformer\n        # forward() receives. Missing even one variable here means you can't\n        # tell if it's the source of divergence during denoise debugging.\n    })\n\n    # --- denoising loop ---\n    for i, t in enumerate(timesteps):\n        noise_pred = self.transformer(latents, t, prompt_embeds, ...)\n        latents = self.scheduler.step(noise_pred, t, latents)[0]\n\n        _maybe_checkpoint(_checkpoints, f\"after_step_{i}\", {\n            \"latents\": latents,\n        })\n\n    _maybe_checkpoint(_checkpoints, \"post_loop\", {\n        \"latents\": latents,\n    })\n\n    # --- decode ---\n    video = self.vae.decode(latents)\n    return video\n```\n\n## The helper function\n\nEach `_maybe_checkpoint` call does three things based on the Checkpoint's flags: `save` captures the local variables into `ckpt.data`, `load` injects pre-populated `ckpt.data` back into local variables, `stop` halts execution (raises an exception caught at the top level).\n\n```python\ndef _maybe_checkpoint(checkpoints, name, data):\n    if not checkpoints:\n        return\n    ckpt = checkpoints.get(name)\n    if ckpt is None:\n        return\n    if ckpt.save:\n        ckpt.data.update(data)\n    if ckpt.stop:\n        raise PipelineStop  # caught at __call__ level, returns None\n```\n\n## Injection support\n\nAdd `load` support at each checkpoint where you might want to inject:\n\n```python\n_maybe_checkpoint(_checkpoints, \"preloop\", {\"latents\": latents, ...})\n\n# Load support: replace local variables with injected data\nif _checkpoints:\n    ckpt = _checkpoints.get(\"preloop\")\n    if ckpt is not None and ckpt.load:\n        latents = ckpt.data[\"latents\"].to(device=device, dtype=latents.dtype)\n```\n\n## Key insight\n\nThe checkpoint dict is passed into the pipeline and mutated in-place. After the pipeline returns (or stops early), you read back `ckpt.data` to get the captured tensors. Both pipelines save under their own key names, so the test maps between them (e.g. reference `\"video_state.latent\"` -> diffusers `\"latents\"`).\n\n## Memory management for large models\n\nFor large models, free the source pipeline's GPU memory before loading the target pipeline. Clone injected tensors to CPU, delete everything else, then run the target with `enable_model_cpu_offload()`.\n"
  },
  {
    "path": ".ai/skills/parity-testing/pitfalls.md",
    "content": "# Complete Pitfalls Reference\n\n## 1. Global CPU RNG\n`MultivariateNormal.sample()` uses the global CPU RNG, not `torch.Generator`. Must call `torch.manual_seed(seed)` before each pipeline run. A `generator=` kwarg won't help.\n\n## 2. Timestep dtype\nMany transformers expect `int64` timesteps. `get_timestep_embedding` casts to float, so `745.3` and `745` produce different embeddings. Match the reference's casting.\n\n## 3. Guidance parameter mapping\nParameter names may differ: reference `zero_steps=1` (meaning `i <= 1`, 2 steps) vs target `zero_init_steps=2` (meaning `step < 2`, same thing). Check exact semantics.\n\n## 4. `patch_size` in noise generation\nIf noise generation depends on `patch_size` (e.g. `sample_block_noise`), it must be passed through. Missing it changes noise spatial structure.\n\n## 5. Variable shadowing in nested loops\nNested loops (stages -> chunks -> timesteps) can shadow variable names. If outer loop uses `latents` and inner loop also assigns to `latents`, scoping must match the reference.\n\n## 6. Float precision differences -- don't dismiss them\nTarget may compute in float32 where reference used bfloat16. Small per-element diffs (1e-3 to 1e-2) *look* harmless but can compound catastrophically over iterative processes like denoising loops (see Pitfalls #11 and #13). Before dismissing a precision difference: (a) check whether it feeds into an iterative process, (b) if so, trace the accumulation curve over all iterations to see if it stays bounded or grows exponentially. Only truly non-iterative precision diffs (e.g. in a single-pass encoder) are safe to accept.\n\n## 7. Scheduler state reset between stages\nSome schedulers accumulate state (e.g. `model_outputs` in UniPC) that must be cleared between stages.\n\n## 8. Component access\nStandard: `self.transformer`. Modular: `components.transformer`. Missing this causes AttributeError.\n\n## 9. Guider state across stages\nIn multi-stage denoising, the guider's internal state (e.g. `zero_init_steps`) may need save/restore between stages.\n\n## 10. Model storage location\nNEVER store converted models in `/tmp/` -- temporary directories get wiped on restart. Always save converted checkpoints under a persistent path in the project repo (e.g. `models/ltx23-diffusers/`).\n\n## 11. Noise dtype mismatch (causes washed-out output)\n\nReference code often generates noise in float32 then casts to model dtype (bfloat16) before storing:\n\n```python\nnoise = torch.randn(..., dtype=torch.float32, generator=gen)\nnoise = noise.to(dtype=model_dtype)  # bfloat16 -- values get quantized\n```\n\nDiffusers pipelines may keep latents in float32 throughout the loop. The per-element difference is only ~1.5e-02, but this compounds over 30 denoising steps via 1/sigma amplification (Pitfall #13) and produces completely washed-out output.\n\n**Fix**: Match the reference -- generate noise in the model's working dtype:\n```python\nlatent_dtype = self.transformer.dtype  # e.g. bfloat16\nlatents = self.prepare_latents(..., dtype=latent_dtype, ...)\n```\n\n**Detection**: Encode stage test shows initial latent max_diff of exactly ~1.5e-02. This specific magnitude is the signature of float32->bfloat16 quantization error.\n\n## 12. RoPE position dtype\n\nRoPE cosine/sine values are sensitive to position coordinate dtype. If reference uses bfloat16 positions but diffusers uses float32, the RoPE output diverges significantly (max_diff up to 2.0). Different modalities may use different position dtypes (e.g. video bfloat16, audio float32) -- check the reference carefully.\n\n## 13. 1/sigma error amplification in Euler denoising\n\nIn Euler/flow-matching, the velocity formula divides by sigma: `v = (latents - pred_x0) / sigma`. As sigma shrinks from ~1.0 (step 0) to ~0.001 (step 29), errors are amplified up to 1000x. A 1.5e-02 init difference grows linearly through mid-steps, then exponentially in final steps, reaching max_diff ~6.0. This is why dtype mismatches (Pitfalls #11, #12) that seem tiny at init produce visually broken output. Use per-step accumulation tracing to diagnose.\n\n## 14. Config value assumptions -- always diff, never assume\n\nWhen debugging parity, don't assume config values match code defaults. The published model checkpoint may override defaults with different values. A wrong assumption about a single config field can send you down hours of debugging in the wrong direction.\n\n**The pattern that goes wrong:**\n1. You see `param_x` has default `1` in the code\n2. The reference code also uses `param_x` with a default of `1`\n3. You assume both sides use `1` and apply a \"fix\" based on that\n4. But the actual checkpoint config has `param_x: 1000`, and so does the published diffusers config\n5. Your \"fix\" now *creates* divergence instead of fixing it\n\n**Prevention -- config diff first:**\n```python\n# Reference: read from checkpoint metadata (no model loading needed)\nfrom safetensors import safe_open\nimport json\nref_config = json.loads(safe_open(checkpoint_path, framework=\"pt\").metadata()[\"config\"])\n\n# Diffusers: read from model config\nfrom diffusers import MyModel\ndiff_model = MyModel.from_pretrained(model_path, subfolder=\"transformer\")\ndiff_config = dict(diff_model.config)\n\n# Compare all values\nfor key in sorted(set(list(ref_config.get(\"transformer\", {}).keys()) + list(diff_config.keys()))):\n    ref_val = ref_config.get(\"transformer\", {}).get(key, \"MISSING\")\n    diff_val = diff_config.get(key, \"MISSING\")\n    if ref_val != diff_val:\n        print(f\"  DIFF {key}: ref={ref_val}, diff={diff_val}\")\n```\n\nRun this **before** writing any hooks, analysis code, or fixes. It takes 30 seconds and catches wrong assumptions immediately.\n\n**When debugging divergence -- trace values, don't reason about them:**\nIf two implementations diverge, hook the actual intermediate values at the point of divergence rather than reading code to figure out what the values \"should\" be. Code analysis builds on assumptions; value tracing reveals facts.\n\n## 15. Decoder config mismatch (causes pixelated artifacts)\n\nThe upstream model config may have wrong values for decoder-specific parameters (e.g. `upsample_residual`, `upsample_type`). These control whether the decoder uses skip connections in upsampling -- getting them wrong produces severe pixelation or blocky artifacts.\n\n**Detection**: Feed identical post-loop latents through both decoders. If max pixel diff is large (PSNR < 40 dB) on CPU/float32, it's a real bug, not precision noise. Trace through decoder blocks (conv_in -> mid_block -> up_blocks) to find where divergence starts.\n\n**Fix**: Correct the config value. Don't edit cached files in `~/.cache/huggingface/` -- either save to a local model directory or open a PR on the upstream repo (see Testing Rule #7).\n\n## 16. Incomplete injection tests -- inject ALL variables or the test is invalid\n\nWhen doing injection tests (feeding reference tensors into the diffusers pipeline), you must inject **every** divergent input, including sigmas/timesteps. A common mistake: the preloop checkpoint saves sigmas but the injection code only loads latents and embeddings. The test then runs with different sigma schedules, making it impossible to isolate the real cause.\n\n**Prevention**: After writing injection code, verify by listing every variable the injected stage consumes and checking each one is either (a) injected from reference, or (b) confirmed identical between pipelines.\n\n## 17. bf16 connector/encoder divergence -- don't chase it\n\nWhen running on GPU/bfloat16, multi-layer encoders (e.g. 8-layer connector transformers) accumulate bf16 rounding noise that looks alarming (max_diff 0.3-2.7). Before investigating, re-run the component test on CPU/float32. If it passes (max_diff < 1e-4), the divergence is pure precision noise, not a code bug. Don't spend hours tracing through layers -- confirm on CPU/float32 and move on.\n\n## 18. Stale test fixtures\n\nWhen using saved tensors for cross-pipeline comparison, always ensure both sets of tensors were captured from the same run configuration (same seed, same config, same code version). Mixing fixtures from different runs (e.g. reference tensors from yesterday, diffusers tensors from today after a code change) creates phantom divergence that wastes debugging time. Regenerate both sides in a single test script execution.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug-report.yml",
    "content": "name: \"\\U0001F41B Bug Report\"\ndescription: Report a bug on Diffusers\nlabels: [ \"bug\" ]\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        Thanks a lot for taking the time to file this issue 🤗.\n        Issues do not only help to improve the library, but also publicly document common problems, questions, workflows for the whole community!\n        Thus, issues are of the same importance as pull requests when contributing to this library ❤️.\n        In order to make your issue as **useful for the community as possible**, let's try to stick to some simple guidelines:\n        - 1. Please try to be as precise and concise as possible.\n             *Give your issue a fitting title. Assume that someone which very limited knowledge of Diffusers can understand your issue. Add links to the source code, documentation other issues, pull requests etc...*\n        - 2. If your issue is about something not working, **always** provide a reproducible code snippet. The reader should be able to reproduce your issue by **only copy-pasting your code snippet into a Python shell**.\n             *The community cannot solve your issue if it cannot reproduce it. If your bug is related to training, add your training script and make everything needed to train public. Otherwise, just add a simple Python code snippet.*\n        - 3. Add the **minimum** amount of code / context that is needed to understand, reproduce your issue.\n             *Make the life of maintainers easy. `diffusers` is getting many issues every day. Make sure your issue is about one bug and one bug only. Make sure you add only the context, code needed to understand your issues - nothing more. Generally, every issue is a way of documenting this library, try to make it a good documentation entry.*\n        - 4. For issues related to community pipelines (i.e., the pipelines located in the `examples/community` folder), please tag the author of the pipeline in your issue thread as those pipelines are not maintained.\n  - type: markdown\n    attributes:\n      value: |\n        For more in-detail information on how to write good issues you can have a look [here](https://huggingface.co/course/chapter8/5?fw=pt).\n  - type: textarea\n    id: bug-description\n    attributes:\n      label: Describe the bug\n      description: A clear and concise description of what the bug is. If you intend to submit a pull request for this issue, tell us in the description. Thanks!\n      placeholder: Bug description\n    validations:\n      required: true\n  - type: textarea\n    id: reproduction\n    attributes:\n      label: Reproduction\n      description: Please provide a minimal reproducible code which we can copy/paste and reproduce the issue.\n      placeholder: Reproduction\n    validations:\n      required: true\n  - type: textarea\n    id: logs\n    attributes:\n      label: Logs\n      description: \"Please include the Python logs if you can.\"\n      render: shell\n  - type: textarea\n    id: system-info\n    attributes:\n      label: System Info\n      description: Please share your system info with us. You can run the command `diffusers-cli env` and copy-paste its output below.\n      placeholder: Diffusers version, platform, Python version, ...\n    validations:\n      required: true\n  - type: textarea\n    id: who-can-help\n    attributes:\n      label: Who can help?\n      description: |\n        Your issue will be replied to more quickly if you can figure out the right person to tag with @.\n        If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.\n\n        All issues are read by one of the core maintainers, so if you don't know who to tag, just leave this blank and\n        a core maintainer will ping the right person.\n\n        Please tag a maximum of 2 people.\n\n        Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...): @sayakpaul @DN6\n\n        Questions on pipelines:\n        - Stable Diffusion @yiyixuxu @asomoza\n        - Stable Diffusion XL @yiyixuxu @sayakpaul @DN6\n        - Stable Diffusion 3: @yiyixuxu @sayakpaul @DN6 @asomoza\n        - Kandinsky @yiyixuxu\n        - ControlNet @sayakpaul @yiyixuxu @DN6\n        - T2I Adapter @sayakpaul @yiyixuxu @DN6\n        - IF @DN6\n        - Text-to-Video / Video-to-Video @DN6 @a-r-r-o-w\n        - Wuerstchen @DN6\n        - Other: @yiyixuxu @DN6\n        - Improving generation quality: @asomoza\n\n        Questions on models:\n        - UNet @DN6 @yiyixuxu @sayakpaul\n        - VAE @sayakpaul @DN6 @yiyixuxu\n        - Transformers/Attention @DN6 @yiyixuxu @sayakpaul\n\n        Questions on single file checkpoints: @DN6\n\n        Questions on Schedulers: @yiyixuxu\n\n        Questions on LoRA: @sayakpaul\n\n        Questions on Textual Inversion: @sayakpaul\n\n        Questions on Training:\n        - DreamBooth @sayakpaul\n        - Text-to-Image Fine-tuning @sayakpaul\n        - Textual Inversion @sayakpaul\n        - ControlNet @sayakpaul\n\n        Questions on Tests: @DN6 @sayakpaul @yiyixuxu\n\n        Questions on Documentation: @stevhliu\n\n        Questions on JAX- and MPS-related things: @pcuenca\n\n        Questions on audio pipelines: @sanchit-gandhi\n\n\n\n      placeholder: \"@Username ...\"\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "content": "contact_links:\n  - name: Questions / Discussions\n    url: https://github.com/huggingface/diffusers/discussions\n    about: General usage questions and community discussions\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: \"\\U0001F680 Feature Request\"\nabout: Suggest an idea for this project\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Is your feature request related to a problem? Please describe.**\nA clear and concise description of what the problem is. Ex. I'm always frustrated when [...].\n\n**Describe the solution you'd like.**\nA clear and concise description of what you want to happen.\n\n**Describe alternatives you've considered.**\nA clear and concise description of any alternative solutions or features you've considered.\n\n**Additional context.**\nAdd any other context or screenshots about the feature request here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feedback.md",
    "content": "---\nname: \"💬 Feedback about API Design\"\nabout: Give feedback about the current API design\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**What API design would you like to have changed or added to the library? Why?**\n\n**What use case would this enable or better enable? Can you give us a code example?**\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/new-model-addition.yml",
    "content": "name: \"\\U0001F31F New Model/Pipeline/Scheduler Addition\"\ndescription: Submit a proposal/request to implement a new diffusion model/pipeline/scheduler\nlabels: [ \"New model/pipeline/scheduler\" ]\n\nbody:\n  - type: textarea\n    id: description-request\n    validations:\n      required: true\n    attributes:\n      label: Model/Pipeline/Scheduler description\n      description: |\n        Put any and all important information relative to the model/pipeline/scheduler\n\n  - type: checkboxes\n    id: information-tasks\n    attributes:\n      label: Open source status\n      description: |\n          Please note that if the model implementation isn't available or if the weights aren't open-source, we are less likely to implement it in `diffusers`.\n      options:\n        - label: \"The model implementation is available.\"\n        - label: \"The model weights are available (Only relevant if addition is not a scheduler).\"\n\n  - type: textarea\n    id: additional-info\n    attributes:\n      label: Provide useful links for the implementation\n      description: |\n        Please provide information regarding the implementation, the weights, and the authors.\n        Please mention the authors by @gh-username if you're aware of their usernames.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/remote-vae-pilot-feedback.yml",
    "content": "name: \"\\U0001F31F Remote VAE\"\ndescription: Feedback for remote VAE pilot\nlabels: [ \"Remote VAE\" ]\n\nbody:\n  - type: textarea\n    id: positive\n    validations:\n      required: true\n    attributes:\n      label: Did you like the remote VAE solution?\n      description: |\n        If you liked it, we would appreciate it if you could elaborate what you liked.\n\n  - type: textarea\n    id: feedback\n    validations:\n      required: true\n    attributes:\n      label: What can be improved about the current solution?\n      description: |\n        Let us know the things you would like to see improved. Note that we will work optimizing the solution once the pilot is over and we have usage.\n\n  - type: textarea\n    id: others\n    validations:\n      required: true\n    attributes:\n      label: What other VAEs you would like to see if the pilot goes well?\n      description: |\n        Provide a list of the VAEs you would like to see in the future if the pilot goes well.\n\n  - type: textarea\n    id: additional-info\n    attributes:\n      label: Notify the members of the team\n      description: |\n        Tag the following folks when submitting this feedback: @hlky @sayakpaul\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/translate.md",
    "content": "---\nname: 🌐 Translating a New Language?\nabout: Start a new translation effort in your language\ntitle: '[<languageCode>] Translating docs to <languageName>'\nlabels: WIP\nassignees: ''\n\n---\n\n<!--\nNote: Please search to see if an issue already exists for the language you are trying to translate.\n-->\n\nHi!\n\nLet's bring the documentation to all the <languageName>-speaking community 🌐.\n\nWho would want to translate? Please follow the 🤗 [TRANSLATING guide](https://github.com/huggingface/diffusers/blob/main/docs/TRANSLATING.md). Here is a list of the files ready for translation. Let us know in this issue if you'd like to translate any, and we'll add your name to the list.\n\nSome notes:\n\n* Please translate using an informal tone (imagine you are talking with a friend about Diffusers 🤗).\n* Please translate in a gender-neutral way.\n* Add your translations to the folder called `<languageCode>` inside the [source folder](https://github.com/huggingface/diffusers/tree/main/docs/source).\n* Register your translation in `<languageCode>/_toctree.yml`; please follow the order of the [English version](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml).\n* Once you're finished, open a pull request and tag this issue by including #issue-number in the description, where issue-number is the number of this issue. Please ping @stevhliu for review.\n* 🙋 If you'd like others to help you with the translation, you can also post in the 🤗 [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63).\n\nThank you so much for your help! 🤗\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "content": "# What does this PR do?\n\n<!--\nCongratulations! You've made it this far! You're not quite done yet though.\n\nOnce merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.\n\nThen, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.\n\nOnce you're done, someone will review your PR shortly (see the section \"Who can review?\" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost.\n-->\n\n<!-- Remove if not applicable -->\n\nFixes # (issue)\n\n\n## Before submitting\n- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).\n- [ ] Did you read the [contributor guideline](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md)?\n- [ ] Did you read our [philosophy doc](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md) (important for complex PRs)?\n- [ ] Was this discussed/approved via a GitHub issue or the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)? Please add a link to it if that's the case.\n- [ ] Did you make sure to update the documentation with your changes? Here are the\n      [documentation guidelines](https://github.com/huggingface/diffusers/tree/main/docs), and\n      [here are tips on formatting docstrings](https://github.com/huggingface/diffusers/tree/main/docs#writing-source-documentation).\n- [ ] Did you write any new necessary tests?\n\n\n## Who can review?\n\nAnyone in the community is free to review the PR once the tests have passed. Feel free to tag\nmembers/contributors who may be interested in your PR.\n\n<!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @.\n\n If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.\n Please tag fewer than 3 people.\n\nCore library:\n\n- Schedulers: @yiyixuxu\n- Pipelines and pipeline callbacks: @yiyixuxu and @asomoza\n- Training examples: @sayakpaul\n- Docs: @stevhliu and @sayakpaul\n- JAX and MPS: @pcuenca\n- Audio: @sanchit-gandhi\n- General functionalities: @sayakpaul @yiyixuxu @DN6\n\nIntegrations:\n\n- deepspeed: HF Trainer/Accelerate: @SunMarc\n- PEFT: @sayakpaul @BenjaminBossan\n\nHF projects:\n\n- accelerate: [different repo](https://github.com/huggingface/accelerate)\n- datasets: [different repo](https://github.com/huggingface/datasets)\n- transformers: [different repo](https://github.com/huggingface/transformers)\n- safetensors: [different repo](https://github.com/huggingface/safetensors)\n\n-->\n"
  },
  {
    "path": ".github/actions/setup-miniconda/action.yml",
    "content": "name: Set up conda environment for testing\n\ndescription: Sets up miniconda in your ${RUNNER_TEMP} environment and gives you the ${CONDA_RUN} environment variable so you don't have to worry about polluting non-empeheral runners anymore\n\ninputs:\n  python-version:\n    description: If set to any value, don't use sudo to clean the workspace\n    required: false\n    type: string\n    default: \"3.9\"\n  miniconda-version:\n    description: Miniconda version to install\n    required: false\n    type: string\n    default: \"4.12.0\"\n  environment-file:\n    description: Environment file to install dependencies from\n    required: false\n    type: string\n    default: \"\"\n\nruns:\n  using: composite\n  steps:\n      # Use the same trick from https://github.com/marketplace/actions/setup-miniconda\n      # to refresh the cache daily. This is kind of optional though\n      - name: Get date\n        id: get-date\n        shell: bash\n        run: echo \"today=$(/bin/date -u '+%Y%m%d')d\" >> $GITHUB_OUTPUT\n      - name: Setup miniconda cache\n        id: miniconda-cache\n        uses: actions/cache@v2\n        with:\n          path: ${{ runner.temp }}/miniconda\n          key: miniconda-${{ runner.os }}-${{ runner.arch }}-${{ inputs.python-version }}-${{ steps.get-date.outputs.today }}\n      - name: Install miniconda (${{ inputs.miniconda-version }})\n        if: steps.miniconda-cache.outputs.cache-hit != 'true'\n        env:\n          MINICONDA_VERSION: ${{ inputs.miniconda-version }}\n        shell: bash -l {0}\n        run: |\n          MINICONDA_INSTALL_PATH=\"${RUNNER_TEMP}/miniconda\"\n          mkdir -p \"${MINICONDA_INSTALL_PATH}\"\n          case ${RUNNER_OS}-${RUNNER_ARCH} in\n            Linux-X64)\n              MINICONDA_ARCH=\"Linux-x86_64\"\n              ;;\n            macOS-ARM64)\n              MINICONDA_ARCH=\"MacOSX-arm64\"\n              ;;\n            macOS-X64)\n              MINICONDA_ARCH=\"MacOSX-x86_64\"\n              ;;\n            *)\n            echo \"::error::Platform ${RUNNER_OS}-${RUNNER_ARCH} currently unsupported using this action\"\n              exit 1\n              ;;\n          esac\n          MINICONDA_URL=\"https://repo.anaconda.com/miniconda/Miniconda3-py39_${MINICONDA_VERSION}-${MINICONDA_ARCH}.sh\"\n          curl -fsSL \"${MINICONDA_URL}\" -o \"${MINICONDA_INSTALL_PATH}/miniconda.sh\"\n          bash \"${MINICONDA_INSTALL_PATH}/miniconda.sh\" -b -u -p \"${MINICONDA_INSTALL_PATH}\"\n          rm -rf \"${MINICONDA_INSTALL_PATH}/miniconda.sh\"\n      - name: Update GitHub path to include miniconda install\n        shell: bash\n        run: |\n          MINICONDA_INSTALL_PATH=\"${RUNNER_TEMP}/miniconda\"\n          echo \"${MINICONDA_INSTALL_PATH}/bin\" >> $GITHUB_PATH\n      - name: Setup miniconda env cache (with env file)\n        id: miniconda-env-cache-env-file\n        if: ${{ runner.os }} == 'macOS' && ${{ inputs.environment-file }} != ''\n        uses: actions/cache@v2\n        with:\n          path: ${{ runner.temp }}/conda-python-${{ inputs.python-version }}\n          key: miniconda-env-${{ runner.os }}-${{ runner.arch }}-${{ inputs.python-version }}-${{ steps.get-date.outputs.today }}-${{ hashFiles(inputs.environment-file) }}\n      - name: Setup miniconda env cache (without env file)\n        id: miniconda-env-cache\n        if: ${{ runner.os }} == 'macOS' && ${{ inputs.environment-file }} == ''\n        uses: actions/cache@v2\n        with:\n          path: ${{ runner.temp }}/conda-python-${{ inputs.python-version }}\n          key: miniconda-env-${{ runner.os }}-${{ runner.arch }}-${{ inputs.python-version }}-${{ steps.get-date.outputs.today }}\n      - name: Setup conda environment with python (v${{ inputs.python-version }})\n        if: steps.miniconda-env-cache-env-file.outputs.cache-hit != 'true' && steps.miniconda-env-cache.outputs.cache-hit != 'true'\n        shell: bash\n        env:\n          PYTHON_VERSION: ${{ inputs.python-version }}\n          ENV_FILE: ${{ inputs.environment-file }}\n        run: |\n          CONDA_BASE_ENV=\"${RUNNER_TEMP}/conda-python-${PYTHON_VERSION}\"\n          ENV_FILE_FLAG=\"\"\n          if [[ -f \"${ENV_FILE}\" ]]; then\n            ENV_FILE_FLAG=\"--file ${ENV_FILE}\"\n          elif [[ -n \"${ENV_FILE}\" ]]; then\n            echo \"::warning::Specified env file (${ENV_FILE}) not found, not going to include it\"\n          fi\n          conda create \\\n            --yes \\\n            --prefix \"${CONDA_BASE_ENV}\" \\\n            \"python=${PYTHON_VERSION}\" \\\n            ${ENV_FILE_FLAG} \\\n            cmake=3.22 \\\n            conda-build=3.21 \\\n            ninja=1.10 \\\n            pkg-config=0.29 \\\n            wheel=0.37\n      - name: Clone the base conda environment and update GitHub env\n        shell: bash\n        env:\n          PYTHON_VERSION: ${{ inputs.python-version }}\n          CONDA_BASE_ENV: ${{ runner.temp }}/conda-python-${{ inputs.python-version }}\n        run: |\n          CONDA_ENV=\"${RUNNER_TEMP}/conda_environment_${GITHUB_RUN_ID}\"\n          conda create \\\n            --yes \\\n            --prefix \"${CONDA_ENV}\" \\\n            --clone \"${CONDA_BASE_ENV}\"\n          # TODO: conda-build could not be cloned because it hardcodes the path, so it\n          # could not be cached\n          conda install --yes -p ${CONDA_ENV} conda-build=3.21\n          echo \"CONDA_ENV=${CONDA_ENV}\" >> \"${GITHUB_ENV}\"\n          echo \"CONDA_RUN=conda run -p ${CONDA_ENV} --no-capture-output\" >> \"${GITHUB_ENV}\"\n          echo \"CONDA_BUILD=conda run -p ${CONDA_ENV} conda-build\" >> \"${GITHUB_ENV}\"\n          echo \"CONDA_INSTALL=conda install -p ${CONDA_ENV}\" >> \"${GITHUB_ENV}\"\n      - name: Get disk space usage and throw an error for low disk space\n        shell: bash\n        run: |\n          echo \"Print the available disk space for manual inspection\"\n          df -h\n          # Set the minimum requirement space to 4GB\n          MINIMUM_AVAILABLE_SPACE_IN_GB=4\n          MINIMUM_AVAILABLE_SPACE_IN_KB=$(($MINIMUM_AVAILABLE_SPACE_IN_GB * 1024 * 1024))\n          # Use KB to avoid floating point warning like 3.1GB\n          df -k | tr -s ' ' | cut -d' ' -f 4,9 | while read -r LINE;\n          do\n            AVAIL=$(echo $LINE | cut -f1 -d' ')\n            MOUNT=$(echo $LINE | cut -f2 -d' ')\n            if [ \"$MOUNT\" = \"/\" ]; then\n              if [ \"$AVAIL\" -lt \"$MINIMUM_AVAILABLE_SPACE_IN_KB\" ]; then\n                echo \"There is only ${AVAIL}KB free space left in $MOUNT, which is less than the minimum requirement of ${MINIMUM_AVAILABLE_SPACE_IN_KB}KB. Please help create an issue to PyTorch Release Engineering via https://github.com/pytorch/test-infra/issues and provide the link to the workflow run.\"\n                exit 1;\n              else\n                echo \"There is ${AVAIL}KB free space left in $MOUNT, continue\"\n              fi\n            fi\n          done\n"
  },
  {
    "path": ".github/workflows/benchmark.yml",
    "content": "name: Benchmarking tests\n\non:\n  workflow_dispatch:\n  schedule:\n    - cron: \"30 1 1,15 * *\" # every 2 weeks on the 1st and the 15th of every month at 1:30 AM\n\nenv:\n  DIFFUSERS_IS_CI: yes\n  HF_XET_HIGH_PERFORMANCE: 1\n  HF_HOME: /mnt/cache\n  OMP_NUM_THREADS: 8\n  MKL_NUM_THREADS: 8\n  BASE_PATH: benchmark_outputs\n\njobs:\n  torch_models_cuda_benchmark_tests:\n    env:\n      SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_BENCHMARK }}\n    name: Torch Core Models CUDA Benchmarking Tests\n    strategy:\n      fail-fast: false\n      max-parallel: 1\n    runs-on:\n      group: aws-g6e-4xlarge\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --shm-size \"16gb\" --ipc host --gpus all\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n      - name: NVIDIA-SMI\n        run: |\n          nvidia-smi\n      - name: Install dependencies\n        run: |\n          apt update\n          apt install -y libpq-dev postgresql-client\n          uv pip install -e \".[quality]\"\n          uv pip install -r benchmarks/requirements.txt\n      - name: Environment\n        run: |\n          python utils/print_env.py\n      - name: Diffusers Benchmarking\n        env:\n          HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n        run: |\n          cd benchmarks && python run_all.py\n\n      - name: Push results to the Hub\n        env: \n          HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}\n        run: |\n          cd benchmarks && python push_results.py\n          mkdir $BASE_PATH && cp *.csv $BASE_PATH\n\n      - name: Test suite reports artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: benchmark_test_reports\n          path: benchmarks/${{ env.BASE_PATH }}\n\n      - name: Report success status\n        if: ${{ success() }}\n        run: |\n          pip install requests && python utils/notify_benchmarking_status.py --status=success\n\n      - name: Report failure status\n        if: ${{ failure() }}\n        run: |\n          pip install requests && python utils/notify_benchmarking_status.py --status=failure"
  },
  {
    "path": ".github/workflows/build_docker_images.yml",
    "content": "name: Test, build, and push Docker images\n\non:\n  pull_request: # During PRs, we just check if the changes Dockerfiles can be successfully built\n    branches:\n      - main\n    paths:\n      - \"docker/**\"\n  workflow_dispatch:\n  schedule:\n    - cron: \"0 0 * * *\" # every day at midnight\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\nenv:\n  REGISTRY: diffusers\n  CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}\n\njobs:\n  test-build-docker-images:\n    runs-on:\n      group: aws-general-8-plus\n    if: github.event_name == 'pull_request'\n    steps:\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Check out code\n        uses: actions/checkout@v6\n\n      - name: Find Changed Dockerfiles\n        id: file_changes\n        uses: jitterbit/get-changed-files@v1\n        with:\n          format: \"space-delimited\"\n          token: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Build Changed Docker Images\n        env: \n          CHANGED_FILES: ${{ steps.file_changes.outputs.all }}\n        run: |\n          echo \"$CHANGED_FILES\"\n          ALLOWED_IMAGES=(\n            diffusers-pytorch-cpu\n            diffusers-pytorch-cuda\n            diffusers-pytorch-xformers-cuda\n            diffusers-pytorch-minimum-cuda\n            diffusers-doc-builder\n          )\n\n          declare -A IMAGES_TO_BUILD=()\n\n          for FILE in $CHANGED_FILES; do\n            # skip anything that isn't still on disk\n            if [[ ! -e \"$FILE\" ]]; then\n              echo \"Skipping removed file $FILE\"\n              continue\n            fi\n\n            for IMAGE in \"${ALLOWED_IMAGES[@]}\"; do\n              if [[ \"$FILE\" == docker/${IMAGE}/* ]]; then\n                IMAGES_TO_BUILD[\"$IMAGE\"]=1\n              fi\n            done\n          done\n\n          if [[ ${#IMAGES_TO_BUILD[@]} -eq 0 ]]; then\n            echo \"No relevant Docker changes detected.\"\n            exit 0\n          fi\n\n          for IMAGE in \"${!IMAGES_TO_BUILD[@]}\"; do\n            DOCKER_PATH=\"docker/${IMAGE}\"\n            echo \"Building Docker image for $IMAGE\"\n            docker build -t \"$IMAGE\" \"$DOCKER_PATH\"\n          done\n        if: steps.file_changes.outputs.all != ''\n\n  build-and-push-docker-images:\n    runs-on:\n      group: aws-general-8-plus\n    if: github.event_name != 'pull_request'\n\n    permissions:\n      contents: read\n      packages: write\n\n    strategy:\n      fail-fast: false\n      matrix:\n        image-name:\n          - diffusers-pytorch-cpu\n          - diffusers-pytorch-cuda\n          - diffusers-pytorch-xformers-cuda\n          - diffusers-pytorch-minimum-cuda\n          - diffusers-doc-builder\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v6\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ env.REGISTRY }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n      - name: Build and push\n        uses: docker/build-push-action@v6\n        with:\n          no-cache: true\n          context: ./docker/${{ matrix.image-name }}\n          push: true\n          tags: ${{ env.REGISTRY }}/${{ matrix.image-name }}:latest\n\n      - name: Post to a Slack channel\n        id: slack\n        uses: huggingface/hf-workflows/.github/actions/post-slack@main\n        with:\n          # Slack channel id, channel name, or user id to post message.\n          # See also: https://api.slack.com/methods/chat.postMessage#channels\n          slack_channel: ${{ env.CI_SLACK_CHANNEL }}\n          title: \"🤗 Results of the ${{ matrix.image-name }} Docker Image build\"\n          status: ${{ job.status }}\n          slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/build_documentation.yml",
    "content": "name: Build documentation\n\non:\n  push:\n    branches:\n      - main\n      - doc-builder*\n      - v*-release\n      - v*-patch\n    paths:\n      - \"src/diffusers/**.py\"\n      - \"examples/**\"\n      - \"docs/**\"\n\njobs:\n  build:\n    uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main\n    with:\n      commit_sha: ${{ github.sha }}\n      install_libgl1: true\n      package: diffusers\n      notebook_folder: diffusers_doc\n      languages: en ko zh ja pt\n      custom_container: diffusers/diffusers-doc-builder\n    secrets:\n      token: ${{ secrets.HUGGINGFACE_PUSH }}\n      hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}\n"
  },
  {
    "path": ".github/workflows/build_pr_documentation.yml",
    "content": "name: Build PR Documentation\n\non:\n  pull_request:\n    paths:\n      - \"src/diffusers/**.py\"\n      - \"examples/**\"\n      - \"docs/**\"\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  check-links:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v6\n\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: '3.10'\n\n      - name: Install uv\n        run: |\n          curl -LsSf https://astral.sh/uv/install.sh | sh\n          echo \"$HOME/.cargo/bin\" >> $GITHUB_PATH\n\n      - name: Install doc-builder\n        run: |\n          uv pip install --system git+https://github.com/huggingface/doc-builder.git@main\n\n      - name: Check documentation links\n        run: |\n          uv run doc-builder check-links docs/source/en\n\n  build:\n    needs: check-links\n    uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main\n    with:\n      commit_sha: ${{ github.event.pull_request.head.sha }}\n      pr_number: ${{ github.event.number }}\n      install_libgl1: true\n      package: diffusers\n      languages: en ko zh ja pt\n      custom_container: diffusers/diffusers-doc-builder\n"
  },
  {
    "path": ".github/workflows/codeql.yml",
    "content": "---\nname: CodeQL Security Analysis For Github Actions\n\non:\n  push:\n    branches: [\"main\"]\n  workflow_dispatch:\n  # pull_request:\n\njobs:\n  codeql:\n    name: CodeQL Analysis\n    uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1\n    permissions:\n      security-events: write\n      packages: read\n      actions: read\n      contents: read\n    with:\n      languages: '[\"actions\",\"python\"]'\n      queries: 'security-extended,security-and-quality'\n      runner: 'ubuntu-latest' #optional if need custom runner\n"
  },
  {
    "path": ".github/workflows/mirror_community_pipeline.yml",
    "content": "name: Mirror Community Pipeline\n\non:\n  # Push changes on the main branch\n  push:\n    branches:\n      - main\n    paths:\n      - 'examples/community/**.py'\n\n    # And on tag creation (e.g. `v0.28.1`)\n    tags:\n      - '*'\n\n  # Manual trigger with ref input\n  workflow_dispatch:\n    inputs:\n      ref:\n        description: \"Either 'main' or a tag ref\"\n        required: true\n        default: 'main'\n\njobs:\n  mirror_community_pipeline:\n    env:\n      SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }}\n    runs-on: ubuntu-22.04\n    steps:\n      # Checkout to correct ref\n      #   If workflow dispatch\n      #     If ref is 'main', set:\n      #       CHECKOUT_REF=refs/heads/main\n      #       PATH_IN_REPO=main\n      #     Else it must be a tag. Set:\n      #       CHECKOUT_REF=refs/tags/{tag}\n      #       PATH_IN_REPO={tag}\n      #   If not workflow dispatch\n      #     If ref is 'refs/heads/main' => set 'main'\n      #     Else it must be a tag => set {tag}\n      - name: Set checkout_ref and path_in_repo\n        env:\n          EVENT_NAME: ${{ github.event_name }}\n          EVENT_INPUT_REF: ${{ github.event.inputs.ref }}\n          GITHUB_REF: ${{ github.ref }}\n        run: |\n          if [ \"$EVENT_NAME\" == \"workflow_dispatch\" ]; then\n            if [ -z \"$EVENT_INPUT_REF\" ]; then\n              echo \"Error: Missing ref input\"\n              exit 1\n            elif [ \"$EVENT_INPUT_REF\" == \"main\" ]; then\n              echo \"CHECKOUT_REF=refs/heads/main\" >> $GITHUB_ENV\n              echo \"PATH_IN_REPO=main\" >> $GITHUB_ENV\n            else\n              echo \"CHECKOUT_REF=refs/tags/$EVENT_INPUT_REF\" >> $GITHUB_ENV\n              echo \"PATH_IN_REPO=$EVENT_INPUT_REF\" >> $GITHUB_ENV\n            fi\n          elif [ \"$GITHUB_REF\" == \"refs/heads/main\" ]; then\n            echo \"CHECKOUT_REF=$GITHUB_REF\" >> $GITHUB_ENV\n            echo \"PATH_IN_REPO=main\" >> $GITHUB_ENV\n          else\n            # e.g. refs/tags/v0.28.1 -> v0.28.1\n            echo \"CHECKOUT_REF=$GITHUB_REF\" >> $GITHUB_ENV\n            echo \"PATH_IN_REPO=$(echo $GITHUB_REF | sed 's/^refs\\/tags\\///')\" >> $GITHUB_ENV\n          fi\n      - name: Print env vars\n        run: |\n          echo \"CHECKOUT_REF: ${{ env.CHECKOUT_REF }}\"\n          echo \"PATH_IN_REPO: ${{ env.PATH_IN_REPO }}\"\n      - uses: actions/checkout@v6\n        with:\n          ref: ${{ env.CHECKOUT_REF }}\n\n      # Setup + install dependencies\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n      - name: Install dependencies\n        run: |\n          pip install --upgrade pip\n          pip install --upgrade huggingface_hub\n\n      # Check secret is set\n      - name: whoami\n        run: hf auth whoami\n        env:\n            HF_TOKEN: ${{ secrets.HF_TOKEN_MIRROR_COMMUNITY_PIPELINES }}\n\n      # Push to HF! (under subfolder based on checkout ref)\n      # https://huggingface.co/datasets/diffusers/community-pipelines-mirror\n      - name: Mirror community pipeline to HF\n        run: hf upload diffusers/community-pipelines-mirror ./examples/community ${PATH_IN_REPO} --repo-type dataset\n        env:\n            PATH_IN_REPO: ${{ env.PATH_IN_REPO }}\n            HF_TOKEN: ${{ secrets.HF_TOKEN_MIRROR_COMMUNITY_PIPELINES }}\n\n      - name: Report success status\n        if: ${{ success() }}\n        run: |\n          pip install requests && python utils/notify_community_pipelines_mirror.py --status=success\n\n      - name: Report failure status\n        if: ${{ failure() }}\n        run: |\n          pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure\n"
  },
  {
    "path": ".github/workflows/nightly_tests.yml",
    "content": "name: Nightly and release tests on main/release branch\n\non:\n  workflow_dispatch:\n  schedule:\n    - cron: \"0 0 * * *\" # every day at midnight\n\nenv:\n  DIFFUSERS_IS_CI: yes\n  HF_XET_HIGH_PERFORMANCE: 1\n  OMP_NUM_THREADS: 8\n  MKL_NUM_THREADS: 8\n  PYTEST_TIMEOUT: 600\n  RUN_SLOW: yes\n  RUN_NIGHTLY: yes\n  PIPELINE_USAGE_CUTOFF: 0\n  SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}\n  CONSOLIDATED_REPORT_PATH: consolidated_test_report.md\n\njobs:\n  setup_torch_cuda_pipeline_matrix:\n    name: Setup Torch Pipelines CUDA Slow Tests Matrix\n    runs-on:\n      group: aws-general-8-plus\n    container:\n      image: diffusers/diffusers-pytorch-cpu\n    outputs:\n      pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n      - name: Install dependencies\n        run: |\n          pip install -e .[test]\n          pip install huggingface_hub\n      - name: Fetch Pipeline Matrix\n        id: fetch_pipeline_matrix\n        run: |\n          matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)\n          echo $matrix\n          echo \"pipeline_test_matrix=$matrix\" >> $GITHUB_OUTPUT\n\n      - name: Pipeline Tests Artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: test-pipelines.json\n          path: reports\n\n  run_nightly_tests_for_torch_pipelines:\n    name: Nightly Torch Pipelines CUDA Tests\n    needs: setup_torch_cuda_pipeline_matrix\n    strategy:\n      fail-fast: false\n      max-parallel: 8\n      matrix:\n        module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}\n    runs-on:\n      group: aws-g4dn-2xlarge\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --shm-size \"16gb\" --ipc host --gpus all\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n      - name: NVIDIA-SMI\n        run: nvidia-smi\n      - name: Install dependencies\n        run: |\n          uv pip install -e \".[quality]\"\n          uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git\n          #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n          uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 \n          uv pip install pytest-reportlog\n      - name: Environment\n        run: |\n          python utils/print_env.py\n      - name: Pipeline CUDA Test\n        env:\n          HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n          # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n          CUBLAS_WORKSPACE_CONFIG: :16:8\n        run: |\n          pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n             -k \"not Flax and not Onnx\" \\\n            --make-reports=tests_pipeline_${{ matrix.module }}_cuda \\\n            --report-log=tests_pipeline_${{ matrix.module }}_cuda.log \\\n            tests/pipelines/${{ matrix.module }}\n      - name: Failure short reports\n        if: ${{ failure() }}\n        run: |\n          cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt\n          cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt\n      - name: Test suite reports artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: pipeline_${{ matrix.module }}_test_reports\n          path: reports\n\n  run_nightly_tests_for_other_torch_modules:\n    name: Nightly Torch CUDA Tests\n    runs-on:\n      group: aws-g4dn-2xlarge\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --shm-size \"16gb\" --ipc host --gpus all\n    defaults:\n      run:\n        shell: bash\n    strategy:\n      fail-fast: false\n      max-parallel: 2\n      matrix:\n        module: [models, schedulers, lora, others, single_file, examples]\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality]\"\n        uv pip install peft@git+https://github.com/huggingface/peft.git\n        uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git\n        #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n        uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 \n        uv pip install pytest-reportlog\n    - name: Environment\n      run: python utils/print_env.py\n\n    - name: Run nightly PyTorch CUDA tests for non-pipeline modules\n      if: ${{ matrix.module != 'examples'}}\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n        # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n        CUBLAS_WORKSPACE_CONFIG: :16:8\n      run: |\n        pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n          -k \"not Flax and not Onnx\" \\\n          --make-reports=tests_torch_${{ matrix.module }}_cuda \\\n          --report-log=tests_torch_${{ matrix.module }}_cuda.log \\\n          tests/${{ matrix.module }}\n\n    - name: Run nightly example tests with Torch\n      if: ${{ matrix.module == 'examples' }}\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n        # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n        CUBLAS_WORKSPACE_CONFIG: :16:8\n      run: |\n        pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n          --make-reports=examples_torch_cuda \\\n          --report-log=examples_torch_cuda.log \\\n          examples/\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: |\n        cat reports/tests_torch_${{ matrix.module }}_cuda_stats.txt\n        cat reports/tests_torch_${{ matrix.module }}_cuda_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: torch_${{ matrix.module }}_cuda_test_reports\n        path: reports\n\n  run_torch_compile_tests:\n    name: PyTorch Compile CUDA tests\n\n    runs-on:\n      group: aws-g4dn-2xlarge\n\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --gpus all --shm-size \"16gb\" --ipc host\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: NVIDIA-SMI\n      run: |\n        nvidia-smi\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality,training]\"\n        #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n        uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 \n    - name: Environment\n      run: |\n        python utils/print_env.py\n    - name: Run torch compile tests on GPU\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n        RUN_COMPILE: yes\n      run: |\n        pytest -n 1 --max-worker-restart=0 --dist=loadfile -k \"compile\" --make-reports=tests_torch_compile_cuda tests/\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: cat reports/tests_torch_compile_cuda_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: torch_compile_test_reports\n        path: reports\n\n  run_big_gpu_torch_tests:\n    name: Torch tests on big GPU\n    strategy:\n      fail-fast: false\n      max-parallel: 2\n    runs-on:\n      group: aws-g6e-xlarge-plus\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --shm-size \"16gb\" --ipc host --gpus all\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n      - name: NVIDIA-SMI\n        run: nvidia-smi\n      - name: Install dependencies\n        run: |\n          uv pip install -e \".[quality]\"\n          uv pip install peft@git+https://github.com/huggingface/peft.git\n          uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git\n          #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n          uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 \n          uv pip install pytest-reportlog\n      - name: Environment\n        run: |\n          python utils/print_env.py\n      - name: Selected Torch CUDA Test on big GPU\n        env:\n          HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n          # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n          CUBLAS_WORKSPACE_CONFIG: :16:8\n          BIG_GPU_MEMORY: 40\n        run: |\n          pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n            -m \"big_accelerator\" \\\n            --make-reports=tests_big_gpu_torch_cuda \\\n            --report-log=tests_big_gpu_torch_cuda.log \\\n            tests/\n      - name: Failure short reports\n        if: ${{ failure() }}\n        run: |\n          cat reports/tests_big_gpu_torch_cuda_stats.txt\n          cat reports/tests_big_gpu_torch_cuda_failures_short.txt\n      - name: Test suite reports artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: torch_cuda_big_gpu_test_reports\n          path: reports\n\n  torch_minimum_version_cuda_tests:\n    name: Torch Minimum Version CUDA Tests\n    runs-on:\n      group: aws-g4dn-2xlarge\n    container:\n      image: diffusers/diffusers-pytorch-minimum-cuda\n      options: --shm-size \"16gb\" --ipc host --gpus all\n    defaults:\n      run:\n        shell: bash\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n\n      - name: Install dependencies\n        run: |\n          uv pip install -e \".[quality]\"\n          uv pip install peft@git+https://github.com/huggingface/peft.git\n          uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git\n          #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n          uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 \n\n      - name: Environment\n        run: |\n          python utils/print_env.py\n\n      - name: Run PyTorch CUDA tests\n        env:\n          HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n          # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n          CUBLAS_WORKSPACE_CONFIG: :16:8\n        run: |\n          pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n            -k \"not Flax and not Onnx\" \\\n            --make-reports=tests_torch_minimum_version_cuda \\\n            tests/models/test_modeling_common.py \\\n            tests/pipelines/test_pipelines_common.py \\\n            tests/pipelines/test_pipeline_utils.py \\\n            tests/pipelines/test_pipelines.py \\\n            tests/pipelines/test_pipelines_auto.py \\\n            tests/schedulers/test_schedulers.py \\\n            tests/others\n\n      - name: Failure short reports\n        if: ${{ failure() }}\n        run: |\n          cat reports/tests_torch_minimum_version_cuda_stats.txt\n          cat reports/tests_torch_minimum_version_cuda_failures_short.txt\n\n      - name: Test suite reports artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: torch_minimum_version_cuda_test_reports\n          path: reports\n\n  run_nightly_quantization_tests:\n    name: Torch quantization nightly tests\n    strategy:\n      fail-fast: false\n      max-parallel: 2\n      matrix:\n        config:\n          - backend: \"bitsandbytes\"\n            test_location: \"bnb\"\n            additional_deps: [\"peft\"]\n          - backend: \"gguf\"\n            test_location: \"gguf\"\n            additional_deps: [\"peft\", \"kernels\"]\n          - backend: \"torchao\"\n            test_location: \"torchao\"\n            additional_deps: []\n          - backend: \"optimum_quanto\"\n            test_location: \"quanto\"\n            additional_deps: []\n          - backend: \"nvidia_modelopt\"\n            test_location: \"modelopt\"\n            additional_deps: []\n    runs-on:\n      group: aws-g6e-xlarge-plus\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --shm-size \"20gb\" --ipc host --gpus all\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n      - name: NVIDIA-SMI\n        run: nvidia-smi\n      - name: Install dependencies\n        run: |\n          uv pip install -e \".[quality]\"\n          uv pip install -U ${{ matrix.config.backend }}\n          if [ \"${{ join(matrix.config.additional_deps, ' ') }}\" != \"\" ]; then\n              uv pip install ${{ join(matrix.config.additional_deps, ' ') }}\n          fi\n          uv pip install pytest-reportlog\n          #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n          uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 \n      - name: Environment\n        run: |\n          python utils/print_env.py\n      - name: ${{ matrix.config.backend }} quantization tests on GPU\n        env:\n          HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n          # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n          CUBLAS_WORKSPACE_CONFIG: :16:8\n          BIG_GPU_MEMORY: 40\n        run: |\n          pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n            --make-reports=tests_${{ matrix.config.backend }}_torch_cuda \\\n            --report-log=tests_${{ matrix.config.backend }}_torch_cuda.log \\\n            tests/quantization/${{ matrix.config.test_location }}\n      - name: Failure short reports\n        if: ${{ failure() }}\n        run: |\n          cat reports/tests_${{ matrix.config.backend }}_torch_cuda_stats.txt\n          cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt\n      - name: Test suite reports artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: torch_cuda_${{ matrix.config.backend }}_reports\n          path: reports\n          \n  run_nightly_pipeline_level_quantization_tests:\n    name: Torch quantization nightly tests\n    strategy:\n      fail-fast: false\n      max-parallel: 2\n    runs-on:\n      group: aws-g6e-xlarge-plus\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --shm-size \"20gb\" --ipc host --gpus all\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n      - name: NVIDIA-SMI\n        run: nvidia-smi\n      - name: Install dependencies\n        run: |\n          uv pip install -e \".[quality]\"\n          uv pip install -U bitsandbytes optimum_quanto\n          #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n          uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1 \n          uv pip install pytest-reportlog\n      - name: Environment\n        run: |\n          python utils/print_env.py\n      - name: Pipeline-level quantization tests on GPU\n        env:\n          HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n          # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n          CUBLAS_WORKSPACE_CONFIG: :16:8\n          BIG_GPU_MEMORY: 40\n        run: |\n          pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n            --make-reports=tests_pipeline_level_quant_torch_cuda \\\n            --report-log=tests_pipeline_level_quant_torch_cuda.log \\\n            tests/quantization/test_pipeline_level_quantization.py\n      - name: Failure short reports\n        if: ${{ failure() }}\n        run: |\n          cat reports/tests_pipeline_level_quant_torch_cuda_stats.txt\n          cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt\n      - name: Test suite reports artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: torch_cuda_pipeline_level_quant_reports\n          path: reports\n\n  generate_consolidated_report:\n    name: Generate Consolidated Test Report\n    needs: [\n      run_nightly_tests_for_torch_pipelines,\n      run_nightly_tests_for_other_torch_modules,\n      run_torch_compile_tests,\n      run_big_gpu_torch_tests,\n      run_nightly_quantization_tests,\n      run_nightly_pipeline_level_quantization_tests,\n      # run_nightly_onnx_tests,\n      torch_minimum_version_cuda_tests,\n      # run_flax_tpu_tests\n    ]\n    if: always()\n    runs-on:\n      group: aws-general-8-plus\n    container:\n      image: diffusers/diffusers-pytorch-cpu\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n\n      - name: Create reports directory\n        run: mkdir -p combined_reports\n\n      - name: Download all test reports\n        uses: actions/download-artifact@v7\n        with:\n          path: artifacts\n\n      - name: Prepare reports\n        run: |\n          # Move all report files to a single directory for processing\n          find artifacts -name \"*.txt\" -exec cp {} combined_reports/ \\;\n\n      - name: Install dependencies\n        run: |\n          pip install -e .[test]\n          pip install slack_sdk tabulate\n\n      - name: Generate consolidated report\n        run: |\n          python utils/consolidated_test_report.py \\\n            --reports_dir combined_reports \\\n            --output_file $CONSOLIDATED_REPORT_PATH \\\n            --slack_channel_name diffusers-ci-nightly\n\n      - name: Show consolidated report\n        run: |\n          cat $CONSOLIDATED_REPORT_PATH >> $GITHUB_STEP_SUMMARY\n\n      - name: Upload consolidated report\n        uses: actions/upload-artifact@v6\n        with:\n          name: consolidated_test_report\n          path: ${{ env.CONSOLIDATED_REPORT_PATH }}\n\n# M1 runner currently not well supported\n# TODO: (Dhruv) add these back when we setup better testing for Apple Silicon\n#  run_nightly_tests_apple_m1:\n#    name: Nightly PyTorch MPS tests on MacOS\n#    runs-on: [ self-hosted, apple-m1 ]\n#    if: github.event_name == 'schedule'\n#\n#    steps:\n#      - name: Checkout diffusers\n#        uses: actions/checkout@v6\n#        with:\n#          fetch-depth: 2\n#\n#      - name: Clean checkout\n#        shell: arch -arch arm64 bash {0}\n#        run: |\n#          git clean -fxd\n#      - name: Setup miniconda\n#        uses: ./.github/actions/setup-miniconda\n#        with:\n#          python-version: 3.9\n#\n#      - name: Install dependencies\n#        shell: arch -arch arm64 bash {0}\n#        run: |\n#          ${CONDA_RUN} pip install --upgrade pip uv\n#          ${CONDA_RUN} uv pip install -e \".[quality]\"\n#          ${CONDA_RUN} uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu\n#          ${CONDA_RUN} uv pip install accelerate@git+https://github.com/huggingface/accelerate\n#          ${CONDA_RUN} uv pip install pytest-reportlog\n#      - name: Environment\n#        shell: arch -arch arm64 bash {0}\n#        run: |\n#          ${CONDA_RUN} python utils/print_env.py\n#      - name: Run nightly PyTorch tests on M1 (MPS)\n#        shell: arch -arch arm64 bash {0}\n#        env:\n#          HF_HOME: /System/Volumes/Data/mnt/cache\n#          HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n#        run: |\n#          ${CONDA_RUN} pytest -n 1  --make-reports=tests_torch_mps \\\n#            --report-log=tests_torch_mps.log \\\n#            tests/\n#      - name: Failure short reports\n#        if: ${{ failure() }}\n#        run: cat reports/tests_torch_mps_failures_short.txt\n#\n#      - name: Test suite reports artifacts\n#        if: ${{ always() }}\n#        uses: actions/upload-artifact@v6\n#        with:\n#          name: torch_mps_test_reports\n#          path: reports\n#\n#      - name: Generate Report and Notify Channel\n#        if: always()\n#        run: |\n#          pip install slack_sdk tabulate\n#          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY  run_nightly_tests_apple_m1:\n#    name: Nightly PyTorch MPS tests on MacOS\n#    runs-on: [ self-hosted, apple-m1 ]\n#    if: github.event_name == 'schedule'\n#\n#    steps:\n#      - name: Checkout diffusers\n#        uses: actions/checkout@v6\n#        with:\n#          fetch-depth: 2\n#\n#      - name: Clean checkout\n#        shell: arch -arch arm64 bash {0}\n#        run: |\n#          git clean -fxd\n#      - name: Setup miniconda\n#        uses: ./.github/actions/setup-miniconda\n#        with:\n#          python-version: 3.9\n#\n#      - name: Install dependencies\n#        shell: arch -arch arm64 bash {0}\n#        run: |\n#          ${CONDA_RUN} pip install --upgrade pip uv\n#          ${CONDA_RUN} uv pip install -e \".[quality]\"\n#          ${CONDA_RUN} uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu\n#          ${CONDA_RUN} uv pip install accelerate@git+https://github.com/huggingface/accelerate\n#          ${CONDA_RUN} uv pip install pytest-reportlog\n#      - name: Environment\n#        shell: arch -arch arm64 bash {0}\n#        run: |\n#          ${CONDA_RUN} python utils/print_env.py\n#      - name: Run nightly PyTorch tests on M1 (MPS)\n#        shell: arch -arch arm64 bash {0}\n#        env:\n#          HF_HOME: /System/Volumes/Data/mnt/cache\n#          HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n#        run: |\n#          ${CONDA_RUN} pytest -n 1  --make-reports=tests_torch_mps \\\n#            --report-log=tests_torch_mps.log \\\n#            tests/\n#      - name: Failure short reports\n#        if: ${{ failure() }}\n#        run: cat reports/tests_torch_mps_failures_short.txt\n#\n#      - name: Test suite reports artifacts\n#        if: ${{ always() }}\n#        uses: actions/upload-artifact@v6\n#        with:\n#          name: torch_mps_test_reports\n#          path: reports\n#\n#      - name: Generate Report and Notify Channel\n#        if: always()\n#        run: |\n#          pip install slack_sdk tabulate\n#          python utils/log_reports.py >> $GITHUB_STEP_SUMMARY\n"
  },
  {
    "path": ".github/workflows/notify_slack_about_release.yml",
    "content": "name: Notify Slack about a release\n\non:\n  workflow_dispatch:\n  release:\n    types: [published]\n\njobs:\n  build:\n    runs-on: ubuntu-22.04\n\n    steps:\n    - uses: actions/checkout@v6\n\n    - name: Setup Python\n      uses: actions/setup-python@v6\n      with:\n        python-version: '3.10'\n\n    - name: Notify Slack about the release\n      env:\n        SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}\n      run: pip install requests && python utils/notify_slack_about_release.py\n"
  },
  {
    "path": ".github/workflows/pr_dependency_test.yml",
    "content": "name: Run dependency tests\n\non:\n  pull_request:\n    branches:\n      - main\n    paths:\n      - \"src/diffusers/**.py\"\n  push:\n    branches:\n      - main\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  check_dependencies:\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n      - name: Install dependencies\n        run: |\n          pip install -e .\n          pip install pytest\n      - name: Check for soft dependencies\n        run: |\n            pytest tests/others/test_dependencies.py\n"
  },
  {
    "path": ".github/workflows/pr_modular_tests.yml",
    "content": "\nname: Fast PR tests for Modular\n\non:\n  pull_request:\n    branches: [main]\n    paths:\n      - \"src/diffusers/modular_pipelines/**.py\"\n      - \"src/diffusers/models/modeling_utils.py\"\n      - \"src/diffusers/models/model_loading_utils.py\"\n      - \"src/diffusers/pipelines/pipeline_utils.py\"\n      - \"src/diffusers/pipeline_loading_utils.py\"\n      - \"src/diffusers/loaders/lora_base.py\"\n      - \"src/diffusers/loaders/lora_pipeline.py\"\n      - \"src/diffusers/loaders/peft.py\"\n      - \"tests/modular_pipelines/**.py\"\n      - \".github/**.yml\"\n      - \"utils/**.py\"\n      - \"setup.py\"\n  push:\n    branches:\n      - ci-*\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\nenv:\n  DIFFUSERS_IS_CI: yes\n  HF_XET_HIGH_PERFORMANCE: 1\n  OMP_NUM_THREADS: 4\n  MKL_NUM_THREADS: 4\n  PYTEST_TIMEOUT: 60\n\njobs:\n  check_code_quality:\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n      - name: Install dependencies\n        run: |\n          pip install --upgrade pip\n          pip install .[quality]\n      - name: Check quality\n        run: make quality\n      - name: Check if failure\n        if: ${{ failure() }}\n        run: |\n          echo \"Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'\" >> $GITHUB_STEP_SUMMARY\n\n  check_repository_consistency:\n    needs: check_code_quality\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n      - name: Install dependencies\n        run: |\n          pip install --upgrade pip\n          pip install .[quality]\n      - name: Check repo consistency\n        run: |\n          python utils/check_copies.py\n          python utils/check_dummies.py\n          python utils/check_support_list.py\n          make deps_table_check_updated\n      - name: Check if failure\n        if: ${{ failure() }}\n        run: |\n          echo \"Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'\" >> $GITHUB_STEP_SUMMARY\n  check_auto_docs:\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n      - name: Install dependencies\n        run: |\n          pip install --upgrade pip\n          pip install .[quality]\n      - name: Check auto docs\n        run: make modular-autodoctrings\n      - name: Check if failure\n        if: ${{ failure() }}\n        run: |\n          echo \"Auto docstring checks failed. Please run `python utils/modular_auto_docstring.py --fix_and_overwrite`.\" >> $GITHUB_STEP_SUMMARY\n\n  run_fast_tests:\n    needs: [check_code_quality, check_repository_consistency, check_auto_docs]\n    name: Fast PyTorch Modular Pipeline CPU tests\n\n    runs-on:\n      group: aws-highmemory-32-plus\n\n    container:\n      image: diffusers/diffusers-pytorch-cpu\n      options: --shm-size \"16gb\" --ipc host -v /mnt/hf_cache:/mnt/cache/\n\n    defaults:\n      run:\n        shell: bash\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality]\"\n        #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n        uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1\n        uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps\n\n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Run fast PyTorch Pipeline CPU tests\n      run: |\n        pytest -n 8 --max-worker-restart=0 --dist=loadfile \\\n          -k \"not Flax and not Onnx\" \\\n          --make-reports=tests_torch_cpu_modular_pipelines \\\n          tests/modular_pipelines\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports\n        path: reports\n"
  },
  {
    "path": ".github/workflows/pr_style_bot.yml",
    "content": "name: PR Style Bot\n\non:\n  issue_comment:\n    types: [created]\n\npermissions:\n  contents: write\n  pull-requests: write\n\njobs:\n  style:\n    uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main\n    with:\n      python_quality_dependencies: \"[quality]\"\n    secrets:\n      bot_token: ${{ secrets.HF_STYLE_BOT_ACTION }}"
  },
  {
    "path": ".github/workflows/pr_test_fetcher.yml",
    "content": "name: Fast tests for PRs - Test Fetcher\n\non: workflow_dispatch\n\nenv:\n  DIFFUSERS_IS_CI: yes\n  OMP_NUM_THREADS: 4\n  MKL_NUM_THREADS: 4\n  PYTEST_TIMEOUT: 60\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  setup_pr_tests:\n    name: Setup PR Tests\n    runs-on:\n      group: aws-general-8-plus\n    container:\n      image: diffusers/diffusers-pytorch-cpu\n      options: --shm-size \"16gb\" --ipc host -v /mnt/hf_cache:/mnt/cache/\n    defaults:\n      run:\n        shell: bash\n    outputs:\n      matrix: ${{ steps.set_matrix.outputs.matrix }}\n      test_map: ${{ steps.set_matrix.outputs.test_map }}\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 0\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality]\"\n    - name: Environment\n      run: |\n        python utils/print_env.py\n        echo $(git --version)\n    - name: Fetch Tests\n      run: |\n        python utils/tests_fetcher.py | tee test_preparation.txt\n    - name: Report fetched tests\n      uses: actions/upload-artifact@v6\n      with:\n        name: test_fetched\n        path: test_preparation.txt\n    - id: set_matrix\n      name: Create Test Matrix\n      # The `keys` is used as GitHub actions matrix for jobs, i.e. `models`, `pipelines`, etc.\n      # The `test_map` is used to get the actual identified test files under each key.\n      # If no test to run (so no `test_map.json` file), create a dummy map (empty matrix will fail)\n      run: |\n        if [ -f test_map.json ]; then\n            keys=$(python3 -c 'import json; fp = open(\"test_map.json\"); test_map = json.load(fp); fp.close(); d = list(test_map.keys()); print(json.dumps(d))')\n            test_map=$(python3 -c 'import json; fp = open(\"test_map.json\"); test_map = json.load(fp); fp.close(); print(json.dumps(test_map))')\n        else\n            keys=$(python3 -c 'keys = [\"dummy\"]; print(keys)')\n            test_map=$(python3 -c 'test_map = {\"dummy\": []}; print(test_map)')\n        fi\n        echo $keys\n        echo $test_map\n        echo \"matrix=$keys\" >> $GITHUB_OUTPUT\n        echo \"test_map=$test_map\" >> $GITHUB_OUTPUT\n\n  run_pr_tests:\n    name: Run PR Tests\n    needs: setup_pr_tests\n    if: contains(fromJson(needs.setup_pr_tests.outputs.matrix), 'dummy') != true\n    strategy:\n      fail-fast: false\n      max-parallel: 2\n      matrix:\n        modules: ${{ fromJson(needs.setup_pr_tests.outputs.matrix) }}\n    runs-on:\n      group: aws-general-8-plus\n    container:\n      image: diffusers/diffusers-pytorch-cpu\n      options: --shm-size \"16gb\" --ipc host -v /mnt/hf_cache:/mnt/cache/\n    defaults:\n      run:\n        shell: bash\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality]\"\n        uv pip install accelerate\n\n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Run all selected tests on CPU\n      run: |\n        pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.modules }}_tests_cpu ${{ fromJson(needs.setup_pr_tests.outputs.test_map)[matrix.modules] }}\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      continue-on-error: true\n      run: |\n        cat reports/${{ matrix.modules }}_tests_cpu_stats.txt\n        cat reports/${{ matrix.modules }}_tests_cpu_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n          name: ${{ matrix.modules }}_test_reports\n          path: reports\n\n  run_staging_tests:\n    strategy:\n      fail-fast: false\n      matrix:\n        config:\n          - name: Hub tests for models, schedulers, and pipelines\n            framework: hub_tests_pytorch\n            runner: aws-general-8-plus\n            image: diffusers/diffusers-pytorch-cpu\n            report: torch_hub\n\n    name: ${{ matrix.config.name }}\n    runs-on:\n      group: ${{ matrix.config.runner }}\n    container:\n      image: ${{ matrix.config.image }}\n      options: --shm-size \"16gb\" --ipc host -v /mnt/hf_cache:/mnt/cache/\n\n    defaults:\n      run:\n        shell: bash\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: Install dependencies\n      run: |\n        pip install -e [quality]\n\n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Run Hub tests for models, schedulers, and pipelines on a staging env\n      if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}\n      run: |\n        HUGGINGFACE_CO_STAGING=true pytest \\\n          -m \"is_staging_test\" \\\n          --make-reports=tests_${{ matrix.config.report }} \\\n          tests\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: pr_${{ matrix.config.report }}_test_reports\n        path: reports\n"
  },
  {
    "path": ".github/workflows/pr_tests.yml",
    "content": "name: Fast tests for PRs\n\non:\n  pull_request:\n    branches: [main]\n    paths:\n      - \"src/diffusers/**.py\"\n      - \"benchmarks/**.py\"\n      - \"examples/**.py\"\n      - \"scripts/**.py\"\n      - \"tests/**.py\"\n      - \".github/**.yml\"\n      - \"utils/**.py\"\n      - \"setup.py\"\n  push:\n    branches:\n      - ci-*\n\npermissions:\n  contents: read\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\nenv:\n  DIFFUSERS_IS_CI: yes\n  HF_XET_HIGH_PERFORMANCE: 1\n  OMP_NUM_THREADS: 4\n  MKL_NUM_THREADS: 4\n  PYTEST_TIMEOUT: 60\n\njobs:\n  check_code_quality:\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n      - name: Install dependencies\n        run: |\n          pip install --upgrade pip\n          pip install .[quality]\n      - name: Check quality\n        run: make quality\n      - name: Check if failure\n        if: ${{ failure() }}\n        run: |\n          echo \"Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'\" >> $GITHUB_STEP_SUMMARY\n\n  check_repository_consistency:\n    needs: check_code_quality\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n      - name: Install dependencies\n        run: |\n          pip install --upgrade pip\n          pip install .[quality]\n      - name: Check repo consistency\n        run: |\n          python utils/check_copies.py\n          python utils/check_dummies.py\n          python utils/check_support_list.py\n          make deps_table_check_updated\n      - name: Check if failure\n        if: ${{ failure() }}\n        run: |\n          echo \"Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'\" >> $GITHUB_STEP_SUMMARY\n\n  run_fast_tests:\n    needs: [check_code_quality, check_repository_consistency]\n    strategy:\n      fail-fast: false\n      matrix:\n        config:\n          - name: Fast PyTorch Pipeline CPU tests\n            framework: pytorch_pipelines\n            runner: aws-highmemory-32-plus\n            image: diffusers/diffusers-pytorch-cpu\n            report: torch_cpu_pipelines\n          - name: Fast PyTorch Models & Schedulers CPU tests\n            framework: pytorch_models\n            runner: aws-general-8-plus\n            image: diffusers/diffusers-pytorch-cpu\n            report: torch_cpu_models_schedulers\n          - name: PyTorch Example CPU tests\n            framework: pytorch_examples\n            runner: aws-general-8-plus\n            image: diffusers/diffusers-pytorch-cpu\n            report: torch_example_cpu\n    name: ${{ matrix.config.name }}\n\n    runs-on:\n      group: ${{ matrix.config.runner }}\n\n    container:\n      image: ${{ matrix.config.image }}\n      options: --shm-size \"16gb\" --ipc host -v /mnt/hf_cache:/mnt/cache/\n\n    defaults:\n      run:\n        shell: bash\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality]\"\n        uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n        uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps\n\n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Run fast PyTorch Pipeline CPU tests\n      if: ${{ matrix.config.framework == 'pytorch_pipelines' }}\n      run: |\n        pytest -n 8 --max-worker-restart=0 --dist=loadfile \\\n          -k \"not Flax and not Onnx\" \\\n          --make-reports=tests_${{ matrix.config.report }} \\\n          tests/pipelines\n\n    - name: Run fast PyTorch Model Scheduler CPU tests\n      if: ${{ matrix.config.framework == 'pytorch_models' }}\n      run: |\n        pytest -n 4 --max-worker-restart=0 --dist=loadfile \\\n          -k \"not Flax and not Onnx and not Dependency\" \\\n          --make-reports=tests_${{ matrix.config.report }} \\\n          tests/models tests/schedulers tests/others\n\n    - name: Run example PyTorch CPU tests\n      if: ${{ matrix.config.framework == 'pytorch_examples' }}\n      run: |\n        uv pip install \".[training]\"\n        pytest -n 4 --max-worker-restart=0 --dist=loadfile \\\n          --make-reports=tests_${{ matrix.config.report }} \\\n          examples\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports\n        path: reports\n\n  run_staging_tests:\n    needs: [check_code_quality, check_repository_consistency]\n    strategy:\n      fail-fast: false\n      matrix:\n        config:\n          - name: Hub tests for models, schedulers, and pipelines\n            framework: hub_tests_pytorch\n            runner:\n              group: aws-general-8-plus\n            image: diffusers/diffusers-pytorch-cpu\n            report: torch_hub\n\n    name: ${{ matrix.config.name }}\n\n    runs-on: ${{ matrix.config.runner }}\n\n    container:\n      image: ${{ matrix.config.image }}\n      options: --shm-size \"16gb\" --ipc host -v /mnt/hf_cache:/mnt/cache/\n\n    defaults:\n      run:\n        shell: bash\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality]\"\n\n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Run Hub tests for models, schedulers, and pipelines on a staging env\n      if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}\n      run: |\n        HUGGINGFACE_CO_STAGING=true pytest \\\n          -m \"is_staging_test\" \\\n          --make-reports=tests_${{ matrix.config.report }} \\\n          tests\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: pr_${{ matrix.config.report }}_test_reports\n        path: reports\n\n  run_lora_tests:\n    needs: [check_code_quality, check_repository_consistency]\n\n    name: LoRA tests with PEFT main\n\n    runs-on:\n      group: aws-general-8-plus\n\n    container:\n      image: diffusers/diffusers-pytorch-cpu\n      options: --shm-size \"16gb\" --ipc host -v /mnt/hf_cache:/mnt/cache/\n\n    defaults:\n      run:\n        shell: bash\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality]\"\n        # TODO (sayakpaul, DN6): revisit `--no-deps`\n        uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps\n        uv pip install -U tokenizers\n        uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps\n        uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n        \n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Run fast PyTorch LoRA tests with PEFT\n      run: |\n        pytest -n 4 --max-worker-restart=0 --dist=loadfile \\\n          \\\n          --make-reports=tests_peft_main \\\n          tests/lora/\n        pytest -n 4 --max-worker-restart=0 --dist=loadfile \\\n          \\\n          --make-reports=tests_models_lora_peft_main \\\n          tests/models/ -k \"lora\"\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: |\n        cat reports/tests_peft_main_failures_short.txt\n        cat reports/tests_models_lora_peft_main_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: pr_lora_test_reports\n        path: reports\n\n"
  },
  {
    "path": ".github/workflows/pr_tests_gpu.yml",
    "content": "name: Fast GPU Tests on PR\n\npermissions:\n  contents: read\n\non:\n  pull_request:\n    branches: main\n    paths:\n      - \"src/diffusers/models/modeling_utils.py\"\n      - \"src/diffusers/models/model_loading_utils.py\"\n      - \"src/diffusers/pipelines/pipeline_utils.py\"\n      - \"src/diffusers/pipeline_loading_utils.py\"\n      - \"src/diffusers/loaders/lora_base.py\"\n      - \"src/diffusers/loaders/lora_pipeline.py\"\n      - \"src/diffusers/loaders/peft.py\"\n      - \"tests/pipelines/test_pipelines_common.py\"\n      - \"tests/models/test_modeling_common.py\"\n      - \"examples/**/*.py\"\n  workflow_dispatch:\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\nenv:\n  DIFFUSERS_IS_CI: yes\n  OMP_NUM_THREADS: 8\n  MKL_NUM_THREADS: 8\n  HF_XET_HIGH_PERFORMANCE: 1\n  PYTEST_TIMEOUT: 600\n  PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run\n\njobs:\n  check_code_quality:\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n      - name: Install dependencies\n        run: |\n          pip install --upgrade pip\n          pip install .[quality]\n      - name: Check quality\n        run: make quality\n      - name: Check if failure\n        if: ${{ failure() }}\n        run: |\n          echo \"Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'\" >> $GITHUB_STEP_SUMMARY\n\n  check_repository_consistency:\n    needs: check_code_quality\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n      - name: Install dependencies\n        run: |\n          pip install --upgrade pip\n          pip install .[quality]\n      - name: Check repo consistency\n        run: |\n          python utils/check_copies.py\n          python utils/check_dummies.py\n          python utils/check_support_list.py\n          make deps_table_check_updated\n      - name: Check if failure\n        if: ${{ failure() }}\n        run: |\n          echo \"Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'\" >> $GITHUB_STEP_SUMMARY\n\n  setup_torch_cuda_pipeline_matrix:\n    needs: [check_code_quality, check_repository_consistency]\n    name: Setup Torch Pipelines CUDA Slow Tests Matrix\n    runs-on:\n      group: aws-general-8-plus\n    container:\n      image: diffusers/diffusers-pytorch-cpu\n    outputs:\n      pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n      - name: Install dependencies\n        run: |\n          uv pip install -e \".[quality]\"\n      - name: Environment\n        run: |\n          python utils/print_env.py\n      - name: Fetch Pipeline Matrix\n        id: fetch_pipeline_matrix\n        run: |\n          matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)\n          echo $matrix\n          echo \"pipeline_test_matrix=$matrix\" >> $GITHUB_OUTPUT\n      - name: Pipeline Tests Artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: test-pipelines.json\n          path: reports\n\n  torch_pipelines_cuda_tests:\n    name: Torch Pipelines CUDA Tests\n    needs: setup_torch_cuda_pipeline_matrix\n    strategy:\n      fail-fast: false\n      max-parallel: 8\n      matrix:\n        module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}\n    runs-on:\n      group: aws-g4dn-2xlarge\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --shm-size \"16gb\" --ipc host --gpus all\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n\n      - name: NVIDIA-SMI\n        run: |\n          nvidia-smi\n      - name: Install dependencies\n        run: |\n          uv pip install -e \".[quality]\"\n          uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git\n          uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n\n      - name: Environment\n        run: |\n          python utils/print_env.py\n      - name: Extract tests\n        id: extract_tests\n        run: |\n          pattern=$(python utils/extract_tests_from_mixin.py --type pipeline)\n          echo \"$pattern\" > /tmp/test_pattern.txt\n          echo \"pattern_file=/tmp/test_pattern.txt\" >> $GITHUB_OUTPUT\n\n      - name: PyTorch CUDA checkpoint tests on Ubuntu\n        env:\n          HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n          # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n          CUBLAS_WORKSPACE_CONFIG: :16:8\n        run: |\n          if [ \"${{ matrix.module }}\" = \"ip_adapters\" ]; then\n              pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n              -k \"not Flax and not Onnx\" \\\n              --make-reports=tests_pipeline_${{ matrix.module }}_cuda \\\n              tests/pipelines/${{ matrix.module }}\n          else\n              pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})\n              pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n              -k \"not Flax and not Onnx and $pattern\" \\\n              --make-reports=tests_pipeline_${{ matrix.module }}_cuda \\\n              tests/pipelines/${{ matrix.module }}\n          fi\n\n      - name: Failure short reports\n        if: ${{ failure() }}\n        run: |\n          cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt\n          cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt\n      - name: Test suite reports artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: pipeline_${{ matrix.module }}_test_reports\n          path: reports\n\n  torch_cuda_tests:\n    name: Torch CUDA Tests\n    needs: [check_code_quality, check_repository_consistency]\n    runs-on:\n      group: aws-g4dn-2xlarge\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --shm-size \"16gb\" --ipc host --gpus all\n    defaults:\n      run:\n        shell: bash\n    strategy:\n      fail-fast: false\n      max-parallel: 4\n      matrix:\n        module: [models, schedulers, lora, others]\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality]\"\n        uv pip install peft@git+https://github.com/huggingface/peft.git\n        uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git\n        uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n\n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Extract tests\n      id: extract_tests\n      run: |\n        pattern=$(python utils/extract_tests_from_mixin.py --type ${{ matrix.module }})\n        echo \"$pattern\" > /tmp/test_pattern.txt\n        echo \"pattern_file=/tmp/test_pattern.txt\" >> $GITHUB_OUTPUT\n\n    - name: Run PyTorch CUDA tests\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n        # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n        CUBLAS_WORKSPACE_CONFIG: :16:8\n      run: |\n        pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})\n        if [ -z \"$pattern\" ]; then\n          pytest -n 1  --max-worker-restart=0 --dist=loadfile -k \"not Flax and not Onnx\" tests/${{ matrix.module }} \\\n          --make-reports=tests_torch_cuda_${{ matrix.module }}\n        else\n          pytest -n 1  --max-worker-restart=0 --dist=loadfile -k \"not Flax and not Onnx and $pattern\" tests/${{ matrix.module }} \\\n          --make-reports=tests_torch_cuda_${{ matrix.module }}\n        fi\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: |\n        cat reports/tests_torch_cuda_${{ matrix.module }}_stats.txt\n        cat reports/tests_torch_cuda_${{ matrix.module }}_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: torch_cuda_test_reports_${{ matrix.module }}\n        path: reports\n\n  run_examples_tests:\n    name: Examples PyTorch CUDA tests on Ubuntu\n    needs: [check_code_quality, check_repository_consistency]\n    runs-on:\n      group: aws-g4dn-2xlarge\n\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --gpus all --shm-size \"16gb\" --ipc host\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: NVIDIA-SMI\n      run: |\n        nvidia-smi\n    - name: Install dependencies\n      run: |\n        uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n        uv pip install -e \".[quality,training]\"\n\n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Run example tests on GPU\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n      run: |\n        uv pip install \".[training]\"\n        pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: |\n        cat reports/examples_torch_cuda_stats.txt\n        cat reports/examples_torch_cuda_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: examples_test_reports\n        path: reports\n\n"
  },
  {
    "path": ".github/workflows/pr_torch_dependency_test.yml",
    "content": "name: Run Torch dependency tests\n\non:\n  pull_request:\n    branches:\n      - main\n    paths:\n      - \"src/diffusers/**.py\"\n  push:\n    branches:\n      - main\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  check_torch_dependencies:\n    runs-on: ubuntu-22.04\n    steps:\n      - uses: actions/checkout@v6\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n      - name: Install dependencies\n        run: |\n          pip install -e .\n          pip install torch torchvision torchaudio pytest\n      - name: Check for soft dependencies\n        run: |\n            pytest tests/others/test_dependencies.py\n"
  },
  {
    "path": ".github/workflows/push_tests.yml",
    "content": "name: Fast GPU Tests on main\n\non:\n  workflow_dispatch:\n  push:\n    branches:\n      - main\n    paths:\n      - \"src/diffusers/**.py\"\n      - \"examples/**.py\"\n      - \"tests/**.py\"\n\nenv:\n  DIFFUSERS_IS_CI: yes\n  OMP_NUM_THREADS: 8\n  MKL_NUM_THREADS: 8\n  HF_XET_HIGH_PERFORMANCE: 1\n  PYTEST_TIMEOUT: 600\n  PIPELINE_USAGE_CUTOFF: 50000\n\njobs:\n  setup_torch_cuda_pipeline_matrix:\n    name: Setup Torch Pipelines CUDA Slow Tests Matrix\n    runs-on:\n      group: aws-general-8-plus\n    container:\n      image: diffusers/diffusers-pytorch-cpu\n    outputs:\n      pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n      - name: Install dependencies\n        run: |\n          uv pip install -e \".[quality]\"\n      - name: Environment\n        run: |\n          python utils/print_env.py\n      - name: Fetch Pipeline Matrix\n        id: fetch_pipeline_matrix\n        run: |\n          matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)\n          echo $matrix\n          echo \"pipeline_test_matrix=$matrix\" >> $GITHUB_OUTPUT\n      - name: Pipeline Tests Artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: test-pipelines.json\n          path: reports\n\n  torch_pipelines_cuda_tests:\n    name: Torch Pipelines CUDA Tests\n    needs: setup_torch_cuda_pipeline_matrix\n    strategy:\n      fail-fast: false\n      max-parallel: 8\n      matrix:\n        module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}\n    runs-on:\n      group: aws-g4dn-2xlarge\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --shm-size \"16gb\" --ipc host --gpus all\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n      - name: NVIDIA-SMI\n        run: |\n          nvidia-smi\n      - name: Install dependencies\n        run: |\n          uv pip install -e \".[quality]\"\n          uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git\n          uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n      - name: Environment\n        run: |\n          python utils/print_env.py\n      - name: PyTorch CUDA checkpoint tests on Ubuntu\n        env:\n          HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n          # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n          CUBLAS_WORKSPACE_CONFIG: :16:8\n        run: |\n          pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n            -k \"not Flax and not Onnx\" \\\n            --make-reports=tests_pipeline_${{ matrix.module }}_cuda \\\n            tests/pipelines/${{ matrix.module }}\n      - name: Failure short reports\n        if: ${{ failure() }}\n        run: |\n          cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt\n          cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt\n      - name: Test suite reports artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: pipeline_${{ matrix.module }}_test_reports\n          path: reports\n\n  torch_cuda_tests:\n    name: Torch CUDA Tests\n    runs-on:\n      group: aws-g4dn-2xlarge\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --shm-size \"16gb\" --ipc host --gpus all\n    defaults:\n      run:\n        shell: bash\n    strategy:\n      fail-fast: false\n      max-parallel: 2\n      matrix:\n        module: [models, schedulers, lora, others, single_file]\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality]\"\n        uv pip install peft@git+https://github.com/huggingface/peft.git\n        uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git\n        uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n\n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Run PyTorch CUDA tests\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n        # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n        CUBLAS_WORKSPACE_CONFIG: :16:8\n      run: |\n        pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n          -k \"not Flax and not Onnx\" \\\n          --make-reports=tests_torch_cuda_${{ matrix.module }} \\\n          tests/${{ matrix.module }}\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: |\n        cat reports/tests_torch_cuda_${{ matrix.module }}_stats.txt\n        cat reports/tests_torch_cuda_${{ matrix.module }}_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: torch_cuda_test_reports_${{ matrix.module }}\n        path: reports\n\n  run_torch_compile_tests:\n    name: PyTorch Compile CUDA tests\n\n    runs-on:\n      group: aws-g4dn-2xlarge\n\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --gpus all --shm-size \"16gb\" --ipc host\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: NVIDIA-SMI\n      run: |\n        nvidia-smi\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality,training]\"\n        uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n    - name: Environment\n      run: |\n        python utils/print_env.py\n    - name: Run example tests on GPU\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n        RUN_COMPILE: yes\n      run: |\n        pytest -n 1 --max-worker-restart=0 --dist=loadfile -k \"compile\" --make-reports=tests_torch_compile_cuda tests/\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: cat reports/tests_torch_compile_cuda_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: torch_compile_test_reports\n        path: reports\n\n  run_xformers_tests:\n    name: PyTorch xformers CUDA tests\n\n    runs-on:\n      group: aws-g4dn-2xlarge\n\n    container:\n      image: diffusers/diffusers-pytorch-xformers-cuda\n      options: --gpus all --shm-size \"16gb\" --ipc host\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: NVIDIA-SMI\n      run: |\n        nvidia-smi\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality,training]\"\n    - name: Environment\n      run: |\n        python utils/print_env.py\n    - name: Run example tests on GPU\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n      run: |\n        pytest -n 1 --max-worker-restart=0 --dist=loadfile -k \"xformers\" --make-reports=tests_torch_xformers_cuda tests/\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: cat reports/tests_torch_xformers_cuda_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: torch_xformers_test_reports\n        path: reports\n\n  run_examples_tests:\n    name: Examples PyTorch CUDA tests on Ubuntu\n\n    runs-on:\n      group: aws-g4dn-2xlarge\n\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --gpus all --shm-size \"16gb\" --ipc host\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: NVIDIA-SMI\n      run: |\n        nvidia-smi\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality,training]\"\n\n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Run example tests on GPU\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n      run: |\n        uv pip install \".[training]\"\n        pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: |\n        cat reports/examples_torch_cuda_stats.txt\n        cat reports/examples_torch_cuda_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: examples_test_reports\n        path: reports\n"
  },
  {
    "path": ".github/workflows/push_tests_fast.yml",
    "content": "name: Fast tests on main\n\non:\n  push:\n    branches:\n      - main\n    paths:\n      - \"src/diffusers/**.py\"\n      - \"examples/**.py\"\n      - \"tests/**.py\"\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\nenv:\n  DIFFUSERS_IS_CI: yes\n  HF_HOME: /mnt/cache\n  OMP_NUM_THREADS: 8\n  MKL_NUM_THREADS: 8\n  HF_XET_HIGH_PERFORMANCE: 1\n  PYTEST_TIMEOUT: 600\n  RUN_SLOW: no\n\njobs:\n  run_fast_tests:\n    strategy:\n      fail-fast: false\n      matrix:\n        config:\n          - name: Fast PyTorch CPU tests on Ubuntu\n            framework: pytorch\n            runner: aws-general-8-plus\n            image: diffusers/diffusers-pytorch-cpu\n            report: torch_cpu\n          - name: PyTorch Example CPU tests on Ubuntu\n            framework: pytorch_examples\n            runner: aws-general-8-plus\n            image: diffusers/diffusers-pytorch-cpu\n            report: torch_example_cpu\n\n    name: ${{ matrix.config.name }}\n\n    runs-on:\n      group: ${{ matrix.config.runner }}\n\n    container:\n      image: ${{ matrix.config.image }}\n      options: --shm-size \"16gb\" --ipc host -v /mnt/hf_cache:/mnt/cache/\n\n    defaults:\n      run:\n        shell: bash\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality]\"\n\n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Run fast PyTorch CPU tests\n      if: ${{ matrix.config.framework == 'pytorch' }}\n      run: |\n        pytest -n 4 --max-worker-restart=0 --dist=loadfile \\\n          -k \"not Flax and not Onnx\" \\\n          --make-reports=tests_${{ matrix.config.report }} \\\n          tests/\n\n    - name: Run example PyTorch CPU tests\n      if: ${{ matrix.config.framework == 'pytorch_examples' }}\n      run: |\n        uv pip install \".[training]\"\n        pytest -n 4 --max-worker-restart=0 --dist=loadfile \\\n          --make-reports=tests_${{ matrix.config.report }} \\\n          examples\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: pr_${{ matrix.config.report }}_test_reports\n        path: reports\n"
  },
  {
    "path": ".github/workflows/push_tests_mps.yml",
    "content": "name: Fast mps tests on main\n\non:\n  workflow_dispatch:\n\nenv:\n  DIFFUSERS_IS_CI: yes\n  HF_HOME: /mnt/cache\n  OMP_NUM_THREADS: 8\n  MKL_NUM_THREADS: 8\n  HF_XET_HIGH_PERFORMANCE: 1\n  PYTEST_TIMEOUT: 600\n  RUN_SLOW: no\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  run_fast_tests_apple_m1:\n    name: Fast PyTorch MPS tests on MacOS\n    runs-on: macos-13-xlarge\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: Clean checkout\n      shell: arch -arch arm64 bash {0}\n      run: |\n        git clean -fxd\n\n    - name: Setup miniconda\n      uses: ./.github/actions/setup-miniconda\n      with:\n        python-version: 3.9\n\n    - name: Install dependencies\n      shell: arch -arch arm64 bash {0}\n      run: |\n        ${CONDA_RUN} python -m pip install --upgrade pip uv\n        ${CONDA_RUN} python -m uv pip install -e \".[quality]\"\n        ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio\n        ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git\n        ${CONDA_RUN} python -m uv pip install transformers --upgrade\n\n    - name: Environment\n      shell: arch -arch arm64 bash {0}\n      run: |\n        ${CONDA_RUN} python utils/print_env.py\n\n    - name: Run fast PyTorch tests on M1 (MPS)\n      shell: arch -arch arm64 bash {0}\n      env:\n        HF_HOME: /System/Volumes/Data/mnt/cache\n        HF_TOKEN: ${{ secrets.HF_TOKEN }}\n      run: |\n        ${CONDA_RUN} python -m pytest -n 0 --make-reports=tests_torch_mps tests/\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: cat reports/tests_torch_mps_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: pr_torch_mps_test_reports\n        path: reports\n"
  },
  {
    "path": ".github/workflows/pypi_publish.yaml",
    "content": "# Adapted from https://blog.deepjyoti30.dev/pypi-release-github-action\n\nname: PyPI release\n\non:\n  workflow_dispatch:\n  push:\n    tags:\n      - \"*\"\n\njobs:\n  find-and-checkout-latest-branch:\n    runs-on: ubuntu-22.04\n    outputs:\n      latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}\n    steps:\n      - name: Checkout Repo\n        uses: actions/checkout@v6\n\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: '3.10'\n\n      - name: Fetch latest branch\n        id: fetch_latest_branch\n        run: |\n          pip install -U requests packaging\n          LATEST_BRANCH=$(python utils/fetch_latest_release_branch.py)\n          echo \"Latest branch: $LATEST_BRANCH\"\n          echo \"latest_branch=$LATEST_BRANCH\" >> $GITHUB_ENV\n\n      - name: Set latest branch output\n        id: set_latest_branch\n        run: echo \"::set-output name=latest_branch::${{ env.latest_branch }}\"\n\n  release:\n    needs: find-and-checkout-latest-branch\n    runs-on: ubuntu-22.04\n\n    steps:\n      - name: Checkout Repo\n        uses: actions/checkout@v6\n        with:\n          ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }}\n\n      - name: Setup Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install -U setuptools wheel twine\n          pip install -U torch --index-url https://download.pytorch.org/whl/cpu\n\n      - name: Build the dist files\n        run: python setup.py bdist_wheel && python setup.py sdist\n\n      - name: Publish to the test PyPI\n        env:\n          TWINE_USERNAME: ${{ secrets.TEST_PYPI_USERNAME }}\n          TWINE_PASSWORD: ${{ secrets.TEST_PYPI_PASSWORD }}\n        run: twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/\n\n      - name: Test installing diffusers and importing\n        run: |\n          pip install diffusers && pip uninstall diffusers -y\n          pip install -i https://test.pypi.org/simple/ diffusers\n          pip install -U transformers\n          python utils/print_env.py\n          python -c \"from diffusers import __version__; print(__version__)\"\n          python -c \"from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()\"\n          python -c \"from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')\"\n          python -c \"from diffusers import *\"\n\n      - name: Publish to PyPI\n        env:\n          TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}\n          TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}\n        run: twine upload dist/* -r pypi\n"
  },
  {
    "path": ".github/workflows/release_tests_fast.yml",
    "content": "# Duplicate workflow to push_tests.yml that is meant to run on release/patch branches as a final check\n# Creating a duplicate workflow here is simpler than adding complex path/branch parsing logic to push_tests.yml\n# Needs to be updated if push_tests.yml updated\nname: (Release) Fast GPU Tests on main\n\non:\n  workflow_dispatch:\n  push:\n    branches:\n      - \"v*.*.*-release\"\n      - \"v*.*.*-patch\"\n\nenv:\n  DIFFUSERS_IS_CI: yes\n  OMP_NUM_THREADS: 8\n  MKL_NUM_THREADS: 8\n  PYTEST_TIMEOUT: 600\n  PIPELINE_USAGE_CUTOFF: 50000\n\njobs:\n  setup_torch_cuda_pipeline_matrix:\n    name: Setup Torch Pipelines CUDA Slow Tests Matrix\n    runs-on:\n      group: aws-general-8-plus\n    container:\n      image: diffusers/diffusers-pytorch-cpu\n    outputs:\n      pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n      - name: Install dependencies\n        run: |\n          uv pip install -e \".[quality]\"\n          uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n      - name: Environment\n        run: |\n          python utils/print_env.py\n      - name: Fetch Pipeline Matrix\n        id: fetch_pipeline_matrix\n        run: |\n          matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)\n          echo $matrix\n          echo \"pipeline_test_matrix=$matrix\" >> $GITHUB_OUTPUT\n      - name: Pipeline Tests Artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: test-pipelines.json\n          path: reports\n\n  torch_pipelines_cuda_tests:\n    name: Torch Pipelines CUDA Tests\n    needs: setup_torch_cuda_pipeline_matrix\n    strategy:\n      fail-fast: false\n      max-parallel: 8\n      matrix:\n        module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}\n    runs-on:\n      group: aws-g4dn-2xlarge\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --shm-size \"16gb\" --ipc host --gpus all\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n      - name: NVIDIA-SMI\n        run: |\n          nvidia-smi\n      - name: Install dependencies\n        run: |\n          uv pip install -e \".[quality]\"\n          uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git\n          uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n      - name: Environment\n        run: |\n          python utils/print_env.py\n      - name: Slow PyTorch CUDA checkpoint tests on Ubuntu\n        env:\n          HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n          # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n          CUBLAS_WORKSPACE_CONFIG: :16:8\n        run: |\n          pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n            -k \"not Flax and not Onnx\" \\\n            --make-reports=tests_pipeline_${{ matrix.module }}_cuda \\\n            tests/pipelines/${{ matrix.module }}\n      - name: Failure short reports\n        if: ${{ failure() }}\n        run: |\n          cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt\n          cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt\n      - name: Test suite reports artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: pipeline_${{ matrix.module }}_test_reports\n          path: reports\n\n  torch_cuda_tests:\n    name: Torch CUDA Tests\n    runs-on:\n      group: aws-g4dn-2xlarge\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --shm-size \"16gb\" --ipc host --gpus all\n    defaults:\n      run:\n        shell: bash\n    strategy:\n      fail-fast: false\n      max-parallel: 2\n      matrix:\n        module: [models, schedulers, lora, others, single_file]\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality]\"\n        uv pip install peft@git+https://github.com/huggingface/peft.git\n        uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git\n        uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n\n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Run PyTorch CUDA tests\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n        # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n        CUBLAS_WORKSPACE_CONFIG: :16:8\n      run: |\n        pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n          -k \"not Flax and not Onnx\" \\\n          --make-reports=tests_torch_${{ matrix.module }}_cuda \\\n          tests/${{ matrix.module }}\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: |\n        cat reports/tests_torch_${{ matrix.module }}_cuda_stats.txt\n        cat reports/tests_torch_${{ matrix.module }}_cuda_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: torch_cuda_${{ matrix.module }}_test_reports\n        path: reports\n\n  torch_minimum_version_cuda_tests:\n    name: Torch Minimum Version CUDA Tests\n    runs-on:\n      group: aws-g4dn-2xlarge\n    container:\n      image: diffusers/diffusers-pytorch-minimum-cuda\n      options: --shm-size \"16gb\" --ipc host --gpus all\n    defaults:\n      run:\n        shell: bash\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n\n      - name: Install dependencies\n        run: |\n          uv pip install -e \".[quality]\"\n          uv pip install peft@git+https://github.com/huggingface/peft.git\n          uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git\n          uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n\n      - name: Environment\n        run: |\n          python utils/print_env.py\n\n      - name: Run PyTorch CUDA tests\n        env:\n          HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n          # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms\n          CUBLAS_WORKSPACE_CONFIG: :16:8\n        run: |\n          pytest -n 1 --max-worker-restart=0 --dist=loadfile \\\n            -k \"not Flax and not Onnx\" \\\n            --make-reports=tests_torch_minimum_cuda \\\n            tests/models/test_modeling_common.py \\\n            tests/pipelines/test_pipelines_common.py \\\n            tests/pipelines/test_pipeline_utils.py \\\n            tests/pipelines/test_pipelines.py \\\n            tests/pipelines/test_pipelines_auto.py \\\n            tests/schedulers/test_schedulers.py \\\n            tests/others\n\n      - name: Failure short reports\n        if: ${{ failure() }}\n        run: |\n          cat reports/tests_torch_minimum_version_cuda_stats.txt\n          cat reports/tests_torch_minimum_version_cuda_failures_short.txt\n\n      - name: Test suite reports artifacts\n        if: ${{ always() }}\n        uses: actions/upload-artifact@v6\n        with:\n          name: torch_minimum_version_cuda_test_reports\n          path: reports\n\n  run_torch_compile_tests:\n    name: PyTorch Compile CUDA tests\n\n    runs-on:\n      group: aws-g4dn-2xlarge\n\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --gpus all --shm-size \"16gb\" --ipc host\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: NVIDIA-SMI\n      run: |\n        nvidia-smi\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality,training]\"\n        uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n    - name: Environment\n      run: |\n        python utils/print_env.py\n    - name: Run torch compile tests on GPU\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n        RUN_COMPILE: yes\n      run: |\n        pytest -n 1 --max-worker-restart=0 --dist=loadfile -k \"compile\" --make-reports=tests_torch_compile_cuda tests/\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: cat reports/tests_torch_compile_cuda_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: torch_compile_test_reports\n        path: reports\n\n  run_xformers_tests:\n    name: PyTorch xformers CUDA tests\n\n    runs-on:\n      group: aws-g4dn-2xlarge\n\n    container:\n      image: diffusers/diffusers-pytorch-xformers-cuda\n      options: --gpus all --shm-size \"16gb\" --ipc host\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: NVIDIA-SMI\n      run: |\n        nvidia-smi\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality,training]\"\n        uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n    - name: Environment\n      run: |\n        python utils/print_env.py\n    - name: Run example tests on GPU\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n      run: |\n        pytest -n 1 --max-worker-restart=0 --dist=loadfile -k \"xformers\" --make-reports=tests_torch_xformers_cuda tests/\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: cat reports/tests_torch_xformers_cuda_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: torch_xformers_test_reports\n        path: reports\n\n  run_examples_tests:\n    name: Examples PyTorch CUDA tests on Ubuntu\n\n    runs-on:\n      group: aws-g4dn-2xlarge\n\n    container:\n      image: diffusers/diffusers-pytorch-cuda\n      options: --gpus all --shm-size \"16gb\" --ipc host\n\n    steps:\n    - name: Checkout diffusers\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 2\n\n    - name: NVIDIA-SMI\n      run: |\n        nvidia-smi\n\n    - name: Install dependencies\n      run: |\n        uv pip install -e \".[quality,training]\"\n        uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git\n\n    - name: Environment\n      run: |\n        python utils/print_env.py\n\n    - name: Run example tests on GPU\n      env:\n        HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}\n      run: |\n        uv pip install \".[training]\"\n        pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/\n\n    - name: Failure short reports\n      if: ${{ failure() }}\n      run: |\n        cat reports/examples_torch_cuda_stats.txt\n        cat reports/examples_torch_cuda_failures_short.txt\n\n    - name: Test suite reports artifacts\n      if: ${{ always() }}\n      uses: actions/upload-artifact@v6\n      with:\n        name: examples_test_reports\n        path: reports\n"
  },
  {
    "path": ".github/workflows/run_tests_from_a_pr.yml",
    "content": "name: Check running SLOW tests from a PR (only GPU)\n\non:\n  workflow_dispatch:\n    inputs:\n      docker_image:\n        default: 'diffusers/diffusers-pytorch-cuda'\n        description: 'Name of the Docker image'\n        required: true\n      pr_number:\n        description: 'PR number to test on'\n        required: true\n      test:\n        description: 'Tests to run (e.g.: `tests/models`).'\n        required: true\n\nenv:\n  DIFFUSERS_IS_CI: yes\n  IS_GITHUB_CI: \"1\"\n  HF_HOME: /mnt/cache\n  OMP_NUM_THREADS: 8\n  MKL_NUM_THREADS: 8\n  PYTEST_TIMEOUT: 600\n  RUN_SLOW: yes\n\njobs:\n  run_tests:\n    name: \"Run a test on our runner from a PR\"\n    runs-on:\n      group: aws-g4dn-2xlarge\n    container:\n      image: ${{ github.event.inputs.docker_image }}\n      options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/\n\n    steps:\n      - name: Validate test files input\n        id: validate_test_files\n        env:\n          PY_TEST: ${{ github.event.inputs.test }}\n        run: |\n          if [[ ! \"$PY_TEST\" =~ ^tests/ ]]; then\n            echo \"Error: The input string must start with 'tests/'.\"\n            exit 1\n          fi\n\n          if [[ ! \"$PY_TEST\" =~ ^tests/(models|pipelines|lora) ]]; then\n            echo \"Error: The input string must contain either 'models', 'pipelines', or 'lora' after 'tests/'.\"\n            exit 1\n          fi\n\n          if [[ \"$PY_TEST\" == *\";\"* ]]; then\n            echo \"Error: The input string must not contain ';'.\"\n            exit 1\n          fi\n          echo \"$PY_TEST\"\n        \n        shell: bash -e {0}\n\n      - name: Checkout PR branch\n        uses: actions/checkout@v6\n        with:\n          ref: refs/pull/${{ inputs.pr_number }}/head\n\n      - name: Install pytest\n        run: |\n          uv pip install -e \".[quality]\"\n          uv pip install peft\n\n      - name: Run tests\n        env:\n            PY_TEST: ${{ github.event.inputs.test }}\n        run: |\n          pytest \"$PY_TEST\"\n"
  },
  {
    "path": ".github/workflows/ssh-pr-runner.yml",
    "content": "name: SSH into PR runners\n\non:\n  workflow_dispatch:\n    inputs:\n      docker_image:\n        description: 'Name of the Docker image'\n        required: true\n\nenv:\n  IS_GITHUB_CI: \"1\"\n  HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}\n  HF_HOME: /mnt/cache\n  DIFFUSERS_IS_CI: yes\n  OMP_NUM_THREADS: 8\n  MKL_NUM_THREADS: 8\n  RUN_SLOW: yes\n\njobs:\n  ssh_runner:\n    name: \"SSH\"\n    runs-on:\n      group: aws-highmemory-32-plus\n    container:\n      image: ${{ github.event.inputs.docker_image }}\n      options: --shm-size \"16gb\" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --privileged\n\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n\n      - name: Tailscale # In order to be able to SSH when a test fails\n        uses: huggingface/tailscale-action@main\n        with:\n          authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}\n          slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}\n          slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}\n          waitForSSH: true\n"
  },
  {
    "path": ".github/workflows/ssh-runner.yml",
    "content": "name: SSH into GPU runners\n\non:\n  workflow_dispatch:\n    inputs:\n      runner_type:\n        description: 'Type of runner to test (aws-g6-4xlarge-plus: a10, aws-g4dn-2xlarge: t4, aws-g6e-xlarge-plus: L40)'\n        type: choice\n        required: true\n        options:\n          - aws-g6-4xlarge-plus\n          - aws-g4dn-2xlarge\n          - aws-g6e-xlarge-plus\n      docker_image:\n        description: 'Name of the Docker image'\n        required: true\n\nenv:\n  IS_GITHUB_CI: \"1\"\n  HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}\n  HF_HOME: /mnt/cache\n  DIFFUSERS_IS_CI: yes\n  OMP_NUM_THREADS: 8\n  MKL_NUM_THREADS: 8\n  RUN_SLOW: yes\n\njobs:\n  ssh_runner:\n    name: \"SSH\"\n    runs-on:\n      group: \"${{ github.event.inputs.runner_type }}\"\n    container:\n      image: ${{ github.event.inputs.docker_image }}\n      options: --shm-size \"16gb\" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus all --privileged\n\n    steps:\n      - name: Checkout diffusers\n        uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n\n      - name: NVIDIA-SMI\n        run: |\n          nvidia-smi\n\n      - name: Tailscale # In order to be able to SSH when a test fails\n        uses: huggingface/tailscale-action@main\n        with:\n          authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}\n          slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}\n          slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}\n          waitForSSH: true\n"
  },
  {
    "path": ".github/workflows/stale.yml",
    "content": "name: Stale Bot\n\non:\n  schedule:\n    - cron: \"0 15 * * *\"\n\njobs:\n  close_stale_issues:\n    name: Close Stale Issues\n    if: github.repository == 'huggingface/diffusers'\n    runs-on: ubuntu-22.04\n    permissions:\n      issues: write\n      pull-requests: write\n    env:\n      GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n    steps:\n    - uses: actions/checkout@v6\n\n    - name: Setup Python\n      uses: actions/setup-python@v6\n      with:\n        python-version: 3.10\n\n    - name: Install requirements\n      run: |\n        pip install PyGithub\n    - name: Close stale issues\n      run: |\n        python utils/stale.py\n"
  },
  {
    "path": ".github/workflows/trufflehog.yml",
    "content": "on:\n  push:\n\nname: Secret Leaks\n\njobs:\n  trufflehog:\n    runs-on: ubuntu-22.04\n    steps:\n    - name: Checkout code\n      uses: actions/checkout@v6\n      with:\n        fetch-depth: 0\n    - name: Secret Scanning\n      uses: trufflesecurity/trufflehog@main\n      with:\n        extra_args: --results=verified,unknown\n\n"
  },
  {
    "path": ".github/workflows/typos.yml",
    "content": "name: Check typos\n\non:\n  workflow_dispatch:\n\njobs:\n  build:\n    runs-on: ubuntu-22.04\n\n    steps:\n      - uses: actions/checkout@v6\n\n      - name: typos-action\n        uses: crate-ci/typos@v1.42.1\n"
  },
  {
    "path": ".github/workflows/update_metadata.yml",
    "content": "name: Update Diffusers metadata\n\non:\n  workflow_dispatch:\n  push:\n    branches:\n      - main\n      - update_diffusers_metadata*\n\njobs:\n  update_metadata:\n    runs-on: ubuntu-22.04\n    defaults:\n      run:\n        shell: bash -l {0}\n\n    steps:\n      - uses: actions/checkout@v6\n\n      - name: Setup environment\n        run: |\n          pip install --upgrade pip\n          pip install datasets pandas\n          pip install .[torch]\n\n      - name: Update metadata\n        env:\n          HF_TOKEN: ${{ secrets.SAYAK_HF_TOKEN }}\n        run: |\n          python utils/update_metadata.py --commit_sha ${{ github.sha }}\n"
  },
  {
    "path": ".github/workflows/upload_pr_documentation.yml",
    "content": "name: Upload PR Documentation\n\non:\n  workflow_run:\n    workflows: [\"Build PR Documentation\"]\n    types:\n      - completed\n\njobs:\n  build:\n    uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main\n    with:\n      package_name: diffusers\n    secrets:\n      hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}\n      comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}\n"
  },
  {
    "path": ".gitignore",
    "content": "# Initially taken from GitHub's Python gitignore file\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# tests and logs\ntests/fixtures/cached_*_text.txt\nlogs/\nlightning_logs/\nlang_code_data/\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a Python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# vscode\n.vs\n.vscode\n\n# Cursor\n.cursor\n\n# Pycharm\n.idea\n\n# TF code\ntensorflow_code\n\n# Models\nproc_data\n\n# examples\nruns\n/runs_old\n/wandb\n/examples/runs\n/examples/**/*.args\n/examples/rag/sweep\n\n# data\n/data\nserialization_dir\n\n# emacs\n*.*~\ndebug.env\n\n# vim\n.*.swp\n\n# ctags\ntags\n\n# pre-commit\n.pre-commit*\n\n# .lock\n*.lock\n\n# DS_Store (MacOS)\n.DS_Store\n\n# RL pipelines may produce mp4 outputs\n*.mp4\n\n# dependencies\n/transformers\n\n# ruff\n.ruff_cache\n\n# wandb\nwandb\n\n# AI agent generated symlinks\n/AGENTS.md\n/CLAUDE.md\n/.agents/skills\n/.claude/skills"
  },
  {
    "path": "CITATION.cff",
    "content": "cff-version: 1.2.0\ntitle: 'Diffusers: State-of-the-art diffusion models'\nmessage: >-\n  If you use this software, please cite it using the\n  metadata from this file.\ntype: software\nauthors:\n  - given-names: Patrick\n    family-names: von Platen\n  - given-names: Suraj\n    family-names: Patil\n  - given-names: Anton\n    family-names: Lozhkov\n  - given-names: Pedro\n    family-names: Cuenca\n  - given-names: Nathan\n    family-names: Lambert\n  - given-names: Kashif\n    family-names: Rasul\n  - given-names: Mishig\n    family-names: Davaadorj\n  - given-names: Dhruv\n    family-names: Nair\n  - given-names: Sayak\n    family-names: Paul\n  - given-names: Steven\n    family-names: Liu\n  - given-names: William\n    family-names: Berman\n  - given-names: Yiyi\n    family-names: Xu\n  - given-names: Thomas\n    family-names: Wolf\nrepository-code: 'https://github.com/huggingface/diffusers'\nabstract: >-\n  Diffusers provides pretrained diffusion models across\n  multiple modalities, such as vision and audio, and serves\n  as a modular toolbox for inference and training of\n  diffusion models.\nkeywords:\n  - deep-learning\n  - pytorch\n  - image-generation\n  - hacktoberfest\n  - diffusion\n  - text2image\n  - image2image\n  - score-based-generative-modeling\n  - stable-diffusion\n  - stable-diffusion-diffusers\nlicense: Apache-2.0\nversion: 0.12.1\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "\n# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participation in our\ncommunity a harassment-free experience for everyone, regardless of age, body\nsize, visible or invisible disability, ethnicity, sex characteristics, gender\nidentity and expression, level of experience, education, socio-economic status,\nnationality, personal appearance, race, caste, color, religion, or sexual identity\nand orientation.\n\nWe pledge to act and interact in ways that contribute to an open, welcoming,\ndiverse, inclusive, and healthy community.\n\n## Our Standards\n\nExamples of behavior that contributes to a positive environment for our\ncommunity include:\n\n* Demonstrating empathy and kindness toward other people\n* Being respectful of differing opinions, viewpoints, and experiences\n* Giving and gracefully accepting constructive feedback\n* Accepting responsibility and apologizing to those affected by our mistakes,\n  and learning from the experience\n* Focusing on what is best not just for us as individuals, but for the\n  overall Diffusers community\n\nExamples of unacceptable behavior include:\n\n* The use of sexualized language or imagery, and sexual attention or\n  advances of any kind\n* Trolling, insulting or derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or email\n  address, without their explicit permission\n* Spamming issues or PRs with links to projects unrelated to this library\n* Other conduct which could reasonably be considered inappropriate in a\n  professional setting\n\n## Enforcement Responsibilities\n\nCommunity leaders are responsible for clarifying and enforcing our standards of\nacceptable behavior and will take appropriate and fair corrective action in\nresponse to any behavior that they deem inappropriate, threatening, offensive,\nor harmful.\n\nCommunity leaders have the right and responsibility to remove, edit, or reject\ncomments, commits, code, wiki edits, issues, and other contributions that are\nnot aligned to this Code of Conduct, and will communicate reasons for moderation\ndecisions when appropriate.\n\n## Scope\n\nThis Code of Conduct applies within all community spaces, and also applies when\nan individual is officially representing the community in public spaces.\nExamples of representing our community include using an official e-mail address,\nposting via an official social media account, or acting as an appointed\nrepresentative at an online or offline event.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported to the community leaders responsible for enforcement at\nfeedback@huggingface.co.\nAll complaints will be reviewed and investigated promptly and fairly.\n\nAll community leaders are obligated to respect the privacy and security of the\nreporter of any incident.\n\n## Enforcement Guidelines\n\nCommunity leaders will follow these Community Impact Guidelines in determining\nthe consequences for any action they deem in violation of this Code of Conduct:\n\n### 1. Correction\n\n**Community Impact**: Use of inappropriate language or other behavior deemed\nunprofessional or unwelcome in the community.\n\n**Consequence**: A private, written warning from community leaders, providing\nclarity around the nature of the violation and an explanation of why the\nbehavior was inappropriate. A public apology may be requested.\n\n### 2. Warning\n\n**Community Impact**: A violation through a single incident or series\nof actions.\n\n**Consequence**: A warning with consequences for continued behavior. No\ninteraction with the people involved, including unsolicited interaction with\nthose enforcing the Code of Conduct, for a specified period of time. This\nincludes avoiding interactions in community spaces as well as external channels\nlike social media. Violating these terms may lead to a temporary or\npermanent ban.\n\n### 3. Temporary Ban\n\n**Community Impact**: A serious violation of community standards, including\nsustained inappropriate behavior.\n\n**Consequence**: A temporary ban from any sort of interaction or public\ncommunication with the community for a specified period of time. No public or\nprivate interaction with the people involved, including unsolicited interaction\nwith those enforcing the Code of Conduct, is allowed during this period.\nViolating these terms may lead to a permanent ban.\n\n### 4. Permanent Ban\n\n**Community Impact**: Demonstrating a pattern of violation of community\nstandards, including sustained inappropriate behavior,  harassment of an\nindividual, or aggression toward or disparagement of classes of individuals.\n\n**Consequence**: A permanent ban from any sort of public interaction within\nthe community.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage],\nversion 2.1, available at\nhttps://www.contributor-covenant.org/version/2/1/code_of_conduct.html.\n\nCommunity Impact Guidelines were inspired by [Mozilla's code of conduct\nenforcement ladder](https://github.com/mozilla/diversity).\n\n[homepage]: https://www.contributor-covenant.org\n\nFor answers to common questions about this code of conduct, see the FAQ at\nhttps://www.contributor-covenant.org/faq. Translations are available at\nhttps://www.contributor-covenant.org/translations.\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, Any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include LICENSE\ninclude src/diffusers/utils/model_card_template.md\n"
  },
  {
    "path": "Makefile",
    "content": ".PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples codex claude clean-ai\n\n# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)\nexport PYTHONPATH = src\n\ncheck_dirs := examples scripts src tests utils benchmarks\n\nmodified_only_fixup:\n\t$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))\n\t@if test -n \"$(modified_py_files)\"; then \\\n\t\techo \"Checking/fixing $(modified_py_files)\"; \\\n\t\truff check $(modified_py_files) --fix; \\\n\t\truff format $(modified_py_files);\\\n\telse \\\n\t\techo \"No library .py files were modified\"; \\\n\tfi\n\n# Update src/diffusers/dependency_versions_table.py\n\ndeps_table_update:\n\t@python setup.py deps_table_update\n\ndeps_table_check_updated:\n\t@md5sum src/diffusers/dependency_versions_table.py > md5sum.saved\n\t@python setup.py deps_table_update\n\t@md5sum -c --quiet md5sum.saved || (printf \"\\nError: the version dependency table is outdated.\\nPlease run 'make fixup' or 'make style' and commit the changes.\\n\\n\" && exit 1)\n\t@rm md5sum.saved\n\n# autogenerating code\n\nautogenerate_code: deps_table_update\n\n# Check that the repo is in a good state\n\nrepo-consistency:\n\tpython utils/check_dummies.py\n\tpython utils/check_repo.py\n\tpython utils/check_inits.py\n\n# this target runs checks on all files\n\nquality:\n\truff check $(check_dirs) setup.py\n\truff format --check $(check_dirs) setup.py\n\tdoc-builder style src/diffusers docs/source --max_len 119 --check_only\n\tpython utils/check_doc_toc.py\n\n# Format source code automatically and check is there are any problems left that need manual fixing\n\nextra_style_checks:\n\tpython utils/custom_init_isort.py\n\tpython utils/check_doc_toc.py --fix_and_overwrite\n\n# this target runs checks on all files and potentially modifies some of them\n\nstyle:\n\truff check $(check_dirs) setup.py --fix\n\truff format $(check_dirs) setup.py\n\tdoc-builder style src/diffusers docs/source --max_len 119\n\t${MAKE} autogenerate_code\n\t${MAKE} extra_style_checks\n\n# Super fast fix and check target that only works on relevant modified files since the branch was made\n\nfixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency\n\n# Make marked copies of snippets of codes conform to the original\n\nfix-copies:\n\tpython utils/check_copies.py --fix_and_overwrite\n\tpython utils/check_dummies.py --fix_and_overwrite\n\n# Auto docstrings in modular blocks\nmodular-autodoctrings:\n\tpython utils/modular_auto_docstring.py\n\n# Run tests for the library\n\ntest:\n\tpython -m pytest -n auto --dist=loadfile -s -v ./tests/\n\n# Run tests for examples\n\ntest-examples:\n\tpython -m pytest -n auto --dist=loadfile -s -v ./examples/\n\n\n# Release stuff\n\npre-release:\n\tpython utils/release.py\n\npre-patch:\n\tpython utils/release.py --patch\n\npost-release:\n\tpython utils/release.py --post_release\n\npost-patch:\n\tpython utils/release.py --post_release --patch\n\n# AI agent symlinks\n\ncodex:\n\tln -snf .ai/AGENTS.md AGENTS.md\n\tmkdir -p .agents\n\trm -rf .agents/skills\n\tln -snf ../.ai/skills .agents/skills\n\nclaude:\n\tln -snf .ai/AGENTS.md CLAUDE.md\n\tmkdir -p .claude\n\trm -rf .claude/skills\n\tln -snf ../.ai/skills .claude/skills\n\nclean-ai:\n\trm -f AGENTS.md CLAUDE.md\n\trm -rf .agents/skills .claude/skills\n"
  },
  {
    "path": "PHILOSOPHY.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Philosophy\n\n🧨 Diffusers provides **state-of-the-art** pretrained diffusion models across multiple modalities.\nIts purpose is to serve as a **modular toolbox** for both inference and training.\n\nWe aim to build a library that stands the test of time and therefore take API design very seriously.\n\nIn a nutshell, Diffusers is built to be a natural extension of PyTorch. Therefore, most of our design choices are based on [PyTorch's Design Principles](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy). Let's go over the most important ones:\n\n## Usability over Performance\n\n- While Diffusers has many built-in performance-enhancing features (see [Memory and Speed](https://huggingface.co/docs/diffusers/optimization/fp16)), models are always loaded with the highest precision and lowest optimization. Therefore, by default diffusion pipelines are always instantiated on CPU with float32 precision if not otherwise defined by the user. This ensures usability across different platforms and accelerators and means that no complex installations are required to run the library.\n- Diffusers aims to be a **light-weight** package and therefore has very few required dependencies, but many soft dependencies that can improve performance (such as `accelerate`, `safetensors`, `onnx`, etc...). We strive to keep the library as lightweight as possible so that it can be added without much concern as a dependency on other packages.\n- Diffusers prefers simple, self-explainable code over condensed, magic code. This means that short-hand code syntaxes such as lambda functions, and advanced PyTorch operators are often not desired.\n\n## Simple over easy\n\nAs PyTorch states, **explicit is better than implicit** and **simple is better than complex**. This design philosophy is reflected in multiple parts of the library:\n- We follow PyTorch's API with methods like [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to) to let the user handle device management.\n- Raising concise error messages is preferred to silently correct erroneous input. Diffusers aims at teaching the user, rather than making the library as easy to use as possible.\n- Complex model vs. scheduler logic is exposed instead of magically handled inside. Schedulers/Samplers are separated from diffusion models with minimal dependencies on each other. This forces the user to write the unrolled denoising loop. However, the separation allows for easier debugging and gives the user more control over adapting the denoising process or switching out diffusion models or schedulers.\n- Separately trained components of the diffusion pipeline, *e.g.* the text encoder, the UNet, and the variational autoencoder, each has their own model class. This forces the user to handle the interaction between the different model components, and the serialization format separates the model components into different files. However, this allows for easier debugging and customization. DreamBooth or Textual Inversion training\nis very simple thanks to Diffusers' ability to separate single components of the diffusion pipeline.\n\n## Tweakable, contributor-friendly over abstraction\n\nFor large parts of the library, Diffusers adopts an important design principle of the [Transformers library](https://github.com/huggingface/transformers), which is to prefer copy-pasted code over hasty abstractions. This design principle is very opinionated and stands in stark contrast to popular design principles such as [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself).\nIn short, just like Transformers does for modeling files, Diffusers prefers to keep an extremely low level of abstraction and very self-contained code for pipelines and schedulers.\nFunctions, long code blocks, and even classes can be copied across multiple files which at first can look like a bad, sloppy design choice that makes the library unmaintainable.\n**However**, this design has proven to be extremely successful for Transformers and makes a lot of sense for community-driven, open-source machine learning libraries because:\n- Machine Learning is an extremely fast-moving field in which paradigms, model architectures, and algorithms are changing rapidly, which therefore makes it very difficult to define long-lasting code abstractions.\n- Machine Learning practitioners like to be able to quickly tweak existing code for ideation and research and therefore prefer self-contained code over one that contains many abstractions.\n- Open-source libraries rely on community contributions and therefore must build a library that is easy to contribute to. The more abstract the code, the more dependencies, the harder to read, and the harder to contribute to. Contributors simply stop contributing to very abstract libraries out of fear of breaking vital functionality. If contributing to a library cannot break other fundamental code, not only is it more inviting for potential new contributors, but it is also easier to review and contribute to multiple parts in parallel.\n\nAt Hugging Face, we call this design the **single-file policy** which means that almost all of the code of a certain class should be written in a single, self-contained file. To read more about the philosophy, you can have a look\nat [this blog post](https://huggingface.co/blog/transformers-design-philosophy).\n\nIn Diffusers, we follow this philosophy for both pipelines and schedulers, but only partly for diffusion models. The reason we don't follow this design fully for diffusion models is because almost all diffusion pipelines, such\nas [DDPM](https://huggingface.co/docs/diffusers/api/pipelines/ddpm), [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview#stable-diffusion-pipelines), [unCLIP (DALL·E 2)](https://huggingface.co/docs/diffusers/api/pipelines/unclip) and [Imagen](https://imagen.research.google/) all rely on the same diffusion model, the [UNet](https://huggingface.co/docs/diffusers/api/models/unet2d-cond).\n\nGreat, now you should have generally understood why 🧨 Diffusers is designed the way it is 🤗.\nWe try to apply these design principles consistently across the library. Nevertheless, there are some minor exceptions to the philosophy or some unlucky design choices. If you have feedback regarding the design, we would ❤️  to hear it [directly on GitHub](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).\n\n## Design Philosophy in Details\n\nNow, let's look a bit into the nitty-gritty details of the design philosophy. Diffusers essentially consists of three major classes: [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines), [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models), and [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).\nLet's walk through more detailed design decisions for each class.\n\n### Pipelines\n\nPipelines are designed to be easy to use (therefore do not follow [*Simple over easy*](#simple-over-easy) 100%), are not feature complete, and should loosely be seen as examples of how to use [models](#models) and [schedulers](#schedulers) for inference.\n\nThe following design principles are followed:\n- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as it’s done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [# Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251).\n- Pipelines all inherit from [`DiffusionPipeline`].\n- Every pipeline consists of different model and scheduler components, that are documented in the [`model_index.json` file](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/model_index.json), are accessible under the same name as attributes of the pipeline and can be shared between pipelines with [`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components) function.\n- Every pipeline should be loadable via the [`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained) function.\n- Pipelines should be used **only** for inference.\n- Pipelines should be very readable, self-explanatory, and easy to tweak.\n- Pipelines should be designed to build on top of each other and be easy to integrate into higher-level APIs.\n- Pipelines are **not** intended to be feature-complete user interfaces. For feature-complete user interfaces one should rather have a look at [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), and [lama-cleaner](https://github.com/Sanster/lama-cleaner).\n- Every pipeline should have one and only one way to run it via a `__call__` method. The naming of the `__call__` arguments should be shared across all pipelines.\n- Pipelines should be named after the task they are intended to solve.\n- In almost all cases, novel diffusion pipelines shall be implemented in a new pipeline folder/file.\n\n### Models\n\nModels are designed as configurable toolboxes that are natural extensions of [PyTorch's Module class](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). They only partly follow the **single-file policy**.\n\nThe following design principles are followed:\n- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.\n- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unets/unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py), [`transformers/transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py), etc...\n- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.\n- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.\n- Models all inherit from `ModelMixin` and `ConfigMixin`.\n- Models can be optimized for performance when it doesn’t demand major code changes, keep backward compatibility, and give significant memory or compute gain.\n- Models should by default have the highest precision and lowest performance setting.\n- To integrate new model checkpoints whose general architecture can be classified as an architecture that already exists in Diffusers, the existing model architecture shall be adapted to make it work with the new checkpoint. One should only create a new file if the model architecture is fundamentally different.\n- Models should be designed to be easily extendable to future changes. This can be achieved by limiting public function arguments, configuration arguments, and \"foreseeing\" future changes, *e.g.* it is usually better to add `string` \"...type\" arguments that can easily be extended to new future types instead of boolean `is_..._type` arguments. Only the minimum amount of changes shall be made to existing architectures to make a new model checkpoint work.\n- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and\nreadable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n\n### Schedulers\n\nSchedulers are responsible to guide the denoising process for inference as well as to define a noise schedule for training. They are designed as individual classes with loadable configuration files and strongly follow the **single-file policy**.\n\nThe following design principles are followed:\n- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).\n- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained.\n- One scheduler Python file corresponds to one scheduler algorithm (as might be defined in a paper).\n- If schedulers share similar functionalities, we can make use of the `# Copied from` mechanism.\n- Schedulers all inherit from `SchedulerMixin` and `ConfigMixin`.\n- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](./docs/source/en/using-diffusers/schedulers.md).\n- Every scheduler has to have a `set_num_inference_steps`, and a `step` function. `set_num_inference_steps(...)` has to be called before every denoising process, *i.e.* before `step(...)` is called.\n- Every scheduler exposes the timesteps to be \"looped over\" via a `timesteps` attribute, which is an array of timesteps the model will be called upon.\n- The `step(...)` function takes a predicted model output and the \"current\" sample (x_t) and returns the \"previous\", slightly more denoised sample (x_t-1).\n- Given the complexity of diffusion schedulers, the `step` function does not expose all the complexity and can be a bit of a \"black box\".\n- In almost all cases, novel schedulers shall be implemented in a new scheduling file.\n"
  },
  {
    "path": "README.md",
    "content": "<!---\nCopyright 2022 - The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/en/imgs/diffusers_library.jpg\" width=\"400\"/>\n    <br>\n<p>\n<p align=\"center\">\n    <a href=\"https://github.com/huggingface/diffusers/blob/main/LICENSE\"><img alt=\"GitHub\" src=\"https://img.shields.io/github/license/huggingface/datasets.svg?color=blue\"></a>\n    <a href=\"https://github.com/huggingface/diffusers/releases\"><img alt=\"GitHub release\" src=\"https://img.shields.io/github/release/huggingface/diffusers.svg\"></a>\n    <a href=\"https://pepy.tech/project/diffusers\"><img alt=\"GitHub release\" src=\"https://static.pepy.tech/badge/diffusers/month\"></a>\n    <a href=\"CODE_OF_CONDUCT.md\"><img alt=\"Contributor Covenant\" src=\"https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg\"></a>\n    <a href=\"https://twitter.com/diffuserslib\"><img alt=\"X account\" src=\"https://img.shields.io/twitter/url/https/twitter.com/diffuserslib.svg?style=social&label=Follow%20%40diffuserslib\"></a>\n</p>\n\n🤗 Diffusers is the go-to library for state-of-the-art pretrained diffusion models for generating images, audio, and even 3D structures of molecules. Whether you're looking for a simple inference solution or training your own diffusion models, 🤗 Diffusers is a modular toolbox that supports both. Our library is designed with a focus on [usability over performance](https://huggingface.co/docs/diffusers/conceptual/philosophy#usability-over-performance), [simple over easy](https://huggingface.co/docs/diffusers/conceptual/philosophy#simple-over-easy), and [customizability over abstractions](https://huggingface.co/docs/diffusers/conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).\n\n🤗 Diffusers offers three core components:\n\n- State-of-the-art [diffusion pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) that can be run in inference with just a few lines of code.\n- Interchangeable noise [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview) for different diffusion speeds and output quality.\n- Pretrained [models](https://huggingface.co/docs/diffusers/api/models/overview) that can be used as building blocks, and combined with schedulers, for creating your own end-to-end diffusion systems.\n\n## Installation\n\nWe recommend installing 🤗 Diffusers in a virtual environment from PyPI or Conda. For more details about installing [PyTorch](https://pytorch.org/get-started/locally/), please refer to their official documentation.\n\n### PyTorch\n\nWith `pip` (official package):\n\n```bash\npip install --upgrade diffusers[torch]\n```\n\nWith `conda` (maintained by the community):\n\n```sh\nconda install -c conda-forge diffusers\n```\n\n### Apple Silicon (M1/M2) support\n\nPlease refer to the [How to use Stable Diffusion in Apple Silicon](https://huggingface.co/docs/diffusers/optimization/mps) guide.\n\n## Quickstart\n\nGenerating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 30,000+ checkpoints):\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16)\npipeline.to(\"cuda\")\npipeline(\"An image of a squirrel in Picasso style\").images[0]\n```\n\nYou can also dig into the models and schedulers toolbox to build your own diffusion system:\n\n```python\nfrom diffusers import DDPMScheduler, UNet2DModel\nfrom PIL import Image\nimport torch\n\nscheduler = DDPMScheduler.from_pretrained(\"google/ddpm-cat-256\")\nmodel = UNet2DModel.from_pretrained(\"google/ddpm-cat-256\").to(\"cuda\")\nscheduler.set_timesteps(50)\n\nsample_size = model.config.sample_size\nnoise = torch.randn((1, 3, sample_size, sample_size), device=\"cuda\")\ninput = noise\n\nfor t in scheduler.timesteps:\n    with torch.no_grad():\n        noisy_residual = model(input, t).sample\n        prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample\n        input = prev_noisy_sample\n\nimage = (input / 2 + 0.5).clamp(0, 1)\nimage = image.cpu().permute(0, 2, 3, 1).numpy()[0]\nimage = Image.fromarray((image * 255).round().astype(\"uint8\"))\nimage\n```\n\nCheck out the [Quickstart](https://huggingface.co/docs/diffusers/quicktour) to launch your diffusion journey today!\n\n## How to navigate the documentation\n\n| **Documentation**                                                   | **What can I learn?**                                                                                                                                                                           |\n|---------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [Tutorial](https://huggingface.co/docs/diffusers/tutorials/tutorial_overview)                                                            | A basic crash course for learning how to use the library's most important features like using models and schedulers to build your own diffusion system, and training your own diffusion model.  |\n| [Loading](https://huggingface.co/docs/diffusers/using-diffusers/loading)                                                             | Guides for how to load and configure all the components (pipelines, models, and schedulers) of the library, as well as how to use different schedulers.                                         |\n| [Pipelines for inference](https://huggingface.co/docs/diffusers/using-diffusers/overview_techniques)                                             | Guides for how to use pipelines for different inference tasks, batched generation, controlling generated outputs and randomness, and how to contribute a pipeline to the library.               |\n| [Optimization](https://huggingface.co/docs/diffusers/optimization/fp16)                                                        | Guides for how to optimize your diffusion model to run faster and consume less memory.                                                                                                          |\n| [Training](https://huggingface.co/docs/diffusers/training/overview) | Guides for how to train a diffusion model for different tasks with different training techniques.                                                                                               |\n## Contribution\n\nWe ❤️  contributions from the open-source community!\nIf you want to contribute to this library, please check out our [Contribution guide](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md).\nYou can look out for [issues](https://github.com/huggingface/diffusers/issues) you'd like to tackle to contribute to the library.\n- See [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) for general opportunities to contribute\n- See [New model/pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) to contribute exciting new diffusion models / diffusion pipelines\n- See [New scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)\n\nAlso, say 👋 in our public Discord channel <a href=\"https://discord.gg/G7tWnz98XR\"><img alt=\"Join us on Discord\" src=\"https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white\"></a>. We discuss the hottest trends about diffusion models, help each other with contributions, personal projects or just hang out ☕.\n\n\n## Popular Tasks & Pipelines\n\n<table>\n  <tr>\n    <th>Task</th>\n    <th>Pipeline</th>\n    <th>🤗 Hub</th>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td>Unconditional Image Generation</td>\n    <td><a href=\"https://huggingface.co/docs/diffusers/api/pipelines/ddpm\"> DDPM </a></td>\n    <td><a href=\"https://huggingface.co/google/ddpm-ema-church-256\"> google/ddpm-ema-church-256 </a></td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td>Text-to-Image</td>\n    <td><a href=\"https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img\">Stable Diffusion Text-to-Image</a></td>\n      <td><a href=\"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5\"> stable-diffusion-v1-5/stable-diffusion-v1-5 </a></td>\n  </tr>\n  <tr>\n    <td>Text-to-Image</td>\n    <td><a href=\"https://huggingface.co/docs/diffusers/api/pipelines/unclip\">unCLIP</a></td>\n      <td><a href=\"https://huggingface.co/kakaobrain/karlo-v1-alpha\"> kakaobrain/karlo-v1-alpha </a></td>\n  </tr>\n  <tr>\n    <td>Text-to-Image</td>\n    <td><a href=\"https://huggingface.co/docs/diffusers/api/pipelines/deepfloyd_if\">DeepFloyd IF</a></td>\n      <td><a href=\"https://huggingface.co/DeepFloyd/IF-I-XL-v1.0\"> DeepFloyd/IF-I-XL-v1.0 </a></td>\n  </tr>\n  <tr>\n    <td>Text-to-Image</td>\n    <td><a href=\"https://huggingface.co/docs/diffusers/api/pipelines/kandinsky\">Kandinsky</a></td>\n      <td><a href=\"https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder\"> kandinsky-community/kandinsky-2-2-decoder </a></td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td>Text-guided Image-to-Image</td>\n    <td><a href=\"https://huggingface.co/docs/diffusers/api/pipelines/controlnet\">ControlNet</a></td>\n      <td><a href=\"https://huggingface.co/lllyasviel/sd-controlnet-canny\"> lllyasviel/sd-controlnet-canny </a></td>\n  </tr>\n  <tr>\n    <td>Text-guided Image-to-Image</td>\n    <td><a href=\"https://huggingface.co/docs/diffusers/api/pipelines/pix2pix\">InstructPix2Pix</a></td>\n      <td><a href=\"https://huggingface.co/timbrooks/instruct-pix2pix\"> timbrooks/instruct-pix2pix </a></td>\n  </tr>\n  <tr>\n    <td>Text-guided Image-to-Image</td>\n    <td><a href=\"https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img\">Stable Diffusion Image-to-Image</a></td>\n      <td><a href=\"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5\"> stable-diffusion-v1-5/stable-diffusion-v1-5 </a></td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td>Text-guided Image Inpainting</td>\n    <td><a href=\"https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint\">Stable Diffusion Inpainting</a></td>\n      <td><a href=\"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting\"> stable-diffusion-v1-5/stable-diffusion-inpainting </a></td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td>Image Variation</td>\n    <td><a href=\"https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/image_variation\">Stable Diffusion Image Variation</a></td>\n      <td><a href=\"https://huggingface.co/lambdalabs/sd-image-variations-diffusers\"> lambdalabs/sd-image-variations-diffusers </a></td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td>Super Resolution</td>\n    <td><a href=\"https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/upscale\">Stable Diffusion Upscale</a></td>\n      <td><a href=\"https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler\"> stabilityai/stable-diffusion-x4-upscaler </a></td>\n  </tr>\n  <tr>\n    <td>Super Resolution</td>\n    <td><a href=\"https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_upscale\">Stable Diffusion Latent Upscale</a></td>\n      <td><a href=\"https://huggingface.co/stabilityai/sd-x2-latent-upscaler\"> stabilityai/sd-x2-latent-upscaler </a></td>\n  </tr>\n</table>\n\n## Popular libraries using 🧨 Diffusers\n\n- https://github.com/microsoft/TaskMatrix\n- https://github.com/invoke-ai/InvokeAI\n- https://github.com/InstantID/InstantID\n- https://github.com/apple/ml-stable-diffusion\n- https://github.com/Sanster/lama-cleaner\n- https://github.com/IDEA-Research/Grounded-Segment-Anything\n- https://github.com/ashawkey/stable-dreamfusion\n- https://github.com/deep-floyd/IF\n- https://github.com/bentoml/BentoML\n- https://github.com/bmaltais/kohya_ss\n- +14,000 other amazing GitHub repositories 💪\n\nThank you for using us ❤️.\n\n## Credits\n\nThis library concretizes previous work by many different authors and would not have been possible without their great research and implementations. We'd like to thank, in particular, the following implementations which have helped us in our development and without which the API could not have been as polished today:\n\n- @CompVis' latent diffusion models library, available [here](https://github.com/CompVis/latent-diffusion)\n- @hojonathanho original DDPM implementation, available [here](https://github.com/hojonathanho/diffusion) as well as the extremely useful translation into PyTorch by @pesser, available [here](https://github.com/pesser/pytorch_diffusion)\n- @ermongroup's DDIM implementation, available [here](https://github.com/ermongroup/ddim)\n- @yang-song's Score-VE and Score-VP implementations, available [here](https://github.com/yang-song/score_sde_pytorch)\n\nWe also want to thank @heejkoo for the very helpful overview of papers, code and resources on diffusion models, available [here](https://github.com/heejkoo/Awesome-Diffusion-Models) as well as @crowsonkb and @rromb for useful discussions and insights.\n\n## Citation\n\n```bibtex\n@misc{von-platen-etal-2022-diffusers,\n  author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Dhruv Nair and Sayak Paul and William Berman and Yiyi Xu and Steven Liu and Thomas Wolf},\n  title = {Diffusers: State-of-the-art diffusion models},\n  year = {2022},\n  publisher = {GitHub},\n  journal = {GitHub repository},\n  howpublished = {\\url{https://github.com/huggingface/diffusers}}\n}\n```\n"
  },
  {
    "path": "_typos.toml",
    "content": "# Files for typos\n# Instruction:  https://github.com/marketplace/actions/typos-action#getting-started\n\n[default.extend-identifiers]\n\n[default.extend-words]\nNIN=\"NIN\" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py\nnd=\"np\" # nd may be np (numpy)\nparms=\"parms\" # parms is used in scripts/convert_original_stable_diffusion_to_diffusers.py\n\n\n[files]\nextend-exclude = [\"_typos.toml\"]\n"
  },
  {
    "path": "benchmarks/README.md",
    "content": "# Diffusers Benchmarks\n\nWelcome to Diffusers Benchmarks. These benchmarks are use to obtain latency and memory information of the most popular models across different scenarios such as:\n\n* Base case i.e., when using `torch.bfloat16` and `torch.nn.functional.scaled_dot_product_attention`.\n* Base + `torch.compile()`\n* NF4 quantization\n* Layerwise upcasting\n\nInstead of full diffusion pipelines, only the forward pass of the respective model classes (such as `FluxTransformer2DModel`) is tested with the real checkpoints (such as `\"black-forest-labs/FLUX.1-dev\"`). \n\nThe entrypoint to running all the currently available benchmarks is in `run_all.py`. However, one can run the individual benchmarks, too, e.g., `python benchmarking_flux.py`. It should produce a CSV file containing various information about the benchmarks run.\n\nThe benchmarks are run on a weekly basis and the CI is defined in [benchmark.yml](../.github/workflows/benchmark.yml).\n\n## Running the benchmarks manually\n\nFirst set up `torch` and install `diffusers` from the root of the directory:\n\n```py\npip install -e \".[quality,test]\"\n```\n\nThen make sure the other dependencies are installed:\n\n```sh\ncd benchmarks/\npip install -r requirements.txt\n```\n\nWe need to be authenticated to access some of the checkpoints used during benchmarking:\n\n```sh\nhf auth login\n```\n\nWe use an L40 GPU with 128GB RAM to run the benchmark CI. As such, the benchmarks are configured to run on NVIDIA GPUs. So, make sure you have access to a similar machine (or modify the benchmarking scripts accordingly).\n\nThen you can either launch the entire benchmarking suite by running:\n\n```sh\npython run_all.py\n```\n\nOr, you can run the individual benchmarks.\n\n## Customizing the benchmarks\n\nWe define \"scenarios\" to cover the most common ways in which these models are used. You can\ndefine a new scenario, modifying an existing benchmark file:\n\n```py\nBenchmarkScenario(\n    name=f\"{CKPT_ID}-bnb-8bit\",\n    model_cls=FluxTransformer2DModel,\n    model_init_kwargs={\n        \"pretrained_model_name_or_path\": CKPT_ID,\n        \"torch_dtype\": torch.bfloat16,\n        \"subfolder\": \"transformer\",\n        \"quantization_config\": BitsAndBytesConfig(load_in_8bit=True),\n    },\n    get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n    model_init_fn=model_init_fn,\n)\n```\n\nYou can also configure a new model-level benchmark and add it to the existing suite. To do so, just defining a valid benchmarking file like `benchmarking_flux.py` should be enough.\n\nHappy benchmarking 🧨"
  },
  {
    "path": "benchmarks/__init__.py",
    "content": ""
  },
  {
    "path": "benchmarks/benchmarking_flux.py",
    "content": "from functools import partial\n\nimport torch\nfrom benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn\n\nfrom diffusers import BitsAndBytesConfig, FluxTransformer2DModel\nfrom diffusers.utils.testing_utils import torch_device\n\n\nCKPT_ID = \"black-forest-labs/FLUX.1-dev\"\nRESULT_FILENAME = \"flux.csv\"\n\n\ndef get_input_dict(**device_dtype_kwargs):\n    # resolution: 1024x1024\n    # maximum sequence length 512\n    hidden_states = torch.randn(1, 4096, 64, **device_dtype_kwargs)\n    encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)\n    pooled_prompt_embeds = torch.randn(1, 768, **device_dtype_kwargs)\n    image_ids = torch.ones(512, 3, **device_dtype_kwargs)\n    text_ids = torch.ones(4096, 3, **device_dtype_kwargs)\n    timestep = torch.tensor([1.0], **device_dtype_kwargs)\n    guidance = torch.tensor([1.0], **device_dtype_kwargs)\n\n    return {\n        \"hidden_states\": hidden_states,\n        \"encoder_hidden_states\": encoder_hidden_states,\n        \"img_ids\": image_ids,\n        \"txt_ids\": text_ids,\n        \"pooled_projections\": pooled_prompt_embeds,\n        \"timestep\": timestep,\n        \"guidance\": guidance,\n    }\n\n\nif __name__ == \"__main__\":\n    scenarios = [\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-bf16\",\n            model_cls=FluxTransformer2DModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"transformer\",\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=model_init_fn,\n            compile_kwargs={\"fullgraph\": True},\n        ),\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-bnb-nf4\",\n            model_cls=FluxTransformer2DModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"transformer\",\n                \"quantization_config\": BitsAndBytesConfig(\n                    load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type=\"nf4\"\n                ),\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=model_init_fn,\n        ),\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-layerwise-upcasting\",\n            model_cls=FluxTransformer2DModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"transformer\",\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=partial(model_init_fn, layerwise_upcasting=True),\n        ),\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-group-offload-leaf\",\n            model_cls=FluxTransformer2DModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"transformer\",\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=partial(\n                model_init_fn,\n                group_offload_kwargs={\n                    \"onload_device\": torch_device,\n                    \"offload_device\": torch.device(\"cpu\"),\n                    \"offload_type\": \"leaf_level\",\n                    \"use_stream\": True,\n                    \"non_blocking\": True,\n                },\n            ),\n        ),\n    ]\n\n    runner = BenchmarkMixin()\n    runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)\n"
  },
  {
    "path": "benchmarks/benchmarking_ltx.py",
    "content": "from functools import partial\n\nimport torch\nfrom benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn\n\nfrom diffusers import LTXVideoTransformer3DModel\nfrom diffusers.utils.testing_utils import torch_device\n\n\nCKPT_ID = \"Lightricks/LTX-Video-0.9.7-dev\"\nRESULT_FILENAME = \"ltx.csv\"\n\n\ndef get_input_dict(**device_dtype_kwargs):\n    # 512x704 (161 frames)\n    # `max_sequence_length`: 256\n    hidden_states = torch.randn(1, 7392, 128, **device_dtype_kwargs)\n    encoder_hidden_states = torch.randn(1, 256, 4096, **device_dtype_kwargs)\n    encoder_attention_mask = torch.ones(1, 256, **device_dtype_kwargs)\n    timestep = torch.tensor([1.0], **device_dtype_kwargs)\n    video_coords = torch.randn(1, 3, 7392, **device_dtype_kwargs)\n\n    return {\n        \"hidden_states\": hidden_states,\n        \"encoder_hidden_states\": encoder_hidden_states,\n        \"encoder_attention_mask\": encoder_attention_mask,\n        \"timestep\": timestep,\n        \"video_coords\": video_coords,\n    }\n\n\nif __name__ == \"__main__\":\n    scenarios = [\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-bf16\",\n            model_cls=LTXVideoTransformer3DModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"transformer\",\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=model_init_fn,\n            compile_kwargs={\"fullgraph\": True},\n        ),\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-layerwise-upcasting\",\n            model_cls=LTXVideoTransformer3DModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"transformer\",\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=partial(model_init_fn, layerwise_upcasting=True),\n        ),\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-group-offload-leaf\",\n            model_cls=LTXVideoTransformer3DModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"transformer\",\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=partial(\n                model_init_fn,\n                group_offload_kwargs={\n                    \"onload_device\": torch_device,\n                    \"offload_device\": torch.device(\"cpu\"),\n                    \"offload_type\": \"leaf_level\",\n                    \"use_stream\": True,\n                    \"non_blocking\": True,\n                },\n            ),\n        ),\n    ]\n\n    runner = BenchmarkMixin()\n    runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)\n"
  },
  {
    "path": "benchmarks/benchmarking_sdxl.py",
    "content": "from functools import partial\n\nimport torch\nfrom benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn\n\nfrom diffusers import UNet2DConditionModel\nfrom diffusers.utils.testing_utils import torch_device\n\n\nCKPT_ID = \"stabilityai/stable-diffusion-xl-base-1.0\"\nRESULT_FILENAME = \"sdxl.csv\"\n\n\ndef get_input_dict(**device_dtype_kwargs):\n    # height: 1024\n    # width: 1024\n    # max_sequence_length: 77\n    hidden_states = torch.randn(1, 4, 128, 128, **device_dtype_kwargs)\n    encoder_hidden_states = torch.randn(1, 77, 2048, **device_dtype_kwargs)\n    timestep = torch.tensor([1.0], **device_dtype_kwargs)\n    added_cond_kwargs = {\n        \"text_embeds\": torch.randn(1, 1280, **device_dtype_kwargs),\n        \"time_ids\": torch.ones(1, 6, **device_dtype_kwargs),\n    }\n\n    return {\n        \"sample\": hidden_states,\n        \"encoder_hidden_states\": encoder_hidden_states,\n        \"timestep\": timestep,\n        \"added_cond_kwargs\": added_cond_kwargs,\n    }\n\n\nif __name__ == \"__main__\":\n    scenarios = [\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-bf16\",\n            model_cls=UNet2DConditionModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"unet\",\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=model_init_fn,\n            compile_kwargs={\"fullgraph\": True},\n        ),\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-layerwise-upcasting\",\n            model_cls=UNet2DConditionModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"unet\",\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=partial(model_init_fn, layerwise_upcasting=True),\n        ),\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-group-offload-leaf\",\n            model_cls=UNet2DConditionModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"unet\",\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=partial(\n                model_init_fn,\n                group_offload_kwargs={\n                    \"onload_device\": torch_device,\n                    \"offload_device\": torch.device(\"cpu\"),\n                    \"offload_type\": \"leaf_level\",\n                    \"use_stream\": True,\n                    \"non_blocking\": True,\n                },\n            ),\n        ),\n    ]\n\n    runner = BenchmarkMixin()\n    runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)\n"
  },
  {
    "path": "benchmarks/benchmarking_utils.py",
    "content": "import gc\nimport inspect\nimport logging\nimport os\nimport queue\nimport threading\nfrom contextlib import nullcontext\nfrom dataclasses import dataclass\nfrom typing import Any, Callable\n\nimport pandas as pd\nimport torch\nimport torch.utils.benchmark as benchmark\n\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.utils.testing_utils import require_torch_gpu, torch_device\n\n\nlogging.basicConfig(level=logging.INFO, format=\"%(asctime)s %(levelname)s %(name)s: %(message)s\")\nlogger = logging.getLogger(__name__)\n\nNUM_WARMUP_ROUNDS = 5\n\n\ndef benchmark_fn(f, *args, **kwargs):\n    t0 = benchmark.Timer(\n        stmt=\"f(*args, **kwargs)\",\n        globals={\"args\": args, \"kwargs\": kwargs, \"f\": f},\n        num_threads=1,\n    )\n    return float(f\"{(t0.blocked_autorange().mean):.3f}\")\n\n\ndef flush():\n    gc.collect()\n    torch.cuda.empty_cache()\n    torch.cuda.reset_max_memory_allocated()\n    torch.cuda.reset_peak_memory_stats()\n\n\n# Adapted from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py\ndef calculate_flops(model, input_dict):\n    try:\n        from torchprofile import profile_macs\n    except ModuleNotFoundError:\n        raise\n\n    # This is a hacky way to convert the kwargs to args as `profile_macs` cries about kwargs.\n    sig = inspect.signature(model.forward)\n    param_names = [\n        p.name\n        for p in sig.parameters.values()\n        if p.kind\n        in (\n            inspect.Parameter.POSITIONAL_ONLY,\n            inspect.Parameter.POSITIONAL_OR_KEYWORD,\n        )\n        and p.name != \"self\"\n    ]\n    bound = sig.bind_partial(**input_dict)\n    bound.apply_defaults()\n    args = tuple(bound.arguments[name] for name in param_names)\n\n    model.eval()\n    with torch.no_grad():\n        macs = profile_macs(model, args)\n    flops = 2 * macs  # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition)\n    return flops\n\n\ndef calculate_params(model):\n    return sum(p.numel() for p in model.parameters())\n\n\n# Users can define their own in case this doesn't suffice. For most cases,\n# it should be sufficient.\ndef model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs):\n    model = model_cls.from_pretrained(**init_kwargs).eval()\n    if group_offload_kwargs and isinstance(group_offload_kwargs, dict):\n        model.enable_group_offload(**group_offload_kwargs)\n    else:\n        model.to(torch_device)\n    if layerwise_upcasting:\n        model.enable_layerwise_casting(\n            storage_dtype=torch.float8_e4m3fn, compute_dtype=init_kwargs.get(\"torch_dtype\", torch.bfloat16)\n        )\n    return model\n\n\n@dataclass\nclass BenchmarkScenario:\n    name: str\n    model_cls: ModelMixin\n    model_init_kwargs: dict[str, Any]\n    model_init_fn: Callable\n    get_model_input_dict: Callable\n    compile_kwargs: dict[str, Any] | None = None\n\n\n@require_torch_gpu\nclass BenchmarkMixin:\n    def pre_benchmark(self):\n        flush()\n        torch.compiler.reset()\n\n    def post_benchmark(self, model):\n        model.cpu()\n        flush()\n        torch.compiler.reset()\n\n    @torch.no_grad()\n    def run_benchmark(self, scenario: BenchmarkScenario):\n        # 0) Basic stats\n        logger.info(f\"Running scenario: {scenario.name}.\")\n        try:\n            model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)\n            num_params = round(calculate_params(model) / 1e9, 2)\n            try:\n                flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e9, 2)\n            except Exception as e:\n                logger.info(f\"Problem in calculating FLOPs:\\n{e}\")\n                flops = None\n            model.cpu()\n            del model\n        except Exception as e:\n            logger.info(f\"Error while initializing the model and calculating FLOPs:\\n{e}\")\n            return {}\n        self.pre_benchmark()\n\n        # 1) plain stats\n        results = {}\n        plain = None\n        try:\n            plain = self._run_phase(\n                model_cls=scenario.model_cls,\n                init_fn=scenario.model_init_fn,\n                init_kwargs=scenario.model_init_kwargs,\n                get_input_fn=scenario.get_model_input_dict,\n                compile_kwargs=None,\n            )\n        except Exception as e:\n            logger.info(f\"Benchmark could not be run with the following error:\\n{e}\")\n            return results\n\n        # 2) compiled stats (if any)\n        compiled = {\"time\": None, \"memory\": None}\n        if scenario.compile_kwargs:\n            try:\n                compiled = self._run_phase(\n                    model_cls=scenario.model_cls,\n                    init_fn=scenario.model_init_fn,\n                    init_kwargs=scenario.model_init_kwargs,\n                    get_input_fn=scenario.get_model_input_dict,\n                    compile_kwargs=scenario.compile_kwargs,\n                )\n            except Exception as e:\n                logger.info(f\"Compilation benchmark could not be run with the following error\\n: {e}\")\n                if plain is None:\n                    return results\n\n        # 3) merge\n        result = {\n            \"scenario\": scenario.name,\n            \"model_cls\": scenario.model_cls.__name__,\n            \"num_params_B\": num_params,\n            \"flops_G\": flops,\n            \"time_plain_s\": plain[\"time\"],\n            \"mem_plain_GB\": plain[\"memory\"],\n            \"time_compile_s\": compiled[\"time\"],\n            \"mem_compile_GB\": compiled[\"memory\"],\n        }\n        if scenario.compile_kwargs:\n            result[\"fullgraph\"] = scenario.compile_kwargs.get(\"fullgraph\", False)\n            result[\"mode\"] = scenario.compile_kwargs.get(\"mode\", \"default\")\n        else:\n            result[\"fullgraph\"], result[\"mode\"] = None, None\n        return result\n\n    def run_bencmarks_and_collate(self, scenarios: BenchmarkScenario | list[BenchmarkScenario], filename: str):\n        if not isinstance(scenarios, list):\n            scenarios = [scenarios]\n        record_queue = queue.Queue()\n        stop_signal = object()\n\n        def _writer_thread():\n            while True:\n                item = record_queue.get()\n                if item is stop_signal:\n                    break\n                df_row = pd.DataFrame([item])\n                write_header = not os.path.exists(filename)\n                df_row.to_csv(filename, mode=\"a\", header=write_header, index=False)\n                record_queue.task_done()\n\n            record_queue.task_done()\n\n        writer = threading.Thread(target=_writer_thread, daemon=True)\n        writer.start()\n\n        for s in scenarios:\n            try:\n                record = self.run_benchmark(s)\n                if record:\n                    record_queue.put(record)\n                else:\n                    logger.info(f\"Record empty from scenario: {s.name}.\")\n            except Exception as e:\n                logger.info(f\"Running scenario ({s.name}) led to error:\\n{e}\")\n        record_queue.put(stop_signal)\n        logger.info(f\"Results serialized to {filename=}.\")\n\n    def _run_phase(\n        self,\n        *,\n        model_cls: ModelMixin,\n        init_fn: Callable,\n        init_kwargs: dict[str, Any],\n        get_input_fn: Callable,\n        compile_kwargs: dict[str, Any] | None = None,\n    ) -> dict[str, float]:\n        # setup\n        self.pre_benchmark()\n\n        # init & (optional) compile\n        model = init_fn(model_cls, **init_kwargs)\n        if compile_kwargs:\n            model.compile(**compile_kwargs)\n\n        # build inputs\n        inp = get_input_fn()\n\n        # measure\n        run_ctx = torch._inductor.utils.fresh_inductor_cache() if compile_kwargs else nullcontext()\n        with run_ctx:\n            for _ in range(NUM_WARMUP_ROUNDS):\n                _ = model(**inp)\n            time_s = benchmark_fn(lambda m, d: m(**d), model, inp)\n        mem_gb = torch.cuda.max_memory_allocated() / (1024**3)\n        mem_gb = round(mem_gb, 2)\n\n        # teardown\n        self.post_benchmark(model)\n        del model\n        return {\"time\": time_s, \"memory\": mem_gb}\n"
  },
  {
    "path": "benchmarks/benchmarking_wan.py",
    "content": "from functools import partial\n\nimport torch\nfrom benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn\n\nfrom diffusers import WanTransformer3DModel\nfrom diffusers.utils.testing_utils import torch_device\n\n\nCKPT_ID = \"Wan-AI/Wan2.1-T2V-14B-Diffusers\"\nRESULT_FILENAME = \"wan.csv\"\n\n\ndef get_input_dict(**device_dtype_kwargs):\n    # height: 480\n    # width: 832\n    # num_frames: 81\n    # max_sequence_length: 512\n    hidden_states = torch.randn(1, 16, 21, 60, 104, **device_dtype_kwargs)\n    encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)\n    timestep = torch.tensor([1.0], **device_dtype_kwargs)\n\n    return {\"hidden_states\": hidden_states, \"encoder_hidden_states\": encoder_hidden_states, \"timestep\": timestep}\n\n\nif __name__ == \"__main__\":\n    scenarios = [\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-bf16\",\n            model_cls=WanTransformer3DModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"transformer\",\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=model_init_fn,\n            compile_kwargs={\"fullgraph\": True},\n        ),\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-layerwise-upcasting\",\n            model_cls=WanTransformer3DModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"transformer\",\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=partial(model_init_fn, layerwise_upcasting=True),\n        ),\n        BenchmarkScenario(\n            name=f\"{CKPT_ID}-group-offload-leaf\",\n            model_cls=WanTransformer3DModel,\n            model_init_kwargs={\n                \"pretrained_model_name_or_path\": CKPT_ID,\n                \"torch_dtype\": torch.bfloat16,\n                \"subfolder\": \"transformer\",\n            },\n            get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),\n            model_init_fn=partial(\n                model_init_fn,\n                group_offload_kwargs={\n                    \"onload_device\": torch_device,\n                    \"offload_device\": torch.device(\"cpu\"),\n                    \"offload_type\": \"leaf_level\",\n                    \"use_stream\": True,\n                    \"non_blocking\": True,\n                },\n            ),\n        ),\n    ]\n\n    runner = BenchmarkMixin()\n    runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)\n"
  },
  {
    "path": "benchmarks/push_results.py",
    "content": "import os\n\nimport pandas as pd\nfrom huggingface_hub import hf_hub_download, upload_file\nfrom huggingface_hub.utils import EntryNotFoundError\n\n\nREPO_ID = \"diffusers/benchmarks\"\n\n\ndef has_previous_benchmark() -> str:\n    from run_all import FINAL_CSV_FILENAME\n\n    csv_path = None\n    try:\n        csv_path = hf_hub_download(repo_id=REPO_ID, repo_type=\"dataset\", filename=FINAL_CSV_FILENAME)\n    except EntryNotFoundError:\n        csv_path = None\n    return csv_path\n\n\ndef filter_float(value):\n    if isinstance(value, str):\n        return float(value.split()[0])\n    return value\n\n\ndef push_to_hf_dataset():\n    from run_all import FINAL_CSV_FILENAME, GITHUB_SHA\n\n    csv_path = has_previous_benchmark()\n    if csv_path is not None:\n        current_results = pd.read_csv(FINAL_CSV_FILENAME)\n        previous_results = pd.read_csv(csv_path)\n\n        numeric_columns = current_results.select_dtypes(include=[\"float64\", \"int64\"]).columns\n\n        for column in numeric_columns:\n            # get previous values as floats, aligned to current index\n            prev_vals = previous_results[column].map(filter_float).reindex(current_results.index)\n\n            # get current values as floats\n            curr_vals = current_results[column].astype(float)\n\n            # stringify the current values\n            curr_str = curr_vals.map(str)\n\n            # build an appendage only when prev exists and differs\n            append_str = prev_vals.where(prev_vals.notnull() & (prev_vals != curr_vals), other=pd.NA).map(\n                lambda x: f\" ({x})\" if pd.notnull(x) else \"\"\n            )\n\n            # combine\n            current_results[column] = curr_str + append_str\n        os.remove(FINAL_CSV_FILENAME)\n        current_results.to_csv(FINAL_CSV_FILENAME, index=False)\n\n    commit_message = f\"upload from sha: {GITHUB_SHA}\" if GITHUB_SHA is not None else \"upload benchmark results\"\n    upload_file(\n        repo_id=REPO_ID,\n        path_in_repo=FINAL_CSV_FILENAME,\n        path_or_fileobj=FINAL_CSV_FILENAME,\n        repo_type=\"dataset\",\n        commit_message=commit_message,\n    )\n    upload_file(\n        repo_id=\"diffusers/benchmark-analyzer\",\n        path_in_repo=FINAL_CSV_FILENAME,\n        path_or_fileobj=FINAL_CSV_FILENAME,\n        repo_type=\"space\",\n        commit_message=commit_message,\n    )\n\n\nif __name__ == \"__main__\":\n    push_to_hf_dataset()\n"
  },
  {
    "path": "benchmarks/requirements.txt",
    "content": "pandas \npsutil\ngpustat\ntorchprofile\nbitsandbytes\npsycopg2==2.9.9"
  },
  {
    "path": "benchmarks/run_all.py",
    "content": "import glob\nimport logging\nimport os\nimport subprocess\n\nimport pandas as pd\n\n\nlogging.basicConfig(level=logging.INFO, format=\"%(asctime)s %(levelname)s %(name)s: %(message)s\")\nlogger = logging.getLogger(__name__)\n\nPATTERN = \"benchmarking_*.py\"\nFINAL_CSV_FILENAME = \"collated_results.csv\"\nGITHUB_SHA = os.getenv(\"GITHUB_SHA\", None)\n\n\nclass SubprocessCallException(Exception):\n    pass\n\n\ndef run_command(command: list[str], return_stdout=False):\n    try:\n        output = subprocess.check_output(command, stderr=subprocess.STDOUT)\n        if return_stdout and hasattr(output, \"decode\"):\n            return output.decode(\"utf-8\")\n    except subprocess.CalledProcessError as e:\n        raise SubprocessCallException(f\"Command `{' '.join(command)}` failed with:\\n{e.output.decode()}\") from e\n\n\ndef merge_csvs(final_csv: str = \"collated_results.csv\"):\n    all_csvs = glob.glob(\"*.csv\")\n    all_csvs = [f for f in all_csvs if f != final_csv]\n    if not all_csvs:\n        logger.info(\"No result CSVs found to merge.\")\n        return\n\n    df_list = []\n    for f in all_csvs:\n        try:\n            d = pd.read_csv(f)\n        except pd.errors.EmptyDataError:\n            # If a file existed but was zero‐bytes or corrupted, skip it\n            continue\n        df_list.append(d)\n\n    if not df_list:\n        logger.info(\"All result CSVs were empty or invalid; nothing to merge.\")\n        return\n\n    final_df = pd.concat(df_list, ignore_index=True)\n    if GITHUB_SHA is not None:\n        final_df[\"github_sha\"] = GITHUB_SHA\n    final_df.to_csv(final_csv, index=False)\n    logger.info(f\"Merged {len(all_csvs)} partial CSVs → {final_csv}.\")\n\n\ndef run_scripts():\n    python_files = sorted(glob.glob(PATTERN))\n    python_files = [f for f in python_files if f != \"benchmarking_utils.py\"]\n\n    for file in python_files:\n        script_name = file.split(\".py\")[0].split(\"_\")[-1]  # example: benchmarking_foo.py -> foo\n        logger.info(f\"\\n****** Running file: {file} ******\")\n\n        partial_csv = f\"{script_name}.csv\"\n        if os.path.exists(partial_csv):\n            logger.info(f\"Found {partial_csv}. Removing for safer numbers and duplication.\")\n            os.remove(partial_csv)\n\n        command = [\"python\", file]\n        try:\n            run_command(command)\n            logger.info(f\"→ {file} finished normally.\")\n        except SubprocessCallException as e:\n            logger.info(f\"Error running {file}:\\n{e}\")\n        finally:\n            logger.info(f\"→ Merging partial CSVs after {file} …\")\n            merge_csvs(final_csv=FINAL_CSV_FILENAME)\n\n    logger.info(f\"\\nAll scripts attempted. Final collated CSV: {FINAL_CSV_FILENAME}\")\n\n\nif __name__ == \"__main__\":\n    run_scripts()\n"
  },
  {
    "path": "docker/diffusers-doc-builder/Dockerfile",
    "content": "FROM python:3.10-slim\nENV PYTHONDONTWRITEBYTECODE=1\nLABEL maintainer=\"Hugging Face\"\nLABEL repository=\"diffusers\"\n\nENV DEBIAN_FRONTEND=noninteractive\n\nRUN apt-get -y update && apt-get install -y bash \\\n    build-essential \\\n    git \\\n    git-lfs \\\n    curl \\\n    ca-certificates \\\n    libglib2.0-0 \\\n    libsndfile1-dev \\\n    libgl1 \\\n    zip \\\n    wget\n\nENV UV_PYTHON=/usr/local/bin/python\n\n# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)\nRUN pip install uv\nRUN uv pip install --no-cache-dir \\\n    torch \\\n    torchvision \\\n    torchaudio \\\n    --extra-index-url https://download.pytorch.org/whl/cpu\n\nRUN uv pip install --no-cache-dir \"git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]\"\n\n# Extra dependencies\nRUN uv pip install --no-cache-dir \\\n    accelerate \\\n    numpy==1.26.4 \\\n    hf_xet \\\n    setuptools==69.5.1 \\\n    bitsandbytes \\\n    torchao \\\n    gguf \\\n    optimum-quanto\n\nRUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean\n\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "docker/diffusers-onnxruntime-cpu/Dockerfile",
    "content": "FROM ubuntu:20.04\nLABEL maintainer=\"Hugging Face\"\nLABEL repository=\"diffusers\"\n\nENV DEBIAN_FRONTEND=noninteractive\n\nRUN apt-get -y update \\\n    && apt-get install -y software-properties-common \\\n    && add-apt-repository ppa:deadsnakes/ppa\n\nRUN apt install -y bash \\\n                   build-essential \\\n                   git \\\n                   git-lfs \\\n                   curl \\\n                   ca-certificates \\\n                   libsndfile1-dev \\\n                   libgl1 \\\n                   python3.10 \\\n                   python3-pip \\\n                   python3.10-venv && \\\n    rm -rf /var/lib/apt/lists\n\n# make sure to use venv\nRUN python3.10 -m venv /opt/venv\nENV PATH=\"/opt/venv/bin:$PATH\"\n\n# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)\nRUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \\\n    python3 -m uv pip install --no-cache-dir \\\n        torch \\\n        torchvision \\\n        torchaudio\\\n        onnxruntime \\\n        --extra-index-url https://download.pytorch.org/whl/cpu && \\\n    python3 -m uv pip install --no-cache-dir \\\n        accelerate \\\n        datasets \\\n        hf-doc-builder \\\n        huggingface-hub \\\n        Jinja2 \\\n        librosa \\\n        numpy==1.26.4 \\\n        scipy \\\n        tensorboard \\\n        transformers \\\n        hf_xet\n\nCMD [\"/bin/bash\"]"
  },
  {
    "path": "docker/diffusers-onnxruntime-cuda/Dockerfile",
    "content": "FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04\nLABEL maintainer=\"Hugging Face\"\nLABEL repository=\"diffusers\"\n\nENV DEBIAN_FRONTEND=noninteractive\n\nRUN apt-get -y update \\\n    && apt-get install -y software-properties-common \\\n    && add-apt-repository ppa:deadsnakes/ppa\n\nRUN apt install -y bash \\\n                   build-essential \\\n                   git \\\n                   git-lfs \\\n                   curl \\\n                   ca-certificates \\\n                   libsndfile1-dev \\\n                   libgl1 \\\n                   python3.10 \\\n                   python3-pip \\\n                   python3.10-venv && \\\n    rm -rf /var/lib/apt/lists\n\n# make sure to use venv\nRUN python3.10 -m venv /opt/venv\nENV PATH=\"/opt/venv/bin:$PATH\"\n\n# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)\nRUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \\\n    python3.10 -m uv pip install --no-cache-dir \\\n        torch \\\n        torchvision \\\n        torchaudio \\\n        \"onnxruntime-gpu>=1.13.1\" \\\n        --extra-index-url https://download.pytorch.org/whl/cu117 && \\\n    python3.10 -m uv pip install --no-cache-dir \\\n        accelerate \\\n        datasets \\\n        hf-doc-builder \\\n        huggingface-hub \\\n        hf_xet \\\n        Jinja2 \\\n        librosa \\\n        numpy==1.26.4 \\\n        scipy \\\n        tensorboard \\\n        transformers\n\nCMD [\"/bin/bash\"]"
  },
  {
    "path": "docker/diffusers-pytorch-cpu/Dockerfile",
    "content": "FROM python:3.10-slim\nENV PYTHONDONTWRITEBYTECODE=1\nLABEL maintainer=\"Hugging Face\"\nLABEL repository=\"diffusers\"\n\nENV DEBIAN_FRONTEND=noninteractive\n\nRUN apt-get -y update && apt-get install -y bash \\\n    build-essential \\\n    git \\\n    git-lfs \\\n    curl \\\n    ca-certificates \\\n    libglib2.0-0 \\\n    libsndfile1-dev \\\n    libgl1\n\nENV UV_PYTHON=/usr/local/bin/python\n\n# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)\nRUN pip install uv\nRUN uv pip install --no-cache-dir \\\n    torch \\\n    torchvision \\\n    torchaudio \\\n    --extra-index-url https://download.pytorch.org/whl/cpu\n\nRUN uv pip install --no-cache-dir \"git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]\"\n\n# Extra dependencies\nRUN uv pip install --no-cache-dir \\\n    accelerate \\\n    numpy==1.26.4 \\\n    hf_xet\n\nRUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean\n\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "docker/diffusers-pytorch-cuda/Dockerfile",
    "content": "FROM nvidia/cuda:12.9.1-runtime-ubuntu24.04\nLABEL maintainer=\"Hugging Face\"\nLABEL repository=\"diffusers\"\n\nARG PYTHON_VERSION=3.10\nENV DEBIAN_FRONTEND=noninteractive\n\nRUN apt-get -y update \\\n    && apt-get install -y software-properties-common \\\n    && add-apt-repository ppa:deadsnakes/ppa && \\\n    apt-get update\n\nRUN apt install -y bash \\\n    build-essential \\\n    git \\\n    git-lfs \\\n    curl \\\n    ca-certificates \\\n    libglib2.0-0 \\\n    libsndfile1-dev \\\n    libgl1 \\\n    python3 \\\n    python3-pip \\\n    && apt-get clean \\\n    && rm -rf /var/lib/apt/lists/*\n\nRUN curl -LsSf https://astral.sh/uv/install.sh | sh\nENV PATH=\"/root/.local/bin:$PATH\"\nENV VIRTUAL_ENV=\"/opt/venv\"\nENV UV_PYTHON_INSTALL_DIR=/opt/uv/python\nRUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}\nENV PATH=\"$VIRTUAL_ENV/bin:$PATH\"\n\n# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)\n# Install torch, torchvision, and torchaudio together to ensure compatibility\nRUN uv pip install --no-cache-dir \\\n    torch \\\n    torchvision \\\n    torchaudio \\\n    --index-url https://download.pytorch.org/whl/cu129\n\n# Install compatible versions of numba/llvmlite for Python 3.10+\nRUN uv pip install --no-cache-dir \\\n    \"llvmlite>=0.40.0\" \\\n    \"numba>=0.57.0\"\n\nRUN uv pip install --no-cache-dir \"git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]\"\n\n# Extra dependencies\nRUN uv pip install --no-cache-dir \\\n    accelerate \\\n    numpy==1.26.4 \\\n    pytorch-lightning \\\n    hf_xet\n\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "docker/diffusers-pytorch-minimum-cuda/Dockerfile",
    "content": "FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04\nLABEL maintainer=\"Hugging Face\"\nLABEL repository=\"diffusers\"\n\nARG PYTHON_VERSION=3.10\nENV DEBIAN_FRONTEND=noninteractive\nENV MINIMUM_SUPPORTED_TORCH_VERSION=\"2.1.0\"\nENV MINIMUM_SUPPORTED_TORCHVISION_VERSION=\"0.16.0\"\nENV MINIMUM_SUPPORTED_TORCHAUDIO_VERSION=\"2.1.0\"\n\nRUN apt-get -y update \\\n    && apt-get install -y software-properties-common \\\n    && add-apt-repository ppa:deadsnakes/ppa && \\\n    apt-get update\n\nRUN apt install -y bash \\\n    build-essential \\\n    git \\\n    git-lfs \\\n    curl \\\n    ca-certificates \\\n    libglib2.0-0 \\\n    libsndfile1-dev \\\n    libgl1 \\\n    python3 \\\n    python3-pip \\\n    && apt-get clean \\\n    && rm -rf /var/lib/apt/lists/*\n\nRUN curl -LsSf https://astral.sh/uv/install.sh | sh\nENV PATH=\"/root/.local/bin:$PATH\"\nENV VIRTUAL_ENV=\"/opt/venv\"\nENV UV_PYTHON_INSTALL_DIR=/opt/uv/python\nRUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}\nENV PATH=\"$VIRTUAL_ENV/bin:$PATH\"\n\n# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)\nRUN uv pip install --no-cache-dir \\\n    torch==$MINIMUM_SUPPORTED_TORCH_VERSION \\\n    torchvision==$MINIMUM_SUPPORTED_TORCHVISION_VERSION \\\n    torchaudio==$MINIMUM_SUPPORTED_TORCHAUDIO_VERSION\n\nRUN uv pip install --no-cache-dir \"git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]\"\n\n# Extra dependencies\nRUN uv pip install --no-cache-dir \\\n    accelerate \\\n    numpy==1.26.4 \\\n    pytorch-lightning \\\n    hf_xet\n\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "docker/diffusers-pytorch-xformers-cuda/Dockerfile",
    "content": "FROM nvidia/cuda:12.9.1-runtime-ubuntu24.04\nLABEL maintainer=\"Hugging Face\"\nLABEL repository=\"diffusers\"\n\nARG PYTHON_VERSION=3.10\nENV DEBIAN_FRONTEND=noninteractive\n\nRUN apt-get -y update \\\n    && apt-get install -y software-properties-common \\\n    && add-apt-repository ppa:deadsnakes/ppa && \\\n    apt-get update\n\nRUN apt install -y bash \\\n    build-essential \\\n    git \\\n    git-lfs \\\n    curl \\\n    ca-certificates \\\n    libglib2.0-0 \\\n    libsndfile1-dev \\\n    libgl1 \\\n    python3 \\\n    python3-pip \\\n    && apt-get clean \\\n    && rm -rf /var/lib/apt/lists/*\n\nRUN curl -LsSf https://astral.sh/uv/install.sh | sh\nENV PATH=\"/root/.local/bin:$PATH\"\nENV VIRTUAL_ENV=\"/opt/venv\"\nENV UV_PYTHON_INSTALL_DIR=/opt/uv/python\nRUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}\nENV PATH=\"$VIRTUAL_ENV/bin:$PATH\"\n\n# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)\n# Install torch, torchvision, and torchaudio together to ensure compatibility\nRUN uv pip install --no-cache-dir \\\n    torch \\\n    torchvision \\\n    torchaudio \\\n    --index-url https://download.pytorch.org/whl/cu129\n\n# Install compatible versions of numba/llvmlite for Python 3.10+\nRUN uv pip install --no-cache-dir \\\n    \"llvmlite>=0.40.0\" \\\n    \"numba>=0.57.0\"\n\nRUN uv pip install --no-cache-dir \"git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]\"\n\n# Extra dependencies\nRUN uv pip install --no-cache-dir \\\n    accelerate \\\n    numpy==1.26.4 \\\n    pytorch-lightning \\\n    hf_xet \\\n    xformers\n\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "docs/README.md",
    "content": "<!---\nCopyright 2024- The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\n\n# Generating the documentation\n\nTo generate the documentation, you first have to build it. Several packages are necessary to build the doc,\nyou can install them with the following command, at the root of the code repository:\n\n```bash\npip install -e \".[docs]\"\n```\n\nThen you need to install our open source documentation builder tool:\n\n```bash\npip install git+https://github.com/huggingface/doc-builder\n```\n\n---\n**NOTE**\n\nYou only need to generate the documentation to inspect it locally (if you're planning changes and want to\ncheck how they look before committing for instance). You don't have to commit the built documentation.\n\n---\n\n## Previewing the documentation\n\nTo preview the docs, first install the `watchdog` module with:\n\n```bash\npip install watchdog\n```\n\nThen run the following command:\n\n```bash\ndoc-builder preview {package_name} {path_to_docs}\n```\n\nFor example:\n\n```bash\ndoc-builder preview diffusers docs/source/en\n```\n\nThe docs will be viewable at [http://localhost:3000](http://localhost:3000). You can also preview the docs once you have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives.\n\n---\n**NOTE**\n\nThe `preview` command only works with existing doc files. When you add a completely new file, you need to update `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again).\n\n---\n\n## Adding a new element to the navigation bar\n\nAccepted files are Markdown (.md).\n\nCreate a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting\nthe filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml) file.\n\n## Renaming section headers and moving sections\n\nIt helps to keep the old links working when renaming the section header and/or moving sections from one document to another. This is because the old links are likely to be used in Issues, Forums, and Social media and it'd make for a much more superior user experience if users reading those months later could still easily navigate to the originally intended information.\n\nTherefore, we simply keep a little map of moved sections at the end of the document where the original section was. The key is to preserve the original anchor.\n\nSo if you renamed a section from: \"Section A\" to \"Section B\", then you can add at the end of the file:\n\n```md\nSections that were moved:\n\n[ <a href=\"#section-b\">Section A</a><a id=\"section-a\"></a> ]\n```\nand of course, if you moved it to another file, then:\n\n```md\nSections that were moved:\n\n[ <a href=\"../new-file#section-b\">Section A</a><a id=\"section-a\"></a> ]\n```\n\nUse the relative style to link to the new file so that the versioned docs continue to work.\n\nFor an example of a rich moved section set please see the very end of [the transformers Trainer doc](https://github.com/huggingface/transformers/blob/main/docs/source/en/main_classes/trainer.md).\n\n\n## Writing Documentation - Specification\n\nThe `huggingface/diffusers` documentation follows the\n[Google documentation](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) style for docstrings,\nalthough we can write them directly in Markdown.\n\n### Adding a new tutorial\n\nAdding a new tutorial or section is done in two steps:\n\n- Add a new Markdown (.md) file under `docs/source/<languageCode>`.\n- Link that file in `docs/source/<languageCode>/_toctree.yml` on the correct toc-tree.\n\nMake sure to put your new file under the proper section. It's unlikely to go in the first section (*Get Started*), so\ndepending on the intended targets (beginners, more advanced users, or researchers) it should go in sections two, three, or four.\n\n### Adding a new pipeline/scheduler\n\nWhen adding a new pipeline:\n\n- Create a file `xxx.md` under `docs/source/<languageCode>/api/pipelines` (don't hesitate to copy an existing file as template).\n- Link that file in (*Diffusers Summary*) section in `docs/source/api/pipelines/overview.md`, along with the link to the paper, and a colab notebook (if available).\n- Write a short overview of the diffusion model:\n    - Overview with paper & authors\n    - Paper abstract\n    - Tips and tricks and how to use it best\n    - Possible an end-to-end example of how to use it\n- Add all the pipeline classes that should be linked in the diffusion model. These classes should be added using our Markdown syntax. By default as follows:\n\n```\n[[autodoc]] XXXPipeline\n    - all\n\t- __call__\n```\n\nThis will include every public method of the pipeline that is documented, as well as the  `__call__` method that is not documented by default. If you just want to add additional methods that are not documented, you can put the list of all methods to add in a list that contains `all`.\n\n```\n[[autodoc]] XXXPipeline\n    - all\n\t- __call__\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n    - enable_xformers_memory_efficient_attention\n    - disable_xformers_memory_efficient_attention\n```\n\nYou can follow the same process to create a new scheduler under the `docs/source/<languageCode>/api/schedulers` folder.\n\n### Writing source documentation\n\nValues that should be put in `code` should either be surrounded by backticks: \\`like so\\`. Note that argument names\nand objects like True, None, or any strings should usually be put in `code`.\n\nWhen mentioning a class, function, or method, it is recommended to use our syntax for internal links so that our tool\nadds a link to its documentation with this syntax: \\[\\`XXXClass\\`\\] or \\[\\`function\\`\\]. This requires the class or\nfunction to be in the main package.\n\nIf you want to create a link to some internal class or function, you need to\nprovide its path. For instance: \\[\\`pipelines.ImagePipelineOutput\\`\\]. This will be converted into a link with\n`pipelines.ImagePipelineOutput` in the description. To get rid of the path and only keep the name of the object you are\nlinking to in the description, add a ~: \\[\\`~pipelines.ImagePipelineOutput\\`\\] will generate a link with `ImagePipelineOutput` in the description.\n\nThe same works for methods so you can either use \\[\\`XXXClass.method\\`\\] or \\[\\`~XXXClass.method\\`\\].\n\n#### Defining arguments in a method\n\nArguments should be defined with the `Args:` (or `Arguments:` or `Parameters:`) prefix, followed by a line return and\nan indentation. The argument should be followed by its type, with its shape if it is a tensor, a colon, and its\ndescription:\n\n```\n    Args:\n        n_layers (`int`): The number of layers of the model.\n```\n\nIf the description is too long to fit in one line, another indentation is necessary before writing the description\nafter the argument.\n\nHere's an example showcasing everything so far:\n\n```\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary.\n\n            Indices can be obtained using [`AlbertTokenizer`]. See [`~PreTrainedTokenizer.encode`] and\n            [`~PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n```\n\nFor optional arguments or arguments with defaults we follow the following syntax: imagine we have a function with the\nfollowing signature:\n\n```py\ndef my_function(x: str=None, a: float=3.14):\n```\n\nthen its documentation should look like this:\n\n```\n    Args:\n        x (`str`, *optional*):\n            This argument controls ...\n        a (`float`, *optional*, defaults to `3.14`):\n            This argument is used to ...\n```\n\nNote that we always omit the \"defaults to \\`None\\`\" when None is the default for any argument. Also note that even\nif the first line describing your argument type and its default gets long, you can't break it on several lines. You can\nhowever write as many lines as you want in the indented description (see the example above with `input_ids`).\n\n#### Writing a multi-line code block\n\nMulti-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown:\n\n\n````\n```\n# first line of code\n# second line\n# etc\n```\n````\n\n#### Writing a return block\n\nThe return block should be introduced with the `Returns:` prefix, followed by a line return and an indentation.\nThe first line should be the type of the return, followed by a line return. No need to indent further for the elements\nbuilding the return.\n\nHere's an example of a single value return:\n\n```\n    Returns:\n        `List[int]`: A list of integers in the range [0, 1] --- 1 for a special token, 0 for a sequence token.\n```\n\nHere's an example of a tuple return, comprising several objects:\n\n```\n    Returns:\n        `tuple(torch.Tensor)` comprising various elements depending on the configuration ([`BertConfig`]) and inputs:\n        - ** loss** (*optional*, returned when `masked_lm_labels` is provided) `torch.Tensor` of shape `(1,)` --\n          Total loss is the sum of the masked language modeling loss and the next sequence prediction (classification) loss.\n        - **prediction_scores** (`torch.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`) --\n          Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).\n```\n\n#### Adding an image\n\nDue to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like\nthe ones hosted on [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) in which to place these files and reference\nthem by URL. We recommend putting them in the following dataset: [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images).\nIf an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images\nto this dataset.\n\n## Styling the docstring\n\nWe have an automatic script running with the `make style` command that will make sure that:\n- the docstrings fully take advantage of the line width\n- all code examples are formatted using black, like the code of the Transformers library\n\nThis script may have some weird failures if you made a syntax mistake or if you uncover a bug. Therefore, it's\nrecommended to commit your changes before running `make style`, so you can revert the changes done by that script\neasily.\n"
  },
  {
    "path": "docs/TRANSLATING.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n### Translating the Diffusers documentation into your language\n\nAs part of our mission to democratize machine learning, we'd love to make the Diffusers library available in many more languages! Follow the steps below if you want to help translate the documentation into your language 🙏.\n\n**🗞️ Open an issue**\n\nTo get started, navigate to the [Issues](https://github.com/huggingface/diffusers/issues) page of this repo and check if anyone else has opened an issue for your language. If not, open a new issue by selecting the \"🌐 Translating a New Language?\" from the \"New issue\" button.\n\nOnce an issue exists, post a comment to indicate which chapters you'd like to work on, and we'll add your name to the list.\n\n\n**🍴 Fork the repository**\n\nFirst, you'll need to [fork the Diffusers repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo). You can do this by clicking on the **Fork** button on the top-right corner of this repo's page.\n\nOnce you've forked the repo, you'll want to get the files on your local machine for editing. You can do that by cloning the fork with Git as follows:\n\n```bash\ngit clone https://github.com/<YOUR-USERNAME>/diffusers.git\n```\n\n**📋 Copy-paste the English version with a new language code**\n\nThe documentation files are in one leading directory:\n\n- [`docs/source`](https://github.com/huggingface/diffusers/tree/main/docs/source): All the documentation materials are organized here by language.\n\nYou'll only need to copy the files in the [`docs/source/en`](https://github.com/huggingface/diffusers/tree/main/docs/source/en) directory, so first navigate to your fork of the repo and run the following:\n\n```bash\ncd ~/path/to/diffusers/docs\ncp -r source/en source/<LANG-ID>\n```\n\nHere, `<LANG-ID>` should be one of the ISO 639-1 or ISO 639-2 language codes -- see [here](https://www.loc.gov/standards/iso639-2/php/code_list.php) for a handy table.\n\n**✍️ Start translating**\n\nThe fun part comes - translating the text!\n\nThe first thing we recommend is translating the part of the `_toctree.yml` file that corresponds to your doc chapter. This file is used to render the table of contents on the website.\n\n> 🙋 If the `_toctree.yml` file doesn't yet exist for your language, you can create one by copy-pasting from the English version and deleting the sections unrelated to your chapter. Just make sure it exists in the `docs/source/<LANG-ID>/` directory!\n\nThe fields you should add are `local` (with the name of the file containing the translation; e.g. `autoclass_tutorial`), and `title` (with the title of the doc in your language; e.g. `Load pretrained instances with an AutoClass`) -- as a reference, here is the `_toctree.yml` for [English](https://github.com/huggingface/diffusers/blob/main/docs/source/en/_toctree.yml):\n\n```yaml\n- sections:\n  - local: pipeline_tutorial # Do not change this! Use the same name for your .md file\n    title: Pipelines for inference # Translate this!\n    ...\n  title: Tutorials # Translate this!\n```\n\nOnce you have translated the `_toctree.yml` file, you can start translating the [MDX](https://mdxjs.com/) files associated with your docs chapter.\n\n> 🙋 If you'd like others to help you with the translation, you should [open an issue](https://github.com/huggingface/diffusers/issues) and tag @patrickvonplaten.\n"
  },
  {
    "path": "docs/source/_config.py",
    "content": "# docstyle-ignore\nINSTALL_CONTENT = \"\"\"\n# Diffusers installation\n! pip install diffusers transformers datasets accelerate\n# To install from source instead of the last release, comment the command above and uncomment the following one.\n# ! pip install git+https://github.com/huggingface/diffusers.git\n\"\"\"\n\nnotebook_first_cells = [{\"type\": \"code\", \"content\": INSTALL_CONTENT}]\n"
  },
  {
    "path": "docs/source/en/_toctree.yml",
    "content": "- sections:\n  - local: index\n    title: Diffusers\n  - local: installation\n    title: Installation\n  - local: quicktour\n    title: Quickstart\n  - local: stable_diffusion\n    title: Basic performance\n  title: Get started\n- isExpanded: false\n  sections:\n  - local: using-diffusers/loading\n    title: DiffusionPipeline\n  - local: tutorials/autopipeline\n    title: AutoPipeline\n  - local: using-diffusers/custom_pipeline_overview\n    title: Community pipelines and components\n  - local: using-diffusers/callback\n    title: Pipeline callbacks\n  - local: using-diffusers/reusing_seeds\n    title: Reproducibility\n  - local: using-diffusers/schedulers\n    title: Schedulers\n  - local: using-diffusers/guiders\n    title: Guiders\n  - local: using-diffusers/automodel\n    title: AutoModel\n  - local: using-diffusers/other-formats\n    title: Model formats\n  - local: using-diffusers/push_to_hub\n    title: Sharing pipelines and models\n  title: Pipelines\n- isExpanded: false\n  sections:\n  - local: tutorials/using_peft_for_inference\n    title: LoRA\n  - local: using-diffusers/ip_adapter\n    title: IP-Adapter\n  - local: using-diffusers/controlnet\n    title: ControlNet\n  - local: using-diffusers/t2i_adapter\n    title: T2I-Adapter\n  - local: using-diffusers/dreambooth\n    title: DreamBooth\n  - local: using-diffusers/textual_inversion_inference\n    title: Textual inversion\n  title: Adapters\n- isExpanded: false\n  sections:\n  - local: using-diffusers/weighted_prompts\n    title: Prompting\n  - local: using-diffusers/create_a_server\n    title: Create a server\n  - local: using-diffusers/batched_inference\n    title: Batch inference\n  - local: training/distributed_inference\n    title: Distributed inference\n  - local: hybrid_inference/overview\n    title: Remote inference\n  title: Inference\n- isExpanded: false\n  sections:\n  - local: optimization/fp16\n    title: Accelerate inference\n  - local: optimization/cache\n    title: Caching\n  - local: optimization/attention_backends\n    title: Attention backends\n  - local: optimization/memory\n    title: Reduce memory usage\n  - local: optimization/speed-memory-optims\n    title: Compiling and offloading quantized models\n  - sections:\n    - local: optimization/pruna\n      title: Pruna\n    - local: optimization/xformers\n      title: xFormers\n    - local: optimization/tome\n      title: Token merging\n    - local: optimization/deepcache\n      title: DeepCache\n    - local: optimization/cache_dit\n      title: CacheDiT\n    - local: optimization/tgate\n      title: TGATE\n    - local: optimization/xdit\n      title: xDiT\n    - local: optimization/para_attn\n      title: ParaAttention\n    - local: using-diffusers/image_quality\n      title: FreeU\n    title: Community optimizations\n  title: Inference optimization\n- isExpanded: false\n  sections:\n  - local: modular_diffusers/overview\n    title: Overview\n  - local: modular_diffusers/quickstart\n    title: Quickstart\n  - local: modular_diffusers/modular_diffusers_states\n    title: States\n  - local: modular_diffusers/pipeline_block\n    title: ModularPipelineBlocks\n  - local: modular_diffusers/sequential_pipeline_blocks\n    title: SequentialPipelineBlocks\n  - local: modular_diffusers/loop_sequential_pipeline_blocks\n    title: LoopSequentialPipelineBlocks\n  - local: modular_diffusers/auto_pipeline_blocks\n    title: AutoPipelineBlocks\n  - local: modular_diffusers/modular_pipeline\n    title: ModularPipeline\n  - local: modular_diffusers/components_manager\n    title: ComponentsManager\n  - local: modular_diffusers/custom_blocks\n    title: Building Custom Blocks\n  - local: modular_diffusers/mellon\n    title: Using Custom Blocks with Mellon\n  title: Modular Diffusers\n- isExpanded: false\n  sections:\n  - local: training/overview\n    title: Overview\n  - local: training/create_dataset\n    title: Create a dataset for training\n  - local: training/adapt_a_model\n    title: Adapt a model to a new task\n  - local: tutorials/basic_training\n    title: Train a diffusion model\n  - sections:\n    - local: training/unconditional_training\n      title: Unconditional image generation\n    - local: training/text2image\n      title: Text-to-image\n    - local: training/sdxl\n      title: Stable Diffusion XL\n    - local: training/kandinsky\n      title: Kandinsky 2.2\n    - local: training/wuerstchen\n      title: Wuerstchen\n    - local: training/controlnet\n      title: ControlNet\n    - local: training/t2i_adapters\n      title: T2I-Adapters\n    - local: training/instructpix2pix\n      title: InstructPix2Pix\n    - local: training/cogvideox\n      title: CogVideoX\n    title: Models\n  - sections:\n    - local: training/text_inversion\n      title: Textual Inversion\n    - local: training/dreambooth\n      title: DreamBooth\n    - local: training/lora\n      title: LoRA\n    - local: training/custom_diffusion\n      title: Custom Diffusion\n    - local: training/lcm_distill\n      title: Latent Consistency Distillation\n    - local: training/ddpo\n      title: Reinforcement learning training with DDPO\n    title: Methods\n  title: Training\n- isExpanded: false\n  sections:\n  - local: quantization/overview\n    title: Getting started\n  - local: quantization/bitsandbytes\n    title: bitsandbytes\n  - local: quantization/gguf\n    title: gguf\n  - local: quantization/torchao\n    title: torchao\n  - local: quantization/quanto\n    title: quanto\n  - local: quantization/modelopt\n    title: NVIDIA ModelOpt\n  title: Quantization\n- isExpanded: false\n  sections:\n  - local: optimization/onnx\n    title: ONNX\n  - local: optimization/open_vino\n    title: OpenVINO\n  - local: optimization/coreml\n    title: Core ML\n  - local: optimization/mps\n    title: Metal Performance Shaders (MPS)\n  - local: optimization/habana\n    title: Intel Gaudi\n  - local: optimization/neuron\n    title: AWS Neuron\n  title: Model accelerators and hardware\n- isExpanded: false\n  sections:\n  - local: using-diffusers/helios\n    title: Helios\n  - local: using-diffusers/consisid\n    title: ConsisID\n  - local: using-diffusers/sdxl\n    title: Stable Diffusion XL\n  - local: using-diffusers/sdxl_turbo\n    title: SDXL Turbo\n  - local: using-diffusers/kandinsky\n    title: Kandinsky\n  - local: using-diffusers/omnigen\n    title: OmniGen\n  - local: using-diffusers/pag\n    title: PAG\n  - local: using-diffusers/inference_with_lcm\n    title: Latent Consistency Model\n  - local: using-diffusers/shap-e\n    title: Shap-E\n  - local: using-diffusers/diffedit\n    title: DiffEdit\n  - local: using-diffusers/inference_with_tcd_lora\n    title: Trajectory Consistency Distillation-LoRA\n  - local: using-diffusers/svd\n    title: Stable Video Diffusion\n  - local: using-diffusers/marigold_usage\n    title: Marigold Computer Vision\n  title: Specific pipeline examples\n- isExpanded: false\n  sections:\n  - sections:\n    - local: using-diffusers/unconditional_image_generation\n      title: Unconditional image generation\n    - local: using-diffusers/conditional_image_generation\n      title: Text-to-image\n    - local: using-diffusers/img2img\n      title: Image-to-image\n    - local: using-diffusers/inpaint\n      title: Inpainting\n    - local: advanced_inference/outpaint\n      title: Outpainting\n    - local: using-diffusers/text-img2vid\n      title: Video generation\n    - local: using-diffusers/depth2img\n      title: Depth-to-image\n    title: Task recipes\n  - local: using-diffusers/write_own_pipeline\n    title: Understanding pipelines, models and schedulers\n  - local: community_projects\n    title: Projects built with Diffusers\n  - local: conceptual/philosophy\n    title: Philosophy\n  - local: using-diffusers/controlling_generation\n    title: Controlled generation\n  - local: conceptual/contribution\n    title: How to contribute?\n  - local: conceptual/ethical_guidelines\n    title: Diffusers' Ethical Guidelines\n  - local: conceptual/evaluation\n    title: Evaluating Diffusion Models\n  title: Resources\n- isExpanded: false\n  sections:\n  - sections:\n    - local: api/configuration\n      title: Configuration\n    - local: api/logging\n      title: Logging\n    - local: api/outputs\n      title: Outputs\n    - local: api/quantization\n      title: Quantization\n    - local: hybrid_inference/api_reference\n      title: Remote inference\n    - local: api/parallel\n      title: Parallel inference\n    title: Main Classes\n  - sections:\n    - local: api/modular_diffusers/pipeline\n      title: Pipeline\n    - local: api/modular_diffusers/pipeline_blocks\n      title: Blocks\n    - local: api/modular_diffusers/pipeline_states\n      title: States\n    - local: api/modular_diffusers/pipeline_components\n      title: Components and configs\n    - local: api/modular_diffusers/guiders\n      title: Guiders\n    title: Modular\n  - sections:\n    - local: api/loaders/ip_adapter\n      title: IP-Adapter\n    - local: api/loaders/lora\n      title: LoRA\n    - local: api/loaders/single_file\n      title: Single files\n    - local: api/loaders/textual_inversion\n      title: Textual Inversion\n    - local: api/loaders/unet\n      title: UNet\n    - local: api/loaders/transformer_sd3\n      title: SD3Transformer2D\n    - local: api/loaders/peft\n      title: PEFT\n    title: Loaders\n  - sections:\n    - local: api/models/overview\n      title: Overview\n    - local: api/models/auto_model\n      title: AutoModel\n    - sections:\n      - local: api/models/controlnet\n        title: ControlNetModel\n      - local: api/models/controlnet_union\n        title: ControlNetUnionModel\n      - local: api/models/controlnet_flux\n        title: FluxControlNetModel\n      - local: api/models/controlnet_hunyuandit\n        title: HunyuanDiT2DControlNetModel\n      - local: api/models/controlnet_sana\n        title: SanaControlNetModel\n      - local: api/models/controlnet_sd3\n        title: SD3ControlNetModel\n      - local: api/models/controlnet_sparsectrl\n        title: SparseControlNetModel\n      title: ControlNets\n    - sections:\n      - local: api/models/allegro_transformer3d\n        title: AllegroTransformer3DModel\n      - local: api/models/aura_flow_transformer2d\n        title: AuraFlowTransformer2DModel\n      - local: api/models/transformer_bria_fibo\n        title: BriaFiboTransformer2DModel\n      - local: api/models/bria_transformer\n        title: BriaTransformer2DModel\n      - local: api/models/chroma_transformer\n        title: ChromaTransformer2DModel\n      - local: api/models/chronoedit_transformer_3d\n        title: ChronoEditTransformer3DModel\n      - local: api/models/cogvideox_transformer3d\n        title: CogVideoXTransformer3DModel\n      - local: api/models/cogview3plus_transformer2d\n        title: CogView3PlusTransformer2DModel\n      - local: api/models/cogview4_transformer2d\n        title: CogView4Transformer2DModel\n      - local: api/models/consisid_transformer3d\n        title: ConsisIDTransformer3DModel\n      - local: api/models/cosmos_transformer3d\n        title: CosmosTransformer3DModel\n      - local: api/models/dit_transformer2d\n        title: DiTTransformer2DModel\n      - local: api/models/easyanimate_transformer3d\n        title: EasyAnimateTransformer3DModel\n      - local: api/models/flux2_transformer\n        title: Flux2Transformer2DModel\n      - local: api/models/flux_transformer\n        title: FluxTransformer2DModel\n      - local: api/models/glm_image_transformer2d\n        title: GlmImageTransformer2DModel\n      - local: api/models/helios_transformer3d\n        title: HeliosTransformer3DModel\n      - local: api/models/hidream_image_transformer\n        title: HiDreamImageTransformer2DModel\n      - local: api/models/hunyuan_transformer2d\n        title: HunyuanDiT2DModel\n      - local: api/models/hunyuanimage_transformer_2d\n        title: HunyuanImageTransformer2DModel\n      - local: api/models/hunyuan_video15_transformer_3d\n        title: HunyuanVideo15Transformer3DModel\n      - local: api/models/hunyuan_video_transformer_3d\n        title: HunyuanVideoTransformer3DModel\n      - local: api/models/latte_transformer3d\n        title: LatteTransformer3DModel\n      - local: api/models/longcat_image_transformer2d\n        title: LongCatImageTransformer2DModel\n      - local: api/models/ltx2_video_transformer3d\n        title: LTX2VideoTransformer3DModel\n      - local: api/models/ltx_video_transformer3d\n        title: LTXVideoTransformer3DModel\n      - local: api/models/lumina2_transformer2d\n        title: Lumina2Transformer2DModel\n      - local: api/models/lumina_nextdit2d\n        title: LuminaNextDiT2DModel\n      - local: api/models/mochi_transformer3d\n        title: MochiTransformer3DModel\n      - local: api/models/omnigen_transformer\n        title: OmniGenTransformer2DModel\n      - local: api/models/ovisimage_transformer2d\n        title: OvisImageTransformer2DModel\n      - local: api/models/pixart_transformer2d\n        title: PixArtTransformer2DModel\n      - local: api/models/prior_transformer\n        title: PriorTransformer\n      - local: api/models/qwenimage_transformer2d\n        title: QwenImageTransformer2DModel\n      - local: api/models/sana_transformer2d\n        title: SanaTransformer2DModel\n      - local: api/models/sana_video_transformer3d\n        title: SanaVideoTransformer3DModel\n      - local: api/models/sd3_transformer2d\n        title: SD3Transformer2DModel\n      - local: api/models/skyreels_v2_transformer_3d\n        title: SkyReelsV2Transformer3DModel\n      - local: api/models/stable_audio_transformer\n        title: StableAudioDiTModel\n      - local: api/models/transformer2d\n        title: Transformer2DModel\n      - local: api/models/transformer_temporal\n        title: TransformerTemporalModel\n      - local: api/models/wan_animate_transformer_3d\n        title: WanAnimateTransformer3DModel\n      - local: api/models/wan_transformer_3d\n        title: WanTransformer3DModel\n      - local: api/models/z_image_transformer2d\n        title: ZImageTransformer2DModel\n      title: Transformers\n    - sections:\n      - local: api/models/stable_cascade_unet\n        title: StableCascadeUNet\n      - local: api/models/unet\n        title: UNet1DModel\n      - local: api/models/unet2d-cond\n        title: UNet2DConditionModel\n      - local: api/models/unet2d\n        title: UNet2DModel\n      - local: api/models/unet3d-cond\n        title: UNet3DConditionModel\n      - local: api/models/unet-motion\n        title: UNetMotionModel\n      - local: api/models/uvit2d\n        title: UViT2DModel\n      title: UNets\n    - sections:\n      - local: api/models/asymmetricautoencoderkl\n        title: AsymmetricAutoencoderKL\n      - local: api/models/autoencoder_dc\n        title: AutoencoderDC\n      - local: api/models/autoencoderkl\n        title: AutoencoderKL\n      - local: api/models/autoencoderkl_allegro\n        title: AutoencoderKLAllegro\n      - local: api/models/autoencoderkl_cogvideox\n        title: AutoencoderKLCogVideoX\n      - local: api/models/autoencoderkl_cosmos\n        title: AutoencoderKLCosmos\n      - local: api/models/autoencoder_kl_hunyuanimage\n        title: AutoencoderKLHunyuanImage\n      - local: api/models/autoencoder_kl_hunyuanimage_refiner\n        title: AutoencoderKLHunyuanImageRefiner\n      - local: api/models/autoencoder_kl_hunyuan_video\n        title: AutoencoderKLHunyuanVideo\n      - local: api/models/autoencoder_kl_hunyuan_video15\n        title: AutoencoderKLHunyuanVideo15\n      - local: api/models/autoencoderkl_audio_ltx_2\n        title: AutoencoderKLLTX2Audio\n      - local: api/models/autoencoderkl_ltx_2\n        title: AutoencoderKLLTX2Video\n      - local: api/models/autoencoderkl_ltx_video\n        title: AutoencoderKLLTXVideo\n      - local: api/models/autoencoderkl_magvit\n        title: AutoencoderKLMagvit\n      - local: api/models/autoencoderkl_mochi\n        title: AutoencoderKLMochi\n      - local: api/models/autoencoderkl_qwenimage\n        title: AutoencoderKLQwenImage\n      - local: api/models/autoencoder_kl_wan\n        title: AutoencoderKLWan\n      - local: api/models/autoencoder_rae\n        title: AutoencoderRAE\n      - local: api/models/consistency_decoder_vae\n        title: ConsistencyDecoderVAE\n      - local: api/models/autoencoder_oobleck\n        title: Oobleck AutoEncoder\n      - local: api/models/autoencoder_tiny\n        title: Tiny AutoEncoder\n      - local: api/models/vq\n        title: VQModel\n      title: VAEs\n    title: Models\n  - sections:\n    - local: api/pipelines/overview\n      title: Overview\n    - local: api/pipelines/auto_pipeline\n      title: AutoPipeline\n    - sections:\n      - local: api/pipelines/audioldm\n        title: AudioLDM\n      - local: api/pipelines/audioldm2\n        title: AudioLDM 2\n      - local: api/pipelines/dance_diffusion\n        title: Dance Diffusion\n      - local: api/pipelines/musicldm\n        title: MusicLDM\n      - local: api/pipelines/stable_audio\n        title: Stable Audio\n      title: Audio\n    - sections:\n      - local: api/pipelines/amused\n        title: aMUSEd\n      - local: api/pipelines/animatediff\n        title: AnimateDiff\n      - local: api/pipelines/attend_and_excite\n        title: Attend-and-Excite\n      - local: api/pipelines/aura_flow\n        title: AuraFlow\n      - local: api/pipelines/blip_diffusion\n        title: BLIP-Diffusion\n      - local: api/pipelines/bria_3_2\n        title: Bria 3.2\n      - local: api/pipelines/bria_fibo\n        title: Bria Fibo\n      - local: api/pipelines/bria_fibo_edit\n        title: Bria Fibo Edit\n      - local: api/pipelines/chroma\n        title: Chroma\n      - local: api/pipelines/cogview3\n        title: CogView3\n      - local: api/pipelines/cogview4\n        title: CogView4\n      - local: api/pipelines/consistency_models\n        title: Consistency Models\n      - local: api/pipelines/controlnet\n        title: ControlNet\n      - local: api/pipelines/controlnet_flux\n        title: ControlNet with Flux.1\n      - local: api/pipelines/controlnet_hunyuandit\n        title: ControlNet with Hunyuan-DiT\n      - local: api/pipelines/controlnet_sd3\n        title: ControlNet with Stable Diffusion 3\n      - local: api/pipelines/controlnet_sdxl\n        title: ControlNet with Stable Diffusion XL\n      - local: api/pipelines/controlnet_sana\n        title: ControlNet-Sana\n      - local: api/pipelines/controlnetxs\n        title: ControlNet-XS\n      - local: api/pipelines/controlnetxs_sdxl\n        title: ControlNet-XS with Stable Diffusion XL\n      - local: api/pipelines/controlnet_union\n        title: ControlNetUnion\n      - local: api/pipelines/ddim\n        title: DDIM\n      - local: api/pipelines/ddpm\n        title: DDPM\n      - local: api/pipelines/deepfloyd_if\n        title: DeepFloyd IF\n      - local: api/pipelines/diffedit\n        title: DiffEdit\n      - local: api/pipelines/dit\n        title: DiT\n      - local: api/pipelines/easyanimate\n        title: EasyAnimate\n      - local: api/pipelines/flux\n        title: Flux\n      - local: api/pipelines/flux2\n        title: Flux2\n      - local: api/pipelines/control_flux_inpaint\n        title: FluxControlInpaint\n      - local: api/pipelines/glm_image\n        title: GLM-Image\n      - local: api/pipelines/hidream\n        title: HiDream-I1\n      - local: api/pipelines/hunyuandit\n        title: Hunyuan-DiT\n      - local: api/pipelines/hunyuanimage21\n        title: HunyuanImage2.1\n      - local: api/pipelines/pix2pix\n        title: InstructPix2Pix\n      - local: api/pipelines/kandinsky\n        title: Kandinsky 2.1\n      - local: api/pipelines/kandinsky_v22\n        title: Kandinsky 2.2\n      - local: api/pipelines/kandinsky3\n        title: Kandinsky 3\n      - local: api/pipelines/kandinsky5_image\n        title: Kandinsky 5.0 Image\n      - local: api/pipelines/kolors\n        title: Kolors\n      - local: api/pipelines/latent_consistency_models\n        title: Latent Consistency Models\n      - local: api/pipelines/latent_diffusion\n        title: Latent Diffusion\n      - local: api/pipelines/ledits_pp\n        title: LEDITS++\n      - local: api/pipelines/longcat_image\n        title: LongCat-Image\n      - local: api/pipelines/lumina2\n        title: Lumina 2.0\n      - local: api/pipelines/lumina\n        title: Lumina-T2X\n      - local: api/pipelines/marigold\n        title: Marigold\n      - local: api/pipelines/panorama\n        title: MultiDiffusion\n      - local: api/pipelines/omnigen\n        title: OmniGen\n      - local: api/pipelines/ovis_image\n        title: Ovis-Image\n      - local: api/pipelines/pag\n        title: PAG\n      - local: api/pipelines/paint_by_example\n        title: Paint by Example\n      - local: api/pipelines/pixart\n        title: PixArt-α\n      - local: api/pipelines/pixart_sigma\n        title: PixArt-Σ\n      - local: api/pipelines/prx\n        title: PRX\n      - local: api/pipelines/qwenimage\n        title: QwenImage\n      - local: api/pipelines/sana\n        title: Sana\n      - local: api/pipelines/sana_sprint\n        title: Sana Sprint\n      - local: api/pipelines/sana_video\n        title: Sana Video\n      - local: api/pipelines/self_attention_guidance\n        title: Self-Attention Guidance\n      - local: api/pipelines/semantic_stable_diffusion\n        title: Semantic Guidance\n      - local: api/pipelines/shap_e\n        title: Shap-E\n      - local: api/pipelines/stable_cascade\n        title: Stable Cascade\n      - sections:\n        - local: api/pipelines/stable_diffusion/overview\n          title: Overview\n        - local: api/pipelines/stable_diffusion/depth2img\n          title: Depth-to-image\n        - local: api/pipelines/stable_diffusion/gligen\n          title: GLIGEN (Grounded Language-to-Image Generation)\n        - local: api/pipelines/stable_diffusion/image_variation\n          title: Image variation\n        - local: api/pipelines/stable_diffusion/img2img\n          title: Image-to-image\n        - local: api/pipelines/stable_diffusion/inpaint\n          title: Inpainting\n        - local: api/pipelines/stable_diffusion/latent_upscale\n          title: Latent upscaler\n        - local: api/pipelines/stable_diffusion/ldm3d_diffusion\n          title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D\n            Upscaler\n        - local: api/pipelines/stable_diffusion/stable_diffusion_safe\n          title: Safe Stable Diffusion\n        - local: api/pipelines/stable_diffusion/sdxl_turbo\n          title: SDXL Turbo\n        - local: api/pipelines/stable_diffusion/stable_diffusion_2\n          title: Stable Diffusion 2\n        - local: api/pipelines/stable_diffusion/stable_diffusion_3\n          title: Stable Diffusion 3\n        - local: api/pipelines/stable_diffusion/stable_diffusion_xl\n          title: Stable Diffusion XL\n        - local: api/pipelines/stable_diffusion/upscale\n          title: Super-resolution\n        - local: api/pipelines/stable_diffusion/adapter\n          title: T2I-Adapter\n        - local: api/pipelines/stable_diffusion/text2img\n          title: Text-to-image\n        title: Stable Diffusion\n      - local: api/pipelines/stable_unclip\n        title: Stable unCLIP\n      - local: api/pipelines/unclip\n        title: unCLIP\n      - local: api/pipelines/unidiffuser\n        title: UniDiffuser\n      - local: api/pipelines/value_guided_sampling\n        title: Value-guided sampling\n      - local: api/pipelines/visualcloze\n        title: VisualCloze\n      - local: api/pipelines/wuerstchen\n        title: Wuerstchen\n      - local: api/pipelines/z_image\n        title: Z-Image\n      title: Image\n    - sections:\n      - local: api/pipelines/allegro\n        title: Allegro\n      - local: api/pipelines/chronoedit\n        title: ChronoEdit\n      - local: api/pipelines/cogvideox\n        title: CogVideoX\n      - local: api/pipelines/consisid\n        title: ConsisID\n      - local: api/pipelines/cosmos\n        title: Cosmos\n      - local: api/pipelines/framepack\n        title: Framepack\n      - local: api/pipelines/helios\n        title: Helios\n      - local: api/pipelines/hunyuan_video\n        title: HunyuanVideo\n      - local: api/pipelines/hunyuan_video15\n        title: HunyuanVideo1.5\n      - local: api/pipelines/i2vgenxl\n        title: I2VGen-XL\n      - local: api/pipelines/kandinsky5_video\n        title: Kandinsky 5.0 Video\n      - local: api/pipelines/latte\n        title: Latte\n      - local: api/pipelines/ltx2\n        title: LTX-2\n      - local: api/pipelines/ltx_video\n        title: LTXVideo\n      - local: api/pipelines/mochi\n        title: Mochi\n      - local: api/pipelines/pia\n        title: Personalized Image Animator (PIA)\n      - local: api/pipelines/skyreels_v2\n        title: SkyReels-V2\n      - local: api/pipelines/stable_diffusion/svd\n        title: Stable Video Diffusion\n      - local: api/pipelines/text_to_video\n        title: Text-to-video\n      - local: api/pipelines/text_to_video_zero\n        title: Text2Video-Zero\n      - local: api/pipelines/wan\n        title: Wan\n      title: Video\n    title: Pipelines\n  - sections:\n    - local: api/schedulers/overview\n      title: Overview\n    - local: api/schedulers/cm_stochastic_iterative\n      title: CMStochasticIterativeScheduler\n    - local: api/schedulers/ddim_cogvideox\n      title: CogVideoXDDIMScheduler\n    - local: api/schedulers/multistep_dpm_solver_cogvideox\n      title: CogVideoXDPMScheduler\n    - local: api/schedulers/consistency_decoder\n      title: ConsistencyDecoderScheduler\n    - local: api/schedulers/cosine_dpm\n      title: CosineDPMSolverMultistepScheduler\n    - local: api/schedulers/ddim_inverse\n      title: DDIMInverseScheduler\n    - local: api/schedulers/ddim\n      title: DDIMScheduler\n    - local: api/schedulers/ddpm\n      title: DDPMScheduler\n    - local: api/schedulers/deis\n      title: DEISMultistepScheduler\n    - local: api/schedulers/multistep_dpm_solver_inverse\n      title: DPMSolverMultistepInverse\n    - local: api/schedulers/multistep_dpm_solver\n      title: DPMSolverMultistepScheduler\n    - local: api/schedulers/dpm_sde\n      title: DPMSolverSDEScheduler\n    - local: api/schedulers/singlestep_dpm_solver\n      title: DPMSolverSinglestepScheduler\n    - local: api/schedulers/edm_multistep_dpm_solver\n      title: EDMDPMSolverMultistepScheduler\n    - local: api/schedulers/edm_euler\n      title: EDMEulerScheduler\n    - local: api/schedulers/euler_ancestral\n      title: EulerAncestralDiscreteScheduler\n    - local: api/schedulers/euler\n      title: EulerDiscreteScheduler\n    - local: api/schedulers/flow_match_euler_discrete\n      title: FlowMatchEulerDiscreteScheduler\n    - local: api/schedulers/flow_match_heun_discrete\n      title: FlowMatchHeunDiscreteScheduler\n    - local: api/schedulers/helios_dmd\n      title: HeliosDMDScheduler\n    - local: api/schedulers/helios\n      title: HeliosScheduler\n    - local: api/schedulers/heun\n      title: HeunDiscreteScheduler\n    - local: api/schedulers/ipndm\n      title: IPNDMScheduler\n    - local: api/schedulers/stochastic_karras_ve\n      title: KarrasVeScheduler\n    - local: api/schedulers/dpm_discrete_ancestral\n      title: KDPM2AncestralDiscreteScheduler\n    - local: api/schedulers/dpm_discrete\n      title: KDPM2DiscreteScheduler\n    - local: api/schedulers/lcm\n      title: LCMScheduler\n    - local: api/schedulers/lms_discrete\n      title: LMSDiscreteScheduler\n    - local: api/schedulers/pndm\n      title: PNDMScheduler\n    - local: api/schedulers/repaint\n      title: RePaintScheduler\n    - local: api/schedulers/score_sde_ve\n      title: ScoreSdeVeScheduler\n    - local: api/schedulers/score_sde_vp\n      title: ScoreSdeVpScheduler\n    - local: api/schedulers/tcd\n      title: TCDScheduler\n    - local: api/schedulers/unipc\n      title: UniPCMultistepScheduler\n    - local: api/schedulers/vq_diffusion\n      title: VQDiffusionScheduler\n    title: Schedulers\n  - sections:\n    - local: api/internal_classes_overview\n      title: Overview\n    - local: api/attnprocessor\n      title: Attention Processor\n    - local: api/activations\n      title: Custom activation functions\n    - local: api/cache\n      title: Caching methods\n    - local: api/normalization\n      title: Custom normalization layers\n    - local: api/utilities\n      title: Utilities\n    - local: api/image_processor\n      title: VAE Image Processor\n    - local: api/video_processor\n      title: Video Processor\n    title: Internal classes\n  title: API\n"
  },
  {
    "path": "docs/source/en/advanced_inference/outpaint.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Outpainting\n\nOutpainting extends an image beyond its original boundaries, allowing you to add, replace, or modify visual elements in an image while preserving the original image. Like [inpainting](../using-diffusers/inpaint), you want to fill the white area (in this case, the area outside of the original image) with new visual elements while keeping the original image (represented by a mask of black pixels). There are a couple of ways to outpaint, such as with a [ControlNet](https://hf.co/blog/OzzyGT/outpainting-controlnet) or with [Differential Diffusion](https://hf.co/blog/OzzyGT/outpainting-differential-diffusion).\n\nThis guide will show you how to outpaint with an inpainting model, ControlNet, and a ZoeDepth estimator.\n\nBefore you begin, make sure you have the [controlnet_aux](https://github.com/huggingface/controlnet_aux) library installed so you can use the ZoeDepth estimator.\n\n```py\n!pip install -q controlnet_aux\n```\n\n## Image preparation\n\nStart by picking an image to outpaint with and remove the background with a Space like [BRIA-RMBG-1.4](https://hf.co/spaces/briaai/BRIA-RMBG-1.4).\n\n<iframe\n\tsrc=\"https://briaai-bria-rmbg-1-4.hf.space\"\n\tframeborder=\"0\"\n\twidth=\"850\"\n\theight=\"450\"\n></iframe>\n\nFor example, remove the background from this image of a pair of shoes.\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/original-jordan.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">original image</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/no-background-jordan.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">background removed</figcaption>\n  </div>\n</div>\n\n[Stable Diffusion XL (SDXL)](../using-diffusers/sdxl) models work best with 1024x1024 images, but you can resize the image to any size as long as your hardware has enough memory to support it. The transparent background in the image should also be replaced with a white background. Create a function (like the one below) that scales and pastes the image onto a white background.\n\n```py\nimport random\n\nimport requests\nimport torch\nfrom controlnet_aux import ZoeDetector\nfrom PIL import Image, ImageOps\n\nfrom diffusers import (\n    AutoencoderKL,\n    ControlNetModel,\n    StableDiffusionXLControlNetPipeline,\n    StableDiffusionXLInpaintPipeline,\n)\n\ndef scale_and_paste(original_image):\n    aspect_ratio = original_image.width / original_image.height\n\n    if original_image.width > original_image.height:\n        new_width = 1024\n        new_height = round(new_width / aspect_ratio)\n    else:\n        new_height = 1024\n        new_width = round(new_height * aspect_ratio)\n\n    resized_original = original_image.resize((new_width, new_height), Image.LANCZOS)\n    white_background = Image.new(\"RGBA\", (1024, 1024), \"white\")\n    x = (1024 - new_width) // 2\n    y = (1024 - new_height) // 2\n    white_background.paste(resized_original, (x, y), resized_original)\n\n    return resized_original, white_background\n\noriginal_image = Image.open(\n    requests.get(\n        \"https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/no-background-jordan.png\",\n        stream=True,\n    ).raw\n).convert(\"RGBA\")\nresized_img, white_bg_image = scale_and_paste(original_image)\n```\n\nTo avoid adding unwanted extra details, use the ZoeDepth estimator to provide additional guidance during generation and to ensure the shoes remain consistent with the original image.\n\n```py\nzoe = ZoeDetector.from_pretrained(\"lllyasviel/Annotators\")\nimage_zoe = zoe(white_bg_image, detect_resolution=512, image_resolution=1024)\nimage_zoe\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/zoedepth-jordan.png\"/>\n</div>\n\n## Outpaint\n\nOnce your image is ready, you can generate content in the white area around the shoes with [controlnet-inpaint-dreamer-sdxl](https://hf.co/destitech/controlnet-inpaint-dreamer-sdxl), a SDXL ControlNet trained for inpainting.\n\nLoad the inpainting ControlNet, ZoeDepth model, VAE and pass them to the [`StableDiffusionXLControlNetPipeline`]. Then you can create an optional `generate_image` function (for convenience) to outpaint an initial image.\n\n```py\ncontrolnets = [\n    ControlNetModel.from_pretrained(\n        \"destitech/controlnet-inpaint-dreamer-sdxl\", torch_dtype=torch.float16, variant=\"fp16\"\n    ),\n    ControlNetModel.from_pretrained(\n        \"diffusers/controlnet-zoe-depth-sdxl-1.0\", torch_dtype=torch.float16\n    ),\n]\nvae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16).to(\"cuda\")\npipeline = StableDiffusionXLControlNetPipeline.from_pretrained(\n    \"SG161222/RealVisXL_V4.0\", torch_dtype=torch.float16, variant=\"fp16\", controlnet=controlnets, vae=vae\n).to(\"cuda\")\n\ndef generate_image(prompt, negative_prompt, inpaint_image, zoe_image, seed: int = None):\n    if seed is None:\n        seed = random.randint(0, 2**32 - 1)\n\n    generator = torch.Generator(device=\"cpu\").manual_seed(seed)\n\n    image = pipeline(\n        prompt,\n        negative_prompt=negative_prompt,\n        image=[inpaint_image, zoe_image],\n        guidance_scale=6.5,\n        num_inference_steps=25,\n        generator=generator,\n        controlnet_conditioning_scale=[0.5, 0.8],\n        control_guidance_end=[0.9, 0.6],\n    ).images[0]\n\n    return image\n\nprompt = \"nike air jordans on a basketball court\"\nnegative_prompt = \"\"\n\ntemp_image = generate_image(prompt, negative_prompt, white_bg_image, image_zoe, 908097)\n```\n\nPaste the original image over the initial outpainted image. You'll improve the outpainted background in a later step.\n\n```py\nx = (1024 - resized_img.width) // 2\ny = (1024 - resized_img.height) // 2\ntemp_image.paste(resized_img, (x, y), resized_img)\ntemp_image\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/initial-outpaint.png\"/>\n</div>\n\n> [!TIP]\n> Now is a good time to free up some memory if you're running low!\n>\n> ```py\n> pipeline=None\n> torch.cuda.empty_cache()\n> ```\n\nNow that you have an initial outpainted image, load the [`StableDiffusionXLInpaintPipeline`] with the [RealVisXL](https://hf.co/SG161222/RealVisXL_V4.0) model to generate the final outpainted image with better quality.\n\n```py\npipeline = StableDiffusionXLInpaintPipeline.from_pretrained(\n    \"OzzyGT/RealVisXL_V4.0_inpainting\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n    vae=vae,\n).to(\"cuda\")\n```\n\nPrepare a mask for the final outpainted image. To create a more natural transition between the original image and the outpainted background, blur the mask to help it blend better.\n\n```py\nmask = Image.new(\"L\", temp_image.size)\nmask.paste(resized_img.split()[3], (x, y))\nmask = ImageOps.invert(mask)\nfinal_mask = mask.point(lambda p: p > 128 and 255)\nmask_blurred = pipeline.mask_processor.blur(final_mask, blur_factor=20)\nmask_blurred\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/blurred-mask.png\"/>\n</div>\n\nCreate a better prompt and pass it to the `generate_outpaint` function to generate the final outpainted image. Again, paste the original image over the final outpainted background.\n\n```py\ndef generate_outpaint(prompt, negative_prompt, image, mask, seed: int = None):\n    if seed is None:\n        seed = random.randint(0, 2**32 - 1)\n\n    generator = torch.Generator(device=\"cpu\").manual_seed(seed)\n\n    image = pipeline(\n        prompt,\n        negative_prompt=negative_prompt,\n        image=image,\n        mask_image=mask,\n        guidance_scale=10.0,\n        strength=0.8,\n        num_inference_steps=30,\n        generator=generator,\n    ).images[0]\n\n    return image\n\nprompt = \"high quality photo of nike air jordans on a basketball court, highly detailed\"\nnegative_prompt = \"\"\n\nfinal_image = generate_outpaint(prompt, negative_prompt, temp_image, mask_blurred, 7688778)\nx = (1024 - resized_img.width) // 2\ny = (1024 - resized_img.height) // 2\nfinal_image.paste(resized_img, (x, y), resized_img)\nfinal_image\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/final-outpaint.png\"/>\n</div>\n"
  },
  {
    "path": "docs/source/en/api/activations.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Activation functions\n\nCustomized activation functions for supporting various models in 🤗 Diffusers.\n\n## GELU\n\n[[autodoc]] models.activations.GELU\n\n## GEGLU\n\n[[autodoc]] models.activations.GEGLU\n\n## ApproximateGELU\n\n[[autodoc]] models.activations.ApproximateGELU\n\n\n## SwiGLU\n\n[[autodoc]] models.activations.SwiGLU\n\n## FP32SiLU\n\n[[autodoc]] models.activations.FP32SiLU\n\n## LinearActivation\n\n[[autodoc]] models.activations.LinearActivation\n"
  },
  {
    "path": "docs/source/en/api/attnprocessor.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Attention Processor\n\nAn attention processor is a class for applying different types of attention mechanisms.\n\n## AttnProcessor\n\n[[autodoc]] models.attention_processor.AttnProcessor\n\n[[autodoc]] models.attention_processor.AttnProcessor2_0\n\n[[autodoc]] models.attention_processor.AttnAddedKVProcessor\n\n[[autodoc]] models.attention_processor.AttnAddedKVProcessor2_0\n\n[[autodoc]] models.attention_processor.AttnProcessorNPU\n\n[[autodoc]] models.attention_processor.FusedAttnProcessor2_0\n\n## Allegro\n\n[[autodoc]] models.attention_processor.AllegroAttnProcessor2_0\n\n## AuraFlow\n\n[[autodoc]] models.attention_processor.AuraFlowAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.FusedAuraFlowAttnProcessor2_0\n\n## CogVideoX\n\n[[autodoc]] models.attention_processor.CogVideoXAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.FusedCogVideoXAttnProcessor2_0\n\n## CrossFrameAttnProcessor\n\n[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor\n\n## Custom Diffusion\n\n[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor\n\n[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor\n\n## Flux\n\n[[autodoc]] models.attention_processor.FluxAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.FusedFluxAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.FluxSingleAttnProcessor2_0\n\n## Hunyuan\n\n[[autodoc]] models.attention_processor.HunyuanAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.FusedHunyuanAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.PAGHunyuanAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.PAGCFGHunyuanAttnProcessor2_0\n\n## IdentitySelfAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.PAGIdentitySelfAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0\n\n## IP-Adapter\n\n[[autodoc]] models.attention_processor.IPAdapterAttnProcessor\n\n[[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0\n\n## JointAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.JointAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.PAGJointAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.PAGCFGJointAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.FusedJointAttnProcessor2_0\n\n## LoRA\n\n[[autodoc]] models.attention_processor.LoRAAttnProcessor\n\n[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor\n\n[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor\n\n## Lumina-T2X\n\n[[autodoc]] models.attention_processor.LuminaAttnProcessor2_0\n\n## Mochi\n\n[[autodoc]] models.attention_processor.MochiAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.MochiVaeAttnProcessor2_0\n\n## Sana\n\n[[autodoc]] models.attention_processor.SanaLinearAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.SanaMultiscaleAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0\n\n## Stable Audio\n\n[[autodoc]] models.attention_processor.StableAudioAttnProcessor2_0\n\n## SlicedAttnProcessor\n\n[[autodoc]] models.attention_processor.SlicedAttnProcessor\n\n[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor\n\n## XFormersAttnProcessor\n\n[[autodoc]] models.attention_processor.XFormersAttnProcessor\n\n[[autodoc]] models.attention_processor.XFormersAttnAddedKVProcessor\n\n## XLAFlashAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.XLAFlashAttnProcessor2_0\n\n## XFormersJointAttnProcessor\n\n[[autodoc]] models.attention_processor.XFormersJointAttnProcessor\n\n## IPAdapterXFormersAttnProcessor\n\n[[autodoc]] models.attention_processor.IPAdapterXFormersAttnProcessor\n\n## FluxIPAdapterJointAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.FluxIPAdapterJointAttnProcessor2_0\n\n\n## XLAFluxFlashAttnProcessor2_0\n\n[[autodoc]] models.attention_processor.XLAFluxFlashAttnProcessor2_0"
  },
  {
    "path": "docs/source/en/api/cache.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# Caching methods\n\nCache methods speedup diffusion transformers by storing and reusing intermediate outputs of specific layers, such as attention and feedforward layers, instead of recalculating them at each inference step.\n\n## CacheMixin\n\n[[autodoc]] CacheMixin\n\n## PyramidAttentionBroadcastConfig\n\n[[autodoc]] PyramidAttentionBroadcastConfig\n\n[[autodoc]] apply_pyramid_attention_broadcast\n\n## FasterCacheConfig\n\n[[autodoc]] FasterCacheConfig\n\n[[autodoc]] apply_faster_cache\n\n## FirstBlockCacheConfig\n\n[[autodoc]] FirstBlockCacheConfig\n\n[[autodoc]] apply_first_block_cache\n\n### TaylorSeerCacheConfig\n\n[[autodoc]] TaylorSeerCacheConfig\n\n[[autodoc]] apply_taylorseer_cache\n"
  },
  {
    "path": "docs/source/en/api/configuration.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Configuration\n\nSchedulers from [`~schedulers.scheduling_utils.SchedulerMixin`] and models from [`ModelMixin`] inherit from [`ConfigMixin`] which stores all the parameters that are passed to their respective `__init__` methods in a JSON-configuration file.\n\n> [!TIP]\n> To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf auth login`.\n\n## ConfigMixin\n\n[[autodoc]] ConfigMixin\n\t- load_config\n\t- from_config\n\t- save_config\n\t- to_json_file\n\t- to_json_string\n"
  },
  {
    "path": "docs/source/en/api/image_processor.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# VAE Image Processor\n\nThe [`VaeImageProcessor`] provides a unified API for [`StableDiffusionPipeline`]s to prepare image inputs for VAE encoding and post-processing outputs once they're decoded. This includes transformations such as resizing, normalization, and conversion between PIL Image, PyTorch, and NumPy arrays.\n\nAll pipelines with [`VaeImageProcessor`] accept PIL Image, PyTorch tensor, or NumPy arrays as image inputs and return outputs based on the `output_type` argument by the user. You can pass encoded image latents directly to the pipeline and return latents from the pipeline as a specific output with the `output_type` argument (for example `output_type=\"latent\"`). This allows you to take the generated latents from one pipeline and pass it to another pipeline as input without leaving the latent space. It also makes it much easier to use multiple pipelines together by passing PyTorch tensors directly between different pipelines.\n\n## VaeImageProcessor\n\n[[autodoc]] image_processor.VaeImageProcessor\n\n## InpaintProcessor\n\nThe [`InpaintProcessor`] accepts `mask` and `image` inputs and process them together. Optionally, it can accept padding_mask_crop and apply mask overlay.\n\n[[autodoc]] image_processor.InpaintProcessor\n\n## VaeImageProcessorLDM3D\n\nThe [`VaeImageProcessorLDM3D`] accepts RGB and depth inputs and returns RGB and depth outputs.\n\n[[autodoc]] image_processor.VaeImageProcessorLDM3D\n\n## PixArtImageProcessor\n\n[[autodoc]] image_processor.PixArtImageProcessor\n\n## IPAdapterMaskProcessor\n\n[[autodoc]] image_processor.IPAdapterMaskProcessor\n"
  },
  {
    "path": "docs/source/en/api/internal_classes_overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Overview\n\nThe APIs in this section are more experimental and prone to breaking changes. Most of them are used internally for development, but they may also be useful to you if you're interested in building a diffusion model with some custom parts or if you're interested in some of our helper utilities for working with 🤗 Diffusers.\n"
  },
  {
    "path": "docs/source/en/api/loaders/ip_adapter.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# IP-Adapter\n\n[IP-Adapter](https://hf.co/papers/2308.06721) is a lightweight adapter that enables prompting a diffusion model with an image. This method decouples the cross-attention layers of the image and text features. The image features are generated from an image encoder.\n\n> [!TIP]\n> Learn how to load and use an IP-Adapter checkpoint and image in the [IP-Adapter](../../using-diffusers/ip_adapter) guide,.\n\n## IPAdapterMixin\n\n[[autodoc]] loaders.ip_adapter.IPAdapterMixin\n\n## SD3IPAdapterMixin\n\n[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin\n    - all\n    - is_ip_adapter_active\n\n## IPAdapterMaskProcessor\n\n[[autodoc]] image_processor.IPAdapterMaskProcessor"
  },
  {
    "path": "docs/source/en/api/loaders/lora.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# LoRA\n\nLoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the denoiser, text encoder or both. The denoiser usually corresponds to a UNet ([`UNet2DConditionModel`], for example) or a Transformer ([`SD3Transformer2DModel`], for example). There are several classes for loading LoRA weights:\n\n- [`StableDiffusionLoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model.\n- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`StableDiffusionLoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model.\n- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3).\n- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).\n- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).\n- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).\n- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).\n- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).\n- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).\n- [`HeliosLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/helios).\n- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).\n- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).\n- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).\n- [`SkyReelsV2LoraLoaderMixin`] provides similar functions for [SkyReels-V2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/skyreels_v2).\n- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).\n- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].\n- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)\n- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).\n- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).\n- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).\n- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2).\n- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.\n\n> [!TIP]\n> To learn more about how to load LoRA weights, see the [LoRA](../../tutorials/using_peft_for_inference) loading guide.\n\n## LoraBaseMixin\n\n[[autodoc]] loaders.lora_base.LoraBaseMixin\n\n## StableDiffusionLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin\n\n## StableDiffusionXLLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin\n\n## SD3LoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin\n\n## FluxLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin\n\n## Flux2LoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin\n\n## LTX2LoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin\n\n## CogVideoXLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin\n\n## Mochi1LoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin\n## AuraFlowLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin\n\n## LTXVideoLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.LTXVideoLoraLoaderMixin\n\n## SanaLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin\n\n## HeliosLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.HeliosLoraLoaderMixin\n\n## HunyuanVideoLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin\n\n## Lumina2LoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin\n\n## CogView4LoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.CogView4LoraLoaderMixin\n\n## WanLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin\n\n## SkyReelsV2LoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.SkyReelsV2LoraLoaderMixin\n\n## AmusedLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin\n\n## HiDreamImageLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin\n\n## QwenImageLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin\n\n## ZImageLoraLoaderMixin\n\n[[autodoc]] loaders.lora_pipeline.ZImageLoraLoaderMixin\n\n## KandinskyLoraLoaderMixin\n[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin\n\n## LoraBaseMixin\n\n[[autodoc]] loaders.lora_base.LoraBaseMixin"
  },
  {
    "path": "docs/source/en/api/loaders/peft.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# PEFT\n\nDiffusers supports loading adapters such as [LoRA](../../tutorials/using_peft_for_inference) with the [PEFT](https://huggingface.co/docs/peft/index) library with the [`~loaders.peft.PeftAdapterMixin`] class. This allows modeling classes in Diffusers like [`UNet2DConditionModel`], [`SD3Transformer2DModel`] to operate with an adapter.\n\n> [!TIP]\n> Refer to the [Inference with PEFT](../../tutorials/using_peft_for_inference.md) tutorial for an overview of how to use PEFT in Diffusers for inference.\n\n## PeftAdapterMixin\n\n[[autodoc]] loaders.peft.PeftAdapterMixin\n"
  },
  {
    "path": "docs/source/en/api/loaders/single_file.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Single files\n\nThe [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:\n\n* a model stored in a single file, which is useful if you're working with models from the diffusion ecosystem, like Automatic1111, and commonly rely on a single-file layout to store and share models\n* a model stored in their originally distributed layout, which is useful if you're working with models finetuned with other services, and want to load it directly into Diffusers model objects and pipelines\n\n> [!TIP]\n> Read the [Model files and layouts](../../using-diffusers/other-formats) guide to learn more about the Diffusers-multifolder layout versus the single-file layout, and how to load models stored in these different layouts.\n\n## Supported pipelines\n\n- [`StableDiffusionPipeline`]\n- [`StableDiffusionImg2ImgPipeline`]\n- [`StableDiffusionInpaintPipeline`]\n- [`StableDiffusionControlNetPipeline`]\n- [`StableDiffusionControlNetImg2ImgPipeline`]\n- [`StableDiffusionControlNetInpaintPipeline`]\n- [`StableDiffusionUpscalePipeline`]\n- [`StableDiffusionXLPipeline`]\n- [`StableDiffusionXLImg2ImgPipeline`]\n- [`StableDiffusionXLInpaintPipeline`]\n- [`StableDiffusionXLInstructPix2PixPipeline`]\n- [`StableDiffusionXLControlNetPipeline`]\n- [`StableDiffusionXLKDiffusionPipeline`]\n- [`StableDiffusion3Pipeline`]\n- [`LatentConsistencyModelPipeline`]\n- [`LatentConsistencyModelImg2ImgPipeline`]\n- [`StableDiffusionControlNetXSPipeline`]\n- [`StableDiffusionXLControlNetXSPipeline`]\n- [`LEditsPPPipelineStableDiffusion`]\n- [`LEditsPPPipelineStableDiffusionXL`]\n- [`PIAPipeline`]\n\n## Supported models\n\n- [`UNet2DConditionModel`]\n- [`StableCascadeUNet`]\n- [`AutoencoderKL`]\n- [`ControlNetModel`]\n- [`SD3Transformer2DModel`]\n- [`FluxTransformer2DModel`]\n\n## FromSingleFileMixin\n\n[[autodoc]] loaders.single_file.FromSingleFileMixin\n\n## FromOriginalModelMixin\n\n[[autodoc]] loaders.single_file_model.FromOriginalModelMixin\n"
  },
  {
    "path": "docs/source/en/api/loaders/textual_inversion.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Textual Inversion\n\nTextual Inversion is a training method for personalizing models by learning new text embeddings from a few example images. The file produced from training is extremely small (a few KBs) and the new embeddings can be loaded into the text encoder.\n\n[`TextualInversionLoaderMixin`] provides a function for loading Textual Inversion embeddings from Diffusers and Automatic1111 into the text encoder and loading a special token to activate the embeddings.\n\n> [!TIP]\n> To learn more about how to load Textual Inversion embeddings, see the [Textual Inversion](../../using-diffusers/textual_inversion_inference) loading guide.\n\n## TextualInversionLoaderMixin\n\n[[autodoc]] loaders.textual_inversion.TextualInversionLoaderMixin"
  },
  {
    "path": "docs/source/en/api/loaders/transformer_sd3.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# SD3Transformer2D\n\nThis class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.\n\nThe [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.\n\n> [!TIP]\n> To learn more about how to load LoRA weights, see the [LoRA](../../tutorials/using_peft_for_inference) loading guide.\n\n## SD3Transformer2DLoadersMixin\n\n[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin\n    - all\n    - _load_ip_adapter_weights"
  },
  {
    "path": "docs/source/en/api/loaders/unet.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# UNet\n\nSome training methods - like LoRA and Custom Diffusion - typically target the UNet's attention layers, but these training methods can also target other non-attention layers. Instead of training all of a model's parameters, only a subset of the parameters are trained, which is faster and more efficient. This class is useful if you're *only* loading weights into a UNet. If you need to load weights into the text encoder or a text encoder and UNet, try using the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] function instead.\n\nThe [`UNet2DConditionLoadersMixin`] class provides functions for loading and saving weights, fusing and unfusing LoRAs, disabling and enabling LoRAs, and setting and deleting adapters.\n\n> [!TIP]\n> To learn more about how to load LoRA weights, see the [LoRA](../../tutorials/using_peft_for_inference) guide.\n\n## UNet2DConditionLoadersMixin\n\n[[autodoc]] loaders.unet.UNet2DConditionLoadersMixin"
  },
  {
    "path": "docs/source/en/api/logging.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Logging\n\n🤗 Diffusers has a centralized logging system to easily manage the verbosity of the library. The default verbosity is set to `WARNING`.\n\nTo change the verbosity level, use one of the direct setters. For instance, to change the verbosity to the `INFO` level.\n\n```python\nimport diffusers\n\ndiffusers.logging.set_verbosity_info()\n```\n\nYou can also use the environment variable `DIFFUSERS_VERBOSITY` to override the default verbosity. You can set it\nto one of the following: `debug`, `info`, `warning`, `error`, `critical`. For example:\n\n```bash\nDIFFUSERS_VERBOSITY=error ./myprogram.py\n```\n\nAdditionally, some `warnings` can be disabled by setting the environment variable\n`DIFFUSERS_NO_ADVISORY_WARNINGS` to a true value, like `1`. This disables any warning logged by\n[`logger.warning_advice`]. For example:\n\n```bash\nDIFFUSERS_NO_ADVISORY_WARNINGS=1 ./myprogram.py\n```\n\nHere is an example of how to use the same logger as the library in your own module or script:\n\n```python\nfrom diffusers.utils import logging\n\nlogging.set_verbosity_info()\nlogger = logging.get_logger(\"diffusers\")\nlogger.info(\"INFO\")\nlogger.warning(\"WARN\")\n```\n\n\nAll methods of the logging module are documented below. The main methods are\n[`logging.get_verbosity`] to get the current level of verbosity in the logger and\n[`logging.set_verbosity`] to set the verbosity to the level of your choice.\n\nIn order from the least verbose to the most verbose:\n\n|                                                    Method | Integer value |                                         Description |\n|----------------------------------------------------------:|--------------:|----------------------------------------------------:|\n| `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` |            50 |                only report the most critical errors |\n|                                 `diffusers.logging.ERROR` |            40 |                                  only report errors |\n|   `diffusers.logging.WARNING` or `diffusers.logging.WARN` |            30 |           only report errors and warnings (default) |\n|                                  `diffusers.logging.INFO` |            20 | only report errors, warnings, and basic information |\n|                                 `diffusers.logging.DEBUG` |            10 |                              report all information |\n\nBy default, `tqdm` progress bars are displayed during model download. [`logging.disable_progress_bar`] and [`logging.enable_progress_bar`] are used to enable or disable this behavior.\n\n## Base setters\n\n[[autodoc]] utils.logging.set_verbosity_error\n\n[[autodoc]] utils.logging.set_verbosity_warning\n\n[[autodoc]] utils.logging.set_verbosity_info\n\n[[autodoc]] utils.logging.set_verbosity_debug\n\n## Other functions\n\n[[autodoc]] utils.logging.get_verbosity\n\n[[autodoc]] utils.logging.set_verbosity\n\n[[autodoc]] utils.logging.get_logger\n\n[[autodoc]] utils.logging.enable_default_handler\n\n[[autodoc]] utils.logging.disable_default_handler\n\n[[autodoc]] utils.logging.enable_explicit_format\n\n[[autodoc]] utils.logging.reset_format\n\n[[autodoc]] utils.logging.enable_progress_bar\n\n[[autodoc]] utils.logging.disable_progress_bar\n"
  },
  {
    "path": "docs/source/en/api/models/allegro_transformer3d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AllegroTransformer3DModel\n\nA Diffusion Transformer model for 3D data from [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AllegroTransformer3DModel\n\ntransformer = AllegroTransformer3DModel.from_pretrained(\"rhymes-ai/Allegro\", subfolder=\"transformer\", torch_dtype=torch.bfloat16).to(\"cuda\")\n```\n\n## AllegroTransformer3DModel\n\n[[autodoc]] AllegroTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/asymmetricautoencoderkl.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AsymmetricAutoencoderKL\n\nImproved larger variational autoencoder (VAE) model with KL loss for inpainting task: [Designing a Better Asymmetric VQGAN for StableDiffusion](https://huggingface.co/papers/2306.04632) by Zixin Zhu, Xuelu Feng, Dongdong Chen, Jianmin Bao, Le Wang, Yinpeng Chen, Lu Yuan, Gang Hua.\n\nThe abstract from the paper is:\n\n*StableDiffusion is a revolutionary text-to-image generator that is causing a stir in the world of image generation and editing. Unlike traditional methods that learn a diffusion model in pixel space, StableDiffusion learns a diffusion model in the latent space via a VQGAN, ensuring both efficiency and quality. It not only supports image generation tasks, but also enables image editing for real images, such as image inpainting and local editing. However, we have observed that the vanilla VQGAN used in StableDiffusion leads to significant information loss, causing distortion artifacts even in non-edited image regions. To this end, we propose a new asymmetric VQGAN with two simple designs. Firstly, in addition to the input from the encoder, the decoder contains a conditional branch that incorporates information from task-specific priors, such as the unmasked image region in inpainting. Secondly, the decoder is much heavier than the encoder, allowing for more detailed recovery while only slightly increasing the total inference cost. The training cost of our asymmetric VQGAN is cheap, and we only need to retrain a new asymmetric decoder while keeping the vanilla VQGAN encoder and StableDiffusion unchanged. Our asymmetric VQGAN can be widely used in StableDiffusion-based inpainting and local editing methods. Extensive experiments demonstrate that it can significantly improve the inpainting and editing performance, while maintaining the original text-to-image capability. The code is available at https://github.com/buxiangzhiren/Asymmetric_VQGAN*\n\nEvaluation results can be found in section 4.1 of the original paper.\n\n## Available checkpoints\n\n* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5)\n* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2)\n\n## Example Usage\n\n```python\nfrom diffusers import AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline\nfrom diffusers.utils import load_image, make_image_grid\n\n\nprompt = \"a photo of a person with beard\"\nimg_url = \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png\"\nmask_url = \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png\"\n\noriginal_image = load_image(img_url).resize((512, 512))\nmask_image = load_image(mask_url).resize((512, 512))\n\npipe = StableDiffusionInpaintPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-inpainting\")\npipe.vae = AsymmetricAutoencoderKL.from_pretrained(\"cross-attention/asymmetric-autoencoder-kl-x-1-5\")\npipe.to(\"cuda\")\n\nimage = pipe(prompt=prompt, image=original_image, mask_image=mask_image).images[0]\nmake_image_grid([original_image, mask_image, image], rows=1, cols=3)\n```\n\n## AsymmetricAutoencoderKL\n\n[[autodoc]] models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL\n\n## AutoencoderKLOutput\n\n[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/aura_flow_transformer2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AuraFlowTransformer2DModel\n\nA Transformer model for image-like data from [AuraFlow](https://blog.fal.ai/auraflow/).\n\n## AuraFlowTransformer2DModel\n\n[[autodoc]] AuraFlowTransformer2DModel\n"
  },
  {
    "path": "docs/source/en/api/models/auto_model.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AutoModel\n\n[`AutoModel`] automatically retrieves the correct model class from the checkpoint `config.json` file.\n\n## AutoModel\n\n[[autodoc]] AutoModel\n\t- all\n\t- from_pretrained\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoder_dc.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderDC\n\nThe 2D Autoencoder model used in [SANA](https://huggingface.co/papers/2410.10629) and introduced in [DCAE](https://huggingface.co/papers/2410.10733) by authors Junyu Chen\\*, Han Cai\\*, Junsong Chen, Enze Xie, Shang Yang, Haotian Tang, Muyang Li, Yao Lu, Song Han from MIT HAN Lab.\n\nThe abstract from the paper is:\n\n*We present Deep Compression Autoencoder (DC-AE), a new family of autoencoder models for accelerating high-resolution diffusion models. Existing autoencoder models have demonstrated impressive results at a moderate spatial compression ratio (e.g., 8x), but fail to maintain satisfactory reconstruction accuracy for high spatial compression ratios (e.g., 64x). We address this challenge by introducing two key techniques: (1) Residual Autoencoding, where we design our models to learn residuals based on the space-to-channel transformed features to alleviate the optimization difficulty of high spatial-compression autoencoders; (2) Decoupled High-Resolution Adaptation, an efficient decoupled three-phases training strategy for mitigating the generalization penalty of high spatial-compression autoencoders. With these designs, we improve the autoencoder's spatial compression ratio up to 128 while maintaining the reconstruction quality. Applying our DC-AE to latent diffusion models, we achieve significant speedup without accuracy drop. For example, on ImageNet 512x512, our DC-AE provides 19.1x inference speedup and 17.9x training speedup on H100 GPU for UViT-H while achieving a better FID, compared with the widely used SD-VAE-f8 autoencoder. Our code is available at [this https URL](https://github.com/mit-han-lab/efficientvit).*\n\nThe following DCAE models are released and supported in Diffusers.\n\n| Diffusers format | Original format |\n|:----------------:|:---------------:|\n| [`mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-sana-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0)\n| [`mit-han-lab/dc-ae-f32c32-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0)\n| [`mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0)\n| [`mit-han-lab/dc-ae-f64c128-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0)\n| [`mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0)\n| [`mit-han-lab/dc-ae-f128c512-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0)\n| [`mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0)\n\nThis model was contributed by [lawrence-cj](https://github.com/lawrence-cj).\n\nLoad a model in Diffusers format with [`~ModelMixin.from_pretrained`].\n\n```python\nfrom diffusers import AutoencoderDC\n\nae = AutoencoderDC.from_pretrained(\"mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers\", torch_dtype=torch.float32).to(\"cuda\")\n```\n\n## Load a model in Diffusers via `from_single_file`\n\n```python\nfrom difusers import AutoencoderDC\n\nckpt_path = \"https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors\"\nmodel = AutoencoderDC.from_single_file(ckpt_path) \n\n```\n\nThe `AutoencoderDC` model has `in` and `mix` single file checkpoint variants that have matching checkpoint keys, but use different scaling factors. It is not possible for Diffusers to automatically infer the correct config file to use with the model based on just the checkpoint and will default to configuring the model using the `mix` variant config file. To override the automatically determined config, please use the `config` argument when using single file loading with `in` variant checkpoints. \n\n```python\nfrom diffusers import AutoencoderDC\n\nckpt_path = \"https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0/blob/main/model.safetensors\"\nmodel = AutoencoderDC.from_single_file(ckpt_path, config=\"mit-han-lab/dc-ae-f128c512-in-1.0-diffusers\")\n```\n\n\n## AutoencoderDC\n\n[[autodoc]] AutoencoderDC\n  - encode\n  - decode\n  - all\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoder_kl_hunyuan_video.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLHunyuanVideo\n\nThe 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](https://github.com/Tencent/HunyuanVideo/), which was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLHunyuanVideo\n\nvae = AutoencoderKLHunyuanVideo.from_pretrained(\"hunyuanvideo-community/HunyuanVideo\", subfolder=\"vae\", torch_dtype=torch.float16)\n```\n\n## AutoencoderKLHunyuanVideo\n\n[[autodoc]] AutoencoderKLHunyuanVideo\n  - decode\n  - all\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoder_kl_hunyuan_video15.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLHunyuanVideo15\n\nThe 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo1.5](https://github.com/Tencent/HunyuanVideo1-1.5) by Tencent.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLHunyuanVideo15\n\nvae = AutoencoderKLHunyuanVideo15.from_pretrained(\"hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v\", subfolder=\"vae\", torch_dtype=torch.float32)\n\n# make sure to enable tiling to avoid OOM\nvae.enable_tiling()\n```\n\n## AutoencoderKLHunyuanVideo15\n\n[[autodoc]] AutoencoderKLHunyuanVideo15\n  - decode\n  - encode\n  - all\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoder_kl_hunyuanimage.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLHunyuanImage\n\nThe 2D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1].\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLHunyuanImage\n\nvae = AutoencoderKLHunyuanImage.from_pretrained(\"hunyuanvideo-community/HunyuanImage-2.1-Diffusers\", subfolder=\"vae\", torch_dtype=torch.bfloat16)\n```\n\n## AutoencoderKLHunyuanImage\n\n[[autodoc]] AutoencoderKLHunyuanImage\n  - decode\n  - all\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoder_kl_hunyuanimage_refiner.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLHunyuanImageRefiner\n\nThe 3D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) for its refiner pipeline.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLHunyuanImageRefiner\n\nvae = AutoencoderKLHunyuanImageRefiner.from_pretrained(\"hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers\", subfolder=\"vae\", torch_dtype=torch.bfloat16)\n```\n\n## AutoencoderKLHunyuanImageRefiner\n\n[[autodoc]] AutoencoderKLHunyuanImageRefiner\n  - decode\n  - all\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoder_kl_wan.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLWan\n\nThe 3D variational autoencoder (VAE) model with KL loss used in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLWan\n\nvae = AutoencoderKLWan.from_pretrained(\"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\", subfolder=\"vae\", torch_dtype=torch.float32)\n```\n\n## AutoencoderKLWan\n\n[[autodoc]] AutoencoderKLWan\n  - decode\n  - all\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoder_oobleck.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AutoencoderOobleck\n\nThe Oobleck variational autoencoder (VAE) model with KL loss was introduced in [Stability-AI/stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools) and [Stable Audio Open](https://huggingface.co/papers/2407.14358) by Stability AI. The model is used in 🤗 Diffusers to encode audio waveforms into latents and to decode latent representations into audio waveforms.\n\nThe abstract from the paper is:\n\n*Open generative models are vitally important for the community, allowing for fine-tunes and serving as baselines when presenting new models. However, most current text-to-audio models are private and not accessible for artists and researchers to build upon. Here we describe the architecture and training process of a new open-weights text-to-audio model trained with Creative Commons data. Our evaluation shows that the model's performance is competitive with the state-of-the-art across various metrics. Notably, the reported FDopenl3 results (measuring the realism of the generations) showcase its potential for high-quality stereo sound synthesis at 44.1kHz.*\n\n## AutoencoderOobleck\n\n[[autodoc]] AutoencoderOobleck\n    - decode\n    - encode\n    - all\n\n## OobleckDecoderOutput\n\n[[autodoc]] models.autoencoders.autoencoder_oobleck.OobleckDecoderOutput\n\n## OobleckDecoderOutput\n\n[[autodoc]] models.autoencoders.autoencoder_oobleck.OobleckDecoderOutput\n\n## AutoencoderOobleckOutput\n\n[[autodoc]] models.autoencoders.autoencoder_oobleck.AutoencoderOobleckOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoder_rae.md",
    "content": "<!-- Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AutoencoderRAE\n\nThe Representation Autoencoder (RAE) model introduced in [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) by Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie from NYU VISIONx.\n\nRAE combines a frozen pretrained vision encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT-MAE-style decoder. In the two-stage RAE training recipe, the autoencoder is trained in stage 1 (reconstruction), and then a diffusion model is trained on the resulting latent space in stage 2 (generation).\n\nThe following RAE models are released and supported in Diffusers:\n\n| Model | Encoder | Latent shape (224px input) |\n|:------|:--------|:---------------------------|\n| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08) | DINOv2-base | 768 x 16 x 16 |\n| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512) | DINOv2-base (512px) | 768 x 32 x 32 |\n| [`nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08) | DINOv2-small | 384 x 16 x 16 |\n| [`nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08) | DINOv2-large | 1024 x 16 x 16 |\n| [`nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08) | SigLIP2-base | 768 x 16 x 16 |\n| [`nyu-visionx/RAE-mae-base-p16-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-mae-base-p16-ViTXL-n08) | MAE-base | 768 x 16 x 16 |\n\n## Loading a pretrained model\n\n```python\nfrom diffusers import AutoencoderRAE\n\nmodel = AutoencoderRAE.from_pretrained(\n    \"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08\"\n).to(\"cuda\").eval()\n```\n\n## Encoding and decoding a real image\n\n```python\nimport torch\nfrom diffusers import AutoencoderRAE\nfrom diffusers.utils import load_image\nfrom torchvision.transforms.functional import to_tensor, to_pil_image\n\nmodel = AutoencoderRAE.from_pretrained(\n    \"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08\"\n).to(\"cuda\").eval()\n\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png\")\nimage = image.convert(\"RGB\").resize((224, 224))\nx = to_tensor(image).unsqueeze(0).to(\"cuda\")  # (1, 3, 224, 224), values in [0, 1]\n\nwith torch.no_grad():\n    latents = model.encode(x).latent        # (1, 768, 16, 16)\n    recon = model.decode(latents).sample     # (1, 3, 256, 256)\n\nrecon_image = to_pil_image(recon[0].clamp(0, 1).cpu())\nrecon_image.save(\"recon.png\")\n```\n\n## Latent normalization\n\nSome pretrained checkpoints include per-channel `latents_mean` and `latents_std` statistics for normalizing the latent space. When present, `encode` and `decode` automatically apply the normalization and denormalization, respectively.\n\n```python\nmodel = AutoencoderRAE.from_pretrained(\n    \"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08\"\n).to(\"cuda\").eval()\n\n# Latent normalization is handled automatically inside encode/decode\n# when the checkpoint config includes latents_mean/latents_std.\nwith torch.no_grad():\n    latents = model.encode(x).latent   # normalized latents\n    recon = model.decode(latents).sample\n```\n\n## AutoencoderRAE\n\n[[autodoc]] AutoencoderRAE\n  - encode\n  - decode\n  - all\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoder_tiny.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Tiny AutoEncoder\n\nTiny AutoEncoder for Stable Diffusion (TAESD) was introduced in [madebyollin/taesd](https://github.com/madebyollin/taesd) by Ollin Boer Bohan. It is a tiny distilled version of Stable Diffusion's VAE that can quickly decode the latents in a [`StableDiffusionPipeline`] or [`StableDiffusionXLPipeline`] almost instantly.\n\nTo use with Stable Diffusion v-2.1:\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline, AutoencoderTiny\n\npipe = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-2-1-base\", torch_dtype=torch.float16\n)\npipe.vae = AutoencoderTiny.from_pretrained(\"madebyollin/taesd\", torch_dtype=torch.float16)\npipe = pipe.to(\"cuda\")\n\nprompt = \"slice of delicious New York-style berry cheesecake\"\nimage = pipe(prompt, num_inference_steps=25).images[0]\nimage\n```\n\nTo use with Stable Diffusion XL 1.0\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline, AutoencoderTiny\n\npipe = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n)\npipe.vae = AutoencoderTiny.from_pretrained(\"madebyollin/taesdxl\", torch_dtype=torch.float16)\npipe = pipe.to(\"cuda\")\n\nprompt = \"slice of delicious New York-style berry cheesecake\"\nimage = pipe(prompt, num_inference_steps=25).images[0]\nimage\n```\n\n## AutoencoderTiny\n\n[[autodoc]] AutoencoderTiny\n\n## AutoencoderTinyOutput\n\n[[autodoc]] models.autoencoders.autoencoder_tiny.AutoencoderTinyOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoderkl.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AutoencoderKL\n\nThe variational autoencoder (VAE) model with KL loss was introduced in [Auto-Encoding Variational Bayes](https://huggingface.co/papers/1312.6114v11) by Diederik P. Kingma and Max Welling. The model is used in 🤗 Diffusers to encode images into latents and to decode latent representations into images.\n\nThe abstract from the paper is:\n\n*How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.*\n\n## Loading from the original format\n\nBy default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded\nfrom the original format using [`FromOriginalModelMixin.from_single_file`] as follows:\n\n```py\nfrom diffusers import AutoencoderKL\n\nurl = \"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors\"  # can also be a local file\nmodel = AutoencoderKL.from_single_file(url)\n```\n\n## AutoencoderKL\n\n[[autodoc]] AutoencoderKL\n    - decode\n    - encode\n    - all\n\n## AutoencoderKLOutput\n\n[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoderkl_allegro.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLAllegro\n\nThe 3D variational autoencoder (VAE) model with KL loss used in [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLAllegro\n\nvae = AutoencoderKLAllegro.from_pretrained(\"rhymes-ai/Allegro\", subfolder=\"vae\", torch_dtype=torch.float32).to(\"cuda\")\n```\n\n## AutoencoderKLAllegro\n\n[[autodoc]] AutoencoderKLAllegro\n    - decode\n    - encode\n    - all\n\n## AutoencoderKLOutput\n\n[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoderkl_audio_ltx_2.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLLTX2Audio\n\nThe 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLLTX2Audio\n\nvae = AutoencoderKLLTX2Audio.from_pretrained(\"Lightricks/LTX-2\", subfolder=\"vae\", torch_dtype=torch.float32).to(\"cuda\")\n```\n\n## AutoencoderKLLTX2Audio\n\n[[autodoc]] AutoencoderKLLTX2Audio\n    - encode\n    - decode\n    - all"
  },
  {
    "path": "docs/source/en/api/models/autoencoderkl_cogvideox.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLCogVideoX\n\nThe 3D variational autoencoder (VAE) model with KL loss used in [CogVideoX](https://github.com/THUDM/CogVideo) was introduced in [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) by Tsinghua University & ZhipuAI.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLCogVideoX\n\nvae = AutoencoderKLCogVideoX.from_pretrained(\"THUDM/CogVideoX-2b\", subfolder=\"vae\", torch_dtype=torch.float16).to(\"cuda\")\n```\n\n## AutoencoderKLCogVideoX\n\n[[autodoc]] AutoencoderKLCogVideoX\n    - decode\n    - encode\n    - all\n\n## AutoencoderKLOutput\n\n[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoderkl_cosmos.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLCosmos\n\n[Cosmos Tokenizers](https://github.com/NVIDIA/Cosmos-Tokenizer).\n\nSupported models:\n- [nvidia/Cosmos-1.0-Tokenizer-CV8x8x8](https://huggingface.co/nvidia/Cosmos-1.0-Tokenizer-CV8x8x8)\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLCosmos\n\nvae = AutoencoderKLCosmos.from_pretrained(\"nvidia/Cosmos-1.0-Tokenizer-CV8x8x8\", subfolder=\"vae\")\n```\n\n## AutoencoderKLCosmos\n\n[[autodoc]] AutoencoderKLCosmos\n    - decode\n    - encode\n    - all\n\n## AutoencoderKLOutput\n\n[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoderkl_ltx_2.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLLTX2Video\n\nThe 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLLTX2Video\n\nvae = AutoencoderKLLTX2Video.from_pretrained(\"Lightricks/LTX-2\", subfolder=\"vae\", torch_dtype=torch.float32).to(\"cuda\")\n```\n\n## AutoencoderKLLTX2Video\n\n[[autodoc]] AutoencoderKLLTX2Video\n    - decode\n    - encode\n    - all\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoderkl_ltx_video.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLLTXVideo\n\nThe 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLLTXVideo\n\nvae = AutoencoderKLLTXVideo.from_pretrained(\"Lightricks/LTX-Video\", subfolder=\"vae\", torch_dtype=torch.float32).to(\"cuda\")\n```\n\n## AutoencoderKLLTXVideo\n\n[[autodoc]] AutoencoderKLLTXVideo\n    - decode\n    - encode\n    - all\n\n## AutoencoderKLOutput\n\n[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoderkl_magvit.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLMagvit\n\nThe 3D variational autoencoder (VAE) model with KL loss used in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLMagvit\n\nvae = AutoencoderKLMagvit.from_pretrained(\"alibaba-pai/EasyAnimateV5.1-12b-zh\", subfolder=\"vae\", torch_dtype=torch.float16).to(\"cuda\")\n```\n\n## AutoencoderKLMagvit\n\n[[autodoc]] AutoencoderKLMagvit\n    - decode\n    - encode\n    - all\n\n## AutoencoderKLOutput\n\n[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoderkl_mochi.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLMochi\n\nThe 3D variational autoencoder (VAE) model with KL loss used in [Mochi](https://github.com/genmoai/models) was introduced in [Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Tsinghua University & ZhipuAI.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLMochi\n\nvae = AutoencoderKLMochi.from_pretrained(\"genmo/mochi-1-preview\", subfolder=\"vae\", torch_dtype=torch.float32).to(\"cuda\")\n```\n\n## AutoencoderKLMochi\n\n[[autodoc]] AutoencoderKLMochi\n    - decode\n    - all\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/autoencoderkl_qwenimage.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# AutoencoderKLQwenImage\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import AutoencoderKLQwenImage\n\nvae = AutoencoderKLQwenImage.from_pretrained(\"Qwen/QwenImage-20B\", subfolder=\"vae\")\n```\n\n## AutoencoderKLQwenImage\n\n[[autodoc]] AutoencoderKLQwenImage\n    - decode\n    - encode\n    - all\n\n## AutoencoderKLOutput\n\n[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput\n\n## DecoderOutput\n\n[[autodoc]] models.autoencoders.vae.DecoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/bria_transformer.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# BriaTransformer2DModel\n\nA modified flux Transformer model from [Bria](https://huggingface.co/briaai/BRIA-3.2)\n\n## BriaTransformer2DModel\n\n[[autodoc]] BriaTransformer2DModel\n"
  },
  {
    "path": "docs/source/en/api/models/chroma_transformer.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ChromaTransformer2DModel\n\nA modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)\n\n## ChromaTransformer2DModel\n\n[[autodoc]] ChromaTransformer2DModel\n"
  },
  {
    "path": "docs/source/en/api/models/chronoedit_transformer_3d.md",
    "content": "<!-- Copyright 2025 The ChronoEdit Team and HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# ChronoEditTransformer3DModel\n\nA Diffusion Transformer model for 3D video-like data from [ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.\n\n> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import ChronoEditTransformer3DModel\n\ntransformer = ChronoEditTransformer3DModel.from_pretrained(\"nvidia/ChronoEdit-14B-Diffusers\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## ChronoEditTransformer3DModel\n\n[[autodoc]] ChronoEditTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/cogvideox_transformer3d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# CogVideoXTransformer3DModel\n\nA Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideo) was introduced in [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) by Tsinghua University & ZhipuAI.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import CogVideoXTransformer3DModel\n\ntransformer = CogVideoXTransformer3DModel.from_pretrained(\"THUDM/CogVideoX-2b\", subfolder=\"transformer\", torch_dtype=torch.float16).to(\"cuda\")\n```\n\n## CogVideoXTransformer3DModel\n\n[[autodoc]] CogVideoXTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/cogview3plus_transformer2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# CogView3PlusTransformer2DModel\n\nA Diffusion Transformer model for 2D data from [CogView3Plus](https://github.com/THUDM/CogView3) was introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) by Tsinghua University & ZhipuAI.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import CogView3PlusTransformer2DModel\n\ntransformer = CogView3PlusTransformer2DModel.from_pretrained(\"THUDM/CogView3Plus-3b\", subfolder=\"transformer\", torch_dtype=torch.bfloat16).to(\"cuda\")\n```\n\n## CogView3PlusTransformer2DModel\n\n[[autodoc]] CogView3PlusTransformer2DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/cogview4_transformer2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# CogView4Transformer2DModel\n\nA Diffusion Transformer model for 2D data from [CogView4]()\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import CogView4Transformer2DModel\n\ntransformer = CogView4Transformer2DModel.from_pretrained(\"THUDM/CogView4-6B\", subfolder=\"transformer\", torch_dtype=torch.bfloat16).to(\"cuda\")\n```\n\n## CogView4Transformer2DModel\n\n[[autodoc]] CogView4Transformer2DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/consisid_transformer3d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# ConsisIDTransformer3DModel\n\nA Diffusion Transformer model for 3D data from [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) was introduced in [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://huggingface.co/papers/2411.17440) by Peking University & University of Rochester & etc.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import ConsisIDTransformer3DModel\n\ntransformer = ConsisIDTransformer3DModel.from_pretrained(\"BestWishYsh/ConsisID-preview\", subfolder=\"transformer\", torch_dtype=torch.bfloat16).to(\"cuda\")\n```\n\n## ConsisIDTransformer3DModel\n\n[[autodoc]] ConsisIDTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/consistency_decoder_vae.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Consistency Decoder\n\nConsistency decoder can be used to decode the latents from the denoising UNet in the [`StableDiffusionPipeline`]. This decoder was introduced in the [DALL-E 3 technical report](https://openai.com/dall-e-3).\n\nThe original codebase can be found at [openai/consistencydecoder](https://github.com/openai/consistencydecoder).\n\n> [!WARNING]\n> Inference is only supported for 2 iterations as of now.\n\nThe pipeline could not have been contributed without the help of [madebyollin](https://github.com/madebyollin) and [mrsteyk](https://github.com/mrsteyk) from [this issue](https://github.com/openai/consistencydecoder/issues/1).\n\n## ConsistencyDecoderVAE\n[[autodoc]] ConsistencyDecoderVAE\n    - all\n    - decode\n"
  },
  {
    "path": "docs/source/en/api/models/controlnet.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNetModel\n\nThe ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.\n\nThe abstract from the paper is:\n\n*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with \"zero convolutions\" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*\n\n## Loading from the original format\n\nBy default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded\nfrom the original format using [`FromOriginalModelMixin.from_single_file`] as follows:\n\n```py\nfrom diffusers import StableDiffusionControlNetPipeline, ControlNetModel\n\nurl = \"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth\"  # can also be a local path\ncontrolnet = ControlNetModel.from_single_file(url)\n\nurl = \"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors\"  # can also be a local path\npipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)\n```\n\n## Loading from Control LoRA\n\nControl-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs.\n\n```py\nfrom diffusers import ControlNetModel, UNet2DConditionModel\n\nlora_id = \"stabilityai/control-lora\"\nlora_filename = \"control-LoRAs-rank128/control-lora-canny-rank128.safetensors\"\n\nunet = UNet2DConditionModel.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", subfolder=\"unet\", torch_dtype=torch.bfloat16).to(\"cuda\")\ncontrolnet = ControlNetModel.from_unet(unet).to(device=\"cuda\", dtype=torch.bfloat16)\ncontrolnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config)\n```\n\n## ControlNetModel\n\n[[autodoc]] ControlNetModel\n\n## ControlNetOutput\n\n[[autodoc]] models.controlnets.controlnet.ControlNetOutput\n"
  },
  {
    "path": "docs/source/en/api/models/controlnet_flux.md",
    "content": "<!--Copyright 2025 The HuggingFace Team and The InstantX Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# FluxControlNetModel\n\nFluxControlNetModel is an implementation of ControlNet for Flux.1.\n\nThe ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.\n\nThe abstract from the paper is:\n\n*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with \"zero convolutions\" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*\n\n## Loading from the original format\n\nBy default the [`FluxControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`].\n\n```py\nfrom diffusers import FluxControlNetPipeline\nfrom diffusers.models import FluxControlNetModel, FluxMultiControlNetModel\n\ncontrolnet = FluxControlNetModel.from_pretrained(\"InstantX/FLUX.1-dev-Controlnet-Canny\")\npipe = FluxControlNetPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", controlnet=controlnet)\n\ncontrolnet = FluxControlNetModel.from_pretrained(\"InstantX/FLUX.1-dev-Controlnet-Canny\")\ncontrolnet = FluxMultiControlNetModel([controlnet])\npipe = FluxControlNetPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", controlnet=controlnet)\n```\n\n## FluxControlNetModel\n\n[[autodoc]] FluxControlNetModel\n\n## FluxControlNetOutput\n\n[[autodoc]] models.controlnets.controlnet_flux.FluxControlNetOutput"
  },
  {
    "path": "docs/source/en/api/models/controlnet_hunyuandit.md",
    "content": "<!--Copyright 2025 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# HunyuanDiT2DControlNetModel\n\nHunyuanDiT2DControlNetModel is an implementation of ControlNet for [Hunyuan-DiT](https://huggingface.co/papers/2405.08748).\n\nControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.\n\nWith a ControlNet model, you can provide an additional control image to condition and control Hunyuan-DiT generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.\n\nThe abstract from the paper is:\n\n*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with \"zero convolutions\" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*\n\nThis code is implemented by Tencent Hunyuan Team. You can find pre-trained checkpoints for Hunyuan-DiT ControlNets on [Tencent Hunyuan](https://huggingface.co/Tencent-Hunyuan).\n\n## Example For Loading HunyuanDiT2DControlNetModel\n\n```py\nfrom diffusers import HunyuanDiT2DControlNetModel\nimport torch\ncontrolnet = HunyuanDiT2DControlNetModel.from_pretrained(\"Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Pose\", torch_dtype=torch.float16)\n```\n\n## HunyuanDiT2DControlNetModel\n\n[[autodoc]] HunyuanDiT2DControlNetModel"
  },
  {
    "path": "docs/source/en/api/models/controlnet_sana.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# SanaControlNetModel\n\nThe ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.\n\nThe abstract from the paper is:\n\n*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with \"zero convolutions\" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*\n\nThis model was contributed by [ishan24](https://huggingface.co/ishan24). ❤️\nThe original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.\n\n## SanaControlNetModel\n[[autodoc]] SanaControlNetModel\n\n## SanaControlNetOutput\n[[autodoc]] models.controlnets.controlnet_sana.SanaControlNetOutput\n\n"
  },
  {
    "path": "docs/source/en/api/models/controlnet_sd3.md",
    "content": "<!--Copyright 2025 The HuggingFace Team and The InstantX Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# SD3ControlNetModel\n\nSD3ControlNetModel is an implementation of ControlNet for Stable Diffusion 3.\n\nThe ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.\n\nThe abstract from the paper is:\n\n*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with \"zero convolutions\" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*\n\n## Loading from the original format\n\nBy default the [`SD3ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`].\n\n```py\nfrom diffusers import StableDiffusion3ControlNetPipeline\nfrom diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel\n\ncontrolnet = SD3ControlNetModel.from_pretrained(\"InstantX/SD3-Controlnet-Canny\")\npipe = StableDiffusion3ControlNetPipeline.from_pretrained(\"stabilityai/stable-diffusion-3-medium-diffusers\", controlnet=controlnet)\n```\n\n## SD3ControlNetModel\n\n[[autodoc]] SD3ControlNetModel\n\n## SD3ControlNetOutput\n\n[[autodoc]] models.controlnets.controlnet_sd3.SD3ControlNetOutput\n\n"
  },
  {
    "path": "docs/source/en/api/models/controlnet_sparsectrl.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# SparseControlNetModel\n\nSparseControlNetModel is an implementation of ControlNet for [AnimateDiff](https://huggingface.co/papers/2307.04725).\n\nControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.\n\nThe SparseCtrl version of ControlNet was introduced in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://huggingface.co/papers/2311.16933) for achieving controlled generation in text-to-video diffusion models by Yuwei Guo, Ceyuan Yang, Anyi Rao, Maneesh Agrawala, Dahua Lin, and Bo Dai.\n\nThe abstract from the paper is:\n\n*The development of text-to-video (T2V), i.e., generating videos with a given text prompt, has been significantly advanced in recent years. However, relying solely on text prompts often results in ambiguous frame composition due to spatial uncertainty. The research community thus leverages the dense structure signals, e.g., per-frame depth/edge sequences, to enhance controllability, whose collection accordingly increases the burden of inference. In this work, we present SparseCtrl to enable flexible structure control with temporally sparse signals, requiring only one or a few inputs, as shown in Figure 1. It incorporates an additional condition encoder to process these sparse signals while leaving the pre-trained T2V model untouched. The proposed approach is compatible with various modalities, including sketches, depth maps, and RGB images, providing more practical control for video generation and promoting applications such as storyboarding, depth rendering, keyframe animation, and interpolation. Extensive experiments demonstrate the generalization of SparseCtrl on both original and personalized T2V generators. Codes and models will be publicly available at [this https URL](https://guoyww.github.io/projects/SparseCtrl).*\n\n## Example for loading SparseControlNetModel\n\n```python\nimport torch\nfrom diffusers import SparseControlNetModel\n\n# fp32 variant in float16\n# 1. Scribble checkpoint\ncontrolnet = SparseControlNetModel.from_pretrained(\"guoyww/animatediff-sparsectrl-scribble\", torch_dtype=torch.float16)\n\n# 2. RGB checkpoint\ncontrolnet = SparseControlNetModel.from_pretrained(\"guoyww/animatediff-sparsectrl-rgb\", torch_dtype=torch.float16)\n\n# For loading fp16 variant, pass `variant=\"fp16\"` as an additional parameter\n```\n\n## SparseControlNetModel\n\n[[autodoc]] SparseControlNetModel\n\n## SparseControlNetOutput\n\n[[autodoc]] models.controlnets.controlnet_sparsectrl.SparseControlNetOutput\n"
  },
  {
    "path": "docs/source/en/api/models/controlnet_union.md",
    "content": "<!--Copyright 2025 The HuggingFace Team and The InstantX Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNetUnionModel\n\nControlNetUnionModel is an implementation of ControlNet for Stable Diffusion XL.\n\nThe ControlNet model was introduced in [ControlNetPlus](https://github.com/xinsir6/ControlNetPlus) by xinsir6. It supports multiple conditioning inputs without increasing computation.\n\n*We design a new architecture that can support 10+ control types in condition text-to-image generation and can generate high resolution images visually comparable with midjourney. The network is based on the original ControlNet architecture, we propose two new modules to: 1 Extend the original ControlNet to support different image conditions using the same network parameter. 2 Support multiple conditions input without increasing computation offload, which is especially important for designers who want to edit image in detail, different conditions use the same condition encoder, without adding extra computations or parameters.*\n\n## Loading\n\nBy default the [`ControlNetUnionModel`] should be loaded with [`~ModelMixin.from_pretrained`].\n\n```py\nfrom diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel\n\ncontrolnet = ControlNetUnionModel.from_pretrained(\"xinsir/controlnet-union-sdxl-1.0\")\npipe = StableDiffusionXLControlNetUnionPipeline.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", controlnet=controlnet)\n```\n\n## ControlNetUnionModel\n\n[[autodoc]] ControlNetUnionModel\n\n"
  },
  {
    "path": "docs/source/en/api/models/cosmos_transformer3d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# CosmosTransformer3DModel\n\nA Diffusion Transformer model for 3D video-like data was introduced in [Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import CosmosTransformer3DModel\n\ntransformer = CosmosTransformer3DModel.from_pretrained(\"nvidia/Cosmos-1.0-Diffusion-7B-Text2World\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## CosmosTransformer3DModel\n\n[[autodoc]] CosmosTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/dit_transformer2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DiTTransformer2DModel\n\nA Transformer model for image-like data from [DiT](https://huggingface.co/papers/2212.09748).\n\n## DiTTransformer2DModel\n\n[[autodoc]] DiTTransformer2DModel\n"
  },
  {
    "path": "docs/source/en/api/models/easyanimate_transformer3d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# EasyAnimateTransformer3DModel\n\nA Diffusion Transformer model for 3D data from [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import EasyAnimateTransformer3DModel\n\ntransformer = EasyAnimateTransformer3DModel.from_pretrained(\"alibaba-pai/EasyAnimateV5.1-12b-zh\", subfolder=\"transformer\", torch_dtype=torch.float16).to(\"cuda\")\n```\n\n## EasyAnimateTransformer3DModel\n\n[[autodoc]] EasyAnimateTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/flux2_transformer.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Flux2Transformer2DModel\n\nA Transformer model for image-like data from [Flux2](https://hf.co/black-forest-labs/FLUX.2-dev).\n\n## Flux2Transformer2DModel\n\n[[autodoc]] Flux2Transformer2DModel\n\n## Flux2Transformer2DModelOutput\n\n[[autodoc]] models.transformers.transformer_flux2.Flux2Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/flux_transformer.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# FluxTransformer2DModel\n\nA Transformer model for image-like data from [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).\n\n## FluxTransformer2DModel\n\n[[autodoc]] FluxTransformer2DModel\n"
  },
  {
    "path": "docs/source/en/api/models/glm_image_transformer2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# GlmImageTransformer2DModel\n\nA Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel] (TODO).\n\n## GlmImageTransformer2DModel\n\n[[autodoc]] GlmImageTransformer2DModel\n"
  },
  {
    "path": "docs/source/en/api/models/helios_transformer3d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# HeliosTransformer3DModel\n\nA 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/2603.04379) by Peking University & ByteDance & etc.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import HeliosTransformer3DModel\n\n# Best Quality\ntransformer = HeliosTransformer3DModel.from_pretrained(\"BestWishYsh/Helios-Base\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n# Intermediate Weight\ntransformer = HeliosTransformer3DModel.from_pretrained(\"BestWishYsh/Helios-Mid\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n# Best Efficiency\ntransformer = HeliosTransformer3DModel.from_pretrained(\"BestWishYsh/Helios-Distilled\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## HeliosTransformer3DModel\n\n[[autodoc]] HeliosTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/hidream_image_transformer.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# HiDreamImageTransformer2DModel\n\nA Transformer model for image-like data from [HiDream-I1](https://huggingface.co/HiDream-ai).\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import HiDreamImageTransformer2DModel\n\ntransformer = HiDreamImageTransformer2DModel.from_pretrained(\"HiDream-ai/HiDream-I1-Full\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## Loading GGUF quantized checkpoints for HiDream-I1\n\nGGUF checkpoints for the `HiDreamImageTransformer2DModel` can  be loaded using `~FromOriginalModelMixin.from_single_file`\n\n```python\nimport torch\nfrom diffusers import GGUFQuantizationConfig, HiDreamImageTransformer2DModel\n\nckpt_path = \"https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf\"\ntransformer = HiDreamImageTransformer2DModel.from_single_file(\n    ckpt_path,\n    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),\n    torch_dtype=torch.bfloat16\n)\n```\n\n## HiDreamImageTransformer2DModel\n\n[[autodoc]] HiDreamImageTransformer2DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/hunyuan_transformer2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# HunyuanDiT2DModel\n\nA Diffusion Transformer model for 2D data from [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT).\n\n## HunyuanDiT2DModel\n\n[[autodoc]] HunyuanDiT2DModel\n\n"
  },
  {
    "path": "docs/source/en/api/models/hunyuan_video15_transformer_3d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# HunyuanVideo15Transformer3DModel\n\nA Diffusion Transformer model for 3D video-like data used in [HunyuanVideo1.5](https://github.com/Tencent/HunyuanVideo1-1.5).\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import HunyuanVideo15Transformer3DModel\n\ntransformer = HunyuanVideo15Transformer3DModel.from_pretrained(\"hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v\" subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## HunyuanVideo15Transformer3DModel\n\n[[autodoc]] HunyuanVideo15Transformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/hunyuan_video_transformer_3d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# HunyuanVideoTransformer3DModel\n\nA Diffusion Transformer model for 3D video-like data was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import HunyuanVideoTransformer3DModel\n\ntransformer = HunyuanVideoTransformer3DModel.from_pretrained(\"hunyuanvideo-community/HunyuanVideo\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## HunyuanVideoTransformer3DModel\n\n[[autodoc]] HunyuanVideoTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/hunyuanimage_transformer_2d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# HunyuanImageTransformer2DModel\n\nA Diffusion Transformer model for [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import HunyuanImageTransformer2DModel\n\ntransformer = HunyuanImageTransformer2DModel.from_pretrained(\"hunyuanvideo-community/HunyuanImage-2.1-Diffusers\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## HunyuanImageTransformer2DModel\n\n[[autodoc]] HunyuanImageTransformer2DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/latte_transformer3d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n## LatteTransformer3DModel\n\nA Diffusion Transformer model for 3D data from [Latte](https://github.com/Vchitect/Latte).\n\n## LatteTransformer3DModel\n\n[[autodoc]] LatteTransformer3DModel\n"
  },
  {
    "path": "docs/source/en/api/models/longcat_image_transformer2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# LongCatImageTransformer2DModel\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import LongCatImageTransformer2DModel\n\ntransformer = LongCatImageTransformer2DModel.from_pretrained(\"meituan-longcat/LongCat-Image \", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## LongCatImageTransformer2DModel\n\n[[autodoc]] LongCatImageTransformer2DModel"
  },
  {
    "path": "docs/source/en/api/models/ltx2_video_transformer3d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# LTX2VideoTransformer3DModel\n\nA Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import LTX2VideoTransformer3DModel\n\ntransformer = LTX2VideoTransformer3DModel.from_pretrained(\"Lightricks/LTX-2\", subfolder=\"transformer\", torch_dtype=torch.bfloat16).to(\"cuda\")\n```\n\n## LTX2VideoTransformer3DModel\n\n[[autodoc]] LTX2VideoTransformer3DModel\n"
  },
  {
    "path": "docs/source/en/api/models/ltx_video_transformer3d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# LTXVideoTransformer3DModel\n\nA Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import LTXVideoTransformer3DModel\n\ntransformer = LTXVideoTransformer3DModel.from_pretrained(\"Lightricks/LTX-Video\", subfolder=\"transformer\", torch_dtype=torch.bfloat16).to(\"cuda\")\n```\n\n## LTXVideoTransformer3DModel\n\n[[autodoc]] LTXVideoTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/lumina2_transformer2d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# Lumina2Transformer2DModel\n\nA Diffusion Transformer model for 3D video-like data was introduced in [Lumina Image 2.0](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) by Alpha-VLLM.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import Lumina2Transformer2DModel\n\ntransformer = Lumina2Transformer2DModel.from_pretrained(\"Alpha-VLLM/Lumina-Image-2.0\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## Lumina2Transformer2DModel\n\n[[autodoc]] Lumina2Transformer2DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/lumina_nextdit2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# LuminaNextDiT2DModel\n\nA Next Version of Diffusion Transformer model for 2D data from [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X).\n\n## LuminaNextDiT2DModel\n\n[[autodoc]] LuminaNextDiT2DModel\n\n"
  },
  {
    "path": "docs/source/en/api/models/mochi_transformer3d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# MochiTransformer3DModel\n\nA Diffusion Transformer model for 3D video-like data was introduced in [Mochi-1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Genmo.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import MochiTransformer3DModel\n\ntransformer = MochiTransformer3DModel.from_pretrained(\"genmo/mochi-1-preview\", subfolder=\"transformer\", torch_dtype=torch.float16).to(\"cuda\")\n```\n\n## MochiTransformer3DModel\n\n[[autodoc]] MochiTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/omnigen_transformer.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# OmniGenTransformer2DModel\n\nA Transformer model that accepts multimodal instructions to generate images for [OmniGen](https://github.com/VectorSpaceLab/OmniGen/).\n\nThe abstract from the paper is:\n\n*The emergence of Large Language Models (LLMs) has unified language  generation tasks and revolutionized human-machine interaction.  However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism.  This work represents the first attempt at a general-purpose image generation model,  and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.*\n\n```python\nimport torch\nfrom diffusers import OmniGenTransformer2DModel\n\ntransformer = OmniGenTransformer2DModel.from_pretrained(\"Shitao/OmniGen-v1-diffusers\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## OmniGenTransformer2DModel\n\n[[autodoc]] OmniGenTransformer2DModel\n"
  },
  {
    "path": "docs/source/en/api/models/overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Models\n\n🤗 Diffusers provides pretrained models for popular algorithms and modules to create custom diffusion systems. The primary function of models is to denoise an input sample as modeled by the distribution  \\\\(p_{\\theta}(x_{t-1}|x_{t})\\\\).\n\nAll models are built from the base [`ModelMixin`] class which is a [`torch.nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) providing basic functionality for saving and loading models, locally and from the Hugging Face Hub.\n\n## ModelMixin\n[[autodoc]] ModelMixin\n\n## PushToHubMixin\n\n[[autodoc]] utils.PushToHubMixin\n"
  },
  {
    "path": "docs/source/en/api/models/ovisimage_transformer2d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# OvisImageTransformer2DModel\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import OvisImageTransformer2DModel\n\ntransformer = OvisImageTransformer2DModel.from_pretrained(\"AIDC-AI/Ovis-Image-7B\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## OvisImageTransformer2DModel\n\n[[autodoc]] OvisImageTransformer2DModel\n"
  },
  {
    "path": "docs/source/en/api/models/pixart_transformer2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# PixArtTransformer2DModel\n\nA Transformer model for image-like data from [PixArt-Alpha](https://huggingface.co/papers/2310.00426) and [PixArt-Sigma](https://huggingface.co/papers/2403.04692).\n\n## PixArtTransformer2DModel\n\n[[autodoc]] PixArtTransformer2DModel\n"
  },
  {
    "path": "docs/source/en/api/models/prior_transformer.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# PriorTransformer\n\nThe Prior Transformer was originally introduced in [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) by Ramesh et al. It is used to predict CLIP image embeddings from CLIP text embeddings; image embeddings are predicted through a denoising diffusion process.\n\nThe abstract from the paper is:\n\n*Contrastive models like CLIP have been shown to learn robust representations of images that capture both semantics and style. To leverage these representations for image generation, we propose a two-stage model: a prior that generates a CLIP image embedding given a text caption, and a decoder that generates an image conditioned on the image embedding. We show that explicitly generating image representations improves image diversity with minimal loss in photorealism and caption similarity. Our decoders conditioned on image representations can also produce variations of an image that preserve both its semantics and style, while varying the non-essential details absent from the image representation. Moreover, the joint embedding space of CLIP enables language-guided image manipulations in a zero-shot fashion. We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher-quality samples.*\n\n## PriorTransformer\n\n[[autodoc]] PriorTransformer\n\n## PriorTransformerOutput\n\n[[autodoc]] models.transformers.prior_transformer.PriorTransformerOutput\n"
  },
  {
    "path": "docs/source/en/api/models/qwenimage_transformer2d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# QwenImageTransformer2DModel\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import QwenImageTransformer2DModel\n\ntransformer = QwenImageTransformer2DModel.from_pretrained(\"Qwen/QwenImage-20B\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## QwenImageTransformer2DModel\n\n[[autodoc]] QwenImageTransformer2DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/sana_transformer2d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# SanaTransformer2DModel\n\nA Diffusion Transformer model for 2D data from [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) was introduced from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.\n\nThe abstract from the paper is:\n\n*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import SanaTransformer2DModel\n\ntransformer = SanaTransformer2DModel.from_pretrained(\"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## SanaTransformer2DModel\n\n[[autodoc]] SanaTransformer2DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/sana_video_transformer3d.md",
    "content": "<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# SanaVideoTransformer3DModel\n\nA Diffusion Transformer model for 3D data (video) from [SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.\n\nThe abstract from the paper is:\n\n*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation.*\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import SanaVideoTransformer3DModel\nimport torch\n\ntransformer = SanaVideoTransformer3DModel.from_pretrained(\"Efficient-Large-Model/SANA-Video_2B_480p_diffusers\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## SanaVideoTransformer3DModel\n\n[[autodoc]] SanaVideoTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n\n"
  },
  {
    "path": "docs/source/en/api/models/sd3_transformer2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# SD3 Transformer Model\n\nThe Transformer model introduced in [Stable Diffusion 3](https://hf.co/papers/2403.03206). Its novelty lies in the MMDiT transformer block.\n\n## SD3Transformer2DModel\n\n[[autodoc]] SD3Transformer2DModel"
  },
  {
    "path": "docs/source/en/api/models/skyreels_v2_transformer_3d.md",
    "content": "<!-- Copyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# SkyReelsV2Transformer3DModel\n\nA Diffusion Transformer model for 3D video-like data was introduced in [SkyReels-V2](https://github.com/SkyworkAI/SkyReels-V2) by the Skywork AI.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import SkyReelsV2Transformer3DModel\n\ntransformer = SkyReelsV2Transformer3DModel.from_pretrained(\"Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## SkyReelsV2Transformer3DModel\n\n[[autodoc]] SkyReelsV2Transformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/stable_audio_transformer.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# StableAudioDiTModel\n\nA Transformer model for audio waveforms from [Stable Audio Open](https://huggingface.co/papers/2407.14358).\n\n## StableAudioDiTModel\n\n[[autodoc]] StableAudioDiTModel\n"
  },
  {
    "path": "docs/source/en/api/models/stable_cascade_unet.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# StableCascadeUNet\n\nA UNet model from the [Stable Cascade pipeline](../pipelines/stable_cascade.md).\n\n## StableCascadeUNet\n\n[[autodoc]] models.unets.unet_stable_cascade.StableCascadeUNet\n"
  },
  {
    "path": "docs/source/en/api/models/transformer2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Transformer2DModel\n\nA Transformer model for image-like data from [CompVis](https://huggingface.co/CompVis) that is based on the [Vision Transformer](https://huggingface.co/papers/2010.11929) introduced by Dosovitskiy et al. The [`Transformer2DModel`] accepts discrete (classes of vector embeddings) or continuous (actual embeddings) inputs.\n\nWhen the input is **continuous**:\n\n1. Project the input and reshape it to `(batch_size, sequence_length, feature_dimension)`.\n2. Apply the Transformer blocks in the standard way.\n3. Reshape to image.\n\nWhen the input is **discrete**:\n\n> [!TIP]\n> It is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised image don't contain a prediction for the masked pixel because the unnoised image cannot be masked.\n\n1. Convert input (classes of latent pixels) to embeddings and apply positional embeddings.\n2. Apply the Transformer blocks in the standard way.\n3. Predict classes of unnoised image.\n\n## Transformer2DModel\n\n[[autodoc]] Transformer2DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/transformer_bria_fibo.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# BriaFiboTransformer2DModel\n\nA modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO)\n\n## BriaFiboTransformer2DModel\n\n[[autodoc]] BriaFiboTransformer2DModel\n"
  },
  {
    "path": "docs/source/en/api/models/transformer_temporal.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# TransformerTemporalModel\n\nA Transformer model for video-like data.\n\n## TransformerTemporalModel\n\n[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModel\n\n## TransformerTemporalModelOutput\n\n[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/unet-motion.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# UNetMotionModel\n\nThe [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet model.\n\nThe abstract from the paper is:\n\n*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*\n\n## UNetMotionModel\n[[autodoc]] UNetMotionModel\n\n## UNet3DConditionOutput\n[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput\n"
  },
  {
    "path": "docs/source/en/api/models/unet.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# UNet1DModel\n\nThe [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al. for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 1D UNet model.\n\nThe abstract from the paper is:\n\n*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*\n\n## UNet1DModel\n[[autodoc]] UNet1DModel\n\n## UNet1DOutput\n[[autodoc]] models.unets.unet_1d.UNet1DOutput\n"
  },
  {
    "path": "docs/source/en/api/models/unet2d-cond.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# UNet2DConditionModel\n\nThe [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al. for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet conditional model.\n\nThe abstract from the paper is:\n\n*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*\n\n## UNet2DConditionModel\n[[autodoc]] UNet2DConditionModel\n\n## UNet2DConditionOutput\n[[autodoc]] models.unets.unet_2d_condition.UNet2DConditionOutput\n"
  },
  {
    "path": "docs/source/en/api/models/unet2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# UNet2DModel\n\nThe [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al. for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet model.\n\nThe abstract from the paper is:\n\n*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*\n\n## UNet2DModel\n[[autodoc]] UNet2DModel\n\n## UNet2DOutput\n[[autodoc]] models.unets.unet_2d.UNet2DOutput\n"
  },
  {
    "path": "docs/source/en/api/models/unet3d-cond.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# UNet3DConditionModel\n\nThe [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al. for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 3D UNet conditional model.\n\nThe abstract from the paper is:\n\n*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*\n\n## UNet3DConditionModel\n[[autodoc]] UNet3DConditionModel\n\n## UNet3DConditionOutput\n[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput\n"
  },
  {
    "path": "docs/source/en/api/models/uvit2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# UVit2DModel\n\nThe [U-ViT](https://hf.co/papers/2301.11093) model is a vision transformer (ViT) based UNet. This model incorporates elements from ViT (considers all inputs such as time, conditions and noisy image patches as tokens) and a UNet (long skip connections between the shallow and deep layers). The skip connection is important for predicting pixel-level features. An additional 3x3 convolutional block is applied prior to the final output to improve image quality.\n\nThe abstract from the paper is:\n\n*Currently, applying diffusion models in pixel space of high resolution images is difficult. Instead, existing approaches focus on diffusion in lower dimensional spaces (latent diffusion), or have multiple super-resolution levels of generation referred to as cascades. The downside is that these approaches add additional complexity to the diffusion framework. This paper aims to improve denoising diffusion for high resolution images while keeping the model as simple as possible. The paper is centered around the research question: How can one train a standard denoising diffusion models on high resolution images, and still obtain performance comparable to these alternate approaches? The four main findings are: 1) the noise schedule should be adjusted for high resolution images, 2) It is sufficient to scale only a particular part of the architecture, 3) dropout should be added at specific locations in the architecture, and 4) downsampling is an effective strategy to avoid high resolution feature maps. Combining these simple yet effective techniques, we achieve state-of-the-art on image generation among diffusion models without sampling modifiers on ImageNet.*\n\n## UVit2DModel\n\n[[autodoc]] UVit2DModel\n\n## UVit2DConvEmbed\n\n[[autodoc]] models.unets.uvit_2d.UVit2DConvEmbed\n\n## UVitBlock\n\n[[autodoc]] models.unets.uvit_2d.UVitBlock\n\n## ConvNextBlock\n\n[[autodoc]] models.unets.uvit_2d.ConvNextBlock\n\n## ConvMlmLayer\n\n[[autodoc]] models.unets.uvit_2d.ConvMlmLayer\n"
  },
  {
    "path": "docs/source/en/api/models/vq.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# VQModel\n\nThe VQ-VAE model was introduced in [Neural Discrete Representation Learning](https://huggingface.co/papers/1711.00937) by Aaron van den Oord, Oriol Vinyals and Koray Kavukcuoglu. The model is used in 🤗 Diffusers to decode latent representations into images. Unlike [`AutoencoderKL`], the [`VQModel`] works in a quantized latent space.\n\nThe abstract from the paper is:\n\n*Learning useful representations without supervision remains a key challenge in machine learning. In this paper, we propose a simple yet powerful generative model that learns such discrete representations. Our model, the Vector Quantised-Variational AutoEncoder (VQ-VAE), differs from VAEs in two key ways: the encoder network outputs discrete, rather than continuous, codes; and the prior is learnt rather than static. In order to learn a discrete latent representation, we incorporate ideas from vector quantisation (VQ). Using the VQ method allows the model to circumvent issues of \"posterior collapse\" -- where the latents are ignored when they are paired with a powerful autoregressive decoder -- typically observed in the VAE framework. Pairing these representations with an autoregressive prior, the model can generate high quality images, videos, and speech as well as doing high quality speaker conversion and unsupervised learning of phonemes, providing further evidence of the utility of the learnt representations.*\n\n## VQModel\n\n[[autodoc]] VQModel\n\n## VQEncoderOutput\n\n[[autodoc]] models.autoencoders.vq_model.VQEncoderOutput\n"
  },
  {
    "path": "docs/source/en/api/models/wan_animate_transformer_3d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# WanAnimateTransformer3DModel\n\nA Diffusion Transformer model for 3D video-like data was introduced in [Wan Animate](https://github.com/Wan-Video/Wan2.2) by the Alibaba Wan Team.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import WanAnimateTransformer3DModel\n\ntransformer = WanAnimateTransformer3DModel.from_pretrained(\"Wan-AI/Wan2.2-Animate-14B-Diffusers\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## WanAnimateTransformer3DModel\n\n[[autodoc]] WanAnimateTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/wan_transformer_3d.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# WanTransformer3DModel\n\nA Diffusion Transformer model for 3D video-like data was introduced in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.\n\nThe model can be loaded with the following code snippet.\n\n```python\nfrom diffusers import WanTransformer3DModel\n\ntransformer = WanTransformer3DModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n```\n\n## WanTransformer3DModel\n\n[[autodoc]] WanTransformer3DModel\n\n## Transformer2DModelOutput\n\n[[autodoc]] models.modeling_outputs.Transformer2DModelOutput\n"
  },
  {
    "path": "docs/source/en/api/models/z_image_transformer2d.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ZImageTransformer2DModel\n\nA Transformer model for image-like data from [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo).\n\n## ZImageTransformer2DModel\n\n[[autodoc]] ZImageTransformer2DModel"
  },
  {
    "path": "docs/source/en/api/modular_diffusers/guiders.md",
    "content": "# Guiders\n\nGuiders are components in Modular Diffusers that control how the diffusion process is guided during generation. They implement various guidance techniques to improve generation quality and control.\n\n## BaseGuidance\n\n[[autodoc]] diffusers.guiders.guider_utils.BaseGuidance\n\n## ClassifierFreeGuidance\n\n[[autodoc]] diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance\n\n## ClassifierFreeZeroStarGuidance\n\n[[autodoc]] diffusers.guiders.classifier_free_zero_star_guidance.ClassifierFreeZeroStarGuidance\n\n## SkipLayerGuidance\n\n[[autodoc]] diffusers.guiders.skip_layer_guidance.SkipLayerGuidance\n\n## SmoothedEnergyGuidance\n\n[[autodoc]] diffusers.guiders.smoothed_energy_guidance.SmoothedEnergyGuidance\n\n## PerturbedAttentionGuidance\n\n[[autodoc]] diffusers.guiders.perturbed_attention_guidance.PerturbedAttentionGuidance\n\n## AdaptiveProjectedGuidance\n\n[[autodoc]] diffusers.guiders.adaptive_projected_guidance.AdaptiveProjectedGuidance\n\n## AutoGuidance\n\n[[autodoc]] diffusers.guiders.auto_guidance.AutoGuidance\n\n## TangentialClassifierFreeGuidance\n\n[[autodoc]] diffusers.guiders.tangential_classifier_free_guidance.TangentialClassifierFreeGuidance\n"
  },
  {
    "path": "docs/source/en/api/modular_diffusers/pipeline.md",
    "content": "# Pipeline\n\n## ModularPipeline\n\n[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ModularPipeline\n"
  },
  {
    "path": "docs/source/en/api/modular_diffusers/pipeline_blocks.md",
    "content": "# Pipeline blocks\n\n## ModularPipelineBlocks\n\n[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ModularPipelineBlocks\n\n## SequentialPipelineBlocks\n\n[[autodoc]] diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks\n\n## LoopSequentialPipelineBlocks\n\n[[autodoc]] diffusers.modular_pipelines.modular_pipeline.LoopSequentialPipelineBlocks\n\n## AutoPipelineBlocks\n\n[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks\n\n## ConditionalPipelineBlocks\n\n[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ConditionalPipelineBlocks"
  },
  {
    "path": "docs/source/en/api/modular_diffusers/pipeline_components.md",
    "content": "# Components and configs\n\n## ComponentSpec\n\n[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ComponentSpec\n\n## ConfigSpec\n\n[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ConfigSpec\n\n## ComponentsManager\n\n[[autodoc]] diffusers.modular_pipelines.components_manager.ComponentsManager\n\n## InsertableDict\n\n[[autodoc]] diffusers.modular_pipelines.modular_pipeline_utils.InsertableDict"
  },
  {
    "path": "docs/source/en/api/modular_diffusers/pipeline_states.md",
    "content": "# Pipeline states\n\n## PipelineState\n\n[[autodoc]] diffusers.modular_pipelines.modular_pipeline.PipelineState\n\n## BlockState\n\n[[autodoc]] diffusers.modular_pipelines.modular_pipeline.BlockState "
  },
  {
    "path": "docs/source/en/api/normalization.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Normalization layers\n\nCustomized normalization layers for supporting various models in 🤗 Diffusers.\n\n## AdaLayerNorm\n\n[[autodoc]] models.normalization.AdaLayerNorm\n\n## AdaLayerNormZero\n\n[[autodoc]] models.normalization.AdaLayerNormZero\n\n## AdaLayerNormSingle\n\n[[autodoc]] models.normalization.AdaLayerNormSingle\n\n## AdaGroupNorm\n\n[[autodoc]] models.normalization.AdaGroupNorm\n\n## AdaLayerNormContinuous\n\n[[autodoc]] models.normalization.AdaLayerNormContinuous\n\n## RMSNorm\n\n[[autodoc]] models.normalization.RMSNorm\n\n## GlobalResponseNorm\n\n[[autodoc]] models.normalization.GlobalResponseNorm\n\n\n## LuminaLayerNormContinuous\n[[autodoc]] models.normalization.LuminaLayerNormContinuous\n\n## SD35AdaLayerNormZeroX\n[[autodoc]] models.normalization.SD35AdaLayerNormZeroX\n\n## AdaLayerNormZeroSingle\n[[autodoc]] models.normalization.AdaLayerNormZeroSingle\n\n## LuminaRMSNormZero\n[[autodoc]] models.normalization.LuminaRMSNormZero\n\n## LpNorm\n[[autodoc]] models.normalization.LpNorm\n\n## CogView3PlusAdaLayerNormZeroTextImage\n[[autodoc]] models.normalization.CogView3PlusAdaLayerNormZeroTextImage\n\n## CogVideoXLayerNormZero\n[[autodoc]] models.normalization.CogVideoXLayerNormZero\n\n## MochiRMSNormZero\n[[autodoc]] models.transformers.transformer_mochi.MochiRMSNormZero\n\n## MochiRMSNorm\n[[autodoc]] models.normalization.MochiRMSNorm"
  },
  {
    "path": "docs/source/en/api/outputs.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Outputs\n\nAll model outputs are subclasses of [`~utils.BaseOutput`], data structures containing all the information returned by the model. The outputs can also be used as tuples or dictionaries.\n\nFor example:\n\n```python\nfrom diffusers import DDIMPipeline\n\npipeline = DDIMPipeline.from_pretrained(\"google/ddpm-cifar10-32\")\noutputs = pipeline()\n```\n\nThe `outputs` object is a [`~pipelines.ImagePipelineOutput`] which means it has an image attribute.\n\nYou can access each attribute as you normally would or with a keyword lookup, and if that attribute is not returned by the model, you will get `None`:\n\n```python\noutputs.images\noutputs[\"images\"]\n```\n\nWhen considering the `outputs` object as a tuple, it only considers the attributes that don't have `None` values.\nFor instance, retrieving an image by indexing into it returns the tuple `(outputs.images)`:\n\n```python\noutputs[:1]\n```\n\n> [!TIP]\n> To check a specific pipeline or model output, refer to its corresponding API documentation.\n\n## BaseOutput\n\n[[autodoc]] utils.BaseOutput\n    - to_tuple\n\n## ImagePipelineOutput\n\n[[autodoc]] pipelines.ImagePipelineOutput\n\n## AudioPipelineOutput\n\n[[autodoc]] pipelines.AudioPipelineOutput\n\n## ImageTextPipelineOutput\n\n[[autodoc]] ImageTextPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/parallel.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# Parallelism\n\nParallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times. Refer to the [Distributed inferece](../training/distributed_inference) guide to learn more.\n\n## ParallelConfig\n\n[[autodoc]] ParallelConfig\n\n## ContextParallelConfig\n\n[[autodoc]] ContextParallelConfig\n\n[[autodoc]] hooks.apply_context_parallel\n"
  },
  {
    "path": "docs/source/en/api/pipelines/allegro.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# Allegro\n\n[Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) from RhymesAI, by Yuan Zhou, Qiuyue Wang, Yuxuan Cai, Huan Yang.\n\nThe abstract from the paper is:\n\n*Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .*\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## Quantization\n\nQuantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.\n\nRefer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`AllegroPipeline`] for inference with bitsandbytes.\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, AllegroTransformer3DModel, AllegroPipeline\nfrom diffusers.utils import export_to_video\nfrom transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel\n\nquant_config = BitsAndBytesConfig(load_in_8bit=True)\ntext_encoder_8bit = T5EncoderModel.from_pretrained(\n    \"rhymes-ai/Allegro\",\n    subfolder=\"text_encoder\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_8bit = AllegroTransformer3DModel.from_pretrained(\n    \"rhymes-ai/Allegro\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\npipeline = AllegroPipeline.from_pretrained(\n    \"rhymes-ai/Allegro\",\n    text_encoder=text_encoder_8bit,\n    transformer=transformer_8bit,\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n)\n\nprompt = (\n    \"A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, \"\n    \"the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this \"\n    \"location might be a popular spot for docking fishing boats.\"\n)\nvideo = pipeline(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0]\nexport_to_video(video, \"harbor.mp4\", fps=15)\n```\n\n## AllegroPipeline\n\n[[autodoc]] AllegroPipeline\n  - all\n  - __call__\n\n## AllegroPipelineOutput\n\n[[autodoc]] pipelines.allegro.pipeline_output.AllegroPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/amused.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# aMUSEd\n\naMUSEd was introduced in [aMUSEd: An Open MUSE Reproduction](https://huggingface.co/papers/2401.01808) by Suraj Patil, William Berman, Robin Rombach, and Patrick von Platen.\n\nAmused is a lightweight text to image model based off of the [MUSE](https://huggingface.co/papers/2301.00704) architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once.\n\nAmused is a vqvae token based transformer that can generate an image in fewer forward passes than many diffusion models. In contrast with muse, it uses the smaller text encoder CLIP-L/14 instead of t5-xxl. Due to its small parameter count and few forward pass generation process, amused can generate many images quickly. This benefit is seen particularly at larger batch sizes.\n\nThe abstract from the paper is:\n\n*We present aMUSEd, an open-source, lightweight masked image model (MIM) for text-to-image generation based on MUSE. With 10 percent of MUSE's parameters, aMUSEd is focused on fast image generation. We believe MIM is under-explored compared to latent diffusion, the prevailing approach for text-to-image generation. Compared to latent diffusion, MIM requires fewer inference steps and is more interpretable. Additionally, MIM can be fine-tuned to learn additional styles with only a single image. We hope to encourage further exploration of MIM by demonstrating its effectiveness on large-scale text-to-image generation and releasing reproducible training code. We also release checkpoints for two models which directly produce images at 256x256 and 512x512 resolutions.*\n\n| Model | Params |\n|-------|--------|\n| [amused-256](https://huggingface.co/amused/amused-256) | 603M |\n| [amused-512](https://huggingface.co/amused/amused-512) | 608M |\n\n## AmusedPipeline\n\n[[autodoc]] AmusedPipeline\n\t- __call__\n\t- all\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\n[[autodoc]] AmusedImg2ImgPipeline\n\t- __call__\n\t- all\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\n[[autodoc]] AmusedInpaintPipeline\n\t- __call__\n\t- all\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention"
  },
  {
    "path": "docs/source/en/api/pipelines/animatediff.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Text-to-Video Generation with AnimateDiff\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n## Overview\n\n[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://huggingface.co/papers/2307.04725) by Yuwei Guo, Ceyuan Yang, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai.\n\nThe abstract of the paper is the following:\n\n*With the advance of text-to-image models (e.g., Stable Diffusion) and corresponding personalization techniques such as DreamBooth and LoRA, everyone can manifest their imagination into high-quality images at an affordable cost. Subsequently, there is a great demand for image animation techniques to further combine generated static images with motion dynamics. In this report, we propose a practical framework to animate most of the existing personalized text-to-image models once and for all, saving efforts in model-specific tuning. At the core of the proposed framework is to insert a newly initialized motion modeling module into the frozen text-to-image model and train it on video clips to distill reasonable motion priors. Once trained, by simply injecting this motion modeling module, all personalized versions derived from the same base T2I readily become text-driven models that produce diverse and personalized animated images. We conduct our evaluation on several public representative personalized text-to-image models across anime pictures and realistic photographs, and demonstrate that our proposed framework helps these models generate temporally smooth animation clips while preserving the domain and diversity of their outputs. Code and pre-trained weights will be publicly available at [this https URL](https://animatediff.github.io/).*\n\n## Available Pipelines\n\n| Pipeline | Tasks | Demo\n|---|---|:---:|\n| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |\n| [AnimateDiffControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py) | *Controlled Video-to-Video Generation with AnimateDiff using ControlNet* |\n| [AnimateDiffSparseControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py) | *Controlled Video-to-Video Generation with AnimateDiff using SparseCtrl* |\n| [AnimateDiffSDXLPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py) | *Video-to-Video Generation with AnimateDiff* |\n| [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* |\n| [AnimateDiffVideoToVideoControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py) | *Video-to-Video Generation with AnimateDiff using ControlNet* |\n\n## Available checkpoints\n\nMotion Adapter checkpoints can be found under [guoyww](https://huggingface.co/guoyww/). These checkpoints are meant to work with any model based on Stable Diffusion 1.4/1.5.\n\n## Usage example\n\n### AnimateDiffPipeline\n\nAnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet.\n\nThe following example demonstrates how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.\n\n```python\nimport torch\nfrom diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter\nfrom diffusers.utils import export_to_gif\n\n# Load the motion adapter\nadapter = MotionAdapter.from_pretrained(\"guoyww/animatediff-motion-adapter-v1-5-2\", torch_dtype=torch.float16)\n# load SD 1.5 based finetuned model\nmodel_id = \"SG161222/Realistic_Vision_V5.1_noVAE\"\npipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)\nscheduler = DDIMScheduler.from_pretrained(\n    model_id,\n    subfolder=\"scheduler\",\n    clip_sample=False,\n    timestep_spacing=\"linspace\",\n    beta_schedule=\"linear\",\n    steps_offset=1,\n)\npipe.scheduler = scheduler\n\n# enable memory savings\npipe.enable_vae_slicing()\npipe.enable_model_cpu_offload()\n\noutput = pipe(\n    prompt=(\n        \"masterpiece, bestquality, highlydetailed, ultradetailed, sunset, \"\n        \"orange sky, warm lighting, fishing boats, ocean waves seagulls, \"\n        \"rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, \"\n        \"golden hour, coastal landscape, seaside scenery\"\n    ),\n    negative_prompt=\"bad quality, worse quality\",\n    num_frames=16,\n    guidance_scale=7.5,\n    num_inference_steps=25,\n    generator=torch.Generator(\"cpu\").manual_seed(42),\n)\nframes = output.frames[0]\nexport_to_gif(frames, \"animation.gif\")\n```\n\nHere are some sample outputs:\n\n<table>\n    <tr>\n        <td><center>\n        masterpiece, bestquality, sunset.\n        <br>\n        <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-realistic-doc.gif\"\n            alt=\"masterpiece, bestquality, sunset\"\n            style=\"width: 300px;\" />\n        </center></td>\n    </tr>\n</table>\n\n> [!TIP]\n> AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the AnimateDiff checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`.\n\n### AnimateDiffControlNetPipeline\n\nAnimateDiff can also be used with ControlNets ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala. With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide depth maps, the ControlNet model generates a video that'll preserve the spatial information from the depth maps. It is a more flexible and accurate way to control the video generation process.\n\n```python\nimport torch\nfrom diffusers import AnimateDiffControlNetPipeline, AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler\nfrom diffusers.utils import export_to_gif, load_video\n\n# Additionally, you will need a preprocess videos before they can be used with the ControlNet\n# HF maintains just the right package for it: `pip install controlnet_aux`\nfrom controlnet_aux.processor import ZoeDetector\n\n# Download controlnets from https://huggingface.co/lllyasviel/ControlNet-v1-1 to use .from_single_file\n# Download Diffusers-format controlnets, such as https://huggingface.co/lllyasviel/sd-controlnet-depth, to use .from_pretrained()\ncontrolnet = ControlNetModel.from_single_file(\"control_v11f1p_sd15_depth.pth\", torch_dtype=torch.float16)\n\n# We use AnimateLCM for this example but one can use the original motion adapters as well (for example, https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-3)\nmotion_adapter = MotionAdapter.from_pretrained(\"wangfuyun/AnimateLCM\")\n\nvae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16)\npipe: AnimateDiffControlNetPipeline = AnimateDiffControlNetPipeline.from_pretrained(\n    \"SG161222/Realistic_Vision_V5.1_noVAE\",\n    motion_adapter=motion_adapter,\n    controlnet=controlnet,\n    vae=vae,\n).to(device=\"cuda\", dtype=torch.float16)\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule=\"linear\")\npipe.load_lora_weights(\"wangfuyun/AnimateLCM\", weight_name=\"AnimateLCM_sd15_t2v_lora.safetensors\", adapter_name=\"lcm-lora\")\npipe.set_adapters([\"lcm-lora\"], [0.8])\n\ndepth_detector = ZoeDetector.from_pretrained(\"lllyasviel/Annotators\").to(\"cuda\")\nvideo = load_video(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif\")\nconditioning_frames = []\n\nwith pipe.progress_bar(total=len(video)) as progress_bar:\n    for frame in video:\n        conditioning_frames.append(depth_detector(frame))\n        progress_bar.update()\n\nprompt = \"a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality\"\nnegative_prompt = \"bad quality, worst quality\"\n\nvideo = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=len(video),\n    num_inference_steps=10,\n    guidance_scale=2.0,\n    conditioning_frames=conditioning_frames,\n    generator=torch.Generator().manual_seed(42),\n).frames[0]\n\nexport_to_gif(video, \"animatediff_controlnet.gif\", fps=8)\n```\n\nHere are some sample outputs:\n\n<table align=\"center\">\n    <tr>\n      <th align=\"center\">Source Video</th>\n      <th align=\"center\">Output Video</th>\n    </tr>\n    <tr>\n        <td align=\"center\">\n          raccoon playing a guitar\n          <br />\n          <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif\" alt=\"racoon playing a guitar\" />\n        </td>\n        <td align=\"center\">\n          a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality\n          <br/>\n          <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-controlnet-output.gif\" alt=\"a panda, playing a guitar, sitting in a pink boat, in the ocean, mountains in background, realistic, high quality\" />\n        </td>\n    </tr>\n</table>\n\n### AnimateDiffSparseControlNetPipeline\n\n[SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://huggingface.co/papers/2311.16933) for achieving controlled generation in text-to-video diffusion models by Yuwei Guo, Ceyuan Yang, Anyi Rao, Maneesh Agrawala, Dahua Lin, and Bo Dai.\n\nThe abstract from the paper is:\n\n*The development of text-to-video (T2V), i.e., generating videos with a given text prompt, has been significantly advanced in recent years. However, relying solely on text prompts often results in ambiguous frame composition due to spatial uncertainty. The research community thus leverages the dense structure signals, e.g., per-frame depth/edge sequences, to enhance controllability, whose collection accordingly increases the burden of inference. In this work, we present SparseCtrl to enable flexible structure control with temporally sparse signals, requiring only one or a few inputs, as shown in Figure 1. It incorporates an additional condition encoder to process these sparse signals while leaving the pre-trained T2V model untouched. The proposed approach is compatible with various modalities, including sketches, depth maps, and RGB images, providing more practical control for video generation and promoting applications such as storyboarding, depth rendering, keyframe animation, and interpolation. Extensive experiments demonstrate the generalization of SparseCtrl on both original and personalized T2V generators. Codes and models will be publicly available at [this https URL](https://guoyww.github.io/projects/SparseCtrl).*\n\nSparseCtrl introduces the following checkpoints for controlled text-to-video generation:\n\n- [SparseCtrl Scribble](https://huggingface.co/guoyww/animatediff-sparsectrl-scribble)\n- [SparseCtrl RGB](https://huggingface.co/guoyww/animatediff-sparsectrl-rgb)\n\n#### Using SparseCtrl Scribble\n\n```python\nimport torch\n\nfrom diffusers import AnimateDiffSparseControlNetPipeline\nfrom diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel\nfrom diffusers.schedulers import DPMSolverMultistepScheduler\nfrom diffusers.utils import export_to_gif, load_image\n\n\nmodel_id = \"SG161222/Realistic_Vision_V5.1_noVAE\"\nmotion_adapter_id = \"guoyww/animatediff-motion-adapter-v1-5-3\"\ncontrolnet_id = \"guoyww/animatediff-sparsectrl-scribble\"\nlora_adapter_id = \"guoyww/animatediff-motion-lora-v1-5-3\"\nvae_id = \"stabilityai/sd-vae-ft-mse\"\ndevice = \"cuda\"\n\nmotion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)\ncontrolnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)\nvae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)\nscheduler = DPMSolverMultistepScheduler.from_pretrained(\n    model_id,\n    subfolder=\"scheduler\",\n    beta_schedule=\"linear\",\n    algorithm_type=\"dpmsolver++\",\n    use_karras_sigmas=True,\n)\npipe = AnimateDiffSparseControlNetPipeline.from_pretrained(\n    model_id,\n    motion_adapter=motion_adapter,\n    controlnet=controlnet,\n    vae=vae,\n    scheduler=scheduler,\n    torch_dtype=torch.float16,\n).to(device)\npipe.load_lora_weights(lora_adapter_id, adapter_name=\"motion_lora\")\npipe.fuse_lora(lora_scale=1.0)\n\nprompt = \"an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality\"\nnegative_prompt = \"low quality, worst quality, letterboxed\"\n\nimage_files = [\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png\",\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png\",\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png\"\n]\ncondition_frame_indices = [0, 8, 15]\nconditioning_frames = [load_image(img_file) for img_file in image_files]\n\nvideo = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_inference_steps=25,\n    conditioning_frames=conditioning_frames,\n    controlnet_conditioning_scale=1.0,\n    controlnet_frame_indices=condition_frame_indices,\n    generator=torch.Generator().manual_seed(1337),\n).frames[0]\nexport_to_gif(video, \"output.gif\")\n```\n\nHere are some sample outputs:\n\n<table align=\"center\">\n    <tr>\n        <center>\n          <b>an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality</b>\n        </center>\n    </tr>\n    <tr>\n        <td>\n          <center>\n            <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-1.png\" alt=\"scribble-1\" />\n          </center>\n        </td>\n        <td>\n          <center>\n            <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-2.png\" alt=\"scribble-2\" />\n          </center>\n        </td>\n        <td>\n          <center>\n            <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-scribble-3.png\" alt=\"scribble-3\" />\n          </center>\n        </td>\n    </tr>\n    <tr>\n        <td colspan=3>\n          <center>\n            <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-sparsectrl-scribble-results.gif\" alt=\"an aerial view of a cyberpunk city, night time, neon lights, masterpiece, high quality\" />\n          </center>\n        </td>\n    </tr>\n</table>\n\n#### Using SparseCtrl RGB\n\n```python\nimport torch\n\nfrom diffusers import AnimateDiffSparseControlNetPipeline\nfrom diffusers.models import AutoencoderKL, MotionAdapter, SparseControlNetModel\nfrom diffusers.schedulers import DPMSolverMultistepScheduler\nfrom diffusers.utils import export_to_gif, load_image\n\n\nmodel_id = \"SG161222/Realistic_Vision_V5.1_noVAE\"\nmotion_adapter_id = \"guoyww/animatediff-motion-adapter-v1-5-3\"\ncontrolnet_id = \"guoyww/animatediff-sparsectrl-rgb\"\nlora_adapter_id = \"guoyww/animatediff-motion-lora-v1-5-3\"\nvae_id = \"stabilityai/sd-vae-ft-mse\"\ndevice = \"cuda\"\n\nmotion_adapter = MotionAdapter.from_pretrained(motion_adapter_id, torch_dtype=torch.float16).to(device)\ncontrolnet = SparseControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16).to(device)\nvae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16).to(device)\nscheduler = DPMSolverMultistepScheduler.from_pretrained(\n    model_id,\n    subfolder=\"scheduler\",\n    beta_schedule=\"linear\",\n    algorithm_type=\"dpmsolver++\",\n    use_karras_sigmas=True,\n)\npipe = AnimateDiffSparseControlNetPipeline.from_pretrained(\n    model_id,\n    motion_adapter=motion_adapter,\n    controlnet=controlnet,\n    vae=vae,\n    scheduler=scheduler,\n    torch_dtype=torch.float16,\n).to(device)\npipe.load_lora_weights(lora_adapter_id, adapter_name=\"motion_lora\")\n\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-firework.png\")\n\nvideo = pipe(\n    prompt=\"closeup face photo of man in black clothes, night city street, bokeh, fireworks in background\",\n    negative_prompt=\"low quality, worst quality\",\n    num_inference_steps=25,\n    conditioning_frames=image,\n    controlnet_frame_indices=[0],\n    controlnet_conditioning_scale=1.0,\n    generator=torch.Generator().manual_seed(42),\n).frames[0]\nexport_to_gif(video, \"output.gif\")\n```\n\nHere are some sample outputs:\n\n<table align=\"center\">\n    <tr>\n        <center>\n          <b>closeup face photo of man in black clothes, night city street, bokeh, fireworks in background</b>\n        </center>\n    </tr>\n    <tr>\n        <td>\n          <center>\n            <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-firework.png\" alt=\"closeup face photo of man in black clothes, night city street, bokeh, fireworks in background\" />\n          </center>\n        </td>\n        <td>\n          <center>\n            <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-sparsectrl-rgb-result.gif\" alt=\"closeup face photo of man in black clothes, night city street, bokeh, fireworks in background\" />\n          </center>\n        </td>\n    </tr>\n</table>\n\n### AnimateDiffSDXLPipeline\n\nAnimateDiff can also be used with SDXL models. This is currently an experimental feature as only a beta release of the motion adapter checkpoint is available.\n\n```python\nimport torch\nfrom diffusers.models import MotionAdapter\nfrom diffusers import AnimateDiffSDXLPipeline, DDIMScheduler\nfrom diffusers.utils import export_to_gif\n\nadapter = MotionAdapter.from_pretrained(\"guoyww/animatediff-motion-adapter-sdxl-beta\", torch_dtype=torch.float16)\n\nmodel_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\nscheduler = DDIMScheduler.from_pretrained(\n    model_id,\n    subfolder=\"scheduler\",\n    clip_sample=False,\n    timestep_spacing=\"linspace\",\n    beta_schedule=\"linear\",\n    steps_offset=1,\n)\npipe = AnimateDiffSDXLPipeline.from_pretrained(\n    model_id,\n    motion_adapter=adapter,\n    scheduler=scheduler,\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n).to(\"cuda\")\n\n# enable memory savings\npipe.enable_vae_slicing()\npipe.enable_vae_tiling()\n\noutput = pipe(\n    prompt=\"a panda surfing in the ocean, realistic, high quality\",\n    negative_prompt=\"low quality, worst quality\",\n    num_inference_steps=20,\n    guidance_scale=8,\n    width=1024,\n    height=1024,\n    num_frames=16,\n)\n\nframes = output.frames[0]\nexport_to_gif(frames, \"animation.gif\")\n```\n\n### AnimateDiffVideoToVideoPipeline\n\nAnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities.\n\n```python\nimport imageio\nimport requests\nimport torch\nfrom diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter\nfrom diffusers.utils import export_to_gif\nfrom io import BytesIO\nfrom PIL import Image\n\n# Load the motion adapter\nadapter = MotionAdapter.from_pretrained(\"guoyww/animatediff-motion-adapter-v1-5-2\", torch_dtype=torch.float16)\n# load SD 1.5 based finetuned model\nmodel_id = \"SG161222/Realistic_Vision_V5.1_noVAE\"\npipe = AnimateDiffVideoToVideoPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)\nscheduler = DDIMScheduler.from_pretrained(\n    model_id,\n    subfolder=\"scheduler\",\n    clip_sample=False,\n    timestep_spacing=\"linspace\",\n    beta_schedule=\"linear\",\n    steps_offset=1,\n)\npipe.scheduler = scheduler\n\n# enable memory savings\npipe.enable_vae_slicing()\npipe.enable_model_cpu_offload()\n\n# helper function to load videos\ndef load_video(file_path: str):\n    images = []\n\n    if file_path.startswith(('http://', 'https://')):\n        # If the file_path is a URL\n        response = requests.get(file_path)\n        response.raise_for_status()\n        content = BytesIO(response.content)\n        vid = imageio.get_reader(content)\n    else:\n        # Assuming it's a local file path\n        vid = imageio.get_reader(file_path)\n\n    for frame in vid:\n        pil_image = Image.fromarray(frame)\n        images.append(pil_image)\n\n    return images\n\nvideo = load_video(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif\")\n\noutput = pipe(\n    video = video,\n    prompt=\"panda playing a guitar, on a boat, in the ocean, high quality\",\n    negative_prompt=\"bad quality, worse quality\",\n    guidance_scale=7.5,\n    num_inference_steps=25,\n    strength=0.5,\n    generator=torch.Generator(\"cpu\").manual_seed(42),\n)\nframes = output.frames[0]\nexport_to_gif(frames, \"animation.gif\")\n```\n\nHere are some sample outputs:\n\n<table>\n    <tr>\n      <th align=center>Source Video</th>\n      <th align=center>Output Video</th>\n    </tr>\n    <tr>\n        <td align=center>\n          raccoon playing a guitar\n          <br />\n          <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif\"\n              alt=\"racoon playing a guitar\"\n              style=\"width: 300px;\" />\n        </td>\n        <td align=center>\n          panda playing a guitar\n          <br/>\n          <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-output-1.gif\"\n              alt=\"panda playing a guitar\"\n              style=\"width: 300px;\" />\n        </td>\n    </tr>\n    <tr>\n        <td align=center>\n          closeup of margot robbie, fireworks in the background, high quality\n          <br />\n          <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-2.gif\"\n              alt=\"closeup of margot robbie, fireworks in the background, high quality\"\n              style=\"width: 300px;\" />\n        </td>\n        <td align=center>\n          closeup of tony stark, robert downey jr, fireworks\n          <br/>\n          <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-output-2.gif\"\n              alt=\"closeup of tony stark, robert downey jr, fireworks\"\n              style=\"width: 300px;\" />\n        </td>\n    </tr>\n</table>\n\n\n\n### AnimateDiffVideoToVideoControlNetPipeline\n\nAnimateDiff can be used together with ControlNets to enhance video-to-video generation by allowing for precise control over the output. ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala, and allows you to condition Stable Diffusion with an additional control image to ensure that the spatial information is preserved throughout the video. \n\nThis pipeline allows you to condition your generation both on the original video and on a sequence of control images.\n\n```python\nimport torch\nfrom PIL import Image\nfrom tqdm.auto import tqdm\n\nfrom controlnet_aux.processor import OpenposeDetector\nfrom diffusers import AnimateDiffVideoToVideoControlNetPipeline\nfrom diffusers.utils import export_to_gif, load_video\nfrom diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler\n\n# Load the ControlNet\ncontrolnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-openpose\", torch_dtype=torch.float16)\n# Load the motion adapter\nmotion_adapter = MotionAdapter.from_pretrained(\"wangfuyun/AnimateLCM\")\n# Load SD 1.5 based finetuned model\nvae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16)\npipe = AnimateDiffVideoToVideoControlNetPipeline.from_pretrained(\n    \"SG161222/Realistic_Vision_V5.1_noVAE\",\n    motion_adapter=motion_adapter,\n    controlnet=controlnet,\n    vae=vae,\n).to(device=\"cuda\", dtype=torch.float16)\n\n# Enable LCM to speed up inference\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule=\"linear\")\npipe.load_lora_weights(\"wangfuyun/AnimateLCM\", weight_name=\"AnimateLCM_sd15_t2v_lora.safetensors\", adapter_name=\"lcm-lora\")\npipe.set_adapters([\"lcm-lora\"], [0.8])\n\nvideo = load_video(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif\")\nvideo = [frame.convert(\"RGB\") for frame in video]\n\nprompt = \"astronaut in space, dancing\"\nnegative_prompt = \"bad quality, worst quality, jpeg artifacts, ugly\"\n\n# Create controlnet preprocessor\nopen_pose = OpenposeDetector.from_pretrained(\"lllyasviel/Annotators\").to(\"cuda\")\n\n# Preprocess controlnet images\nconditioning_frames = []\nfor frame in tqdm(video):\n    conditioning_frames.append(open_pose(frame))\n\nstrength = 0.8\nwith torch.inference_mode():\n    video = pipe(\n        video=video,\n        prompt=prompt,\n        negative_prompt=negative_prompt,\n        num_inference_steps=10,\n        guidance_scale=2.0,\n        controlnet_conditioning_scale=0.75,\n        conditioning_frames=conditioning_frames,\n        strength=strength,\n        generator=torch.Generator().manual_seed(42),\n    ).frames[0]\n\nvideo = [frame.resize(conditioning_frames[0].size) for frame in video]\nexport_to_gif(video, f\"animatediff_vid2vid_controlnet.gif\", fps=8)\n```\n\nHere are some sample outputs:\n\n<table align=\"center\">\n    <tr>\n      <th align=\"center\">Source Video</th>\n      <th align=\"center\">Output Video</th>\n    </tr>\n    <tr>\n        <td align=\"center\">\n          anime girl, dancing\n          <br />\n          <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif\" alt=\"anime girl, dancing\" />\n        </td>\n        <td align=\"center\">\n          astronaut in space, dancing\n          <br/>\n          <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff_vid2vid_controlnet.gif\" alt=\"astronaut in space, dancing\" />\n        </td>\n    </tr>\n</table>\n\n**The lights and composition were transferred from the Source Video.**\n\n## Using Motion LoRAs\n\nMotion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-motion-adapter-v1-5-2` checkpoint. These LoRAs are responsible for adding specific types of motion to the animations.\n\n```python\nimport torch\nfrom diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter\nfrom diffusers.utils import export_to_gif\n\n# Load the motion adapter\nadapter = MotionAdapter.from_pretrained(\"guoyww/animatediff-motion-adapter-v1-5-2\", torch_dtype=torch.float16)\n# load SD 1.5 based finetuned model\nmodel_id = \"SG161222/Realistic_Vision_V5.1_noVAE\"\npipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)\npipe.load_lora_weights(\n    \"guoyww/animatediff-motion-lora-zoom-out\", adapter_name=\"zoom-out\"\n)\n\nscheduler = DDIMScheduler.from_pretrained(\n    model_id,\n    subfolder=\"scheduler\",\n    clip_sample=False,\n    beta_schedule=\"linear\",\n    timestep_spacing=\"linspace\",\n    steps_offset=1,\n)\npipe.scheduler = scheduler\n\n# enable memory savings\npipe.enable_vae_slicing()\npipe.enable_model_cpu_offload()\n\noutput = pipe(\n    prompt=(\n        \"masterpiece, bestquality, highlydetailed, ultradetailed, sunset, \"\n        \"orange sky, warm lighting, fishing boats, ocean waves seagulls, \"\n        \"rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, \"\n        \"golden hour, coastal landscape, seaside scenery\"\n    ),\n    negative_prompt=\"bad quality, worse quality\",\n    num_frames=16,\n    guidance_scale=7.5,\n    num_inference_steps=25,\n    generator=torch.Generator(\"cpu\").manual_seed(42),\n)\nframes = output.frames[0]\nexport_to_gif(frames, \"animation.gif\")\n```\n\n<table>\n    <tr>\n        <td><center>\n        masterpiece, bestquality, sunset.\n        <br>\n        <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-zoom-out-lora.gif\"\n            alt=\"masterpiece, bestquality, sunset\"\n            style=\"width: 300px;\" />\n        </center></td>\n    </tr>\n</table>\n\n## Using Motion LoRAs with PEFT\n\nYou can also leverage the [PEFT](https://github.com/huggingface/peft) backend to combine Motion LoRA's and create more complex animations.\n\nFirst install PEFT with\n\n```shell\npip install peft\n```\n\nThen you can use the following code to combine Motion LoRAs.\n\n```python\nimport torch\nfrom diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter\nfrom diffusers.utils import export_to_gif\n\n# Load the motion adapter\nadapter = MotionAdapter.from_pretrained(\"guoyww/animatediff-motion-adapter-v1-5-2\", torch_dtype=torch.float16)\n# load SD 1.5 based finetuned model\nmodel_id = \"SG161222/Realistic_Vision_V5.1_noVAE\"\npipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)\n\npipe.load_lora_weights(\n    \"diffusers/animatediff-motion-lora-zoom-out\", adapter_name=\"zoom-out\",\n)\npipe.load_lora_weights(\n    \"diffusers/animatediff-motion-lora-pan-left\", adapter_name=\"pan-left\",\n)\npipe.set_adapters([\"zoom-out\", \"pan-left\"], adapter_weights=[1.0, 1.0])\n\nscheduler = DDIMScheduler.from_pretrained(\n    model_id,\n    subfolder=\"scheduler\",\n    clip_sample=False,\n    timestep_spacing=\"linspace\",\n    beta_schedule=\"linear\",\n    steps_offset=1,\n)\npipe.scheduler = scheduler\n\n# enable memory savings\npipe.enable_vae_slicing()\npipe.enable_model_cpu_offload()\n\noutput = pipe(\n    prompt=(\n        \"masterpiece, bestquality, highlydetailed, ultradetailed, sunset, \"\n        \"orange sky, warm lighting, fishing boats, ocean waves seagulls, \"\n        \"rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, \"\n        \"golden hour, coastal landscape, seaside scenery\"\n    ),\n    negative_prompt=\"bad quality, worse quality\",\n    num_frames=16,\n    guidance_scale=7.5,\n    num_inference_steps=25,\n    generator=torch.Generator(\"cpu\").manual_seed(42),\n)\nframes = output.frames[0]\nexport_to_gif(frames, \"animation.gif\")\n```\n\n<table>\n    <tr>\n        <td><center>\n        masterpiece, bestquality, sunset.\n        <br>\n        <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-zoom-out-pan-left-lora.gif\"\n            alt=\"masterpiece, bestquality, sunset\"\n            style=\"width: 300px;\" />\n        </center></td>\n    </tr>\n</table>\n\n## Using FreeInit\n\n[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://huggingface.co/papers/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu.\n\nFreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper.\n\nThe following example demonstrates the usage of FreeInit.\n\n```python\nimport torch\nfrom diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler\nfrom diffusers.utils import export_to_gif\n\nadapter = MotionAdapter.from_pretrained(\"guoyww/animatediff-motion-adapter-v1-5-2\")\nmodel_id = \"SG161222/Realistic_Vision_V5.1_noVAE\"\npipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16).to(\"cuda\")\npipe.scheduler = DDIMScheduler.from_pretrained(\n    model_id,\n    subfolder=\"scheduler\",\n    beta_schedule=\"linear\",\n    clip_sample=False,\n    timestep_spacing=\"linspace\",\n    steps_offset=1\n)\n\n# enable memory savings\npipe.enable_vae_slicing()\npipe.enable_vae_tiling()\n\n# enable FreeInit\n# Refer to the enable_free_init documentation for a full list of configurable parameters\npipe.enable_free_init(method=\"butterworth\", use_fast_sampling=True)\n\n# run inference\noutput = pipe(\n    prompt=\"a panda playing a guitar, on a boat, in the ocean, high quality\",\n    negative_prompt=\"bad quality, worse quality\",\n    num_frames=16,\n    guidance_scale=7.5,\n    num_inference_steps=20,\n    generator=torch.Generator(\"cpu\").manual_seed(666),\n)\n\n# disable FreeInit\npipe.disable_free_init()\n\nframes = output.frames[0]\nexport_to_gif(frames, \"animation.gif\")\n```\n\n> [!WARNING]\n> FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n<table>\n    <tr>\n      <th align=center>Without FreeInit enabled</th>\n      <th align=center>With FreeInit enabled</th>\n    </tr>\n    <tr>\n        <td align=center>\n          panda playing a guitar\n          <br />\n          <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-no-freeinit.gif\"\n              alt=\"panda playing a guitar\"\n              style=\"width: 300px;\" />\n        </td>\n        <td align=center>\n          panda playing a guitar\n          <br/>\n          <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-freeinit.gif\"\n              alt=\"panda playing a guitar\"\n              style=\"width: 300px;\" />\n        </td>\n    </tr>\n</table>\n\n## Using AnimateLCM\n\n[AnimateLCM](https://animatelcm.github.io/) is a motion module checkpoint and an [LCM LoRA](https://huggingface.co/docs/diffusers/using-diffusers/inference_with_lcm_lora) that have been created using a consistency learning strategy that decouples the distillation of the image generation priors and the motion generation priors.\n\n```python\nimport torch\nfrom diffusers import AnimateDiffPipeline, LCMScheduler, MotionAdapter\nfrom diffusers.utils import export_to_gif\n\nadapter = MotionAdapter.from_pretrained(\"wangfuyun/AnimateLCM\")\npipe = AnimateDiffPipeline.from_pretrained(\"emilianJR/epiCRealism\", motion_adapter=adapter)\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule=\"linear\")\n\npipe.load_lora_weights(\"wangfuyun/AnimateLCM\", weight_name=\"sd15_lora_beta.safetensors\", adapter_name=\"lcm-lora\")\n\npipe.enable_vae_slicing()\npipe.enable_model_cpu_offload()\n\noutput = pipe(\n    prompt=\"A space rocket with trails of smoke behind it launching into space from the desert, 4k, high resolution\",\n    negative_prompt=\"bad quality, worse quality, low resolution\",\n    num_frames=16,\n    guidance_scale=1.5,\n    num_inference_steps=6,\n    generator=torch.Generator(\"cpu\").manual_seed(0),\n)\nframes = output.frames[0]\nexport_to_gif(frames, \"animatelcm.gif\")\n```\n\n<table>\n    <tr>\n        <td><center>\n        A space rocket, 4K.\n        <br>\n        <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatelcm-output.gif\"\n            alt=\"A space rocket, 4K\"\n            style=\"width: 300px;\" />\n        </center></td>\n    </tr>\n</table>\n\nAnimateLCM is also compatible with existing [Motion LoRAs](https://huggingface.co/collections/dn6/animatediff-motion-loras-654cb8ad732b9e3cf4d3c17e).\n\n```python\nimport torch\nfrom diffusers import AnimateDiffPipeline, LCMScheduler, MotionAdapter\nfrom diffusers.utils import export_to_gif\n\nadapter = MotionAdapter.from_pretrained(\"wangfuyun/AnimateLCM\")\npipe = AnimateDiffPipeline.from_pretrained(\"emilianJR/epiCRealism\", motion_adapter=adapter)\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule=\"linear\")\n\npipe.load_lora_weights(\"wangfuyun/AnimateLCM\", weight_name=\"sd15_lora_beta.safetensors\", adapter_name=\"lcm-lora\")\npipe.load_lora_weights(\"guoyww/animatediff-motion-lora-tilt-up\", adapter_name=\"tilt-up\")\n\npipe.set_adapters([\"lcm-lora\", \"tilt-up\"], [1.0, 0.8])\npipe.enable_vae_slicing()\npipe.enable_model_cpu_offload()\n\noutput = pipe(\n    prompt=\"A space rocket with trails of smoke behind it launching into space from the desert, 4k, high resolution\",\n    negative_prompt=\"bad quality, worse quality, low resolution\",\n    num_frames=16,\n    guidance_scale=1.5,\n    num_inference_steps=6,\n    generator=torch.Generator(\"cpu\").manual_seed(0),\n)\nframes = output.frames[0]\nexport_to_gif(frames, \"animatelcm-motion-lora.gif\")\n```\n\n<table>\n    <tr>\n        <td><center>\n        A space rocket, 4K.\n        <br>\n        <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatelcm-motion-lora.gif\"\n            alt=\"A space rocket, 4K\"\n            style=\"width: 300px;\" />\n        </center></td>\n    </tr>\n</table>\n\n## Using FreeNoise\n\n[FreeNoise: Tuning-Free Longer Video Diffusion via Noise Rescheduling](https://huggingface.co/papers/2310.15169) by Haonan Qiu, Menghan Xia, Yong Zhang, Yingqing He, Xintao Wang, Ying Shan, Ziwei Liu.\n\nFreeNoise is a sampling mechanism that can generate longer videos with short-video generation models by employing noise-rescheduling, temporal attention over sliding windows, and weighted averaging of latent frames. It also can be used with multiple prompts to allow for interpolated video generations. More details are available in the paper.\n\nThe currently supported AnimateDiff pipelines that can be used with FreeNoise are:\n- [`AnimateDiffPipeline`]\n- [`AnimateDiffControlNetPipeline`]\n- [`AnimateDiffVideoToVideoPipeline`]\n- [`AnimateDiffVideoToVideoControlNetPipeline`]\n\nIn order to use FreeNoise, a single line needs to be added to the inference code after loading your pipelines.\n\n```diff\n+ pipe.enable_free_noise()\n```\n\nAfter this, either a single prompt could be used, or multiple prompts can be passed as a dictionary of integer-string pairs. The integer keys of the dictionary correspond to the frame index at which the influence of that prompt would be maximum. Each frame index should map to a single string prompt. The prompts for intermediate frame indices, that are not passed in the dictionary, are created by interpolating between the frame prompts that are passed. By default, simple linear interpolation is used. However, you can customize this behaviour with a callback to the `prompt_interpolation_callback` parameter when enabling FreeNoise.\n\nFull example:\n\n```python\nimport torch\nfrom diffusers import AutoencoderKL, AnimateDiffPipeline, LCMScheduler, MotionAdapter\nfrom diffusers.utils import export_to_video, load_image\n\n# Load pipeline\ndtype = torch.float16\nmotion_adapter = MotionAdapter.from_pretrained(\"wangfuyun/AnimateLCM\", torch_dtype=dtype)\nvae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=dtype)\n\npipe = AnimateDiffPipeline.from_pretrained(\"emilianJR/epiCRealism\", motion_adapter=motion_adapter, vae=vae, torch_dtype=dtype)\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule=\"linear\")\n\npipe.load_lora_weights(\n    \"wangfuyun/AnimateLCM\", weight_name=\"AnimateLCM_sd15_t2v_lora.safetensors\", adapter_name=\"lcm_lora\"\n)\npipe.set_adapters([\"lcm_lora\"], [0.8])\n\n# Enable FreeNoise for long prompt generation\npipe.enable_free_noise(context_length=16, context_stride=4)\npipe.to(\"cuda\")\n\n# Can be a single prompt, or a dictionary with frame timesteps\nprompt = {\n    0: \"A caterpillar on a leaf, high quality, photorealistic\",\n    40: \"A caterpillar transforming into a cocoon, on a leaf, near flowers, photorealistic\",\n    80: \"A cocoon on a leaf, flowers in the background, photorealistic\",\n    120: \"A cocoon maturing and a butterfly being born, flowers and leaves visible in the background, photorealistic\",\n    160: \"A beautiful butterfly, vibrant colors, sitting on a leaf, flowers in the background, photorealistic\",\n    200: \"A beautiful butterfly, flying away in a forest, photorealistic\",\n    240: \"A cyberpunk butterfly, neon lights, glowing\",\n}\nnegative_prompt = \"bad quality, worst quality, jpeg artifacts\"\n\n# Run inference\noutput = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=256,\n    guidance_scale=2.5,\n    num_inference_steps=10,\n    generator=torch.Generator(\"cpu\").manual_seed(0),\n)\n\n# Save video\nframes = output.frames[0]\nexport_to_video(frames, \"output.mp4\", fps=16)\n```\n\n### FreeNoise memory savings\n\nSince FreeNoise processes multiple frames together, there are parts in the modeling where the memory required exceeds that available on normal consumer GPUs. The main memory bottlenecks that we identified are spatial and temporal attention blocks, upsampling and downsampling blocks, resnet blocks and feed-forward layers. Since most of these blocks operate effectively only on the channel/embedding dimension, one can perform chunked inference across the batch dimensions. The batch dimension in AnimateDiff are either spatial (`[B x F, H x W, C]`) or temporal (`B x H x W, F, C`) in nature (note that it may seem counter-intuitive, but the batch dimension here are correct, because spatial blocks process across the `B x F` dimension while the temporal blocks process across the `B x H x W` dimension). We introduce a `SplitInferenceModule` that makes it easier to chunk across any dimension and perform inference. This saves a lot of memory but comes at the cost of requiring more time for inference.\n\n```diff\n# Load pipeline and adapters\n# ...\n+ pipe.enable_free_noise_split_inference()\n+ pipe.unet.enable_forward_chunking(16)\n```\n\nThe call to `pipe.enable_free_noise_split_inference` method accepts two parameters: `spatial_split_size` (defaults to `256`) and `temporal_split_size` (defaults to `16`). These can be configured based on how much VRAM you have available. A lower split size results in lower memory usage but slower inference, whereas a larger split size results in faster inference at the cost of more memory.\n\n## Using `from_single_file` with the MotionAdapter\n\n`diffusers>=0.30.0` supports loading the AnimateDiff checkpoints into the `MotionAdapter` in their original format via `from_single_file`\n\n```python\nfrom diffusers import MotionAdapter\n\nckpt_path = \"https://huggingface.co/Lightricks/LongAnimateDiff/blob/main/lt_long_mm_32_frames.ckpt\"\n\nadapter = MotionAdapter.from_single_file(ckpt_path, torch_dtype=torch.float16)\npipe = AnimateDiffPipeline.from_pretrained(\"emilianJR/epiCRealism\", motion_adapter=adapter)\n```\n\n## AnimateDiffPipeline\n\n[[autodoc]] AnimateDiffPipeline\n  - all\n  - __call__\n\n## AnimateDiffControlNetPipeline\n\n[[autodoc]] AnimateDiffControlNetPipeline\n  - all\n  - __call__\n\n## AnimateDiffSparseControlNetPipeline\n\n[[autodoc]] AnimateDiffSparseControlNetPipeline\n  - all\n  - __call__\n\n## AnimateDiffSDXLPipeline\n\n[[autodoc]] AnimateDiffSDXLPipeline\n  - all\n  - __call__\n\n## AnimateDiffVideoToVideoPipeline\n\n[[autodoc]] AnimateDiffVideoToVideoPipeline\n  - all\n  - __call__\n\n## AnimateDiffVideoToVideoControlNetPipeline\n\n[[autodoc]] AnimateDiffVideoToVideoControlNetPipeline\n  - all\n  - __call__\n\n## AnimateDiffPipelineOutput\n\n[[autodoc]] pipelines.animatediff.AnimateDiffPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/attend_and_excite.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# Attend-and-Excite\n\nAttend-and-Excite for Stable Diffusion was proposed in [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://attendandexcite.github.io/Attend-and-Excite/) and provides textual attention control over image generation.\n\nThe abstract from the paper is:\n\n*Recent text-to-image generative models have demonstrated an unparalleled ability to generate diverse and creative imagery guided by a target text prompt. While revolutionary, current state-of-the-art diffusion models may still fail in generating images that fully convey the semantics in the given text prompt. We analyze the publicly available Stable Diffusion model and assess the existence of catastrophic neglect, where the model fails to generate one or more of the subjects from the input prompt. Moreover, we find that in some cases the model also fails to correctly bind attributes (e.g., colors) to their corresponding subjects. To help mitigate these failure cases, we introduce the concept of Generative Semantic Nursing (GSN), where we seek to intervene in the generative process on the fly during inference time to improve the faithfulness of the generated images. Using an attention-based formulation of GSN, dubbed Attend-and-Excite, we guide the model to refine the cross-attention units to attend to all subject tokens in the text prompt and strengthen - or excite - their activations, encouraging the model to generate all subjects described in the text prompt. We compare our approach to alternative approaches and demonstrate that it conveys the desired concepts more faithfully across a range of text prompts.*\n\nYou can find additional information about Attend-and-Excite on the [project page](https://attendandexcite.github.io/Attend-and-Excite/), the [original codebase](https://github.com/AttendAndExcite/Attend-and-Excite), or try it out in a [demo](https://huggingface.co/spaces/AttendAndExcite/Attend-and-Excite).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## StableDiffusionAttendAndExcitePipeline\n\n[[autodoc]] StableDiffusionAttendAndExcitePipeline\n\t- all\n\t- __call__\n\n## StableDiffusionPipelineOutput\n\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/audioldm.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# AudioLDM\n\nAudioLDM was proposed in [AudioLDM: Text-to-Audio Generation with Latent Diffusion Models](https://huggingface.co/papers/2301.12503) by Haohe Liu et al. Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM\nis a text-to-audio _latent diffusion model (LDM)_ that learns continuous audio representations from [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap)\nlatents. AudioLDM takes a text prompt as input and predicts the corresponding audio. It can generate text-conditional\nsound effects, human speech and music.\n\nThe abstract from the paper is:\n\n*Text-to-audio (TTA) system has recently gained attention for its ability to synthesize general audio based on text descriptions. However, previous studies in TTA have limited generation quality with high computational costs. In this study, we propose AudioLDM, a TTA system that is built on a latent space to learn the continuous audio representations from contrastive language-audio pretraining (CLAP) latents. The pretrained CLAP models enable us to train LDMs with audio embedding while providing text embedding as a condition during sampling. By learning the latent representations of audio signals and their compositions without modeling the cross-modal relationship, AudioLDM is advantageous in both generation quality and computational efficiency. Trained on AudioCaps with a single GPU, AudioLDM achieves state-of-the-art TTA performance measured by both objective and subjective metrics (e.g., frechet distance). Moreover, AudioLDM is the first TTA system that enables various text-guided audio manipulations (e.g., style transfer) in a zero-shot fashion. Our implementation and demos are available at [this https URL](https://audioldm.github.io/).*\n\nThe original codebase can be found at [haoheliu/AudioLDM](https://github.com/haoheliu/AudioLDM).\n\n## Tips\n\nWhen constructing a prompt, keep in mind:\n\n* Descriptive prompt inputs work best; you can use adjectives to describe the sound (for example, \"high quality\" or \"clear\") and make the prompt context specific (for example, \"water stream in a forest\" instead of \"stream\").\n* It's best to use general terms like \"cat\" or \"dog\" instead of specific names or abstract objects the model may not be familiar with.\n\nDuring inference:\n\n* The _quality_ of the predicted audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.\n* The _length_ of the predicted audio sample can be controlled by varying the `audio_length_in_s` argument.\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## AudioLDMPipeline\n[[autodoc]] AudioLDMPipeline\n\t- all\n\t- __call__\n\n## AudioPipelineOutput\n[[autodoc]] pipelines.AudioPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/audioldm2.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AudioLDM 2\n\nAudioLDM 2 was proposed in [AudioLDM 2: Learning Holistic Audio Generation with Self-supervised Pretraining](https://huggingface.co/papers/2308.05734) by Haohe Liu et al. AudioLDM 2 takes a text prompt as input and predicts the corresponding audio. It can generate text-conditional sound effects, human speech and music.\n\nInspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM 2 is a text-to-audio _latent diffusion model (LDM)_ that learns continuous audio representations from text embeddings. Two text encoder models are used to compute the text embeddings from a prompt input: the text-branch of [CLAP](https://huggingface.co/docs/transformers/main/en/model_doc/clap) and the encoder of [Flan-T5](https://huggingface.co/docs/transformers/main/en/model_doc/flan-t5). These text embeddings are then projected to a shared embedding space by an [AudioLDM2ProjectionModel](https://huggingface.co/docs/diffusers/main/api/pipelines/audioldm2#diffusers.AudioLDM2ProjectionModel). A [GPT2](https://huggingface.co/docs/transformers/main/en/model_doc/gpt2) _language model (LM)_ is used to auto-regressively predict eight new embedding vectors, conditional on the projected CLAP and Flan-T5 embeddings. The generated embedding vectors and Flan-T5 text embeddings are used as cross-attention conditioning in the LDM. The [UNet](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2UNet2DConditionModel) of AudioLDM 2 is unique in the sense that it takes **two** cross-attention embeddings, as opposed to one cross-attention conditioning, as in most other LDMs.\n\nThe abstract of the paper is the following:\n\n*Although audio generation shares commonalities across different types of audio, such as speech, music, and sound effects, designing models for each type requires careful consideration of specific objectives and biases that can significantly differ from those of other types. To bring us closer to a unified perspective of audio generation, this paper proposes a framework that utilizes the same learning method for speech, music, and sound effect generation. Our framework introduces a general representation of audio, called \"language of audio\" (LOA). Any audio can be translated into LOA based on AudioMAE, a self-supervised pre-trained representation learning model. In the generation process, we translate any modalities into LOA by using a GPT-2 model, and we perform self-supervised audio generation learning with a latent diffusion model conditioned on LOA. The proposed framework naturally brings advantages such as in-context learning abilities and reusable self-supervised pretrained AudioMAE and latent diffusion models. Experiments on the major benchmarks of text-to-audio, text-to-music, and text-to-speech demonstrate state-of-the-art or competitive performance against previous approaches. Our code, pretrained model, and demo are available at [this https URL](https://audioldm.github.io/audioldm2).*\n\nThis pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi) and [Nguyễn Công Tú Anh](https://github.com/tuanh123789). The original codebase can be\nfound at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).\n\n## Tips\n\n### Choosing a checkpoint\n\nAudioLDM2 comes in three variants. Two of these checkpoints are applicable to the general task of text-to-audio generation. The third checkpoint is trained exclusively on text-to-music generation.\n\nAll checkpoints share the same model size for the text encoders and VAE. They differ in the size and depth of the UNet.\nSee table below for details on the three checkpoints:\n\n| Checkpoint                                                      | Task          | UNet Model Size | Total Model Size | Training Data / h |\n|-----------------------------------------------------------------|---------------|-----------------|------------------|-------------------|\n| [audioldm2](https://huggingface.co/cvssp/audioldm2)             | Text-to-audio | 350M            | 1.1B             | 1150k             |\n| [audioldm2-large](https://huggingface.co/cvssp/audioldm2-large) | Text-to-audio | 750M            | 1.5B             | 1150k             |\n| [audioldm2-music](https://huggingface.co/cvssp/audioldm2-music) | Text-to-music | 350M            | 1.1B             | 665k              |\n| [audioldm2-gigaspeech](https://huggingface.co/anhnct/audioldm2_gigaspeech) | Text-to-speech | 350M            | 1.1B             |10k              |\n| [audioldm2-ljspeech](https://huggingface.co/anhnct/audioldm2_ljspeech) | Text-to-speech | 350M            | 1.1B             |              |\n\n### Constructing a prompt\n\n* Descriptive prompt inputs work best: use adjectives to describe the sound (e.g. \"high quality\" or \"clear\") and make the prompt context specific (e.g. \"water stream in a forest\" instead of \"stream\").\n* It's best to use general terms like \"cat\" or \"dog\" instead of specific names or abstract objects the model may not be familiar with.\n* Using a **negative prompt** can significantly improve the quality of the generated waveform, by guiding the generation away from terms that correspond to poor quality audio. Try using a negative prompt of \"Low quality.\"\n\n### Controlling inference\n\n* The _quality_ of the predicted audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.\n* The _length_ of the predicted audio sample can be controlled by varying the `audio_length_in_s` argument.\n\n### Evaluating generated waveforms:\n\n* The quality of the generated waveforms can vary significantly based on the seed. Try generating with different seeds until you find a satisfactory generation.\n* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.\n\nThe following example demonstrates how to construct good music and speech generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## AudioLDM2Pipeline\n[[autodoc]] AudioLDM2Pipeline\n\t- all\n\t- __call__\n\n## AudioLDM2ProjectionModel\n[[autodoc]] AudioLDM2ProjectionModel\n\t- forward\n\n## AudioLDM2UNet2DConditionModel\n[[autodoc]] AudioLDM2UNet2DConditionModel\n\t- forward\n\n## AudioPipelineOutput\n[[autodoc]] pipelines.AudioPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/aura_flow.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AuraFlow\n\nAuraFlow is inspired by [Stable Diffusion 3](../pipelines/stable_diffusion/stable_diffusion_3) and is by far the largest text-to-image generation model that comes with an Apache 2.0 license. This model achieves state-of-the-art results on the [GenEval](https://github.com/djghosh13/geneval) benchmark.\n\nIt was developed by the Fal team and more details about it can be found in [this blog post](https://blog.fal.ai/auraflow/).\n\n> [!TIP]\n> AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details.\n\n## Quantization\n\nQuantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.\n\nRefer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`AuraFlowPipeline`] for inference with bitsandbytes.\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, AuraFlowTransformer2DModel, AuraFlowPipeline\nfrom transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel\n\nquant_config = BitsAndBytesConfig(load_in_8bit=True)\ntext_encoder_8bit = T5EncoderModel.from_pretrained(\n    \"fal/AuraFlow\",\n    subfolder=\"text_encoder\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_8bit = AuraFlowTransformer2DModel.from_pretrained(\n    \"fal/AuraFlow\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\npipeline = AuraFlowPipeline.from_pretrained(\n    \"fal/AuraFlow\",\n    text_encoder=text_encoder_8bit,\n    transformer=transformer_8bit,\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n)\n\nprompt = \"a tiny astronaut hatching from an egg on the moon\"\nimage = pipeline(prompt).images[0]\nimage.save(\"auraflow.png\")\n```\n\nLoading [GGUF checkpoints](https://huggingface.co/docs/diffusers/quantization/gguf) are also supported:\n\n```py\nimport torch\nfrom diffusers import (\n    AuraFlowPipeline,\n    GGUFQuantizationConfig,\n    AuraFlowTransformer2DModel,\n)\n\ntransformer = AuraFlowTransformer2DModel.from_single_file(\n    \"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf\",\n    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),\n    torch_dtype=torch.bfloat16,\n)\n\npipeline = AuraFlowPipeline.from_pretrained(\n    \"fal/AuraFlow-v0.3\",\n    transformer=transformer,\n    torch_dtype=torch.bfloat16,\n)\n\nprompt = \"a cute pony in a field of flowers\"\nimage = pipeline(prompt).images[0]\nimage.save(\"auraflow.png\")\n```\n\n## Support for `torch.compile()`\n\nAuraFlow can be compiled with `torch.compile()` to speed up inference latency even for different resolutions. First, install PyTorch nightly following the instructions from [here](https://pytorch.org/). The snippet below shows the changes needed to enable this:\n\n```diff\n+ torch.fx.experimental._config.use_duck_shape = False\n+ pipeline.transformer = torch.compile(\n    pipeline.transformer, fullgraph=True, dynamic=True\n)\n```\n\nSpecifying `use_duck_shape` to be `False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out [this comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).\n\nThis enables from 100% (on low resolutions) to a 30% (on 1536x1536 resolution) speed improvements.\n\nThanks to [AstraliteHeart](https://github.com/huggingface/diffusers/pull/11297/) who helped us rewrite the [`AuraFlowTransformer2DModel`] class so that the above works for different resolutions ([PR](https://github.com/huggingface/diffusers/pull/11297/)).\n\n## AuraFlowPipeline\n\n[[autodoc]] AuraFlowPipeline\n\t- all\n\t- __call__"
  },
  {
    "path": "docs/source/en/api/pipelines/auto_pipeline.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AutoPipeline\n\nThe `AutoPipeline` is designed to make it easy to load a checkpoint for a task without needing to know the specific pipeline class. Based on the task, the `AutoPipeline` automatically retrieves the correct pipeline class from the checkpoint `model_index.json` file.\n\n> [!TIP]\n> Check out the [AutoPipeline](../../tutorials/autopipeline) tutorial to learn how to use this API!\n\n## AutoPipelineForText2Image\n\n[[autodoc]] AutoPipelineForText2Image\n\t- all\n\t- from_pretrained\n\t- from_pipe\n\n## AutoPipelineForImage2Image\n\n[[autodoc]] AutoPipelineForImage2Image\n\t- all\n\t- from_pretrained\n\t- from_pipe\n\n## AutoPipelineForInpainting\n\n[[autodoc]] AutoPipelineForInpainting\n\t- all\n\t- from_pretrained\n\t- from_pipe\n"
  },
  {
    "path": "docs/source/en/api/pipelines/blip_diffusion.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# BLIP-Diffusion\n\nBLIP-Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://huggingface.co/papers/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.\n\n\nThe abstract from the paper is:\n\n*Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications. Project page at [this https URL](https://dxli94.github.io/BLIP-Diffusion-website/).*\n\nThe original codebase can be found at [salesforce/LAVIS](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion). You can find the official BLIP-Diffusion checkpoints under the [hf.co/SalesForce](https://hf.co/SalesForce) organization.\n\n`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n\n## BlipDiffusionPipeline\n[[autodoc]] BlipDiffusionPipeline\n    - all\n    - __call__\n\n## BlipDiffusionControlNetPipeline\n[[autodoc]] BlipDiffusionControlNetPipeline\n    - all\n    - __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/bria_3_2.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Bria 3.2\n\nBria 3.2 is the next-generation commercial-ready text-to-image model. With just 4 billion parameters, it provides exceptional aesthetics and text rendering, evaluated to provide on par results to leading open-source models, and outperforming other licensed models.\nIn addition to being built entirely on licensed data, 3.2 provides several advantages for enterprise and commercial use:\n\n- Efficient Compute - the model is X3 smaller than the equivalent models in the market (4B parameters vs 12B parameters other open source models)\n- Architecture Consistency: Same architecture as 3.1—ideal for users looking to upgrade without disruption.\n- Fine-tuning Speedup: 2x faster fine-tuning on L40S and A100.\n\nOriginal model checkpoints for Bria 3.2 can be found [here](https://huggingface.co/briaai/BRIA-3.2).\nGithub repo for Bria 3.2 can be found [here](https://github.com/Bria-AI/BRIA-3.2).\n\nIf you want to learn more about the Bria platform, and get free traril access, please visit [bria.ai](https://bria.ai).\n\n\n## Usage\n\n_As the model is gated, before using it with diffusers you first need to go to the [Bria 3.2 Hugging Face page](https://huggingface.co/briaai/BRIA-3.2), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._\n\nUse the command below to log in:\n\n```bash\nhf auth login\n```\n\n\n## BriaPipeline\n\n[[autodoc]] BriaPipeline\n\t- all\n\t- __call__\n\n"
  },
  {
    "path": "docs/source/en/api/pipelines/bria_fibo.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Bria Fibo\n\nText-to-image models have mastered imagination - but not control. FIBO changes that.\n\nFIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.\n\nWith only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.\n\nFIBO is trained exclusively on a structured prompt and will not work with freeform text prompts.\nyou can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON)  to convert your freeform text prompt to a structured JSON prompt.\n\n> [!NOTE]\n> Avoid using freeform text prompts directly with FIBO because it does not produce the best results.\n\nRefer to the Bria Fibo Hugging Face [page](https://huggingface.co/briaai/FIBO) to learn more.\n\n\n## Usage\n\n_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._\n\nUse the command below to log in:\n\n```bash\nhf auth login\n```\n\n\n## BriaFiboPipeline\n\n[[autodoc]] BriaFiboPipeline\n\t- all\n\t- __call__"
  },
  {
    "path": "docs/source/en/api/pipelines/bria_fibo_edit.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Bria Fibo Edit\n\nFibo Edit is an 8B parameter image-to-image model that introduces a new paradigm of structured control, operating on JSON inputs paired with source images to enable deterministic and repeatable editing workflows.\nFeaturing native masking for granular precision, it moves beyond simple prompt-based diffusion to offer explicit, interpretable control optimized for production environments.\nIts lightweight architecture is designed for deep customization, empowering researchers to build specialized \"Edit\" models for domain-specific tasks while delivering top-tier aesthetic quality\n\n## Usage\n_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/Fibo-Edit), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._\n\nUse the command below to log in:\n\n```bash\nhf auth login\n```\n\n\n## BriaFiboEditPipeline\n\n[[autodoc]] BriaFiboEditPipeline\n\t- all\n\t- __call__"
  },
  {
    "path": "docs/source/en/api/pipelines/chroma.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Chroma\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n  <img alt=\"MPS\" src=\"https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22\">\n</div>\n\nChroma is a text to image generation model based on Flux.\n\nOriginal model checkpoints for Chroma can be found here:\n* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD)\n* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base)\n* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint)\n\n> [!TIP]\n> Chroma can use all the same optimizations as Flux.\n\n## Inference\n\n```python\nimport torch\nfrom diffusers import ChromaPipeline\n\npipe = ChromaPipeline.from_pretrained(\"lodestones/Chroma1-HD\", torch_dtype=torch.bfloat16)\npipe.enable_model_cpu_offload()\n\nprompt = [\n    \"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done.\"\n]\nnegative_prompt =  [\"low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors\"]\n\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    generator=torch.Generator(\"cpu\").manual_seed(433),\n    num_inference_steps=40,\n    guidance_scale=3.0,\n    num_images_per_prompt=1,\n).images[0]\nimage.save(\"chroma.png\")\n```\n\n## Loading from a single file\n\nTo use updated model checkpoints that are not in the Diffusers format, you can use the `ChromaTransformer2DModel` class to load the model from a single file in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.\n\nThe following example demonstrates how to run Chroma from a single file.\n\nThen run the following example\n\n```python\nimport torch\nfrom diffusers import ChromaTransformer2DModel, ChromaPipeline\n\nmodel_id = \"lodestones/Chroma1-HD\"\ndtype = torch.bfloat16\n\ntransformer = ChromaTransformer2DModel.from_single_file(\"https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors\", torch_dtype=dtype)\n\npipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)\npipe.enable_model_cpu_offload()\n\nprompt = [\n    \"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done.\"\n]\nnegative_prompt =  [\"low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors\"]\n\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    generator=torch.Generator(\"cpu\").manual_seed(433),\n    num_inference_steps=40,\n    guidance_scale=3.0,\n).images[0]\n\nimage.save(\"chroma-single-file.png\")\n```\n\n## ChromaPipeline\n\n[[autodoc]] ChromaPipeline\n\t- all\n\t- __call__\n\n## ChromaImg2ImgPipeline\n\n[[autodoc]] ChromaImg2ImgPipeline\n\t- all\n\t- __call__\n\n## ChromaInpaintPipeline\n\n[[autodoc]] ChromaInpaintPipeline\n  - all\n  - __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/chronoedit.md",
    "content": "<!-- Copyright 2025 The ChronoEdit Team and HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n<div style=\"float: right;\">\n  <div class=\"flex flex-wrap space-x-1\">\n    <a href=\"https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference\" target=\"_blank\" rel=\"noopener\">\n      <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n    </a>\n  </div>\n</div>\n\n# ChronoEdit\n\n[ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.\n\n> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.\n\n*Recent advances in large generative models have greatly enhanced both image editing and in-context image generation, yet a critical gap remains in ensuring physical consistency, where edited objects must remain coherent. This capability is especially vital for world simulation related tasks. In this paper, we present ChronoEdit, a framework that reframes image editing as a video generation problem. First, ChronoEdit treats the input and edited images as the first and last frames of a video, allowing it to leverage large pretrained video generative models that capture not only object appearance but also the implicit physics of motion and interaction through learned temporal consistency. Second, ChronoEdit introduces a temporal reasoning stage that explicitly performs editing at inference time. Under this setting, target frame is jointly denoised with reasoning tokens to imagine a plausible editing trajectory that constrains the solution space to physically viable transformations. The reasoning tokens are then dropped after a few steps to avoid the high computational cost of rendering a full video. To validate ChronoEdit, we introduce PBench-Edit, a new benchmark of image-prompt pairs for contexts that require physical consistency, and demonstrate that ChronoEdit surpasses state-of-the-art baselines in both visual fidelity and physical plausibility. Project page for code and models: [this https URL](https://research.nvidia.com/labs/toronto-ai/chronoedit).*\n\nThe ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.\n\nAvailable Models/LoRAs:\n- [nvidia/ChronoEdit-14B-Diffusers](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers)\n- [nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora)\n- [nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora)\n\n### Image Editing\n\n```py\nimport torch\nimport numpy as np\nfrom diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline\nfrom diffusers.utils import export_to_video, load_image\nfrom transformers import CLIPVisionModel\nfrom PIL import Image\n\nmodel_id = \"nvidia/ChronoEdit-14B-Diffusers\"\nimage_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder=\"image_encoder\", torch_dtype=torch.float32)\nvae = AutoencoderKLWan.from_pretrained(model_id, subfolder=\"vae\", torch_dtype=torch.float32)\ntransformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder=\"transformer\", torch_dtype=torch.bfloat16)\npipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\nimage = load_image(\n    \"https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png\"\n)\nmax_area = 720 * 1280\naspect_ratio = image.height / image.width\nmod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]\nheight = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value\nwidth = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value\nprint(\"width\", width, \"height\", height)\nimage = image.resize((width, height))\nprompt = (\n    \"The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. \"\n    \"The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood.\"\n)\n\noutput = pipe(\n    image=image,\n    prompt=prompt,\n    height=height,\n    width=width,\n    num_frames=5,\n    num_inference_steps=50,\n    guidance_scale=5.0,\n    enable_temporal_reasoning=False,\n    num_temporal_reasoning_steps=0,\n).frames[0]\nImage.fromarray((output[-1] * 255).clip(0, 255).astype(\"uint8\")).save(\"output.png\")\n```\n\nOptionally, enable **temporal reasoning** for improved physical consistency:\n```py\noutput = pipe(\n    image=image,\n    prompt=prompt,\n    height=height,\n    width=width,\n    num_frames=29,\n    num_inference_steps=50,\n    guidance_scale=5.0,\n    enable_temporal_reasoning=True,\n    num_temporal_reasoning_steps=50,\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=16)\nImage.fromarray((output[-1] * 255).clip(0, 255).astype(\"uint8\")).save(\"output.png\")\n```\n\n### Inference with 8-Step Distillation Lora\n\n```py\nimport torch\nimport numpy as np\nfrom diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline\nfrom diffusers.schedulers import UniPCMultistepScheduler\nfrom diffusers.utils import export_to_video, load_image\nfrom transformers import CLIPVisionModel\nfrom PIL import Image\n\nmodel_id = \"nvidia/ChronoEdit-14B-Diffusers\"\nimage_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder=\"image_encoder\", torch_dtype=torch.float32)\nvae = AutoencoderKLWan.from_pretrained(model_id, subfolder=\"vae\", torch_dtype=torch.float32)\ntransformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder=\"transformer\", torch_dtype=torch.bfloat16)\npipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)\npipe.load_lora_weights(\"nvidia/ChronoEdit-14B-Diffusers\", weight_name=\"lora/chronoedit_distill_lora.safetensors\", adapter_name=\"distill\")\npipe.fuse_lora(adapter_names=[\"distill\"], lora_scale=1.0)\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)\npipe.to(\"cuda\")\n\nimage = load_image(\n    \"https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png\"\n)\nmax_area = 720 * 1280\naspect_ratio = image.height / image.width\nmod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]\nheight = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value\nwidth = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value\nprint(\"width\", width, \"height\", height)\nimage = image.resize((width, height))\nprompt = (\n    \"The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. \"\n    \"The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood.\"\n)\n\noutput = pipe(\n    image=image,\n    prompt=prompt,\n    height=height,\n    width=width,\n    num_frames=5,\n    num_inference_steps=8,\n    guidance_scale=1.0,\n    enable_temporal_reasoning=False,\n    num_temporal_reasoning_steps=0,\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=16)\nImage.fromarray((output[-1] * 255).clip(0, 255).astype(\"uint8\")).save(\"output.png\")\n```\n\n### Inference with Multiple LoRAs\n\n```py\nimport torch\nimport numpy as np\nfrom diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline\nfrom diffusers.schedulers import UniPCMultistepScheduler\nfrom diffusers.utils import export_to_video, load_image\nfrom transformers import CLIPVisionModel\nfrom PIL import Image\n\nmodel_id = \"nvidia/ChronoEdit-14B-Diffusers\"\nimage_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder=\"image_encoder\", torch_dtype=torch.float32)\nvae = AutoencoderKLWan.from_pretrained(model_id, subfolder=\"vae\", torch_dtype=torch.float32)\ntransformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder=\"transformer\", torch_dtype=torch.bfloat16)\npipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)\npipe.load_lora_weights(\"nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora\", weight_name=\"paintbrush_lora_diffusers.safetensors\", adapter_name=\"paintbrush\")\npipe.load_lora_weights(\"nvidia/ChronoEdit-14B-Diffusers\", weight_name=\"lora/chronoedit_distill_lora.safetensors\", adapter_name=\"distill\")\npipe.fuse_lora(adapter_names=[\"paintbrush\", \"distill\"], lora_scale=1.0)\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)\npipe.to(\"cuda\")\n\nimage = load_image(\n    \"https://raw.githubusercontent.com/nv-tlabs/ChronoEdit/refs/heads/main/assets/images/input_paintbrush.png\"\n)\nmax_area = 720 * 1280\naspect_ratio = image.height / image.width\nmod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]\nheight = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value\nwidth = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value\nprint(\"width\", width, \"height\", height)\nimage = image.resize((width, height))\nprompt = (\n    \"Turn the pencil sketch in the image into an actual object that is consistent with the image’s content. The user wants to change the sketch to a crown and a hat.\"\n)\n\noutput = pipe(\n    image=image,\n    prompt=prompt,\n    height=height,\n    width=width,\n    num_frames=5,\n    num_inference_steps=8,\n    guidance_scale=1.0,\n    enable_temporal_reasoning=False,\n    num_temporal_reasoning_steps=0,\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=16)\nImage.fromarray((output[-1] * 255).clip(0, 255).astype(\"uint8\")).save(\"output_1.png\")\n```\n\n## ChronoEditPipeline\n\n[[autodoc]] ChronoEditPipeline\n  - all\n  - __call__\n\n## ChronoEditPipelineOutput\n\n[[autodoc]] pipelines.chronoedit.pipeline_output.ChronoEditPipelineOutput"
  },
  {
    "path": "docs/source/en/api/pipelines/cogvideox.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n-->\n\n<div style=\"float: right;\">\n  <div class=\"flex flex-wrap space-x-1\">\n    <a href=\"https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference\" target=\"_blank\" rel=\"noopener\">\n      <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n    </a>\n  </div>\n</div>\n\n# CogVideoX\n\n[CogVideoX](https://huggingface.co/papers/2408.06072) is a large diffusion transformer model - available in 2B and 5B parameters - designed to generate longer and more consistent videos from text. This model uses a 3D causal variational autoencoder to more efficiently process video data by reducing sequence length (and associated training compute) and preventing flickering in generated videos. An \"expert\" transformer with adaptive LayerNorm improves alignment between text and video, and 3D full attention helps accurately capture motion and time in generated videos.\n\nYou can find all the original CogVideoX checkpoints under the [CogVideoX](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) collection.\n\n> [!TIP]\n> Click on the CogVideoX models in the right sidebar for more examples of other video generation tasks.\n\nThe example below demonstrates how to generate a video optimized for memory or inference speed.\n\n<hfoptions id=\"usage\">\n<hfoption id=\"memory\">\n\nRefer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.\n\nThe quantized CogVideoX 5B model below requires ~16GB of VRAM.\n\n```py\nimport torch\nfrom diffusers import CogVideoXPipeline, AutoModel\nfrom diffusers.quantizers import PipelineQuantizationConfig\nfrom diffusers.hooks import apply_group_offloading\nfrom diffusers.utils import export_to_video\n\n# quantize weights to int8 with torchao\npipeline_quant_config = PipelineQuantizationConfig(\n  quant_backend=\"torchao\",\n  quant_kwargs={\"quant_type\": \"int8wo\"},\n  components_to_quantize=\"transformer\"\n)\n\n# fp8 layerwise weight-casting\ntransformer = AutoModel.from_pretrained(\n    \"THUDM/CogVideoX-5b\",\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16\n)\ntransformer.enable_layerwise_casting(\n    storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16\n)\n\npipeline = CogVideoXPipeline.from_pretrained(\n    \"THUDM/CogVideoX-5b\",\n    transformer=transformer,\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16\n)\npipeline.to(\"cuda\")\n\n# model-offloading\npipeline.enable_model_cpu_offload()\n\nprompt = \"\"\"\nA detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. \nThe ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. \nSurrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, \nwith the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.\n\"\"\"\n\nvideo = pipeline(\n    prompt=prompt,\n    guidance_scale=6,\n    num_inference_steps=50\n).frames[0]\nexport_to_video(video, \"output.mp4\", fps=8)\n```\n\n</hfoption>\n<hfoption id=\"inference speed\">\n\n[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.\n\nThe average inference time with torch.compile on a 80GB A100 is 76.27 seconds compared to 96.89 seconds for an uncompiled model.\n\n```py\nimport torch\nfrom diffusers import CogVideoXPipeline\nfrom diffusers.utils import export_to_video\n\npipeline = CogVideoXPipeline.from_pretrained(\n    \"THUDM/CogVideoX-2b\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\n\n# torch.compile\npipeline.transformer.to(memory_format=torch.channels_last)\npipeline.transformer = torch.compile(\n    pipeline.transformer, mode=\"max-autotune\", fullgraph=True\n)\n\nprompt = \"\"\"\nA detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. \nThe ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. \nSurrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, \nwith the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.\n\"\"\"\n\nvideo = pipeline(\n    prompt=prompt,\n    guidance_scale=6,\n    num_inference_steps=50\n).frames[0]\nexport_to_video(video, \"output.mp4\", fps=8)\n```\n\n</hfoption>\n</hfoptions>\n\n## Notes\n\n- CogVideoX supports LoRAs with [`~loaders.CogVideoXLoraLoaderMixin.load_lora_weights`].\n\n  <details>\n  <summary>Show example code</summary>\n\n  ```py\n  import torch\n  from diffusers import CogVideoXPipeline\n  from diffusers.hooks import apply_group_offloading\n  from diffusers.utils import export_to_video\n\n  pipeline = CogVideoXPipeline.from_pretrained(\n      \"THUDM/CogVideoX-5b\",\n      torch_dtype=torch.bfloat16\n  )\n  pipeline.to(\"cuda\")\n\n  # load LoRA weights\n  pipeline.load_lora_weights(\"finetrainers/CogVideoX-1.5-crush-smol-v0\", adapter_name=\"crush-lora\")\n  pipeline.set_adapters(\"crush-lora\", 0.9)\n\n  # model-offloading\n  pipeline.enable_model_cpu_offload()\n\n  prompt = \"\"\"\n  PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.\n  \"\"\"\n  negative_prompt = \"inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs\"\n\n  video = pipeline(\n      prompt=prompt, \n      negative_prompt=negative_prompt, \n      num_frames=81, \n      height=480,\n      width=768,\n      num_inference_steps=50\n  ).frames[0]\n  export_to_video(video, \"output.mp4\", fps=16)\n  ```\n\n  </details>\n\n- The text-to-video (T2V) checkpoints work best with a resolution of 1360x768 because that was the resolution it was pretrained on.\n\n- The image-to-video (I2V) checkpoints work with multiple resolutions. The width can vary from 768 to 1360, but the height must be 758. Both height and width must be divisible by 16.\n\n- Both T2V and I2V checkpoints work best with 81 and 161 frames. It is recommended to export the generated video at 16fps.\n\n- Refer to the table below to view memory usage when various memory-saving techniques are enabled.\n\n  | method | memory usage (enabled) | memory usage (disabled) |\n  |---|---|---|\n  | enable_model_cpu_offload | 19GB | 33GB |\n  | enable_sequential_cpu_offload | <4GB | ~33GB (very slow inference speed) |\n  | enable_tiling | 11GB (with enable_model_cpu_offload) | --- |\n \n## CogVideoXPipeline\n\n[[autodoc]] CogVideoXPipeline\n  - all\n  - __call__\n\n## CogVideoXImageToVideoPipeline\n\n[[autodoc]] CogVideoXImageToVideoPipeline\n  - all\n  - __call__\n\n## CogVideoXVideoToVideoPipeline\n\n[[autodoc]] CogVideoXVideoToVideoPipeline\n  - all\n  - __call__\n\n## CogVideoXFunControlPipeline\n\n[[autodoc]] CogVideoXFunControlPipeline\n  - all\n  - __call__\n\n## CogVideoXPipelineOutput\n\n[[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/cogview3.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n-->\n\n# CogView3Plus\n\n[CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) from Tsinghua University & ZhipuAI, by Wendi Zheng, Jiayan Teng, Zhuoyi Yang, Weihan Wang, Jidong Chen, Xiaotao Gu, Yuxiao Dong, Ming Ding, Jie Tang.\n\nThe abstract from the paper is:\n\n*Recent advancements in text-to-image generative systems have been largely driven by diffusion models. However, single-stage text-to-image diffusion models still face challenges, in terms of computational efficiency and the refinement of image details. To tackle the issue, we propose CogView3, an innovative cascaded framework that enhances the performance of text-to-image diffusion. CogView3 is the first model implementing relay diffusion in the realm of text-to-image generation, executing the task by first creating low-resolution images and subsequently applying relay-based super-resolution. This methodology not only results in competitive text-to-image outputs but also greatly reduces both training and inference costs. Our experimental results demonstrate that CogView3 outperforms SDXL, the current state-of-the-art open-source text-to-image diffusion model, by 77.0% in human evaluations, all while requiring only about 1/2 of the inference time. The distilled variant of CogView3 achieves comparable performance while only utilizing 1/10 of the inference time by SDXL.*\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\nThis pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).\n\n## CogView3PlusPipeline\n\n[[autodoc]] CogView3PlusPipeline\n  - all\n  - __call__\n\n## CogView3PipelineOutput\n\n[[autodoc]] pipelines.cogview3.pipeline_output.CogView3PipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/cogview4.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n-->\n\n# CogView4\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\nThis pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).\n\n## CogView4Pipeline\n\n[[autodoc]] CogView4Pipeline\n  - all\n  - __call__\n\n## CogView4PipelineOutput\n\n[[autodoc]] pipelines.cogview4.pipeline_output.CogView4PipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/consisid.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n-->\n\n# ConsisID\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n[Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://huggingface.co/papers/2411.17440) from Peking University & University of Rochester & etc, by Shenghai Yuan, Jinfa Huang, Xianyi He, Yunyang Ge, Yujun Shi, Liuhan Chen, Jiebo Luo, Li Yuan.\n\nThe abstract from the paper is:\n\n*Identity-preserving text-to-video (IPT2V) generation aims to create high-fidelity videos with consistent human identity. It is an important task in video generation but remains an open problem for generative models. This paper pushes the technical frontier of IPT2V in two directions that have not been resolved in the literature: (1) A tuning-free pipeline without tedious case-by-case finetuning, and (2) A frequency-aware heuristic identity-preserving Diffusion Transformer (DiT)-based control scheme. To achieve these goals, we propose **ConsisID**, a tuning-free DiT-based controllable IPT2V model to keep human-**id**entity **consis**tent in the generated video. Inspired by prior findings in frequency analysis of vision/diffusion transformers, it employs identity-control signals in the frequency domain, where facial features can be decomposed into low-frequency global features (e.g., profile, proportions) and high-frequency intrinsic features (e.g., identity markers that remain unaffected by pose changes). First, from a low-frequency perspective, we introduce a global facial extractor, which encodes the reference image and facial key points into a latent space, generating features enriched with low-frequency information. These features are then integrated into the shallow layers of the network to alleviate training challenges associated with DiT. Second, from a high-frequency perspective, we design a local facial extractor to capture high-frequency details and inject them into the transformer blocks, enhancing the model's ability to preserve fine-grained features. To leverage the frequency information for identity preservation, we propose a hierarchical training strategy, transforming a vanilla pre-trained video generation model into an IPT2V model. Extensive experiments demonstrate that our frequency-aware heuristic scheme provides an optimal control solution for DiT-based models. Thanks to this scheme, our **ConsisID** achieves excellent results in generating high-quality, identity-preserving videos, making strides towards more effective IPT2V. The model weight of ConsID is publicly available at https://github.com/PKU-YuanGroup/ConsisID.*\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\nThis pipeline was contributed by [SHYuanBest](https://github.com/SHYuanBest). The original codebase can be found [here](https://github.com/PKU-YuanGroup/ConsisID). The original weights can be found under [hf.co/BestWishYsh](https://huggingface.co/BestWishYsh).\n\nThere are two official ConsisID checkpoints for identity-preserving text-to-video.\n\n| checkpoints | recommended inference dtype |\n|:---:|:---:|\n| [`BestWishYsh/ConsisID-preview`](https://huggingface.co/BestWishYsh/ConsisID-preview) | torch.bfloat16 |\n| [`BestWishYsh/ConsisID-1.5`](https://huggingface.co/BestWishYsh/ConsisID-preview) | torch.bfloat16 |\n\n### Memory optimization\n\nConsisID requires about 44 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/SHYuanBest/bc4207c36f454f9e969adbb50eaf8258) script.\n\n| Feature (overlay the previous) | Max Memory Allocated | Max Memory Reserved |\n| :----------------------------- | :------------------- | :------------------ |\n| -                              | 37 GB                | 44 GB               |\n| enable_model_cpu_offload       | 22 GB                | 25 GB               |\n| enable_sequential_cpu_offload  | 16 GB                | 22 GB               |\n| vae.enable_slicing             | 16 GB                | 22 GB               |\n| vae.enable_tiling              | 5 GB                 | 7 GB                |\n\n## ConsisIDPipeline\n\n[[autodoc]] ConsisIDPipeline\n\n  - all\n  - __call__\n\n## ConsisIDPipelineOutput\n\n[[autodoc]] pipelines.consisid.pipeline_output.ConsisIDPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/consistency_models.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Consistency Models\n\nConsistency Models were proposed in [Consistency Models](https://huggingface.co/papers/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever.\n\nThe abstract from the paper is:\n\n*Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256.*\n\nThe original codebase can be found at [openai/consistency_models](https://github.com/openai/consistency_models), and additional checkpoints are available at [openai](https://huggingface.co/openai).\n\nThe pipeline was contributed by [dg845](https://github.com/dg845) and [ayushtues](https://huggingface.co/ayushtues). ❤️\n\n## Tips\n\nFor an additional speed-up, use `torch.compile` to generate multiple images in <1 second:\n\n```diff\n  import torch\n  from diffusers import ConsistencyModelPipeline\n\n  device = \"cuda\"\n  # Load the cd_bedroom256_lpips checkpoint.\n  model_id_or_path = \"openai/diffusers-cd_bedroom256_lpips\"\n  pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)\n  pipe.to(device)\n\n+ pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n\n  # Multistep sampling\n  # Timesteps can be explicitly specified; the particular timesteps below are from the original GitHub repo:\n  # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L83\n  for _ in range(10):\n      image = pipe(timesteps=[17, 0]).images[0]\n      image.show()\n```\n\n\n## ConsistencyModelPipeline\n[[autodoc]] ConsistencyModelPipeline\n    - all\n    - __call__\n\n## ImagePipelineOutput\n[[autodoc]] pipelines.ImagePipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/control_flux_inpaint.md",
    "content": "<!--Copyright 2025 The HuggingFace Team, The Black Forest Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# FluxControlInpaint\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nFluxControlInpaintPipeline is an implementation of Inpainting for Flux.1 Depth/Canny models. It is a pipeline that allows you to inpaint images using the Flux.1 Depth/Canny models. The pipeline takes an image and a mask as input and returns the inpainted image.\n\nFLUX.1 Depth and Canny [dev] is a 12 billion parameter rectified flow transformer capable of generating an image based on a text description while following the structure of a given input image. **This is not a ControlNet model**.\n\n| Control type | Developer | Link |\n| -------- | ---------- | ---- |\n| Depth | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |\n| Canny | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |\n\n\n> [!TIP]\n> Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).\n\n```python\nimport torch\nfrom diffusers import FluxControlInpaintPipeline\nfrom diffusers.models.transformers import FluxTransformer2DModel\nfrom transformers import T5EncoderModel\nfrom diffusers.utils import load_image, make_image_grid\nfrom image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux\nfrom PIL import Image\nimport numpy as np\n\npipe = FluxControlInpaintPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-Depth-dev\",\n    torch_dtype=torch.bfloat16,\n)\n# use following lines if you have GPU constraints\n# ---------------------------------------------------------------\ntransformer = FluxTransformer2DModel.from_pretrained(\n    \"sayakpaul/FLUX.1-Depth-dev-nf4\", subfolder=\"transformer\", torch_dtype=torch.bfloat16\n)\ntext_encoder_2 = T5EncoderModel.from_pretrained(\n    \"sayakpaul/FLUX.1-Depth-dev-nf4\", subfolder=\"text_encoder_2\", torch_dtype=torch.bfloat16\n)\npipe.transformer = transformer\npipe.text_encoder_2 = text_encoder_2\npipe.enable_model_cpu_offload()\n# ---------------------------------------------------------------\npipe.to(\"cuda\")\n\nprompt = \"a blue robot singing opera with human-like expressions\"\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png\")\n\nhead_mask = np.zeros_like(image)\nhead_mask[65:580,300:642] = 255\nmask_image = Image.fromarray(head_mask)\n\nprocessor = DepthPreprocessor.from_pretrained(\"LiheYoung/depth-anything-large-hf\")\ncontrol_image = processor(image)[0].convert(\"RGB\")\n\noutput = pipe(\n    prompt=prompt,\n    image=image,\n    control_image=control_image,\n    mask_image=mask_image,\n    num_inference_steps=30,\n    strength=0.9,\n    guidance_scale=10.0,\n    generator=torch.Generator().manual_seed(42),\n).images[0]\nmake_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save(\"output.png\")\n```\n\n## FluxControlInpaintPipeline\n[[autodoc]] FluxControlInpaintPipeline\n\t- all\n\t- __call__\n\n\n## FluxPipelineOutput\n[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput"
  },
  {
    "path": "docs/source/en/api/pipelines/controlnet.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNet\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.\n\nWith a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.\n\nThe abstract from the paper is:\n\n*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with \"zero convolutions\" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*\n\nThis model was contributed by [takuma104](https://huggingface.co/takuma104). ❤️\n\nThe original codebase can be found at [lllyasviel/ControlNet](https://github.com/lllyasviel/ControlNet), and you can find official ControlNet checkpoints on [lllyasviel's](https://huggingface.co/lllyasviel) Hub profile.\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## StableDiffusionControlNetPipeline\n[[autodoc]] StableDiffusionControlNetPipeline\n\t- all\n\t- __call__\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n\t- enable_vae_slicing\n\t- disable_vae_slicing\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\t- load_textual_inversion\n\n## StableDiffusionControlNetImg2ImgPipeline\n[[autodoc]] StableDiffusionControlNetImg2ImgPipeline\n\t- all\n\t- __call__\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n\t- enable_vae_slicing\n\t- disable_vae_slicing\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\t- load_textual_inversion\n\n## StableDiffusionControlNetInpaintPipeline\n[[autodoc]] StableDiffusionControlNetInpaintPipeline\n\t- all\n\t- __call__\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n\t- enable_vae_slicing\n\t- disable_vae_slicing\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\t- load_textual_inversion\n\n## StableDiffusionPipelineOutput\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/controlnet_flux.md",
    "content": "<!--Copyright 2025 The HuggingFace Team, The InstantX Team, and the XLabs Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNet with Flux.1\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nFluxControlNetPipeline is an implementation of ControlNet for Flux.1.\n\nControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.\n\nWith a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.\n\nThe abstract from the paper is:\n\n*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with \"zero convolutions\" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*\n\nThis controlnet code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for Flux-ControlNet in the table below:\n\n\n| ControlNet type | Developer | Link |\n| -------- | ---------- | ---- |\n| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny) |\n| Depth | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Depth) |\n| Union | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union) |\n\nXLabs ControlNets are also supported, which was contributed by the [XLabs team](https://huggingface.co/XLabs-AI).\n\n| ControlNet type | Developer | Link |\n| -------- | ---------- | ---- |\n| Canny | [The XLabs Team](https://huggingface.co/XLabs-AI) | [Link](https://huggingface.co/XLabs-AI/flux-controlnet-canny-diffusers) |\n| Depth | [The XLabs Team](https://huggingface.co/XLabs-AI) | [Link](https://huggingface.co/XLabs-AI/flux-controlnet-depth-diffusers) |\n| HED | [The XLabs Team](https://huggingface.co/XLabs-AI) | [Link](https://huggingface.co/XLabs-AI/flux-controlnet-hed-diffusers) |\n\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## FluxControlNetPipeline\n[[autodoc]] FluxControlNetPipeline\n\t- all\n\t- __call__\n\n\n## FluxPipelineOutput\n[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput"
  },
  {
    "path": "docs/source/en/api/pipelines/controlnet_hunyuandit.md",
    "content": "<!--Copyright 2025 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNet with Hunyuan-DiT\n\nHunyuanDiTControlNetPipeline is an implementation of ControlNet for [Hunyuan-DiT](https://huggingface.co/papers/2405.08748).\n\nControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.\n\nWith a ControlNet model, you can provide an additional control image to condition and control Hunyuan-DiT generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.\n\nThe abstract from the paper is:\n\n*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with \"zero convolutions\" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*\n\nThis code is implemented by Tencent Hunyuan Team. You can find pre-trained checkpoints for Hunyuan-DiT ControlNets on [Tencent Hunyuan](https://huggingface.co/Tencent-Hunyuan).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## HunyuanDiTControlNetPipeline\n[[autodoc]] HunyuanDiTControlNetPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/controlnet_sana.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNet\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.\n\nWith a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.\n\nThe abstract from the paper is:\n\n*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with \"zero convolutions\" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*\n\nThis pipeline was contributed by [ishan24](https://huggingface.co/ishan24). ❤️\nThe original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.\n\n## SanaControlNetPipeline\n[[autodoc]] SanaControlNetPipeline\n\t- all\n\t- __call__\n\n## SanaPipelineOutput\n[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput"
  },
  {
    "path": "docs/source/en/api/pipelines/controlnet_sd3.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNet with Stable Diffusion 3\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nStableDiffusion3ControlNetPipeline is an implementation of ControlNet for Stable Diffusion 3.\n\nControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.\n\nWith a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.\n\nThe abstract from the paper is:\n\n*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with \"zero convolutions\" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*\n\nThis controlnet code is mainly implemented by [The InstantX Team](https://huggingface.co/InstantX). The inpainting-related code was developed by [The Alimama Creative Team](https://huggingface.co/alimama-creative). You can find pre-trained checkpoints for SD3-ControlNet in the table below:\n\n\n| ControlNet type | Developer | Link |\n| -------- | ---------- | ---- |\n| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Canny) |\n| Depth | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Depth) |\n| Pose | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Pose) |\n| Tile | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Tile) |\n| Inpainting | [The AlimamaCreative Team](https://huggingface.co/alimama-creative) | [link](https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting) |\n\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## StableDiffusion3ControlNetPipeline\n[[autodoc]] StableDiffusion3ControlNetPipeline\n\t- all\n\t- __call__\n\n## StableDiffusion3ControlNetInpaintingPipeline\n[[autodoc]] pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet_inpainting.StableDiffusion3ControlNetInpaintingPipeline\n\t- all\n\t- __call__\n\n## StableDiffusion3PipelineOutput\n[[autodoc]] pipelines.stable_diffusion_3.pipeline_output.StableDiffusion3PipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/controlnet_sdxl.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNet with Stable Diffusion XL\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.\n\nWith a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.\n\nThe abstract from the paper is:\n\n*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with \"zero convolutions\" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*\n\nYou can find additional smaller Stable Diffusion XL (SDXL) ControlNet checkpoints from the 🤗 [Diffusers](https://huggingface.co/diffusers) Hub organization, and browse [community-trained](https://huggingface.co/models?other=stable-diffusion-xl&other=controlnet) checkpoints on the Hub.\n\n> [!WARNING]\n> 🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!\n\nIf you don't see a checkpoint you're interested in, you can train your own SDXL ControlNet with our [training script](../../../../../examples/controlnet/README_sdxl).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## StableDiffusionXLControlNetPipeline\n[[autodoc]] StableDiffusionXLControlNetPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLControlNetImg2ImgPipeline\n[[autodoc]] StableDiffusionXLControlNetImg2ImgPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLControlNetInpaintPipeline\n[[autodoc]] StableDiffusionXLControlNetInpaintPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionPipelineOutput\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/controlnet_union.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNetUnion\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nControlNetUnionModel is an implementation of ControlNet for Stable Diffusion XL.\n\nThe ControlNet model was introduced in [ControlNetPlus](https://github.com/xinsir6/ControlNetPlus) by xinsir6. It supports multiple conditioning inputs without increasing computation.\n\n*We design a new architecture that can support 10+ control types in condition text-to-image generation and can generate high resolution images visually comparable with midjourney. The network is based on the original ControlNet architecture, we propose two new modules to: 1 Extend the original ControlNet to support different image conditions using the same network parameter. 2 Support multiple conditions input without increasing computation offload, which is especially important for designers who want to edit image in detail, different conditions use the same condition encoder, without adding extra computations or parameters.*\n\n\n## StableDiffusionXLControlNetUnionPipeline\n[[autodoc]] StableDiffusionXLControlNetUnionPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLControlNetUnionImg2ImgPipeline\n[[autodoc]] StableDiffusionXLControlNetUnionImg2ImgPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLControlNetUnionInpaintPipeline\n[[autodoc]] StableDiffusionXLControlNetUnionInpaintPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/controlnetxs.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# ControlNet-XS\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.\n\nLike the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.\n\nControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb) with StableDiffusion-XL) and uses ~45% less memory.\n\nHere's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/):\n\n*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.*\n\nThis model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## StableDiffusionControlNetXSPipeline\n[[autodoc]] StableDiffusionControlNetXSPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionPipelineOutput\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/controlnetxs_sdxl.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# ControlNet-XS with Stable Diffusion XL\n\nControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.\n\nLike the original ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.\n\nControlNet-XS generates images with comparable quality to a regular ControlNet, but it is 20-25% faster ([see benchmark](https://github.com/UmerHA/controlnet-xs-benchmark/blob/main/Speed%20Benchmark.ipynb)) and uses ~45% less memory.\n\nHere's the overview from the [project page](https://vislearn.github.io/ControlNet-XS/):\n\n*With increasing computing capabilities, current model architectures appear to follow the trend of simply upscaling all components without validating the necessity for doing so. In this project we investigate the size and architectural design of ControlNet [Zhang et al., 2023] for controlling the image generation process with stable diffusion-based models. We show that a new architecture with as little as 1% of the parameters of the base model achieves state-of-the art results, considerably better than ControlNet in terms of FID score. Hence we call it ControlNet-XS. We provide the code for controlling StableDiffusion-XL [Podell et al., 2023] (Model B, 48M Parameters) and StableDiffusion 2.1 [Rombach et al. 2022] (Model B, 14M Parameters), all under openrail license.*\n\nThis model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️\n\n> [!WARNING]\n> 🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## StableDiffusionXLControlNetXSPipeline\n[[autodoc]] StableDiffusionXLControlNetXSPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionPipelineOutput\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/cosmos.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n# Cosmos\n\n[Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.\n\n*Physical AI needs to be trained digitally first. It needs a digital twin of itself, the policy model, and a digital twin of the world, the world model. In this paper, we present the Cosmos World Foundation Model Platform to help developers build customized world models for their Physical AI setups. We position a world foundation model as a general-purpose world model that can be fine-tuned into customized world models for downstream applications. Our platform covers a video curation pipeline, pre-trained world foundation models, examples of post-training of pre-trained world foundation models, and video tokenizers. To help Physical AI builders solve the most critical problems of our society, we make our platform open-source and our models open-weight with permissive licenses available via https://github.com/NVIDIA/Cosmos.*\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## Basic usage\n\n```python\nimport torch\nfrom diffusers import Cosmos2_5_PredictBasePipeline\nfrom diffusers.utils import export_to_video\n\nmodel_id = \"nvidia/Cosmos-Predict2.5-2B\"\npipe = Cosmos2_5_PredictBasePipeline.from_pretrained(\n    model_id, revision=\"diffusers/base/post-trained\", torch_dtype=torch.bfloat16\n)\npipe.to(\"cuda\")\n\nprompt = \"As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow advance of traffic through the frosty city corridor.\"\nnegative_prompt = \"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.\"\n\noutput = pipe(\n    image=None,\n    video=None,\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=93,\n    generator=torch.Generator().manual_seed(1),\n).frames[0]\nexport_to_video(output, \"text2world.mp4\", fps=16)\n```\n\n## Cosmos2_5_TransferPipeline\n\n[[autodoc]] Cosmos2_5_TransferPipeline\n  - all\n  - __call__\n\n\n## Cosmos2_5_PredictBasePipeline\n\n[[autodoc]] Cosmos2_5_PredictBasePipeline\n  - all\n  - __call__\n\n\n## CosmosTextToWorldPipeline\n\n[[autodoc]] CosmosTextToWorldPipeline\n  - all\n  - __call__\n\n## CosmosVideoToWorldPipeline\n\n[[autodoc]] CosmosVideoToWorldPipeline\n  - all\n  - __call__\n\n## Cosmos2TextToImagePipeline\n\n[[autodoc]] Cosmos2TextToImagePipeline\n  - all\n  - __call__\n\n## Cosmos2VideoToWorldPipeline\n\n[[autodoc]] Cosmos2VideoToWorldPipeline\n  - all\n  - __call__\n\n## CosmosPipelineOutput\n\n[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput\n\n## CosmosImagePipelineOutput\n\n[[autodoc]] pipelines.cosmos.pipeline_output.CosmosImagePipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/dance_diffusion.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# Dance Diffusion\n\n[Dance Diffusion](https://github.com/Harmonai-org/sample-generator) is by Zach Evans.\n\nDance Diffusion is the first in a suite of generative audio tools for producers and musicians released by [Harmonai](https://github.com/Harmonai-org).\n\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## DanceDiffusionPipeline\n[[autodoc]] DanceDiffusionPipeline\n\t- all\n\t- __call__\n\n## AudioPipelineOutput\n[[autodoc]] pipelines.AudioPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/ddim.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DDIM\n\n[Denoising Diffusion Implicit Models](https://huggingface.co/papers/2010.02502) (DDIM) by Jiaming Song, Chenlin Meng and Stefano Ermon.\n\nThe abstract from the paper is:\n\n*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.*\n\nThe original codebase can be found at [ermongroup/ddim](https://github.com/ermongroup/ddim).\n\n## DDIMPipeline\n[[autodoc]] DDIMPipeline\n\t- all\n\t- __call__\n\n## ImagePipelineOutput\n[[autodoc]] pipelines.ImagePipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/ddpm.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DDPM\n\n[Denoising Diffusion Probabilistic Models](https://huggingface.co/papers/2006.11239) (DDPM) by Jonathan Ho, Ajay Jain and Pieter Abbeel proposes a diffusion based model of the same name. In the 🤗 Diffusers library, DDPM refers to the *discrete denoising scheduler* from the paper as well as the pipeline.\n\nThe abstract from the paper is:\n\n*We present high quality image synthesis results using diffusion probabilistic models, a class of latent variable models inspired by considerations from nonequilibrium thermodynamics. Our best results are obtained by training on a weighted variational bound designed according to a novel connection between diffusion probabilistic models and denoising score matching with Langevin dynamics, and our models naturally admit a progressive lossy decompression scheme that can be interpreted as a generalization of autoregressive decoding. On the unconditional CIFAR10 dataset, we obtain an Inception score of 9.46 and a state-of-the-art FID score of 3.17. On 256x256 LSUN, we obtain sample quality similar to ProgressiveGAN.*\n\nThe original codebase can be found at [hohonathanho/diffusion](https://github.com/hojonathanho/diffusion).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n# DDPMPipeline\n[[autodoc]] DDPMPipeline\n\t- all\n\t- __call__\n\n## ImagePipelineOutput\n[[autodoc]] pipelines.ImagePipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/deepfloyd_if.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DeepFloyd IF\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n  <img alt=\"MPS\" src=\"https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22\">\n</div>\n\n## Overview\n\nDeepFloyd IF is a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding.\nThe model is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules:\n- Stage 1: a base model that generates 64x64 px image based on text prompt,\n- Stage 2: a 64x64 px => 256x256 px super-resolution model, and\n- Stage 3: a 256x256 px => 1024x1024 px super-resolution model\nStage 1 and Stage 2 utilize a frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture enhanced with cross-attention and attention pooling.\nStage 3 is [Stability AI's x4 Upscaling model](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler).\nThe result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset.\nOur work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis.\n\n## Usage\n\nBefore you can use IF, you need to accept its usage conditions. To do so:\n1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be logged in.\n2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0). Accepting the license on the stage I model card will auto accept for the other IF models.\n3. Make sure to login locally. Install `huggingface_hub`:\n```sh\npip install huggingface_hub --upgrade\n```\n\nrun the login function in a Python shell:\n\n```py\nfrom huggingface_hub import login\n\nlogin()\n```\n\nand enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens).\n\nNext we install `diffusers` and dependencies:\n\n```sh\npip install -q diffusers accelerate transformers\n```\n\nThe following sections give more in-detail examples of how to use IF. Specifically:\n\n- [Text-to-Image Generation](#text-to-image-generation)\n- [Image-to-Image Generation](#text-guided-image-to-image-generation)\n- [Inpainting](#text-guided-inpainting-generation)\n- [Reusing model weights](#converting-between-different-pipelines)\n- [Speed optimization](#optimizing-for-speed)\n- [Memory optimization](#optimizing-for-memory)\n\n**Available checkpoints**\n- *Stage-1*\n  - [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)\n  - [DeepFloyd/IF-I-L-v1.0](https://huggingface.co/DeepFloyd/IF-I-L-v1.0)\n  - [DeepFloyd/IF-I-M-v1.0](https://huggingface.co/DeepFloyd/IF-I-M-v1.0)\n\n- *Stage-2*\n  - [DeepFloyd/IF-II-L-v1.0](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)\n  - [DeepFloyd/IF-II-M-v1.0](https://huggingface.co/DeepFloyd/IF-II-M-v1.0)\n\n- *Stage-3*\n  - [stabilityai/stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler)\n\n\n**Google Colab**\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb)\n\n### Text-to-Image Generation\n\nBy default diffusers makes use of [model cpu offloading](../../optimization/memory#model-offloading) to run the whole IF pipeline with as little as 14 GB of VRAM.\n\n```python\nfrom diffusers import DiffusionPipeline\nfrom diffusers.utils import pt_to_pil, make_image_grid\nimport torch\n\n# stage 1\nstage_1 = DiffusionPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\", variant=\"fp16\", torch_dtype=torch.float16)\nstage_1.enable_model_cpu_offload()\n\n# stage 2\nstage_2 = DiffusionPipeline.from_pretrained(\n    \"DeepFloyd/IF-II-L-v1.0\", text_encoder=None, variant=\"fp16\", torch_dtype=torch.float16\n)\nstage_2.enable_model_cpu_offload()\n\n# stage 3\nsafety_modules = {\n    \"feature_extractor\": stage_1.feature_extractor,\n    \"safety_checker\": stage_1.safety_checker,\n    \"watermarker\": stage_1.watermarker,\n}\nstage_3 = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-x4-upscaler\", **safety_modules, torch_dtype=torch.float16\n)\nstage_3.enable_model_cpu_offload()\n\nprompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says \"very deep learning\"'\ngenerator = torch.manual_seed(1)\n\n# text embeds\nprompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)\n\n# stage 1\nstage_1_output = stage_1(\n    prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type=\"pt\"\n).images\n#pt_to_pil(stage_1_output)[0].save(\"./if_stage_I.png\")\n\n# stage 2\nstage_2_output = stage_2(\n    image=stage_1_output,\n    prompt_embeds=prompt_embeds,\n    negative_prompt_embeds=negative_embeds,\n    generator=generator,\n    output_type=\"pt\",\n).images\n#pt_to_pil(stage_2_output)[0].save(\"./if_stage_II.png\")\n\n# stage 3\nstage_3_output = stage_3(prompt=prompt, image=stage_2_output, noise_level=100, generator=generator).images\n#stage_3_output[0].save(\"./if_stage_III.png\")\nmake_image_grid([pt_to_pil(stage_1_output)[0], pt_to_pil(stage_2_output)[0], stage_3_output[0]], rows=1, rows=3)\n```\n\n### Text Guided Image-to-Image Generation\n\nThe same IF model weights can be used for text-guided image-to-image translation or image variation.\nIn this case just make sure to load the weights using the [`IFImg2ImgPipeline`] and [`IFImg2ImgSuperResolutionPipeline`] pipelines.\n\n**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines\nwithout loading them twice by making use of the [`~DiffusionPipeline.components`] argument as explained [here](#converting-between-different-pipelines).\n\n```python\nfrom diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline\nfrom diffusers.utils import pt_to_pil, load_image, make_image_grid\nimport torch\n\n# download image\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\noriginal_image = load_image(url)\noriginal_image = original_image.resize((768, 512))\n\n# stage 1\nstage_1 = IFImg2ImgPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\", variant=\"fp16\", torch_dtype=torch.float16)\nstage_1.enable_model_cpu_offload()\n\n# stage 2\nstage_2 = IFImg2ImgSuperResolutionPipeline.from_pretrained(\n    \"DeepFloyd/IF-II-L-v1.0\", text_encoder=None, variant=\"fp16\", torch_dtype=torch.float16\n)\nstage_2.enable_model_cpu_offload()\n\n# stage 3\nsafety_modules = {\n    \"feature_extractor\": stage_1.feature_extractor,\n    \"safety_checker\": stage_1.safety_checker,\n    \"watermarker\": stage_1.watermarker,\n}\nstage_3 = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-x4-upscaler\", **safety_modules, torch_dtype=torch.float16\n)\nstage_3.enable_model_cpu_offload()\n\nprompt = \"A fantasy landscape in style minecraft\"\ngenerator = torch.manual_seed(1)\n\n# text embeds\nprompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)\n\n# stage 1\nstage_1_output = stage_1(\n    image=original_image,\n    prompt_embeds=prompt_embeds,\n    negative_prompt_embeds=negative_embeds,\n    generator=generator,\n    output_type=\"pt\",\n).images\n#pt_to_pil(stage_1_output)[0].save(\"./if_stage_I.png\")\n\n# stage 2\nstage_2_output = stage_2(\n    image=stage_1_output,\n    original_image=original_image,\n    prompt_embeds=prompt_embeds,\n    negative_prompt_embeds=negative_embeds,\n    generator=generator,\n    output_type=\"pt\",\n).images\n#pt_to_pil(stage_2_output)[0].save(\"./if_stage_II.png\")\n\n# stage 3\nstage_3_output = stage_3(prompt=prompt, image=stage_2_output, generator=generator, noise_level=100).images\n#stage_3_output[0].save(\"./if_stage_III.png\")\nmake_image_grid([original_image, pt_to_pil(stage_1_output)[0], pt_to_pil(stage_2_output)[0], stage_3_output[0]], rows=1, rows=4)\n```\n\n### Text Guided Inpainting Generation\n\nThe same IF model weights can be used for text-guided image-to-image translation or image variation.\nIn this case just make sure to load the weights using the [`IFInpaintingPipeline`] and [`IFInpaintingSuperResolutionPipeline`] pipelines.\n\n**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines\nwithout loading them twice by making use of the [`~DiffusionPipeline.components()`] function as explained [here](#converting-between-different-pipelines).\n\n```python\nfrom diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline\nfrom diffusers.utils import pt_to_pil, load_image, make_image_grid\nimport torch\n\n# download image\nurl = \"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png\"\noriginal_image = load_image(url)\n\n# download mask\nurl = \"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png\"\nmask_image = load_image(url)\n\n# stage 1\nstage_1 = IFInpaintingPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\", variant=\"fp16\", torch_dtype=torch.float16)\nstage_1.enable_model_cpu_offload()\n\n# stage 2\nstage_2 = IFInpaintingSuperResolutionPipeline.from_pretrained(\n    \"DeepFloyd/IF-II-L-v1.0\", text_encoder=None, variant=\"fp16\", torch_dtype=torch.float16\n)\nstage_2.enable_model_cpu_offload()\n\n# stage 3\nsafety_modules = {\n    \"feature_extractor\": stage_1.feature_extractor,\n    \"safety_checker\": stage_1.safety_checker,\n    \"watermarker\": stage_1.watermarker,\n}\nstage_3 = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-x4-upscaler\", **safety_modules, torch_dtype=torch.float16\n)\nstage_3.enable_model_cpu_offload()\n\nprompt = \"blue sunglasses\"\ngenerator = torch.manual_seed(1)\n\n# text embeds\nprompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)\n\n# stage 1\nstage_1_output = stage_1(\n    image=original_image,\n    mask_image=mask_image,\n    prompt_embeds=prompt_embeds,\n    negative_prompt_embeds=negative_embeds,\n    generator=generator,\n    output_type=\"pt\",\n).images\n#pt_to_pil(stage_1_output)[0].save(\"./if_stage_I.png\")\n\n# stage 2\nstage_2_output = stage_2(\n    image=stage_1_output,\n    original_image=original_image,\n    mask_image=mask_image,\n    prompt_embeds=prompt_embeds,\n    negative_prompt_embeds=negative_embeds,\n    generator=generator,\n    output_type=\"pt\",\n).images\n#pt_to_pil(stage_1_output)[0].save(\"./if_stage_II.png\")\n\n# stage 3\nstage_3_output = stage_3(prompt=prompt, image=stage_2_output, generator=generator, noise_level=100).images\n#stage_3_output[0].save(\"./if_stage_III.png\")\nmake_image_grid([original_image, mask_image, pt_to_pil(stage_1_output)[0], pt_to_pil(stage_2_output)[0], stage_3_output[0]], rows=1, rows=5)\n```\n\n### Converting between different pipelines\n\nIn addition to being loaded with `from_pretrained`, Pipelines can also be loaded directly from each other.\n\n```python\nfrom diffusers import IFPipeline, IFSuperResolutionPipeline\n\npipe_1 = IFPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\")\npipe_2 = IFSuperResolutionPipeline.from_pretrained(\"DeepFloyd/IF-II-L-v1.0\")\n\n\nfrom diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline\n\npipe_1 = IFImg2ImgPipeline(**pipe_1.components)\npipe_2 = IFImg2ImgSuperResolutionPipeline(**pipe_2.components)\n\n\nfrom diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline\n\npipe_1 = IFInpaintingPipeline(**pipe_1.components)\npipe_2 = IFInpaintingSuperResolutionPipeline(**pipe_2.components)\n```\n\n### Optimizing for speed\n\nThe simplest optimization to run IF faster is to move all model components to the GPU.\n\n```py\npipe = DiffusionPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\", variant=\"fp16\", torch_dtype=torch.float16)\npipe.to(\"cuda\")\n```\n\nYou can also run the diffusion process for a shorter number of timesteps.\n\nThis can either be done with the `num_inference_steps` argument:\n\n```py\npipe(\"<prompt>\", num_inference_steps=30)\n```\n\nOr with the `timesteps` argument:\n\n```py\nfrom diffusers.pipelines.deepfloyd_if import fast27_timesteps\n\npipe(\"<prompt>\", timesteps=fast27_timesteps)\n```\n\nWhen doing image variation or inpainting, you can also decrease the number of timesteps\nwith the strength argument. The strength argument is the amount of noise to add to the input image which also determines how many steps to run in the denoising process.\nA smaller number will vary the image less but run faster.\n\n```py\npipe = IFImg2ImgPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\", variant=\"fp16\", torch_dtype=torch.float16)\npipe.to(\"cuda\")\n\nimage = pipe(image=image, prompt=\"<prompt>\", strength=0.3).images\n```\n\nYou can also use [`torch.compile`](../../optimization/fp16#torchcompile). Note that we have not exhaustively tested `torch.compile`\nwith IF and it might not give expected results.\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipe = DiffusionPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\", variant=\"fp16\", torch_dtype=torch.float16)\npipe.to(\"cuda\")\n\npipe.text_encoder = torch.compile(pipe.text_encoder, mode=\"reduce-overhead\", fullgraph=True)\npipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\n### Optimizing for memory\n\nWhen optimizing for GPU memory, we can use the standard diffusers CPU offloading APIs.\n\nEither the model based CPU offloading,\n\n```py\npipe = DiffusionPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\", variant=\"fp16\", torch_dtype=torch.float16)\npipe.enable_model_cpu_offload()\n```\n\nor the more aggressive layer based CPU offloading.\n\n```py\npipe = DiffusionPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\", variant=\"fp16\", torch_dtype=torch.float16)\npipe.enable_sequential_cpu_offload()\n```\n\nAdditionally, T5 can be loaded in 8bit precision\n\n```py\nfrom transformers import T5EncoderModel\n\ntext_encoder = T5EncoderModel.from_pretrained(\n    \"DeepFloyd/IF-I-XL-v1.0\", subfolder=\"text_encoder\", device_map=\"auto\", load_in_8bit=True, variant=\"8bit\"\n)\n\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\n    \"DeepFloyd/IF-I-XL-v1.0\",\n    text_encoder=text_encoder,  # pass the previously instantiated 8bit text encoder\n    unet=None,\n    device_map=\"auto\",\n)\n\nprompt_embeds, negative_embeds = pipe.encode_prompt(\"<prompt>\")\n```\n\nFor CPU RAM constrained machines like Google Colab free tier where we can't load all model components to the CPU at once, we can manually only load the pipeline with\nthe text encoder or UNet when the respective model components are needed.\n\n```py\nfrom diffusers import IFPipeline, IFSuperResolutionPipeline\nimport torch\nimport gc\nfrom transformers import T5EncoderModel\nfrom diffusers.utils import pt_to_pil, make_image_grid\n\ntext_encoder = T5EncoderModel.from_pretrained(\n    \"DeepFloyd/IF-I-XL-v1.0\", subfolder=\"text_encoder\", device_map=\"auto\", load_in_8bit=True, variant=\"8bit\"\n)\n\n# text to image\npipe = DiffusionPipeline.from_pretrained(\n    \"DeepFloyd/IF-I-XL-v1.0\",\n    text_encoder=text_encoder,  # pass the previously instantiated 8bit text encoder\n    unet=None,\n    device_map=\"auto\",\n)\n\nprompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says \"very deep learning\"'\nprompt_embeds, negative_embeds = pipe.encode_prompt(prompt)\n\n# Remove the pipeline so we can re-load the pipeline with the unet\ndel text_encoder\ndel pipe\ngc.collect()\ntorch.cuda.empty_cache()\n\npipe = IFPipeline.from_pretrained(\n    \"DeepFloyd/IF-I-XL-v1.0\", text_encoder=None, variant=\"fp16\", torch_dtype=torch.float16, device_map=\"auto\"\n)\n\ngenerator = torch.Generator().manual_seed(0)\nstage_1_output = pipe(\n    prompt_embeds=prompt_embeds,\n    negative_prompt_embeds=negative_embeds,\n    output_type=\"pt\",\n    generator=generator,\n).images\n\n#pt_to_pil(stage_1_output)[0].save(\"./if_stage_I.png\")\n\n# Remove the pipeline so we can load the super-resolution pipeline\ndel pipe\ngc.collect()\ntorch.cuda.empty_cache()\n\n# First super resolution\n\npipe = IFSuperResolutionPipeline.from_pretrained(\n    \"DeepFloyd/IF-II-L-v1.0\", text_encoder=None, variant=\"fp16\", torch_dtype=torch.float16, device_map=\"auto\"\n)\n\ngenerator = torch.Generator().manual_seed(0)\nstage_2_output = pipe(\n    image=stage_1_output,\n    prompt_embeds=prompt_embeds,\n    negative_prompt_embeds=negative_embeds,\n    output_type=\"pt\",\n    generator=generator,\n).images\n\n#pt_to_pil(stage_2_output)[0].save(\"./if_stage_II.png\")\nmake_image_grid([pt_to_pil(stage_1_output)[0], pt_to_pil(stage_2_output)[0]], rows=1, rows=2)\n```\n\n## Available Pipelines:\n\n| Pipeline | Tasks | Colab\n|---|---|:---:|\n| [pipeline_if.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py) | *Text-to-Image Generation* | - |\n| [pipeline_if_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py) | *Text-to-Image Generation* | - |\n| [pipeline_if_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py) | *Image-to-Image Generation* | - |\n| [pipeline_if_img2img_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py) | *Image-to-Image Generation* | - |\n| [pipeline_if_inpainting.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py) | *Image-to-Image Generation* | - |\n| [pipeline_if_inpainting_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py) | *Image-to-Image Generation* | - |\n\n## IFPipeline\n[[autodoc]] IFPipeline\n\t- all\n\t- __call__\n\n## IFSuperResolutionPipeline\n[[autodoc]] IFSuperResolutionPipeline\n\t- all\n\t- __call__\n\n## IFImg2ImgPipeline\n[[autodoc]] IFImg2ImgPipeline\n\t- all\n\t- __call__\n\n## IFImg2ImgSuperResolutionPipeline\n[[autodoc]] IFImg2ImgSuperResolutionPipeline\n\t- all\n\t- __call__\n\n## IFInpaintingPipeline\n[[autodoc]] IFInpaintingPipeline\n\t- all\n\t- __call__\n\n## IFInpaintingSuperResolutionPipeline\n[[autodoc]] IFInpaintingSuperResolutionPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/diffedit.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# DiffEdit\n\n[DiffEdit: Diffusion-based semantic image editing with mask guidance](https://huggingface.co/papers/2210.11427) is by Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord.\n\nThe abstract from the paper is:\n\n*Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.*\n\nThe original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/posts/2022-11-02-diffedit-implementation.html).\n\nThis pipeline was contributed by [clarencechen](https://github.com/clarencechen). ❤️\n\n## Tips\n\n* The pipeline can generate masks that can be fed into other inpainting pipelines.\n* In order to generate an image using this pipeline, both an image mask (source and target prompts can be manually specified or generated, and passed to [`~StableDiffusionDiffEditPipeline.generate_mask`])\nand a set of partially inverted latents (generated using [`~StableDiffusionDiffEditPipeline.invert`]) _must_ be provided as arguments when calling the pipeline to generate the final edited image.\n* The function [`~StableDiffusionDiffEditPipeline.generate_mask`] exposes two prompt arguments, `source_prompt` and `target_prompt`\nthat let you control the locations of the semantic edits in the final image to be generated. Let's say,\nyou wanted to translate from \"cat\" to \"dog\". In this case, the edit direction will be \"cat -> dog\". To reflect\nthis in the generated mask, you simply have to set the embeddings related to the phrases including \"cat\" to\n`source_prompt` and \"dog\" to `target_prompt`.\n* When generating partially inverted latents using `invert`, assign a caption or text embedding describing the\noverall image to the `prompt` argument to help guide the inverse latent sampling process. In most cases, the\nsource concept is sufficiently descriptive to yield good results, but feel free to explore alternatives.\n* When calling the pipeline to generate the final edited image, assign the source concept to `negative_prompt`\nand the target concept to `prompt`. Taking the above example, you simply have to set the embeddings related to\nthe phrases including \"cat\" to `negative_prompt` and \"dog\" to `prompt`.\n* If you wanted to reverse the direction in the example above, i.e., \"dog -> cat\", then it's recommended to:\n    * Swap the `source_prompt` and `target_prompt` in the arguments to `generate_mask`.\n    * Change the input prompt in [`~StableDiffusionDiffEditPipeline.invert`] to include \"dog\".\n    * Swap the `prompt` and `negative_prompt` in the arguments to call the pipeline to generate the final edited image.\n* The source and target prompts, or their corresponding embeddings, can also be automatically generated. Please refer to the [DiffEdit](../../using-diffusers/diffedit) guide for more details.\n\n## StableDiffusionDiffEditPipeline\n[[autodoc]] StableDiffusionDiffEditPipeline\n    - all\n    - generate_mask\n    - invert\n    - __call__\n\n## StableDiffusionPipelineOutput\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/dit.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DiT\n\n[Scalable Diffusion Models with Transformers](https://huggingface.co/papers/2212.09748) (DiT) is by William Peebles and Saining Xie.\n\nThe abstract from the paper is:\n\n*We explore a new class of diffusion models based on the transformer architecture. We train latent diffusion models of images, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops -- through increased transformer depth/width or increased number of input tokens -- consistently have lower FID. In addition to possessing good scalability properties, our largest DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512x512 and 256x256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.*\n\nThe original codebase can be found at [facebookresearch/dit](https://github.com/facebookresearch/dit).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## DiTPipeline\n[[autodoc]] DiTPipeline\n\t- all\n\t- __call__\n\n## ImagePipelineOutput\n[[autodoc]] pipelines.ImagePipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/easyanimate.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n-->\n\n# EasyAnimate\n[EasyAnimate](https://github.com/aigc-apps/EasyAnimate) by Alibaba PAI.\n\nThe description from it's GitHub page:\n*EasyAnimate is a pipeline based on the transformer architecture, designed for generating AI images and videos, and for training baseline models and Lora models for Diffusion Transformer. We support direct prediction from pre-trained EasyAnimate models, allowing for the generation of videos with various resolutions, approximately 6 seconds in length, at 8fps (EasyAnimateV5.1, 1 to 49 frames). Additionally, users can train their own baseline and Lora models for specific style transformations.*\n\nThis pipeline was contributed by [bubbliiiing](https://github.com/bubbliiiing). The original codebase can be found [here](https://huggingface.co/alibaba-pai). The original weights can be found under [hf.co/alibaba-pai](https://huggingface.co/alibaba-pai).\n\nThere are two official EasyAnimate checkpoints for text-to-video and video-to-video.\n\n| checkpoints | recommended inference dtype |\n|:---:|:---:|\n| [`alibaba-pai/EasyAnimateV5.1-12b-zh`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh) | torch.float16 |\n| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |\n\nThere is one official EasyAnimate checkpoints available for image-to-video and video-to-video.\n\n| checkpoints | recommended inference dtype |\n|:---:|:---:|\n| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |\n\nThere are two official EasyAnimate checkpoints available for control-to-video.\n\n| checkpoints | recommended inference dtype |\n|:---:|:---:|\n| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control) | torch.float16 |\n| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera) | torch.float16 |\n\nFor the EasyAnimateV5.1 series:\n- Text-to-video (T2V) and Image-to-video (I2V) works for multiple resolutions. The width and height can vary from 256 to 1024.\n- Both T2V and I2V models support generation with 1~49 frames and work best at this value. Exporting videos at 8 FPS is recommended.\n\n## Quantization\n\nQuantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.\n\nRefer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`EasyAnimatePipeline`] for inference with bitsandbytes.\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, EasyAnimateTransformer3DModel, EasyAnimatePipeline\nfrom diffusers.utils import export_to_video\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_8bit = EasyAnimateTransformer3DModel.from_pretrained(\n    \"alibaba-pai/EasyAnimateV5.1-12b-zh\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\npipeline = EasyAnimatePipeline.from_pretrained(\n    \"alibaba-pai/EasyAnimateV5.1-12b-zh\",\n    transformer=transformer_8bit,\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n)\n\nprompt = \"A cat walks on the grass, realistic style.\"\nnegative_prompt = \"bad detailed\"\nvideo = pipeline(prompt=prompt, negative_prompt=negative_prompt, num_frames=49, num_inference_steps=30).frames[0]\nexport_to_video(video, \"cat.mp4\", fps=8)\n```\n\n## EasyAnimatePipeline\n\n[[autodoc]] EasyAnimatePipeline\n  - all\n  - __call__\n\n## EasyAnimatePipelineOutput\n\n[[autodoc]] pipelines.easyanimate.pipeline_output.EasyAnimatePipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/flux.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Flux\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n  <img alt=\"MPS\" src=\"https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22\">\n</div>\n\nFlux is a series of text-to-image generation models based on diffusion transformers. To know more about Flux, check out the original [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/) by the creators of Flux, Black Forest Labs.\n\nOriginal model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux).\n\n> [!TIP]\n> Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.  For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).\n>\n> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.\n\nFlux comes in the following variants:\n\n| model type | model id |\n|:----------:|:--------:|\n| Timestep-distilled | [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell) |\n| Guidance-distilled | [`black-forest-labs/FLUX.1-dev`](https://huggingface.co/black-forest-labs/FLUX.1-dev) |\n| Fill Inpainting/Outpainting (Guidance-distilled) | [`black-forest-labs/FLUX.1-Fill-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) |\n| Canny Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Canny-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |\n| Depth Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Depth-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |\n| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) |\n| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) |\n| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |\n| Kontext | [`black-forest-labs/FLUX.1-kontext`](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) |\n\nAll checkpoints have different usage which we detail below.\n\n### Timestep-distilled\n\n* `max_sequence_length` cannot be more than 256.\n* `guidance_scale` needs to be 0.\n* As this is a timestep-distilled model, it benefits from fewer sampling steps.\n\n```python\nimport torch\nfrom diffusers import FluxPipeline\n\npipe = FluxPipeline.from_pretrained(\"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16)\npipe.enable_model_cpu_offload()\n\nprompt = \"A cat holding a sign that says hello world\"\nout = pipe(\n    prompt=prompt,\n    guidance_scale=0.,\n    height=768,\n    width=1360,\n    num_inference_steps=4,\n    max_sequence_length=256,\n).images[0]\nout.save(\"image.png\")\n```\n\n### Guidance-distilled\n\n* The guidance-distilled variant takes about 50 sampling steps for good-quality generation.\n* It doesn't have any limitations around the `max_sequence_length`.\n\n```python\nimport torch\nfrom diffusers import FluxPipeline\n\npipe = FluxPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16)\npipe.enable_model_cpu_offload()\n\nprompt = \"a tiny astronaut hatching from an egg on the moon\"\nout = pipe(\n    prompt=prompt,\n    guidance_scale=3.5,\n    height=768,\n    width=1360,\n    num_inference_steps=50,\n).images[0]\nout.save(\"image.png\")\n```\n\n### Fill Inpainting/Outpainting\n\n* Flux Fill pipeline does not require `strength` as an input like regular inpainting pipelines.\n* It supports both inpainting and outpainting.\n\n```python\nimport torch\nfrom diffusers import FluxFillPipeline\nfrom diffusers.utils import load_image\n\nimage = load_image(\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup.png\")\nmask = load_image(\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup_mask.png\")\n\nrepo_id = \"black-forest-labs/FLUX.1-Fill-dev\"\npipe = FluxFillPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(\"cuda\")\n\nimage = pipe(\n    prompt=\"a white paper cup\",\n    image=image,\n    mask_image=mask,\n    height=1632,\n    width=1232,\n    max_sequence_length=512,\n    generator=torch.Generator(\"cpu\").manual_seed(0)\n).images[0]\nimage.save(f\"output.png\")\n```\n\n### Canny Control\n\n**Note:** `black-forest-labs/Flux.1-Canny-dev` is _not_ a [`ControlNetModel`] model. ControlNet models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Canny Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible. \n\n```python\n# !pip install -U controlnet-aux\nimport torch\nfrom controlnet_aux import CannyDetector\nfrom diffusers import FluxControlPipeline\nfrom diffusers.utils import load_image\n\npipe = FluxControlPipeline.from_pretrained(\"black-forest-labs/FLUX.1-Canny-dev\", torch_dtype=torch.bfloat16).to(\"cuda\")\n\nprompt = \"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts.\"\ncontrol_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png\")\n\nprocessor = CannyDetector()\ncontrol_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)\n\nimage = pipe(\n    prompt=prompt,\n    control_image=control_image,\n    height=1024,\n    width=1024,\n    num_inference_steps=50,\n    guidance_scale=30.0,\n).images[0]\nimage.save(\"output.png\")\n```\n\nCanny Control is also possible with a LoRA variant of this condition. The usage is as follows:\n\n```python\n# !pip install -U controlnet-aux\nimport torch\nfrom controlnet_aux import CannyDetector\nfrom diffusers import FluxControlPipeline\nfrom diffusers.utils import load_image\n\npipe = FluxControlPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16).to(\"cuda\")\npipe.load_lora_weights(\"black-forest-labs/FLUX.1-Canny-dev-lora\")\n\nprompt = \"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts.\"\ncontrol_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png\")\n\nprocessor = CannyDetector()\ncontrol_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)\n\nimage = pipe(\n    prompt=prompt,\n    control_image=control_image,\n    height=1024,\n    width=1024,\n    num_inference_steps=50,\n    guidance_scale=30.0,\n).images[0]\nimage.save(\"output.png\")\n```\n\n### Depth Control\n\n**Note:** `black-forest-labs/Flux.1-Depth-dev` is _not_ a ControlNet model. [`ControlNetModel`] models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.\n\n```python\n# !pip install git+https://github.com/huggingface/image_gen_aux\nimport torch\nfrom diffusers import FluxControlPipeline, FluxTransformer2DModel\nfrom diffusers.utils import load_image\nfrom image_gen_aux import DepthPreprocessor\n\npipe = FluxControlPipeline.from_pretrained(\"black-forest-labs/FLUX.1-Depth-dev\", torch_dtype=torch.bfloat16).to(\"cuda\")\n\nprompt = \"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts.\"\ncontrol_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png\")\n\nprocessor = DepthPreprocessor.from_pretrained(\"LiheYoung/depth-anything-large-hf\")\ncontrol_image = processor(control_image)[0].convert(\"RGB\")\n\nimage = pipe(\n    prompt=prompt,\n    control_image=control_image,\n    height=1024,\n    width=1024,\n    num_inference_steps=30,\n    guidance_scale=10.0,\n    generator=torch.Generator().manual_seed(42),\n).images[0]\nimage.save(\"output.png\")\n```\n\nDepth Control is also possible with a LoRA variant of this condition. The usage is as follows:\n\n```python\n# !pip install git+https://github.com/huggingface/image_gen_aux\nimport torch\nfrom diffusers import FluxControlPipeline, FluxTransformer2DModel\nfrom diffusers.utils import load_image\nfrom image_gen_aux import DepthPreprocessor\n\npipe = FluxControlPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16).to(\"cuda\")\npipe.load_lora_weights(\"black-forest-labs/FLUX.1-Depth-dev-lora\")\n\nprompt = \"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts.\"\ncontrol_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png\")\n\nprocessor = DepthPreprocessor.from_pretrained(\"LiheYoung/depth-anything-large-hf\")\ncontrol_image = processor(control_image)[0].convert(\"RGB\")\n\nimage = pipe(\n    prompt=prompt,\n    control_image=control_image,\n    height=1024,\n    width=1024,\n    num_inference_steps=30,\n    guidance_scale=10.0,\n    generator=torch.Generator().manual_seed(42),\n).images[0]\nimage.save(\"output.png\")\n```\n\n### Redux\n\n* Flux Redux pipeline is an adapter for FLUX.1 base models. It can be used with both flux-dev and flux-schnell, for image-to-image generation.\n* You can first use the `FluxPriorReduxPipeline` to get the `prompt_embeds` and `pooled_prompt_embeds`, and then feed them into the `FluxPipeline` for image-to-image generation.\n* When use `FluxPriorReduxPipeline` with a base pipeline, you can set `text_encoder=None` and `text_encoder_2=None` in the base pipeline, in order to save VRAM.\n\n```python\nimport torch\nfrom diffusers import FluxPriorReduxPipeline, FluxPipeline\nfrom diffusers.utils import load_image\ndevice = \"cuda\"\ndtype = torch.bfloat16\n\n\nrepo_redux = \"black-forest-labs/FLUX.1-Redux-dev\"\nrepo_base = \"black-forest-labs/FLUX.1-dev\" \npipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)\npipe = FluxPipeline.from_pretrained(\n    repo_base, \n    text_encoder=None,\n    text_encoder_2=None,\n    torch_dtype=torch.bfloat16\n).to(device)\n\nimage = load_image(\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png\")\npipe_prior_output = pipe_prior_redux(image)\nimages = pipe(\n    guidance_scale=2.5,\n    num_inference_steps=50,\n    generator=torch.Generator(\"cpu\").manual_seed(0),\n    **pipe_prior_output,\n).images\nimages[0].save(\"flux-redux.png\")\n```\n\n### Kontext\n\nFlux Kontext is a model that allows in-context control of the image generation process, allowing for editing, refinement, relighting, style transfer, character customization, and more.\n\n```python\nimport torch\nfrom diffusers import FluxKontextPipeline\nfrom diffusers.utils import load_image\n\npipe = FluxKontextPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-Kontext-dev\", torch_dtype=torch.bfloat16\n)\npipe.to(\"cuda\")\n\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png\").convert(\"RGB\")\nprompt = \"Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors\"\nimage = pipe(\n    image=image,\n    prompt=prompt,\n    guidance_scale=2.5,\n    generator=torch.Generator().manual_seed(42),\n).images[0]\nimage.save(\"flux-kontext.png\")\n```\n\nFlux Kontext comes with an integrity safety checker, which should be run after the image generation step. To run the safety checker, install the official repository from [black-forest-labs/flux](https://github.com/black-forest-labs/flux) and add the following code:\n\n```python\nfrom flux.content_filters import PixtralContentFilter\n\n# ... pipeline invocation to generate images\n\nintegrity_checker = PixtralContentFilter(torch.device(\"cuda\"))\nimage_ = np.array(image) / 255.0\nimage_ = 2 * image_ - 1\nimage_ = torch.from_numpy(image_).to(\"cuda\", dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2)\nif integrity_checker.test_image(image_):\n    raise ValueError(\"Your image has been flagged. Choose another prompt/image or try again.\")\n```\n\n### Kontext Inpainting\n`FluxKontextInpaintPipeline` enables image modification within a fixed mask region. It currently supports both text-based conditioning and image-reference conditioning.\n<hfoptions id=\"kontext-inpaint\">\n<hfoption id=\"text-only\">\n\n\n```python\nimport torch\nfrom diffusers import FluxKontextInpaintPipeline\nfrom diffusers.utils import load_image\n\nprompt = \"Change the yellow dinosaur to green one\"\nimg_url = (\n    \"https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true\"\n)\nmask_url = (\n    \"https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true\"\n)\n\nsource = load_image(img_url)\nmask = load_image(mask_url)\n\npipe = FluxKontextInpaintPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-Kontext-dev\", torch_dtype=torch.bfloat16\n)\npipe.to(\"cuda\")\n\nimage = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0]\nimage.save(\"kontext_inpainting_normal.png\")\n```\n</hfoption>\n<hfoption id=\"image conditioning\">\n\n```python\nimport torch\nfrom diffusers import FluxKontextInpaintPipeline\nfrom diffusers.utils import load_image\n\npipe = FluxKontextInpaintPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-Kontext-dev\", torch_dtype=torch.bfloat16\n)\npipe.to(\"cuda\")\n\nprompt = \"Replace this ball\"\nimg_url = \"https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500\"\nmask_url = \"https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true\"\nimage_reference_url = \"https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s\"\n\nsource = load_image(img_url)\nmask = load_image(mask_url)\nimage_reference = load_image(image_reference_url)\n\nmask = pipe.mask_processor.blur(mask, blur_factor=12)\nimage = pipe(\n    prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0\n).images[0]\nimage.save(\"kontext_inpainting_ref.png\")\n```\n</hfoption>\n</hfoptions>\n\n## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux\n\nWe can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).\n\n```py\nfrom diffusers import FluxControlPipeline\nfrom image_gen_aux import DepthPreprocessor\nfrom diffusers.utils import load_image\nfrom huggingface_hub import hf_hub_download\nimport torch\n\ncontrol_pipe = FluxControlPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16)\ncontrol_pipe.load_lora_weights(\"black-forest-labs/FLUX.1-Depth-dev-lora\", adapter_name=\"depth\")\ncontrol_pipe.load_lora_weights(\n    hf_hub_download(\"ByteDance/Hyper-SD\", \"Hyper-FLUX.1-dev-8steps-lora.safetensors\"), adapter_name=\"hyper-sd\"\n)\ncontrol_pipe.set_adapters([\"depth\", \"hyper-sd\"], adapter_weights=[0.85, 0.125])\ncontrol_pipe.enable_model_cpu_offload()\n\nprompt = \"A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts.\"\ncontrol_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png\")\n\nprocessor = DepthPreprocessor.from_pretrained(\"LiheYoung/depth-anything-large-hf\")\ncontrol_image = processor(control_image)[0].convert(\"RGB\")\n\nimage = control_pipe(\n    prompt=prompt,\n    control_image=control_image,\n    height=1024,\n    width=1024,\n    num_inference_steps=8,\n    guidance_scale=10.0,\n    generator=torch.Generator().manual_seed(42),\n).images[0]\nimage.save(\"output.png\")\n```\n\n## Note about `unload_lora_weights()` when using Flux LoRAs\n\nWhen unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to_overwritten_params=True)` to reset the `pipe.transformer` completely back to its original form. The resultant pipeline can then be used with methods like [`DiffusionPipeline.from_pipe`]. More details about this argument are available in [this PR](https://github.com/huggingface/diffusers/pull/10397).\n\n## IP-Adapter\n\n> [!TIP]\n> Check out [IP-Adapter](../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.\n\nAn IP-Adapter lets you prompt Flux with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images.\n\n```python\nimport torch\nfrom diffusers import FluxPipeline\nfrom diffusers.utils import load_image\n\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg\").resize((1024, 1024))\n\npipe.load_ip_adapter(\n    \"XLabs-AI/flux-ip-adapter\",\n    weight_name=\"ip_adapter.safetensors\",\n    image_encoder_pretrained_model_name_or_path=\"openai/clip-vit-large-patch14\"\n)\npipe.set_ip_adapter_scale(1.0)\n\nimage = pipe(\n    width=1024,\n    height=1024,\n    prompt=\"wearing sunglasses\",\n    negative_prompt=\"\",\n    true_cfg_scale=4.0,\n    generator=torch.Generator().manual_seed(4444),\n    ip_adapter_image=image,\n).images[0]\n\nimage.save('flux_ip_adapter_output.jpg')\n```\n\n<div class=\"justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_output.jpg\"/>\n    <figcaption class=\"mt-2 text-sm text-center text-gray-500\">IP-Adapter examples with prompt \"wearing sunglasses\"</figcaption>\n</div>\n\n## Optimize\n\nFlux is a very large model and requires ~50GB of RAM/VRAM to load all the modeling components. Enable some of the optimizations below to lower the memory requirements.\n\n### Group offloading\n\n[Group offloading](../../optimization/memory#group-offloading) lowers VRAM usage by offloading groups of internal layers rather than the whole model or weights. You need to use [`~hooks.apply_group_offloading`] on all the model components of a pipeline. The `offload_type` parameter allows you to toggle between block and leaf-level offloading. Setting it to `leaf_level` offloads the lowest leaf-level parameters to the CPU instead of offloading at the module-level.\n\nOn CUDA devices that support asynchronous data streaming, set `use_stream=True` to overlap data transfer and computation to accelerate inference.\n\n> [!TIP]\n> It is possible to mix block and leaf-level offloading for different components in a pipeline.\n\n```py\nimport torch\nfrom diffusers import FluxPipeline\nfrom diffusers.hooks import apply_group_offloading\n\nmodel_id = \"black-forest-labs/FLUX.1-dev\"\ndtype = torch.bfloat16\npipe = FluxPipeline.from_pretrained(\n\tmodel_id,\n\ttorch_dtype=dtype,\n)\n\napply_group_offloading(\n    pipe.transformer,\n    offload_type=\"leaf_level\",\n    offload_device=torch.device(\"cpu\"),\n    onload_device=torch.device(\"cuda\"),\n    use_stream=True,\n)\napply_group_offloading(\n    pipe.text_encoder, \n    offload_device=torch.device(\"cpu\"),\n    onload_device=torch.device(\"cuda\"),\n    offload_type=\"leaf_level\",\n    use_stream=True,\n)\napply_group_offloading(\n    pipe.text_encoder_2, \n    offload_device=torch.device(\"cpu\"),\n    onload_device=torch.device(\"cuda\"),\n    offload_type=\"leaf_level\",\n    use_stream=True,\n)\napply_group_offloading(\n    pipe.vae, \n    offload_device=torch.device(\"cpu\"),\n    onload_device=torch.device(\"cuda\"),\n    offload_type=\"leaf_level\",\n    use_stream=True,\n)\n\nprompt=\"A cat wearing sunglasses and working as a lifeguard at pool.\"\n\ngenerator = torch.Generator().manual_seed(181201)\nimage = pipe(\n    prompt,\n    width=576,\n    height=1024,\n    num_inference_steps=30,\n    generator=generator\n).images[0]\nimage\n```\n\n### Running FP16 inference\n\nFlux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.\n\nFP16 inference code:\n```python\nimport torch\nfrom diffusers import FluxPipeline\n\npipe = FluxPipeline.from_pretrained(\"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16) # can replace schnell with dev\n# to run on low vram GPUs (i.e. between 4 and 32 GB VRAM)\npipe.enable_sequential_cpu_offload()\npipe.vae.enable_slicing()\npipe.vae.enable_tiling()\n\npipe.to(torch.float16) # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once\n\nprompt = \"A cat holding a sign that says hello world\"\nout = pipe(\n    prompt=prompt,\n    guidance_scale=0.,\n    height=768,\n    width=1360,\n    num_inference_steps=4,\n    max_sequence_length=256,\n).images[0]\nout.save(\"image.png\")\n```\n\n### Quantization\n\nQuantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.\n\nRefer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`FluxPipeline`] for inference with bitsandbytes.\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline\nfrom transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel\n\nquant_config = BitsAndBytesConfig(load_in_8bit=True)\ntext_encoder_8bit = T5EncoderModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"text_encoder_2\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_8bit = FluxTransformer2DModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\npipeline = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    text_encoder_2=text_encoder_8bit,\n    transformer=transformer_8bit,\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n)\n\nprompt = \"a tiny astronaut hatching from an egg on the moon\"\nimage = pipeline(prompt, guidance_scale=3.5, height=768, width=1360, num_inference_steps=50).images[0]\nimage.save(\"flux.png\")\n```\n\n## Single File Loading for the `FluxTransformer2DModel`\n\nThe `FluxTransformer2DModel` supports loading checkpoints in the original format shipped by Black Forest Labs. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.\n\n> [!TIP]\n> `FP8` inference can be brittle depending on the GPU type, CUDA version, and `torch` version that you are using. It is recommended that you use the `optimum-quanto` library in order to run FP8 inference on your machine.\n\nThe following example demonstrates how to run Flux with less than 16GB of VRAM.\n\nFirst install `optimum-quanto`\n\n```shell\npip install optimum-quanto\n```\n\nThen run the following example\n\n```python\nimport torch\nfrom diffusers import FluxTransformer2DModel, FluxPipeline\nfrom transformers import T5EncoderModel, CLIPTextModel\nfrom optimum.quanto import freeze, qfloat8, quantize\n\nbfl_repo = \"black-forest-labs/FLUX.1-dev\"\ndtype = torch.bfloat16\n\ntransformer = FluxTransformer2DModel.from_single_file(\"https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors\", torch_dtype=dtype)\nquantize(transformer, weights=qfloat8)\nfreeze(transformer)\n\ntext_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder=\"text_encoder_2\", torch_dtype=dtype)\nquantize(text_encoder_2, weights=qfloat8)\nfreeze(text_encoder_2)\n\npipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)\npipe.transformer = transformer\npipe.text_encoder_2 = text_encoder_2\n\npipe.enable_model_cpu_offload()\n\nprompt = \"A cat holding a sign that says hello world\"\nimage = pipe(\n    prompt,\n    guidance_scale=3.5,\n    output_type=\"pil\",\n    num_inference_steps=20,\n    generator=torch.Generator(\"cpu\").manual_seed(0)\n).images[0]\n\nimage.save(\"flux-fp8-dev.png\")\n```\n\n## FluxPipeline\n\n[[autodoc]] FluxPipeline\n\t- all\n\t- __call__\n\n## FluxImg2ImgPipeline\n\n[[autodoc]] FluxImg2ImgPipeline\n\t- all\n\t- __call__\n\n## FluxInpaintPipeline\n\n[[autodoc]] FluxInpaintPipeline\n\t- all\n\t- __call__\n\n\n## FluxControlNetInpaintPipeline\n\n[[autodoc]] FluxControlNetInpaintPipeline\n\t- all\n\t- __call__\n\n## FluxControlNetImg2ImgPipeline\n\n[[autodoc]] FluxControlNetImg2ImgPipeline\n\t- all\n\t- __call__\n\n## FluxControlPipeline\n\n[[autodoc]] FluxControlPipeline\n\t- all\n\t- __call__\n\n## FluxControlImg2ImgPipeline\n\n[[autodoc]] FluxControlImg2ImgPipeline\n\t- all\n\t- __call__\n\n## FluxPriorReduxPipeline\n\n[[autodoc]] FluxPriorReduxPipeline\n\t- all\n\t- __call__\n\n## FluxFillPipeline\n\n[[autodoc]] FluxFillPipeline\n\t- all\n\t- __call__\n\n## FluxKontextPipeline\n\n[[autodoc]] FluxKontextPipeline\n\t- all\n\t- __call__\n\n## FluxKontextInpaintPipeline\n\n[[autodoc]] FluxKontextInpaintPipeline\n\t- all\n\t- __call__"
  },
  {
    "path": "docs/source/en/api/pipelines/flux2.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Flux2\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n  <img alt=\"MPS\" src=\"https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22\">\n</div>\n\nFlux.2 is the recent series of image generation models from Black Forest Labs, preceded by the [Flux.1](./flux.md) series. It is an entirely new model with a new architecture and pre-training done from scratch!\n\nOriginal model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2).\n\n> [!TIP]\n> Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.\n>\n> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.\n\n## Caption upsampling\n\nFlux.2 can potentially generate better better outputs with better prompts. We can \"upsample\"\nan input prompt by setting the `caption_upsample_temperature` argument in the pipeline call arguments.\nThe [official implementation](https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L140) recommends this value to be 0.15.\n\n## Flux2Pipeline\n\n[[autodoc]] Flux2Pipeline\n\t- all\n\t- __call__\n\n## Flux2KleinPipeline\n\n[[autodoc]] Flux2KleinPipeline\n\t- all\n\t- __call__\n\n## Flux2KleinKVPipeline\n\n[[autodoc]] Flux2KleinKVPipeline\n\t- all\n\t- __call__"
  },
  {
    "path": "docs/source/en/api/pipelines/framepack.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n# Framepack\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n[Packing Input Frame Context in Next-Frame Prediction Models for Video Generation](https://huggingface.co/papers/2504.12626) by Lvmin Zhang and Maneesh Agrawala.\n\n*We present a neural network structure, FramePack, to train next-frame (or next-frame-section) prediction models for video generation. The FramePack compresses input frames to make the transformer context length a fixed number regardless of the video length. As a result, we are able to process a large number of frames using video diffusion with computation bottleneck similar to image diffusion. This also makes the training video batch sizes significantly higher (batch sizes become comparable to image diffusion training). We also propose an anti-drifting sampling method that generates frames in inverted temporal order with early-established endpoints to avoid exposure bias (error accumulation over iterations). Finally, we show that existing video diffusion models can be finetuned with FramePack, and their visual quality may be improved because the next-frame prediction supports more balanced diffusion schedulers with less extreme flow shift timesteps.*\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## Available models\n\n| Model name | Description |\n|:---|:---|\n- [`lllyasviel/FramePackI2V_HY`](https://huggingface.co/lllyasviel/FramePackI2V_HY) | Trained with the \"inverted anti-drifting\" strategy as described in the paper. Inference requires setting `sampling_type=\"inverted_anti_drifting\"` when running the pipeline. |\n- [`lllyasviel/FramePack_F1_I2V_HY_20250503`](https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503) | Trained with a novel anti-drifting strategy but inference is performed in \"vanilla\" strategy as described in the paper. Inference requires setting `sampling_type=\"vanilla\"` when running the pipeline. |\n\n## Usage\n\nRefer to the pipeline documentation for basic usage examples. The following section contains examples of offloading, different sampling methods, quantization, and more.\n\n### First and last frame to video\n\nThe following example shows how to use Framepack with start and end image controls, using the inverted anti-drifiting sampling model.\n\n```python\nimport torch\nfrom diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel\nfrom diffusers.utils import export_to_video, load_image\nfrom transformers import SiglipImageProcessor, SiglipVisionModel\n\ntransformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(\n    \"lllyasviel/FramePackI2V_HY\", torch_dtype=torch.bfloat16\n)\nfeature_extractor = SiglipImageProcessor.from_pretrained(\n    \"lllyasviel/flux_redux_bfl\", subfolder=\"feature_extractor\"\n)\nimage_encoder = SiglipVisionModel.from_pretrained(\n    \"lllyasviel/flux_redux_bfl\", subfolder=\"image_encoder\", torch_dtype=torch.float16\n)\npipe = HunyuanVideoFramepackPipeline.from_pretrained(\n    \"hunyuanvideo-community/HunyuanVideo\",\n    transformer=transformer,\n    feature_extractor=feature_extractor,\n    image_encoder=image_encoder,\n    torch_dtype=torch.float16,\n)\n\n# Enable memory optimizations\npipe.enable_model_cpu_offload()\npipe.vae.enable_tiling()\n\nprompt = \"CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective.\"\nfirst_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png\"\n)\nlast_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png\"\n)\noutput = pipe(\n    image=first_image,\n    last_image=last_image,\n    prompt=prompt,\n    height=512,\n    width=512,\n    num_frames=91,\n    num_inference_steps=30,\n    guidance_scale=9.0,\n    generator=torch.Generator().manual_seed(0),\n    sampling_type=\"inverted_anti_drifting\",\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=30)\n```\n\n### Vanilla sampling\n\nThe following example shows how to use Framepack with the F1 model trained with vanilla sampling but new regulation approach for anti-drifting.\n\n```python\nimport torch\nfrom diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel\nfrom diffusers.utils import export_to_video, load_image\nfrom transformers import SiglipImageProcessor, SiglipVisionModel\n\ntransformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(\n    \"lllyasviel/FramePack_F1_I2V_HY_20250503\", torch_dtype=torch.bfloat16\n)\nfeature_extractor = SiglipImageProcessor.from_pretrained(\n    \"lllyasviel/flux_redux_bfl\", subfolder=\"feature_extractor\"\n)\nimage_encoder = SiglipVisionModel.from_pretrained(\n    \"lllyasviel/flux_redux_bfl\", subfolder=\"image_encoder\", torch_dtype=torch.float16\n)\npipe = HunyuanVideoFramepackPipeline.from_pretrained(\n    \"hunyuanvideo-community/HunyuanVideo\",\n    transformer=transformer,\n    feature_extractor=feature_extractor,\n    image_encoder=image_encoder,\n    torch_dtype=torch.float16,\n)\n\n# Enable memory optimizations\npipe.enable_model_cpu_offload()\npipe.vae.enable_tiling()\n\nimage = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png\"\n)\noutput = pipe(\n    image=image,\n    prompt=\"A penguin dancing in the snow\",\n    height=832,\n    width=480,\n    num_frames=91,\n    num_inference_steps=30,\n    guidance_scale=9.0,\n    generator=torch.Generator().manual_seed(0),\n    sampling_type=\"vanilla\",\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=30)\n```\n\n### Group offloading\n\nGroup offloading ([`~hooks.apply_group_offloading`]) provides aggressive memory optimizations for offloading internal parts of any model to the CPU, with possibly no additional overhead to generation time. If you have very low VRAM available, this approach may be suitable for you depending on the amount of CPU RAM available.\n\n```python\nimport torch\nfrom diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel\nfrom diffusers.hooks import apply_group_offloading\nfrom diffusers.utils import export_to_video, load_image\nfrom transformers import SiglipImageProcessor, SiglipVisionModel\n\ntransformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(\n    \"lllyasviel/FramePack_F1_I2V_HY_20250503\", torch_dtype=torch.bfloat16\n)\nfeature_extractor = SiglipImageProcessor.from_pretrained(\n    \"lllyasviel/flux_redux_bfl\", subfolder=\"feature_extractor\"\n)\nimage_encoder = SiglipVisionModel.from_pretrained(\n    \"lllyasviel/flux_redux_bfl\", subfolder=\"image_encoder\", torch_dtype=torch.float16\n)\npipe = HunyuanVideoFramepackPipeline.from_pretrained(\n    \"hunyuanvideo-community/HunyuanVideo\",\n    transformer=transformer,\n    feature_extractor=feature_extractor,\n    image_encoder=image_encoder,\n    torch_dtype=torch.float16,\n)\n\n# Enable group offloading\nonload_device = torch.device(\"cuda\")\noffload_device = torch.device(\"cpu\")\nlist(map(\n    lambda x: apply_group_offloading(x, onload_device, offload_device, offload_type=\"leaf_level\", use_stream=True, low_cpu_mem_usage=True),\n    [pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]\n))\npipe.image_encoder.to(onload_device)\npipe.vae.to(onload_device)\npipe.vae.enable_tiling()\n\nimage = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png\"\n)\noutput = pipe(\n    image=image,\n    prompt=\"A penguin dancing in the snow\",\n    height=832,\n    width=480,\n    num_frames=91,\n    num_inference_steps=30,\n    guidance_scale=9.0,\n    generator=torch.Generator().manual_seed(0),\n    sampling_type=\"vanilla\",\n).frames[0]\nprint(f\"Max memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB\")\nexport_to_video(output, \"output.mp4\", fps=30)\n```\n\n## HunyuanVideoFramepackPipeline\n\n[[autodoc]] HunyuanVideoFramepackPipeline\n  - all\n  - __call__\n\n## HunyuanVideoPipelineOutput\n\n[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput\n\n"
  },
  {
    "path": "docs/source/en/api/pipelines/glm_image.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n-->\n\n# GLM-Image\n\n## Overview\n\nGLM-Image is an image generation model adopts a hybrid autoregressive + diffusion decoder architecture, effectively pushing the upper bound of visual fidelity and fine-grained details. In general image generation quality, it aligns with industry-standard LDM-based approaches, while demonstrating significant advantages in knowledge-intensive image generation scenarios.\n\nModel architecture: a hybrid autoregressive + diffusion decoder design、\n\n+ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K–4K tokens, corresponding to 1K–2K high-resolution image outputs. You can check AR model in class `GlmImageForConditionalGeneration` of `transformers` library.\n+ Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images.\n\nPost-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality.\n\n+ Autoregressive module: provides low-frequency feedback signals focused on aesthetics and semantic alignment, improving instruction following and artistic expressiveness.\n+ Decoder module: delivers high-frequency feedback targeting detail fidelity and text accuracy, resulting in highly realistic textures, lighting, and color reproduction, as well as more precise text rendering.\n\nGLM-Image supports both text-to-image and image-to-image generation within a single model\n\n+ Text-to-image: generates high-detail images from textual descriptions, with particularly strong performance in information-dense scenarios.\n+ Image-to-image: supports a wide range of tasks, including image editing, style transfer, multi-subject consistency, and identity-preserving generation for people and objects.\n\nThis pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The codebase can be found [here](https://huggingface.co/zai-org/GLM-Image).\n\n## Usage examples\n\n### Text to Image Generation\n\n```python\nimport torch\nfrom diffusers.pipelines.glm_image import GlmImagePipeline\n\npipe = GlmImagePipeline.from_pretrained(\"zai-org/GLM-Image\",torch_dtype=torch.bfloat16,device_map=\"cuda\")\nprompt = \"A beautifully designed modern food magazine style dessert recipe illustration, themed around a raspberry mousse cake. The overall layout is clean and bright, divided into four main areas: the top left features a bold black title 'Raspberry Mousse Cake Recipe Guide', with a soft-lit close-up photo of the finished cake on the right, showcasing a light pink cake adorned with fresh raspberries and mint leaves; the bottom left contains an ingredient list section, titled 'Ingredients' in a simple font, listing 'Flour 150g', 'Eggs 3', 'Sugar 120g', 'Raspberry puree 200g', 'Gelatin sheets 10g', 'Whipping cream 300ml', and 'Fresh raspberries', each accompanied by minimalist line icons (like a flour bag, eggs, sugar jar, etc.); the bottom right displays four equally sized step boxes, each containing high-definition macro photos and corresponding instructions, arranged from top to bottom as follows: Step 1 shows a whisk whipping white foam (with the instruction 'Whip egg whites to stiff peaks'), Step 2 shows a red-and-white mixture being folded with a spatula (with the instruction 'Gently fold in the puree and batter'), Step 3 shows pink liquid being poured into a round mold (with the instruction 'Pour into mold and chill for 4 hours'), Step 4 shows the finished cake decorated with raspberries and mint leaves (with the instruction 'Decorate with raspberries and mint'); a light brown information bar runs along the bottom edge, with icons on the left representing 'Preparation time: 30 minutes', 'Cooking time: 20 minutes', and 'Servings: 8'. The overall color scheme is dominated by creamy white and light pink, with a subtle paper texture in the background, featuring compact and orderly text and image layout with clear information hierarchy.\"\nimage = pipe(\n    prompt=prompt,\n    height=32 * 32,\n    width=36 * 32,\n    num_inference_steps=30,\n    guidance_scale=1.5,\n    generator=torch.Generator(device=\"cuda\").manual_seed(42),\n).images[0]\n\nimage.save(\"output_t2i.png\")\n```\n\n### Image to Image Generation\n\n```python\nimport torch\nfrom diffusers.pipelines.glm_image import GlmImagePipeline\nfrom PIL import Image\n\npipe = GlmImagePipeline.from_pretrained(\"zai-org/GLM-Image\",torch_dtype=torch.bfloat16,device_map=\"cuda\")\nimage_path = \"cond.jpg\" \nprompt = \"Replace the background of the snow forest with an underground station featuring an automatic escalator.\"\nimage = Image.open(image_path).convert(\"RGB\")\nimage = pipe(\n    prompt=prompt,\n    image=[image], # can input multiple images for multi-image-to-image generation such as [image, image1]\n    height=33 * 32,\n    width=32 * 32,\n    num_inference_steps=30,\n    guidance_scale=1.5,\n    generator=torch.Generator(device=\"cuda\").manual_seed(42),\n).images[0]\n\nimage.save(\"output_i2i.png\")\n```\n\n+ Since the AR model used in GLM-Image is configured with `do_sample=True` and a temperature of `0.95` by default, the generated images can vary significantly across runs. We do not recommend setting do_sample=False, as this may lead to incorrect or degenerate outputs from the AR model.\n\n## GlmImagePipeline\n\n[[autodoc]] pipelines.glm_image.pipeline_glm_image.GlmImagePipeline\n  - all\n  - __call__\n\n## GlmImagePipelineOutput\n\n[[autodoc]] pipelines.glm_image.pipeline_output.GlmImagePipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/helios.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n<div style=\"float: right;\">\n  <div class=\"flex flex-wrap space-x-1\">\n    <a href=\"https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference\" target=\"_blank\" rel=\"noopener\">\n      <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n    </a>\n  </div>\n</div>\n\n# Helios\n\n[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/2603.04379) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Zongjian Li, Xinwei Huang, Xiao Yang, Li Yuan.\n\n*  <u>We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality.</u> We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios-Page).\n\nThe following Helios models are supported in Diffusers:\n\n- [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality, with v-prediction, standard CFG and custom HeliosScheduler.\n- [Helios-Mid](https://huggingface.co/BestWishYsh/Helios-Mid): Intermediate Weight, with v-prediction, CFG-Zero* and custom HeliosScheduler.\n- [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency, with x0-prediction and custom HeliosDMDScheduler.\n\n> [!TIP]\n> Click on the Helios models in the right sidebar for more examples of video generation.\n\n### Optimizing Memory and Inference Speed\n\nThe example below demonstrates how to generate a video from text optimized for memory or inference speed.\n\n<hfoptions id=\"optimization\">\n<hfoption id=\"memory\">\n\nRefer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.\n\nThe Helios model below requires ~6GB of VRAM.\n\n```py\nimport torch\nfrom diffusers import AutoModel, HeliosPipeline\nfrom diffusers.hooks.group_offloading import apply_group_offloading\nfrom diffusers.utils import export_to_video\n\nvae = AutoModel.from_pretrained(\"BestWishYsh/Helios-Base\", subfolder=\"vae\", torch_dtype=torch.float32)\n\n# group-offloading\npipeline = HeliosPipeline.from_pretrained(\n    \"BestWishYsh/Helios-Base\",\n    vae=vae,\n    torch_dtype=torch.bfloat16\n)\npipeline.enable_group_offload(\n    onload_device=torch.device(\"cuda\"),\n    offload_device=torch.device(\"cpu\"),\n    offload_type=\"leaf_level\",\n    use_stream=True,\n    record_stream=True,\n)\n\nprompt = \"\"\"\nA vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue \nand yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with \na variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, \nallowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades \nof red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and \nthe vivid colors of its surroundings. A close-up shot with dynamic movement.\n\"\"\"\nnegative_prompt = \"\"\"\nBright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,\nlow quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,\nmisshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\n\"\"\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=99,\n    num_inference_steps=50,\n    guidance_scale=5.0,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).frames[0]\nexport_to_video(output, \"helios_base_t2v_output.mp4\", fps=24)\n```\n\n</hfoption>\n<hfoption id=\"inference speed\">\n\n[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Attention Backends](../../optimization/attention_backends) such as FlashAttention and SageAttention can significantly increase speed by optimizing the computation of the attention mechanism. [Context Parallelism](../../training/distributed_inference#context-parallelism) splits the input sequence across multiple devices to enable processing of long contexts in parallel, reducing memory pressure and latency. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.\n\n```py\nimport torch\nfrom diffusers import AutoModel, HeliosPipeline\nfrom diffusers.utils import export_to_video\n\nvae = AutoModel.from_pretrained(\"BestWishYsh/Helios-Base\", subfolder=\"vae\", torch_dtype=torch.float32)\n\npipeline = HeliosPipeline.from_pretrained(\n    \"BestWishYsh/Helios-Base\",\n    vae=vae,\n    torch_dtype=torch.bfloat16\n)\npipeline.to(\"cuda\")\n\n# attention backend\n# pipeline.transformer.set_attention_backend(\"flash\")\npipeline.transformer.set_attention_backend(\"_flash_3_hub\") # For Hopper GPUs\n\n# torch.compile\ntorch.backends.cudnn.benchmark = True\npipeline.text_encoder.compile(mode=\"max-autotune-no-cudagraphs\", dynamic=False)\npipeline.vae.compile(mode=\"max-autotune-no-cudagraphs\", dynamic=False)\npipeline.transformer.compile(mode=\"max-autotune-no-cudagraphs\", dynamic=False)\n\nprompt = \"\"\"\nA vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue \nand yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with \na variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, \nallowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades \nof red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and \nthe vivid colors of its surroundings. A close-up shot with dynamic movement.\n\"\"\"\nnegative_prompt = \"\"\"\nBright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,\nlow quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,\nmisshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\n\"\"\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=99,\n    num_inference_steps=50,\n    guidance_scale=5.0,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).frames[0]\nexport_to_video(output, \"helios_base_t2v_output.mp4\", fps=24)\n```\n\n</hfoption>\n</hfoptions>\n\n\n### Generation with Helios-Base\n\nThe example below demonstrates how to use Helios-Base to generate video based on text, image or video.\n\n<hfoptions id=\"Helios-Base usage\">\n<hfoption id=\"usage\">\n\n```python\nimport torch\nfrom diffusers import AutoModel, HeliosPipeline\nfrom diffusers.utils import export_to_video, load_video, load_image\n\nvae = AutoModel.from_pretrained(\"BestWishYsh/Helios-Base\", subfolder=\"vae\", torch_dtype=torch.float32)\n\npipeline = HeliosPipeline.from_pretrained(\n    \"BestWishYsh/Helios-Base\",\n    vae=vae,\n    torch_dtype=torch.bfloat16\n)\npipeline.to(\"cuda\")\n\nnegative_prompt = \"\"\"\nBright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,\nlow quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,\nmisshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\n\"\"\"\n\n# For Text-to-Video\nprompt = \"\"\"\nA vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue \nand yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with \na variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, \nallowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades \nof red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and \nthe vivid colors of its surroundings. A close-up shot with dynamic movement.\n\"\"\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=99,\n    num_inference_steps=50,\n    guidance_scale=5.0,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).frames[0]\nexport_to_video(output, \"helios_base_t2v_output.mp4\", fps=24)\n\n# For Image-to-Video\nprompt = \"\"\"\nA towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, \nilluminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, \ncasting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes \napparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and \nrelentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and \nrespect for nature’s might.\n\"\"\"\nimage_path = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    image=load_image(image_path).resize((640, 384)),\n    num_frames=99,\n    num_inference_steps=50,\n    guidance_scale=5.0,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).frames[0]\nexport_to_video(output, \"helios_base_i2v_output.mp4\", fps=24)\n\n# For Video-to-Video\nprompt = \"\"\"\nA bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees \nunder a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, \nemphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to \nthe scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. \nA front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.\n\"\"\"\nvideo_path = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    video=load_video(video_path),\n    num_frames=99,\n    num_inference_steps=50,\n    guidance_scale=5.0,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).frames[0]\nexport_to_video(output, \"helios_base_v2v_output.mp4\", fps=24)\n```\n\n</hfoption>\n</hfoptions>\n\n\n### Generation with Helios-Mid\n\nThe example below demonstrates how to use Helios-Mid to generate video based on text, image or video.\n\n<hfoptions id=\"Helios-Mid usage\">\n<hfoption id=\"usage\">\n\n```python\nimport torch\nfrom diffusers import AutoModel, HeliosPyramidPipeline\nfrom diffusers.utils import export_to_video, load_video, load_image\n\nvae = AutoModel.from_pretrained(\"BestWishYsh/Helios-Mid\", subfolder=\"vae\", torch_dtype=torch.float32)\n\npipeline = HeliosPyramidPipeline.from_pretrained(\n    \"BestWishYsh/Helios-Mid\",\n    vae=vae,\n    torch_dtype=torch.bfloat16\n)\npipeline.to(\"cuda\")\n\nnegative_prompt = \"\"\"\nBright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,\nlow quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,\nmisshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\n\"\"\"\n\n# For Text-to-Video\nprompt = \"\"\"\nA vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue \nand yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with \na variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, \nallowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades \nof red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and \nthe vivid colors of its surroundings. A close-up shot with dynamic movement.\n\"\"\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=99,\n    pyramid_num_inference_steps_list=[20, 20, 20],\n    guidance_scale=5.0,\n    use_zero_init=True,\n    zero_steps=1,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).frames[0]\nexport_to_video(output, \"helios_pyramid_t2v_output.mp4\", fps=24)\n\n# For Image-to-Video\nprompt = \"\"\"\nA towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, \nilluminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, \ncasting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes \napparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and \nrelentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and \nrespect for nature’s might.\n\"\"\"\nimage_path = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    image=load_image(image_path).resize((640, 384)),\n    num_frames=99,\n    pyramid_num_inference_steps_list=[20, 20, 20],\n    guidance_scale=5.0,\n    use_zero_init=True,\n    zero_steps=1,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).frames[0]\nexport_to_video(output, \"helios_pyramid_i2v_output.mp4\", fps=24)\n\n# For Video-to-Video\nprompt = \"\"\"\nA bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees \nunder a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, \nemphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to \nthe scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. \nA front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.\n\"\"\"\nvideo_path = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    video=load_video(video_path),\n    num_frames=99,\n    pyramid_num_inference_steps_list=[20, 20, 20],\n    guidance_scale=5.0,\n    use_zero_init=True,\n    zero_steps=1,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).frames[0]\nexport_to_video(output, \"helios_pyramid_v2v_output.mp4\", fps=24)\n```\n\n</hfoption>\n</hfoptions>\n\n\n### Generation with Helios-Distilled\n\nThe example below demonstrates how to use Helios-Distilled to generate video based on text, image or video.\n\n<hfoptions id=\"Helios-Distilled usage\">\n<hfoption id=\"usage\">\n\n```python\nimport torch\nfrom diffusers import AutoModel, HeliosPyramidPipeline\nfrom diffusers.utils import export_to_video, load_video, load_image\n\nvae = AutoModel.from_pretrained(\"BestWishYsh/Helios-Distilled\", subfolder=\"vae\", torch_dtype=torch.float32)\n\npipeline = HeliosPyramidPipeline.from_pretrained(\n    \"BestWishYsh/Helios-Distilled\",\n    vae=vae,\n    torch_dtype=torch.bfloat16\n)\npipeline.to(\"cuda\")\n\nnegative_prompt = \"\"\"\nBright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,\nlow quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,\nmisshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\n\"\"\"\n\n# For Text-to-Video\nprompt = \"\"\"\nA vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue \nand yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with \na variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear, \nallowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades \nof red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and \nthe vivid colors of its surroundings. A close-up shot with dynamic movement.\n\"\"\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=240,\n    pyramid_num_inference_steps_list=[2, 2, 2],\n    guidance_scale=1.0,\n    is_amplify_first_chunk=True,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).frames[0]\nexport_to_video(output, \"helios_distilled_t2v_output.mp4\", fps=24)\n\n# For Image-to-Video\nprompt = \"\"\"\nA towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water, \nilluminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest, \ncasting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes \napparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and \nrelentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and \nrespect for nature’s might.\n\"\"\"\nimage_path = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    image=load_image(image_path).resize((640, 384)),\n    num_frames=240,\n    pyramid_num_inference_steps_list=[2, 2, 2],\n    guidance_scale=1.0,\n    is_amplify_first_chunk=True,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).frames[0]\nexport_to_video(output, \"helios_distilled_i2v_output.mp4\", fps=24)\n\n# For Video-to-Video\nprompt = \"\"\"\nA bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees \nunder a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop, \nemphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to \nthe scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere. \nA front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.\n\"\"\"\nvideo_path = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    video=load_video(video_path),\n    num_frames=240,\n    pyramid_num_inference_steps_list=[2, 2, 2],\n    guidance_scale=1.0,\n    is_amplify_first_chunk=True,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).frames[0]\nexport_to_video(output, \"helios_distilled_v2v_output.mp4\", fps=24)\n```\n\n</hfoption>\n</hfoptions>\n\n\n## HeliosPipeline\n\n[[autodoc]] HeliosPipeline\n\n  - all\n  - __call__\n\n## HeliosPyramidPipeline\n\n[[autodoc]] HeliosPyramidPipeline\n\n  - all\n  - __call__\n\n## HeliosPipelineOutput\n\n[[autodoc]] pipelines.helios.pipeline_output.HeliosPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/hidream.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n# HiDreamImage\n\n[HiDream-I1](https://huggingface.co/HiDream-ai) by HiDream.ai\n\n> [!TIP]\n> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.\n\n## Available models\n\nThe following models are available for the [`HiDreamImagePipeline`] pipeline:\n\n| Model name | Description |\n|:---|:---|\n| [`HiDream-ai/HiDream-I1-Full`](https://huggingface.co/HiDream-ai/HiDream-I1-Full) | - |\n| [`HiDream-ai/HiDream-I1-Dev`](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) | - |\n| [`HiDream-ai/HiDream-I1-Fast`](https://huggingface.co/HiDream-ai/HiDream-I1-Fast) | - |\n\n## HiDreamImagePipeline\n\n[[autodoc]] HiDreamImagePipeline\n  - all\n  - __call__\n\n## HiDreamImagePipelineOutput\n\n[[autodoc]] pipelines.hidream_image.pipeline_output.HiDreamImagePipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/hunyuan_video.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n<div style=\"float: right;\">\n  <div class=\"flex flex-wrap space-x-1\">\n    <a href=\"https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference\" target=\"_blank\" rel=\"noopener\">\n      <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n    </a>\n  </div>\n</div>\n\n# HunyuanVideo\n\n[HunyuanVideo](https://huggingface.co/papers/2412.03603) is a 13B parameter diffusion transformer model designed to be competitive with closed-source video foundation models and enable wider community access. This model uses a \"dual-stream to single-stream\" architecture to separately process the video and text tokens first, before concatenating and feeding them to the transformer to fuse the multimodal information. A pretrained multimodal large language model (MLLM) is used as the encoder because it has better image-text alignment, better image detail description and reasoning, and it can be used as a zero-shot learner if system instructions are added to user prompts. Finally, HunyuanVideo uses a 3D causal variational autoencoder to more efficiently process video data at the original resolution and frame rate.\n\nYou can find all the original HunyuanVideo checkpoints under the [Tencent](https://huggingface.co/tencent) organization.\n\n> [!TIP]\n> Click on the HunyuanVideo models in the right sidebar for more examples of video generation tasks.\n>\n> The examples below use a checkpoint from [hunyuanvideo-community](https://huggingface.co/hunyuanvideo-community) because the weights are stored in a layout compatible with Diffusers.\n\nThe example below demonstrates how to generate a video optimized for memory or inference speed.\n\n<hfoptions id=\"usage\">\n<hfoption id=\"memory\">\n\nRefer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.\n\nThe quantized HunyuanVideo model below requires ~14GB of VRAM.\n\n```py\nimport torch\nfrom diffusers import AutoModel, HunyuanVideoPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\nfrom diffusers.utils import export_to_video\n\n# quantize weights to int4 with bitsandbytes\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_backend=\"bitsandbytes_4bit\",\n    quant_kwargs={\n      \"load_in_4bit\": True,\n      \"bnb_4bit_quant_type\": \"nf4\",\n      \"bnb_4bit_compute_dtype\": torch.bfloat16\n      },\n    components_to_quantize=\"transformer\"\n)\n\npipeline = HunyuanVideoPipeline.from_pretrained(\n    \"hunyuanvideo-community/HunyuanVideo\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n)\n\n# model-offloading and tiling\npipeline.enable_model_cpu_offload()\npipeline.vae.enable_tiling()\n\nprompt = \"A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys.\"\nvideo = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]\nexport_to_video(video, \"output.mp4\", fps=15)\n```\n\n</hfoption>\n<hfoption id=\"inference speed\">\n\n[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.\n\n```py\nimport torch\nfrom diffusers import AutoModel, HunyuanVideoPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\nfrom diffusers.utils import export_to_video\n\n# quantize weights to int4 with bitsandbytes\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_backend=\"bitsandbytes_4bit\",\n    quant_kwargs={\n      \"load_in_4bit\": True,\n      \"bnb_4bit_quant_type\": \"nf4\",\n      \"bnb_4bit_compute_dtype\": torch.bfloat16\n      },\n    components_to_quantize=\"transformer\"\n)\n\npipeline = HunyuanVideoPipeline.from_pretrained(\n    \"hunyuanvideo-community/HunyuanVideo\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n)\n\n# model-offloading and tiling\npipeline.enable_model_cpu_offload()\npipeline.vae.enable_tiling()\n\n# torch.compile\npipeline.transformer.to(memory_format=torch.channels_last)\npipeline.transformer = torch.compile(\n    pipeline.transformer, mode=\"max-autotune\", fullgraph=True\n)\n\nprompt = \"A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys.\"\nvideo = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]\nexport_to_video(video, \"output.mp4\", fps=15)\n```\n\n</hfoption>\n</hfoptions>\n\n## Notes\n\n- HunyuanVideo supports LoRAs with [`~loaders.HunyuanVideoLoraLoaderMixin.load_lora_weights`].\n\n  <details>\n  <summary>Show example code</summary>\n\n  ```py\n  import torch\n  from diffusers import AutoModel, HunyuanVideoPipeline\n  from diffusers.quantizers import PipelineQuantizationConfig\n  from diffusers.utils import export_to_video\n\n  # quantize weights to int4 with bitsandbytes\n  pipeline_quant_config = PipelineQuantizationConfig(\n      quant_backend=\"bitsandbytes_4bit\",\n      quant_kwargs={\n        \"load_in_4bit\": True,\n        \"bnb_4bit_quant_type\": \"nf4\",\n        \"bnb_4bit_compute_dtype\": torch.bfloat16\n        },\n      components_to_quantize=\"transformer\"\n  )\n\n  pipeline = HunyuanVideoPipeline.from_pretrained(\n      \"hunyuanvideo-community/HunyuanVideo\",\n      quantization_config=pipeline_quant_config,\n      torch_dtype=torch.bfloat16,\n  )\n\n  # load LoRA weights\n  pipeline.load_lora_weights(\"https://huggingface.co/lucataco/hunyuan-steamboat-willie-10\", adapter_name=\"steamboat-willie\")\n  pipeline.set_adapters(\"steamboat-willie\", 0.9)\n\n  # model-offloading and tiling\n  pipeline.enable_model_cpu_offload()\n  pipeline.vae.enable_tiling()\n\n  # use \"In the style of SWR\" to trigger the LoRA\n  prompt = \"\"\"\n  In the style of SWR. A black and white animated scene featuring a fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys.\n  \"\"\"\n  video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]\n  export_to_video(video, \"output.mp4\", fps=15)\n  ```\n\n  </details>\n\n- Refer to the table below for recommended inference values.\n\n  | parameter | recommended value |\n  |---|---|\n  | text encoder dtype | `torch.float16` |\n  | transformer dtype | `torch.bfloat16` |\n  | vae dtype | `torch.float16` |\n  | `num_frames (k)` | 4 * `k` + 1 |\n\n- Try lower `shift` values (`2.0` to `5.0`) for lower resolution videos and higher `shift` values (`7.0` to `12.0`) for higher resolution images.\n\n## HunyuanVideoPipeline\n\n[[autodoc]] HunyuanVideoPipeline\n  - all\n  - __call__\n\n## HunyuanVideoPipelineOutput\n\n[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/hunyuan_video15.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n\n# HunyuanVideo-1.5\n\nHunyuanVideo-1.5 is a lightweight yet powerful video generation model that achieves state-of-the-art visual quality and motion coherence with only 8.3 billion parameters, enabling efficient inference on consumer-grade GPUs. This achievement is built upon several key components, including meticulous data curation, an advanced DiT architecture with selective and sliding tile attention (SSTA), enhanced bilingual understanding through glyph-aware text encoding, progressive pre-training and post-training, and an efficient video super-resolution network. Leveraging these designs, we developed a unified framework capable of high-quality text-to-video and image-to-video generation across multiple durations and resolutions. Extensive experiments demonstrate that this compact and proficient model establishes a new state-of-the-art among open-source models.\n\nYou can find all the original HunyuanVideo checkpoints under the [Tencent](https://huggingface.co/tencent) organization.\n\n> [!TIP]\n> Click on the HunyuanVideo models in the right sidebar for more examples of video generation tasks.\n>\n> The examples below use a checkpoint from [hunyuanvideo-community](https://huggingface.co/hunyuanvideo-community) because the weights are stored in a layout compatible with Diffusers.\n\nThe example below demonstrates how to generate a video optimized for memory or inference speed.\n\n<hfoptions id=\"usage\">\n<hfoption id=\"memory\">\n\nRefer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.\n\n\n```py\nimport torch\nfrom diffusers import AutoModel, HunyuanVideo15Pipeline\nfrom diffusers.utils import export_to_video\n\n\npipeline = HunyuanVideo15Pipeline.from_pretrained(\n    \"HunyuanVideo-1.5-Diffusers-480p_t2v\",\n    torch_dtype=torch.bfloat16,\n)\n\n# model-offloading and tiling\npipeline.enable_model_cpu_offload()\npipeline.vae.enable_tiling()\n\nprompt = \"A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys.\"\nvideo = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]\nexport_to_video(video, \"output.mp4\", fps=15)\n```\n\n## Notes\n\n- HunyuanVideo1.5 use attention masks with variable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently.\n\n    - **H100/H800:** `_flash_3_hub` or `_flash_3_varlen_hub`\n    - **A100/A800/RTX 4090:** `flash_hub` or `flash_varlen_hub`\n    - **Other GPUs:** `sage_hub`\n\nRefer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend.\n\n\n```py\npipe.transformer.set_attention_backend(\"flash_hub\")  # or your preferred backend\n```\n\n- [`HunyuanVideo15Pipeline`] use guider and does not take `guidance_scale` parameter at runtime. \n\nYou can check the default guider configuration using `pipe.guider`:\n\n```py\n>>> pipe.guider \nClassifierFreeGuidance {\n  \"_class_name\": \"ClassifierFreeGuidance\",\n  \"_diffusers_version\": \"0.36.0.dev0\",\n  \"enabled\": true,\n  \"guidance_rescale\": 0.0,\n  \"guidance_scale\": 6.0,\n  \"start\": 0.0,\n  \"stop\": 1.0,\n  \"use_original_formulation\": false\n}\n\nState:\n  step: None\n  num_inference_steps: None\n  timestep: None\n  count_prepared: 0\n  enabled: True\n  num_conditions: 2\n```\n\nTo update guider configuration, you can run `pipe.guider = pipe.guider.new(...)`\n\n```py\npipe.guider = pipe.guider.new(guidance_scale=5.0)\n```\n\nRead more on Guider [here](../../using-diffusers/guiders).\n\n\n\n## HunyuanVideo15Pipeline\n\n[[autodoc]] HunyuanVideo15Pipeline\n  - all\n  - __call__\n\n## HunyuanVideo15ImageToVideoPipeline\n\n[[autodoc]] HunyuanVideo15ImageToVideoPipeline\n  - all\n  - __call__\n\n## HunyuanVideo15PipelineOutput\n\n[[autodoc]] pipelines.hunyuan_video1_5.pipeline_output.HunyuanVideo15PipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/hunyuandit.md",
    "content": "<!--Copyright 2025 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Hunyuan-DiT\n![chinese elements understanding](https://github.com/gnobitab/diffusers-hunyuan/assets/1157982/39b99036-c3cb-4f16-bb1a-40ec25eda573)\n\n[Hunyuan-DiT : A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding](https://huggingface.co/papers/2405.08748) from Tencent Hunyuan.\n\nThe abstract from the paper is:\n\n*We present Hunyuan-DiT, a text-to-image diffusion transformer with fine-grained understanding of both English and Chinese. To construct Hunyuan-DiT, we carefully design the transformer structure, text encoder, and positional encoding. We also build from scratch a whole data pipeline to update and evaluate data for iterative model optimization. For fine-grained language understanding, we train a Multimodal Large Language Model to refine the captions of the images. Finally, Hunyuan-DiT can perform multi-turn multimodal dialogue with users, generating and refining images according to the context. Through our holistic human evaluation protocol with more than 50 professional human evaluators, Hunyuan-DiT sets a new state-of-the-art in Chinese-to-image generation compared with other open-source models.*\n\n\nYou can find the original codebase at [Tencent/HunyuanDiT](https://github.com/Tencent/HunyuanDiT) and all the available checkpoints at [Tencent-Hunyuan](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT).\n\n**Highlights**: HunyuanDiT supports Chinese/English-to-image, multi-resolution generation.\n\nHunyuanDiT has the following components:\n* It uses a diffusion transformer as the backbone\n* It combines two text encoders, a bilingual CLIP and a multilingual T5 encoder\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n> [!TIP]\n> You can further improve generation quality by passing the generated image from [`HungyuanDiTPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.\n\n## Optimization\n\nYou can optimize the pipeline's runtime and memory consumption with torch.compile and feed-forward chunking. To learn about other optimization methods, check out the [Speed up inference](../../optimization/fp16) and [Reduce memory usage](../../optimization/memory) guides.\n\n### Inference\n\nUse [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.\n\nFirst, load the pipeline:\n\n```python\nfrom diffusers import HunyuanDiTPipeline\nimport torch\n\npipeline = HunyuanDiTPipeline.from_pretrained(\n\t\"Tencent-Hunyuan/HunyuanDiT-Diffusers\", torch_dtype=torch.float16\n).to(\"cuda\")\n```\n\nThen change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:\n\n```python\npipeline.transformer.to(memory_format=torch.channels_last)\npipeline.vae.to(memory_format=torch.channels_last)\n```\n\nFinally, compile the components and run inference:\n\n```python\npipeline.transformer = torch.compile(pipeline.transformer, mode=\"max-autotune\", fullgraph=True)\npipeline.vae.decode = torch.compile(pipeline.vae.decode, mode=\"max-autotune\", fullgraph=True)\n\nimage = pipeline(prompt=\"一个宇航员在骑马\").images[0]\n```\n\nThe [benchmark](https://gist.github.com/sayakpaul/29d3a14905cfcbf611fe71ebd22e9b23) results on a 80GB A100 machine are:\n\n```bash\nWith torch.compile(): Average inference time: 12.470 seconds.\nWithout torch.compile(): Average inference time: 20.570 seconds.\n```\n\n### Memory optimization\n\nBy loading the T5 text encoder in 8 bits, you can run the pipeline in just under 6 GBs of GPU VRAM. Refer to [this script](https://gist.github.com/sayakpaul/3154605f6af05b98a41081aaba5ca43e) for details.\n\nFurthermore, you can use the [`~HunyuanDiT2DModel.enable_forward_chunking`] method to reduce memory usage. Feed-forward chunking runs the feed-forward layers in a transformer block in a loop instead of all at once. This gives you a trade-off between memory consumption and inference runtime.\n\n```diff\n+ pipeline.transformer.enable_forward_chunking(chunk_size=1, dim=1)\n```\n\n\n## HunyuanDiTPipeline\n\n[[autodoc]] HunyuanDiTPipeline\n\t- all\n\t- __call__\n\n"
  },
  {
    "path": "docs/source/en/api/pipelines/hunyuanimage21.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n# HunyuanImage2.1\n\n\nHunyuanImage-2.1 is a 17B text-to-image model that is capable of generating 2K (2048 x 2048) resolution images\n\nHunyuanImage-2.1 comes in the following variants:\n\n| model type | model id |\n|:----------:|:--------:|\n| HunyuanImage-2.1 | [hunyuanvideo-community/HunyuanImage-2.1-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Diffusers) |\n| HunyuanImage-2.1-Distilled | [hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers) |\n| HunyuanImage-2.1-Refiner | [hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers) |\n\n> [!TIP]\n> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.\n\n## HunyuanImage-2.1\n\nHunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../../using-diffusers/guiders)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.\n\n```python\nimport torch\nfrom diffusers import HunyuanImagePipeline\n\npipe = HunyuanImagePipeline.from_pretrained(\n    \"hunyuanvideo-community/HunyuanImage-2.1-Diffusers\", \n    torch_dtype=torch.bfloat16\n)\npipe = pipe.to(\"cuda\")\n``` \n\nYou can inspect the `guider` object:\n\n```py\n>>> pipe.guider\nAdaptiveProjectedMixGuidance {\n  \"_class_name\": \"AdaptiveProjectedMixGuidance\",\n  \"_diffusers_version\": \"0.36.0.dev0\",\n  \"adaptive_projected_guidance_momentum\": -0.5,\n  \"adaptive_projected_guidance_rescale\": 10.0,\n  \"adaptive_projected_guidance_scale\": 10.0,\n  \"adaptive_projected_guidance_start_step\": 5,\n  \"enabled\": true,\n  \"eta\": 0.0,\n  \"guidance_rescale\": 0.0,\n  \"guidance_scale\": 3.5,\n  \"start\": 0.0,\n  \"stop\": 1.0,\n  \"use_original_formulation\": false\n}\n\nState:\n  step: None\n  num_inference_steps: None\n  timestep: None\n  count_prepared: 0\n  enabled: True\n  num_conditions: 2\n  momentum_buffer: None\n  is_apg_enabled: False\n  is_cfg_enabled: True\n```\n\nTo update the guider with a different configuration, use the `new()` method. For example, to generate an image with `guidance_scale=5.0` while keeping all other default guidance parameters:\n\n```py\nimport torch\nfrom diffusers import HunyuanImagePipeline\n\npipe = HunyuanImagePipeline.from_pretrained(\n    \"hunyuanvideo-community/HunyuanImage-2.1-Diffusers\", \n    torch_dtype=torch.bfloat16\n)\npipe = pipe.to(\"cuda\")\n\n# Update the guider configuration\npipe.guider = pipe.guider.new(guidance_scale=5.0)\n\nprompt = (\n    \"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, \"\n    \"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a \"\n    \"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style.\"\n)\n\nimage = pipe(\n    prompt=prompt, \n    num_inference_steps=50, \n    height=2048, \n    width=2048,\n).images[0]\nimage.save(\"image.png\")\n```\n\n\n## HunyuanImage-2.1-Distilled\n\nuse `distilled_guidance_scale` with the guidance-distilled checkpoint, \n\n```py\nimport torch\nfrom diffusers import HunyuanImagePipeline\npipe = HunyuanImagePipeline.from_pretrained(\"hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers\", torch_dtype=torch.bfloat16)\npipe = pipe.to(\"cuda\")\n\nprompt = (\n    \"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, \"\n    \"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a \"\n    \"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style.\"\n)\n\nout = pipe(\n    prompt,\n    num_inference_steps=8,\n    distilled_guidance_scale=3.25,\n    height=2048,\n    width=2048,\n    generator=generator,\n).images[0]\n\n```\n\n\n## HunyuanImagePipeline\n\n[[autodoc]] HunyuanImagePipeline\n  - all\n  - __call__\n\n## HunyuanImageRefinerPipeline\n\n[[autodoc]] HunyuanImageRefinerPipeline\n  - all\n  - __call__\n\n\n## HunyuanImagePipelineOutput\n\n[[autodoc]] pipelines.hunyuan_image.pipeline_output.HunyuanImagePipelineOutput"
  },
  {
    "path": "docs/source/en/api/pipelines/i2vgenxl.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# I2VGen-XL\n\n[I2VGen-XL: High-Quality Image-to-Video Synthesis via Cascaded Diffusion Models](https://hf.co/papers/2311.04145.pdf) by Shiwei Zhang, Jiayu Wang, Yingya Zhang, Kang Zhao, Hangjie Yuan, Zhiwu Qin, Xiang Wang, Deli Zhao, and Jingren Zhou.\n\nThe abstract from the paper is:\n\n*Video synthesis has recently made remarkable strides benefiting from the rapid development of diffusion models. However, it still encounters challenges in terms of semantic accuracy, clarity and spatio-temporal continuity. They primarily arise from the scarcity of well-aligned text-video data and the complex inherent structure of videos, making it difficult for the model to simultaneously ensure semantic and qualitative excellence. In this report, we propose a cascaded I2VGen-XL approach that enhances model performance by decoupling these two factors and ensures the alignment of the input data by utilizing static images as a form of crucial guidance. I2VGen-XL consists of two stages: i) the base stage guarantees coherent semantics and preserves content from input images by using two hierarchical encoders, and ii) the refinement stage enhances the video's details by incorporating an additional brief text and improves the resolution to 1280×720. To improve the diversity, we collect around 35 million single-shot text-video pairs and 6 billion text-image pairs to optimize the model. By this means, I2VGen-XL can simultaneously enhance the semantic accuracy, continuity of details and clarity of generated videos. Through extensive experiments, we have investigated the underlying principles of I2VGen-XL and compared it with current top methods, which can demonstrate its effectiveness on diverse data. The source code and models will be publicly available at [this https URL](https://i2vgen-xl.github.io/).*\n\nThe original codebase can be found [here](https://github.com/ali-vilab/i2vgen-xl/). The model checkpoints can be found [here](https://huggingface.co/ali-vilab/).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the [\"Reduce memory usage\"] section [here](../../using-diffusers/svd#reduce-memory-usage).\n\nSample output with I2VGenXL:\n\n<table>\n    <tr>\n        <td><center>\n        library.\n        <br>\n        <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/i2vgen-xl-example.gif\"\n            alt=\"library\"\n            style=\"width: 300px;\" />\n        </center></td>\n    </tr>\n</table>\n\n## Notes\n\n* I2VGenXL always uses a `clip_skip` value of 1. This means it leverages the penultimate layer representations from the text encoder of CLIP.\n* It can generate videos of quality that is often on par with [Stable Video Diffusion](../../using-diffusers/svd) (SVD).\n* Unlike SVD, it additionally accepts text prompts as inputs.\n* It can generate higher resolution videos.\n* When using the [`DDIMScheduler`] (which is default for this pipeline), less than 50 steps for inference leads to bad results.\n* This implementation is 1-stage variant of I2VGenXL. The main figure in the [I2VGen-XL](https://huggingface.co/papers/2311.04145) paper shows a 2-stage variant, however, 1-stage variant works well. See [this discussion](https://github.com/huggingface/diffusers/discussions/7952) for more details.\n\n## I2VGenXLPipeline\n[[autodoc]] I2VGenXLPipeline\n\t- all\n\t- __call__\n\n## I2VGenXLPipelineOutput\n[[autodoc]] pipelines.i2vgen_xl.pipeline_i2vgen_xl.I2VGenXLPipelineOutput"
  },
  {
    "path": "docs/source/en/api/pipelines/kandinsky.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\nhttp://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Kandinsky 2.1\n\nKandinsky 2.1 is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Vladimir Arkhipkin](https://github.com/oriBetelgeuse), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey), and [Denis Dimitrov](https://github.com/denndimitrov).\n\nThe description from it's GitHub page is:\n\n*Kandinsky 2.1 inherits best practicies from Dall-E 2 and Latent diffusion, while introducing some new ideas. As text and image encoder it uses CLIP model and diffusion image prior (mapping) between latent spaces of CLIP modalities. This approach increases the visual performance of the model and unveils new horizons in blending images and text-guided image manipulation.*\n\nThe original codebase can be found at [ai-forever/Kandinsky-2](https://github.com/ai-forever/Kandinsky-2).\n\n> [!TIP]\n> Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## KandinskyPriorPipeline\n\n[[autodoc]] KandinskyPriorPipeline\n\t- all\n\t- __call__\n\t- interpolate\n\n## KandinskyPipeline\n\n[[autodoc]] KandinskyPipeline\n\t- all\n\t- __call__\n\n## KandinskyCombinedPipeline\n\n[[autodoc]] KandinskyCombinedPipeline\n\t- all\n\t- __call__\n\n## KandinskyImg2ImgPipeline\n\n[[autodoc]] KandinskyImg2ImgPipeline\n\t- all\n\t- __call__\n\n## KandinskyImg2ImgCombinedPipeline\n\n[[autodoc]] KandinskyImg2ImgCombinedPipeline\n\t- all\n\t- __call__\n\n## KandinskyInpaintPipeline\n\n[[autodoc]] KandinskyInpaintPipeline\n\t- all\n\t- __call__\n\n## KandinskyInpaintCombinedPipeline\n\n[[autodoc]] KandinskyInpaintCombinedPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/kandinsky3.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\nhttp://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Kandinsky 3\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nKandinsky 3 is created by [Vladimir Arkhipkin](https://github.com/oriBetelgeuse),[Anastasia Maltseva](https://github.com/NastyaMittseva),[Igor Pavlov](https://github.com/boomb0om),[Andrei Filatov](https://github.com/anvilarth),[Arseniy Shakhmatov](https://github.com/cene555),[Andrey Kuznetsov](https://github.com/kuznetsoffandrey),[Denis Dimitrov](https://github.com/denndimitrov), [Zein Shaheen](https://github.com/zeinsh)\n\nThe description from it's GitHub page:\n\n*Kandinsky 3.0 is an open-source text-to-image diffusion model built upon the Kandinsky2-x model family. In comparison to its predecessors, enhancements have been made to the text understanding and visual quality of the model, achieved by increasing the size of the text encoder and Diffusion U-Net models, respectively.*\n\nIts architecture includes 3 main components:\n1. [FLAN-UL2](https://huggingface.co/google/flan-ul2), which is an encoder decoder model based on the T5 architecture.\n2. New U-Net architecture featuring BigGAN-deep blocks doubles depth while maintaining the same number of parameters.\n3. Sber-MoVQGAN is a decoder proven to have superior results in image restoration.\n\n\n\nThe original codebase can be found at [ai-forever/Kandinsky-3](https://github.com/ai-forever/Kandinsky-3).\n\n> [!TIP]\n> Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.\n\n> [!TIP]\n> Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## Kandinsky3Pipeline\n\n[[autodoc]] Kandinsky3Pipeline\n\t- all\n\t- __call__\n\n## Kandinsky3Img2ImgPipeline\n\n[[autodoc]] Kandinsky3Img2ImgPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/kandinsky5_image.md",
    "content": "<!--Copyright 2025 The HuggingFace Team and Kandinsky Lab Team. All rights reserved.\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\nhttp://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Kandinsky 5.0 Image\n\n[Kandinsky 5.0](https://arxiv.org/abs/2511.14993) is a family of diffusion models for Video & Image generation. \n\nKandinsky 5.0 Image Lite is a lightweight image generation model (6B parameters).\n\nThe model introduces several key innovations:\n- **Latent diffusion pipeline** with **Flow Matching** for improved training stability\n- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings\n- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding\n- **Flux VAE** for efficient image encoding and decoding\n\nThe original codebase can be found at [kandinskylab/Kandinsky-5](https://github.com/kandinskylab/Kandinsky-5).\n\n> [!TIP]\n> Check out the [Kandinsky Lab](https://huggingface.co/kandinskylab) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.\n\n\n## Available Models\n\nKandinsky 5.0 Image Lite:\n\n| model_id | Description | Use Cases |\n|------------|-------------|-----------|\n| [**kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers**](https://huggingface.co/kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers) | 6B image Supervised Fine-Tuned model | Highest generation quality |\n| [**kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers**](https://huggingface.co/kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers) | 6B image editing Supervised Fine-Tuned model | Highest generation quality |\n| [**kandinskylab/Kandinsky-5.0-T2I-Lite-pretrain-Diffusers**](https://huggingface.co/kandinskylab/Kandinsky-5.0-T2I-Lite-pretrain-Diffusers) | 6B image Base pretrained model | Research and fine-tuning |\n| [**kandinskylab/Kandinsky-5.0-I2I-Lite-pretrain-Diffusers**](https://huggingface.co/kandinskylab/Kandinsky-5.0-I2I-Lite-pretrain-Diffusers) | 6B image editing Base pretrained model | Research and fine-tuning |\n\n## Usage Examples\n\n### Basic Text-to-Image Generation\n\n```python\nimport torch\nfrom diffusers import Kandinsky5T2IPipeline\n\n# Load the pipeline\nmodel_id = \"kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers\"\npipe = Kandinsky5T2IPipeline.from_pretrained(model_id)\n_ = pipe.to(device='cuda',dtype=torch.bfloat16)\n\n# Generate image\nprompt = \"A fluffy, expressive cat wearing a bright red hat with a soft, slightly textured fabric. The hat should look cozy and well-fitted on the cat’s head. On the front of the hat, add clean, bold white text that reads “SWEET”, clearly visible and neatly centered. Ensure the overall lighting highlights the hat’s color and the cat’s fur details.\"\n\noutput = pipe(\n    prompt=prompt,\n    negative_prompt=\"\",\n    height=1024,\n    width=1024,\n    num_inference_steps=50,\n    guidance_scale=3.5,\n).image[0]\n```\n\n### Basic Image-to-Image Generation\n\n```python\nimport torch\nfrom diffusers import Kandinsky5I2IPipeline\nfrom diffusers.utils import load_image \n# Load the pipeline\nmodel_id = \"kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers\"\npipe = Kandinsky5I2IPipeline.from_pretrained(model_id)\n\n_ = pipe.to(device='cuda',dtype=torch.bfloat16)\npipe.enable_model_cpu_offload()                                               # <--- Enable CPU offloading for single GPU inference\n\n# Edit the input image\nimage = load_image(\n    \"https://huggingface.co/kandinsky-community/kandinsky-3/resolve/main/assets/title.jpg?download=true\"\n)\n\nprompt = \"Change the background from a winter night scene to a bright summer day. Place the character on a sandy beach with clear blue sky, soft sunlight, and gentle waves in the distance. Replace the winter clothing with a light short-sleeved T-shirt (in soft pastel colors) and casual shorts. Ensure the character’s fur reflects warm daylight instead of cold winter tones. Add small beach details such as seashells, footprints in the sand, and a few scattered beach toys nearby. Keep the oranges in the scene, but place them naturally on the sand.\"\nnegative_prompt = \"\"\n\noutput = pipe(\n    image=image,\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    guidance_scale=3.5,\n).image[0]\n```\n\n\n## Kandinsky5T2IPipeline\n\n[[autodoc]] Kandinsky5T2IPipeline\n    - all\n    - __call__\n\n## Kandinsky5I2IPipeline\n\n[[autodoc]] Kandinsky5I2IPipeline\n    - all\n    - __call__\n\n\n## Citation\n```bibtex\n@misc{kandinsky2025,\n    author = {Alexander Belykh and Alexander Varlamov and Alexey Letunovskiy and Anastasia Aliaskina and Anastasia Maltseva and Anastasiia Kargapoltseva and Andrey Shutkin and Anna Averchenkova and Anna Dmitrienko and Bulat Akhmatov and Denis Dimitrov and Denis Koposov and Denis Parkhomenko and Dmitrii and Ilya Vasiliev and Ivan Kirillov and Julia Agafonova and Kirill Chernyshev and Kormilitsyn Semen and Lev Novitskiy and Maria Kovaleva and Mikhail Mamaev and Mikhailov and Nikita Kiselev and Nikita Osterov and Nikolai Gerasimenko and Nikolai Vaulin and Olga Kim and Olga Vdovchenko and Polina Gavrilova and Polina Mikhailova and Tatiana Nikulina and Viacheslav Vasilev and Vladimir Arkhipkin and Vladimir Korviakov and Vladimir Polovnikov and Yury Kolabushin},\n    title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},\n    howpublished = {\\url{https://github.com/kandinskylab/Kandinsky-5}},\n    year = 2025\n}\n```\n"
  },
  {
    "path": "docs/source/en/api/pipelines/kandinsky5_video.md",
    "content": "<!--Copyright 2025 The HuggingFace Team Kandinsky Lab Team. All rights reserved.\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\nhttp://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Kandinsky 5.0 Video\n\n[Kandinsky 5.0](https://arxiv.org/abs/2511.14993) is a family of diffusion models for Video & Image generation.\n\nKandinsky 5.0 Lite line-up of lightweight video generation models (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.\n\nKandinsky 5.0 Pro line-up of large high quality video generation models (19B parameters). It offers high qualty generation in HD and more generation formats like I2V.\n\nThe model introduces several key innovations:\n- **Latent diffusion pipeline** with **Flow Matching** for improved training stability\n- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings\n- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding\n- **HunyuanVideo 3D VAE** for efficient video encoding and decoding\n- **Sparse attention mechanisms** (NABLA) for efficient long-sequence processing\n\nThe original codebase can be found at [kandinskylab/Kandinsky-5](https://github.com/kandinskylab/Kandinsky-5).\n\n> [!TIP]\n> Check out the [Kandinsky Lab](https://huggingface.co/kandinskylab) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.\n\n## Available Models\n\nKandinsky 5.0 T2V Pro:\n\n| model_id | Description | Use Cases |\n|------------|-------------|-----------|\n| **kandinskylab/Kandinsky-5.0-T2V-Pro-sft-5s-Diffusers** | 5 second Text-to-Video Pro model | High-quality text-to-video generation |\n| **kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers** | 5 second Image-to-Video Pro model | High-quality image-to-video generation |\n\nKandinsky 5.0 T2V Lite:\n| model_id | Description | Use Cases |\n|------------|-------------|-----------|\n| **kandinskylab/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers** | 5 second Supervised Fine-Tuned model | Highest generation quality |\n| **kandinskylab/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers** | 10 second Supervised Fine-Tuned model | Highest generation quality |\n| **kandinskylab/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers** | 5 second Classifier-Free Guidance distilled | 2× faster inference |\n| **kandinskylab/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers** | 10 second Classifier-Free Guidance distilled | 2× faster inference |\n| **kandinskylab/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers** | 5 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |\n| **kandinskylab/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers** | 10 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |\n| **kandinskylab/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers** | 5 second Base pretrained model | Research and fine-tuning |\n| **kandinskylab/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers** | 10 second Base pretrained model | Research and fine-tuning |\n\n\n## Usage Examples\n\n### Basic Text-to-Video Generation\n\n#### Pro\n**⚠️ Warning!** all Pro models should be infered with pipeline.enable_model_cpu_offload()  \n```python\nimport torch\nfrom diffusers import Kandinsky5T2VPipeline\nfrom diffusers.utils import export_to_video\n\n# Load the pipeline\nmodel_id = \"kandinskylab/Kandinsky-5.0-T2V-Pro-sft-5s-Diffusers\"\npipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)\n\npipe = pipe.to(\"cuda\")\npipeline.transformer.set_attention_backend(\"flex\")                            # <--- Set attention bakend to Flex\npipeline.enable_model_cpu_offload()                                           # <--- Enable cpu offloading for single GPU inference\npipeline.transformer.compile(mode=\"max-autotune-no-cudagraphs\", dynamic=True) # <--- Compile with max-autotune-no-cudagraphs\n\n# Generate video\nprompt = \"A cat and a dog baking a cake together in a kitchen.\"\nnegative_prompt = \"Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards\"\n\noutput = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=768,\n    width=1024,\n    num_frames=121,  # ~5 seconds at 24fps\n    num_inference_steps=50,\n    guidance_scale=5.0,\n).frames[0]\n\nexport_to_video(output, \"output.mp4\", fps=24, quality=9)\n```\n\n#### Lite\n```python\nimport torch\nfrom diffusers import Kandinsky5T2VPipeline\nfrom diffusers.utils import export_to_video\n\n# Load the pipeline\nmodel_id = \"kandinskylab/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers\"\npipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)\npipe = pipe.to(\"cuda\")\n\n# Generate video\nprompt = \"A cat and a dog baking a cake together in a kitchen.\"\nnegative_prompt = \"Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards\"\n\noutput = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=512,\n    width=768,\n    num_frames=121,  # ~5 seconds at 24fps\n    num_inference_steps=50,\n    guidance_scale=5.0,\n).frames[0]\n\nexport_to_video(output, \"output.mp4\", fps=24, quality=9)\n```\n\n### 10 second Models\n**⚠️ Warning!** all 10 second models should be used with Flex attention and max-autotune-no-cudagraphs compilation:\n\n```python\npipe = Kandinsky5T2VPipeline.from_pretrained(\n    \"kandinskylab/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers\", \n    torch_dtype=torch.bfloat16\n)\npipe = pipe.to(\"cuda\")\n\npipe.transformer.set_attention_backend(\n    \"flex\"\n)                                       # <--- Set attention bakend to Flex\npipe.transformer.compile(\n    mode=\"max-autotune-no-cudagraphs\", \n    dynamic=True\n)                                       # <--- Compile with max-autotune-no-cudagraphs\n\nprompt = \"A cat and a dog baking a cake together in a kitchen.\"\nnegative_prompt = \"Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards\"\n\noutput = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=512,\n    width=768,\n    num_frames=241,\n    num_inference_steps=50,\n    guidance_scale=5.0,\n).frames[0]\n\nexport_to_video(output, \"output.mp4\", fps=24, quality=9)\n```\n\n### Diffusion Distilled model\n**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):\n\n```python\nmodel_id = \"kandinskylab/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers\"\npipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)\npipe = pipe.to(\"cuda\")\n\noutput = pipe(\n    prompt=\"A beautiful sunset over mountains\",\n    num_inference_steps=16,  # <--- Model is distilled in 16 steps\n    guidance_scale=1.0,      # <--- no CFG\n).frames[0]\n\nexport_to_video(output, \"output.mp4\", fps=24, quality=9)\n```\n\n\n### Basic Image-to-Video Generation\n**⚠️ Warning!** all Pro models should be infered with pipeline.enable_model_cpu_offload()  \n```python\nimport torch\nfrom diffusers import Kandinsky5T2VPipeline\nfrom diffusers.utils import export_to_video\n\n# Load the pipeline\nmodel_id = \"kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers\"\npipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)\n\npipe = pipe.to(\"cuda\")\npipeline.transformer.set_attention_backend(\"flex\")                            # <--- Set attention bakend to Flex\npipeline.enable_model_cpu_offload()                                           # <--- Enable cpu offloading for single GPU inference\npipeline.transformer.compile(mode=\"max-autotune-no-cudagraphs\", dynamic=True) # <--- Compile with max-autotune-no-cudagraphs\n\n# Generate video\nimage = load_image(\n    \"https://huggingface.co/kandinsky-community/kandinsky-3/resolve/main/assets/title.jpg?download=true\"\n)\nheight = 896\nwidth = 896\nimage = image.resize((width, height))\n\nprompt = \"An funny furry creture smiles happily and holds a sign that says 'Kandinsky'\"\nnegative_prompt = \"\"\n\noutput = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    num_frames=121,  # ~5 seconds at 24fps\n    num_inference_steps=50,\n    guidance_scale=5.0,\n).frames[0]\n\nexport_to_video(output, \"output.mp4\", fps=24, quality=9)\n```\n\n\n\n## Kandinsky 5.0 Pro Side-by-Side evaluation\n\n<table border=\"0\" style=\"width: 200; text-align: left; margin-top: 20px;\">\n  <tr>\n      <td>\n          <img width=\"200\" alt=\"image\" src=\"https://github.com/user-attachments/assets/73e5ff00-2735-40fd-8f01-767de9181918\" />\n      </td>\n      <td>\n         <img width=\"200\" alt=\"image\" src=\"https://github.com/user-attachments/assets/f449a9e7-74b7-481d-82da-02723e396acd\" />\n      </td>\n\n  <tr>\n      <td>\n          Comparison with Veo 3 \n      </td>\n      <td>\n          Comparison with Veo 3 fast\n      </td>\n  <tr>\n      <td>\n          <img width=\"200\" alt=\"image\" src=\"https://github.com/user-attachments/assets/a6902fb6-b5e8-4093-adad-aa4caab79c6d\" />\n      </td>\n      <td>\n          <img width=\"200\" alt=\"image\" src=\"https://github.com/user-attachments/assets/09986015-3d07-4de8-b942-c145039b9b2d\" />\n      </td>\n  <tr>\n      <td>\n          Comparison with Wan 2.2 A14B Text-to-Video mode\n      </td>\n      <td>\n          Comparison with Wan 2.2 A14B Image-to-Video mode\n      </td>\n\n</table>\n\n\n## Kandinsky 5.0 Lite Side-by-Side evaluation\n\nThe evaluation is based on the expanded prompts from the [Movie Gen benchmark](https://github.com/facebookresearch/MovieGenBench), which are available in the expanded_prompt column of the benchmark/moviegen_bench.csv file.\n\n<table border=\"0\" style=\"width: 400; text-align: left; margin-top: 20px;\">\n  <tr>\n      <td>\n          <img src=\"https://github.com/kandinskylab/kandinsky-5/raw/main/assets/sbs/kandinsky_5_video_lite_vs_sora.jpg\" width=400 >\n      </td>\n      <td>\n          <img src=\"https://github.com/kandinskylab/kandinsky-5/raw/main/assets/sbs/kandinsky_5_video_lite_vs_wan_2.1_14B.jpg\" width=400 >\n      </td>\n  <tr>\n      <td>\n          <img src=\"https://github.com/kandinskylab/kandinsky-5/raw/main/assets/sbs/kandinsky_5_video_lite_vs_wan_2.2_5B.jpg\" width=400 >\n      </td>\n      <td>\n          <img src=\"https://github.com/kandinskylab/kandinsky-5/raw/main/assets/sbs/kandinsky_5_video_lite_vs_wan_2.2_A14B.jpg\" width=400 >\n      </td>\n  <tr>\n      <td>\n          <img src=\"https://github.com/kandinskylab/kandinsky-5/raw/main/assets/sbs/kandinsky_5_video_lite_vs_wan_2.1_1.3B.jpg\" width=400 >\n      </td>\n\n</table>\n\n\n\n\n## Kandinsky 5.0 Lite Distill Side-by-Side evaluation\n\n<table border=\"0\" style=\"width: 400; text-align: left; margin-top: 20px;\">\n  <tr>\n      <td>\n          <img src=\"https://github.com/kandinskylab/kandinsky-5/raw/main/assets/sbs/kandinsky_5_video_lite_5s_vs_kandinsky_5_video_lite_distill_5s.jpg\" width=400 >\n      </td>\n      <td>\n          <img src=\"https://github.com/kandinskylab/kandinsky-5/raw/main/assets/sbs/kandinsky_5_video_lite_10s_vs_kandinsky_5_video_lite_distill_10s.jpg\" width=400 >\n      </td>\n\n</table>\n\n## Kandinsky5T2VPipeline\n\n[[autodoc]] Kandinsky5T2VPipeline\n    - all\n    - __call__\n\n## Kandinsky5I2VPipeline\n\n[[autodoc]] Kandinsky5I2VPipeline\n    - all\n    - __call__\n\n\n## Citation\n```bibtex\n@misc{kandinsky2025,\n    author = {Alexander Belykh and Alexander Varlamov and Alexey Letunovskiy and Anastasia Aliaskina and Anastasia Maltseva and Anastasiia Kargapoltseva and Andrey Shutkin and Anna Averchenkova and Anna Dmitrienko and Bulat Akhmatov and Denis Dimitrov and Denis Koposov and Denis Parkhomenko and Dmitrii and Ilya Vasiliev and Ivan Kirillov and Julia Agafonova and Kirill Chernyshev and Kormilitsyn Semen and Lev Novitskiy and Maria Kovaleva and Mikhail Mamaev and Mikhailov and Nikita Kiselev and Nikita Osterov and Nikolai Gerasimenko and Nikolai Vaulin and Olga Kim and Olga Vdovchenko and Polina Gavrilova and Polina Mikhailova and Tatiana Nikulina and Viacheslav Vasilev and Vladimir Arkhipkin and Vladimir Korviakov and Vladimir Polovnikov and Yury Kolabushin},\n    title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},\n    howpublished = {\\url{https://github.com/kandinskylab/Kandinsky-5}},\n    year = 2025\n}\n```\n"
  },
  {
    "path": "docs/source/en/api/pipelines/kandinsky_v22.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\nhttp://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Kandinsky 2.2\n\nKandinsky 2.2 is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Vladimir Arkhipkin](https://github.com/oriBetelgeuse), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey), and [Denis Dimitrov](https://github.com/denndimitrov).\n\nThe description from it's GitHub page is:\n\n*Kandinsky 2.2 brings substantial improvements upon its predecessor, Kandinsky 2.1, by introducing a new, more powerful image encoder - CLIP-ViT-G and the ControlNet support. The switch to CLIP-ViT-G as the image encoder significantly increases the model's capability to generate more aesthetic pictures and better understand text, thus enhancing the model's overall performance. The addition of the ControlNet mechanism allows the model to effectively control the process of generating images. This leads to more accurate and visually appealing outputs and opens new possibilities for text-guided image manipulation.*\n\nThe original codebase can be found at [ai-forever/Kandinsky-2](https://github.com/ai-forever/Kandinsky-2).\n\n> [!TIP]\n> Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.\n\n> [!TIP]\n> Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## KandinskyV22PriorPipeline\n\n[[autodoc]] KandinskyV22PriorPipeline\n\t- all\n\t- __call__\n\t- interpolate\n\n## KandinskyV22Pipeline\n\n[[autodoc]] KandinskyV22Pipeline\n\t- all\n\t- __call__\n\n## KandinskyV22CombinedPipeline\n\n[[autodoc]] KandinskyV22CombinedPipeline\n\t- all\n\t- __call__\n\n## KandinskyV22ControlnetPipeline\n\n[[autodoc]] KandinskyV22ControlnetPipeline\n\t- all\n\t- __call__\n\n## KandinskyV22PriorEmb2EmbPipeline\n\n[[autodoc]] KandinskyV22PriorEmb2EmbPipeline\n\t- all\n\t- __call__\n\t- interpolate\n\n## KandinskyV22Img2ImgPipeline\n\n[[autodoc]] KandinskyV22Img2ImgPipeline\n\t- all\n\t- __call__\n\n## KandinskyV22Img2ImgCombinedPipeline\n\n[[autodoc]] KandinskyV22Img2ImgCombinedPipeline\n\t- all\n\t- __call__\n\n## KandinskyV22ControlnetImg2ImgPipeline\n\n[[autodoc]] KandinskyV22ControlnetImg2ImgPipeline\n\t- all\n\t- __call__\n\n## KandinskyV22InpaintPipeline\n\n[[autodoc]] KandinskyV22InpaintPipeline\n\t- all\n\t- __call__\n\n## KandinskyV22InpaintCombinedPipeline\n\n[[autodoc]] KandinskyV22InpaintCombinedPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/kolors.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Kolors: Effective Training of Diffusion Model for Photorealistic Text-to-Image Synthesis\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n  <img alt=\"MPS\" src=\"https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22\">\n</div>\n\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/kolors_header_collage.png)\n\nKolors is a large-scale text-to-image generation model based on latent diffusion, developed by [the Kuaishou Kolors team](https://github.com/Kwai-Kolors/Kolors). Trained on billions of text-image pairs, Kolors exhibits significant advantages over both open-source and closed-source models in visual quality, complex semantic accuracy, and text rendering for both Chinese and English characters. Furthermore, Kolors supports both Chinese and English inputs, demonstrating strong performance in understanding and generating Chinese-specific content. For more details, please refer to this [technical report](https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf).\n\nThe abstract from the technical report is:\n\n*We present Kolors, a latent diffusion model for text-to-image synthesis, characterized by its profound understanding of both English and Chinese, as well as an impressive degree of photorealism. There are three key insights contributing to the development of Kolors. Firstly, unlike large language model T5 used in Imagen and Stable Diffusion 3, Kolors is built upon the General Language Model (GLM), which enhances its comprehension capabilities in both English and Chinese. Moreover, we employ a multimodal large language model to recaption the extensive training dataset for fine-grained text understanding. These strategies significantly improve Kolors’ ability to comprehend intricate semantics, particularly those involving multiple entities, and enable its advanced text rendering capabilities. Secondly, we divide the training of Kolors into two phases: the concept learning phase with broad knowledge and the quality improvement phase with specifically curated high-aesthetic data. Furthermore, we investigate the critical role of the noise schedule and introduce a novel schedule to optimize high-resolution image generation. These strategies collectively enhance the visual appeal of the generated high-resolution images. Lastly, we propose a category-balanced benchmark KolorsPrompts, which serves as a guide for the training and evaluation of Kolors. Consequently, even when employing the commonly used U-Net backbone, Kolors has demonstrated remarkable performance in human evaluations, surpassing the existing open-source models and achieving Midjourney-v6 level performance, especially in terms of visual appeal. We will release the code and weights of Kolors at <https://github.com/Kwai-Kolors/Kolors>, and hope that it will benefit future research and applications in the visual generation community.*\n\n## Usage Example\n\n```python\nimport torch\n\nfrom diffusers import DPMSolverMultistepScheduler, KolorsPipeline\n\npipe = KolorsPipeline.from_pretrained(\"Kwai-Kolors/Kolors-diffusers\", torch_dtype=torch.float16, variant=\"fp16\")\npipe.to(\"cuda\")\npipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)\n\nimage = pipe(\n    prompt='一张瓢虫的照片，微距，变焦，高质量，电影，拿着一个牌子，写着\"可图\"',\n    negative_prompt=\"\",\n    guidance_scale=6.5,\n    num_inference_steps=25,\n).images[0]\n\nimage.save(\"kolors_sample.png\")\n```\n\n### IP Adapter\n\nKolors needs a different IP Adapter to work, and it uses [Openai-CLIP-336](https://huggingface.co/openai/clip-vit-large-patch14-336) as an image encoder.\n\n> [!TIP]\n> Using an IP Adapter with Kolors requires more than 24GB of VRAM. To use it, we recommend using [`~DiffusionPipeline.enable_model_cpu_offload`] on consumer GPUs.\n\n> [!TIP]\n> While Kolors is integrated in Diffusers, you need to load the image encoder from a revision to use the safetensor files. You can still use the main branch of the original repository if you're comfortable loading pickle checkpoints.\n\n```python\nimport torch\nfrom transformers import CLIPVisionModelWithProjection\n\nfrom diffusers import DPMSolverMultistepScheduler, KolorsPipeline\nfrom diffusers.utils import load_image\n\nimage_encoder = CLIPVisionModelWithProjection.from_pretrained(\n    \"Kwai-Kolors/Kolors-IP-Adapter-Plus\",\n    subfolder=\"image_encoder\",\n    low_cpu_mem_usage=True,\n    torch_dtype=torch.float16,\n    revision=\"refs/pr/4\",\n)\n\npipe = KolorsPipeline.from_pretrained(\n    \"Kwai-Kolors/Kolors-diffusers\", image_encoder=image_encoder, torch_dtype=torch.float16, variant=\"fp16\"\n)\npipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)\n\npipe.load_ip_adapter(\n    \"Kwai-Kolors/Kolors-IP-Adapter-Plus\",\n    subfolder=\"\",\n    weight_name=\"ip_adapter_plus_general.safetensors\",\n    revision=\"refs/pr/4\",\n    image_encoder_folder=None,\n)\npipe.enable_model_cpu_offload()\n\nipa_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/cat_square.png\")\n\nimage = pipe(\n    prompt=\"best quality, high quality\",\n    negative_prompt=\"\",\n    guidance_scale=6.5,\n    num_inference_steps=25,\n    ip_adapter_image=ipa_image,\n).images[0]\n\nimage.save(\"kolors_ipa_sample.png\")\n```\n\n## KolorsPipeline\n\n[[autodoc]] KolorsPipeline\n\n- all\n- __call__\n\n## KolorsImg2ImgPipeline\n\n[[autodoc]] KolorsImg2ImgPipeline\n\n- all\n- __call__\n\n"
  },
  {
    "path": "docs/source/en/api/pipelines/latent_consistency_models.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Latent Consistency Models\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nLatent Consistency Models (LCMs) were proposed in [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://huggingface.co/papers/2310.04378) by Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao.\n\nThe abstract of the paper is as follows:\n\n*Latent Diffusion models (LDMs) have achieved remarkable results in synthesizing high-resolution images. However, the iterative sampling process is computationally intensive and leads to slow generation. Inspired by Consistency Models (song et al.), we propose Latent Consistency Models (LCMs), enabling swift inference with minimal steps on any pre-trained LDMs, including Stable Diffusion (rombach et al). Viewing the guided reverse diffusion process as solving an augmented probability flow ODE (PF-ODE), LCMs are designed to directly predict the solution of such ODE in latent space, mitigating the need for numerous iterations and allowing rapid, high-fidelity sampling. Efficiently distilled from pre-trained classifier-free guided diffusion models, a high-quality 768 x 768 2~4-step LCM takes only 32 A100 GPU hours for training. Furthermore, we introduce Latent Consistency Fine-tuning (LCF), a novel method that is tailored for fine-tuning LCMs on customized image datasets. Evaluation on the LAION-5B-Aesthetics dataset demonstrates that LCMs achieve state-of-the-art text-to-image generation performance with few-step inference. Project Page: [this https URL](https://latent-consistency-models.github.io/).*\n\nA demo for the [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) checkpoint can be found [here](https://huggingface.co/spaces/SimianLuo/Latent_Consistency_Model).\n\nThe pipelines were contributed by [luosiallen](https://luosiallen.github.io/), [nagolinc](https://github.com/nagolinc), and [dg845](https://github.com/dg845).\n\n\n## LatentConsistencyModelPipeline\n\n[[autodoc]] LatentConsistencyModelPipeline\n    - all\n    - __call__\n    - enable_freeu\n    - disable_freeu\n    - enable_vae_slicing\n    - disable_vae_slicing\n    - enable_vae_tiling\n    - disable_vae_tiling\n\n## LatentConsistencyModelImg2ImgPipeline\n\n[[autodoc]] LatentConsistencyModelImg2ImgPipeline\n    - all\n    - __call__\n    - enable_freeu\n    - disable_freeu\n    - enable_vae_slicing\n    - disable_vae_slicing\n    - enable_vae_tiling\n    - disable_vae_tiling\n\n## StableDiffusionPipelineOutput\n\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/latent_diffusion.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Latent Diffusion\n\nLatent Diffusion was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.\n\nThe abstract from the paper is:\n\n*By decomposing the image formation process into a sequential application of denoising autoencoders, diffusion models (DMs) achieve state-of-the-art synthesis results on image data and beyond. Additionally, their formulation allows for a guiding mechanism to control the image generation process without retraining. However, since these models typically operate directly in pixel space, optimization of powerful DMs often consumes hundreds of GPU days and inference is expensive due to sequential evaluations. To enable DM training on limited computational resources while retaining their quality and flexibility, we apply them in the latent space of powerful pretrained autoencoders. In contrast to previous work, training diffusion models on such a representation allows for the first time to reach a near-optimal point between complexity reduction and detail preservation, greatly boosting visual fidelity. By introducing cross-attention layers into the model architecture, we turn diffusion models into powerful and flexible generators for general conditioning inputs such as text or bounding boxes and high-resolution synthesis becomes possible in a convolutional manner. Our latent diffusion models (LDMs) achieve a new state of the art for image inpainting and highly competitive performance on various tasks, including unconditional image generation, semantic scene synthesis, and super-resolution, while significantly reducing computational requirements compared to pixel-based DMs.*\n\nThe original codebase can be found at [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## LDMTextToImagePipeline\n[[autodoc]] LDMTextToImagePipeline\n\t- all\n\t- __call__\n\n## LDMSuperResolutionPipeline\n[[autodoc]] LDMSuperResolutionPipeline\n\t- all\n\t- __call__\n\n## ImagePipelineOutput\n[[autodoc]] pipelines.ImagePipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/latte.md",
    "content": "<!-- # Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n# Latte\n\n![latte text-to-video](https://github.com/Vchitect/Latte/blob/52bc0029899babbd6e9250384c83d8ed2670ff7a/visuals/latte.gif?raw=true)\n\n[Latte: Latent Diffusion Transformer for Video Generation](https://huggingface.co/papers/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University.\n\nThe abstract from the paper is:\n\n*We propose a novel Latent Diffusion Transformer, namely Latte, for video generation. Latte first extracts spatio-temporal tokens from input videos and then adopts a series of Transformer blocks to model video distribution in the latent space. In order to model a substantial number of tokens extracted from videos, four efficient variants are introduced from the perspective of decomposing the spatial and temporal dimensions of input videos. To improve the quality of generated videos, we determine the best practices of Latte through rigorous experimental analysis, including video clip patch embedding, model variants, timestep-class information injection, temporal positional embedding, and learning strategies. Our comprehensive evaluation demonstrates that Latte achieves state-of-the-art performance across four standard video generation datasets, i.e., FaceForensics, SkyTimelapse, UCF101, and Taichi-HD. In addition, we extend Latte to text-to-video generation (T2V) task, where Latte achieves comparable results compared to recent T2V models. We strongly believe that Latte provides valuable insights for future research on incorporating Transformers into diffusion models for video generation.*\n\n**Highlights**: Latte is a latent diffusion transformer proposed as a backbone for modeling different modalities (trained for text-to-video generation here). It achieves state-of-the-art performance across four standard video benchmarks - [FaceForensics](https://huggingface.co/papers/1803.09179), [SkyTimelapse](https://huggingface.co/papers/1709.07592), [UCF101](https://huggingface.co/papers/1212.0402) and [Taichi-HD](https://huggingface.co/papers/2003.00196). To prepare and download the datasets for evaluation, please refer to [this https URL](https://github.com/Vchitect/Latte/blob/main/docs/datasets_evaluation.md).\n\nThis pipeline was contributed by [maxin-cn](https://github.com/maxin-cn). The original codebase can be found [here](https://github.com/Vchitect/Latte). The original weights can be found under [hf.co/maxin-cn](https://huggingface.co/maxin-cn).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n### Inference\n\nUse [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.\n\nFirst, load the pipeline:\n\n```python\nimport torch\nfrom diffusers import LattePipeline\n\npipeline = LattePipeline.from_pretrained(\n\t\"maxin-cn/Latte-1\", torch_dtype=torch.float16\n).to(\"cuda\")\n```\n\nThen change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:\n\n```python\npipeline.transformer.to(memory_format=torch.channels_last)\npipeline.vae.to(memory_format=torch.channels_last)\n```\n\nFinally, compile the components and run inference:\n\n```python\npipeline.transformer = torch.compile(pipeline.transformer)\npipeline.vae.decode = torch.compile(pipeline.vae.decode)\n\nvideo = pipeline(prompt=\"A dog wearing sunglasses floating in space, surreal, nebulae in background\").frames[0]\n```\n\nThe [benchmark](https://gist.github.com/a-r-r-o-w/4e1694ca46374793c0361d740a99ff19) results on an 80GB A100 machine are:\n\n```\nWithout torch.compile(): Average inference time: 16.246 seconds.\nWith torch.compile(): Average inference time: 14.573 seconds.\n```\n\n## Quantization\n\nQuantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.\n\nRefer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LattePipeline`] for inference with bitsandbytes.\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, LatteTransformer3DModel, LattePipeline\nfrom diffusers.utils import export_to_gif\nfrom transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel\n\nquant_config = BitsAndBytesConfig(load_in_8bit=True)\ntext_encoder_8bit = T5EncoderModel.from_pretrained(\n    \"maxin-cn/Latte-1\",\n    subfolder=\"text_encoder\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_8bit = LatteTransformer3DModel.from_pretrained(\n    \"maxin-cn/Latte-1\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\npipeline = LattePipeline.from_pretrained(\n    \"maxin-cn/Latte-1\",\n    text_encoder=text_encoder_8bit,\n    transformer=transformer_8bit,\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n)\n\nprompt = \"A small cactus with a happy face in the Sahara desert.\"\nvideo = pipeline(prompt).frames[0]\nexport_to_gif(video, \"latte.gif\")\n```\n\n## LattePipeline\n\n[[autodoc]] LattePipeline\n  - all\n  - __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/ledits_pp.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# LEDITS++\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nLEDITS++ was proposed in [LEDITS++: Limitless Image Editing using Text-to-Image Models](https://huggingface.co/papers/2311.16711) by Manuel Brack, Felix Friedrich, Katharina Kornmeier, Linoy Tsaban, Patrick Schramowski, Kristian Kersting, Apolinário Passos.\n\nThe abstract from the paper is:\n\n*Text-to-image diffusion models have recently received increasing interest for their astonishing ability to produce high-fidelity images from solely text inputs. Subsequent research efforts aim to exploit and apply their capabilities to real image editing. However, existing image-to-image methods are often inefficient, imprecise, and of limited versatility. They either require time-consuming fine-tuning, deviate unnecessarily strongly from the input image, and/or lack support for multiple, simultaneous edits. To address these issues, we introduce LEDITS++, an efficient yet versatile and precise textual image manipulation technique. LEDITS++'s novel inversion approach requires no tuning nor optimization and produces high-fidelity results with a few diffusion steps. Second, our methodology supports multiple simultaneous edits and is architecture-agnostic. Third, we use a novel implicit masking technique that limits changes to relevant image regions. We propose the novel TEdBench++ benchmark as part of our exhaustive evaluation. Our results demonstrate the capabilities of LEDITS++ and its improvements over previous methods. The project page is available at https://leditsplusplus-project.static.hf.space .*\n\n> [!TIP]\n> You can find additional information about LEDITS++ on the [project page](https://leditsplusplus-project.static.hf.space/index.html) and try it out in a [demo](https://huggingface.co/spaces/editing-images/leditsplusplus).\n\n> [!WARNING]\n> Due to some backward compatibility issues with the current diffusers implementation of [`~schedulers.DPMSolverMultistepScheduler`] this implementation of LEdits++ can no longer guarantee perfect inversion.\n> This issue is unlikely to have any noticeable effects on applied use-cases. However, we provide an alternative implementation that guarantees perfect inversion in a dedicated [GitHub repo](https://github.com/ml-research/ledits_pp).\n\nWe provide two distinct pipelines based on different pre-trained models.\n\n## LEditsPPPipelineStableDiffusion\n[[autodoc]] pipelines.ledits_pp.LEditsPPPipelineStableDiffusion\n\t- all\n\t- __call__\n\t- invert\n\n## LEditsPPPipelineStableDiffusionXL\n[[autodoc]] pipelines.ledits_pp.LEditsPPPipelineStableDiffusionXL\n\t- all\n\t- __call__\n\t- invert\n\n\n\n## LEditsPPDiffusionPipelineOutput\n[[autodoc]] pipelines.ledits_pp.pipeline_output.LEditsPPDiffusionPipelineOutput\n\t- all\n\n## LEditsPPInversionPipelineOutput\n[[autodoc]] pipelines.ledits_pp.pipeline_output.LEditsPPInversionPipelineOutput\n\t- all\n"
  },
  {
    "path": "docs/source/en/api/pipelines/longcat_image.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# LongCat-Image\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n\nWe introduce LongCat-Image, a pioneering open-source and bilingual (Chinese-English) foundation model for image generation, designed to address core challenges in multilingual text rendering, photorealism, deployment efficiency, and developer accessibility prevalent in current leading models.\n\n\n### Key Features\n- 🌟 **Exceptional Efficiency and Performance**: With only **6B parameters**, LongCat-Image surpasses numerous open-source models that are several times larger across multiple benchmarks, demonstrating the immense potential of efficient model design.\n- 🌟 **Superior Editing Performance**: LongCat-Image-Edit model achieves state-of-the-art performance among open-source models, delivering leading instruction-following and image quality with superior visual consistency.\n- 🌟 **Powerful Chinese Text Rendering**: LongCat-Image demonstrates superior accuracy and stability in rendering common Chinese characters compared to existing SOTA open-source models and achieves industry-leading coverage of the Chinese dictionary.\n- 🌟 **Remarkable Photorealism**: Through an innovative data strategy and training framework, LongCat-Image achieves remarkable photorealism in generated images.\n- 🌟 **Comprehensive Open-Source Ecosystem**: We provide a complete toolchain, from intermediate checkpoints to full training code, significantly lowering the barrier for further research and development.\n\nFor more details, please refer to the comprehensive [***LongCat-Image Technical Report***](https://arxiv.org/abs/2412.11963)\n\n\n## Usage Example\n\n```py\nimport torch\nimport diffusers\nfrom diffusers import LongCatImagePipeline\n\nweight_dtype = torch.bfloat16\npipe = LongCatImagePipeline.from_pretrained(\"meituan-longcat/LongCat-Image\", torch_dtype=torch.bfloat16 )\npipe.to('cuda')\n# pipe.enable_model_cpu_offload()\n\nprompt = '一个年轻的亚裔女性，身穿黄色针织衫，搭配白色项链。她的双手放在膝盖上，表情恬静。背景是一堵粗糙的砖墙，午后的阳光温暖地洒在她身上，营造出一种宁静而温馨的氛围。镜头采用中距离视角，突出她的神态和服饰的细节。光线柔和地打在她的脸上，强调她的五官和饰品的质感，增加画面的层次感与亲和力。整个画面构图简洁，砖墙的纹理与阳光的光影效果相得益彰，突显出人物的优雅与从容。'\nimage = pipe(\n    prompt,\n    height=768,\n    width=1344,\n    guidance_scale=4.0,\n    num_inference_steps=50,\n    num_images_per_prompt=1,\n    generator=torch.Generator(\"cpu\").manual_seed(43),\n    enable_cfg_renorm=True,\n    enable_prompt_rewrite=True,\n).images[0]\nimage.save(f'./longcat_image_t2i_example.png')\n```\n\n\nThis pipeline was contributed by LongCat-Image Team. The original codebase can be found [here](https://github.com/meituan-longcat/LongCat-Image).\n\nAvailable models:\n<div style=\"overflow-x: auto; margin-bottom: 16px;\">\n  <table style=\"border-collapse: collapse; width: 100%;\">\n    <thead>\n      <tr>\n        <th style=\"white-space: nowrap; padding: 8px; border: 1px solid #d0d7de; background-color: #f6f8fa;\">Models</th>\n        <th style=\"white-space: nowrap; padding: 8px; border: 1px solid #d0d7de; background-color: #f6f8fa;\">Type</th>\n        <th style=\"padding: 8px; border: 1px solid #d0d7de; background-color: #f6f8fa;\">Description</th>\n        <th style=\"padding: 8px; border: 1px solid #d0d7de; background-color: #f6f8fa;\">Download Link</th>\n      </tr>\n    </thead>\n    <tbody>\n      <tr>\n        <td style=\"white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;\">LongCat&#8209;Image</td>\n        <td style=\"white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;\">Text&#8209;to&#8209;Image</td>\n        <td style=\"padding: 8px; border: 1px solid #d0d7de;\">Final Release. The standard model for out&#8209;of&#8209;the&#8209;box inference.</td>\n        <td style=\"padding: 8px; border: 1px solid #d0d7de;\">\n          <span style=\"white-space: nowrap;\">🤗&nbsp;<a href=\"https://huggingface.co/meituan-longcat/LongCat-Image\">Huggingface</a></span>\n        </td>\n      </tr>\n      <tr>\n        <td style=\"white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;\">LongCat&#8209;Image&#8209;Dev</td>\n        <td style=\"white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;\">Text&#8209;to&#8209;Image</td>\n        <td style=\"padding: 8px; border: 1px solid #d0d7de;\">Development. Mid-training checkpoint, suitable for fine-tuning.</td>\n        <td style=\"padding: 8px; border: 1px solid #d0d7de;\">\n          <span style=\"white-space: nowrap;\">🤗&nbsp;<a href=\"https://huggingface.co/meituan-longcat/LongCat-Image-Dev\">Huggingface</a></span>\n        </td>\n      </tr>\n      <tr>\n        <td style=\"white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;\">LongCat&#8209;Image&#8209;Edit</td>\n        <td style=\"white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;\">Image Editing</td>\n        <td style=\"padding: 8px; border: 1px solid #d0d7de;\">Specialized model for image editing.</td>\n        <td style=\"padding: 8px; border: 1px solid #d0d7de;\">\n          <span style=\"white-space: nowrap;\">🤗&nbsp;<a href=\"https://huggingface.co/meituan-longcat/LongCat-Image-Edit\">Huggingface</a></span>\n        </td>\n      </tr>\n    </tbody>\n  </table>\n</div>\n\n## LongCatImagePipeline\n\n[[autodoc]] LongCatImagePipeline\n- all\n- __call__\n\n## LongCatImagePipelineOutput\n\n[[autodoc]] pipelines.longcat_image.pipeline_output.LongCatImagePipelineOutput\n\n\n\n"
  },
  {
    "path": "docs/source/en/api/pipelines/ltx2.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n# LTX-2\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nLTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.\n\nYou can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.\n\nThe original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).\n\n## Two-stages Generation\nRecommended pipeline to achieve production quality generation, this pipeline is composed of two stages:\n\n- Stage 1: Generate a video at the target resolution using diffusion sampling with classifier-free guidance (CFG). This stage produces a coherent low-noise video sequence that respects the text/image conditioning.\n- Stage 2: Upsample the Stage 1 output by 2 and refine details using a distilled LoRA model to improve fidelity and visual quality. Stage 2 may apply lighter CFG to preserve the structure from Stage 1 while enhancing texture and sharpness.\n\nSample usage of text-to-video two stages pipeline\n\n```py\nimport torch\nfrom diffusers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline\nfrom diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel\nfrom diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES\nfrom diffusers.pipelines.ltx2.export_utils import encode_video\n\ndevice = \"cuda:0\"\nwidth = 768\nheight = 512\n\npipe = LTX2Pipeline.from_pretrained(\n    \"Lightricks/LTX-2\", torch_dtype=torch.bfloat16\n)\npipe.enable_sequential_cpu_offload(device=device)\n\nprompt = \"A beautiful sunset over the ocean\"\nnegative_prompt = \"shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static.\"\n\n# Stage 1 default (non-distilled) inference\nframe_rate = 24.0\nvideo_latent, audio_latent = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    width=width,\n    height=height,\n    num_frames=121,\n    frame_rate=frame_rate,\n    num_inference_steps=40,\n    sigmas=None,\n    guidance_scale=4.0,\n    output_type=\"latent\",\n    return_dict=False,\n)\n\nlatent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(\n    \"Lightricks/LTX-2\",\n    subfolder=\"latent_upsampler\",\n    torch_dtype=torch.bfloat16,\n)\nupsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)\nupsample_pipe.enable_model_cpu_offload(device=device)\nupscaled_video_latent = upsample_pipe(\n    latents=video_latent,\n    output_type=\"latent\",\n    return_dict=False,\n)[0]\n\n# Load Stage 2 distilled LoRA\npipe.load_lora_weights(\n    \"Lightricks/LTX-2\", adapter_name=\"stage_2_distilled\", weight_name=\"ltx-2-19b-distilled-lora-384.safetensors\"\n)\npipe.set_adapters(\"stage_2_distilled\", 1.0)\n# VAE tiling is usually necessary to avoid OOM error when VAE decoding\npipe.vae.enable_tiling()\n# Change scheduler to use Stage 2 distilled sigmas as is\nnew_scheduler = FlowMatchEulerDiscreteScheduler.from_config(\n    pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None\n)\npipe.scheduler = new_scheduler\n# Stage 2 inference with distilled LoRA and sigmas\nvideo, audio = pipe(\n    latents=upscaled_video_latent,\n    audio_latents=audio_latent,\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_inference_steps=3,\n    noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py#L218\n    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,\n    guidance_scale=1.0,\n    output_type=\"np\",\n    return_dict=False,\n)\n\nencode_video(\n    video[0],\n    fps=frame_rate,\n    audio=audio[0].float().cpu(),\n    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,\n    output_path=\"ltx2_lora_distilled_sample.mp4\",\n)\n```\n\n## Distilled checkpoint generation\nFastest two-stages generation pipeline using a distilled checkpoint.\n\n```py\nimport torch\nfrom diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline\nfrom diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel\nfrom diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES\nfrom diffusers.pipelines.ltx2.export_utils import encode_video\n\ndevice = \"cuda\"\nwidth = 768\nheight = 512\nrandom_seed = 42\ngenerator = torch.Generator(device).manual_seed(random_seed)\nmodel_path = \"rootonchair/LTX-2-19b-distilled\"\n\npipe = LTX2Pipeline.from_pretrained(\n    model_path, torch_dtype=torch.bfloat16\n)\npipe.enable_sequential_cpu_offload(device=device)\n\nprompt = \"A beautiful sunset over the ocean\"\nnegative_prompt = \"shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static.\"\n\nframe_rate = 24.0\nvideo_latent, audio_latent = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    width=width,\n    height=height,\n    num_frames=121,\n    frame_rate=frame_rate,\n    num_inference_steps=8,\n    sigmas=DISTILLED_SIGMA_VALUES,\n    guidance_scale=1.0,\n    generator=generator,\n    output_type=\"latent\",\n    return_dict=False,\n)\n\nlatent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(\n    model_path,\n    subfolder=\"latent_upsampler\",\n    torch_dtype=torch.bfloat16,\n)\nupsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)\nupsample_pipe.enable_model_cpu_offload(device=device)\nupscaled_video_latent = upsample_pipe(\n    latents=video_latent,\n    output_type=\"latent\",\n    return_dict=False,\n)[0]\n\nvideo, audio = pipe(\n    latents=upscaled_video_latent,\n    audio_latents=audio_latent,\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_inference_steps=3,\n    noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/distilled.py#L178\n    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,\n    generator=generator,\n    guidance_scale=1.0,\n    output_type=\"np\",\n    return_dict=False,\n)\n\nencode_video(\n    video[0],\n    fps=frame_rate,\n    audio=audio[0].float().cpu(),\n    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,\n    output_path=\"ltx2_distilled_sample.mp4\",\n)\n```\n\n## Condition Pipeline Generation\n\nYou can use `LTX2ConditionPipeline` to specify image and/or video conditions at arbitrary latent indices. For example, we can specify both a first-frame and last-frame condition to perform first-last-frame-to-video (FLF2V) generation:\n\n```py\nimport torch\nfrom diffusers import LTX2ConditionPipeline, LTX2LatentUpsamplePipeline\nfrom diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel\nfrom diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition\nfrom diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES\nfrom diffusers.pipelines.ltx2.export_utils import encode_video\nfrom diffusers.utils import load_image\n\ndevice = \"cuda\"\nwidth = 768\nheight = 512\nrandom_seed = 42\ngenerator = torch.Generator(device).manual_seed(random_seed)\nmodel_path = \"rootonchair/LTX-2-19b-distilled\"\n\npipe = LTX2ConditionPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)\npipe.enable_sequential_cpu_offload(device=device)\npipe.vae.enable_tiling()\n\nprompt = (\n    \"CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are \"\n    \"delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright \"\n    \"sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, \"\n    \"low-angle perspective.\"\n)\n\nfirst_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png\",\n)\nlast_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png\",\n)\nfirst_cond = LTX2VideoCondition(frames=first_image, index=0, strength=1.0)\nlast_cond = LTX2VideoCondition(frames=last_image, index=-1, strength=1.0)\nconditions = [first_cond, last_cond]\n\nframe_rate = 24.0\nvideo_latent, audio_latent = pipe(\n    conditions=conditions,\n    prompt=prompt,\n    width=width,\n    height=height,\n    num_frames=121,\n    frame_rate=frame_rate,\n    num_inference_steps=8,\n    sigmas=DISTILLED_SIGMA_VALUES,\n    guidance_scale=1.0,\n    generator=generator,\n    output_type=\"latent\",\n    return_dict=False,\n)\n\nlatent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(\n    model_path,\n    subfolder=\"latent_upsampler\",\n    torch_dtype=torch.bfloat16,\n)\nupsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)\nupsample_pipe.enable_model_cpu_offload(device=device)\nupscaled_video_latent = upsample_pipe(\n    latents=video_latent,\n    output_type=\"latent\",\n    return_dict=False,\n)[0]\n\nvideo, audio = pipe(\n    latents=upscaled_video_latent,\n    audio_latents=audio_latent,\n    prompt=prompt,\n    width=width * 2,\n    height=height * 2,\n    num_inference_steps=3,\n    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,\n    generator=generator,\n    guidance_scale=1.0,\n    output_type=\"np\",\n    return_dict=False,\n)\n\nencode_video(\n    video[0],\n    fps=frame_rate,\n    audio=audio[0].float().cpu(),\n    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,\n    output_path=\"ltx2_distilled_flf2v.mp4\",\n)\n```\n\nYou can use both image and video conditions:\n\n```py\nimport torch\nfrom diffusers import LTX2ConditionPipeline\nfrom diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition\nfrom diffusers.pipelines.ltx2.export_utils import encode_video\nfrom diffusers.utils import load_image, load_video\n\ndevice = \"cuda\"\nwidth = 768\nheight = 512\nrandom_seed = 42\ngenerator = torch.Generator(device).manual_seed(random_seed)\nmodel_path = \"rootonchair/LTX-2-19b-distilled\"\n\npipe = LTX2ConditionPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)\npipe.enable_sequential_cpu_offload(device=device)\npipe.vae.enable_tiling()\n\nprompt = (\n    \"The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is \"\n    \"divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features \"\n    \"dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered \"\n    \"clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, \"\n    \"with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The \"\n    \"landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the \"\n    \"solitude and beauty of a winter drive through a mountainous region.\"\n)\nnegative_prompt = (\n    \"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, \"\n    \"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, \"\n    \"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, \"\n    \"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of \"\n    \"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent \"\n    \"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny \"\n    \"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, \"\n    \"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, \"\n    \"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward \"\n    \"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, \"\n    \"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts.\"\n)\n\ncond_video = load_video(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4\"\n)\ncond_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg\"\n)\nvideo_cond = LTX2VideoCondition(frames=cond_video, index=0, strength=1.0)\nimage_cond = LTX2VideoCondition(frames=cond_image, index=8, strength=1.0)\nconditions = [video_cond, image_cond]\n\nframe_rate = 24.0\nvideo, audio = pipe(\n    conditions=conditions,\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    width=width,\n    height=height,\n    num_frames=121,\n    frame_rate=frame_rate,\n    num_inference_steps=40,\n    guidance_scale=4.0,\n    generator=generator,\n    output_type=\"np\",\n    return_dict=False,\n)\n\nencode_video(\n    video[0],\n    fps=frame_rate,\n    audio=audio[0].float().cpu(),\n    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,\n    output_path=\"ltx2_cond_video.mp4\",\n)\n```\n\nBecause the conditioning is done via latent frames, the 8 data space frames corresponding to the specified latent frame for an image condition will tend to be static.\n\n## LTX2Pipeline\n\n[[autodoc]] LTX2Pipeline\n  - all\n  - __call__\n\n## LTX2ImageToVideoPipeline\n\n[[autodoc]] LTX2ImageToVideoPipeline\n  - all\n  - __call__\n\n## LTX2ConditionPipeline\n\n[[autodoc]] LTX2ConditionPipeline\n  - all\n  - __call__\n\n## LTX2LatentUpsamplePipeline\n\n[[autodoc]] LTX2LatentUpsamplePipeline\n  - all\n  - __call__\n\n## LTX2PipelineOutput\n\n[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/ltx_video.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n<div style=\"float: right;\">\n  <div class=\"flex flex-wrap space-x-1\">\n    <a href=\"https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference\" target=\"_blank\" rel=\"noopener\">\n      <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n    </a>\n    <img alt=\"MPS\" src=\"https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22\">\n  </div>\n</div>\n\n# LTX-Video\n\n[LTX-Video](https://huggingface.co/Lightricks/LTX-Video) is a diffusion transformer designed for fast and real-time generation of high-resolution videos from text and images. The main feature of LTX-Video is the Video-VAE. The Video-VAE has a higher pixel to latent compression ratio (1:192) which enables more efficient video data processing and faster generation speed. To support and prevent finer details from being lost during generation, the Video-VAE decoder performs the latent to pixel conversion *and* the last denoising step.\n\nYou can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.\n\n> [!TIP]\n> Click on the LTX-Video models in the right sidebar for more examples of other video generation tasks.\n\nThe example below demonstrates how to generate a video optimized for memory or inference speed.\n\n<hfoptions id=\"usage\">\n<hfoption id=\"memory\">\n\nRefer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.\n\nThe LTX-Video model below requires ~10GB of VRAM.\n\n```py\nimport torch\nfrom diffusers import LTXPipeline, AutoModel\nfrom diffusers.hooks import apply_group_offloading\nfrom diffusers.utils import export_to_video\n\n# fp8 layerwise weight-casting\ntransformer = AutoModel.from_pretrained(\n    \"Lightricks/LTX-Video\",\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16\n)\ntransformer.enable_layerwise_casting(\n    storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16\n)\n\npipeline = LTXPipeline.from_pretrained(\"Lightricks/LTX-Video\", transformer=transformer, torch_dtype=torch.bfloat16)\n\n# group-offloading\nonload_device = torch.device(\"cuda\")\noffload_device = torch.device(\"cpu\")\npipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=\"leaf_level\", use_stream=True)\napply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type=\"block_level\", num_blocks_per_group=2)\napply_group_offloading(pipeline.vae, onload_device=onload_device, offload_type=\"leaf_level\")\n\nprompt = \"\"\"\nA woman with long brown hair and light skin smiles at another woman with long blonde hair.\nThe woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek.\nThe camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and \nnatural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage\n\"\"\"\nnegative_prompt = \"worst quality, inconsistent motion, blurry, jittery, distorted\"\n\nvideo = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    width=768,\n    height=512,\n    num_frames=161,\n    decode_timestep=0.03,\n    decode_noise_scale=0.025,\n    num_inference_steps=50,\n).frames[0]\nexport_to_video(video, \"output.mp4\", fps=24)\n```\n\n</hfoption>\n<hfoption id=\"inference speed\">\n\n[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.\n\n```py\nimport torch\nfrom diffusers import LTXPipeline\nfrom diffusers.utils import export_to_video\n\npipeline = LTXPipeline.from_pretrained(\n    \"Lightricks/LTX-Video\", torch_dtype=torch.bfloat16\n)\n\n# torch.compile\npipeline.transformer.to(memory_format=torch.channels_last)\npipeline.transformer = torch.compile(\n    pipeline.transformer, mode=\"max-autotune\", fullgraph=True\n)\n\nprompt = \"\"\"\nA woman with long brown hair and light skin smiles at another woman with long blonde hair.\nThe woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek.\nThe camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and \nnatural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage\n\"\"\"\nnegative_prompt = \"worst quality, inconsistent motion, blurry, jittery, distorted\"\n\nvideo = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    width=768,\n    height=512,\n    num_frames=161,\n    decode_timestep=0.03,\n    decode_noise_scale=0.025,\n    num_inference_steps=50,\n).frames[0]\nexport_to_video(video, \"output.mp4\", fps=24)\n```\n\n</hfoption>\n</hfoptions>\n\n## Notes\n\n- Refer to the following recommended settings for generation from the [LTX-Video](https://github.com/Lightricks/LTX-Video) repository.\n\n  - The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`.\n  - For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality.\n  - For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.\n  - For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitions in the generated video.\n\n- LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined.\n\n  <details>\n  <summary>Show example code</summary>\n\n  ```py\n  import torch\n  from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline\n  from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition\n  from diffusers.utils import export_to_video, load_video\n\n  pipeline = LTXConditionPipeline.from_pretrained(\"Lightricks/LTX-Video-0.9.7-dev\", torch_dtype=torch.bfloat16)\n  pipeline_upsample = LTXLatentUpsamplePipeline.from_pretrained(\"Lightricks/ltxv-spatial-upscaler-0.9.7\", vae=pipeline.vae, torch_dtype=torch.bfloat16)\n  pipeline.to(\"cuda\")\n  pipe_upsample.to(\"cuda\")\n  pipeline.vae.enable_tiling()\n\n  def round_to_nearest_resolution_acceptable_by_vae(height, width):\n      height = height - (height % pipeline.vae_temporal_compression_ratio)\n      width = width - (width % pipeline.vae_temporal_compression_ratio)\n      return height, width\n\n  video = load_video(\n      \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4\"\n  )[:21]  # only use the first 21 frames as conditioning\n  condition1 = LTXVideoCondition(video=video, frame_index=0)\n\n  prompt = \"\"\"\n  The video depicts a winding mountain road covered in snow, with a single vehicle \n  traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. \n  The landscape is characterized by rugged terrain and a river visible in the distance. \n  The scene captures the solitude and beauty of a winter drive through a mountainous region.\n  \"\"\"\n  negative_prompt = \"worst quality, inconsistent motion, blurry, jittery, distorted\"\n  expected_height, expected_width = 768, 1152\n  downscale_factor = 2 / 3\n  num_frames = 161\n\n  # 1. Generate video at smaller resolution\n  # Text-only conditioning is also supported without the need to pass `conditions`\n  downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)\n  downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)\n  latents = pipeline(\n      conditions=[condition1],\n      prompt=prompt,\n      negative_prompt=negative_prompt,\n      width=downscaled_width,\n      height=downscaled_height,\n      num_frames=num_frames,\n      num_inference_steps=30,\n      decode_timestep=0.05,\n      decode_noise_scale=0.025,\n      image_cond_noise_scale=0.0,\n      guidance_scale=5.0,\n      guidance_rescale=0.7,\n      generator=torch.Generator().manual_seed(0),\n      output_type=\"latent\",\n  ).frames\n\n  # 2. Upscale generated video using latent upsampler with fewer inference steps\n  # The available latent upsampler upscales the height/width by 2x\n  upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2\n  upscaled_latents = pipe_upsample(\n      latents=latents,\n      output_type=\"latent\"\n  ).frames\n\n  # 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)\n  video = pipeline(\n      conditions=[condition1],\n      prompt=prompt,\n      negative_prompt=negative_prompt,\n      width=upscaled_width,\n      height=upscaled_height,\n      num_frames=num_frames,\n      denoise_strength=0.4,  # Effectively, 4 inference steps out of 10\n      num_inference_steps=10,\n      latents=upscaled_latents,\n      decode_timestep=0.05,\n      decode_noise_scale=0.025,\n      image_cond_noise_scale=0.0,\n      guidance_scale=5.0,\n      guidance_rescale=0.7,\n      generator=torch.Generator().manual_seed(0),\n      output_type=\"pil\",\n  ).frames[0]\n\n  # 4. Downscale the video to the expected resolution\n  video = [frame.resize((expected_width, expected_height)) for frame in video]\n\n  export_to_video(video, \"output.mp4\", fps=24)\n  ```\n\n  </details>\n\n- LTX-Video 0.9.7 distilled model is guidance and timestep-distilled to speedup generation. It requires `guidance_scale` to be set to `1.0` and `num_inference_steps` should be set between `4` and `10` for good generation quality. You should also use the following custom timesteps for the best results.\n\n  - Base model inference to prepare for upscaling: `[1000, 993, 987, 981, 975, 909, 725, 0.03]`.\n  - Upscaling: `[1000, 909, 725, 421, 0]`.\n\n  <details>\n  <summary>Show example code</summary>\n\n  ```py\n  import torch\n  from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline\n  from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition\n  from diffusers.utils import export_to_video, load_video\n\n  pipeline = LTXConditionPipeline.from_pretrained(\"Lightricks/LTX-Video-0.9.7-distilled\", torch_dtype=torch.bfloat16)\n  pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(\"Lightricks/ltxv-spatial-upscaler-0.9.7\", vae=pipeline.vae, torch_dtype=torch.bfloat16)\n  pipeline.to(\"cuda\")\n  pipe_upsample.to(\"cuda\")\n  pipeline.vae.enable_tiling()\n\n  def round_to_nearest_resolution_acceptable_by_vae(height, width):\n      height = height - (height % pipeline.vae_spatial_compression_ratio)\n      width = width - (width % pipeline.vae_spatial_compression_ratio)\n      return height, width\n\n  prompt = \"\"\"\n  artistic anatomical 3d render, utlra quality, human half full male body with transparent \n  skin revealing structure instead of organs, muscular, intricate creative patterns, \n  monochromatic with backlighting, lightning mesh, scientific concept art, blending biology \n  with botany, surreal and ethereal quality, unreal engine 5, ray tracing, ultra realistic, \n  16K UHD, rich details. camera zooms out in a rotating fashion\n  \"\"\"\n  negative_prompt = \"worst quality, inconsistent motion, blurry, jittery, distorted\"\n  expected_height, expected_width = 768, 1152\n  downscale_factor = 2 / 3\n  num_frames = 161\n\n  # 1. Generate video at smaller resolution\n  downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)\n  downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)\n  latents = pipeline(\n      prompt=prompt,\n      negative_prompt=negative_prompt,\n      width=downscaled_width,\n      height=downscaled_height,\n      num_frames=num_frames,\n      timesteps=[1000, 993, 987, 981, 975, 909, 725, 0.03],\n      decode_timestep=0.05,\n      decode_noise_scale=0.025,\n      image_cond_noise_scale=0.0,\n      guidance_scale=1.0,\n      guidance_rescale=0.7,\n      generator=torch.Generator().manual_seed(0),\n      output_type=\"latent\",\n  ).frames\n\n  # 2. Upscale generated video using latent upsampler with fewer inference steps\n  # The available latent upsampler upscales the height/width by 2x\n  upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2\n  upscaled_latents = pipe_upsample(\n      latents=latents,\n      adain_factor=1.0,\n      output_type=\"latent\"\n  ).frames\n\n  # 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)\n  video = pipeline(\n      prompt=prompt,\n      negative_prompt=negative_prompt,\n      width=upscaled_width,\n      height=upscaled_height,\n      num_frames=num_frames,\n      denoise_strength=0.999,  # Effectively, 4 inference steps out of 5\n      timesteps=[1000, 909, 725, 421, 0],\n      latents=upscaled_latents,\n      decode_timestep=0.05,\n      decode_noise_scale=0.025,\n      image_cond_noise_scale=0.0,\n      guidance_scale=1.0,\n      guidance_rescale=0.7,\n      generator=torch.Generator().manual_seed(0),\n      output_type=\"pil\",\n  ).frames[0]\n\n  # 4. Downscale the video to the expected resolution\n  video = [frame.resize((expected_width, expected_height)) for frame in video]\n\n  export_to_video(video, \"output.mp4\", fps=24)\n  ```\n\n  </details>\n\n- LTX-Video 0.9.8 distilled model is similar to the 0.9.7 variant. It is guidance and timestep-distilled, and similar inference code can be used as above. An improvement of this version is that it supports generating very long videos. Additionally, it supports using tone mapping to improve the quality of the generated video using the `tone_map_compression_ratio` parameter. The default value of `0.6` is recommended.\n\n  <details>\n  <summary>Show example code</summary>\n\n  ```python\n  import torch\n  from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline\n  from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition\n  from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel\n  from diffusers.utils import export_to_video, load_video\n\n  pipeline = LTXConditionPipeline.from_pretrained(\"Lightricks/LTX-Video-0.9.8-13B-distilled\", torch_dtype=torch.bfloat16)\n  # TODO: Update the checkpoint here once updated in LTX org\n  upsampler = LTXLatentUpsamplerModel.from_pretrained(\"a-r-r-o-w/LTX-0.9.8-Latent-Upsampler\", torch_dtype=torch.bfloat16)\n  pipe_upsample = LTXLatentUpsamplePipeline(vae=pipeline.vae, latent_upsampler=upsampler).to(torch.bfloat16)\n  pipeline.to(\"cuda\")\n  pipe_upsample.to(\"cuda\")\n  pipeline.vae.enable_tiling()\n\n  def round_to_nearest_resolution_acceptable_by_vae(height, width):\n      height = height - (height % pipeline.vae_spatial_compression_ratio)\n      width = width - (width % pipeline.vae_spatial_compression_ratio)\n      return height, width\n\n  prompt = \"\"\"The camera pans over a snow-covered mountain range, revealing a vast expanse of snow-capped peaks and valleys.The mountains are covered in a thick layer of snow, with some areas appearing almost white while others have a slightly darker, almost grayish hue. The peaks are jagged and irregular, with some rising sharply into the sky while others are more rounded. The valleys are deep and narrow, with steep slopes that are also covered in snow. The trees in the foreground are mostly bare, with only a few leaves remaining on their branches. The sky is overcast, with thick clouds obscuring the sun. The overall impression is one of peace and tranquility, with the snow-covered mountains standing as a testament to the power and beauty of nature.\"\"\"\n  # prompt = \"\"\"A woman walks away from a white Jeep parked on a city street at night, then ascends a staircase and knocks on a door. The woman, wearing a dark jacket and jeans, walks away from the Jeep parked on the left side of the street, her back to the camera; she walks at a steady pace, her arms swinging slightly by her sides; the street is dimly lit, with streetlights casting pools of light on the wet pavement; a man in a dark jacket and jeans walks past the Jeep in the opposite direction; the camera follows the woman from behind as she walks up a set of stairs towards a building with a green door; she reaches the top of the stairs and turns left, continuing to walk towards the building; she reaches the door and knocks on it with her right hand; the camera remains stationary, focused on the doorway; the scene is captured in real-life footage.\"\"\"\n  negative_prompt = \"bright colors, symbols, graffiti, watermarks, worst quality, inconsistent motion, blurry, jittery, distorted\"\n  expected_height, expected_width = 480, 832\n  downscale_factor = 2 / 3\n  # num_frames = 161\n  num_frames = 361\n\n  # 1. Generate video at smaller resolution\n  downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)\n  downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)\n  latents = pipeline(\n      prompt=prompt,\n      negative_prompt=negative_prompt,\n      width=downscaled_width,\n      height=downscaled_height,\n      num_frames=num_frames,\n      timesteps=[1000, 993, 987, 981, 975, 909, 725, 0.03],\n      decode_timestep=0.05,\n      decode_noise_scale=0.025,\n      image_cond_noise_scale=0.0,\n      guidance_scale=1.0,\n      guidance_rescale=0.7,\n      generator=torch.Generator().manual_seed(0),\n      output_type=\"latent\",\n  ).frames\n\n  # 2. Upscale generated video using latent upsampler with fewer inference steps\n  # The available latent upsampler upscales the height/width by 2x\n  upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2\n  upscaled_latents = pipe_upsample(\n      latents=latents,\n      adain_factor=1.0,\n      tone_map_compression_ratio=0.6,\n      output_type=\"latent\"\n  ).frames\n\n  # 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)\n  video = pipeline(\n      prompt=prompt,\n      negative_prompt=negative_prompt,\n      width=upscaled_width,\n      height=upscaled_height,\n      num_frames=num_frames,\n      denoise_strength=0.999,  # Effectively, 4 inference steps out of 5\n      timesteps=[1000, 909, 725, 421, 0],\n      latents=upscaled_latents,\n      decode_timestep=0.05,\n      decode_noise_scale=0.025,\n      image_cond_noise_scale=0.0,\n      guidance_scale=1.0,\n      guidance_rescale=0.7,\n      generator=torch.Generator().manual_seed(0),\n      output_type=\"pil\",\n  ).frames[0]\n\n  # 4. Downscale the video to the expected resolution\n  video = [frame.resize((expected_width, expected_height)) for frame in video]\n\n  export_to_video(video, \"output.mp4\", fps=24)\n  ```\n\n  </details>\n\n- LTX-Video supports LoRAs with [`~loaders.LTXVideoLoraLoaderMixin.load_lora_weights`].\n\n  <details>\n  <summary>Show example code</summary>\n\n  ```py\n  import torch\n  from diffusers import LTXConditionPipeline\n  from diffusers.utils import export_to_video, load_image\n\n  pipeline = LTXConditionPipeline.from_pretrained(\n      \"Lightricks/LTX-Video-0.9.5\", torch_dtype=torch.bfloat16\n  )\n\n  pipeline.load_lora_weights(\"Lightricks/LTX-Video-Cakeify-LoRA\", adapter_name=\"cakeify\")\n  pipeline.set_adapters(\"cakeify\")\n\n  # use \"CAKEIFY\" to trigger the LoRA\n  prompt = \"CAKEIFY a person using a knife to cut a cake shaped like a Pikachu plushie\"\n  image = load_image(\"https://huggingface.co/Lightricks/LTX-Video-Cakeify-LoRA/resolve/main/assets/images/pikachu.png\")\n\n  video = pipeline(\n      prompt=prompt,\n      image=image,\n      width=576,\n      height=576,\n      num_frames=161,\n      decode_timestep=0.03,\n      decode_noise_scale=0.025,\n      num_inference_steps=50,\n  ).frames[0]\n  export_to_video(video, \"output.mp4\", fps=26)\n  ```\n\n  </details>\n\n- LTX-Video supports loading from single files, such as [GGUF checkpoints](../../quantization/gguf), with [`loaders.FromOriginalModelMixin.from_single_file`] or [`loaders.FromSingleFileMixin.from_single_file`].\n\n  <details>\n  <summary>Show example code</summary>\n\n  ```py\n  import torch\n  from diffusers.utils import export_to_video\n  from diffusers import LTXPipeline, AutoModel, GGUFQuantizationConfig\n\n  transformer = AutoModel.from_single_file(\n      \"https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf\"\n      quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),\n      torch_dtype=torch.bfloat16\n  )\n  pipeline = LTXPipeline.from_pretrained(\n      \"Lightricks/LTX-Video\",\n      transformer=transformer,\n      torch_dtype=torch.bfloat16\n  )\n  ```\n\n  </details>\n\n## LTXI2VLongMultiPromptPipeline\n\n[[autodoc]] LTXI2VLongMultiPromptPipeline\n  - all\n  - __call__\n\n## LTXPipeline\n\n[[autodoc]] LTXPipeline\n  - all\n  - __call__\n\n## LTXImageToVideoPipeline\n\n[[autodoc]] LTXImageToVideoPipeline\n  - all\n  - __call__\n\n## LTXConditionPipeline\n\n[[autodoc]] LTXConditionPipeline\n  - all\n  - __call__\n\n## LTXLatentUpsamplePipeline\n\n[[autodoc]] LTXLatentUpsamplePipeline\n  - all\n  - __call__\n\n## LTXPipelineOutput\n\n[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/lumina.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Lumina-T2X\n![concepts](https://github.com/Alpha-VLLM/Lumina-T2X/assets/54879512/9f52eabb-07dc-4881-8257-6d8a5f2a0a5a)\n\n[Lumina-Next : Making Lumina-T2X Stronger and Faster with Next-DiT](https://github.com/Alpha-VLLM/Lumina-T2X/blob/main/assets/lumina-next.pdf) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory.\n\nThe abstract from the paper is:\n\n*Lumina-T2X is a nascent family of Flow-based Large Diffusion Transformers (Flag-DiT) that establishes a unified framework for transforming noise into various modalities, such as images and videos, conditioned on text instructions. Despite its promising capabilities, Lumina-T2X still encounters challenges including training instability, slow inference, and extrapolation artifacts. In this paper, we present Lumina-Next, an improved version of Lumina-T2X, showcasing stronger generation performance with increased training and inference efficiency. We begin with a comprehensive analysis of the Flag-DiT architecture and identify several suboptimal components, which we address by introducing the Next-DiT architecture with 3D RoPE and sandwich normalizations. To enable better resolution extrapolation, we thoroughly compare different context extrapolation methods applied to text-to-image generation with 3D RoPE, and propose Frequency- and Time-Aware Scaled RoPE tailored for diffusion transformers. Additionally, we introduce a sigmoid time discretization schedule to reduce sampling steps in solving the Flow ODE and the Context Drop method to merge redundant visual tokens for faster network evaluation, effectively boosting the overall sampling speed. Thanks to these improvements, Lumina-Next not only improves the quality and efficiency of basic text-to-image generation but also demonstrates superior resolution extrapolation capabilities and multilingual generation using decoder-based LLMs as the text encoder, all in a zero-shot manner. To further validate Lumina-Next as a versatile generative framework, we instantiate it on diverse tasks including visual recognition, multi-view, audio, music, and point cloud generation, showcasing strong performance across these domains. By releasing all codes and model weights at https://github.com/Alpha-VLLM/Lumina-T2X, we aim to advance the development of next-generation generative AI capable of universal modeling.*\n\n**Highlights**: Lumina-Next is a next-generation Diffusion Transformer that significantly enhances text-to-image generation, multilingual generation, and multitask performance by introducing the Next-DiT architecture, 3D RoPE, and frequency- and time-aware RoPE, among other improvements.\n\nLumina-Next has the following components:\n* It improves sampling efficiency with fewer and faster Steps.\n* It uses a Next-DiT as a transformer backbone with Sandwichnorm 3D RoPE, and Grouped-Query Attention.\n* It uses a Frequency- and Time-Aware Scaled RoPE.\n\n---\n\n[Lumina-T2X: Transforming Text into Any Modality, Resolution, and Duration via Flow-based Large Diffusion Transformers](https://huggingface.co/papers/2405.05945) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory.\n\nThe abstract from the paper is:\n\n*Sora unveils the potential of scaling Diffusion Transformer for generating photorealistic images and videos at arbitrary resolutions, aspect ratios, and durations, yet it still lacks sufficient implementation details. In this technical report, we introduce the Lumina-T2X family - a series of Flow-based Large Diffusion Transformers (Flag-DiT) equipped with zero-initialized attention, as a unified framework designed to transform noise into images, videos, multi-view 3D objects, and audio clips conditioned on text instructions. By tokenizing the latent spatial-temporal space and incorporating learnable placeholders such as [nextline] and [nextframe] tokens, Lumina-T2X seamlessly unifies the representations of different modalities across various spatial-temporal resolutions. This unified approach enables training within a single framework for different modalities and allows for flexible generation of multimodal data at any resolution, aspect ratio, and length during inference. Advanced techniques like RoPE, RMSNorm, and flow matching enhance the stability, flexibility, and scalability of Flag-DiT, enabling models of Lumina-T2X to scale up to 7 billion parameters and extend the context window to 128K tokens. This is particularly beneficial for creating ultra-high-definition images with our Lumina-T2I model and long 720p videos with our Lumina-T2V model. Remarkably, Lumina-T2I, powered by a 5-billion-parameter Flag-DiT, requires only 35% of the training computational costs of a 600-million-parameter naive DiT. Our further comprehensive analysis underscores Lumina-T2X's preliminary capability in resolution extrapolation, high-resolution editing, generating consistent 3D views, and synthesizing videos with seamless transitions. We expect that the open-sourcing of Lumina-T2X will further foster creativity, transparency, and diversity in the generative AI community.*\n\n\nYou can find the original codebase at [Alpha-VLLM](https://github.com/Alpha-VLLM/Lumina-T2X) and all the available checkpoints at [Alpha-VLLM Lumina Family](https://huggingface.co/collections/Alpha-VLLM/lumina-family-66423205bedb81171fd0644b).\n\n**Highlights**: Lumina-T2X supports Any Modality, Resolution, and Duration.\n\nLumina-T2X has the following components:\n* It uses a Flow-based Large Diffusion Transformer as the backbone\n* It supports different any modalities with one backbone and corresponding encoder, decoder.\n\nThis pipeline was contributed by [PommesPeter](https://github.com/PommesPeter). The original codebase can be found [here](https://github.com/Alpha-VLLM/Lumina-T2X). The original weights can be found under [hf.co/Alpha-VLLM](https://huggingface.co/Alpha-VLLM).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n### Inference (Text-to-Image)\n\nUse [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.\n\nFirst, load the pipeline:\n\n```python\nfrom diffusers import LuminaPipeline\nimport torch\n\npipeline = LuminaPipeline.from_pretrained(\n\t\"Alpha-VLLM/Lumina-Next-SFT-diffusers\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n```\n\nThen change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:\n\n```python\npipeline.transformer.to(memory_format=torch.channels_last)\npipeline.vae.to(memory_format=torch.channels_last)\n```\n\nFinally, compile the components and run inference:\n\n```python\npipeline.transformer = torch.compile(pipeline.transformer, mode=\"max-autotune\", fullgraph=True)\npipeline.vae.decode = torch.compile(pipeline.vae.decode, mode=\"max-autotune\", fullgraph=True)\n\nimage = pipeline(prompt=\"Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures\").images[0]\n```\n\n## Quantization\n\nQuantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.\n\nRefer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LuminaPipeline`] for inference with bitsandbytes.\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaPipeline\nfrom transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel\n\nquant_config = BitsAndBytesConfig(load_in_8bit=True)\ntext_encoder_8bit = T5EncoderModel.from_pretrained(\n    \"Alpha-VLLM/Lumina-Next-SFT-diffusers\",\n    subfolder=\"text_encoder\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_8bit = Transformer2DModel.from_pretrained(\n    \"Alpha-VLLM/Lumina-Next-SFT-diffusers\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\npipeline = LuminaPipeline.from_pretrained(\n    \"Alpha-VLLM/Lumina-Next-SFT-diffusers\",\n    text_encoder=text_encoder_8bit,\n    transformer=transformer_8bit,\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n)\n\nprompt = \"a tiny astronaut hatching from an egg on the moon\"\nimage = pipeline(prompt).images[0]\nimage.save(\"lumina.png\")\n```\n\n## LuminaPipeline\n\n[[autodoc]] LuminaPipeline\n\t- all\n\t- __call__\n\n"
  },
  {
    "path": "docs/source/en/api/pipelines/lumina2.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n# Lumina2\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n[Lumina Image 2.0: A Unified and Efficient Image Generative Model](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) is a 2 billion parameter flow-based diffusion transformer capable of generating diverse images from text descriptions.\n\nThe abstract from the paper is:\n\n*We introduce Lumina-Image 2.0, an advanced text-to-image model that surpasses previous state-of-the-art methods across multiple benchmarks, while also shedding light on its potential to evolve into a generalist vision intelligence model. Lumina-Image 2.0 exhibits three key properties: (1) Unification – it adopts a unified architecture that treats text and image tokens as a joint sequence, enabling natural cross-modal interactions and facilitating task expansion. Besides, since high-quality captioners can provide semantically better-aligned text-image training pairs, we introduce a unified captioning system, UniCaptioner, which generates comprehensive and precise captions for the model. This not only accelerates model convergence but also enhances prompt adherence, variable-length prompt handling, and task generalization via prompt templates. (2) Efficiency – to improve the efficiency of the unified architecture, we develop a set of optimization techniques that improve semantic learning and fine-grained texture generation during training while incorporating inference-time acceleration strategies without compromising image quality. (3) Transparency – we open-source all training details, code, and models to ensure full reproducibility, aiming to bridge the gap between well-resourced closed-source research teams and independent developers.*\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## Using Single File loading with Lumina Image 2.0\n\nSingle file loading for Lumina Image 2.0 is available for the `Lumina2Transformer2DModel`\n\n```python\nimport torch\nfrom diffusers import Lumina2Transformer2DModel, Lumina2Pipeline\n\nckpt_path = \"https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth\"\ntransformer = Lumina2Transformer2DModel.from_single_file(\n    ckpt_path, torch_dtype=torch.bfloat16\n)\n\npipe = Lumina2Pipeline.from_pretrained(\n    \"Alpha-VLLM/Lumina-Image-2.0\", transformer=transformer, torch_dtype=torch.bfloat16\n)\npipe.enable_model_cpu_offload()\nimage = pipe(\n    \"a cat holding a sign that says hello\",\n    generator=torch.Generator(\"cpu\").manual_seed(0),\n).images[0]\nimage.save(\"lumina-single-file.png\")\n\n```\n\n## Using GGUF Quantized Checkpoints with Lumina Image 2.0\n\nGGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig` \n\n```python\nfrom diffusers import Lumina2Transformer2DModel, Lumina2Pipeline, GGUFQuantizationConfig \n\nckpt_path = \"https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf\"\ntransformer = Lumina2Transformer2DModel.from_single_file(\n    ckpt_path,\n    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),\n    torch_dtype=torch.bfloat16,\n)\n\npipe = Lumina2Pipeline.from_pretrained(\n    \"Alpha-VLLM/Lumina-Image-2.0\", transformer=transformer, torch_dtype=torch.bfloat16\n)\npipe.enable_model_cpu_offload()\nimage = pipe(\n    \"a cat holding a sign that says hello\",\n    generator=torch.Generator(\"cpu\").manual_seed(0),\n).images[0]\nimage.save(\"lumina-gguf.png\")\n```\n\n## Lumina2Pipeline\n\n[[autodoc]] Lumina2Pipeline\n  - all\n  - __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/marigold.md",
    "content": "<!--\nCopyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.\nCopyright 2024-2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Marigold Computer Vision\n\n![marigold](https://marigoldmonodepth.github.io/images/teaser_collage_compressed.jpg)\n\nMarigold was proposed in \n[Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation](https://huggingface.co/papers/2312.02145), \na CVPR 2024 Oral paper by \n[Bingxin Ke](http://www.kebingxin.com/), \n[Anton Obukhov](https://www.obukhov.ai/), \n[Shengyu Huang](https://shengyuh.github.io/), \n[Nando Metzger](https://nandometzger.github.io/), \n[Rodrigo Caye Daudt](https://rcdaudt.github.io/), and \n[Konrad Schindler](https://scholar.google.com/citations?user=FZuNgqIAAAAJ&hl=en).\nThe core idea is to **repurpose the generative prior of Text-to-Image Latent Diffusion Models (LDMs) for traditional \ncomputer vision tasks**.\nThis approach was explored by fine-tuning Stable Diffusion for **Monocular Depth Estimation**, as demonstrated in the \nteaser above.\n\nMarigold was later extended in the follow-up paper, \n[Marigold: Affordable Adaptation of Diffusion-Based Image Generators for Image Analysis](https://huggingface.co/papers/2312.02145), \nauthored by \n[Bingxin Ke](http://www.kebingxin.com/), \n[Kevin Qu](https://www.linkedin.com/in/kevin-qu-b3417621b/?locale=en_US), \n[Tianfu Wang](https://tianfwang.github.io/), \n[Nando Metzger](https://nandometzger.github.io/), \n[Shengyu Huang](https://shengyuh.github.io/), \n[Bo Li](https://www.linkedin.com/in/bobboli0202/), \n[Anton Obukhov](https://www.obukhov.ai/), and \n[Konrad Schindler](https://scholar.google.com/citations?user=FZuNgqIAAAAJ&hl=en).\nThis work expanded Marigold to support new modalities such as **Surface Normals** and **Intrinsic Image Decomposition** \n(IID), introduced a training protocol for **Latent Consistency Models** (LCM), and demonstrated **High-Resolution** (HR) \nprocessing capability.\n\n> [!TIP]\n> The early Marigold models (`v1-0` and earlier) were optimized for best results with at least 10 inference steps.\n> LCM models were later developed to enable high-quality inference in just 1 to 4 steps.\n> Marigold models `v1-1` and later use the DDIM scheduler to achieve optimal \n> results in as few as 1 to 4 steps.\n\n## Available Pipelines\n\nEach pipeline is tailored for a specific computer vision task, processing an input RGB image and generating a \ncorresponding prediction.\nCurrently, the following computer vision tasks are implemented:\n\n| Pipeline                                                                                                                                          | Recommended Model Checkpoints                                                                                                                                                                           |                              Spaces (Interactive Apps)                               | Predicted Modalities                                                                                                                                                               |\n|---------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py)           | [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1)                                                                                                                       |          [Depth Estimation](https://huggingface.co/spaces/prs-eth/marigold)          | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity)                                                                   |\n| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py)       | [prs-eth/marigold-normals-v1-1](https://huggingface.co/prs-eth/marigold-normals-v1-1)                                                                                                                   | [Surface Normals Estimation](https://huggingface.co/spaces/prs-eth/marigold-normals) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping)                                                                                                                    |\n| [MarigoldIntrinsicsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py) | [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1),<br>[prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | [Intrinsic Image Decomposition](https://huggingface.co/spaces/prs-eth/marigold-iid)  | [Albedo](https://en.wikipedia.org/wiki/Albedo), [Materials](https://www.n.aiq3d.com/wiki/roughnessmetalnessao-map), [Lighting](https://en.wikipedia.org/wiki/Diffuse_reflection)   |\n\n## Available Checkpoints\n\nAll original checkpoints are available under the [PRS-ETH](https://huggingface.co/prs-eth/) organization on Hugging Face.\nThey are designed for use with diffusers pipelines and the [original codebase](https://github.com/prs-eth/marigold), which can also be used to train \nnew model checkpoints.\nThe following is a summary of the recommended checkpoints, all of which produce reliable results with 1 to 4 steps. \n\n| Checkpoint                                                                                          | Modality     | Comment                                                                                                                                                                              |\n|-----------------------------------------------------------------------------------------------------|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1)                   | Depth        | Affine-invariant depth prediction assigns each pixel a value between 0 (near plane) and 1 (far plane), with both planes determined by the model during inference.                    |\n| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1)               | Normals      | The surface normals predictions are unit-length 3D vectors in the screen space camera, with values in the range from -1 to 1.                                                        |\n| [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1) | Intrinsics   | InteriorVerse decomposition is comprised of Albedo and two BRDF material properties: Roughness and Metallicity.                                                                      | \n| [prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1)     | Intrinsics   | HyperSim decomposition of an image $I$ is comprised of Albedo $A$, Diffuse shading $S$, and Non-diffuse residual $R$: $I = A*S+R$. |\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff \n> between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to \n> efficiently load the same components into multiple pipelines. \n> Also, to know more about reducing the memory usage of this pipeline, refer to the [\"Reduce memory usage\"] section \n> [here](../../using-diffusers/svd#reduce-memory-usage).\n\n> [!WARNING]\n> Marigold pipelines were designed and tested with the scheduler embedded in the model checkpoint.\n> The optimal number of inference steps varies by scheduler, with no universal value that works best across all cases.\n> To accommodate this, the `num_inference_steps` parameter in the pipeline's `__call__` method defaults to `None` (see the \n> API reference).\n> Unless set explicitly, it inherits the value from the `default_denoising_steps` field in the checkpoint configuration \n> file (`model_index.json`).\n> This ensures high-quality predictions when invoking the pipeline with only the `image` argument.\n\nSee also Marigold [usage examples](../../using-diffusers/marigold_usage).\n\n## Marigold Depth Prediction API\n\n[[autodoc]] MarigoldDepthPipeline\n\t- __call__\n\n[[autodoc]] pipelines.marigold.pipeline_marigold_depth.MarigoldDepthOutput\n\n[[autodoc]] pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_depth\n\n## Marigold Normals Estimation API\n[[autodoc]] MarigoldNormalsPipeline\n\t- __call__\n\n[[autodoc]] pipelines.marigold.pipeline_marigold_normals.MarigoldNormalsOutput\n\n[[autodoc]] pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_normals\n\n## Marigold Intrinsic Image Decomposition API\n\n[[autodoc]] MarigoldIntrinsicsPipeline\n\t- __call__\n\n[[autodoc]] pipelines.marigold.pipeline_marigold_intrinsics.MarigoldIntrinsicsOutput\n\n[[autodoc]] pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_intrinsics\n"
  },
  {
    "path": "docs/source/en/api/pipelines/mochi.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n-->\n\n# Mochi 1 Preview\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n> [!TIP]\n> Only a research preview of the model weights is available at the moment.\n\n[Mochi 1](https://huggingface.co/genmo/mochi-1-preview) is a video generation model by Genmo with a strong focus on prompt adherence and motion quality. The model features a 10B parameter Asmmetric Diffusion Transformer (AsymmDiT) architecture, and uses non-square QKV and output projection layers to reduce inference memory requirements. A single T5-XXL model is used to encode prompts.\n\n*Mochi 1 preview is an open state-of-the-art video generation model with high-fidelity motion and strong prompt adherence in preliminary evaluation. This model dramatically closes the gap between closed and open video generation systems. The model is released under a permissive Apache 2.0 license.*\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## Quantization\n\nQuantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.\n\nRefer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`MochiPipeline`] for inference with bitsandbytes.\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, MochiTransformer3DModel, MochiPipeline\nfrom diffusers.utils import export_to_video\nfrom transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel\n\nquant_config = BitsAndBytesConfig(load_in_8bit=True)\ntext_encoder_8bit = T5EncoderModel.from_pretrained(\n    \"genmo/mochi-1-preview\",\n    subfolder=\"text_encoder\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_8bit = MochiTransformer3DModel.from_pretrained(\n    \"genmo/mochi-1-preview\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\npipeline = MochiPipeline.from_pretrained(\n    \"genmo/mochi-1-preview\",\n    text_encoder=text_encoder_8bit,\n    transformer=transformer_8bit,\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n)\n\nvideo = pipeline(\n  \"Close-up of a cats eye, with the galaxy reflected in the cats eye. Ultra high resolution 4k.\",\n  num_inference_steps=28,\n  guidance_scale=3.5\n).frames[0]\nexport_to_video(video, \"cat.mp4\")\n```\n\n## Generating videos with Mochi-1 Preview\n\nThe following example will download the full precision `mochi-1-preview` weights and produce the highest quality results but will require at least 42GB VRAM to run.\n\n```python\nimport torch\nfrom diffusers import MochiPipeline\nfrom diffusers.utils import export_to_video\n\npipe = MochiPipeline.from_pretrained(\"genmo/mochi-1-preview\")\n\n# Enable memory savings\npipe.enable_model_cpu_offload()\npipe.enable_vae_tiling()\n\nprompt = \"Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.\"\n\nwith torch.autocast(\"cuda\", torch.bfloat16, cache_enabled=False):\n      frames = pipe(prompt, num_frames=85).frames[0]\n\nexport_to_video(frames, \"mochi.mp4\", fps=30)\n```\n\n## Using a lower precision variant to save memory\n\nThe following example will use the `bfloat16` variant of the model and requires 22GB VRAM to run. There is a slight drop in the quality of the generated video as a result.\n\n```python\nimport torch\nfrom diffusers import MochiPipeline\nfrom diffusers.utils import export_to_video\n\npipe = MochiPipeline.from_pretrained(\"genmo/mochi-1-preview\", variant=\"bf16\", torch_dtype=torch.bfloat16)\n\n# Enable memory savings\npipe.enable_model_cpu_offload()\npipe.enable_vae_tiling()\n\nprompt = \"Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.\"\nframes = pipe(prompt, num_frames=85).frames[0]\n\nexport_to_video(frames, \"mochi.mp4\", fps=30)\n```\n\n## Reproducing the results from the Genmo Mochi repo\n\nThe [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the original implementation, please refer to the following example.\n\n> [!TIP]\n> The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder.\n>\n> When enabling `force_zeros_for_empty_prompt`, it is recommended to run the text encoding step outside the autocast context in full precision.\n\n> [!TIP]\n> Decoding the latents in full precision is very memory intensive. You will need at least 70GB VRAM to generate the 163 frames in this example. To reduce memory, either reduce the number of frames or run the decoding step in `torch.bfloat16`.\n\n```python\nimport torch\nfrom torch.nn.attention import SDPBackend, sdpa_kernel\n\nfrom diffusers import MochiPipeline\nfrom diffusers.utils import export_to_video\nfrom diffusers.video_processor import VideoProcessor\n\npipe = MochiPipeline.from_pretrained(\"genmo/mochi-1-preview\", force_zeros_for_empty_prompt=True)\npipe.enable_vae_tiling()\npipe.enable_model_cpu_offload()\n\nprompt =  \"An aerial shot of a parade of elephants walking across the African savannah. The camera showcases the herd and the surrounding landscape.\"\n\nwith torch.no_grad():\n    prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (\n        pipe.encode_prompt(prompt=prompt)\n    )\n\nwith torch.autocast(\"cuda\", torch.bfloat16):\n    with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):\n        frames = pipe(\n            prompt_embeds=prompt_embeds,\n            prompt_attention_mask=prompt_attention_mask,\n            negative_prompt_embeds=negative_prompt_embeds,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n            guidance_scale=4.5,\n            num_inference_steps=64,\n            height=480,\n            width=848,\n            num_frames=163,\n            generator=torch.Generator(\"cuda\").manual_seed(0),\n            output_type=\"latent\",\n            return_dict=False,\n        )[0]\n\nvideo_processor = VideoProcessor(vae_scale_factor=8)\nhas_latents_mean = hasattr(pipe.vae.config, \"latents_mean\") and pipe.vae.config.latents_mean is not None\nhas_latents_std = hasattr(pipe.vae.config, \"latents_std\") and pipe.vae.config.latents_std is not None\nif has_latents_mean and has_latents_std:\n    latents_mean = (\n        torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)\n    )\n    latents_std = (\n        torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)\n    )\n    frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean\nelse:\n    frames = frames / pipe.vae.config.scaling_factor\n\nwith torch.no_grad():\n    video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0]\n\nvideo = video_processor.postprocess_video(video)[0]\nexport_to_video(video, \"mochi.mp4\", fps=30)\n```\n\n## Running inference with multiple GPUs\n\nIt is possible to split the large Mochi transformer across multiple GPUs using the `device_map` and `max_memory` options in `from_pretrained`. In the following example we split the model across two GPUs, each with 24GB of VRAM.\n\n```python\nimport torch\nfrom diffusers import MochiPipeline, MochiTransformer3DModel\nfrom diffusers.utils import export_to_video\n\nmodel_id = \"genmo/mochi-1-preview\"\ntransformer = MochiTransformer3DModel.from_pretrained(\n    model_id,\n    subfolder=\"transformer\",\n    device_map=\"auto\",\n    max_memory={0: \"24GB\", 1: \"24GB\"}\n)\n\npipe = MochiPipeline.from_pretrained(model_id,  transformer=transformer)\npipe.enable_model_cpu_offload()\npipe.enable_vae_tiling()\n\nwith torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16, cache_enabled=False):\n    frames = pipe(\n        prompt=\"Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.\",\n        negative_prompt=\"\",\n        height=480,\n        width=848,\n        num_frames=85,\n        num_inference_steps=50,\n        guidance_scale=4.5,\n        num_videos_per_prompt=1,\n        generator=torch.Generator(device=\"cuda\").manual_seed(0),\n        max_sequence_length=256,\n        output_type=\"pil\",\n    ).frames[0]\n\nexport_to_video(frames, \"output.mp4\", fps=30)\n```\n\n## Using single file loading with the Mochi Transformer\n\nYou can use `from_single_file` to load the Mochi transformer in its original format.\n\n> [!TIP]\n> Diffusers currently doesn't support using the FP8 scaled versions of the Mochi single file checkpoints.\n\n```python\nimport torch\nfrom diffusers import MochiPipeline, MochiTransformer3DModel\nfrom diffusers.utils import export_to_video\n\nmodel_id = \"genmo/mochi-1-preview\"\n\nckpt_path = \"https://huggingface.co/Comfy-Org/mochi_preview_repackaged/blob/main/split_files/diffusion_models/mochi_preview_bf16.safetensors\"\n\ntransformer = MochiTransformer3DModel.from_pretrained(ckpt_path, torch_dtype=torch.bfloat16)\n\npipe = MochiPipeline.from_pretrained(model_id,  transformer=transformer)\npipe.enable_model_cpu_offload()\npipe.enable_vae_tiling()\n\nwith torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16, cache_enabled=False):\n    frames = pipe(\n        prompt=\"Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.\",\n        negative_prompt=\"\",\n        height=480,\n        width=848,\n        num_frames=85,\n        num_inference_steps=50,\n        guidance_scale=4.5,\n        num_videos_per_prompt=1,\n        generator=torch.Generator(device=\"cuda\").manual_seed(0),\n        max_sequence_length=256,\n        output_type=\"pil\",\n    ).frames[0]\n\nexport_to_video(frames, \"output.mp4\", fps=30)\n```\n\n## MochiPipeline\n\n[[autodoc]] MochiPipeline\n  - all\n  - __call__\n\n## MochiPipelineOutput\n\n[[autodoc]] pipelines.mochi.pipeline_output.MochiPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/musicldm.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# MusicLDM\n\nMusicLDM was proposed in [MusicLDM: Enhancing Novelty in Text-to-Music Generation Using Beat-Synchronous Mixup Strategies](https://huggingface.co/papers/2308.01546) by Ke Chen, Yusong Wu, Haohe Liu, Marianna Nezhurina, Taylor Berg-Kirkpatrick, Shlomo Dubnov.\nMusicLDM takes a text prompt as input and predicts the corresponding music sample.\n\nInspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview) and [AudioLDM](https://huggingface.co/docs/diffusers/api/pipelines/audioldm),\nMusicLDM is a text-to-music _latent diffusion model (LDM)_ that learns continuous audio representations from [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap)\nlatents.\n\nMusicLDM is trained on a corpus of 466 hours of music data. Beat-synchronous data augmentation strategies are applied to the music samples, both in the time domain and in the latent space. Using beat-synchronous data augmentation strategies encourages the model to interpolate between the training samples, but stay within the domain of the training data. The result is generated music that is more diverse while staying faithful to the corresponding style.\n\nThe abstract of the paper is the following:\n\n*Diffusion models have shown promising results in cross-modal generation tasks, including text-to-image and text-to-audio generation. However, generating music, as a special type of audio, presents unique challenges due to limited availability of music data and sensitive issues related to copyright and plagiarism. In this paper, to tackle these challenges, we first construct a state-of-the-art text-to-music model, MusicLDM, that adapts Stable Diffusion and AudioLDM architectures to the music domain. We achieve this by retraining the contrastive language-audio pretraining model (CLAP) and the Hifi-GAN vocoder, as components of MusicLDM, on a collection of music data samples. Then, to address the limitations of training data and to avoid plagiarism, we leverage a beat tracking model and propose two different mixup strategies for data augmentation: beat-synchronous audio mixup and beat-synchronous latent mixup, which recombine training audio directly or via a latent embeddings space, respectively. Such mixup strategies encourage the model to interpolate between musical training samples and generate new music within the convex hull of the training data, making the generated music more diverse while still staying faithful to the corresponding style. In addition to popular evaluation metrics, we design several new evaluation metrics based on CLAP score to demonstrate that our proposed MusicLDM and beat-synchronous mixup strategies improve both the quality and novelty of generated music, as well as the correspondence between input text and generated music.*\n\nThis pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi).\n\n## Tips\n\nWhen constructing a prompt, keep in mind:\n\n* Descriptive prompt inputs work best; use adjectives to describe the sound (for example, \"high quality\" or \"clear\") and make the prompt context specific where possible (e.g. \"melodic techno with a fast beat and synths\" works better than \"techno\").\n* Using a *negative prompt* can significantly improve the quality of the generated audio. Try using a negative prompt of \"low quality, average quality\".\n\nDuring inference:\n\n* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.\n* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.\n* The _length_ of the generated audio sample can be controlled by varying the `audio_length_in_s` argument.\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## MusicLDMPipeline\n[[autodoc]] MusicLDMPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/omnigen.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n-->\n\n# OmniGen\n\n[OmniGen: Unified Image Generation](https://huggingface.co/papers/2409.11340) from BAAI, by Shitao Xiao, Yueze Wang, Junjie Zhou, Huaying Yuan, Xingrun Xing, Ruiran Yan, Chaofan Li, Shuting Wang, Tiejun Huang, Zheng Liu.\n\nThe abstract from the paper is:\n\n*The emergence of Large Language Models (LLMs) has unified language  generation tasks and revolutionized human-machine interaction.  However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism.  This work represents the first attempt at a general-purpose image generation model,  and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.*\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\nThis pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1).\n\n## Inference\n\nFirst, load the pipeline:\n\n```python\nimport torch\nfrom diffusers import OmniGenPipeline\n\npipe = OmniGenPipeline.from_pretrained(\"Shitao/OmniGen-v1-diffusers\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n```\n\nFor text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image. \nYou can try setting the `height` and `width` parameters to generate images with different size.\n\n```python\nprompt = \"Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD.\"\nimage = pipe(\n    prompt=prompt,\n    height=1024,\n    width=1024,\n    guidance_scale=3,\n    generator=torch.Generator(device=\"cpu\").manual_seed(111),\n).images[0]\nimage.save(\"output.png\")\n```\n\nOmniGen supports multimodal inputs. \nWhen the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image. \nIt is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.\n\n```python\nprompt=\"<img><|image_1|></img> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola.\"\ninput_images=[load_image(\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png\")]\nimage = pipe(\n    prompt=prompt, \n    input_images=input_images, \n    guidance_scale=2, \n    img_guidance_scale=1.6,\n    use_input_image_size_as_output=True,\n    generator=torch.Generator(device=\"cpu\").manual_seed(222)).images[0]\nimage.save(\"output.png\")\n```\n\n## OmniGenPipeline\n\n[[autodoc]] OmniGenPipeline\n  - all\n  - __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Pipelines\n\nPipelines provide a simple way to run state-of-the-art diffusion models in inference by bundling all of the necessary components (multiple independently-trained models, schedulers, and processors) into a single end-to-end class. Pipelines are flexible and they can be adapted to use different schedulers or even model components.\n\nAll pipelines are built from the base [`DiffusionPipeline`] class which provides basic functionality for loading, downloading, and saving all the components. Specific pipeline types (for example [`StableDiffusionPipeline`]) loaded with [`~DiffusionPipeline.from_pretrained`] are automatically detected and the pipeline components are loaded and passed to the `__init__` function of the pipeline.\n\n> [!WARNING]\n> You shouldn't use the [`DiffusionPipeline`] class for training. Individual components (for example, [`UNet2DModel`] and [`UNet2DConditionModel`]) of diffusion pipelines are usually trained individually, so we suggest directly working with them instead.\n>\n> <br>\n>\n> Pipelines do not offer any training functionality. You'll notice PyTorch's autograd is disabled by decorating the [`~DiffusionPipeline.__call__`] method with a [`torch.no_grad`](https://pytorch.org/docs/stable/generated/torch.no_grad.html) decorator because pipelines should not be used for training. If you're interested in training, please take a look at the [Training](../../training/overview) guides instead!\n\nThe table below lists all the pipelines currently available in 🤗 Diffusers and the tasks they support. Click on a pipeline to view its abstract and published paper.\n\n| Pipeline | Tasks |\n|---|---|\n| [aMUSEd](amused) | text2image |\n| [AnimateDiff](animatediff) | text2video |\n| [Attend-and-Excite](attend_and_excite) | text2image |\n| [AudioLDM](audioldm) | text2audio |\n| [AudioLDM2](audioldm2) | text2audio |\n| [AuraFlow](aura_flow) | text2image |\n| [BLIP Diffusion](blip_diffusion) | text2image |\n| [Bria 3.2](bria_3_2) | text2image |\n| [CogVideoX](cogvideox) | text2video |\n| [Consistency Models](consistency_models) | unconditional image generation |\n| [ControlNet](controlnet) | text2image, image2image, inpainting |\n| [ControlNet with Flux.1](controlnet_flux) | text2image |\n| [ControlNet with Hunyuan-DiT](controlnet_hunyuandit) | text2image |\n| [ControlNet with Stable Diffusion 3](controlnet_sd3) | text2image |\n| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |\n| [ControlNet-XS](controlnetxs) | text2image |\n| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |\n| [Cosmos](cosmos) | text2video, video2video |\n| [Dance Diffusion](dance_diffusion) | unconditional audio generation |\n| [DDIM](ddim) | unconditional image generation |\n| [DDPM](ddpm) | unconditional image generation |\n| [DeepFloyd IF](deepfloyd_if) | text2image, image2image, inpainting, super-resolution |\n| [DiffEdit](diffedit) | inpainting |\n| [DiT](dit) | text2image |\n| [Flux](flux) | text2image |\n| [Hunyuan-DiT](hunyuandit) | text2image |\n| [I2VGen-XL](i2vgenxl) | image2video |\n| [InstructPix2Pix](pix2pix) | image editing |\n| [Kandinsky 2.1](kandinsky) | text2image, image2image, inpainting, interpolation |\n| [Kandinsky 2.2](kandinsky_v22) | text2image, image2image, inpainting |\n| [Kandinsky 3](kandinsky3) | text2image, image2image |\n| [Kolors](kolors) | text2image |\n| [Latent Consistency Models](latent_consistency_models) | text2image |\n| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |\n| [Latte](latte) | text2image |\n| [LEDITS++](ledits_pp) | image editing |\n| [Lumina-T2X](lumina) | text2image |\n| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |\n| [MultiDiffusion](panorama) | text2image |\n| [MusicLDM](musicldm) | text2audio |\n| [PAG](pag) | text2image |\n| [Paint by Example](paint_by_example) | inpainting |\n| [PIA](pia) | image2video |\n| [PixArt-α](pixart) | text2image |\n| [PixArt-Σ](pixart_sigma) | text2image |\n| [Self-Attention Guidance](self_attention_guidance) | text2image |\n| [Semantic Guidance](semantic_stable_diffusion) | text2image |\n| [Shap-E](shap_e) | text-to-3D, image-to-3D |\n| [Stable Audio](stable_audio) | text2audio |\n| [Stable Cascade](stable_cascade) | text2image |\n| [Stable Diffusion](stable_diffusion/overview) | text2image, image2image, depth2image, inpainting, image variation, latent upscaler, super-resolution |\n| [Stable Diffusion XL](stable_diffusion/stable_diffusion_xl) | text2image, image2image, inpainting |\n| [Stable Diffusion XL Turbo](stable_diffusion/sdxl_turbo) | text2image, image2image, inpainting |\n| [Stable unCLIP](stable_unclip) | text2image, image variation |\n| [T2I-Adapter](stable_diffusion/adapter) | text2image |\n| [Text2Video](text_to_video) | text2video, video2video |\n| [Text2Video-Zero](text_to_video_zero) | text2video |\n| [unCLIP](unclip) | text2image, image variation |\n| [UniDiffuser](unidiffuser) | text2image, image2text, image variation, text variation, unconditional image generation, unconditional audio generation |\n| [Value-guided planning](value_guided_sampling) | value guided sampling |\n| [Wuerstchen](wuerstchen) | text2image |\n| [VisualCloze](visualcloze) | text2image, image2image, subject driven generation, inpainting, style transfer, image restoration, image editing, [depth,normal,edge,pose]2image, [depth,normal,edge,pose]-estimation, virtual try-on, image relighting |\n\n## DiffusionPipeline\n\n[[autodoc]] DiffusionPipeline\n\t- all\n\t- __call__\n\t- device\n\t- to\n\t- components\n\n\n[[autodoc]] pipelines.StableDiffusionMixin.enable_freeu\n\n[[autodoc]] pipelines.StableDiffusionMixin.disable_freeu\n\n## PushToHubMixin\n\n[[autodoc]] utils.PushToHubMixin\n\n## Callbacks\n\n[[autodoc]] callbacks.PipelineCallback\n\n[[autodoc]] callbacks.SDCFGCutoffCallback\n\n[[autodoc]] callbacks.SDXLCFGCutoffCallback\n\n[[autodoc]] callbacks.SDXLControlnetCFGCutoffCallback\n\n[[autodoc]] callbacks.IPAdapterScaleCutoffCallback\n\n[[autodoc]] callbacks.SD3CFGCutoffCallback\n"
  },
  {
    "path": "docs/source/en/api/pipelines/ovis_image.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Ovis-Image\n\n![concepts](https://github.com/AIDC-AI/Ovis-Image/blob/main/docs/imgs/ovis_image_case.png)\n\nOvis-Image is a 7B text-to-image model specifically optimized for high-quality text rendering, designed to operate efficiently under stringent computational constraints.\n\n[Ovis-Image Technical Report](https://arxiv.org/abs/2511.22982) from Alibaba Group, by Guo-Hua Wang, Liangfu Cao, Tianyu Cui, Minghao Fu, Xiaohao Chen, Pengxin Zhan, Jianshan Zhao, Lan Li, Bowen Fu, Jiaqi Liu, Qing-Guo Chen.\n\nThe abstract from the paper is:\n\n*We introduce Ovis-Image, a 7B text-to-image model specifically optimized for high-quality text rendering, designed to operate efficiently under stringent computational constraints. Built upon our previous Ovis-U1 framework, Ovis-Image integrates a diffusion-based visual decoder with the stronger Ovis 2.5 multimodal backbone, leveraging a text-centric training pipeline that combines large-scale pre-training with carefully tailored post-training refinements. Despite its compact architecture, Ovis-Image achieves text rendering performance on par with significantly larger open models such as Qwen-Image and approaches closed-source systems like Seedream and GPT4o. Crucially, the model remains deployable on a single high-end GPU with moderate memory, narrowing the gap between frontier-level text rendering and practical deployment. Our results indicate that combining a strong multimodal backbone with a carefully designed, text-focused training recipe is sufficient to achieve reliable bilingual text rendering without resorting to oversized or proprietary models.*\n\n**Highlights**: \n\n*   **Strong text rendering at a compact 7B scale**: Ovis-Image is a 7B text-to-image model that delivers text rendering quality comparable to much larger 20B-class systems such as Qwen-Image and competitive with leading closed-source models like GPT4o in text-centric scenarios, while remaining small enough to run on widely accessible hardware.\n*   **High fidelity on text-heavy, layout-sensitive prompts**: The model excels on prompts that demand tight alignment between linguistic content and rendered typography (e.g., posters, banners, logos, UI mockups, infographics), producing legible, correctly spelled, and semantically consistent text across diverse fonts, sizes, and aspect ratios without compromising overall visual quality.\n*   **Efficiency and deployability**: With its 7B parameter budget and streamlined architecture, Ovis-Image fits on a single high-end GPU with moderate memory, supports low-latency interactive use, and scales to batch production serving, bringing near–frontier text rendering to applications where tens-of-billions–parameter models are impractical.\n\n\nThis pipeline was contributed by Ovis-Image Team. The original codebase can be found [here](https://github.com/AIDC-AI/Ovis-Image).\n\nAvailable models:\n\n| Model | Recommended dtype |\n|:-----:|:-----------------:|\n| [`AIDC-AI/Ovis-Image-7B`](https://huggingface.co/AIDC-AI/Ovis-Image-7B) | `torch.bfloat16` |\n\nRefer to [this](https://huggingface.co/collections/AIDC-AI/ovis-image) collection for more information.\n\n## OvisImagePipeline\n\n[[autodoc]] OvisImagePipeline\n\t- all\n\t- __call__\n\n## OvisImagePipelineOutput\n\n[[autodoc]] pipelines.ovis_image.pipeline_output.OvisImagePipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/pag.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Perturbed-Attention Guidance\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n[Perturbed-Attention Guidance (PAG)](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) is a new diffusion sampling guidance that improves sample quality across both unconditional and conditional settings, achieving this without requiring further training or the integration of external modules.\n\nPAG was introduced in [Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance](https://huggingface.co/papers/2403.17377) by Donghoon Ahn, Hyoungwon Cho, Jaewon Min, Wooseok Jang, Jungwoo Kim, SeonHwa Kim, Hyun Hee Park, Kyong Hwan Jin and Seungryong Kim.\n\nThe abstract from the paper is:\n\n*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*\n\nPAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.\n\n- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor`\n- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor`\n- Partial identifier as a RegEx: `down_blocks.2`, or `attn1`\n- List of identifiers (can be combo of strings and ReGex): `[\"blocks.1\", \"blocks.(14|20)\", r\"down_blocks\\.(2,3)\"]`\n\n> [!WARNING]\n> Since RegEx is supported as a way for matching layer identifiers, it is crucial to use it correctly otherwise there might be unexpected behaviour. The recommended way to use PAG is by specifying layers as `blocks.{layer_index}` and `blocks.({layer_index_1|layer_index_2|...})`. Using it in any other way, while doable, may bypass our basic validation checks and give you unexpected results.\n\n## AnimateDiffPAGPipeline\n[[autodoc]] AnimateDiffPAGPipeline\n  - all\n  - __call__\n\n## HunyuanDiTPAGPipeline\n[[autodoc]] HunyuanDiTPAGPipeline\n  - all\n  - __call__\n\n## KolorsPAGPipeline\n[[autodoc]] KolorsPAGPipeline\n  - all\n  - __call__\n\n## StableDiffusionPAGInpaintPipeline\n[[autodoc]] StableDiffusionPAGInpaintPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionPAGPipeline\n[[autodoc]] StableDiffusionPAGPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionPAGImg2ImgPipeline\n[[autodoc]] StableDiffusionPAGImg2ImgPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionControlNetPAGPipeline\n[[autodoc]] StableDiffusionControlNetPAGPipeline\n\n## StableDiffusionControlNetPAGInpaintPipeline\n[[autodoc]] StableDiffusionControlNetPAGInpaintPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLPAGPipeline\n[[autodoc]] StableDiffusionXLPAGPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLPAGImg2ImgPipeline\n[[autodoc]] StableDiffusionXLPAGImg2ImgPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLPAGInpaintPipeline\n[[autodoc]] StableDiffusionXLPAGInpaintPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLControlNetPAGPipeline\n[[autodoc]] StableDiffusionXLControlNetPAGPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLControlNetPAGImg2ImgPipeline\n[[autodoc]] StableDiffusionXLControlNetPAGImg2ImgPipeline\n\t- all\n\t- __call__\n\n## StableDiffusion3PAGPipeline\n[[autodoc]] StableDiffusion3PAGPipeline\n\t- all\n\t- __call__\n\n## StableDiffusion3PAGImg2ImgPipeline\n[[autodoc]] StableDiffusion3PAGImg2ImgPipeline\n\t- all\n\t- __call__\n\n## PixArtSigmaPAGPipeline\n[[autodoc]] PixArtSigmaPAGPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/paint_by_example.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# Paint by Example\n\n[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://huggingface.co/papers/2211.13227) is by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen.\n\nThe abstract from the paper is:\n\n*Language-guided image editing has achieved great success recently. In this paper, for the first time, we investigate exemplar-guided image editing for more precise control. We achieve this goal by leveraging self-supervised training to disentangle and re-organize the source image and the exemplar. However, the naive approach will cause obvious fusing artifacts. We carefully analyze it and propose an information bottleneck and strong augmentations to avoid the trivial solution of directly copying and pasting the exemplar image. Meanwhile, to ensure the controllability of the editing process, we design an arbitrary shape mask for the exemplar image and leverage the classifier-free guidance to increase the similarity to the exemplar image. The whole framework involves a single forward of the diffusion model without any iterative optimization. We demonstrate that our method achieves an impressive performance and enables controllable editing on in-the-wild images with high fidelity.*\n\nThe original codebase can be found at [Fantasy-Studio/Paint-by-Example](https://github.com/Fantasy-Studio/Paint-by-Example), and you can try it out in a [demo](https://huggingface.co/spaces/Fantasy-Studio/Paint-by-Example).\n\n## Tips\n\nPaint by Example is supported by the official [Fantasy-Studio/Paint-by-Example](https://huggingface.co/Fantasy-Studio/Paint-by-Example) checkpoint. The checkpoint is warm-started from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) to inpaint partly masked images conditioned on example and reference images.\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## PaintByExamplePipeline\n[[autodoc]] PaintByExamplePipeline\n\t- all\n\t- __call__\n\n## StableDiffusionPipelineOutput\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/panorama.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# MultiDiffusion\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n[MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation](https://huggingface.co/papers/2302.08113) is by Omer Bar-Tal, Lior Yariv, Yaron Lipman, and Tali Dekel.\n\nThe abstract from the paper is:\n\n*Recent advances in text-to-image generation with diffusion models present transformative capabilities in image quality. However, user controllability of the generated image, and fast adaptation to new tasks still remains an open challenge, currently mostly addressed by costly and long re-training and fine-tuning or ad-hoc adaptations to specific image generation tasks. In this work, we present MultiDiffusion, a unified framework that enables versatile and controllable image generation, using a pre-trained text-to-image diffusion model, without any further training or finetuning. At the center of our approach is a new generation process, based on an optimization task that binds together multiple diffusion generation processes with a shared set of parameters or constraints. We show that MultiDiffusion can be readily applied to generate high quality and diverse images that adhere to user-provided controls, such as desired aspect ratio (e.g., panorama), and spatial guiding signals, ranging from tight segmentation masks to bounding boxes.*\n\nYou can find additional information about MultiDiffusion on the [project page](https://multidiffusion.github.io/), [original codebase](https://github.com/omerbt/MultiDiffusion), and try it out in a [demo](https://huggingface.co/spaces/weizmannscience/MultiDiffusion).\n\n## Tips\n\nWhile calling [`StableDiffusionPanoramaPipeline`], it's possible to specify the `view_batch_size` parameter to be > 1.\nFor some GPUs with high performance, this can speedup the generation process and increase VRAM usage.\n\nTo generate panorama-like images make sure you pass the width parameter accordingly. We recommend a width value of 2048 which is the default.\n\nCircular padding is applied to ensure there are no stitching artifacts when working with panoramas to ensure a seamless transition from the rightmost part to the leftmost part. By enabling circular padding (set `circular_padding=True`), the operation applies additional crops after the rightmost point of the image, allowing the model to \"see” the transition from the rightmost part to the leftmost part. This helps maintain visual consistency in a 360-degree sense and creates a proper “panorama” that can be viewed using 360-degree panorama viewers. When decoding latents in Stable Diffusion, circular padding is applied to ensure that the decoded latents match in the RGB space.\n\nFor example, without circular padding, there is a stitching artifact (default):\n![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/indoor_%20no_circular_padding.png)\n\nBut with circular padding, the right and the left parts are matching (`circular_padding=True`):\n![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/indoor_%20circular_padding.png)\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## StableDiffusionPanoramaPipeline\n[[autodoc]] StableDiffusionPanoramaPipeline\n\t- __call__\n\t- all\n\n## StableDiffusionPipelineOutput\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/pia.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# Image-to-Video Generation with PIA (Personalized Image Animator)\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n## Overview\n\n[PIA: Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models](https://huggingface.co/papers/2312.13964) by Yiming Zhang, Zhening Xing, Yanhong Zeng, Youqing Fang, Kai Chen\n\nRecent advancements in personalized text-to-image (T2I) models have revolutionized content creation, empowering non-experts to generate stunning images with unique styles. While promising, adding realistic motions into these personalized images by text poses significant challenges in preserving distinct styles, high-fidelity details, and achieving motion controllability by text. In this paper, we present PIA, a Personalized Image Animator that excels in aligning with condition images, achieving motion controllability by text, and the compatibility with various personalized T2I models without specific tuning. To achieve these goals, PIA builds upon a base T2I model with well-trained temporal alignment layers, allowing for the seamless transformation of any personalized T2I model into an image animation model. A key component of PIA is the introduction of the condition module, which utilizes the condition frame and inter-frame affinity as input to transfer appearance information guided by the affinity hint for individual frame synthesis in the latent space. This design mitigates the challenges of appearance-related image alignment within and allows for a stronger focus on aligning with motion-related guidance.\n\n[Project page](https://pi-animator.github.io/)\n\n## Available Pipelines\n\n| Pipeline | Tasks | Demo\n|---|---|:---:|\n| [PIAPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pia/pipeline_pia.py) | *Image-to-Video Generation with PIA* |\n\n## Available checkpoints\n\nMotion Adapter checkpoints for PIA can be found under the [OpenMMLab org](https://huggingface.co/openmmlab/PIA-condition-adapter). These checkpoints are meant to work with any model based on Stable Diffusion 1.5\n\n## Usage example\n\nPIA works with a MotionAdapter checkpoint and a Stable Diffusion 1.5 model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in the Stable Diffusion UNet. In addition to the motion modules, PIA also replaces the input convolution layer of the SD 1.5 UNet model with a 9 channel input convolution layer.\n\nThe following example demonstrates how to use PIA to generate a video from a single image.\n\n```python\nimport torch\nfrom diffusers import (\n    EulerDiscreteScheduler,\n    MotionAdapter,\n    PIAPipeline,\n)\nfrom diffusers.utils import export_to_gif, load_image\n\nadapter = MotionAdapter.from_pretrained(\"openmmlab/PIA-condition-adapter\")\npipe = PIAPipeline.from_pretrained(\"SG161222/Realistic_Vision_V6.0_B1_noVAE\", motion_adapter=adapter, torch_dtype=torch.float16)\n\npipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)\npipe.enable_model_cpu_offload()\npipe.enable_vae_slicing()\n\nimage = load_image(\n    \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true\"\n)\nimage = image.resize((512, 512))\nprompt = \"cat in a field\"\nnegative_prompt = \"wrong white balance, dark, sketches,worst quality,low quality\"\n\ngenerator = torch.Generator(\"cpu\").manual_seed(0)\noutput = pipe(image=image, prompt=prompt, generator=generator)\nframes = output.frames[0]\nexport_to_gif(frames, \"pia-animation.gif\")\n```\n\nHere are some sample outputs:\n\n<table>\n    <tr>\n        <td><center>\n        cat in a field.\n        <br>\n        <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pia-default-output.gif\"\n            alt=\"cat in a field\"\n            style=\"width: 300px;\" />\n        </center></td>\n    </tr>\n</table>\n\n\n> [!TIP]\n> If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the PIA checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`.\n\n## Using FreeInit\n\n[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://huggingface.co/papers/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu.\n\nFreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to PIA, AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper.\n\nThe following example demonstrates the usage of FreeInit.\n\n```python\nimport torch\nfrom diffusers import (\n    DDIMScheduler,\n    MotionAdapter,\n    PIAPipeline,\n)\nfrom diffusers.utils import export_to_gif, load_image\n\nadapter = MotionAdapter.from_pretrained(\"openmmlab/PIA-condition-adapter\")\npipe = PIAPipeline.from_pretrained(\"SG161222/Realistic_Vision_V6.0_B1_noVAE\", motion_adapter=adapter)\n\n# enable FreeInit\n# Refer to the enable_free_init documentation for a full list of configurable parameters\npipe.enable_free_init(method=\"butterworth\", use_fast_sampling=True)\n\n# Memory saving options\npipe.enable_model_cpu_offload()\npipe.enable_vae_slicing()\n\npipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)\nimage = load_image(\n    \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true\"\n)\nimage = image.resize((512, 512))\nprompt = \"cat in a field\"\nnegative_prompt = \"wrong white balance, dark, sketches,worst quality,low quality\"\n\ngenerator = torch.Generator(\"cpu\").manual_seed(0)\n\noutput = pipe(image=image, prompt=prompt, generator=generator)\nframes = output.frames[0]\nexport_to_gif(frames, \"pia-freeinit-animation.gif\")\n```\n\n<table>\n    <tr>\n        <td><center>\n        cat in a field.\n        <br>\n        <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pia-freeinit-output-cat.gif\"\n            alt=\"cat in a field\"\n            style=\"width: 300px;\" />\n        </center></td>\n    </tr>\n</table>\n\n\n> [!WARNING]\n> FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).\n\n## PIAPipeline\n\n[[autodoc]] PIAPipeline\n\t- all\n\t- __call__\n    - enable_freeu\n    - disable_freeu\n    - enable_free_init\n    - disable_free_init\n    - enable_vae_slicing\n    - disable_vae_slicing\n    - enable_vae_tiling\n    - disable_vae_tiling\n\n## PIAPipelineOutput\n\n[[autodoc]] pipelines.pia.PIAPipelineOutput"
  },
  {
    "path": "docs/source/en/api/pipelines/pix2pix.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# InstructPix2Pix\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n[InstructPix2Pix: Learning to Follow Image Editing Instructions](https://huggingface.co/papers/2211.09800) is by Tim Brooks, Aleksander Holynski and Alexei A. Efros.\n\nThe abstract from the paper is:\n\n*We propose a method for editing images from human instructions: given an input image and a written instruction that tells the model what to do, our model follows these instructions to edit the image. To obtain training data for this problem, we combine the knowledge of two large pretrained models -- a language model (GPT-3) and a text-to-image model (Stable Diffusion) -- to generate a large dataset of image editing examples. Our conditional diffusion model, InstructPix2Pix, is trained on our generated data, and generalizes to real images and user-written instructions at inference time. Since it performs edits in the forward pass and does not require per example fine-tuning or inversion, our model edits images quickly, in a matter of seconds. We show compelling editing results for a diverse collection of input images and written instructions.*\n\nYou can find additional information about InstructPix2Pix on the [project page](https://www.timothybrooks.com/instruct-pix2pix), [original codebase](https://github.com/timothybrooks/instruct-pix2pix), and try it out in a [demo](https://huggingface.co/spaces/timbrooks/instruct-pix2pix).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## StableDiffusionInstructPix2PixPipeline\n[[autodoc]] StableDiffusionInstructPix2PixPipeline\n\t- __call__\n\t- all\n\t- load_textual_inversion\n\t- load_lora_weights\n\t- save_lora_weights\n\n## StableDiffusionXLInstructPix2PixPipeline\n[[autodoc]] StableDiffusionXLInstructPix2PixPipeline\n\t- __call__\n\t- all\n"
  },
  {
    "path": "docs/source/en/api/pipelines/pixart.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# PixArt-α\n\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/header_collage.png)\n\n[PixArt-α: Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis](https://huggingface.co/papers/2310.00426) is Junsong Chen, Jincheng Yu, Chongjian Ge, Lewei Yao, Enze Xie, Yue Wu, Zhongdao Wang, James Kwok, Ping Luo, Huchuan Lu, and Zhenguo Li.\n\nThe abstract from the paper is:\n\n*The most advanced text-to-image (T2I) models require significant training costs (e.g., millions of GPU hours), seriously hindering the fundamental innovation for the AIGC community while increasing CO2 emissions. This paper introduces PIXART-α, a Transformer-based T2I diffusion model whose image generation quality is competitive with state-of-the-art image generators (e.g., Imagen, SDXL, and even Midjourney), reaching near-commercial application standards. Additionally, it supports high-resolution image synthesis up to 1024px resolution with low training cost, as shown in Figure 1 and 2. To achieve this goal, three core designs are proposed: (1) Training strategy decomposition: We devise three distinct training steps that separately optimize pixel dependency, text-image alignment, and image aesthetic quality; (2) Efficient T2I Transformer: We incorporate cross-attention modules into Diffusion Transformer (DiT) to inject text conditions and streamline the computation-intensive class-condition branch; (3) High-informative data: We emphasize the significance of concept density in text-image pairs and leverage a large Vision-Language model to auto-label dense pseudo-captions to assist text-image alignment learning. As a result, PIXART-α's training speed markedly surpasses existing large-scale T2I models, e.g., PIXART-α only takes 10.8% of Stable Diffusion v1.5's training time (675 vs. 6,250 A100 GPU days), saving nearly $300,000 ($26,000 vs. $320,000) and reducing 90% CO2 emissions. Moreover, compared with a larger SOTA model, RAPHAEL, our training cost is merely 1%. Extensive experiments demonstrate that PIXART-α excels in image quality, artistry, and semantic control. We hope PIXART-α will provide new insights to the AIGC community and startups to accelerate building their own high-quality yet low-cost generative models from scratch.*\n\nYou can find the original codebase at [PixArt-alpha/PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha) and all the available checkpoints at [PixArt-alpha](https://huggingface.co/PixArt-alpha).\n\nSome notes about this pipeline:\n\n* It uses a Transformer backbone (instead of a UNet) for denoising. As such it has a similar architecture as [DiT](./dit).\n* It was trained using text conditions computed from T5. This aspect makes the pipeline better at following complex text prompts with intricate details.\n* It is good at producing high-resolution images at different aspect ratios. To get the best results, the authors recommend some size brackets which can be found [here](https://github.com/PixArt-alpha/PixArt-alpha/blob/08fbbd281ec96866109bdd2cdb75f2f58fb17610/diffusion/data/datasets/utils.py).\n* It rivals the quality of state-of-the-art text-to-image generation systems (as of this writing) such as Stable Diffusion XL, Imagen, and DALL-E 2, while being more efficient than them.\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## Inference with under 8GB GPU VRAM\n\nRun the [`PixArtAlphaPipeline`] with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let's walk through a full-fledged example.\n\nFirst, install the [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library:\n\n```bash\npip install -U bitsandbytes\n```\n\nThen load the text encoder in 8-bit:\n\n```python\nfrom transformers import T5EncoderModel\nfrom diffusers import PixArtAlphaPipeline\nimport torch\n\ntext_encoder = T5EncoderModel.from_pretrained(\n    \"PixArt-alpha/PixArt-XL-2-1024-MS\",\n    subfolder=\"text_encoder\",\n    load_in_8bit=True,\n    device_map=\"auto\",\n\n)\npipe = PixArtAlphaPipeline.from_pretrained(\n    \"PixArt-alpha/PixArt-XL-2-1024-MS\",\n    text_encoder=text_encoder,\n    transformer=None,\n    device_map=\"auto\"\n)\n```\n\nNow, use the `pipe` to encode a prompt:\n\n```python\nwith torch.no_grad():\n    prompt = \"cute cat\"\n    prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)\n```\n\nSince text embeddings have been computed, remove the `text_encoder` and `pipe` from the memory, and free up some GPU VRAM:\n\n```python\nimport gc\n\ndef flush():\n    gc.collect()\n    torch.cuda.empty_cache()\n\ndel text_encoder\ndel pipe\nflush()\n```\n\nThen compute the latents with the prompt embeddings as inputs:\n\n```python\npipe = PixArtAlphaPipeline.from_pretrained(\n    \"PixArt-alpha/PixArt-XL-2-1024-MS\",\n    text_encoder=None,\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n\nlatents = pipe(\n    negative_prompt=None,\n    prompt_embeds=prompt_embeds,\n    negative_prompt_embeds=negative_embeds,\n    prompt_attention_mask=prompt_attention_mask,\n    negative_prompt_attention_mask=negative_prompt_attention_mask,\n    num_images_per_prompt=1,\n    output_type=\"latent\",\n).images\n\ndel pipe.transformer\nflush()\n```\n\n> [!TIP]\n> Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.\n\nOnce the latents are computed, pass it off to the VAE to decode into a real image:\n\n```python\nwith torch.no_grad():\n    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]\nimage = pipe.image_processor.postprocess(image, output_type=\"pil\")[0]\nimage.save(\"cat.png\")\n```\n\nBy deleting components you aren't using and flushing the GPU VRAM, you should be able to run [`PixArtAlphaPipeline`] with under 8GB GPU VRAM.\n\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/8bits_cat.png)\n\nIf you want a report of your memory-usage, run this [script](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e).\n\n> [!WARNING]\n> Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It's recommended to compare the outputs with and without 8-bit.\n\nWhile loading the `text_encoder`, you set `load_in_8bit` to `True`. You could also specify `load_in_4bit` to bring your memory requirements down even further to under 7GB.\n\n## PixArtAlphaPipeline\n\n[[autodoc]] PixArtAlphaPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/pixart_sigma.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# PixArt-Σ\n\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/header_collage_sigma.jpg)\n\n[PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation](https://huggingface.co/papers/2403.04692) is Junsong Chen, Jincheng Yu, Chongjian Ge, Lewei Yao, Enze Xie, Yue Wu, Zhongdao Wang, James Kwok, Ping Luo, Huchuan Lu, and Zhenguo Li.\n\nThe abstract from the paper is:\n\n*In this paper, we introduce PixArt-Σ, a Diffusion Transformer model (DiT) capable of directly generating images at 4K resolution. PixArt-Σ represents a significant advancement over its predecessor, PixArt-α, offering images of markedly higher fidelity and improved alignment with text prompts. A key feature of PixArt-Σ is its training efficiency. Leveraging the foundational pre-training of PixArt-α, it evolves from the ‘weaker’ baseline to a ‘stronger’ model via incorporating higher quality data, a process we term “weak-to-strong training”. The advancements in PixArt-Σ are twofold: (1) High-Quality Training Data: PixArt-Σ incorporates superior-quality image data, paired with more precise and detailed image captions. (2) Efficient Token Compression: we propose a novel attention module within the DiT framework that compresses both keys and values, significantly improving efficiency and facilitating ultra-high-resolution image generation. Thanks to these improvements, PixArt-Σ achieves superior image quality and user prompt adherence capabilities with significantly smaller model size (0.6B parameters) than existing text-to-image diffusion models, such as SDXL (2.6B parameters) and SD Cascade (5.1B parameters). Moreover, PixArt-Σ’s capability to generate 4K images supports the creation of high-resolution posters and wallpapers, efficiently bolstering the production of highquality visual content in industries such as film and gaming.*\n\nYou can find the original codebase at [PixArt-alpha/PixArt-sigma](https://github.com/PixArt-alpha/PixArt-sigma) and all the available checkpoints at [PixArt-alpha](https://huggingface.co/PixArt-alpha).\n\nSome notes about this pipeline:\n\n* It uses a Transformer backbone (instead of a UNet) for denoising. As such it has a similar architecture as [DiT](https://hf.co/docs/transformers/model_doc/dit).\n* It was trained using text conditions computed from T5. This aspect makes the pipeline better at following complex text prompts with intricate details.\n* It is good at producing high-resolution images at different aspect ratios. To get the best results, the authors recommend some size brackets which can be found [here](https://github.com/PixArt-alpha/PixArt-sigma/blob/master/diffusion/data/datasets/utils.py).\n* It rivals the quality of state-of-the-art text-to-image generation systems (as of this writing) such as PixArt-α, Stable Diffusion XL, Playground V2.0 and DALL-E 3, while being more efficient than them.\n* It shows the ability of generating super high resolution images, such as 2048px or even 4K.\n* It shows that text-to-image models can grow from a weak model to a stronger one through several improvements (VAEs, datasets, and so on.)\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n> [!TIP]\n> You can further improve generation quality by passing the generated image from [`PixArtSigmaPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.\n\n## Inference with under 8GB GPU VRAM\n\nRun the [`PixArtSigmaPipeline`] with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let's walk through a full-fledged example.\n\nFirst, install the [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library:\n\n```bash\npip install -U bitsandbytes\n```\n\nThen load the text encoder in 8-bit:\n\n```python\nfrom transformers import T5EncoderModel\nfrom diffusers import PixArtSigmaPipeline\nimport torch\n\ntext_encoder = T5EncoderModel.from_pretrained(\n    \"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS\",\n    subfolder=\"text_encoder\",\n    load_in_8bit=True,\n    device_map=\"auto\",\n)\npipe = PixArtSigmaPipeline.from_pretrained(\n    \"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS\",\n    text_encoder=text_encoder,\n    transformer=None,\n    device_map=\"balanced\"\n)\n```\n\nNow, use the `pipe` to encode a prompt:\n\n```python\nwith torch.no_grad():\n    prompt = \"cute cat\"\n    prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)\n```\n\nSince text embeddings have been computed, remove the `text_encoder` and `pipe` from the memory, and free up some GPU VRAM:\n\n```python\nimport gc\n\ndef flush():\n    gc.collect()\n    torch.cuda.empty_cache()\n\ndel text_encoder\ndel pipe\nflush()\n```\n\nThen compute the latents with the prompt embeddings as inputs:\n\n```python\npipe = PixArtSigmaPipeline.from_pretrained(\n    \"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS\",\n    text_encoder=None,\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n\nlatents = pipe(\n    negative_prompt=None,\n    prompt_embeds=prompt_embeds,\n    negative_prompt_embeds=negative_embeds,\n    prompt_attention_mask=prompt_attention_mask,\n    negative_prompt_attention_mask=negative_prompt_attention_mask,\n    num_images_per_prompt=1,\n    output_type=\"latent\",\n).images\n\ndel pipe.transformer\nflush()\n```\n\n> [!TIP]\n> Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.\n\nOnce the latents are computed, pass it off to the VAE to decode into a real image:\n\n```python\nwith torch.no_grad():\n    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]\nimage = pipe.image_processor.postprocess(image, output_type=\"pil\")[0]\nimage.save(\"cat.png\")\n```\n\nBy deleting components you aren't using and flushing the GPU VRAM, you should be able to run [`PixArtSigmaPipeline`] with under 8GB GPU VRAM.\n\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/8bits_cat.png)\n\nIf you want a report of your memory-usage, run this [script](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e).\n\n> [!WARNING]\n> Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It's recommended to compare the outputs with and without 8-bit.\n\nWhile loading the `text_encoder`, you set `load_in_8bit` to `True`. You could also specify `load_in_4bit` to bring your memory requirements down even further to under 7GB.\n\n## PixArtSigmaPipeline\n\n[[autodoc]] PixArtSigmaPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/prx.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n# PRX\n\n\nPRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.\n\n## Available models\n\nPRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.\n\n\n| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |\n|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|\n| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |\n| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |\n| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |\n| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |\n| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |\n| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |\n| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |\n| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s\n\nRefer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.\n\n## Loading the pipeline\n\nLoad the pipeline with [`~DiffusionPipeline.from_pretrained`].\n\n```py\nfrom diffusers.pipelines.prx import PRXPipeline\n\n# Load pipeline - VAE and text encoder will be loaded from HuggingFace\npipe = PRXPipeline.from_pretrained(\"Photoroom/prx-512-t2i-sft\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\nprompt = \"A front-facing portrait of a lion the golden savanna at sunset.\"\nimage = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]\nimage.save(\"prx_output.png\")\n```\n\n### Manual Component Loading\n\nLoad components individually to customize the pipeline for instance to use quantized models.\n\n```py\nimport torch\nfrom diffusers.pipelines.prx import PRXPipeline\nfrom diffusers.models import AutoencoderKL, AutoencoderDC\nfrom diffusers.models.transformers.transformer_prx import PRXTransformer2DModel\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom transformers import T5GemmaModel, GemmaTokenizerFast\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig\nfrom transformers import BitsAndBytesConfig as BitsAndBytesConfig\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\n# Load transformer\ntransformer = PRXTransformer2DModel.from_pretrained(\n    \"checkpoints/prx-512-t2i-sft\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.bfloat16,\n)\n\n# Load scheduler\nscheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n    \"checkpoints/prx-512-t2i-sft\", subfolder=\"scheduler\"\n)\n\n# Load T5Gemma text encoder\nt5gemma_model = T5GemmaModel.from_pretrained(\"google/t5gemma-2b-2b-ul2\",\n                                            quantization_config=quant_config,\n                                            torch_dtype=torch.bfloat16)\ntext_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)\ntokenizer = GemmaTokenizerFast.from_pretrained(\"google/t5gemma-2b-2b-ul2\")\ntokenizer.model_max_length = 256\n\n# Load VAE - choose either Flux VAE or DC-AE\n# Flux VAE\nvae = AutoencoderKL.from_pretrained(\"black-forest-labs/FLUX.1-dev\",\n                                    subfolder=\"vae\",\n                                    quantization_config=quant_config,\n                                    torch_dtype=torch.bfloat16)\n\npipe = PRXPipeline(\n    transformer=transformer,\n    scheduler=scheduler,\n    text_encoder=text_encoder,\n    tokenizer=tokenizer,\n    vae=vae\n)\npipe.to(\"cuda\")\n```\n\n\n## Memory Optimization\n\nFor memory-constrained environments:\n\n```py\nimport torch\nfrom diffusers.pipelines.prx import PRXPipeline\n\npipe = PRXPipeline.from_pretrained(\"Photoroom/prx-512-t2i-sft\", torch_dtype=torch.bfloat16)\npipe.enable_model_cpu_offload()  # Offload components to CPU when not in use\n\n# Or use sequential CPU offload for even lower memory\npipe.enable_sequential_cpu_offload()\n```\n\n## PRXPipeline\n\n[[autodoc]] PRXPipeline\n  - all\n  - __call__\n\n## PRXPipelineOutput\n\n[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/qwenimage.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n# QwenImage\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nQwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.\n\nQwen-Image comes in the following variants:\n\n| model type | model id |\n|:----------:|:--------:|\n| Qwen-Image | [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) |\n| Qwen-Image-Edit | [`Qwen/Qwen-Image-Edit`](https://huggingface.co/Qwen/Qwen-Image-Edit) |\n| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |\n\n> [!TIP]\n> See the [Caching](../../optimization/cache) guide to speed up inference by storing and reusing intermediate outputs.\n\n## LoRA for faster inference\n\nUse a LoRA from `lightx2v/Qwen-Image-Lightning` to speed up inference by reducing the\nnumber of steps. Refer to the code snippet below:\n\n<details>\n<summary>Code</summary>\n\n```py\nfrom diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler\nimport torch \nimport math\n\nckpt_id = \"Qwen/Qwen-Image\"\n\n# From\n# https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10\nscheduler_config = {\n    \"base_image_seq_len\": 256,\n    \"base_shift\": math.log(3),  # We use shift=3 in distillation\n    \"invert_sigmas\": False,\n    \"max_image_seq_len\": 8192,\n    \"max_shift\": math.log(3),  # We use shift=3 in distillation\n    \"num_train_timesteps\": 1000,\n    \"shift\": 1.0,\n    \"shift_terminal\": None,  # set shift_terminal to None\n    \"stochastic_sampling\": False,\n    \"time_shift_type\": \"exponential\",\n    \"use_beta_sigmas\": False,\n    \"use_dynamic_shifting\": True,\n    \"use_exponential_sigmas\": False,\n    \"use_karras_sigmas\": False,\n}\nscheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)\npipe = DiffusionPipeline.from_pretrained(\n    ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16\n).to(\"cuda\")\npipe.load_lora_weights(\n    \"lightx2v/Qwen-Image-Lightning\", weight_name=\"Qwen-Image-Lightning-8steps-V1.0.safetensors\"\n)\n\nprompt = \"a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition.\"\nnegative_prompt = \" \"\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    width=1024,\n    height=1024,\n    num_inference_steps=8,\n    true_cfg_scale=1.0,\n    generator=torch.manual_seed(0),\n).images[0]\nimage.save(\"qwen_fewsteps.png\")\n```\n\n</details>\n\n> [!TIP]\n> The `guidance_scale` parameter in the pipeline is there to support future guidance-distilled models when they come up. Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance, please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like \" \") should enable classifier-free guidance computations.\n\n## Multi-image reference with QwenImageEditPlusPipeline\n\nWith [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.\n\n```py\nimport torch\nfrom PIL import Image\nfrom diffusers import QwenImageEditPlusPipeline\nfrom diffusers.utils import load_image\n\npipe = QwenImageEditPlusPipeline.from_pretrained(\n    \"Qwen/Qwen-Image-Edit-2509\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\nimage_1 = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg\")\nimage_2 = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png\")\nimage = pipe(\n    image=[image_1, image_2],\n    prompt='''put the penguin and the cat at a game show called \"Qwen Edit Plus Games\"''',\n    num_inference_steps=50\n).images[0]\n```\n\n## Performance\n\n### torch.compile\n\nUsing `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):\n\n```python\nimport torch\nfrom diffusers import QwenImagePipeline\n\npipe = QwenImagePipeline.from_pretrained(\"Qwen/Qwen-Image\", torch_dtype=torch.bfloat16).to(\"cuda\")\npipe.transformer = torch.compile(pipe.transformer)\n\n# First call triggers compilation (~7s overhead)\n# Subsequent calls run at ~2.4x faster\nimage = pipe(\"a cat\", num_inference_steps=50).images[0]\n```\n\n### Batched Inference with Variable-Length Prompts\n\nWhen using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.\n\n```python\n# CFG with different prompt lengths works correctly\nimage = pipe(\n    prompt=\"A cat\",\n    negative_prompt=\"blurry, low quality, distorted\",\n    true_cfg_scale=3.5,\n    num_inference_steps=50,\n).images[0]\n```\n\nFor detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).\n\n## QwenImagePipeline\n\n[[autodoc]] QwenImagePipeline\n  - all\n  - __call__\n\n## QwenImageImg2ImgPipeline\n\n[[autodoc]] QwenImageImg2ImgPipeline\n  - all\n  - __call__\n\n## QwenImageInpaintPipeline\n\n[[autodoc]] QwenImageInpaintPipeline\n  - all\n  - __call__\n\n## QwenImageEditPipeline\n\n[[autodoc]] QwenImageEditPipeline\n  - all\n  - __call__\n\n## QwenImageEditInpaintPipeline\n\n[[autodoc]] QwenImageEditInpaintPipeline\n  - all\n  - __call__\n\n## QwenImageControlNetPipeline\n\n[[autodoc]] QwenImageControlNetPipeline\n  - all\n  - __call__\n\n## QwenImageEditPlusPipeline\n\n[[autodoc]] QwenImageEditPlusPipeline\n  - all\n  - __call__\n\n## QwenImageLayeredPipeline\n\n[[autodoc]] QwenImageLayeredPipeline\n  - all\n  - __call__\n\n## QwenImagePipelineOutput\n\n[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput"
  },
  {
    "path": "docs/source/en/api/pipelines/sana.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n# SanaPipeline\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n  <img alt=\"MPS\" src=\"https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22\">\n</div>\n\n[SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.\n\nThe abstract from the paper is:\n\n*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\nThis pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj) and [chenjy2003](https://github.com/chenjy2003). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model).\n\nAvailable models:\n\n| Model | Recommended dtype |\n|:-----:|:-----------------:|\n| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |\n| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |\n| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |\n| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |\n| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |\n| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |\n| [`Efficient-Large-Model/Sana_600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_512px_diffusers) | `torch.float16` |\n\nRefer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) collection for more information.\n\nNote: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype. \n\n> [!TIP]\n> Make sure to pass the `variant` argument for downloaded checkpoints to use lower disk space. Set it to `\"fp16\"` for models with recommended dtype as `torch.float16`, and `\"bf16\"` for models with recommended dtype as `torch.bfloat16`. By default, `torch.float32` weights are downloaded, which use twice the amount of disk storage. Additionally, `torch.float32` weights can be downcasted on-the-fly by specifying the `torch_dtype` argument. Read about it in the [docs](https://huggingface.co/docs/diffusers/v0.31.0/en/api/pipelines/overview#diffusers.DiffusionPipeline.from_pretrained).\n\n## Quantization\n\nQuantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.\n\nRefer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaPipeline`] for inference with bitsandbytes.\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaPipeline\nfrom transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel\n\nquant_config = BitsAndBytesConfig(load_in_8bit=True)\ntext_encoder_8bit = AutoModel.from_pretrained(\n    \"Efficient-Large-Model/Sana_1600M_1024px_diffusers\",\n    subfolder=\"text_encoder\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_8bit = SanaTransformer2DModel.from_pretrained(\n    \"Efficient-Large-Model/Sana_1600M_1024px_diffusers\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\npipeline = SanaPipeline.from_pretrained(\n    \"Efficient-Large-Model/Sana_1600M_1024px_diffusers\",\n    text_encoder=text_encoder_8bit,\n    transformer=transformer_8bit,\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n)\n\nprompt = \"a tiny astronaut hatching from an egg on the moon\"\nimage = pipeline(prompt).images[0]\nimage.save(\"sana.png\")\n```\n\n## SanaPipeline\n\n[[autodoc]] SanaPipeline\n  - all\n  - __call__\n\n## SanaPAGPipeline\n\n[[autodoc]] SanaPAGPipeline\n  - all\n  - __call__\n\n## SanaPipelineOutput\n\n[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/sana_sprint.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n# SANA-Sprint\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA, MIT HAN Lab, and Hugging Face by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han\n\nThe abstract from the paper is:\n\n*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*\n\nThis pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).\n\nAvailable models:\n\n|                                                                    Model                                                                    | Recommended dtype |\n|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:|\n| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16`  |\n| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16`  |\n\nRefer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information.\n\nNote: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype. \n\n\n## Quantization\n\nQuantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.\n\nRefer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes.\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline\nfrom transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel\n\nquant_config = BitsAndBytesConfig(load_in_8bit=True)\ntext_encoder_8bit = AutoModel.from_pretrained(\n    \"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers\",\n    subfolder=\"text_encoder\",\n    quantization_config=quant_config,\n    torch_dtype=torch.bfloat16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_8bit = SanaTransformer2DModel.from_pretrained(\n    \"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.bfloat16,\n)\n\npipeline = SanaSprintPipeline.from_pretrained(\n    \"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers\",\n    text_encoder=text_encoder_8bit,\n    transformer=transformer_8bit,\n    torch_dtype=torch.bfloat16,\n    device_map=\"balanced\",\n)\n\nprompt = \"a tiny astronaut hatching from an egg on the moon\"\nimage = pipeline(prompt).images[0]\nimage.save(\"sana.png\")\n```\n\n## Setting `max_timesteps`\n\nUsers can tweak the `max_timesteps` value for experimenting with the visual quality of the generated outputs. The default `max_timesteps` value was obtained with an inference-time search process. For more details about it, check out the paper.\n\n## Image to Image \n\nThe [`SanaSprintImg2ImgPipeline`] is a pipeline for image-to-image generation. It takes an input image and a prompt, and generates a new image based on the input image and the prompt.\n\n```py\nimport torch\nfrom diffusers import SanaSprintImg2ImgPipeline\nfrom diffusers.utils.loading_utils import load_image\n\nimage = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png\"\n)\n\npipe = SanaSprintImg2ImgPipeline.from_pretrained(\n    \"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers\", \n    torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\nimage = pipe(\n    prompt=\"a cute pink bear\", \n    image=image, \n    strength=0.5, \n    height=832, \n    width=480\n).images[0]\nimage.save(\"output.png\")\n```\n\n## SanaSprintPipeline\n\n[[autodoc]] SanaSprintPipeline\n  - all\n  - __call__\n\n## SanaSprintImg2ImgPipeline\n\n[[autodoc]] SanaSprintImg2ImgPipeline\n  - all\n  - __call__\n\n\n## SanaPipelineOutput\n\n[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/sana_video.md",
    "content": "<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n# Sana-Video\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n  <img alt=\"MPS\" src=\"https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22\">\n</div>\n\n[SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.\n\nThe abstract from the paper is:\n\n*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation. [this https URL](https://github.com/NVlabs/SANA).*\n\nThis pipeline was contributed by SANA Team. The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://hf.co/collections/Efficient-Large-Model/sana-video).\n\nAvailable models:\n\n| Model | Recommended dtype |\n|:-----:|:-----------------:|\n| [`Efficient-Large-Model/SANA-Video_2B_480p_diffusers`](https://huggingface.co/Efficient-Large-Model/ANA-Video_2B_480p_diffusers) | `torch.bfloat16` |\n\nRefer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-video) collection for more information.\n\nNote: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype. \n\n\n## Generation Pipelines\n\n<hfoptions id=\"generation pipelines\">`\n<hfoption id=\"Text-to-Video\">\n\nThe example below demonstrates how to use the text-to-video pipeline to generate a video using a text description.\n\n```python\npipe = SanaVideoPipeline.from_pretrained(\n    \"Efficient-Large-Model/SANA-Video_2B_480p_diffusers\", \n    torch_dtype=torch.bfloat16,\n)\npipe.text_encoder.to(torch.bfloat16)\npipe.vae.to(torch.float32)\npipe.to(\"cuda\")\n\nprompt = \"A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window.\"\nnegative_prompt = \"A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience.\"\nmotion_scale = 30\nmotion_prompt = f\" motion score: {motion_scale}.\"\nprompt = prompt + motion_prompt\n\nvideo = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=480,\n    width=832,\n    frames=81,\n    guidance_scale=6,\n    num_inference_steps=50,\n    generator=torch.Generator(device=\"cuda\").manual_seed(0),\n).frames[0]\n\nexport_to_video(video, \"sana_video.mp4\", fps=16)\n```\n\n</hfoption>\n<hfoption id=\"Image-to-Video\">\n\nThe example below demonstrates how to use the image-to-video pipeline to generate a video using a text description and a starting frame.\n\n```python\npipe = SanaImageToVideoPipeline.from_pretrained(\n    \"Efficient-Large-Model/SANA-Video_2B_480p_diffusers\",\n    torch_dtype=torch.bfloat16,\n)\npipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)\npipe.vae.to(torch.float32)\npipe.text_encoder.to(torch.bfloat16)\npipe.to(\"cuda\")\n\nimage = load_image(\"https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png\")\nprompt = \"A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle.\"\nnegative_prompt = \"A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience.\"\nmotion_scale = 30\nmotion_prompt = f\" motion score: {motion_scale}.\"\nprompt = prompt + motion_prompt\n\nmotion_scale = 30.0\n\nvideo = pipe(\n    image=image,\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=480,\n    width=832,\n    frames=81,\n    guidance_scale=6,\n    num_inference_steps=50,\n    generator=torch.Generator(device=\"cuda\").manual_seed(0),\n).frames[0]\n\nexport_to_video(video, \"sana-i2v.mp4\", fps=16)\n```\n\n</hfoption>\n</hfoptions>\n\n\n## Quantization\n\nQuantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.\n\nRefer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaVideoPipeline`] for inference with bitsandbytes.\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaVideoTransformer3DModel, SanaVideoPipeline\nfrom transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel\n\nquant_config = BitsAndBytesConfig(load_in_8bit=True)\ntext_encoder_8bit = AutoModel.from_pretrained(\n    \"Efficient-Large-Model/SANA-Video_2B_480p_diffusers\",\n    subfolder=\"text_encoder\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_8bit = SanaVideoTransformer3DModel.from_pretrained(\n    \"Efficient-Large-Model/SANA-Video_2B_480p_diffusers\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\npipeline = SanaVideoPipeline.from_pretrained(\n    \"Efficient-Large-Model/SANA-Video_2B_480p_diffusers\",\n    text_encoder=text_encoder_8bit,\n    transformer=transformer_8bit,\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n)\n\nmodel_score = 30\nprompt = \"Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional.\"\nnegative_prompt = \"A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience.\"\nmotion_prompt = f\" motion score: {model_score}.\"\nprompt = prompt + motion_prompt\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=480,\n    width=832,\n    num_frames=81,\n    guidance_scale=6.0,\n    num_inference_steps=50\n).frames[0]\nexport_to_video(output, \"sana-video-output.mp4\", fps=16)\n```\n\n## SanaVideoPipeline\n\n[[autodoc]] SanaVideoPipeline\n  - all\n  - __call__\n\n\n## SanaImageToVideoPipeline\n\n[[autodoc]] SanaImageToVideoPipeline\n  - all\n  - __call__\n\n\n## SanaVideoPipelineOutput\n\n[[autodoc]] pipelines.sana_video.pipeline_sana_video.SanaVideoPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/self_attention_guidance.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# Self-Attention Guidance\n\n[Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://huggingface.co/papers/2210.00939) is by Susung Hong et al.\n\nThe abstract from the paper is:\n\n*Denoising diffusion models (DDMs) have attracted attention for their exceptional generation quality and diversity. This success is largely attributed to the use of class- or text-conditional diffusion guidance methods, such as classifier and classifier-free guidance. In this paper, we present a more comprehensive perspective that goes beyond the traditional guidance methods. From this generalized perspective, we introduce novel condition- and training-free strategies to enhance the quality of generated images. As a simple solution, blur guidance improves the suitability of intermediate samples for their fine-scale information and structures, enabling diffusion models to generate higher quality samples with a moderate guidance scale. Improving upon this, Self-Attention Guidance (SAG) uses the intermediate self-attention maps of diffusion models to enhance their stability and efficacy. Specifically, SAG adversarially blurs only the regions that diffusion models attend to at each iteration and guides them accordingly. Our experimental results show that our SAG improves the performance of various diffusion models, including ADM, IDDPM, Stable Diffusion, and DiT. Moreover, combining SAG with conventional guidance methods leads to further improvement.*\n\nYou can find additional information about Self-Attention Guidance on the [project page](https://ku-cvlab.github.io/Self-Attention-Guidance), [original codebase](https://github.com/KU-CVLAB/Self-Attention-Guidance), and try it out in a [demo](https://huggingface.co/spaces/susunghong/Self-Attention-Guidance) or [notebook](https://colab.research.google.com/github/SusungHong/Self-Attention-Guidance/blob/main/SAG_Stable.ipynb).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## StableDiffusionSAGPipeline\n[[autodoc]] StableDiffusionSAGPipeline\n\t- __call__\n\t- all\n\n## StableDiffusionOutput\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/semantic_stable_diffusion.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# Semantic Guidance\n\nSemantic Guidance for Diffusion Models was proposed in [SEGA: Instructing Text-to-Image Models using Semantic Guidance](https://huggingface.co/papers/2301.12247) and provides strong semantic control over image generation.\nSmall changes to the text prompt usually result in entirely different output images. However, with SEGA a variety of changes to the image are enabled that can be controlled easily and intuitively, while staying true to the original image composition.\n\nThe abstract from the paper is:\n\n*Text-to-image diffusion models have recently received a lot of interest for their astonishing ability to produce high-fidelity images from text only. However, achieving one-shot generation that aligns with the user's intent is nearly impossible, yet small changes to the input prompt often result in very different images. This leaves the user with little semantic control. To put the user in control, we show how to interact with the diffusion process to flexibly steer it along semantic directions. This semantic guidance (SEGA) generalizes to any generative architecture using classifier-free guidance. More importantly, it allows for subtle and extensive edits, changes in composition and style, as well as optimizing the overall artistic conception. We demonstrate SEGA's effectiveness on both latent and pixel-based diffusion models such as Stable Diffusion, Paella, and DeepFloyd-IF using a variety of tasks, thus providing strong evidence for its versatility, flexibility, and improvements over existing methods.*\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## SemanticStableDiffusionPipeline\n[[autodoc]] SemanticStableDiffusionPipeline\n\t- all\n\t- __call__\n\n## SemanticStableDiffusionPipelineOutput\n[[autodoc]] pipelines.semantic_stable_diffusion.pipeline_output.SemanticStableDiffusionPipelineOutput\n\t- all\n"
  },
  {
    "path": "docs/source/en/api/pipelines/shap_e.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\nhttp://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Shap-E\n\nThe Shap-E model was proposed in [Shap-E: Generating Conditional 3D Implicit Functions](https://huggingface.co/papers/2305.02463) by Alex Nichol and Heewoo Jun from [OpenAI](https://github.com/openai).\n\nThe abstract from the paper is:\n\n*We present Shap-E, a conditional generative model for 3D assets. Unlike recent work on 3D generative models which produce a single output representation, Shap-E directly generates the parameters of implicit functions that can be rendered as both textured meshes and neural radiance fields. We train Shap-E in two stages: first, we train an encoder that deterministically maps 3D assets into the parameters of an implicit function; second, we train a conditional diffusion model on outputs of the encoder. When trained on a large dataset of paired 3D and text data, our resulting models are capable of generating complex and diverse 3D assets in a matter of seconds. When compared to Point-E, an explicit generative model over point clouds, Shap-E converges faster and reaches comparable or better sample quality despite modeling a higher-dimensional, multi-representation output space.*\n\nThe original codebase can be found at [openai/shap-e](https://github.com/openai/shap-e).\n\n> [!TIP]\n> See the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## ShapEPipeline\n[[autodoc]] ShapEPipeline\n\t- all\n\t- __call__\n\n## ShapEImg2ImgPipeline\n[[autodoc]] ShapEImg2ImgPipeline\n\t- all\n\t- __call__\n\n## ShapEPipelineOutput\n[[autodoc]] pipelines.shap_e.pipeline_shap_e.ShapEPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/skyreels_v2.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n<div style=\"float: right;\">\n  <div class=\"flex flex-wrap space-x-1\">\n    <a href=\"https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference\" target=\"_blank\" rel=\"noopener\">\n      <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n    </a>\n  </div>\n</div>\n\n# SkyReels-V2: Infinite-length Film Generative model\n\n[SkyReels-V2](https://huggingface.co/papers/2504.13074) by the SkyReels Team from Skywork AI.\n\n*Recent advances in video generation have been driven by diffusion models and autoregressive frameworks, yet critical challenges persist in harmonizing prompt adherence, visual quality, motion dynamics, and duration: compromises in motion dynamics to enhance temporal visual quality, constrained video duration (5-10 seconds) to prioritize resolution, and inadequate shot-aware generation stemming from general-purpose MLLMs' inability to interpret cinematic grammar, such as shot composition, actor expressions, and camera motions. These intertwined limitations hinder realistic long-form synthesis and professional film-style generation. To address these limitations, we propose SkyReels-V2, an Infinite-length Film Generative Model, that synergizes Multi-modal Large Language Model (MLLM), Multi-stage Pretraining, Reinforcement Learning, and Diffusion Forcing Framework. Firstly, we design a comprehensive structural representation of video that combines the general descriptions by the Multi-modal LLM and the detailed shot language by sub-expert models. Aided with human annotation, we then train a unified Video Captioner, named SkyCaptioner-V1, to efficiently label the video data. Secondly, we establish progressive-resolution pretraining for the fundamental video generation, followed by a four-stage post-training enhancement: Initial concept-balanced Supervised Fine-Tuning (SFT) improves baseline quality; Motion-specific Reinforcement Learning (RL) training with human-annotated and synthetic distortion data addresses dynamic artifacts; Our diffusion forcing framework with non-decreasing noise schedules enables long-video synthesis in an efficient search space; Final high-quality SFT refines visual fidelity. All the code and models are available at [this https URL](https://github.com/SkyworkAI/SkyReels-V2).*\n\nYou can find all the original SkyReels-V2 checkpoints under the [Skywork](https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9) organization.\n\nThe following SkyReels-V2 models are supported in Diffusers:\n- [SkyReels-V2 DF 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers)\n- [SkyReels-V2 DF 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-540P-Diffusers)\n- [SkyReels-V2 DF 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P-Diffusers)\n- [SkyReels-V2 T2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-540P-Diffusers)\n- [SkyReels-V2 T2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-720P-Diffusers)\n- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers)\n- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers)\n- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers)\n\nThis model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).\n\n> [!TIP]\n> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation.\n\n### A _Visual_ Demonstration\n\nThe example below has the following parameters:\n\n- `base_num_frames=97`\n- `num_frames=97`\n- `num_inference_steps=30`\n- `ar_step=5`\n- `causal_block_size=5`\n\nWith `vae_scale_factor_temporal=4`, expect `5` blocks of `5` frames each as calculated by:\n\n`num_latent_frames: (97-1)//vae_scale_factor_temporal+1 = 25 frames -> 5 blocks of 5 frames each`\n\nAnd the maximum context length in the latent space is calculated with `base_num_latent_frames`:\n\n`base_num_latent_frames = (97-1)//vae_scale_factor_temporal+1 = 25 -> 25//5 = 5 blocks`\n\nAsynchronous Processing Timeline:\n```text\n┌─────────────────────────────────────────────────────────────────┐\n│ Steps:    1    6   11   16   21   26   31   36   41   46   50   │\n│ Block 1: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■]                       │\n│ Block 2:      [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■]                  │\n│ Block 3:           [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■]             │\n│ Block 4:                [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■]        │\n│ Block 5:                     [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■]   │\n└─────────────────────────────────────────────────────────────────┘\n```\n\nFor Long Videos (`num_frames` > `base_num_frames`):\n`base_num_frames` acts as the \"sliding window size\" for processing long videos.\n\nExample: `257`-frame video with `base_num_frames=97`, `overlap_history=17`\n```text\n┌──── Iteration 1 (frames 1-97) ────┐\n│ Processing window: 97 frames      │ → 5 blocks,\n│ Generates: frames 1-97            │   async processing\n└───────────────────────────────────┘\n            ┌────── Iteration 2 (frames 81-177) ──────┐\n            │ Processing window: 97 frames            │\n            │ Overlap: 17 frames (81-97) from prev    │ → 5 blocks,\n            │ Generates: frames 98-177                │   async processing\n            └─────────────────────────────────────────┘\n                        ┌────── Iteration 3 (frames 161-257) ──────┐\n                        │ Processing window: 97 frames             │\n                        │ Overlap: 17 frames (161-177) from prev   │ → 5 blocks,\n                        │ Generates: frames 178-257                │   async processing\n                        └──────────────────────────────────────────┘\n```\n\nEach iteration independently runs the asynchronous processing with its own `5` blocks.\n`base_num_frames` controls:\n1. Memory usage (larger window = more VRAM)\n2. Model context length (must match training constraints)\n3. Number of blocks per iteration (`base_num_latent_frames // causal_block_size`)\n\nEach block takes `30` steps to complete denoising.\nBlock N starts at step: `1 + (N-1) x ar_step`\nTotal steps: `30 + (5-1) x 5 = 50` steps\n\n\nSynchronous mode (`ar_step=0`) would process all blocks/frames simultaneously:\n```text\n┌──────────────────────────────────────────────┐\n│ Steps:       1            ...            30  │\n│ All blocks: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │\n└──────────────────────────────────────────────┘\n```\nTotal steps: `30` steps\n\n\nAn example on how the step matrix is constructed for asynchronous processing:\nGiven the parameters: (`num_inference_steps=30, flow_shift=8, num_frames=97, ar_step=5, causal_block_size=5`)\n```\n- num_latent_frames = (97 frames - 1) // (4 temporal downsampling) + 1 = 25\n- step_template = [999, 995, 991, 986, 980, 975, 969, 963, 956, 948,\n                   941, 932, 922, 912, 901, 888, 874, 859, 841, 822,\n                   799, 773, 743, 708, 666, 615, 551, 470, 363, 216]\n```\n\nThe algorithm creates a `50x25` `step_matrix` where:\n```\n- Row 1:  [999×5, 999×5, 999×5, 999×5, 999×5]\n- Row 2:  [995×5, 999×5, 999×5, 999×5, 999×5]\n- Row 3:  [991×5, 999×5, 999×5, 999×5, 999×5]\n- ...\n- Row 7:  [969×5, 995×5, 999×5, 999×5, 999×5]\n- ...\n- Row 21: [799×5, 888×5, 941×5, 975×5, 999×5]\n- ...\n- Row 35: [  0×5, 216×5, 666×5, 822×5, 901×5]\n- ...\n- Row 42: [  0×5,   0×5,   0×5, 551×5, 773×5]\n- ...\n- Row 50: [  0×5,   0×5,   0×5,   0×5, 216×5]\n```\n\nDetailed Row `6` Analysis:\n```\n- step_matrix[5]:      [ 975×5,  999×5,   999×5,   999×5,   999×5]\n- step_index[5]:       [   6×5,    1×5,     0×5,     0×5,     0×5]\n- step_update_mask[5]: [True×5, True×5, False×5, False×5, False×5]\n- valid_interval[5]:   (0, 25)\n```\n\nKey Pattern: Block `i` lags behind Block `i-1` by exactly `ar_step=5` timesteps, creating the\nstaggered \"diffusion forcing\" effect where later blocks condition on cleaner earlier blocks.\n\n\n### Text-to-Video Generation\n\nThe example below demonstrates how to generate a video from text.\n\n<hfoptions id=\"T2V usage\">\n<hfoption id=\"T2V memory\">\n\nRefer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.\n\nFrom the original repo:\n>You can use --ar_step 5 to enable asynchronous inference. When asynchronous inference, --causal_block_size 5 is recommended while it is not supposed to be set for synchronous generation... Asynchronous inference will take more steps to diffuse the whole sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance.\n\n```py\nimport torch\nfrom diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline, UniPCMultistepScheduler\nfrom diffusers.utils import export_to_video\n\n\nmodel_id = \"Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers\"\nvae = AutoModel.from_pretrained(model_id, subfolder=\"vae\", torch_dtype=torch.float32)\n\npipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(\n    model_id,\n    vae=vae,\n    torch_dtype=torch.bfloat16,\n)\npipeline.to(\"cuda\")\nflow_shift = 8.0  # 8.0 for T2V, 5.0 for I2V\npipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)\n\nprompt = \"A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window.\"\n\noutput = pipeline(\n    prompt=prompt,\n    num_inference_steps=30,\n    height=544,  # 720 for 720P\n    width=960,   # 1280 for 720P\n    num_frames=97,\n    base_num_frames=97,  # 121 for 720P\n    ar_step=5,  # Controls asynchronous inference (0 for synchronous mode)\n    causal_block_size=5,  # Number of frames in each block for asynchronous processing\n    overlap_history=None,  # Number of frames to overlap for smooth transitions in long videos; 17 for long video generations\n    addnoise_condition=20,  # Improves consistency in long video generation\n).frames[0]\nexport_to_video(output, \"video.mp4\", fps=24, quality=8)\n```\n\n</hfoption>\n</hfoptions>\n\n### First-Last-Frame-to-Video Generation\n\nThe example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame.\n\n<hfoptions id=\"FLF2V usage\">\n<hfoption id=\"usage\">\n\n```python\nimport numpy as np\nimport torch\nimport torchvision.transforms.functional as TF\nfrom diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPipeline, UniPCMultistepScheduler\nfrom diffusers.utils import export_to_video, load_image\n\n\nmodel_id = \"Skywork/SkyReels-V2-DF-1.3B-720P-Diffusers\"\nvae = AutoencoderKLWan.from_pretrained(model_id, subfolder=\"vae\", torch_dtype=torch.float32)\npipeline = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(\n    model_id, vae=vae, torch_dtype=torch.bfloat16\n)\npipeline.to(\"cuda\")\nflow_shift = 5.0  # 8.0 for T2V, 5.0 for I2V\npipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)\n\nfirst_frame = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png\")\nlast_frame = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png\")\n\ndef aspect_ratio_resize(image, pipeline, max_area=720 * 1280):\n    aspect_ratio = image.height / image.width\n    mod_value = pipeline.vae_scale_factor_spatial * pipeline.transformer.config.patch_size[1]\n    height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value\n    width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value\n    image = image.resize((width, height))\n    return image, height, width\n\ndef center_crop_resize(image, height, width):\n    # Calculate resize ratio to match first frame dimensions\n    resize_ratio = max(width / image.width, height / image.height)\n\n    # Resize the image\n    width = round(image.width * resize_ratio)\n    height = round(image.height * resize_ratio)\n    size = [width, height]\n    image = TF.center_crop(image, size)\n\n    return image, height, width\n\nfirst_frame, height, width = aspect_ratio_resize(first_frame, pipeline)\nif last_frame.size != first_frame.size:\n    last_frame, _, _ = center_crop_resize(last_frame, height, width)\n\nprompt = \"CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective.\"\n\noutput = pipeline(\n    image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.0\n).frames[0]\nexport_to_video(output, \"video.mp4\", fps=24, quality=8)\n```\n\n</hfoption>\n</hfoptions>\n\n\n### Video-to-Video Generation\n\n<hfoptions id=\"V2V usage\">\n<hfoption id=\"usage\">\n\n`SkyReelsV2DiffusionForcingVideoToVideoPipeline` extends a given video.\n\n```python\nimport numpy as np\nimport torch\nimport torchvision.transforms.functional as TF\nfrom diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingVideoToVideoPipeline, UniPCMultistepScheduler\nfrom diffusers.utils import export_to_video, load_video\n\n\nmodel_id = \"Skywork/SkyReels-V2-DF-1.3B-720P-Diffusers\"\nvae = AutoencoderKLWan.from_pretrained(model_id, subfolder=\"vae\", torch_dtype=torch.float32)\npipeline = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained(\n    model_id, vae=vae, torch_dtype=torch.bfloat16\n)\npipeline.to(\"cuda\")\nflow_shift = 5.0  # 8.0 for T2V, 5.0 for I2V\npipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)\n\nvideo = load_video(\"input_video.mp4\")\n\nprompt = \"CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective.\"\n\noutput = pipeline(\n    video=video, prompt=prompt, height=720, width=1280, guidance_scale=5.0, overlap_history=17,\n    num_inference_steps=30, num_frames=257, base_num_frames=121#, ar_step=5, causal_block_size=5,\n).frames[0]\nexport_to_video(output, \"video.mp4\", fps=24, quality=8)\n# Total frames will be the number of frames of the given video + 257\n```\n\n</hfoption>\n</hfoptions>\n\n## Notes\n\n- SkyReels-V2 supports LoRAs with [`~loaders.SkyReelsV2LoraLoaderMixin.load_lora_weights`].\n\n`SkyReelsV2Pipeline` and `SkyReelsV2ImageToVideoPipeline` are also available without Diffusion Forcing framework applied.\n\n\n## SkyReelsV2DiffusionForcingPipeline\n\n[[autodoc]] SkyReelsV2DiffusionForcingPipeline\n  - all\n  - __call__\n\n## SkyReelsV2DiffusionForcingImageToVideoPipeline\n\n[[autodoc]] SkyReelsV2DiffusionForcingImageToVideoPipeline\n  - all\n  - __call__\n\n## SkyReelsV2DiffusionForcingVideoToVideoPipeline\n\n[[autodoc]] SkyReelsV2DiffusionForcingVideoToVideoPipeline\n  - all\n  - __call__\n\n## SkyReelsV2Pipeline\n\n[[autodoc]] SkyReelsV2Pipeline\n  - all\n  - __call__\n\n## SkyReelsV2ImageToVideoPipeline\n\n[[autodoc]] SkyReelsV2ImageToVideoPipeline\n  - all\n  - __call__\n\n## SkyReelsV2PipelineOutput\n\n[[autodoc]] pipelines.skyreels_v2.pipeline_output.SkyReelsV2PipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_audio.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Audio\n\nStable Audio was proposed in [Stable Audio Open](https://huggingface.co/papers/2407.14358) by Zach Evans et al. . it takes a text prompt as input and predicts the corresponding sound or music sample.\n\nStable Audio Open generates variable-length (up to 47s) stereo audio at 44.1kHz from text prompts. It comprises three components: an autoencoder that compresses waveforms into a manageable sequence length, a T5-based text embedding for text conditioning, and a transformer-based diffusion (DiT) model that operates in the latent space of the autoencoder.\n\nStable Audio is trained on a corpus of around 48k audio recordings, where around 47k are from Freesound and the rest are from the Free Music Archive (FMA). All audio files are licensed under CC0, CC BY, or CC Sampling+. This data is used to train the autoencoder and the DiT.\n\nThe abstract of the paper is the following:\n*Open generative models are vitally important for the community, allowing for fine-tunes and serving as baselines when presenting new models. However, most current text-to-audio models are private and not accessible for artists and researchers to build upon. Here we describe the architecture and training process of a new open-weights text-to-audio model trained with Creative Commons data. Our evaluation shows that the model's performance is competitive with the state-of-the-art across various metrics. Notably, the reported FDopenl3 results (measuring the realism of the generations) showcase its potential for high-quality stereo sound synthesis at 44.1kHz.*\n\nThis pipeline was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe). The original codebase can be found at [Stability-AI/stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools).\n\n## Tips\n\nWhen constructing a prompt, keep in mind:\n\n* Descriptive prompt inputs work best; use adjectives to describe the sound (for example, \"high quality\" or \"clear\") and make the prompt context specific where possible (e.g. \"melodic techno with a fast beat and synths\" works better than \"techno\").\n* Using a *negative prompt* can significantly improve the quality of the generated audio. Try using a negative prompt of \"low quality, average quality\".\n\nDuring inference:\n\n* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.\n* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.\n\n## Quantization\n\nQuantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.\n\nRefer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`StableAudioPipeline`] for inference with bitsandbytes.\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, StableAudioDiTModel, StableAudioPipeline\nfrom diffusers.utils import export_to_video\nfrom transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel\n\nquant_config = BitsAndBytesConfig(load_in_8bit=True)\ntext_encoder_8bit = T5EncoderModel.from_pretrained(\n    \"stabilityai/stable-audio-open-1.0\",\n    subfolder=\"text_encoder\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_8bit = StableAudioDiTModel.from_pretrained(\n    \"stabilityai/stable-audio-open-1.0\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\npipeline = StableAudioPipeline.from_pretrained(\n    \"stabilityai/stable-audio-open-1.0\",\n    text_encoder=text_encoder_8bit,\n    transformer=transformer_8bit,\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n)\n\nprompt = \"The sound of a hammer hitting a wooden surface.\"\nnegative_prompt = \"Low quality.\"\naudio = pipeline(\n    prompt,\n    negative_prompt=negative_prompt,\n    num_inference_steps=200,\n    audio_end_in_s=10.0,\n    num_waveforms_per_prompt=3,\n    generator=generator,\n).audios\n\noutput = audio[0].T.float().cpu().numpy()\nsf.write(\"hammer.wav\", output, pipeline.vae.sampling_rate)\n```\n\n\n## StableAudioPipeline\n[[autodoc]] StableAudioPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_cascade.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Cascade\n\nThis model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main\ndifference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this\nimportant? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.\nHow small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being\nencoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a\n1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the\nhighly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable\nDiffusion 1.5.\n\nTherefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions\nlike finetuning, LoRA, ControlNet, IP-Adapter, LCM etc. are possible with this method as well.\n\nThe original codebase can be found at [Stability-AI/StableCascade](https://github.com/Stability-AI/StableCascade).\n\n## Model Overview\nStable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images,\nhence the name \"Stable Cascade\".\n\nStage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.\nHowever, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a\nspatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves\na compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the\nimage. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible\nfor generating the small 24 x 24 latents given a text prompt.\n\nThe Stage C model operates on the small 24 x 24 latents and denoises the latents conditioned on text prompts. The model is also the largest component in the Cascade pipeline and is meant to be used with the `StableCascadePriorPipeline`\n\nThe Stage B and Stage A models are used with the `StableCascadeDecoderPipeline` and are responsible for generating the final image given the small 24 x 24 latents.\n\n> [!WARNING]\n> There are some restrictions on data types that can be used with the Stable Cascade models. The official checkpoints for the  `StableCascadePriorPipeline` do not support the `torch.float16` data type. Please use `torch.bfloat16` instead.\n>\n> In order to use the `torch.bfloat16` data type with the `StableCascadeDecoderPipeline` you need to have PyTorch 2.2.0 or higher installed. This also means that using the `StableCascadeCombinedPipeline` with `torch.bfloat16` requires PyTorch 2.2.0 or higher, since it calls the `StableCascadeDecoderPipeline` internally.\n>\n> If it is not possible to install PyTorch 2.2.0 or higher in your environment, the `StableCascadeDecoderPipeline` can be used on its own with the `torch.float16` data type. You can download the full precision or `bf16` variant weights for the pipeline and cast the weights to `torch.float16`.\n\n## Usage example\n\n```python\nimport torch\nfrom diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline\n\nprompt = \"an image of a shiba inu, donning a spacesuit and helmet\"\nnegative_prompt = \"\"\n\nprior = StableCascadePriorPipeline.from_pretrained(\"stabilityai/stable-cascade-prior\", variant=\"bf16\", torch_dtype=torch.bfloat16)\ndecoder = StableCascadeDecoderPipeline.from_pretrained(\"stabilityai/stable-cascade\", variant=\"bf16\", torch_dtype=torch.float16)\n\nprior.enable_model_cpu_offload()\nprior_output = prior(\n    prompt=prompt,\n    height=1024,\n    width=1024,\n    negative_prompt=negative_prompt,\n    guidance_scale=4.0,\n    num_images_per_prompt=1,\n    num_inference_steps=20\n)\n\ndecoder.enable_model_cpu_offload()\ndecoder_output = decoder(\n    image_embeddings=prior_output.image_embeddings.to(torch.float16),\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    guidance_scale=0.0,\n    output_type=\"pil\",\n    num_inference_steps=10\n).images[0]\ndecoder_output.save(\"cascade.png\")\n```\n\n## Using the Lite Versions of the Stage B and Stage C models\n\n```python\nimport torch\nfrom diffusers import (\n    StableCascadeDecoderPipeline,\n    StableCascadePriorPipeline,\n    StableCascadeUNet,\n)\n\nprompt = \"an image of a shiba inu, donning a spacesuit and helmet\"\nnegative_prompt = \"\"\n\nprior_unet = StableCascadeUNet.from_pretrained(\"stabilityai/stable-cascade-prior\", subfolder=\"prior_lite\")\ndecoder_unet = StableCascadeUNet.from_pretrained(\"stabilityai/stable-cascade\", subfolder=\"decoder_lite\")\n\nprior = StableCascadePriorPipeline.from_pretrained(\"stabilityai/stable-cascade-prior\", prior=prior_unet)\ndecoder = StableCascadeDecoderPipeline.from_pretrained(\"stabilityai/stable-cascade\", decoder=decoder_unet)\n\nprior.enable_model_cpu_offload()\nprior_output = prior(\n    prompt=prompt,\n    height=1024,\n    width=1024,\n    negative_prompt=negative_prompt,\n    guidance_scale=4.0,\n    num_images_per_prompt=1,\n    num_inference_steps=20\n)\n\ndecoder.enable_model_cpu_offload()\ndecoder_output = decoder(\n    image_embeddings=prior_output.image_embeddings,\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    guidance_scale=0.0,\n    output_type=\"pil\",\n    num_inference_steps=10\n).images[0]\ndecoder_output.save(\"cascade.png\")\n```\n\n## Loading original checkpoints with `from_single_file`\n\nLoading the original format checkpoints is supported via `from_single_file` method in the StableCascadeUNet.\n\n```python\nimport torch\nfrom diffusers import (\n    StableCascadeDecoderPipeline,\n    StableCascadePriorPipeline,\n    StableCascadeUNet,\n)\n\nprompt = \"an image of a shiba inu, donning a spacesuit and helmet\"\nnegative_prompt = \"\"\n\nprior_unet = StableCascadeUNet.from_single_file(\n    \"https://huggingface.co/stabilityai/stable-cascade/resolve/main/stage_c_bf16.safetensors\",\n    torch_dtype=torch.bfloat16\n)\ndecoder_unet = StableCascadeUNet.from_single_file(\n    \"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors\",\n    torch_dtype=torch.bfloat16\n)\n\nprior = StableCascadePriorPipeline.from_pretrained(\"stabilityai/stable-cascade-prior\", prior=prior_unet, torch_dtype=torch.bfloat16)\ndecoder = StableCascadeDecoderPipeline.from_pretrained(\"stabilityai/stable-cascade\", decoder=decoder_unet, torch_dtype=torch.bfloat16)\n\nprior.enable_model_cpu_offload()\nprior_output = prior(\n    prompt=prompt,\n    height=1024,\n    width=1024,\n    negative_prompt=negative_prompt,\n    guidance_scale=4.0,\n    num_images_per_prompt=1,\n    num_inference_steps=20\n)\n\ndecoder.enable_model_cpu_offload()\ndecoder_output = decoder(\n    image_embeddings=prior_output.image_embeddings,\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    guidance_scale=0.0,\n    output_type=\"pil\",\n    num_inference_steps=10\n).images[0]\ndecoder_output.save(\"cascade-single-file.png\")\n```\n\n## Uses\n\n### Direct Use\n\nThe model is intended for research purposes for now. Possible research areas and tasks include\n\n- Research on generative models.\n- Safe deployment of models which have the potential to generate harmful content.\n- Probing and understanding the limitations and biases of generative models.\n- Generation of artworks and use in design and other artistic processes.\n- Applications in educational or creative tools.\n\nExcluded uses are described below.\n\n### Out-of-Scope Use\n\nThe model was not trained to be factual or true representations of people or events,\nand therefore using the model to generate such content is out-of-scope for the abilities of this model.\nThe model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).\n\n## Limitations and Bias\n\n### Limitations\n- Faces and people in general may not be generated properly.\n- The autoencoding part of the model is lossy.\n\n\n## StableCascadeCombinedPipeline\n\n[[autodoc]] StableCascadeCombinedPipeline\n\t- all\n\t- __call__\n\n## StableCascadePriorPipeline\n\n[[autodoc]] StableCascadePriorPipeline\n\t- all\n\t- __call__\n\n## StableCascadePriorPipelineOutput\n\n[[autodoc]] pipelines.stable_cascade.pipeline_stable_cascade_prior.StableCascadePriorPipelineOutput\n\n## StableCascadeDecoderPipeline\n\n[[autodoc]] StableCascadeDecoderPipeline\n\t- all\n\t- __call__\n\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/adapter.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# T2I-Adapter\n\n[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.08453) by Chong Mou, Xintao Wang, Liangbin Xie, Jian Zhang, Zhongang Qi, Ying Shan, Xiaohu Qie.\n\nUsing the pretrained models we can provide control images (for example, a depth map) to control Stable Diffusion text-to-image generation so that it follows the structure of the depth image and fills in the details.\n\nThe abstract of the paper is the following:\n\n*The incredible generative ability of large-scale text-to-image (T2I) models has demonstrated strong power of learning complex structures and meaningful semantics. However, relying solely on text prompts cannot fully take advantage of the knowledge learned by the model, especially when flexible and accurate controlling (e.g., color and structure) is needed. In this paper, we aim to ``dig out\" the capabilities that T2I models have implicitly learned, and then explicitly use them to control the generation more granularly. Specifically, we propose to learn simple and lightweight T2I-Adapters to align internal knowledge in T2I models with external control signals, while freezing the original large T2I models. In this way, we can train various adapters according to different conditions, achieving rich control and editing effects in the color and structure of the generation results. Further, the proposed T2I-Adapters have attractive properties of practical value, such as composability and generalization ability. Extensive experiments demonstrate that our T2I-Adapter has promising generation quality and a wide range of applications.*\n\nThis model was contributed by the community contributor [HimariO](https://github.com/HimariO) ❤️ .\n\n## StableDiffusionAdapterPipeline\n\n[[autodoc]] StableDiffusionAdapterPipeline\n    - all\n    - __call__\n    - enable_attention_slicing\n    - disable_attention_slicing\n    - enable_vae_slicing\n    - disable_vae_slicing\n    - enable_xformers_memory_efficient_attention\n    - disable_xformers_memory_efficient_attention\n\n## StableDiffusionXLAdapterPipeline\n\n[[autodoc]] StableDiffusionXLAdapterPipeline\n    - all\n    - __call__\n    - enable_attention_slicing\n    - disable_attention_slicing\n    - enable_vae_slicing\n    - disable_vae_slicing\n    - enable_xformers_memory_efficient_attention\n    - disable_xformers_memory_efficient_attention\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/depth2img.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Depth-to-image\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nThe Stable Diffusion model can also infer depth based on an image using [MiDaS](https://github.com/isl-org/MiDaS). This allows you to pass a text prompt and an initial image to condition the generation of new images as well as a `depth_map` to preserve the image structure.\n\n> [!TIP]\n> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!\n>\n> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!\n\n## StableDiffusionDepth2ImgPipeline\n\n[[autodoc]] StableDiffusionDepth2ImgPipeline\n\t- all\n\t- __call__\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\t- load_textual_inversion\n\t- load_lora_weights\n\t- save_lora_weights\n\n## StableDiffusionPipelineOutput\n\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/gligen.md",
    "content": "<!--Copyright 2025 The GLIGEN Authors and The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# GLIGEN (Grounded Language-to-Image Generation)\n\nThe GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] and [`StableDiffusionGLIGENTextImagePipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes with [`StableDiffusionGLIGENPipeline`], if input images are given, [`StableDiffusionGLIGENTextImagePipeline`] can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.\n\nThe abstract from the [paper](https://huggingface.co/papers/2301.07093) is:\n\n*Large-scale text-to-image diffusion models have made amazing advances. However, the status quo is to use text input alone, which can impede controllability. In this work, we propose GLIGEN, Grounded-Language-to-Image Generation, a novel approach that builds upon and extends the functionality of existing pre-trained text-to-image diffusion models by enabling them to also be conditioned on grounding inputs. To preserve the vast concept knowledge of the pre-trained model, we freeze all of its weights and inject the grounding information into new trainable layers via a gated mechanism. Our model achieves open-world grounded text2img generation with caption and bounding box condition inputs, and the grounding ability generalizes well to novel spatial configurations and concepts. GLIGEN’s zeroshot performance on COCO and LVIS outperforms existing supervised layout-to-image baselines by a large margin.*\n\n> [!TIP]\n> Make sure to check out the Stable Diffusion [Tips](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality and how to reuse pipeline components efficiently!\n>\n> If you want to use one of the official checkpoints for a task, explore the [gligen](https://huggingface.co/gligen) Hub organizations!\n\n[`StableDiffusionGLIGENPipeline`] was contributed by [Nikhil Gajendrakumar](https://github.com/nikhil-masterful) and [`StableDiffusionGLIGENTextImagePipeline`] was contributed by [Nguyễn Công Tú Anh](https://github.com/tuanh123789).\n\n## StableDiffusionGLIGENPipeline\n\n[[autodoc]] StableDiffusionGLIGENPipeline\n\t- all\n\t- __call__\n\t- enable_vae_slicing\n\t- disable_vae_slicing\n\t- enable_vae_tiling\n\t- disable_vae_tiling\n\t- enable_model_cpu_offload\n\t- prepare_latents\n\t- enable_fuser\n\n## StableDiffusionGLIGENTextImagePipeline\n\n[[autodoc]] StableDiffusionGLIGENTextImagePipeline\n\t- all\n\t- __call__\n\t- enable_vae_slicing\n\t- disable_vae_slicing\n\t- enable_vae_tiling\n\t- disable_vae_tiling\n\t- enable_model_cpu_offload\n\t- prepare_latents\n\t- enable_fuser\n\n## StableDiffusionPipelineOutput\n\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/image_variation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Image variation\n\nThe Stable Diffusion model can also generate variations from an input image. It uses a fine-tuned version of a Stable Diffusion model by [Justin Pinkney](https://www.justinpinkney.com/) from [Lambda](https://lambdalabs.com/).\n\nThe original codebase can be found at [LambdaLabsML/lambda-diffusers](https://github.com/LambdaLabsML/lambda-diffusers#stable-diffusion-image-variations) and additional official checkpoints for image variation can be found at [lambdalabs/sd-image-variations-diffusers](https://huggingface.co/lambdalabs/sd-image-variations-diffusers).\n\n> [!TIP]\n> Make sure to check out the Stable Diffusion [Tips](./overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!\n\n## StableDiffusionImageVariationPipeline\n\n[[autodoc]] StableDiffusionImageVariationPipeline\n\t- all\n\t- __call__\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\n## StableDiffusionPipelineOutput\n\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/img2img.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Image-to-image\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nThe Stable Diffusion model can also be applied to image-to-image generation by passing a text prompt and an initial image to condition the generation of new images.\n\nThe [`StableDiffusionImg2ImgPipeline`] uses the diffusion-denoising mechanism proposed in [SDEdit: Guided Image Synthesis and Editing with Stochastic Differential Equations](https://huggingface.co/papers/2108.01073) by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan Zhu, Stefano Ermon.\n\nThe abstract from the paper is:\n\n*Guided image synthesis enables everyday users to create and edit photo-realistic images with minimum effort. The key challenge is balancing faithfulness to the user input (e.g., hand-drawn colored strokes) and realism of the synthesized image. Existing GAN-based methods attempt to achieve such balance using either conditional GANs or GAN inversions, which are challenging and often require additional training data or loss functions for individual applications. To address these issues, we introduce a new image synthesis and editing method, Stochastic Differential Editing (SDEdit), based on a diffusion model generative prior, which synthesizes realistic images by iteratively denoising through a stochastic differential equation (SDE). Given an input image with user guide of any type, SDEdit first adds noise to the input, then subsequently denoises the resulting image through the SDE prior to increase its realism. SDEdit does not require task-specific training or inversions and can naturally achieve the balance between realism and faithfulness. SDEdit significantly outperforms state-of-the-art GAN-based methods by up to 98.09% on realism and 91.72% on overall satisfaction scores, according to a human perception study, on multiple tasks, including stroke-based image synthesis and editing as well as image compositing.*\n\n> [!TIP]\n> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!\n\n## StableDiffusionImg2ImgPipeline\n\n[[autodoc]] StableDiffusionImg2ImgPipeline\n\t- all\n\t- __call__\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\t- load_textual_inversion\n\t- from_single_file\n\t- load_lora_weights\n\t- save_lora_weights\n\n## StableDiffusionPipelineOutput\n\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/inpaint.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Inpainting\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nThe Stable Diffusion model can also be applied to inpainting which lets you edit specific parts of an image by providing a mask and a text prompt using Stable Diffusion.\n\n## Tips\n\nIt is recommended to use this pipeline with checkpoints that have been specifically fine-tuned for inpainting, such\nas [stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting). Default\ntext-to-image Stable Diffusion checkpoints, such as\n[stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) are also compatible but they might be less performant.\n\n> [!TIP]\n> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!\n>\n> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!\n\n## StableDiffusionInpaintPipeline\n\n[[autodoc]] StableDiffusionInpaintPipeline\n\t- all\n\t- __call__\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\t- load_textual_inversion\n\t- load_lora_weights\n\t- save_lora_weights\n\n## StableDiffusionPipelineOutput\n\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/latent_upscale.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Latent upscaler\n\nThe Stable Diffusion latent upscaler model was created by [Katherine Crowson](https://github.com/crowsonkb/k-diffusion) in collaboration with [Stability AI](https://stability.ai/). It is used to enhance the output image resolution by a factor of 2 (see this demo [notebook](https://colab.research.google.com/drive/1o1qYJcFeywzCIdkfKJy7cTpgZTCM2EI4) for a demonstration of the original implementation).\n\n> [!TIP]\n> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!\n>\n> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!\n\n## StableDiffusionLatentUpscalePipeline\n\n[[autodoc]] StableDiffusionLatentUpscalePipeline\n\t- all\n\t- __call__\n\t- enable_sequential_cpu_offload\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\n## StableDiffusionPipelineOutput\n\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md",
    "content": "<!--Copyright 2025 The Intel Labs Team Authors and HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# Text-to-(RGB, depth)\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nLDM3D was proposed in [LDM3D: Latent Diffusion Model for 3D](https://huggingface.co/papers/2305.10853) by Gabriela Ben Melech Stan, Diana Wofk, Scottie Fox, Alex Redden, Will Saxton, Jean Yu, Estelle Aflalo, Shao-Yen Tseng, Fabio Nonato, Matthias Muller, and Vasudev Lal. LDM3D generates an image and a depth map from a given text prompt unlike the existing text-to-image diffusion models such as [Stable Diffusion](./overview) which only generates an image. With almost the same number of parameters, LDM3D achieves to create a latent space that can compress both the RGB images and the depth maps.\n\nTwo checkpoints are available for use:\n- [ldm3d-original](https://huggingface.co/Intel/ldm3d). The original checkpoint used in the [paper](https://huggingface.co/papers/2305.10853)\n- [ldm3d-4c](https://huggingface.co/Intel/ldm3d-4c). The new version of LDM3D using 4 channels inputs instead of 6-channels inputs and finetuned on higher resolution images.\n\n\nThe abstract from the paper is:\n\n*This research paper proposes a Latent Diffusion Model for 3D (LDM3D) that generates both image and depth map data from a given text prompt, allowing users to generate RGBD images from text prompts. The LDM3D model is fine-tuned on a dataset of tuples containing an RGB image, depth map and caption, and validated through extensive experiments. We also develop an application called DepthFusion, which uses the generated RGB images and depth maps to create immersive and interactive 360-degree-view experiences using TouchDesigner. This technology has the potential to transform a wide range of industries, from entertainment and gaming to architecture and design. Overall, this paper presents a significant contribution to the field of generative AI and computer vision, and showcases the potential of LDM3D and DepthFusion to revolutionize content creation and digital experiences. A short video summarizing the approach can be found at [this url](https://t.ly/tdi2).*\n\n> [!TIP]\n> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!\n\n## StableDiffusionLDM3DPipeline\n\n[[autodoc]] pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline\n\t- all\n\t- __call__\n\n\n## LDM3DPipelineOutput\n\n[[autodoc]] pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d.LDM3DPipelineOutput\n\t- all\n\t- __call__\n\n# Upscaler\n\n[LDM3D-VR](https://huggingface.co/papers/2311.03226) is an extended version of LDM3D.\n\nThe abstract from the paper is:\n*Latent diffusion models have proven to be state-of-the-art in the creation and manipulation of visual outputs. However, as far as we know, the generation of depth maps jointly with RGB is still limited. We introduce LDM3D-VR, a suite of diffusion models targeting virtual reality development that includes LDM3D-pano and LDM3D-SR. These models enable the generation of panoramic RGBD based on textual prompts and the upscaling of low-resolution inputs to high-resolution RGBD, respectively. Our models are fine-tuned from existing pretrained models on datasets containing panoramic/high-resolution RGB images, depth maps and captions. Both models are evaluated in comparison to existing related methods*\n\nTwo checkpoints are available for use:\n- [ldm3d-pano](https://huggingface.co/Intel/ldm3d-pano). This checkpoint enables the generation of panoramic images and requires the StableDiffusionLDM3DPipeline pipeline to be used.\n- [ldm3d-sr](https://huggingface.co/Intel/ldm3d-sr). This checkpoint enables the upscaling of RGB and depth images. Can be used in cascade after the original LDM3D pipeline using the StableDiffusionUpscaleLDM3DPipeline from communauty pipeline.\n\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Diffusion pipelines\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nStable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). Latent diffusion applies the diffusion process over a lower dimensional latent space to reduce memory and compute complexity. This specific type of diffusion model was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.\n\nStable Diffusion is trained on 512x512 images from a subset of the LAION-5B dataset. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and can run on consumer GPUs.\n\nFor more details about how Stable Diffusion works and how it differs from the base latent diffusion model, take a look at the Stability AI [announcement](https://stability.ai/blog/stable-diffusion-announcement) and our own [blog post](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) for more technical details.\n\nYou can find the original codebase for Stable Diffusion v1.0 at [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) and Stable Diffusion v2.0 at [Stability-AI/stablediffusion](https://github.com/Stability-AI/stablediffusion) as well as their original scripts for various tasks. Additional official checkpoints for the different Stable Diffusion versions and tasks can be found on the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations. Explore these organizations to find the best checkpoint for your use-case!\n\nThe table below summarizes the available Stable Diffusion pipelines, their supported tasks, and an interactive demo:\n\n<div class=\"flex justify-center\">\n    <div class=\"rounded-xl border border-gray-200\">\n    <table class=\"min-w-full divide-y-2 divide-gray-200 bg-white text-sm\">\n        <thead>\n        <tr>\n            <th class=\"px-4 py-2 font-medium text-gray-900 text-left\">\n            Pipeline\n            </th>\n            <th class=\"px-4 py-2 font-medium text-gray-900 text-left\">\n            Supported tasks\n            </th>\n            <th class=\"px-4 py-2 font-medium text-gray-900 text-left\">\n            🤗 Space\n            </th>\n        </tr>\n        </thead>\n        <tbody class=\"divide-y divide-gray-200\">\n        <tr>\n            <td class=\"px-4 py-2 text-gray-700\">\n            <a href=\"./text2img\">StableDiffusion</a>\n            </td>\n            <td class=\"px-4 py-2 text-gray-700\">text-to-image</td>\n            <td class=\"px-4 py-2\"><a href=\"https://huggingface.co/spaces/stabilityai/stable-diffusion\"><img src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue\"/></a>\n            </td>\n        </tr>\n        <tr>\n            <td class=\"px-4 py-2 text-gray-700\">\n            <a href=\"./img2img\">StableDiffusionImg2Img</a>\n            </td>\n            <td class=\"px-4 py-2 text-gray-700\">image-to-image</td>\n            <td class=\"px-4 py-2\"><a href=\"https://huggingface.co/spaces/huggingface/diffuse-the-rest\"><img src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue\"/></a>\n            </td>\n        </tr>\n        <tr>\n            <td class=\"px-4 py-2 text-gray-700\">\n            <a href=\"./inpaint\">StableDiffusionInpaint</a>\n            </td>\n            <td class=\"px-4 py-2 text-gray-700\">inpainting</td>\n            <td class=\"px-4 py-2\"><a href=\"https://huggingface.co/spaces/stable-diffusion-v1-5/stable-diffusion-inpainting\"><img src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue\"/></a>\n            </td>\n        </tr>\n        <tr>\n            <td class=\"px-4 py-2 text-gray-700\">\n            <a href=\"./depth2img\">StableDiffusionDepth2Img</a>\n            </td>\n            <td class=\"px-4 py-2 text-gray-700\">depth-to-image</td>\n            <td class=\"px-4 py-2\"><a href=\"https://huggingface.co/spaces/radames/stable-diffusion-depth2img\"><img src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue\"/></a>\n            </td>\n        </tr>\n        <tr>\n            <td class=\"px-4 py-2 text-gray-700\">\n            <a href=\"./image_variation\">StableDiffusionImageVariation</a>\n            </td>\n            <td class=\"px-4 py-2 text-gray-700\">image variation</td>\n            <td class=\"px-4 py-2\"><a href=\"https://huggingface.co/spaces/lambdalabs/stable-diffusion-image-variations\"><img src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue\"/></a>\n            </td>\n        </tr>\n        <tr>\n            <td class=\"px-4 py-2 text-gray-700\">\n            <a href=\"./stable_diffusion_safe\">StableDiffusionPipelineSafe</a>\n            </td>\n            <td class=\"px-4 py-2 text-gray-700\">filtered text-to-image</td>\n            <td class=\"px-4 py-2\"><a href=\"https://huggingface.co/spaces/AIML-TUDA/unsafe-vs-safe-stable-diffusion\"><img src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue\"/></a>\n            </td>\n        </tr>\n        <tr>\n            <td class=\"px-4 py-2 text-gray-700\">\n            <a href=\"./stable_diffusion_2\">StableDiffusion2</a>\n            </td>\n            <td class=\"px-4 py-2 text-gray-700\">text-to-image, inpainting, depth-to-image, super-resolution</td>\n            <td class=\"px-4 py-2\"><a href=\"https://huggingface.co/spaces/stabilityai/stable-diffusion\"><img src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue\"/></a>\n            </td>\n        </tr>\n        <tr>\n            <td class=\"px-4 py-2 text-gray-700\">\n            <a href=\"./stable_diffusion_xl\">StableDiffusionXL</a>\n            </td>\n            <td class=\"px-4 py-2 text-gray-700\">text-to-image, image-to-image</td>\n            <td class=\"px-4 py-2\"><a href=\"https://huggingface.co/spaces/RamAnanth1/stable-diffusion-xl\"><img src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue\"/></a>\n            </td>\n        </tr>\n        <tr>\n            <td class=\"px-4 py-2 text-gray-700\">\n            <a href=\"./latent_upscale\">StableDiffusionLatentUpscale</a>\n            </td>\n            <td class=\"px-4 py-2 text-gray-700\">super-resolution</td>\n            <td class=\"px-4 py-2\"><a href=\"https://huggingface.co/spaces/huggingface-projects/stable-diffusion-latent-upscaler\"><img src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue\"/></a>\n            </td>\n        </tr>\n        <tr>\n            <td class=\"px-4 py-2 text-gray-700\">\n            <a href=\"./upscale\">StableDiffusionUpscale</a>\n            </td>\n            <td class=\"px-4 py-2 text-gray-700\">super-resolution</td>\n        </tr>\n        <tr>\n            <td class=\"px-4 py-2 text-gray-700\">\n            <a href=\"./ldm3d_diffusion\">StableDiffusionLDM3D</a>\n            </td>\n            <td class=\"px-4 py-2 text-gray-700\">text-to-rgb, text-to-depth, text-to-pano</td>\n            <td class=\"px-4 py-2\"><a href=\"https://huggingface.co/spaces/r23/ldm3d-space\"><img src=\"https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue\"/></a>\n            </td>\n        </tr>\n        <tr>\n            <td class=\"px-4 py-2 text-gray-700\">\n            <a href=\"./ldm3d_diffusion\">StableDiffusionUpscaleLDM3D</a>\n            </td>\n            <td class=\"px-4 py-2 text-gray-700\">ldm3d super-resolution</td>\n        </tr>\n        </tbody>\n    </table>\n    </div>\n</div>\n\n## Tips\n\nTo help you get the most out of the Stable Diffusion pipelines, here are a few tips for improving performance and usability. These tips are applicable to all Stable Diffusion pipelines.\n\n### Explore tradeoff between speed and quality\n\n[`StableDiffusionPipeline`] uses the [`PNDMScheduler`] by default, but 🤗 Diffusers provides many other schedulers (some of which are faster or output better quality) that are compatible. For example, if you want to use the [`EulerDiscreteScheduler`] instead of the default:\n\n```py\nfrom diffusers import StableDiffusionPipeline, EulerDiscreteScheduler\n\npipeline = StableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\")\npipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)\n\n# or\neuler_scheduler = EulerDiscreteScheduler.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"scheduler\")\npipeline = StableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", scheduler=euler_scheduler)\n```\n\n### Reuse pipeline components to save memory\n\nTo save memory and use the same components across multiple pipelines, use the `.components` method to avoid loading weights into RAM more than once.\n\n```py\nfrom diffusers import (\n    StableDiffusionPipeline,\n    StableDiffusionImg2ImgPipeline,\n    StableDiffusionInpaintPipeline,\n)\n\ntext2img = StableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\")\nimg2img = StableDiffusionImg2ImgPipeline(**text2img.components)\ninpaint = StableDiffusionInpaintPipeline(**text2img.components)\n\n# now you can use text2img(...), img2img(...), inpaint(...) just like the call methods of each respective pipeline\n```\n\n### Create web demos using `gradio`\n\nThe Stable Diffusion pipelines are automatically supported in [Gradio](https://github.com/gradio-app/gradio/), a library that makes creating beautiful and user-friendly machine learning apps on the web a breeze. First, make sure you have Gradio installed:\n\n```sh\npip install -U gradio\n```\n\nThen, create a web demo around any Stable Diffusion-based pipeline. For example, you can create an image generation pipeline in a single line of code with Gradio's [`Interface.from_pipeline`](https://www.gradio.app/docs/interface#interface-from-pipeline) function:\n\n```py\nfrom diffusers import StableDiffusionPipeline\nimport gradio as gr\n\npipe = StableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\")\n\ngr.Interface.from_pipeline(pipe).launch()\n```\n\nwhich opens an intuitive drag-and-drop interface in your browser:\n\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gradio-panda.png)\n\nSimilarly, you could create a demo for an image-to-image pipeline with:\n\n```py\nfrom diffusers import StableDiffusionImg2ImgPipeline\nimport gradio as gr\n\n\npipe = StableDiffusionImg2ImgPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\")\n\ngr.Interface.from_pipeline(pipe).launch()\n```\n\nBy default, the web demo runs on a local server. If you'd like to share it with others, you can generate a temporary public\nlink by setting `share=True` in `launch()`. Or, you can host your demo on [Hugging Face Spaces](https://huggingface.co/spaces)https://huggingface.co/spaces for a permanent link."
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/sdxl_turbo.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# SDXL Turbo\n\nStable Diffusion XL (SDXL) Turbo was proposed in [Adversarial Diffusion Distillation](https://stability.ai/research/adversarial-diffusion-distillation) by Axel Sauer, Dominik Lorenz, Andreas Blattmann, and Robin Rombach.\n\nThe abstract from the paper is:\n\n*We introduce Adversarial Diffusion Distillation (ADD), a novel training approach that efficiently samples large-scale foundational image diffusion models in just 1–4 steps while maintaining high image quality. We use score distillation to leverage large-scale off-the-shelf image diffusion models as a teacher signal in combination with an adversarial loss to ensure high image fidelity even in the low-step regime of one or two sampling steps. Our analyses show that our model clearly outperforms existing few-step methods (GANs,Latent Consistency Models) in a single step and reaches the performance of state-of-the-art diffusion models (SDXL) in only four steps. ADD is the first method to unlock single-step, real-time image synthesis with foundation models.*\n\n## Tips\n\n- SDXL Turbo uses the exact same architecture as [SDXL](./stable_diffusion_xl), which means it also has the same API. Please refer to the [SDXL](./stable_diffusion_xl) API reference for more details.\n- SDXL Turbo should disable guidance scale by setting `guidance_scale=0.0`.\n- SDXL Turbo should use `timestep_spacing='trailing'` for the scheduler and use between 1 and 4 steps.\n- SDXL Turbo has been trained to generate images of size 512x512.\n- SDXL Turbo is open-access, but not open-source meaning that one might have to buy a model license in order to use it for commercial applications. Make sure to read the [official model card](https://huggingface.co/stabilityai/sdxl-turbo) to learn more.\n\n> [!TIP]\n> To learn how to use SDXL Turbo for various tasks, how to optimize performance, and other usage examples, take a look at the [SDXL Turbo](../../../using-diffusers/sdxl_turbo) guide.\n>\n> Check out the [Stability AI](https://huggingface.co/stabilityai) Hub organization for the official base and refiner model checkpoints!\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Diffusion 2\n\nStable Diffusion 2 is a text-to-image _latent diffusion_ model built upon the work of the original [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release), and it was led by Robin Rombach and Katherine Crowson from [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/).\n\n*The Stable Diffusion 2.0 release includes robust text-to-image models trained using a brand new text encoder (OpenCLIP), developed by LAION with support from Stability AI, which greatly improves the quality of the generated images compared to earlier V1 releases. The text-to-image models in this release can generate images with default resolutions of both 512x512 pixels and 768x768 pixels.\nThese models are trained on an aesthetic subset of the [LAION-5B dataset](https://laion.ai/blog/laion-5b/) created by the DeepFloyd team at Stability AI, which is then further filtered to remove adult content using [LAION’s NSFW filter](https://openreview.net/forum?id=M3Y74vmsMcY).*\n\nFor more details about how Stable Diffusion 2 works and how it differs from the original Stable Diffusion, please refer to the official [announcement post](https://stability.ai/blog/stable-diffusion-v2-release).\n\nThe architecture of Stable Diffusion 2 is more or less identical to the original [Stable Diffusion model](./text2img) so check out it's API documentation for how to use Stable Diffusion 2. We recommend using the [`DPMSolverMultistepScheduler`] as it gives a reasonable speed/quality trade-off and can be run with as little as 20 steps.\n\nStable Diffusion 2 is available for tasks like text-to-image, inpainting, super-resolution, and depth-to-image:\n\n| Task                    | Repository                                                                                                    |\n|-------------------------|---------------------------------------------------------------------------------------------------------------|\n| text-to-image (512x512) | [stabilityai/stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base)             |\n| text-to-image (768x768) | [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2)                       |\n| inpainting              | [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) |\n| super-resolution        | [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler)               |\n| depth-to-image          | [stabilityai/stable-diffusion-2-depth](https://huggingface.co/stabilityai/stable-diffusion-2-depth)           |\n\nHere are some examples for how to use Stable Diffusion 2 for each task:\n\n> [!TIP]\n> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!\n>\n> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!\n\n## Text-to-image\n\n```py\nfrom diffusers import DiffusionPipeline, DPMSolverMultistepScheduler\nimport torch\n\nrepo_id = \"stabilityai/stable-diffusion-2-base\"\npipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, variant=\"fp16\")\n\npipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\npipe = pipe.to(\"cuda\")\n\nprompt = \"High quality photo of an astronaut riding a horse in space\"\nimage = pipe(prompt, num_inference_steps=25).images[0]\nimage\n```\n\n## Inpainting\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline, DPMSolverMultistepScheduler\nfrom diffusers.utils import load_image, make_image_grid\n\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\ninit_image = load_image(img_url).resize((512, 512))\nmask_image = load_image(mask_url).resize((512, 512))\n\nrepo_id = \"stabilityai/stable-diffusion-2-inpainting\"\npipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, variant=\"fp16\")\n\npipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\npipe = pipe.to(\"cuda\")\n\nprompt = \"Face of a yellow cat, high resolution, sitting on a park bench\"\nimage = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=25).images[0]\nmake_image_grid([init_image, mask_image, image], rows=1, cols=3)\n```\n\n## Super-resolution\n\n```py\nfrom diffusers import StableDiffusionUpscalePipeline\nfrom diffusers.utils import load_image, make_image_grid\nimport torch\n\n# load model and scheduler\nmodel_id = \"stabilityai/stable-diffusion-x4-upscaler\"\npipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)\npipeline = pipeline.to(\"cuda\")\n\n# let's download an  image\nurl = \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png\"\nlow_res_img = load_image(url)\nlow_res_img = low_res_img.resize((128, 128))\nprompt = \"a white cat\"\nupscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]\nmake_image_grid([low_res_img.resize((512, 512)), upscaled_image.resize((512, 512))], rows=1, cols=2)\n```\n\n## Depth-to-image\n\n```py\nimport torch\nfrom diffusers import StableDiffusionDepth2ImgPipeline\nfrom diffusers.utils import load_image, make_image_grid\n\npipe = StableDiffusionDepth2ImgPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-2-depth\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n\n\nurl = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\ninit_image = load_image(url)\nprompt = \"two tigers\"\nnegative_prompt = \"bad, deformed, ugly, bad anotomy\"\nimage = pipe(prompt=prompt, image=init_image, negative_prompt=negative_prompt, strength=0.7).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Diffusion 3\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n  <img alt=\"MPS\" src=\"https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22\">\n</div>\n\nStable Diffusion 3 (SD3) was proposed in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://huggingface.co/papers/2403.03206) by Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Muller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, Dustin Podell, Tim Dockhorn, Zion English, Kyle Lacey, Alex Goodwin, Yannik Marek, and Robin Rombach.\n\nThe abstract from the paper is:\n\n*Diffusion models create data from noise by inverting the forward paths of data towards noise and have emerged as a powerful generative modeling technique for high-dimensional, perceptual data such as images and videos. Rectified flow is a recent generative model formulation that connects data and noise in a straight line. Despite its better theoretical properties and conceptual simplicity, it is not yet decisively established as standard practice. In this work, we improve existing noise sampling techniques for training rectified flow models by biasing them towards perceptually relevant scales. Through a large-scale study, we demonstrate the superior performance of this approach compared to established diffusion formulations for high-resolution text-to-image synthesis. Additionally, we present a novel transformer-based architecture for text-to-image generation that uses separate weights for the two modalities and enables a bidirectional flow of information between image and text tokens, improving text comprehension typography, and human preference ratings. We demonstrate that this architecture follows predictable scaling trends and correlates lower validation loss to improved text-to-image synthesis as measured by various metrics and human evaluations.*\n\n\n## Usage Example\n\n_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._\n\nUse the command below to log in:\n\n```bash\nhf auth login\n```\n\n> [!TIP]\n> The SD3 pipeline uses three text encoders to generate an image. Model offloading is necessary in order for it to run on most commodity hardware. Please use the `torch.float16` data type for additional memory savings.\n\n```python\nimport torch\nfrom diffusers import StableDiffusion3Pipeline\n\npipe = StableDiffusion3Pipeline.from_pretrained(\"stabilityai/stable-diffusion-3-medium-diffusers\", torch_dtype=torch.float16)\npipe.to(\"cuda\")\n\nimage = pipe(\n    prompt=\"a photo of a cat holding a sign that says hello world\",\n    negative_prompt=\"\",\n    num_inference_steps=28,\n    height=1024,\n    width=1024,\n    guidance_scale=7.0,\n).images[0]\n\nimage.save(\"sd3_hello_world.png\")\n```\n\n**Note:** Stable Diffusion 3.5 can also be run using the SD3 pipeline, and all mentioned optimizations and techniques apply to it as well. In total there are three official models in the SD3 family:\n- [`stabilityai/stable-diffusion-3-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers)\n- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)\n- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)\n\n## Image Prompting with IP-Adapters\n\nAn IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need:\n\n- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder.\n- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`.\n- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection. \n\nIP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally.\n\n```python\nimport torch\nfrom PIL import Image\n\nfrom diffusers import StableDiffusion3Pipeline\nfrom transformers import SiglipVisionModel, SiglipImageProcessor\n\nimage_encoder_id = \"google/siglip-so400m-patch14-384\"\nip_adapter_id = \"InstantX/SD3.5-Large-IP-Adapter\"\n\nfeature_extractor = SiglipImageProcessor.from_pretrained(\n    image_encoder_id,\n    torch_dtype=torch.float16\n)\nimage_encoder = SiglipVisionModel.from_pretrained(\n    image_encoder_id,\n    torch_dtype=torch.float16\n).to( \"cuda\")\n\npipe = StableDiffusion3Pipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-3.5-large\",\n    torch_dtype=torch.float16,\n    feature_extractor=feature_extractor,\n    image_encoder=image_encoder,\n).to(\"cuda\")\n\npipe.load_ip_adapter(ip_adapter_id)\npipe.set_ip_adapter_scale(0.6)\n\nref_img = Image.open(\"image.jpg\").convert('RGB')\n\nimage = pipe(\n    width=1024,\n    height=1024,\n    prompt=\"a cat\",\n    negative_prompt=\"lowres, low quality, worst quality\",\n    num_inference_steps=24,\n    guidance_scale=5.0,\n    ip_adapter_image=ref_img\n).images[0]\n\nimage.save(\"result.jpg\")\n```\n\n<div class=\"justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd3_ip_adapter_example.png\"/>\n    <figcaption class=\"mt-2 text-sm text-center text-gray-500\">IP-Adapter examples with prompt \"a cat\"</figcaption>\n</div>\n\n\n> [!TIP]\n> Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.\n\n\n## Memory Optimisations for SD3\n\nSD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.\n\n### Running Inference with Model Offloading\n\nThe most basic memory optimization available in Diffusers allows you to offload the components of the model to CPU during inference in order to save memory, while seeing a slight increase in inference latency. Model offloading will only move a model component onto the GPU when it needs to be executed, while keeping the remaining components on the CPU.\n\n```python\nimport torch\nfrom diffusers import StableDiffusion3Pipeline\n\npipe = StableDiffusion3Pipeline.from_pretrained(\"stabilityai/stable-diffusion-3-medium-diffusers\", torch_dtype=torch.float16)\npipe.enable_model_cpu_offload()\n\nimage = pipe(\n    prompt=\"a photo of a cat holding a sign that says hello world\",\n    negative_prompt=\"\",\n    num_inference_steps=28,\n    height=1024,\n    width=1024,\n    guidance_scale=7.0,\n).images[0]\n\nimage.save(\"sd3_hello_world.png\")\n```\n\n### Dropping the T5 Text Encoder during Inference\n\nRemoving the memory-intensive 4.7B parameter T5-XXL text encoder during inference can significantly decrease the memory requirements for SD3 with only a slight loss in performance.\n\n```python\nimport torch\nfrom diffusers import StableDiffusion3Pipeline\n\npipe = StableDiffusion3Pipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-3-medium-diffusers\",\n    text_encoder_3=None,\n    tokenizer_3=None,\n    torch_dtype=torch.float16\n)\npipe.to(\"cuda\")\n\nimage = pipe(\n    prompt=\"a photo of a cat holding a sign that says hello world\",\n    negative_prompt=\"\",\n    num_inference_steps=28,\n    height=1024,\n    width=1024,\n    guidance_scale=7.0,\n).images[0]\n\nimage.save(\"sd3_hello_world-no-T5.png\")\n```\n\n### Using a Quantized Version of the T5 Text Encoder\n\nWe can leverage the `bitsandbytes` library to load and quantize the T5-XXL text encoder to 8-bit precision. This allows you to keep using all three text encoders while only slightly impacting performance.\n\nFirst install the `bitsandbytes` library.\n\n```shell\npip install bitsandbytes\n```\n\nThen load the T5-XXL model using the `BitsAndBytesConfig`.\n\n```python\nimport torch\nfrom diffusers import StableDiffusion3Pipeline\nfrom transformers import T5EncoderModel, BitsAndBytesConfig\n\nquantization_config = BitsAndBytesConfig(load_in_8bit=True)\n\nmodel_id = \"stabilityai/stable-diffusion-3-medium-diffusers\"\ntext_encoder = T5EncoderModel.from_pretrained(\n    model_id,\n    subfolder=\"text_encoder_3\",\n    quantization_config=quantization_config,\n)\npipe = StableDiffusion3Pipeline.from_pretrained(\n    model_id,\n    text_encoder_3=text_encoder,\n    device_map=\"balanced\",\n    torch_dtype=torch.float16\n)\n\nimage = pipe(\n    prompt=\"a photo of a cat holding a sign that says hello world\",\n    negative_prompt=\"\",\n    num_inference_steps=28,\n    height=1024,\n    width=1024,\n    guidance_scale=7.0,\n).images[0]\n\nimage.save(\"sd3_hello_world-8bit-T5.png\")\n```\n\nYou can find the end-to-end script [here](https://gist.github.com/sayakpaul/82acb5976509851f2db1a83456e504f1).\n\n## Performance Optimizations for SD3\n\n### Using Torch Compile to Speed Up Inference\n\nUsing compiled components in the SD3 pipeline can speed up inference by as much as 4X. The following code snippet demonstrates how to compile the Transformer and VAE components of the SD3 pipeline.\n\n```python\nimport torch\nfrom diffusers import StableDiffusion3Pipeline\n\ntorch.set_float32_matmul_precision(\"high\")\n\ntorch._inductor.config.conv_1x1_as_mm = True\ntorch._inductor.config.coordinate_descent_tuning = True\ntorch._inductor.config.epilogue_fusion = False\ntorch._inductor.config.coordinate_descent_check_all_directions = True\n\npipe = StableDiffusion3Pipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-3-medium-diffusers\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipe.set_progress_bar_config(disable=True)\n\npipe.transformer.to(memory_format=torch.channels_last)\npipe.vae.to(memory_format=torch.channels_last)\n\npipe.transformer = torch.compile(pipe.transformer, mode=\"max-autotune\", fullgraph=True)\npipe.vae.decode = torch.compile(pipe.vae.decode, mode=\"max-autotune\", fullgraph=True)\n\n# Warm Up\nprompt = \"a photo of a cat holding a sign that says hello world\"\nfor _ in range(3):\n    _ = pipe(prompt=prompt, generator=torch.manual_seed(1))\n\n# Run Inference\nimage = pipe(prompt=prompt, generator=torch.manual_seed(1)).images[0]\nimage.save(\"sd3_hello_world.png\")\n```\n\nCheck out the full script [here](https://gist.github.com/sayakpaul/508d89d7aad4f454900813da5d42ca97).\n\n## Quantization\n\nQuantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.\n\nRefer to the [Quantization](../../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`StableDiffusion3Pipeline`] for inference with bitsandbytes.\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SD3Transformer2DModel, StableDiffusion3Pipeline\nfrom transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel\n\nquant_config = BitsAndBytesConfig(load_in_8bit=True)\ntext_encoder_8bit = T5EncoderModel.from_pretrained(\n    \"stabilityai/stable-diffusion-3.5-large\",\n    subfolder=\"text_encoder_3\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_8bit = SD3Transformer2DModel.from_pretrained(\n    \"stabilityai/stable-diffusion-3.5-large\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\npipeline = StableDiffusion3Pipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-3.5-large\",\n    text_encoder=text_encoder_8bit,\n    transformer=transformer_8bit,\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n)\n\nprompt = \"a tiny astronaut hatching from an egg on the moon\"\nimage = pipeline(prompt, num_inference_steps=28, guidance_scale=7.0).images[0]\nimage.save(\"sd3.png\")\n```\n\n## Using Long Prompts with the T5 Text Encoder\n\nBy default, the T5 Text Encoder prompt uses a maximum sequence length of `256`. This can be adjusted by setting the `max_sequence_length` to accept fewer or more tokens. Keep in mind that longer sequences require additional resources and result in longer generation times, such as during batch inference.\n\n```python\nprompt = \"A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. It features the distinctive, bulky body shape of a hippo. However, instead of the usual grey skin, the creature’s body resembles a golden-brown, crispy waffle fresh off the griddle. The skin is textured with the familiar grid pattern of a waffle, each square filled with a glistening sheen of syrup. The environment combines the natural habitat of a hippo with elements of a breakfast table setting, a river of warm, melted butter, with oversized utensils or plates peeking out from the lush, pancake-like foliage in the background, a towering pepper mill standing in for a tree.  As the sun rises in this fantastical world, it casts a warm, buttery glow over the scene. The creature, content in its butter river, lets out a yawn. Nearby, a flock of birds take flight\"\n\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=\"\",\n    num_inference_steps=28,\n    guidance_scale=4.5,\n    max_sequence_length=512,\n).images[0]\n```\n\n### Sending a different prompt to the T5 Text Encoder\n\nYou can send a different prompt to the CLIP Text Encoders and the T5 Text Encoder to prevent the prompt from being truncated by the CLIP Text Encoders and to improve generation.\n\n> [!TIP]\n> The prompt with the CLIP Text Encoders is still truncated to the 77 token limit.\n\n```python\nprompt = \"A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. A river of warm, melted butter, pancake-like foliage in the background, a towering pepper mill standing in for a tree.\"\n\nprompt_3 = \"A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. It features the distinctive, bulky body shape of a hippo. However, instead of the usual grey skin, the creature’s body resembles a golden-brown, crispy waffle fresh off the griddle. The skin is textured with the familiar grid pattern of a waffle, each square filled with a glistening sheen of syrup. The environment combines the natural habitat of a hippo with elements of a breakfast table setting, a river of warm, melted butter, with oversized utensils or plates peeking out from the lush, pancake-like foliage in the background, a towering pepper mill standing in for a tree.  As the sun rises in this fantastical world, it casts a warm, buttery glow over the scene. The creature, content in its butter river, lets out a yawn. Nearby, a flock of birds take flight\"\n\nimage = pipe(\n    prompt=prompt,\n    prompt_3=prompt_3,\n    negative_prompt=\"\",\n    num_inference_steps=28,\n    guidance_scale=4.5,\n    max_sequence_length=512,\n).images[0]\n```\n\n## Tiny AutoEncoder for Stable Diffusion 3\n\nTiny AutoEncoder for Stable Diffusion (TAESD3) is a tiny distilled version of Stable Diffusion 3's VAE by [Ollin Boer Bohan](https://github.com/madebyollin/taesd) that can decode [`StableDiffusion3Pipeline`] latents almost instantly.\n\nTo use with Stable Diffusion 3:\n\n```python\nimport torch\nfrom diffusers import StableDiffusion3Pipeline, AutoencoderTiny\n\npipe = StableDiffusion3Pipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-3-medium-diffusers\", torch_dtype=torch.float16\n)\npipe.vae = AutoencoderTiny.from_pretrained(\"madebyollin/taesd3\", torch_dtype=torch.float16)\npipe = pipe.to(\"cuda\")\n\nprompt = \"slice of delicious New York-style berry cheesecake\"\nimage = pipe(prompt, num_inference_steps=25).images[0]\nimage.save(\"cheesecake.png\")\n```\n\n## Loading the original checkpoints via `from_single_file`\n\nThe `SD3Transformer2DModel` and `StableDiffusion3Pipeline` classes support loading the original checkpoints via the `from_single_file` method. This method allows you to load the original checkpoint files that were used to train the models.\n\n## Loading the original checkpoints for the `SD3Transformer2DModel`\n\n```python\nfrom diffusers import SD3Transformer2DModel\n\nmodel = SD3Transformer2DModel.from_single_file(\"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium.safetensors\")\n```\n\n## Loading the single checkpoint for the `StableDiffusion3Pipeline`\n\n### Loading the single file checkpoint without T5\n\n```python\nimport torch\nfrom diffusers import StableDiffusion3Pipeline\n\npipe = StableDiffusion3Pipeline.from_single_file(\n    \"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors\",\n    torch_dtype=torch.float16,\n    text_encoder_3=None\n)\npipe.enable_model_cpu_offload()\n\nimage = pipe(\"a picture of a cat holding a sign that says hello world\").images[0]\nimage.save('sd3-single-file.png')\n```\n\n### Loading the single file checkpoint with T5\n\n> [!TIP]\n> The following example loads a checkpoint stored in a 8-bit floating point format which requires PyTorch 2.3 or later.\n\n```python\nimport torch\nfrom diffusers import StableDiffusion3Pipeline\n\npipe = StableDiffusion3Pipeline.from_single_file(\n    \"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors\",\n    torch_dtype=torch.float16,\n)\npipe.enable_model_cpu_offload()\n\nimage = pipe(\"a picture of a cat holding a sign that says hello world\").images[0]\nimage.save('sd3-single-file-t5-fp8.png')\n```\n\n### Loading the single file checkpoint for the Stable Diffusion 3.5 Transformer Model\n\n```python\nimport torch\nfrom diffusers import SD3Transformer2DModel, StableDiffusion3Pipeline\n\ntransformer = SD3Transformer2DModel.from_single_file(\n    \"https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo/blob/main/sd3.5_large.safetensors\",\n    torch_dtype=torch.bfloat16,\n)\npipe = StableDiffusion3Pipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-3.5-large\",\n    transformer=transformer,\n    torch_dtype=torch.bfloat16,\n)\npipe.enable_model_cpu_offload()\nimage = pipe(\"a cat holding a sign that says hello world\").images[0]\nimage.save(\"sd35.png\")\n```\n\n## StableDiffusion3Pipeline\n\n[[autodoc]] StableDiffusion3Pipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# Safe Stable Diffusion\n\nSafe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105) and mitigates inappropriate degeneration from Stable Diffusion models because they're trained on unfiltered web-crawled datasets. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, and otherwise offensive content. Safe Stable Diffusion is an extension of Stable Diffusion that drastically reduces this type of content.\n\nThe abstract from the paper is:\n\n*Text-conditioned image generation models have recently achieved astonishing results in image quality and text alignment and are consequently employed in a fast-growing number of applications. Since they are highly data-driven, relying on billion-sized datasets randomly scraped from the internet, they also suffer, as we demonstrate, from degenerated and biased human behavior. In turn, they may even reinforce such biases. To help combat these undesired side effects, we present safe latent diffusion (SLD). Specifically, to measure the inappropriate degeneration due to unfiltered and imbalanced training sets, we establish a novel image generation test bed-inappropriate image prompts (I2P)-containing dedicated, real-world image-to-text prompts covering concepts such as nudity and violence. As our exhaustive empirical evaluation demonstrates, the introduced SLD removes and suppresses inappropriate image parts during the diffusion process, with no additional training required and no adverse effect on overall image quality or text alignment.*\n\n## Tips\n\nUse the `safety_concept` property of [`StableDiffusionPipelineSafe`] to check and edit the current safety concept:\n\n```python\n>>> from diffusers import StableDiffusionPipelineSafe\n\n>>> pipeline = StableDiffusionPipelineSafe.from_pretrained(\"AIML-TUDA/stable-diffusion-safe\")\n>>> pipeline.safety_concept\n'an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity, bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child abuse, brutality, cruelty'\n```\nFor each image generation the active concept is also contained in [`StableDiffusionSafePipelineOutput`].\n\nThere are 4 configurations (`SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONG`, and `SafetyConfig.MAX`) that can be applied:\n\n```python\n>>> from diffusers import StableDiffusionPipelineSafe\n>>> from diffusers.pipelines.stable_diffusion_safe import SafetyConfig\n\n>>> pipeline = StableDiffusionPipelineSafe.from_pretrained(\"AIML-TUDA/stable-diffusion-safe\")\n>>> prompt = \"the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker\"\n>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX)\n```\n\n> [!TIP]\n> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!\n\n## StableDiffusionPipelineSafe\n\n[[autodoc]] StableDiffusionPipelineSafe\n\t- all\n\t- __call__\n\n## StableDiffusionSafePipelineOutput\n\n[[autodoc]] pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Diffusion XL\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n  <img alt=\"MPS\" src=\"https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22\">\n</div>\n\nStable Diffusion XL (SDXL) was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://huggingface.co/papers/2307.01952) by Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, and Robin Rombach.\n\nThe abstract from the paper is:\n\n*We present SDXL, a latent diffusion model for text-to-image synthesis. Compared to previous versions of Stable Diffusion, SDXL leverages a three times larger UNet backbone: The increase of model parameters is mainly due to more attention blocks and a larger cross-attention context as SDXL uses a second text encoder. We design multiple novel conditioning schemes and train SDXL on multiple aspect ratios. We also introduce a refinement model which is used to improve the visual fidelity of samples generated by SDXL using a post-hoc image-to-image technique. We demonstrate that SDXL shows drastically improved performance compared the previous versions of Stable Diffusion and achieves results competitive with those of black-box state-of-the-art image generators.*\n\n## Tips\n\n- Using SDXL with a DPM++ scheduler for less than 50 steps is known to produce [visual artifacts](https://github.com/huggingface/diffusers/issues/5433) because the solver becomes numerically unstable. To fix this issue, take a look at this [PR](https://github.com/huggingface/diffusers/pull/5541) which recommends for ODE/SDE solvers:\n\t- set `use_karras_sigmas=True` or `lu_lambdas=True` to improve image quality\n\t- set `euler_at_final=True` if you're using a solver with uniform step sizes (DPM++2M or DPM++2M SDE)\n- Most SDXL checkpoints work best with an image size of 1024x1024. Image sizes of 768x768 and 512x512 are also supported, but the results aren't as good. Anything below 512x512 is not recommended and likely won't be for default checkpoints like [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).\n- SDXL can pass a different prompt for each of the text encoders it was trained on. We can even pass different parts of the same prompt to the text encoders.\n- SDXL output images can be improved by making use of a refiner model in an image-to-image setting.\n- SDXL offers `negative_original_size`, `negative_crops_coords_top_left`, and `negative_target_size` to negatively condition the model on image resolution and cropping parameters.\n\n> [!TIP]\n> To learn how to use SDXL for various tasks, how to optimize performance, and other usage examples, take a look at the [Stable Diffusion XL](../../../using-diffusers/sdxl) guide.\n>\n> Check out the [Stability AI](https://huggingface.co/stabilityai) Hub organization for the official base and refiner model checkpoints!\n\n## StableDiffusionXLPipeline\n\n[[autodoc]] StableDiffusionXLPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLImg2ImgPipeline\n\n[[autodoc]] StableDiffusionXLImg2ImgPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLInpaintPipeline\n\n[[autodoc]] StableDiffusionXLInpaintPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/svd.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Video Diffusion\n\nStable Video Diffusion was proposed in [Stable Video Diffusion: Scaling Latent Video Diffusion Models to Large Datasets](https://hf.co/papers/2311.15127) by Andreas Blattmann, Tim Dockhorn, Sumith Kulal, Daniel Mendelevitch, Maciej Kilian, Dominik Lorenz, Yam Levi, Zion English, Vikram Voleti, Adam Letts, Varun Jampani, Robin Rombach.\n\nThe abstract from the paper is:\n\n*We present Stable Video Diffusion - a latent video diffusion model for high-resolution, state-of-the-art text-to-video and image-to-video generation. Recently, latent diffusion models trained for 2D image synthesis have been turned into generative video models by inserting temporal layers and finetuning them on small, high-quality video datasets. However, training methods in the literature vary widely, and the field has yet to agree on a unified strategy for curating video data. In this paper, we identify and evaluate three different stages for successful training of video LDMs: text-to-image pretraining, video pretraining, and high-quality video finetuning. Furthermore, we demonstrate the necessity of a well-curated pretraining dataset for generating high-quality videos and present a systematic curation process to train a strong base model, including captioning and filtering strategies. We then explore the impact of finetuning our base model on high-quality data and train a text-to-video model that is competitive with closed-source video generation. We also show that our base model provides a powerful motion representation for downstream tasks such as image-to-video generation and adaptability to camera motion-specific LoRA modules. Finally, we demonstrate that our model provides a strong multi-view 3D-prior and can serve as a base to finetune a multi-view diffusion model that jointly generates multiple views of objects in a feedforward fashion, outperforming image-based methods at a fraction of their compute budget. We release code and model weights at this https URL.*\n\n> [!TIP]\n> To learn how to use Stable Video Diffusion, take a look at the [Stable Video Diffusion](../../../using-diffusers/svd) guide.\n>\n> <br>\n>\n> Check out the [Stability AI](https://huggingface.co/stabilityai) Hub organization for the [base](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid) and [extended frame](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) checkpoints!\n\n## Tips\n\nVideo generation is memory-intensive and one way to reduce your memory usage is to set `enable_forward_chunking` on the pipeline's UNet so you don't run the entire feedforward layer at once. Breaking it up into chunks in a loop is more efficient.\n\nCheck out the [Text or image-to-video](../../../using-diffusers/text-img2vid) guide for more details about how certain parameters can affect video generation and how to optimize inference by reducing memory usage.\n\n## StableVideoDiffusionPipeline\n\n[[autodoc]] StableVideoDiffusionPipeline\n\n## StableVideoDiffusionPipelineOutput\n\n[[autodoc]] pipelines.stable_video_diffusion.StableVideoDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/text2img.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Text-to-image\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nThe Stable Diffusion model was created by researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), [Runway](https://github.com/runwayml), and [LAION](https://laion.ai/). The [`StableDiffusionPipeline`] is capable of generating photorealistic images given any text input. It's trained on 512x512 images from a subset of the LAION-5B dataset. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and can run on consumer GPUs. Latent diffusion is the research on top of which Stable Diffusion was built. It was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer.\n\nThe abstract from the paper is:\n\n*By decomposing the image formation process into a sequential application of denoising autoencoders, diffusion models (DMs) achieve state-of-the-art synthesis results on image data and beyond. Additionally, their formulation allows for a guiding mechanism to control the image generation process without retraining. However, since these models typically operate directly in pixel space, optimization of powerful DMs often consumes hundreds of GPU days and inference is expensive due to sequential evaluations. To enable DM training on limited computational resources while retaining their quality and flexibility, we apply them in the latent space of powerful pretrained autoencoders. In contrast to previous work, training diffusion models on such a representation allows for the first time to reach a near-optimal point between complexity reduction and detail preservation, greatly boosting visual fidelity. By introducing cross-attention layers into the model architecture, we turn diffusion models into powerful and flexible generators for general conditioning inputs such as text or bounding boxes and high-resolution synthesis becomes possible in a convolutional manner. Our latent diffusion models (LDMs) achieve a new state of the art for image inpainting and highly competitive performance on various tasks, including unconditional image generation, semantic scene synthesis, and super-resolution, while significantly reducing computational requirements compared to pixel-based DMs. Code is available at https://github.com/CompVis/latent-diffusion.*\n\n> [!TIP]\n> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!\n>\n> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!\n\n## StableDiffusionPipeline\n\n[[autodoc]] StableDiffusionPipeline\n\t- all\n\t- __call__\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n\t- enable_vae_slicing\n\t- disable_vae_slicing\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\t- enable_vae_tiling\n\t- disable_vae_tiling\n\t- load_textual_inversion\n\t- from_single_file\n\t- load_lora_weights\n\t- save_lora_weights\n\n## StableDiffusionPipelineOutput\n\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_diffusion/upscale.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Super-resolution\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nThe Stable Diffusion upscaler diffusion model was created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/), and [LAION](https://laion.ai/). It is used to enhance the resolution of input images by a factor of 4.\n\n> [!TIP]\n> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!\n>\n> If you're interested in using one of the official checkpoints for a task, explore the [CompVis](https://huggingface.co/CompVis) and [Stability AI](https://huggingface.co/stabilityai) Hub organizations!\n\n## StableDiffusionUpscalePipeline\n\n[[autodoc]] StableDiffusionUpscalePipeline\n\t- all\n\t- __call__\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\n## StableDiffusionPipelineOutput\n\n[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/stable_unclip.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable unCLIP\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nStable unCLIP checkpoints are finetuned from [Stable Diffusion 2.1](./stable_diffusion/stable_diffusion_2) checkpoints to condition on CLIP image embeddings.\nStable unCLIP still conditions on text embeddings. Given the two separate conditionings, stable unCLIP can be used\nfor text guided image variation. When combined with an unCLIP prior, it can also be used for full text to image generation.\n\nThe abstract from the paper is:\n\n*Contrastive models like CLIP have been shown to learn robust representations of images that capture both semantics and style. To leverage these representations for image generation, we propose a two-stage model: a prior that generates a CLIP image embedding given a text caption, and a decoder that generates an image conditioned on the image embedding. We show that explicitly generating image representations improves image diversity with minimal loss in photorealism and caption similarity. Our decoders conditioned on image representations can also produce variations of an image that preserve both its semantics and style, while varying the non-essential details absent from the image representation. Moreover, the joint embedding space of CLIP enables language-guided image manipulations in a zero-shot fashion. We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher-quality samples.*\n\n## Tips\n\nStable unCLIP takes  `noise_level` as input during inference which determines how much noise is added to the image embeddings. A higher `noise_level` increases variation in the final un-noised images. By default, we do not add any additional noise to the image embeddings (`noise_level = 0`).\n\n### Text-to-Image Generation\nStable unCLIP can be leveraged for text-to-image generation by pipelining it with the prior model of KakaoBrain's open source DALL-E 2 replication [Karlo](https://huggingface.co/kakaobrain/karlo-v1-alpha):\n\n```python\nimport torch\nfrom diffusers import UnCLIPScheduler, DDPMScheduler, StableUnCLIPPipeline\nfrom diffusers.models import PriorTransformer\nfrom transformers import CLIPTokenizer, CLIPTextModelWithProjection\n\nprior_model_id = \"kakaobrain/karlo-v1-alpha\"\ndata_type = torch.float16\nprior = PriorTransformer.from_pretrained(prior_model_id, subfolder=\"prior\", torch_dtype=data_type)\n\nprior_text_model_id = \"openai/clip-vit-large-patch14\"\nprior_tokenizer = CLIPTokenizer.from_pretrained(prior_text_model_id)\nprior_text_model = CLIPTextModelWithProjection.from_pretrained(prior_text_model_id, torch_dtype=data_type)\nprior_scheduler = UnCLIPScheduler.from_pretrained(prior_model_id, subfolder=\"prior_scheduler\")\nprior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)\n\nstable_unclip_model_id = \"stabilityai/stable-diffusion-2-1-unclip-small\"\n\npipe = StableUnCLIPPipeline.from_pretrained(\n    stable_unclip_model_id,\n    torch_dtype=data_type,\n    variant=\"fp16\",\n    prior_tokenizer=prior_tokenizer,\n    prior_text_encoder=prior_text_model,\n    prior=prior,\n    prior_scheduler=prior_scheduler,\n)\n\npipe = pipe.to(\"cuda\")\nwave_prompt = \"dramatic wave, the Oceans roar, Strong wave spiral across the oceans as the waves unfurl into roaring crests; perfect wave form; perfect wave shape; dramatic wave shape; wave shape unbelievable; wave; wave shape spectacular\"\n\nimage = pipe(prompt=wave_prompt).images[0]\nimage\n```\n> [!WARNING]\n> For text-to-image we use `stabilityai/stable-diffusion-2-1-unclip-small` as it was trained on CLIP ViT-L/14 embedding, the same as the Karlo model prior. [stabilityai/stable-diffusion-2-1-unclip](https://hf.co/stabilityai/stable-diffusion-2-1-unclip) was trained on OpenCLIP ViT-H, so we don't recommend its use.\n\n### Text guided Image-to-Image Variation\n\n```python\nfrom diffusers import StableUnCLIPImg2ImgPipeline\nfrom diffusers.utils import load_image\nimport torch\n\npipe = StableUnCLIPImg2ImgPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-2-1-unclip\", torch_dtype=torch.float16, variation=\"fp16\"\n)\npipe = pipe.to(\"cuda\")\n\nurl = \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/tarsila_do_amaral.png\"\ninit_image = load_image(url)\n\nimages = pipe(init_image).images\nimages[0].save(\"variation_image.png\")\n```\n\nOptionally, you can also pass a prompt to `pipe` such as:\n\n```python\nprompt = \"A fantasy landscape, trending on artstation\"\n\nimage = pipe(init_image, prompt=prompt).images[0]\nimage\n```\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## StableUnCLIPPipeline\n\n[[autodoc]] StableUnCLIPPipeline\n\t- all\n\t- __call__\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n\t- enable_vae_slicing\n\t- disable_vae_slicing\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\n## StableUnCLIPImg2ImgPipeline\n\n[[autodoc]] StableUnCLIPImg2ImgPipeline\n\t- all\n\t- __call__\n\t- enable_attention_slicing\n\t- disable_attention_slicing\n\t- enable_vae_slicing\n\t- disable_vae_slicing\n\t- enable_xformers_memory_efficient_attention\n\t- disable_xformers_memory_efficient_attention\n\n## ImagePipelineOutput\n[[autodoc]] pipelines.ImagePipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/text_to_video.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# Text-to-video\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n[ModelScope Text-to-Video Technical Report](https://huggingface.co/papers/2308.06571) is by Jiuniu Wang, Hangjie Yuan, Dayou Chen, Yingya Zhang, Xiang Wang, Shiwei Zhang.\n\nThe abstract from the paper is:\n\n*This paper introduces ModelScopeT2V, a text-to-video synthesis model that evolves from a text-to-image synthesis model (i.e., Stable Diffusion). ModelScopeT2V incorporates spatio-temporal blocks to ensure consistent frame generation and smooth movement transitions. The model could adapt to varying frame numbers during training and inference, rendering it suitable for both image-text and video-text datasets. ModelScopeT2V brings together three components (i.e., VQGAN, a text encoder, and a denoising UNet), totally comprising 1.7 billion parameters, in which 0.5 billion parameters are dedicated to temporal capabilities. The model demonstrates superior performance over state-of-the-art methods across three evaluation metrics. The code and an online demo are available at https://modelscope.cn/models/damo/text-to-video-synthesis/summary.*\n\nYou can find additional information about Text-to-Video on the [project page](https://modelscope.cn/models/damo/text-to-video-synthesis/summary), [original codebase](https://github.com/modelscope/modelscope/), and try it out in a [demo](https://huggingface.co/spaces/damo-vilab/modelscope-text-to-video-synthesis). Official checkpoints can be found at [damo-vilab](https://huggingface.co/damo-vilab) and [cerspense](https://huggingface.co/cerspense).\n\n## Usage example\n\n### `text-to-video-ms-1.7b`\n\nLet's start by generating a short video with the default length of 16 frames (2s at 8 fps):\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers.utils import export_to_video\n\npipe = DiffusionPipeline.from_pretrained(\"damo-vilab/text-to-video-ms-1.7b\", torch_dtype=torch.float16, variant=\"fp16\")\npipe = pipe.to(\"cuda\")\n\nprompt = \"Spiderman is surfing\"\nvideo_frames = pipe(prompt).frames[0]\nvideo_path = export_to_video(video_frames)\nvideo_path\n```\n\nDiffusers supports different optimization techniques to improve the latency\nand memory footprint of a pipeline. Since videos are often more memory-heavy than images,\nwe can enable CPU offloading and VAE slicing to keep the memory footprint at bay.\n\nLet's generate a video of 8 seconds (64 frames) on the same GPU using CPU offloading and VAE slicing:\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers.utils import export_to_video\n\npipe = DiffusionPipeline.from_pretrained(\"damo-vilab/text-to-video-ms-1.7b\", torch_dtype=torch.float16, variant=\"fp16\")\npipe.enable_model_cpu_offload()\n\n# memory optimization\npipe.enable_vae_slicing()\n\nprompt = \"Darth Vader surfing a wave\"\nvideo_frames = pipe(prompt, num_frames=64).frames[0]\nvideo_path = export_to_video(video_frames)\nvideo_path\n```\n\nIt just takes **7 GBs of GPU memory** to generate the 64 video frames using PyTorch 2.0, \"fp16\" precision and the techniques mentioned above.\n\nWe can also use a different scheduler easily, using the same method we'd use for Stable Diffusion:\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline, DPMSolverMultistepScheduler\nfrom diffusers.utils import export_to_video\n\npipe = DiffusionPipeline.from_pretrained(\"damo-vilab/text-to-video-ms-1.7b\", torch_dtype=torch.float16, variant=\"fp16\")\npipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\npipe.enable_model_cpu_offload()\n\nprompt = \"Spiderman is surfing\"\nvideo_frames = pipe(prompt, num_inference_steps=25).frames[0]\nvideo_path = export_to_video(video_frames)\nvideo_path\n```\n\nHere are some sample outputs:\n\n<table>\n    <tr>\n        <td><center>\n        An astronaut riding a horse.\n        <br>\n        <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astr.gif\"\n            alt=\"An astronaut riding a horse.\"\n            style=\"width: 300px;\" />\n        </center></td>\n        <td ><center>\n        Darth vader surfing in waves.\n        <br>\n        <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vader.gif\"\n            alt=\"Darth vader surfing in waves.\"\n            style=\"width: 300px;\" />\n        </center></td>\n    </tr>\n</table>\n\n### `cerspense/zeroscope_v2_576w` & `cerspense/zeroscope_v2_XL`\n\nZeroscope are watermark-free model and have been trained on specific sizes such as `576x320` and `1024x576`.\nOne should first generate a video using the lower resolution checkpoint [`cerspense/zeroscope_v2_576w`](https://huggingface.co/cerspense/zeroscope_v2_576w) with [`TextToVideoSDPipeline`],\nwhich can then be upscaled using [`VideoToVideoSDPipeline`] and [`cerspense/zeroscope_v2_XL`](https://huggingface.co/cerspense/zeroscope_v2_XL).\n\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline, DPMSolverMultistepScheduler\nfrom diffusers.utils import export_to_video\nfrom PIL import Image\n\npipe = DiffusionPipeline.from_pretrained(\"cerspense/zeroscope_v2_576w\", torch_dtype=torch.float16)\npipe.enable_model_cpu_offload()\n\n# memory optimization\npipe.unet.enable_forward_chunking(chunk_size=1, dim=1)\npipe.enable_vae_slicing()\n\nprompt = \"Darth Vader surfing a wave\"\nvideo_frames = pipe(prompt, num_frames=24).frames[0]\nvideo_path = export_to_video(video_frames)\nvideo_path\n```\n\nNow the video can be upscaled:\n\n```py\npipe = DiffusionPipeline.from_pretrained(\"cerspense/zeroscope_v2_XL\", torch_dtype=torch.float16)\npipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\npipe.enable_model_cpu_offload()\n\n# memory optimization\npipe.unet.enable_forward_chunking(chunk_size=1, dim=1)\npipe.enable_vae_slicing()\n\nvideo = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames]\n\nvideo_frames = pipe(prompt, video=video, strength=0.6).frames[0]\nvideo_path = export_to_video(video_frames)\nvideo_path\n```\n\nHere are some sample outputs:\n\n<table>\n    <tr>\n        <td ><center>\n        Darth vader surfing in waves.\n        <br>\n        <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/darthvader_cerpense.gif\"\n            alt=\"Darth vader surfing in waves.\"\n            style=\"width: 576px;\" />\n        </center></td>\n    </tr>\n</table>\n\n## Tips\n\nVideo generation is memory-intensive and one way to reduce your memory usage is to set `enable_forward_chunking` on the pipeline's UNet so you don't run the entire feedforward layer at once. Breaking it up into chunks in a loop is more efficient.\n\nCheck out the [Text or image-to-video](../../using-diffusers/text-img2vid) guide for more details about how certain parameters can affect video generation and how to optimize inference by reducing memory usage.\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## TextToVideoSDPipeline\n[[autodoc]] TextToVideoSDPipeline\n\t- all\n\t- __call__\n\n## VideoToVideoSDPipeline\n[[autodoc]] VideoToVideoSDPipeline\n\t- all\n\t- __call__\n\n## TextToVideoSDPipelineOutput\n[[autodoc]] pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/text_to_video_zero.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# Text2Video-Zero\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n[Text2Video-Zero: Text-to-Image Diffusion Models are Zero-Shot Video Generators](https://huggingface.co/papers/2303.13439) is by Levon Khachatryan, Andranik Movsisyan, Vahram Tadevosyan, Roberto Henschel, [Zhangyang Wang](https://www.ece.utexas.edu/people/faculty/atlas-wang), Shant Navasardyan, [Humphrey Shi](https://www.humphreyshi.com).\n\nText2Video-Zero enables zero-shot video generation using either:\n1. A textual prompt\n2. A prompt combined with guidance from poses or edges\n3. Video Instruct-Pix2Pix (instruction-guided video editing)\n\nResults are temporally consistent and closely follow the guidance and textual prompts.\n\n![teaser-img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2v_zero_teaser.png)\n\nThe abstract from the paper is:\n\n*Recent text-to-video generation approaches rely on computationally heavy training and require large-scale video datasets. In this paper, we introduce a new task of zero-shot text-to-video generation and propose a low-cost approach (without any training or optimization) by leveraging the power of existing text-to-image synthesis methods (e.g., Stable Diffusion), making them suitable for the video domain.\nOur key modifications include (i) enriching the latent codes of the generated frames with motion dynamics to keep the global scene and the background time consistent; and (ii) reprogramming frame-level self-attention using a new cross-frame attention of each frame on the first frame, to preserve the context, appearance, and identity of the foreground object.\nExperiments show that this leads to low overhead, yet high-quality and remarkably consistent video generation. Moreover, our approach is not limited to text-to-video synthesis but is also applicable to other tasks such as conditional and content-specialized video generation, and Video Instruct-Pix2Pix, i.e., instruction-guided video editing.\nAs experiments show, our method performs comparably or sometimes better than recent approaches, despite not being trained on additional video data.*\n\nYou can find additional information about Text2Video-Zero on the [project page](https://text2video-zero.github.io/), [paper](https://huggingface.co/papers/2303.13439), and [original codebase](https://github.com/Picsart-AI-Research/Text2Video-Zero).\n\n## Usage example\n\n### Text-To-Video\n\nTo generate a video from prompt, run the following Python code:\n```python\nimport torch\nfrom diffusers import TextToVideoZeroPipeline\nimport imageio\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A panda is playing guitar on times square\"\nresult = pipe(prompt=prompt).images\nresult = [(r * 255).astype(\"uint8\") for r in result]\nimageio.mimsave(\"video.mp4\", result, fps=4)\n```\nYou can change these parameters in the pipeline call:\n* Motion field strength (see the [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1):\n    * `motion_field_strength_x` and `motion_field_strength_y`. Default: `motion_field_strength_x=12`, `motion_field_strength_y=12`\n* `T` and `T'` (see the [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1)\n    * `t0` and `t1` in the range `{0, ..., num_inference_steps}`. Default: `t0=45`, `t1=48`\n* Video length:\n    * `video_length`, the number of frames video_length to be generated. Default: `video_length=8`\n\nWe can also generate longer videos by doing the processing in a chunk-by-chunk manner:\n```python\nimport torch\nfrom diffusers import TextToVideoZeroPipeline\nimport numpy as np\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\nseed = 0\nvideo_length = 24  #24 ÷ 4fps = 6 seconds\nchunk_size = 8\nprompt = \"A panda is playing guitar on times square\"\n\n# Generate the video chunk-by-chunk\nresult = []\nchunk_ids = np.arange(0, video_length, chunk_size - 1)\ngenerator = torch.Generator(device=\"cuda\")\nfor i in range(len(chunk_ids)):\n    print(f\"Processing chunk {i + 1} / {len(chunk_ids)}\")\n    ch_start = chunk_ids[i]\n    ch_end = video_length if i == len(chunk_ids) - 1 else chunk_ids[i + 1]\n    # Attach the first frame for Cross Frame Attention\n    frame_ids = [0] + list(range(ch_start, ch_end))\n    # Fix the seed for the temporal consistency\n    generator.manual_seed(seed)\n    output = pipe(prompt=prompt, video_length=len(frame_ids), generator=generator, frame_ids=frame_ids)\n    result.append(output.images[1:])\n\n# Concatenate chunks and save\nresult = np.concatenate(result)\nresult = [(r * 255).astype(\"uint8\") for r in result]\nimageio.mimsave(\"video.mp4\", result, fps=4)\n```\n\n\n- #### SDXL Support\nIn order to use the SDXL model when generating a video from prompt, use the `TextToVideoZeroSDXLPipeline` pipeline:\n\n```python\nimport torch\nfrom diffusers import TextToVideoZeroSDXLPipeline\n\nmodel_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\npipe = TextToVideoZeroSDXLPipeline.from_pretrained(\n    model_id, torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n).to(\"cuda\")\n```\n\n### Text-To-Video with Pose Control\nTo generate a video from prompt with additional pose control\n\n1. Download a demo video\n\n    ```python\n    from huggingface_hub import hf_hub_download\n\n    filename = \"__assets__/poses_skeleton_gifs/dance1_corr.mp4\"\n    repo_id = \"PAIR/Text2Video-Zero\"\n    video_path = hf_hub_download(repo_type=\"space\", repo_id=repo_id, filename=filename)\n    ```\n\n\n2. Read video containing extracted pose images\n    ```python\n    from PIL import Image\n    import imageio\n\n    reader = imageio.get_reader(video_path, \"ffmpeg\")\n    frame_count = 8\n    pose_images = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]\n    ```\n    To extract pose from actual video, read [ControlNet documentation](controlnet).\n\n3. Run `StableDiffusionControlNetPipeline` with our custom attention processor\n\n    ```python\n    import torch\n    from diffusers import StableDiffusionControlNetPipeline, ControlNetModel\n    from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor\n\n    model_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n    controlnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-openpose\", torch_dtype=torch.float16)\n    pipe = StableDiffusionControlNetPipeline.from_pretrained(\n        model_id, controlnet=controlnet, torch_dtype=torch.float16\n    ).to(\"cuda\")\n\n    # Set the attention processor\n    pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))\n    pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))\n\n    # fix latents for all frames\n    latents = torch.randn((1, 4, 64, 64), device=\"cuda\", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)\n\n    prompt = \"Darth Vader dancing in a desert\"\n    result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images\n    imageio.mimsave(\"video.mp4\", result, fps=4)\n    ```\n- #### SDXL Support\n\n\tSince our attention processor also works with SDXL, it can be utilized to generate a video from prompt using ControlNet models powered by SDXL:\n\t```python\n\timport torch\n\tfrom diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel\n\tfrom diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor\n\n\tcontrolnet_model_id = 'thibaud/controlnet-openpose-sdxl-1.0'\n\tmodel_id = 'stabilityai/stable-diffusion-xl-base-1.0'\n\n\tcontrolnet = ControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch.float16)\n\tpipe = StableDiffusionControlNetPipeline.from_pretrained(\n\t\tmodel_id, controlnet=controlnet, torch_dtype=torch.float16\n\t).to('cuda')\n\n\t# Set the attention processor\n\tpipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))\n\tpipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))\n\n\t# fix latents for all frames\n\tlatents = torch.randn((1, 4, 128, 128), device=\"cuda\", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)\n\n\tprompt = \"Darth Vader dancing in a desert\"\n\tresult = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images\n\timageio.mimsave(\"video.mp4\", result, fps=4)\n\t```\n\n### Text-To-Video with Edge Control\n\nTo generate a video from prompt with additional Canny edge control, follow the same steps described above for pose-guided generation using [Canny edge ControlNet model](https://huggingface.co/lllyasviel/sd-controlnet-canny).\n\n\n### Video Instruct-Pix2Pix\n\nTo perform text-guided video editing (with [InstructPix2Pix](pix2pix)):\n\n1. Download a demo video\n\n    ```python\n    from huggingface_hub import hf_hub_download\n\n    filename = \"__assets__/pix2pix video/camel.mp4\"\n    repo_id = \"PAIR/Text2Video-Zero\"\n    video_path = hf_hub_download(repo_type=\"space\", repo_id=repo_id, filename=filename)\n    ```\n\n2. Read video from path\n    ```python\n    from PIL import Image\n    import imageio\n\n    reader = imageio.get_reader(video_path, \"ffmpeg\")\n    frame_count = 8\n    video = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]\n    ```\n\n3. Run `StableDiffusionInstructPix2PixPipeline` with our custom attention processor\n    ```python\n    import torch\n    from diffusers import StableDiffusionInstructPix2PixPipeline\n    from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor\n\n    model_id = \"timbrooks/instruct-pix2pix\"\n    pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\n    pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=3))\n\n    prompt = \"make it Van Gogh Starry Night style\"\n    result = pipe(prompt=[prompt] * len(video), image=video).images\n    imageio.mimsave(\"edited_video.mp4\", result, fps=4)\n    ```\n\n\n### DreamBooth specialization\n\nMethods **Text-To-Video**, **Text-To-Video with Pose Control** and **Text-To-Video with Edge Control**\ncan run with custom [DreamBooth](../../training/dreambooth) models, as shown below for\n[Canny edge ControlNet model](https://huggingface.co/lllyasviel/sd-controlnet-canny) and\n[Avatar style DreamBooth](https://huggingface.co/PAIR/text2video-zero-controlnet-canny-avatar) model:\n\n1. Download a demo video\n\n    ```python\n    from huggingface_hub import hf_hub_download\n\n    filename = \"__assets__/canny_videos_mp4/girl_turning.mp4\"\n    repo_id = \"PAIR/Text2Video-Zero\"\n    video_path = hf_hub_download(repo_type=\"space\", repo_id=repo_id, filename=filename)\n    ```\n\n2. Read video from path\n    ```python\n    from PIL import Image\n    import imageio\n\n    reader = imageio.get_reader(video_path, \"ffmpeg\")\n    frame_count = 8\n    canny_edges = [Image.fromarray(reader.get_data(i)) for i in range(frame_count)]\n    ```\n\n3. Run `StableDiffusionControlNetPipeline` with custom trained DreamBooth model\n    ```python\n    import torch\n    from diffusers import StableDiffusionControlNetPipeline, ControlNetModel\n    from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor\n\n    # set model id to custom model\n    model_id = \"PAIR/text2video-zero-controlnet-canny-avatar\"\n    controlnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\n    pipe = StableDiffusionControlNetPipeline.from_pretrained(\n        model_id, controlnet=controlnet, torch_dtype=torch.float16\n    ).to(\"cuda\")\n\n    # Set the attention processor\n    pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))\n    pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))\n\n    # fix latents for all frames\n    latents = torch.randn((1, 4, 64, 64), device=\"cuda\", dtype=torch.float16).repeat(len(canny_edges), 1, 1, 1)\n\n    prompt = \"oil painting of a beautiful girl avatar style\"\n    result = pipe(prompt=[prompt] * len(canny_edges), image=canny_edges, latents=latents).images\n    imageio.mimsave(\"video.mp4\", result, fps=4)\n    ```\n\nYou can filter out some available DreamBooth-trained models with [this link](https://huggingface.co/models?search=dreambooth).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## TextToVideoZeroPipeline\n[[autodoc]] TextToVideoZeroPipeline\n\t- all\n\t- __call__\n\n## TextToVideoZeroSDXLPipeline\n[[autodoc]] TextToVideoZeroSDXLPipeline\n\t- all\n\t- __call__\n\n## TextToVideoPipelineOutput\n[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/unclip.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\nhttp://www.apache.org/licenses/LICENSE-2.0\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# unCLIP\n\n[Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) is by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. The unCLIP model in 🤗 Diffusers comes from kakaobrain's [karlo](https://github.com/kakaobrain/karlo).\n\nThe abstract from the paper is following:\n\n*Contrastive models like CLIP have been shown to learn robust representations of images that capture both semantics and style. To leverage these representations for image generation, we propose a two-stage model: a prior that generates a CLIP image embedding given a text caption, and a decoder that generates an image conditioned on the image embedding. We show that explicitly generating image representations improves image diversity with minimal loss in photorealism and caption similarity. Our decoders conditioned on image representations can also produce variations of an image that preserve both its semantics and style, while varying the non-essential details absent from the image representation. Moreover, the joint embedding space of CLIP enables language-guided image manipulations in a zero-shot fashion. We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher-quality samples.*\n\nYou can find lucidrains' DALL-E 2 recreation at [lucidrains/DALLE2-pytorch](https://github.com/lucidrains/DALLE2-pytorch).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## UnCLIPPipeline\n[[autodoc]] UnCLIPPipeline\n\t- all\n\t- __call__\n\n## UnCLIPImageVariationPipeline\n[[autodoc]] UnCLIPImageVariationPipeline\n\t- all\n\t- __call__\n\n## ImagePipelineOutput\n[[autodoc]] pipelines.ImagePipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/unidiffuser.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n# UniDiffuser\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\nThe UniDiffuser model was proposed in [One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale](https://huggingface.co/papers/2303.06555) by Fan Bao, Shen Nie, Kaiwen Xue, Chongxuan Li, Shi Pu, Yaole Wang, Gang Yue, Yue Cao, Hang Su, Jun Zhu.\n\nThe abstract from the paper is:\n\n*This paper proposes a unified diffusion framework (dubbed UniDiffuser) to fit all distributions relevant to a set of multi-modal data in one model. Our key insight is -- learning diffusion models for marginal, conditional, and joint distributions can be unified as predicting the noise in the perturbed data, where the perturbation levels (i.e. timesteps) can be different for different modalities. Inspired by the unified view, UniDiffuser learns all distributions simultaneously with a minimal modification to the original diffusion model -- perturbs data in all modalities instead of a single modality, inputs individual timesteps in different modalities, and predicts the noise of all modalities instead of a single modality. UniDiffuser is parameterized by a transformer for diffusion models to handle input types of different modalities. Implemented on large-scale paired image-text data, UniDiffuser is able to perform image, text, text-to-image, image-to-text, and image-text pair generation by setting proper timesteps without additional overhead. In particular, UniDiffuser is able to produce perceptually realistic samples in all tasks and its quantitative results (e.g., the FID and CLIP score) are not only superior to existing general-purpose models but also comparable to the bespoken models (e.g., Stable Diffusion and DALL-E 2) in representative tasks (e.g., text-to-image generation).*\n\nYou can find the original codebase at [thu-ml/unidiffuser](https://github.com/thu-ml/unidiffuser) and additional checkpoints at [thu-ml](https://huggingface.co/thu-ml).\n\n> [!WARNING]\n> There is currently an issue on PyTorch 1.X where the output images are all black or the pixel values become `NaNs`. This issue can be mitigated by switching to PyTorch 2.X.\n\nThis pipeline was contributed by [dg845](https://github.com/dg845). ❤️\n\n## Usage Examples\n\nBecause the UniDiffuser model is trained to model the joint distribution of (image, text) pairs, it is capable of performing a diverse range of generation tasks:\n\n### Unconditional Image and Text Generation\n\nUnconditional generation (where we start from only latents sampled from a standard Gaussian prior) from a [`UniDiffuserPipeline`] will produce a (image, text) pair:\n\n```python\nimport torch\n\nfrom diffusers import UniDiffuserPipeline\n\ndevice = \"cuda\"\nmodel_id_or_path = \"thu-ml/unidiffuser-v1\"\npipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)\npipe.to(device)\n\n# Unconditional image and text generation. The generation task is automatically inferred.\nsample = pipe(num_inference_steps=20, guidance_scale=8.0)\nimage = sample.images[0]\ntext = sample.text[0]\nimage.save(\"unidiffuser_joint_sample_image.png\")\nprint(text)\n```\n\nThis is also called \"joint\" generation in the UniDiffuser paper, since we are sampling from the joint image-text distribution.\n\nNote that the generation task is inferred from the inputs used when calling the pipeline.\nIt is also possible to manually specify the unconditional generation task (\"mode\") manually with [`UniDiffuserPipeline.set_joint_mode`]:\n\n```python\n# Equivalent to the above.\npipe.set_joint_mode()\nsample = pipe(num_inference_steps=20, guidance_scale=8.0)\n```\n\nWhen the mode is set manually, subsequent calls to the pipeline will use the set mode without attempting to infer the mode.\nYou can reset the mode with [`UniDiffuserPipeline.reset_mode`], after which the pipeline will once again infer the mode.\n\nYou can also generate only an image or only text (which the UniDiffuser paper calls \"marginal\" generation since we sample from the marginal distribution of images and text, respectively):\n\n```python\n# Unlike other generation tasks, image-only and text-only generation don't use classifier-free guidance\n# Image-only generation\npipe.set_image_mode()\nsample_image = pipe(num_inference_steps=20).images[0]\n# Text-only generation\npipe.set_text_mode()\nsample_text = pipe(num_inference_steps=20).text[0]\n```\n\n### Text-to-Image Generation\n\nUniDiffuser is also capable of sampling from conditional distributions; that is, the distribution of images conditioned on a text prompt or the distribution of texts conditioned on an image.\nHere is an example of sampling from the conditional image distribution (text-to-image generation or text-conditioned image generation):\n\n```python\nimport torch\n\nfrom diffusers import UniDiffuserPipeline\n\ndevice = \"cuda\"\nmodel_id_or_path = \"thu-ml/unidiffuser-v1\"\npipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)\npipe.to(device)\n\n# Text-to-image generation\nprompt = \"an elephant under the sea\"\n\nsample = pipe(prompt=prompt, num_inference_steps=20, guidance_scale=8.0)\nt2i_image = sample.images[0]\nt2i_image\n```\n\nThe `text2img` mode requires that either an input `prompt` or `prompt_embeds` be supplied. You can set the `text2img` mode manually with [`UniDiffuserPipeline.set_text_to_image_mode`].\n\n### Image-to-Text Generation\n\nSimilarly, UniDiffuser can also produce text samples given an image (image-to-text or image-conditioned text generation):\n\n```python\nimport torch\n\nfrom diffusers import UniDiffuserPipeline\nfrom diffusers.utils import load_image\n\ndevice = \"cuda\"\nmodel_id_or_path = \"thu-ml/unidiffuser-v1\"\npipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)\npipe.to(device)\n\n# Image-to-text generation\nimage_url = \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg\"\ninit_image = load_image(image_url).resize((512, 512))\n\nsample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0)\ni2t_text = sample.text[0]\nprint(i2t_text)\n```\n\nThe `img2text` mode requires that an input `image` be supplied. You can set the `img2text` mode manually with [`UniDiffuserPipeline.set_image_to_text_mode`].\n\n### Image Variation\n\nThe UniDiffuser authors suggest performing image variation through a \"round-trip\" generation method, where given an input image, we first perform an image-to-text generation, and then perform a text-to-image generation on the outputs of the first generation.\nThis produces a new image which is semantically similar to the input image:\n\n```python\nimport torch\n\nfrom diffusers import UniDiffuserPipeline\nfrom diffusers.utils import load_image\n\ndevice = \"cuda\"\nmodel_id_or_path = \"thu-ml/unidiffuser-v1\"\npipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)\npipe.to(device)\n\n# Image variation can be performed with an image-to-text generation followed by a text-to-image generation:\n# 1. Image-to-text generation\nimage_url = \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg\"\ninit_image = load_image(image_url).resize((512, 512))\n\nsample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0)\ni2t_text = sample.text[0]\nprint(i2t_text)\n\n# 2. Text-to-image generation\nsample = pipe(prompt=i2t_text, num_inference_steps=20, guidance_scale=8.0)\nfinal_image = sample.images[0]\nfinal_image.save(\"unidiffuser_image_variation_sample.png\")\n```\n\n### Text Variation\n\nSimilarly, text variation can be performed on an input prompt with a text-to-image generation followed by a image-to-text generation:\n\n```python\nimport torch\n\nfrom diffusers import UniDiffuserPipeline\n\ndevice = \"cuda\"\nmodel_id_or_path = \"thu-ml/unidiffuser-v1\"\npipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)\npipe.to(device)\n\n# Text variation can be performed with a text-to-image generation followed by a image-to-text generation:\n# 1. Text-to-image generation\nprompt = \"an elephant under the sea\"\n\nsample = pipe(prompt=prompt, num_inference_steps=20, guidance_scale=8.0)\nt2i_image = sample.images[0]\nt2i_image.save(\"unidiffuser_text2img_sample_image.png\")\n\n# 2. Image-to-text generation\nsample = pipe(image=t2i_image, num_inference_steps=20, guidance_scale=8.0)\nfinal_prompt = sample.text[0]\nprint(final_prompt)\n```\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## UniDiffuserPipeline\n[[autodoc]] UniDiffuserPipeline\n\t- all\n\t- __call__\n\n## ImageTextPipelineOutput\n[[autodoc]] pipelines.ImageTextPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/value_guided_sampling.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Value-guided planning\n\n> [!WARNING]\n> 🧪 This is an experimental pipeline for reinforcement learning!\n\nThis pipeline is based on the [Planning with Diffusion for Flexible Behavior Synthesis](https://huggingface.co/papers/2205.09991) paper by Michael Janner, Yilun Du, Joshua B. Tenenbaum, Sergey Levine.\n\nThe abstract from the paper is:\n\n*Model-based reinforcement learning methods often use learning only for the purpose of estimating an approximate dynamics model, offloading the rest of the decision-making work to classical trajectory optimizers. While conceptually simple, this combination has a number of empirical shortcomings, suggesting that learned models may not be well-suited to standard trajectory optimization. In this paper, we consider what it would look like to fold as much of the trajectory optimization pipeline as possible into the modeling problem, such that sampling from the model and planning with it become nearly identical. The core of our technical approach lies in a diffusion probabilistic model that plans by iteratively denoising trajectories. We show how classifier-guided sampling and image inpainting can be reinterpreted as coherent planning strategies, explore the unusual and useful properties of diffusion-based planning methods, and demonstrate the effectiveness of our framework in control settings that emphasize long-horizon decision-making and test-time flexibility.*\n\nYou can find additional information about the model on the [project page](https://diffusion-planning.github.io/), the [original codebase](https://github.com/jannerm/diffuser), or try it out in a demo [notebook](https://colab.research.google.com/drive/1rXm8CX4ZdN5qivjJ2lhwhkOmt_m0CvU0#scrollTo=6HXJvhyqcITc&uniqifier=1).\n\nThe script to run the model is available [here](https://github.com/huggingface/diffusers/tree/main/examples/reinforcement_learning).\n\n> [!TIP]\n> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.\n\n## ValueGuidedRLPipeline\n[[autodoc]] diffusers.experimental.ValueGuidedRLPipeline\n"
  },
  {
    "path": "docs/source/en/api/pipelines/visualcloze.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n-->\n\n# VisualCloze\n\n[VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning](https://huggingface.co/papers/2504.07960) is an innovative in-context learning based universal image generation framework that offers key capabilities:\n1. Support for various in-domain tasks\n2. Generalization to unseen tasks through in-context learning\n3. Unify multiple tasks into one step and generate both target image and intermediate results\n4. Support reverse-engineering conditions from target images\n\n## Overview\n\nThe abstract from the paper is:\n\n*Recent progress in diffusion models significantly advances various image generation tasks. However, the current mainstream approach remains focused on building task-specific models, which have limited efficiency when supporting a wide range of different needs. While universal models attempt to address this limitation, they face critical challenges, including generalizable task instruction, appropriate task distributions, and unified architectural design. To tackle these challenges, we propose VisualCloze, a universal image generation framework, which supports a wide range of in-domain tasks, generalization to unseen ones, unseen unification of multiple tasks, and reverse generation. Unlike existing methods that rely on language-based task instruction, leading to task ambiguity and weak generalization, we integrate visual in-context learning, allowing models to identify tasks from visual demonstrations. Meanwhile, the inherent sparsity of visual task distributions hampers the learning of transferable knowledge across tasks. To this end, we introduce Graph200K, a graph-structured dataset that establishes various interrelated tasks, enhancing task density and transferable knowledge. Furthermore, we uncover that our unified image generation formulation shared a consistent objective with image infilling, enabling us to leverage the strong generative priors of pre-trained infilling models without modifying the architectures. The codes, dataset, and models are available at https://visualcloze.github.io.*\n\n## Inference\n\n### Model loading\n\nVisualCloze is a two-stage cascade pipeline, containing `VisualClozeGenerationPipeline` and `VisualClozeUpsamplingPipeline`.\n- In `VisualClozeGenerationPipeline`, each image is downsampled before concatenating images into a grid layout, avoiding excessively high resolutions. VisualCloze releases two models suitable for diffusers, i.e., [VisualClozePipeline-384](https://huggingface.co/VisualCloze/VisualClozePipeline-384) and [VisualClozePipeline-512](https://huggingface.co/VisualCloze/VisualClozePipeline-384), which downsample images to resolutions of 384 and 512, respectively. \n- `VisualClozeUpsamplingPipeline` uses [SDEdit](https://huggingface.co/papers/2108.01073) to enable high-resolution image synthesis.\n\nThe `VisualClozePipeline` integrates both stages to support convenient end-to-end sampling, while also allowing users to utilize each pipeline independently as needed.\n\n### Input Specifications\n\n#### Task and Content Prompts\n- Task prompt: Required to describe the generation task intention\n- Content prompt: Optional description or caption of the target image\n- When content prompt is not needed, pass `None`\n- For batch inference, pass `List[str|None]`\n\n#### Image Input Format\n- Format: `List[List[Image|None]]`\n- Structure:\n  - All rows except the last represent in-context examples\n  - Last row represents the current query (target image set to `None`)\n- For batch inference, pass `List[List[List[Image|None]]]`\n\n#### Resolution Control\n- Default behavior:\n  - Initial generation in the first stage: area of ${pipe.resolution}^2$\n  - Upsampling in the second stage: 3x factor\n- Custom resolution: Adjust using `upsampling_height` and `upsampling_width` parameters\n\n### Examples\n\nFor comprehensive examples covering a wide range of tasks, please refer to the [Online Demo](https://huggingface.co/spaces/VisualCloze/VisualCloze) and [GitHub Repository](https://github.com/lzyhha/VisualCloze). Below are simple examples for three cases: mask-to-image conversion, edge detection, and subject-driven generation.\n\n#### Example for mask2image\n\n```python\nimport torch\nfrom diffusers import VisualClozePipeline\nfrom diffusers.utils import load_image\n\npipe = VisualClozePipeline.from_pretrained(\"VisualCloze/VisualClozePipeline-384\", resolution=384, torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\n# Load in-context images (make sure the paths are correct and accessible)\nimage_paths = [\n    # in-context examples\n    [\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg'),\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg'),\n    ],\n    # query with the target image\n    [\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg'),\n        None, # No image needed for the target image\n    ],\n]\n\n# Task and content prompt\ntask_prompt = \"In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding.\"\ncontent_prompt = \"\"\"Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. \nThe eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. \nIts plumage is a mix of dark brown and golden hues, with intricate feather details. \nThe background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. \nThe foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, \nsoft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, \ntranquil, majestic, wildlife photography.\"\"\"\n\n# Run the pipeline\nimage_result = pipe(\n    task_prompt=task_prompt,\n    content_prompt=content_prompt,\n    image=image_paths,\n    upsampling_width=1344,\n    upsampling_height=768,\n    upsampling_strength=0.4,\n    guidance_scale=30,\n    num_inference_steps=30,\n    max_sequence_length=512,\n    generator=torch.Generator(\"cpu\").manual_seed(0)\n).images[0][0]\n\n# Save the resulting image\nimage_result.save(\"visualcloze.png\")\n```\n\n#### Example for edge-detection\n\n```python\nimport torch\nfrom diffusers import VisualClozePipeline\nfrom diffusers.utils import load_image\n\npipe = VisualClozePipeline.from_pretrained(\"VisualCloze/VisualClozePipeline-384\", resolution=384, torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\n# Load in-context images (make sure the paths are correct and accessible)\nimage_paths = [\n    # in-context examples\n    [\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_image.jpg'),\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_edge.jpg'),\n    ],\n    [\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_image.jpg'),\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_edge.jpg'),\n    ],\n    # query with the target image\n    [\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_query_image.jpg'),\n        None, # No image needed for the target image\n    ],\n]\n\n# Task and content prompt\ntask_prompt = \"Each row illustrates a pathway from [IMAGE1] a sharp and beautifully composed photograph to [IMAGE2] edge map with natural well-connected outlines using a clear logical task.\"\ncontent_prompt = \"\"\n\n# Run the pipeline\nimage_result = pipe(\n    task_prompt=task_prompt,\n    content_prompt=content_prompt,\n    image=image_paths,\n    upsampling_width=864,\n    upsampling_height=1152,\n    upsampling_strength=0.4,\n    guidance_scale=30,\n    num_inference_steps=30,\n    max_sequence_length=512,\n    generator=torch.Generator(\"cpu\").manual_seed(0)\n).images[0][0]\n\n# Save the resulting image\nimage_result.save(\"visualcloze.png\")\n```\n\n#### Example for subject-driven generation\n\n```python\nimport torch\nfrom diffusers import VisualClozePipeline\nfrom diffusers.utils import load_image\n\npipe = VisualClozePipeline.from_pretrained(\"VisualCloze/VisualClozePipeline-384\", resolution=384, torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\n# Load in-context images (make sure the paths are correct and accessible)\nimage_paths = [\n    # in-context examples\n    [\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_reference.jpg'),\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_depth.jpg'),\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_image.jpg'),\n    ],\n    [\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_reference.jpg'),\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_depth.jpg'),\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_image.jpg'),\n    ],\n    # query with the target image\n    [\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_reference.jpg'),\n        load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_depth.jpg'),\n        None, # No image needed for the target image\n    ],\n]\n\n# Task and content prompt\ntask_prompt = \"\"\"Each row describes a process that begins with [IMAGE1] an image containing the key object, \n[IMAGE2] depth map revealing gray-toned spatial layers and results in \n[IMAGE3] an image with artistic qualitya high-quality image with exceptional detail.\"\"\"\ncontent_prompt = \"\"\"A vintage porcelain collector's item. Beneath a blossoming cherry tree in early spring, \nthis treasure is photographed up close, with soft pink petals drifting through the air and vibrant blossoms framing the scene.\"\"\"\n\n# Run the pipeline\nimage_result = pipe(\n    task_prompt=task_prompt,\n    content_prompt=content_prompt,\n    image=image_paths,\n    upsampling_width=1024,\n    upsampling_height=1024,\n    upsampling_strength=0.2,\n    guidance_scale=30,\n    num_inference_steps=30,\n    max_sequence_length=512,\n    generator=torch.Generator(\"cpu\").manual_seed(0)\n).images[0][0]\n\n# Save the resulting image\nimage_result.save(\"visualcloze.png\")\n```\n\n#### Utilize each pipeline independently \n\n```python\nimport torch\nfrom diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline\nfrom diffusers.utils import load_image\nfrom PIL import Image\n\npipe = VisualClozeGenerationPipeline.from_pretrained(\n    \"VisualCloze/VisualClozePipeline-384\", resolution=384, torch_dtype=torch.bfloat16\n)\npipe.to(\"cuda\")\n\nimage_paths = [\n    # in-context examples\n    [\n        load_image(\n            \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg\"\n        ),\n        load_image(\n            \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg\"\n        ),\n    ],\n    # query with the target image\n    [\n        load_image(\n            \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg\"\n        ),\n        None,  # No image needed for the target image\n    ],\n]\ntask_prompt = \"In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding.\"\ncontent_prompt = \"Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography.\"\n\n# Stage 1: Generate initial image\nimage = pipe(\n    task_prompt=task_prompt,\n    content_prompt=content_prompt,\n    image=image_paths,\n    guidance_scale=30,\n    num_inference_steps=30,\n    max_sequence_length=512,\n    generator=torch.Generator(\"cpu\").manual_seed(0),\n).images[0][0]\n\n# Stage 2 (optional): Upsample the generated image\npipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe)\npipe_upsample.to(\"cuda\")\n\nmask_image = Image.new(\"RGB\", image.size, (255, 255, 255))\n\nimage = pipe_upsample(\n    image=image,\n    mask_image=mask_image,\n    prompt=content_prompt,\n    width=1344,\n    height=768,\n    strength=0.4,\n    guidance_scale=30,\n    num_inference_steps=30,\n    max_sequence_length=512,\n    generator=torch.Generator(\"cpu\").manual_seed(0),\n).images[0]\n\nimage.save(\"visualcloze.png\")\n```\n\n## VisualClozePipeline\n\n[[autodoc]] VisualClozePipeline\n  - all\n  - __call__\n\n## VisualClozeGenerationPipeline\n\n[[autodoc]] VisualClozeGenerationPipeline\n  - all\n  - __call__\n"
  },
  {
    "path": "docs/source/en/api/pipelines/wan.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License. -->\n\n<div style=\"float: right;\">\n  <div class=\"flex flex-wrap space-x-1\">\n    <a href=\"https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference\" target=\"_blank\" rel=\"noopener\">\n      <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n    </a>\n  </div>\n</div>\n\n# Wan\n\n[Wan-2.1](https://huggingface.co/papers/2503.20314) by the Wan Team.\n\n*This report presents Wan, a comprehensive and open suite of video foundation models designed to push the boundaries of video generation. Built upon the mainstream diffusion transformer paradigm, Wan achieves significant advancements in generative capabilities through a series of innovations, including our novel VAE, scalable pre-training strategies, large-scale data curation, and automated evaluation metrics. These contributions collectively enhance the model's performance and versatility. Specifically, Wan is characterized by four key features: Leading Performance: The 14B model of Wan, trained on a vast dataset comprising billions of images and videos, demonstrates the scaling laws of video generation with respect to both data and model size. It consistently outperforms the existing open-source models as well as state-of-the-art commercial solutions across multiple internal and external benchmarks, demonstrating a clear and significant performance superiority. Comprehensiveness: Wan offers two capable models, i.e., 1.3B and 14B parameters, for efficiency and effectiveness respectively. It also covers multiple downstream applications, including image-to-video, instruction-guided video editing, and personal video generation, encompassing up to eight tasks. Consumer-Grade Efficiency: The 1.3B model demonstrates exceptional resource efficiency, requiring only 8.19 GB VRAM, making it compatible with a wide range of consumer-grade GPUs. Openness: We open-source the entire series of Wan, including source code and all models, with the goal of fostering the growth of the video generation community. This openness seeks to significantly expand the creative possibilities of video production in the industry and provide academia with high-quality video foundation models. All the code and models are available at [this https URL](https://github.com/Wan-Video/Wan2.1).*\n\nYou can find all the original Wan2.1 checkpoints under the [Wan-AI](https://huggingface.co/Wan-AI) organization.\n\nThe following Wan models are supported in Diffusers:\n\n- [Wan 2.1 T2V 1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)\n- [Wan 2.1 T2V 14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers)\n- [Wan 2.1 I2V 14B - 480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)\n- [Wan 2.1 I2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers)\n- [Wan 2.1 FLF2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers)\n- [Wan 2.1 VACE 1.3B](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers)\n- [Wan 2.1 VACE 14B](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B-diffusers)\n- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)\n- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)\n- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)\n- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)\n\n> [!TIP]\n> Click on the Wan models in the right sidebar for more examples of video generation.\n\n### Text-to-Video Generation\n\nThe example below demonstrates how to generate a video from text optimized for memory or inference speed.\n\n<hfoptions id=\"T2V usage\">\n<hfoption id=\"T2V memory\">\n\nRefer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.\n\nThe Wan2.1 text-to-video model below requires ~13GB of VRAM.\n\n```py\n# pip install ftfy\nimport torch\nimport numpy as np\nfrom diffusers import AutoModel, WanPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\nfrom diffusers.hooks.group_offloading import apply_group_offloading\nfrom diffusers.utils import export_to_video, load_image\nfrom transformers import UMT5EncoderModel\n\ntext_encoder = UMT5EncoderModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"text_encoder\", torch_dtype=torch.bfloat16)\nvae = AutoModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"vae\", torch_dtype=torch.float32)\ntransformer = AutoModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n\n# group-offloading\nonload_device = torch.device(\"cuda\")\noffload_device = torch.device(\"cpu\")\napply_group_offloading(text_encoder,\n    onload_device=onload_device,\n    offload_device=offload_device,\n    offload_type=\"block_level\",\n    num_blocks_per_group=4\n)\ntransformer.enable_group_offload(\n    onload_device=onload_device,\n    offload_device=offload_device,\n    offload_type=\"leaf_level\",\n    use_stream=True\n)\n\npipeline = WanPipeline.from_pretrained(\n    \"Wan-AI/Wan2.1-T2V-14B-Diffusers\",\n    vae=vae,\n    transformer=transformer,\n    text_encoder=text_encoder,\n    torch_dtype=torch.bfloat16\n)\npipeline.to(\"cuda\")\n\nprompt = \"\"\"\nThe camera rushes from far to near in a low-angle shot,\nrevealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in\nfor a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.\nBirch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic\nshadows and warm highlights. Medium composition, front view, low angle, with depth of field.\n\"\"\"\nnegative_prompt = \"\"\"\nBright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,\nlow quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,\nmisshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\n\"\"\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=81,\n    guidance_scale=5.0,\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=16)\n```\n\n</hfoption>\n<hfoption id=\"T2V inference speed\">\n\n[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.\n\n```py\n# pip install ftfy\nimport torch\nimport numpy as np\nfrom diffusers import AutoModel, WanPipeline\nfrom diffusers.hooks.group_offloading import apply_group_offloading\nfrom diffusers.utils import export_to_video, load_image\nfrom transformers import UMT5EncoderModel\n\ntext_encoder = UMT5EncoderModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"text_encoder\", torch_dtype=torch.bfloat16)\nvae = AutoModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"vae\", torch_dtype=torch.float32)\ntransformer = AutoModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n\npipeline = WanPipeline.from_pretrained(\n    \"Wan-AI/Wan2.1-T2V-14B-Diffusers\",\n    vae=vae,\n    transformer=transformer,\n    text_encoder=text_encoder,\n    torch_dtype=torch.bfloat16\n)\npipeline.to(\"cuda\")\n\n# torch.compile\npipeline.transformer.to(memory_format=torch.channels_last)\npipeline.transformer = torch.compile(\n    pipeline.transformer, mode=\"max-autotune\", fullgraph=True\n)\n\nprompt = \"\"\"\nThe camera rushes from far to near in a low-angle shot,\nrevealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in\nfor a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.\nBirch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic\nshadows and warm highlights. Medium composition, front view, low angle, with depth of field.\n\"\"\"\nnegative_prompt = \"\"\"\nBright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,\nlow quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,\nmisshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\n\"\"\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=81,\n    guidance_scale=5.0,\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=16)\n```\n\n</hfoption>\n</hfoptions>\n\n### First-Last-Frame-to-Video Generation\n\nThe example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame.\n\n<hfoptions id=\"FLF2V usage\">\n<hfoption id=\"usage\">\n\n```python\nimport numpy as np\nimport torch\nimport torchvision.transforms.functional as TF\nfrom diffusers import AutoencoderKLWan, WanImageToVideoPipeline\nfrom diffusers.utils import export_to_video, load_image\nfrom transformers import CLIPVisionModel\n\n\nmodel_id = \"Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers\"\nimage_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder=\"image_encoder\", torch_dtype=torch.float32)\nvae = AutoencoderKLWan.from_pretrained(model_id, subfolder=\"vae\", torch_dtype=torch.float32)\npipe = WanImageToVideoPipeline.from_pretrained(\n    model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16\n)\npipe.to(\"cuda\")\n\nfirst_frame = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png\")\nlast_frame = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png\")\n\ndef aspect_ratio_resize(image, pipe, max_area=720 * 1280):\n    aspect_ratio = image.height / image.width\n    mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]\n    height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value\n    width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value\n    image = image.resize((width, height))\n    return image, height, width\n\ndef center_crop_resize(image, height, width):\n    # Calculate resize ratio to match first frame dimensions\n    resize_ratio = max(width / image.width, height / image.height)\n\n    # Resize the image\n    width = round(image.width * resize_ratio)\n    height = round(image.height * resize_ratio)\n    size = [width, height]\n    image = TF.center_crop(image, size)\n\n    return image, height, width\n\nfirst_frame, height, width = aspect_ratio_resize(first_frame, pipe)\nif last_frame.size != first_frame.size:\n    last_frame, _, _ = center_crop_resize(last_frame, height, width)\n\nprompt = \"CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective.\"\n\noutput = pipe(\n    image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=16)\n```\n\n</hfoption>\n</hfoptions>\n\n### Any-to-Video Controllable Generation\n\nWan VACE supports various generation techniques which achieve controllable video generation. Some of the capabilities include:\n- Control to Video (Depth, Pose, Sketch, Flow, Grayscale, Scribble, Layout, Boundary Box, etc.). Recommended library for preprocessing videos to obtain control videos: [huggingface/controlnet_aux]()\n- Image/Video to Video (first frame, last frame, starting clip, ending clip, random clips)\n- Inpainting and Outpainting\n- Subject to Video (faces, object, characters, etc.)\n- Composition to Video (reference anything, animate anything, swap anything, expand anything, move anything, etc.)\n\nThe code snippets available in [this](https://github.com/huggingface/diffusers/pull/11582) pull request demonstrate some examples of how videos can be generated with controllability signals.\n\nThe general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.\n\n### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication\n\n[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.\n\n*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.*\n\nThe project page: https://humanaigc.github.io/wan-animate\n\nThis model was mostly contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).\n\n#### Usage\n\nThe Wan-Animate pipeline supports two modes of operation:\n\n1. **Animation Mode** (default): Animates a character image based on motion and expression from reference videos\n2. **Replacement Mode**: Replaces a character in a background video with a new character while preserving the scene\n\n##### Prerequisites\n\nBefore using the pipeline, you need to preprocess your reference video to extract:\n- **Pose video**: Contains skeletal keypoints representing body motion\n- **Face video**: Contains facial feature representations for expression control\n\nFor replacement mode, you additionally need:\n- **Background video**: The original video containing the scene\n- **Mask video**: A mask indicating where to generate content (white) vs. preserve original (black)\n\n> [!NOTE]\n> Raw videos should not be used for inputs such as `pose_video`, which the pipeline expects to be preprocessed to extract the proper information. Preprocessing scripts to prepare these inputs are available in the [original Wan-Animate repository](https://github.com/Wan-Video/Wan2.2?tab=readme-ov-file#1-preprocessing). Integration of these preprocessing steps into Diffusers is planned for a future release.\n\nThe example below demonstrates how to use the Wan-Animate pipeline:\n\n<hfoptions id=\"Animate usage\">\n<hfoption id=\"Animation mode\">\n\n```python\nimport numpy as np\nimport torch\nfrom diffusers import AutoencoderKLWan, WanAnimatePipeline\nfrom diffusers.utils import export_to_video, load_image, load_video\n\nmodel_id = \"Wan-AI/Wan2.2-Animate-14B-Diffusers\"\nvae = AutoencoderKLWan.from_pretrained(model_id, subfolder=\"vae\", torch_dtype=torch.float32)\npipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\n# Load character image and preprocessed videos\nimage = load_image(\"path/to/character.jpg\")\npose_video = load_video(\"path/to/pose_video.mp4\")  # Preprocessed skeletal keypoints\nface_video = load_video(\"path/to/face_video.mp4\")  # Preprocessed facial features\n\n# Resize image to match VAE constraints\ndef aspect_ratio_resize(image, pipe, max_area=720 * 1280):\n    aspect_ratio = image.height / image.width\n    mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]\n    height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value\n    width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value\n    image = image.resize((width, height))\n    return image, height, width\n\nimage, height, width = aspect_ratio_resize(image, pipe)\n\nprompt = \"A person dancing energetically in a studio with dynamic lighting and professional camera work\"\nnegative_prompt = \"blurry, low quality, distorted, deformed, static, poorly drawn\"\n\n# Generate animated video\noutput = pipe(\n    image=image,\n    pose_video=pose_video,\n    face_video=face_video,\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    segment_frame_length=77,\n    guidance_scale=1.0,\n    mode=\"animate\",  # Animation mode (default)\n).frames[0]\nexport_to_video(output, \"animated_character.mp4\", fps=30)\n```\n\n</hfoption>\n<hfoption id=\"Replacement mode\">\n\n```python\nimport numpy as np\nimport torch\nfrom diffusers import AutoencoderKLWan, WanAnimatePipeline\nfrom diffusers.utils import export_to_video, load_image, load_video\n\nmodel_id = \"Wan-AI/Wan2.2-Animate-14B-Diffusers\"\nvae = AutoencoderKLWan.from_pretrained(model_id, subfolder=\"vae\", torch_dtype=torch.float32)\npipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\n# Load all required inputs for replacement mode\nimage = load_image(\"path/to/new_character.jpg\")\npose_video = load_video(\"path/to/pose_video.mp4\")  # Preprocessed skeletal keypoints\nface_video = load_video(\"path/to/face_video.mp4\")  # Preprocessed facial features\nbackground_video = load_video(\"path/to/background_video.mp4\")  # Original scene\nmask_video = load_video(\"path/to/mask_video.mp4\")  # Black: preserve, White: generate\n\n# Resize image to match video dimensions\ndef aspect_ratio_resize(image, pipe, max_area=720 * 1280):\n    aspect_ratio = image.height / image.width\n    mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]\n    height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value\n    width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value\n    image = image.resize((width, height))\n    return image, height, width\n\nimage, height, width = aspect_ratio_resize(image, pipe)\n\nprompt = \"A person seamlessly integrated into the scene with consistent lighting and environment\"\nnegative_prompt = \"blurry, low quality, inconsistent lighting, floating, disconnected from scene\"\n\n# Replace character in background video\noutput = pipe(\n    image=image,\n    pose_video=pose_video,\n    face_video=face_video,\n    background_video=background_video,\n    mask_video=mask_video,\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    segment_frame_lengths=77,\n    guidance_scale=1.0,\n    mode=\"replace\",  # Replacement mode\n).frames[0]\nexport_to_video(output, \"character_replaced.mp4\", fps=30)\n```\n\n</hfoption>\n<hfoption id=\"Advanced options\">\n\n```python\nimport numpy as np\nimport torch\nfrom diffusers import AutoencoderKLWan, WanAnimatePipeline\nfrom diffusers.utils import export_to_video, load_image, load_video\n\nmodel_id = \"Wan-AI/Wan2.2-Animate-14B-Diffusers\"\nvae = AutoencoderKLWan.from_pretrained(model_id, subfolder=\"vae\", torch_dtype=torch.float32)\npipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\nimage = load_image(\"path/to/character.jpg\")\npose_video = load_video(\"path/to/pose_video.mp4\")\nface_video = load_video(\"path/to/face_video.mp4\")\n\ndef aspect_ratio_resize(image, pipe, max_area=720 * 1280):\n    aspect_ratio = image.height / image.width\n    mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]\n    height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value\n    width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value\n    image = image.resize((width, height))\n    return image, height, width\n\nimage, height, width = aspect_ratio_resize(image, pipe)\n\nprompt = \"A person dancing energetically in a studio\"\nnegative_prompt = \"blurry, low quality\"\n\n# Advanced: Use temporal guidance and custom callback\ndef callback_fn(pipe, step_index, timestep, callback_kwargs):\n    # You can modify latents or other tensors here\n    print(f\"Step {step_index}, Timestep {timestep}\")\n    return callback_kwargs\n\noutput = pipe(\n    image=image,\n    pose_video=pose_video,\n    face_video=face_video,\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    height=height,\n    width=width,\n    segment_frame_length=77,\n    num_inference_steps=50,\n    guidance_scale=5.0,\n    prev_segment_conditioning_frames=5,  # Use 5 frames for temporal guidance (1 or 5 recommended)\n    callback_on_step_end=callback_fn,\n    callback_on_step_end_tensor_inputs=[\"latents\"],\n).frames[0]\nexport_to_video(output, \"animated_advanced.mp4\", fps=30)\n```\n\n</hfoption>\n</hfoptions>\n\n#### Key Parameters\n\n- **mode**: Choose between `\"animate\"` (default) or `\"replace\"`\n- **prev_segment_conditioning_frames**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory\n- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt. For Wan-Animate, CFG is disabled by default (`guidance_scale=1.0`) but can be enabled to support negative prompts and finer control over facial expressions. (Note that CFG will only target the text prompt and face conditioning.)\n\n\n## Notes\n\n- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].\n\n  <details>\n  <summary>Show example code</summary>\n\n  ```py\n  # pip install ftfy\n  import torch\n  from diffusers import AutoModel, WanPipeline\n  from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler\n  from diffusers.utils import export_to_video\n\n  vae = AutoModel.from_pretrained(\n      \"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\", subfolder=\"vae\", torch_dtype=torch.float32\n  )\n  pipeline = WanPipeline.from_pretrained(\n      \"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\", vae=vae, torch_dtype=torch.bfloat16\n  )\n  pipeline.scheduler = UniPCMultistepScheduler.from_config(\n      pipeline.scheduler.config, flow_shift=5.0\n  )\n  pipeline.to(\"cuda\")\n\n  pipeline.load_lora_weights(\"benjamin-paine/steamboat-willie-1.3b\", adapter_name=\"steamboat-willie\")\n  pipeline.set_adapters(\"steamboat-willie\")\n\n  pipeline.enable_model_cpu_offload()\n\n  # use \"steamboat willie style\" to trigger the LoRA\n  prompt = \"\"\"\n  steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,\n  revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in\n  for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.\n  Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic\n  shadows and warm highlights. Medium composition, front view, low angle, with depth of field.\n  \"\"\"\n\n  output = pipeline(\n      prompt=prompt,\n      num_frames=81,\n      guidance_scale=5.0,\n  ).frames[0]\n  export_to_video(output, \"output.mp4\", fps=16)\n  ```\n\n  </details>\n\n- [`WanTransformer3DModel`] and [`AutoencoderKLWan`] supports loading from single files with [`~loaders.FromSingleFileMixin.from_single_file`].\n\n  <details>\n  <summary>Show example code</summary>\n\n  ```py\n  # pip install ftfy\n  import torch\n  from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan\n\n  vae = AutoencoderKLWan.from_single_file(\n      \"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors\"\n  )\n  transformer = WanTransformer3DModel.from_single_file(\n      \"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors\",\n      torch_dtype=torch.bfloat16\n  )\n  pipeline = WanPipeline.from_pretrained(\n      \"Wan-AI/Wan2.1-T2V-1.3B-Diffusers\",\n      vae=vae,\n      transformer=transformer,\n      torch_dtype=torch.bfloat16\n  )\n  ```\n\n  </details>\n\n- Set the [`AutoencoderKLWan`] dtype to `torch.float32` for better decoding quality.\n\n- The number of frames per second (fps) or `k` should be calculated by `4 * k + 1`.\n\n- Try lower `shift` values (`2.0` to `5.0`) for lower resolution videos and higher `shift` values (`7.0` to `12.0`) for higher resolution images.\n\n- Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more.\n\n- Wan 2.2 has two denoisers. By default, LoRAs are only loaded into the first denoiser. One can set `load_into_transformer_2=True` to load LoRAs into the second denoiser. Refer to [this](https://github.com/huggingface/diffusers/pull/12074#issue-3292620048) and [this](https://github.com/huggingface/diffusers/pull/12074#issuecomment-3155896144) examples to learn more.\n\n## WanPipeline\n\n[[autodoc]] WanPipeline\n  - all\n  - __call__\n\n## WanImageToVideoPipeline\n\n[[autodoc]] WanImageToVideoPipeline\n  - all\n  - __call__\n\n## WanVACEPipeline\n\n[[autodoc]] WanVACEPipeline\n  - all\n  - __call__\n\n## WanVideoToVideoPipeline\n\n[[autodoc]] WanVideoToVideoPipeline\n  - all\n  - __call__\n\n## WanAnimatePipeline\n\n[[autodoc]] WanAnimatePipeline\n  - all\n  - __call__\n\n## WanPipelineOutput\n\n[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput\n"
  },
  {
    "path": "docs/source/en/api/pipelines/wuerstchen.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Würstchen\n\n> [!WARNING]\n> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n<img src=\"https://github.com/dome272/Wuerstchen/assets/61938694/0617c863-165a-43ee-9303-2a17299a0cf9\">\n\n[Wuerstchen: An Efficient Architecture for Large-Scale Text-to-Image Diffusion Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, Mats L. Richter and Christopher Pal and Marc Aubreville.\n\nThe abstract from the paper is:\n\n*We introduce Würstchen, a novel architecture for text-to-image synthesis that combines competitive performance with unprecedented cost-effectiveness for large-scale text-to-image diffusion models. A key contribution of our work is to develop a latent diffusion technique in which we learn a detailed but extremely compact semantic image representation used to guide the diffusion process. This highly compressed representation of an image provides much more detailed guidance compared to latent representations of language and this significantly reduces the computational requirements to achieve state-of-the-art results. Our approach also improves the quality of text-conditioned image generation based on our user preference study. The training requirements of our approach consists of 24,602 A100-GPU hours - compared to Stable Diffusion 2.1's 200,000 GPU hours. Our approach also requires less training data to achieve these results. Furthermore, our compact latent representations allows us to perform inference over twice as fast, slashing the usual costs and carbon footprint of a state-of-the-art (SOTA) diffusion model significantly, without compromising the end performance. In a broader comparison against SOTA models our approach is substantially more efficient and compares favorably in terms of image quality. We believe that this work motivates more emphasis on the prioritization of both performance and computational accessibility.*\n\n## Würstchen Overview\nWürstchen is a diffusion model, whose text-conditional model works in a highly compressed latent space of images. Why is this important? Compressing data can reduce computational costs for both training and inference by magnitudes. Training on 1024x1024 images is way more expensive than training on 32x32. Usually, other works make use of a relatively small compression, in the range of 4x - 8x spatial compression. Würstchen takes this to an extreme. Through its novel design, we achieve a 42x spatial compression. This was unseen before because common methods fail to faithfully reconstruct detailed images after 16x spatial compression. Würstchen employs a two-stage compression, what we call Stage A and Stage B. Stage A is a VQGAN, and Stage B is a Diffusion Autoencoder (more details can be found in the [paper](https://huggingface.co/papers/2306.00637)). A third model, Stage C, is learned in that highly compressed latent space. This training requires fractions of the compute used for current top-performing models, while also allowing cheaper and faster inference.\n\n## Würstchen v2 comes to Diffusers\n\nAfter the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competitive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements.\n\n- Higher resolution (1024x1024 up to 2048x2048)\n- Faster inference\n- Multi Aspect Resolution Sampling\n- Better quality\n\n\nWe are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are:\n\n- v2-base\n- v2-aesthetic\n- **(default)** v2-interpolated (50% interpolation between v2-base and v2-aesthetic)\n\nWe recommend using v2-interpolated, as it has a nice touch of both photorealism and aesthetics. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations.\nA comparison can be seen here:\n\n<img src=\"https://github.com/dome272/Wuerstchen/assets/61938694/2914830f-cbd3-461c-be64-d50734f4b49d\" width=500>\n\n## Text-to-Image Generation\n\nFor the sake of usability, Würstchen can be used with a single pipeline. This pipeline can be used as follows:\n\n```python\nimport torch\nfrom diffusers import AutoPipelineForText2Image\nfrom diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS\n\npipe = AutoPipelineForText2Image.from_pretrained(\"warp-ai/wuerstchen\", torch_dtype=torch.float16).to(\"cuda\")\n\ncaption = \"Anthropomorphic cat dressed as a fire fighter\"\nimages = pipe(\n    caption,\n    width=1024,\n    height=1536,\n    prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,\n    prior_guidance_scale=4.0,\n    num_images_per_prompt=2,\n).images\n```\n\nFor explanation purposes, we can also initialize the two main pipelines of Würstchen individually. Würstchen consists of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating text-conditional images, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A are both encapsulated in the `decoder_pipeline`. For more details, take a look at the [paper](https://huggingface.co/papers/2306.00637).\n\n```python\nimport torch\nfrom diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline\nfrom diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS\n\ndevice = \"cuda\"\ndtype = torch.float16\nnum_images_per_prompt = 2\n\nprior_pipeline = WuerstchenPriorPipeline.from_pretrained(\n    \"warp-ai/wuerstchen-prior\", torch_dtype=dtype\n).to(device)\ndecoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(\n    \"warp-ai/wuerstchen\", torch_dtype=dtype\n).to(device)\n\ncaption = \"Anthropomorphic cat dressed as a fire fighter\"\nnegative_prompt = \"\"\n\nprior_output = prior_pipeline(\n    prompt=caption,\n    height=1024,\n    width=1536,\n    timesteps=DEFAULT_STAGE_C_TIMESTEPS,\n    negative_prompt=negative_prompt,\n    guidance_scale=4.0,\n    num_images_per_prompt=num_images_per_prompt,\n)\ndecoder_output = decoder_pipeline(\n    image_embeddings=prior_output.image_embeddings,\n    prompt=caption,\n    negative_prompt=negative_prompt,\n    guidance_scale=0.0,\n    output_type=\"pil\",\n).images[0]\ndecoder_output\n```\n\n## Speed-Up Inference\nYou can make use of `torch.compile` function and gain a speed-up of about 2-3x:\n\n```python\nprior_pipeline.prior = torch.compile(prior_pipeline.prior, mode=\"reduce-overhead\", fullgraph=True)\ndecoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode=\"reduce-overhead\", fullgraph=True)\n```\n\n## Limitations\n\n- Due to the high compression employed by Würstchen, generations can lack a good amount\nof detail. To our human eye, this is especially noticeable in faces, hands etc.\n- **Images can only be generated in 128-pixel steps**, e.g. the next higher resolution\nafter 1024x1024 is 1152x1152\n- The model lacks the ability to render correct text in images\n- The model often does not achieve photorealism\n- Difficult compositional prompts are hard for the model\n\nThe original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen).\n\n\n## WuerstchenCombinedPipeline\n\n[[autodoc]] WuerstchenCombinedPipeline\n\t- all\n\t- __call__\n\n## WuerstchenPriorPipeline\n\n[[autodoc]] WuerstchenPriorPipeline\n\t- all\n\t- __call__\n\n## WuerstchenPriorPipelineOutput\n\n[[autodoc]] pipelines.wuerstchen.pipeline_wuerstchen_prior.WuerstchenPriorPipelineOutput\n\n## WuerstchenDecoderPipeline\n\n[[autodoc]] WuerstchenDecoderPipeline\n\t- all\n\t- __call__\n\n## Citation\n\n```bibtex\n      @misc{pernias2023wuerstchen,\n            title={Wuerstchen: An Efficient Architecture for Large-Scale Text-to-Image Diffusion Models},\n            author={Pablo Pernias and Dominic Rampas and Mats L. Richter and Christopher J. Pal and Marc Aubreville},\n            year={2023},\n            eprint={2306.00637},\n            archivePrefix={arXiv},\n            primaryClass={cs.CV}\n      }\n```\n"
  },
  {
    "path": "docs/source/en/api/pipelines/z_image.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Z-Image\n\n<div class=\"flex flex-wrap space-x-1\">\n  <img alt=\"LoRA\" src=\"https://img.shields.io/badge/LoRA-d8b4fe?style=flat\"/>\n</div>\n\n[Z-Image](https://huggingface.co/papers/2511.22699) is a powerful and highly efficient image generation model with 6B parameters. Currently there's only one model with two more to be released:\n\n|Model|Hugging Face|\n|---|---|\n|Z-Image-Turbo|https://huggingface.co/Tongyi-MAI/Z-Image-Turbo|\n\n## Z-Image-Turbo\n\nZ-Image-Turbo is a distilled version of Z-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence.\n\n## Image-to-image\n\nUse [`ZImageImg2ImgPipeline`] to transform an existing image based on a text prompt.\n\n```python\nimport torch\nfrom diffusers import ZImageImg2ImgPipeline\nfrom diffusers.utils import load_image\n\npipe = ZImageImg2ImgPipeline.from_pretrained(\"Tongyi-MAI/Z-Image-Turbo\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\ninit_image = load_image(url).resize((1024, 1024))\n\nprompt = \"A fantasy landscape with mountains and a river, detailed, vibrant colors\"\nimage = pipe(\n    prompt,\n    image=init_image,\n    strength=0.6,\n    num_inference_steps=9,\n    guidance_scale=0.0,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).images[0]\nimage.save(\"zimage_img2img.png\")\n```\n\n## Inpainting\n\nUse [`ZImageInpaintPipeline`] to inpaint specific regions of an image based on a text prompt and mask.\n\n```python\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom diffusers import ZImageInpaintPipeline\nfrom diffusers.utils import load_image\n\npipe = ZImageInpaintPipeline.from_pretrained(\"Tongyi-MAI/Z-Image-Turbo\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\ninit_image = load_image(url).resize((1024, 1024))\n\n# Create a mask (white = inpaint, black = preserve)\nmask = np.zeros((1024, 1024), dtype=np.uint8)\nmask[256:768, 256:768] = 255  # Inpaint center region\nmask_image = Image.fromarray(mask)\n\nprompt = \"A beautiful lake with mountains in the background\"\nimage = pipe(\n    prompt,\n    image=init_image,\n    mask_image=mask_image,\n    strength=1.0,\n    num_inference_steps=9,\n    guidance_scale=0.0,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).images[0]\nimage.save(\"zimage_inpaint.png\")\n```\n\n## ZImagePipeline\n\n[[autodoc]] ZImagePipeline\n\t- all\n\t- __call__\n\n## ZImageImg2ImgPipeline\n\n[[autodoc]] ZImageImg2ImgPipeline\n\t- all\n\t- __call__\n\n## ZImageInpaintPipeline\n\n[[autodoc]] ZImageInpaintPipeline\n\t- all\n\t- __call__\n"
  },
  {
    "path": "docs/source/en/api/quantization.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n-->\n\n# Quantization\n\nQuantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference.\n\n> [!TIP]\n> Learn how to quantize models in the [Quantization](../quantization/overview) guide.\n\n## PipelineQuantizationConfig\n\n[[autodoc]] quantizers.PipelineQuantizationConfig\n\n## BitsAndBytesConfig\n\n[[autodoc]] quantizers.quantization_config.BitsAndBytesConfig\n\n## GGUFQuantizationConfig\n\n[[autodoc]] quantizers.quantization_config.GGUFQuantizationConfig\n\n## QuantoConfig\n\n[[autodoc]] quantizers.quantization_config.QuantoConfig\n\n## TorchAoConfig\n\n[[autodoc]] quantizers.quantization_config.TorchAoConfig\n\n## DiffusersQuantizer\n\n[[autodoc]] quantizers.base.DiffusersQuantizer\n"
  },
  {
    "path": "docs/source/en/api/schedulers/cm_stochastic_iterative.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# CMStochasticIterativeScheduler\n\n[Consistency Models](https://huggingface.co/papers/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever introduced a multistep and onestep scheduler (Algorithm 1) that is capable of generating good samples in one or a small number of steps.\n\nThe abstract from the paper is:\n\n*Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256.*\n\nThe original codebase can be found at [openai/consistency_models](https://github.com/openai/consistency_models).\n\n## CMStochasticIterativeScheduler\n[[autodoc]] CMStochasticIterativeScheduler\n\n## CMStochasticIterativeSchedulerOutput\n[[autodoc]] schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/consistency_decoder.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ConsistencyDecoderScheduler\n\nThis scheduler is a part of the [`ConsistencyDecoderPipeline`] and was introduced in [DALL-E 3](https://openai.com/dall-e-3).\n\nThe original codebase can be found at [openai/consistency_models](https://github.com/openai/consistency_models).\n\n\n## ConsistencyDecoderScheduler\n[[autodoc]] schedulers.scheduling_consistency_decoder.ConsistencyDecoderScheduler\n"
  },
  {
    "path": "docs/source/en/api/schedulers/cosine_dpm.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# CosineDPMSolverMultistepScheduler\n\nThe [`CosineDPMSolverMultistepScheduler`] is a variant of [`DPMSolverMultistepScheduler`] with cosine schedule, proposed by Nichol and Dhariwal (2021).\nIt is being used in the [Stable Audio Open](https://huggingface.co/papers/2407.14358) paper and the [Stability-AI/stable-audio-tool](https://github.com/Stability-AI/stable-audio-tools) codebase.\n\nThis scheduler was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe).\n\n## CosineDPMSolverMultistepScheduler\n[[autodoc]] CosineDPMSolverMultistepScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/ddim.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DDIMScheduler\n\n[Denoising Diffusion Implicit Models](https://huggingface.co/papers/2010.02502) (DDIM) by Jiaming Song, Chenlin Meng and Stefano Ermon.\n\nThe abstract from the paper is:\n\n*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample.\nTo accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models\nwith the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process.\nWe construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from.\nWe empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.*\n\nThe original codebase of this paper can be found at [ermongroup/ddim](https://github.com/ermongroup/ddim), and you can contact the author on [tsong.me](https://tsong.me/).\n\n## Tips\n\nThe paper [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) claims that a mismatch between the training and inference settings leads to suboptimal inference generation results for Stable Diffusion. To fix this, the authors propose:\n\n> [!WARNING]\n> 🧪 This is an experimental feature!\n\n1. rescale the noise schedule to enforce zero terminal signal-to-noise ratio (SNR)\n\n```py\npipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True)\n```\n\n2. train a model with `v_prediction` (add the following argument to the [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) scripts)\n\n```bash\n--prediction_type=\"v_prediction\"\n```\n\n3. change the sampler to always start from the last timestep\n\n```py\npipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing=\"trailing\")\n```\n\n4. rescale classifier-free guidance to prevent over-exposure\n\n```py\nimage = pipe(prompt, guidance_rescale=0.7).images[0]\n```\n\nFor example:\n\n```py\nfrom diffusers import DiffusionPipeline, DDIMScheduler\nimport torch\n\npipe = DiffusionPipeline.from_pretrained(\"ptx0/pseudo-journey-v2\", torch_dtype=torch.float16)\npipe.scheduler = DDIMScheduler.from_config(\n    pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing=\"trailing\"\n)\npipe.to(\"cuda\")\n\nprompt = \"A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k\"\nimage = pipe(prompt, guidance_rescale=0.7).images[0]\nimage\n```\n\n## DDIMScheduler\n[[autodoc]] DDIMScheduler\n\n## DDIMSchedulerOutput\n[[autodoc]] schedulers.scheduling_ddim.DDIMSchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/ddim_cogvideox.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# CogVideoXDDIMScheduler\n\n`CogVideoXDDIMScheduler` is based on [Denoising Diffusion Implicit Models](https://huggingface.co/papers/2010.02502), specifically for CogVideoX models.\n\n## CogVideoXDDIMScheduler\n\n[[autodoc]] CogVideoXDDIMScheduler\n"
  },
  {
    "path": "docs/source/en/api/schedulers/ddim_inverse.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DDIMInverseScheduler\n\n`DDIMInverseScheduler` is the inverted scheduler from [Denoising Diffusion Implicit Models](https://huggingface.co/papers/2010.02502) (DDIM) by Jiaming Song, Chenlin Meng and Stefano Ermon.\nThe implementation is mostly based on the DDIM inversion definition from [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://huggingface.co/papers/2211.09794).\n\n## DDIMInverseScheduler\n[[autodoc]] DDIMInverseScheduler\n"
  },
  {
    "path": "docs/source/en/api/schedulers/ddpm.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DDPMScheduler\n\n[Denoising Diffusion Probabilistic Models](https://huggingface.co/papers/2006.11239) (DDPM) by Jonathan Ho, Ajay Jain and Pieter Abbeel proposes a diffusion based model of the same name. In the context of the 🤗 Diffusers library, DDPM refers to the discrete denoising scheduler from the paper as well as the pipeline.\n\nThe abstract from the paper is:\n\n*We present high quality image synthesis results using diffusion probabilistic models, a class of latent variable models inspired by considerations from nonequilibrium thermodynamics. Our best results are obtained by training on a weighted variational bound designed according to a novel connection between diffusion probabilistic models and denoising score matching with Langevin dynamics, and our models naturally admit a progressive lossy decompression scheme that can be interpreted as a generalization of autoregressive decoding. On the unconditional CIFAR10 dataset, we obtain an Inception score of 9.46 and a state-of-the-art FID score of 3.17. On 256x256 LSUN, we obtain sample quality similar to ProgressiveGAN. Our implementation is available at [this https URL](https://github.com/hojonathanho/diffusion).*\n\n## DDPMScheduler\n[[autodoc]] DDPMScheduler\n\n## DDPMSchedulerOutput\n[[autodoc]] schedulers.scheduling_ddpm.DDPMSchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/deis.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DEISMultistepScheduler\n\nDiffusion Exponential Integrator Sampler (DEIS) is proposed in [Fast Sampling of Diffusion Models with Exponential Integrator](https://huggingface.co/papers/2204.13902) by Qinsheng Zhang and Yongxin Chen. `DEISMultistepScheduler` is a fast high order solver for diffusion ordinary differential equations (ODEs).\n\nThis implementation modifies the polynomial fitting formula in log-rho space instead of the original linear `t` space in the DEIS paper. The modification enjoys closed-form coefficients for exponential multistep update instead of replying on the numerical solver.\n\nThe abstract from the paper is:\n\n*The past few years have witnessed the great success of Diffusion models~(DMs) in generating high-fidelity samples in generative modeling tasks. A major limitation of the DM is its notoriously slow sampling procedure which normally requires hundreds to thousands of time discretization steps of the learned diffusion process to reach the desired accuracy. Our goal is to develop a fast sampling method for DMs with a much less number of steps while retaining high sample quality. To this end, we systematically analyze the sampling procedure in DMs and identify key factors that affect the sample quality, among which the method of discretization is most crucial. By carefully examining the learned diffusion process, we propose Diffusion Exponential Integrator Sampler~(DEIS). It is based on the Exponential Integrator designed for discretizing ordinary differential equations (ODEs) and leverages a semilinear structure of the learned diffusion process to reduce the discretization error. The proposed method can be applied to any DMs and can generate high-fidelity samples in as few as 10 steps. In our experiments, it takes about 3 minutes on one A6000 GPU to generate 50k images from CIFAR10. Moreover, by directly using pre-trained DMs, we achieve the state-of-art sampling performance when the number of score function evaluation~(NFE) is limited, e.g., 4.17 FID with 10 NFEs, 3.37 FID, and 9.74 IS with only 15 NFEs on CIFAR10. Code is available at [this https URL](https://github.com/qsh-zh/deis).*\n\n## Tips\n\nIt is recommended to set `solver_order` to 2 or 3, while `solver_order=1` is equivalent to [`DDIMScheduler`].\n\nDynamic thresholding from [Imagen](https://huggingface.co/papers/2205.11487) is supported, and for pixel-space\ndiffusion models, you can set `thresholding=True` to use the dynamic thresholding.\n\n## DEISMultistepScheduler\n[[autodoc]] DEISMultistepScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/dpm_discrete.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# KDPM2DiscreteScheduler\n\nThe `KDPM2DiscreteScheduler` is inspired by the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper, and the scheduler is ported from and created by [Katherine Crowson](https://github.com/crowsonkb/).\n\nThe original codebase can be found at [crowsonkb/k-diffusion](https://github.com/crowsonkb/k-diffusion).\n\n## KDPM2DiscreteScheduler\n[[autodoc]] KDPM2DiscreteScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/dpm_discrete_ancestral.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# KDPM2AncestralDiscreteScheduler\n\nThe `KDPM2DiscreteScheduler` with ancestral sampling is inspired by the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper, and the scheduler is ported from and created by [Katherine Crowson](https://github.com/crowsonkb/).\n\nThe original codebase can be found at [crowsonkb/k-diffusion](https://github.com/crowsonkb/k-diffusion).\n\n## KDPM2AncestralDiscreteScheduler\n[[autodoc]] KDPM2AncestralDiscreteScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/dpm_sde.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DPMSolverSDEScheduler\n\nThe `DPMSolverSDEScheduler` is inspired by the stochastic sampler from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper, and the scheduler is ported from and created by [Katherine Crowson](https://github.com/crowsonkb/).\n\n## DPMSolverSDEScheduler\n[[autodoc]] DPMSolverSDEScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/edm_euler.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# EDMEulerScheduler\n\nThe Karras formulation of the Euler scheduler (Algorithm 2) from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper by Karras et al. This is a fast scheduler which can often generate good outputs in 20-30 steps. The scheduler is based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by [Katherine Crowson](https://github.com/crowsonkb/).\n\n\n## EDMEulerScheduler\n[[autodoc]] EDMEulerScheduler\n\n## EDMEulerSchedulerOutput\n[[autodoc]] schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/edm_multistep_dpm_solver.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# EDMDPMSolverMultistepScheduler\n\n`EDMDPMSolverMultistepScheduler` is a [Karras formulation](https://huggingface.co/papers/2206.00364) of `DPMSolverMultistepScheduler`, a multistep scheduler from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.\n\nDPMSolver (and the improved version DPMSolver++) is a fast dedicated high-order solver for diffusion ODEs with convergence order guarantee. Empirically, DPMSolver sampling with only 20 steps can generate high-quality\nsamples, and it can generate quite good samples even in 10 steps.\n\n## EDMDPMSolverMultistepScheduler\n[[autodoc]] EDMDPMSolverMultistepScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/euler.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# EulerDiscreteScheduler\n\nThe Euler scheduler (Algorithm 2) is from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper by Karras et al. This is a fast scheduler which can often generate good outputs in 20-30 steps. The scheduler is based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by [Katherine Crowson](https://github.com/crowsonkb/).\n\n\n## EulerDiscreteScheduler\n[[autodoc]] EulerDiscreteScheduler\n\n## EulerDiscreteSchedulerOutput\n[[autodoc]] schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/euler_ancestral.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# EulerAncestralDiscreteScheduler\n\nA scheduler that uses ancestral sampling with Euler method steps. This is a fast scheduler which can often generate good outputs in 20-30 steps. The scheduler is based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72) implementation by [Katherine Crowson](https://github.com/crowsonkb/).\n\n## EulerAncestralDiscreteScheduler\n[[autodoc]] EulerAncestralDiscreteScheduler\n\n## EulerAncestralDiscreteSchedulerOutput\n[[autodoc]] schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/flow_match_euler_discrete.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# FlowMatchEulerDiscreteScheduler\n\n`FlowMatchEulerDiscreteScheduler` is based on the flow-matching sampling introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).\n\n## FlowMatchEulerDiscreteScheduler\n[[autodoc]] FlowMatchEulerDiscreteScheduler\n"
  },
  {
    "path": "docs/source/en/api/schedulers/flow_match_heun_discrete.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# FlowMatchHeunDiscreteScheduler\n\n`FlowMatchHeunDiscreteScheduler` is based on the flow-matching sampling introduced in [EDM](https://huggingface.co/papers/2403.03206).\n\n## FlowMatchHeunDiscreteScheduler\n[[autodoc]] FlowMatchHeunDiscreteScheduler\n"
  },
  {
    "path": "docs/source/en/api/schedulers/helios.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# HeliosScheduler\n\n`HeliosScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers).\n\n## HeliosScheduler\n[[autodoc]] HeliosScheduler\n\nscheduling_helios\n"
  },
  {
    "path": "docs/source/en/api/schedulers/helios_dmd.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# HeliosDMDScheduler\n\n`HeliosDMDScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers).\n\n## HeliosDMDScheduler\n[[autodoc]] HeliosDMDScheduler\n\nscheduling_helios_dmd\n"
  },
  {
    "path": "docs/source/en/api/schedulers/heun.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# HeunDiscreteScheduler\n\nThe Heun scheduler (Algorithm 1) is from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) paper by Karras et al. The scheduler is ported from the [k-diffusion](https://github.com/crowsonkb/k-diffusion) library and created by [Katherine Crowson](https://github.com/crowsonkb/).\n\n## HeunDiscreteScheduler\n[[autodoc]] HeunDiscreteScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/ipndm.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# IPNDMScheduler\n\n`IPNDMScheduler` is a fourth-order Improved Pseudo Linear Multistep scheduler. The original implementation can be found at [crowsonkb/v-diffusion-pytorch](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296).\n\n## IPNDMScheduler\n[[autodoc]] IPNDMScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/lcm.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Latent Consistency Model Multistep Scheduler\n\n## Overview\n\nMultistep and onestep scheduler (Algorithm 3) introduced alongside latent consistency models in the paper [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://huggingface.co/papers/2310.04378) by Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao.\nThis scheduler should be able to generate good samples from [`LatentConsistencyModelPipeline`] in 1-8 steps.\n\n## LCMScheduler\n[[autodoc]] LCMScheduler\n"
  },
  {
    "path": "docs/source/en/api/schedulers/lms_discrete.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# LMSDiscreteScheduler\n\n`LMSDiscreteScheduler` is a linear multistep scheduler for discrete beta schedules. The scheduler is ported from and created by [Katherine Crowson](https://github.com/crowsonkb/), and the original implementation can be found at [crowsonkb/k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181).\n\n## LMSDiscreteScheduler\n[[autodoc]] LMSDiscreteScheduler\n\n## LMSDiscreteSchedulerOutput\n[[autodoc]] schedulers.scheduling_lms_discrete.LMSDiscreteSchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/multistep_dpm_solver.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DPMSolverMultistepScheduler\n\n`DPMSolverMultistepScheduler` is a multistep scheduler from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.\n\nDPMSolver (and the improved version DPMSolver++) is a fast dedicated high-order solver for diffusion ODEs with convergence order guarantee. Empirically, DPMSolver sampling with only 20 steps can generate high-quality\nsamples, and it can generate quite good samples even in 10 steps.\n\n## Tips\n\nIt is recommended to set `solver_order` to 2 for guide sampling, and `solver_order=3` for unconditional sampling.\n\nDynamic thresholding from [Imagen](https://huggingface.co/papers/2205.11487) is supported, and for pixel-space\ndiffusion models, you can set both `algorithm_type=\"dpmsolver++\"` and `thresholding=True` to use the dynamic\nthresholding. This thresholding method is unsuitable for latent-space diffusion models such as\nStable Diffusion.\n\nThe SDE variant of DPMSolver and DPM-Solver++ is also supported, but only for the first and second-order solvers. This is a fast SDE solver for the reverse diffusion SDE. It is recommended to use the second-order `sde-dpmsolver++`.\n\n## DPMSolverMultistepScheduler\n[[autodoc]] DPMSolverMultistepScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/multistep_dpm_solver_cogvideox.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# CogVideoXDPMScheduler\n\n`CogVideoXDPMScheduler` is based on [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095), specifically for CogVideoX models.\n\n## CogVideoXDPMScheduler\n\n[[autodoc]] CogVideoXDPMScheduler\n"
  },
  {
    "path": "docs/source/en/api/schedulers/multistep_dpm_solver_inverse.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DPMSolverMultistepInverse\n\n`DPMSolverMultistepInverse` is the inverted scheduler from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.\n\nThe implementation is mostly based on the DDIM inversion definition of [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://huggingface.co/papers/2211.09794) and notebook implementation of the [`DiffEdit`] latent inversion from [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion/blob/main/diffedit.ipynb).\n\n## Tips\n\nDynamic thresholding from [Imagen](https://huggingface.co/papers/2205.11487) is supported, and for pixel-space\ndiffusion models, you can set both `algorithm_type=\"dpmsolver++\"` and `thresholding=True` to use the dynamic\nthresholding. This thresholding method is unsuitable for latent-space diffusion models such as\nStable Diffusion.\n\n## DPMSolverMultistepInverseScheduler\n[[autodoc]] DPMSolverMultistepInverseScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Schedulers\n\n🤗 Diffusers provides many scheduler functions for the diffusion process. A scheduler takes a model's output (the sample which the diffusion process is iterating on) and a timestep to return a denoised sample. The timestep is important because it dictates where in the diffusion process the step is; data is generated by iterating forward *n* timesteps and inference occurs by propagating backward through the timesteps. Based on the timestep, a scheduler may be *discrete* in which case the timestep is an `int` or *continuous* in which case the timestep is a `float`.\n\nDepending on the context, a scheduler defines how to iteratively add noise to an image or how to update a sample based on a model's output:\n\n- during *training*, a scheduler adds noise (there are different algorithms for how to add noise) to a sample to train a diffusion model\n- during *inference*, a scheduler defines how to update a sample based on a pretrained model's output\n\nMany schedulers are implemented from the [k-diffusion](https://github.com/crowsonkb/k-diffusion) library by [Katherine Crowson](https://github.com/crowsonkb/), and they're also widely used in A1111. To help you map the schedulers from k-diffusion and A1111 to the schedulers in 🤗 Diffusers, take a look at the table below:\n\n| A1111/k-diffusion    | 🤗 Diffusers                         | Usage                                                                                                         |\n|---------------------|-------------------------------------|---------------------------------------------------------------------------------------------------------------|\n| DPM++ 2M            | [`DPMSolverMultistepScheduler`]     |                                                                                                               |\n| DPM++ 2M Karras     | [`DPMSolverMultistepScheduler`]     | init with `use_karras_sigmas=True`                                                                            |\n| DPM++ 2M SDE        | [`DPMSolverMultistepScheduler`]     | init with `algorithm_type=\"sde-dpmsolver++\"`                                                                  |\n| DPM++ 2M SDE Karras | [`DPMSolverMultistepScheduler`]     | init with `use_karras_sigmas=True` and `algorithm_type=\"sde-dpmsolver++\"`                                     |\n| DPM++ 2S a          | N/A                                 | very similar to  `DPMSolverSinglestepScheduler`                         |\n| DPM++ 2S a Karras   | N/A                                 | very similar to  `DPMSolverSinglestepScheduler(use_karras_sigmas=True, ...)` |\n| DPM++ SDE           | [`DPMSolverSinglestepScheduler`]    |                                                                                                               |\n| DPM++ SDE Karras    | [`DPMSolverSinglestepScheduler`]    | init with `use_karras_sigmas=True`                                                                            |\n| DPM2                | [`KDPM2DiscreteScheduler`]          |                                                                                                               |\n| DPM2 Karras         | [`KDPM2DiscreteScheduler`]          | init with `use_karras_sigmas=True`                                                                            |\n| DPM2 a              | [`KDPM2AncestralDiscreteScheduler`] |                                                                                                               |\n| DPM2 a Karras       | [`KDPM2AncestralDiscreteScheduler`] | init with `use_karras_sigmas=True`                                                                            |\n| DPM adaptive        | N/A                                 |                                                                                                               |\n| DPM fast            | N/A                                 |                                                                                                               |\n| Euler               | [`EulerDiscreteScheduler`]          |                                                                                                               |\n| Euler a             | [`EulerAncestralDiscreteScheduler`] |                                                                                                               |\n| Heun                | [`HeunDiscreteScheduler`]           |                                                                                                               |\n| LMS                 | [`LMSDiscreteScheduler`]            |                                                                                                               |\n| LMS Karras          | [`LMSDiscreteScheduler`]            | init with `use_karras_sigmas=True`                                                                            |\n| N/A                 | [`DEISMultistepScheduler`]          |                                                                                                               |\n| N/A                 | [`UniPCMultistepScheduler`]         |                                                                                                               |\n\n## Noise schedules and schedule types\n| A1111/k-diffusion        | 🤗 Diffusers                                                               |\n|--------------------------|----------------------------------------------------------------------------|\n| Karras                   | init with `use_karras_sigmas=True`                                         |\n| sgm_uniform              | init with `timestep_spacing=\"trailing\"`                                    |\n| simple                   | init with `timestep_spacing=\"trailing\"`                                    |\n| exponential              | init with `timestep_spacing=\"linspace\"`, `use_exponential_sigmas=True`     |\n| beta                     | init with `timestep_spacing=\"linspace\"`, `use_beta_sigmas=True`            |\n\nAll schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.\n\n## SchedulerMixin\n[[autodoc]] SchedulerMixin\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n\n## KarrasDiffusionSchedulers\n\n[`KarrasDiffusionSchedulers`] are a broad generalization of schedulers in 🤗 Diffusers. The schedulers in this class are distinguished at a high level by their noise sampling strategy, the type of network and scaling, the training strategy, and how the loss is weighed.\n\nThe different schedulers in this class, depending on the ordinary differential equations (ODE) solver type, fall into the above taxonomy and provide a good abstraction for the design of the main schedulers implemented in 🤗 Diffusers. The schedulers in this class are given [here](https://github.com/huggingface/diffusers/blob/a69754bb879ed55b9b6dc9dd0b3cf4fa4124c765/src/diffusers/schedulers/scheduling_utils.py#L32).\n\n## PushToHubMixin\n\n[[autodoc]] utils.PushToHubMixin\n"
  },
  {
    "path": "docs/source/en/api/schedulers/pndm.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# PNDMScheduler\n\n`PNDMScheduler`, or pseudo numerical methods for diffusion models, uses more advanced ODE integration techniques like the Runge-Kutta and linear multi-step method. The original implementation can be found at [crowsonkb/k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181).\n\n## PNDMScheduler\n[[autodoc]] PNDMScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/repaint.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# RePaintScheduler\n\n`RePaintScheduler` is a DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks. It is designed to be used with the [`RePaintPipeline`], and it is based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://huggingface.co/papers/2201.09865) by Andreas Lugmayr et al.\n\nThe abstract from the paper is:\n\n*Free-form inpainting is the task of adding new content to an image in the regions specified by an arbitrary binary mask. Most existing approaches train for a certain distribution of masks, which limits their generalization capabilities to unseen mask types. Furthermore, training with pixel-wise and perceptual losses often leads to simple textural extensions towards the missing areas instead of semantically meaningful generation. In this work, we propose RePaint: A Denoising Diffusion Probabilistic Model (DDPM) based inpainting approach that is applicable to even extreme masks. We employ a pretrained unconditional DDPM as the generative prior. To condition the generation process, we only alter the reverse diffusion iterations by sampling the unmasked regions using the given image information. Since this technique does not modify or condition the original DDPM network itself, the model produces high-quality and diverse output images for any inpainting form. We validate our method for both faces and general-purpose image inpainting using standard and extreme masks. RePaint outperforms state-of-the-art Autoregressive, and GAN approaches for at least five out of six mask distributions. GitHub Repository: [this http URL](http://git.io/RePaint).*\n\nThe original implementation can be found at [andreas128/RePaint](https://github.com/andreas128/).\n\n## RePaintScheduler\n[[autodoc]] RePaintScheduler\n\n## RePaintSchedulerOutput\n[[autodoc]] schedulers.scheduling_repaint.RePaintSchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/score_sde_ve.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ScoreSdeVeScheduler\n\n`ScoreSdeVeScheduler` is a variance exploding stochastic differential equation (SDE) scheduler. It was introduced in the [Score-Based Generative Modeling through Stochastic Differential Equations](https://huggingface.co/papers/2011.13456) paper by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, Ben Poole.\n\nThe abstract from the paper is:\n\n*Creating noise from data is easy; creating data from noise is generative modeling. We present a stochastic differential equation (SDE) that smoothly transforms a complex data distribution to a known prior distribution by slowly injecting noise, and a corresponding reverse-time SDE that transforms the prior distribution back into the data distribution by slowly removing the noise. Crucially, the reverse-time SDE depends only on the time-dependent gradient field (\\aka, score) of the perturbed data distribution. By leveraging advances in score-based generative modeling, we can accurately estimate these scores with neural networks, and use numerical SDE solvers to generate samples. We show that this framework encapsulates previous approaches in score-based generative modeling and diffusion probabilistic modeling, allowing for new sampling procedures and new modeling capabilities. In particular, we introduce a predictor-corrector framework to correct errors in the evolution of the discretized reverse-time SDE. We also derive an equivalent neural ODE that samples from the same distribution as the SDE, but additionally enables exact likelihood computation, and improved sampling efficiency. In addition, we provide a new way to solve inverse problems with score-based models, as demonstrated with experiments on class-conditional generation, image inpainting, and colorization. Combined with multiple architectural improvements, we achieve record-breaking performance for unconditional image generation on CIFAR-10 with an Inception score of 9.89 and FID of 2.20, a competitive likelihood of 2.99 bits/dim, and demonstrate high fidelity generation of 1024 x 1024 images for the first time from a score-based generative model.*\n\n## ScoreSdeVeScheduler\n[[autodoc]] ScoreSdeVeScheduler\n\n## SdeVeOutput\n[[autodoc]] schedulers.scheduling_sde_ve.SdeVeOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/score_sde_vp.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ScoreSdeVpScheduler\n\n`ScoreSdeVpScheduler` is a variance preserving stochastic differential equation (SDE) scheduler.  It was introduced in the [Score-Based Generative Modeling through Stochastic Differential Equations](https://huggingface.co/papers/2011.13456) paper by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, Ben Poole.\n\nThe abstract from the paper is:\n\n*Creating noise from data is easy; creating data from noise is generative modeling. We present a stochastic differential equation (SDE) that smoothly transforms a complex data distribution to a known prior distribution by slowly injecting noise, and a corresponding reverse-time SDE that transforms the prior distribution back into the data distribution by slowly removing the noise. Crucially, the reverse-time SDE depends only on the time-dependent gradient field (\\aka, score) of the perturbed data distribution. By leveraging advances in score-based generative modeling, we can accurately estimate these scores with neural networks, and use numerical SDE solvers to generate samples. We show that this framework encapsulates previous approaches in score-based generative modeling and diffusion probabilistic modeling, allowing for new sampling procedures and new modeling capabilities. In particular, we introduce a predictor-corrector framework to correct errors in the evolution of the discretized reverse-time SDE. We also derive an equivalent neural ODE that samples from the same distribution as the SDE, but additionally enables exact likelihood computation, and improved sampling efficiency. In addition, we provide a new way to solve inverse problems with score-based models, as demonstrated with experiments on class-conditional generation, image inpainting, and colorization. Combined with multiple architectural improvements, we achieve record-breaking performance for unconditional image generation on CIFAR-10 with an Inception score of 9.89 and FID of 2.20, a competitive likelihood of 2.99 bits/dim, and demonstrate high fidelity generation of 1024 x 1024 images for the first time from a score-based generative model.*\n\n> [!WARNING]\n> 🚧 This scheduler is under construction!\n\n## ScoreSdeVpScheduler\n[[autodoc]] schedulers.deprecated.scheduling_sde_vp.ScoreSdeVpScheduler\n"
  },
  {
    "path": "docs/source/en/api/schedulers/singlestep_dpm_solver.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DPMSolverSinglestepScheduler\n\n`DPMSolverSinglestepScheduler` is a single step scheduler from [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu.\n\nDPMSolver (and the improved version DPMSolver++) is a fast dedicated high-order solver for diffusion ODEs with convergence order guarantee. Empirically, DPMSolver sampling with only 20 steps can generate high-quality\nsamples, and it can generate quite good samples even in 10 steps.\n\nThe original implementation can be found at [LuChengTHU/dpm-solver](https://github.com/LuChengTHU/dpm-solver).\n\n## Tips\n\nIt is recommended to set `solver_order` to 2 for guide sampling, and `solver_order=3` for unconditional sampling.\n\nDynamic thresholding from [Imagen](https://huggingface.co/papers/2205.11487) is supported, and for pixel-space\ndiffusion models, you can set both `algorithm_type=\"dpmsolver++\"` and `thresholding=True` to use dynamic\nthresholding. This thresholding method is unsuitable for latent-space diffusion models such as\nStable Diffusion.\n\n## DPMSolverSinglestepScheduler\n[[autodoc]] DPMSolverSinglestepScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/stochastic_karras_ve.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# KarrasVeScheduler\n\n`KarrasVeScheduler` is a stochastic sampler tailored to variance-expanding (VE) models. It is based on the [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) and [Score-based generative modeling through stochastic differential equations](https://huggingface.co/papers/2011.13456) papers.\n\n## KarrasVeScheduler\n[[autodoc]] KarrasVeScheduler\n\n## KarrasVeOutput\n[[autodoc]] schedulers.deprecated.scheduling_karras_ve.KarrasVeOutput"
  },
  {
    "path": "docs/source/en/api/schedulers/tcd.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# TCDScheduler\n\n[Trajectory Consistency Distillation](https://huggingface.co/papers/2402.19159) by Jianbin Zheng, Minghui Hu, Zhongyi Fan, Chaoyue Wang, Changxing Ding, Dacheng Tao and Tat-Jen Cham introduced a Strategic Stochastic Sampling (Algorithm 4) that is capable of generating good samples in a small number of steps. Distinguishing it as an advanced iteration of the multistep scheduler (Algorithm 1) in the [Consistency Models](https://huggingface.co/papers/2303.01469), Strategic Stochastic Sampling specifically tailored for the trajectory consistency function.\n\nThe abstract from the paper is:\n\n*Latent Consistency Model (LCM) extends the Consistency Model to the latent space and leverages the guided consistency distillation technique to achieve impressive performance in accelerating text-to-image synthesis. However, we observed that LCM struggles to generate images with both clarity and detailed intricacy. To address this limitation, we initially delve into and elucidate the underlying causes. Our investigation identifies that the primary issue stems from errors in three distinct areas. Consequently, we introduce Trajectory Consistency Distillation (TCD), which encompasses trajectory consistency function and strategic stochastic sampling. The trajectory consistency function diminishes the distillation errors by broadening the scope of the self-consistency boundary condition and endowing the TCD with the ability to accurately trace the entire trajectory of the Probability Flow ODE. Additionally, strategic stochastic sampling is specifically designed to circumvent the accumulated errors inherent in multi-step consistency sampling, which is meticulously tailored to complement the TCD model. Experiments demonstrate that TCD not only significantly enhances image quality at low NFEs but also yields more detailed results compared to the teacher model at high NFEs.*\n\nThe original codebase can be found at [jabir-zheng/TCD](https://github.com/jabir-zheng/TCD).\n\n## TCDScheduler\n[[autodoc]] TCDScheduler\n\n\n## TCDSchedulerOutput\n[[autodoc]] schedulers.scheduling_tcd.TCDSchedulerOutput\n\n"
  },
  {
    "path": "docs/source/en/api/schedulers/unipc.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# UniPCMultistepScheduler\n\n`UniPCMultistepScheduler` is a training-free framework designed for fast sampling of diffusion models. It was introduced in [UniPC: A Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models](https://huggingface.co/papers/2302.04867) by Wenliang Zhao, Lujia Bai, Yongming Rao, Jie Zhou, Jiwen Lu.\n\nIt consists of a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders.\nUniPC is by design model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can also be applied to both noise prediction and data prediction models. The corrector UniC can be also applied after any off-the-shelf solvers to increase the order of accuracy.\n\nThe abstract from the paper is:\n\n*Diffusion probabilistic models (DPMs) have demonstrated a very promising ability in high-resolution image synthesis. However, sampling from a pre-trained DPM is time-consuming due to the multiple evaluations of the denoising network, making it more and more important to accelerate the sampling of DPMs. Despite recent progress in designing fast samplers, existing methods still cannot generate satisfying images in many applications where fewer steps (e.g., <10) are favored. In this paper, we develop a unified corrector (UniC) that can be applied after any existing DPM sampler to increase the order of accuracy without extra model evaluations, and derive a unified predictor (UniP) that supports arbitrary order as a byproduct. Combining UniP and UniC, we propose a unified predictor-corrector framework called UniPC for the fast sampling of DPMs, which has a unified analytical form for any order and can significantly improve the sampling quality over previous methods, especially in extremely few steps. We evaluate our methods through extensive experiments including both unconditional and conditional sampling using pixel-space and latent-space DPMs. Our UniPC can achieve 3.87 FID on CIFAR10 (unconditional) and 7.51 FID on ImageNet 256×256 (conditional) with only 10 function evaluations. Code is available at [this https URL](https://github.com/wl-zhao/UniPC).*\n\n## Tips\n\nIt is recommended to set `solver_order` to 2 for guide sampling, and `solver_order=3` for unconditional sampling.\n\nDynamic thresholding from [Imagen](https://huggingface.co/papers/2205.11487) is supported, and for pixel-space\ndiffusion models, you can set both `predict_x0=True` and `thresholding=True` to use dynamic thresholding. This thresholding method is unsuitable for latent-space diffusion models such as Stable Diffusion.\n\n## UniPCMultistepScheduler\n[[autodoc]] UniPCMultistepScheduler\n\n## SchedulerOutput\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/schedulers/vq_diffusion.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# VQDiffusionScheduler\n\n`VQDiffusionScheduler` converts the transformer model's output into a sample for the unnoised image at the previous diffusion timestep. It was introduced in [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://huggingface.co/papers/2111.14822) by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo.\n\nThe abstract from the paper is:\n\n*We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality.*\n\n## VQDiffusionScheduler\n[[autodoc]] VQDiffusionScheduler\n\n## VQDiffusionSchedulerOutput\n[[autodoc]] schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput\n"
  },
  {
    "path": "docs/source/en/api/utilities.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Utilities\n\nUtility and helper functions for working with 🤗 Diffusers.\n\n## numpy_to_pil\n\n[[autodoc]] utils.numpy_to_pil\n\n## pt_to_pil\n\n[[autodoc]] utils.pt_to_pil\n\n## load_image\n\n[[autodoc]] utils.load_image\n\n## load_video\n\n[[autodoc]] utils.load_video\n\n## export_to_gif\n\n[[autodoc]] utils.export_to_gif\n\n## export_to_video\n\n[[autodoc]] utils.export_to_video\n\n## make_image_grid\n\n[[autodoc]] utils.make_image_grid\n\n## randn_tensor\n\n[[autodoc]] utils.torch_utils.randn_tensor\n\n## apply_layerwise_casting\n\n[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting\n\n## apply_group_offloading\n\n[[autodoc]] hooks.group_offloading.apply_group_offloading\n"
  },
  {
    "path": "docs/source/en/api/video_processor.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Video Processor\n\nThe [`VideoProcessor`] provides a unified API for video pipelines to prepare inputs for VAE encoding and post-processing outputs once they're decoded. The class inherits [`VaeImageProcessor`] so it includes transformations such as resizing, normalization, and conversion between PIL Image, PyTorch, and NumPy arrays.\n\n## VideoProcessor\n\n[[autodoc]] video_processor.VideoProcessor.preprocess_video\n\n[[autodoc]] video_processor.VideoProcessor.postprocess_video\n"
  },
  {
    "path": "docs/source/en/community_projects.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Community Projects\n\nWelcome to Community Projects. This space is dedicated to showcasing the incredible work and innovative applications created by our vibrant community using the `diffusers` library.\n\nThis section aims to:\n\n- Highlight diverse and inspiring projects built with `diffusers`\n- Foster knowledge sharing within our community\n- Provide real-world examples of how `diffusers` can be leveraged\n\nHappy exploring, and thank you for being part of the Diffusers community!\n\n<table>\n    <tr>\n        <th>Project Name</th>\n        <th>Description</th>\n    </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/carson-katri/dream-textures\"> dream-textures </a></td>\n    <td>Stable Diffusion built-in to Blender</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/megvii-research/HiDiffusion\"> HiDiffusion </a></td>\n    <td>Increases the resolution and speed of your diffusion model by only adding a single line of code</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/lllyasviel/IC-Light\"> IC-Light </a></td>\n    <td>IC-Light is a project to manipulate the illumination of images</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/InstantID/InstantID\"> InstantID </a></td>\n    <td>InstantID : Zero-shot Identity-Preserving Generation in Seconds</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/Sanster/IOPaint\"> IOPaint </a></td>\n    <td>Image inpainting tool powered by SOTA AI Model. Remove any unwanted object, defect, people from your pictures or erase and replace(powered by stable diffusion) any thing on your pictures.</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/bmaltais/kohya_ss\"> Kohya </a></td>\n    <td>Gradio GUI for Kohya's Stable Diffusion trainers</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/magic-research/magic-animate\"> MagicAnimate </a></td>\n    <td>MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/levihsu/OOTDiffusion\"> OOTDiffusion </a></td>\n    <td>Outfitting Fusion based Latent Diffusion for Controllable Virtual Try-on</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/vladmandic/automatic\"> SD.Next </a></td>\n    <td>SD.Next: Advanced Implementation of Stable Diffusion and other Diffusion-based generative image models</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/ashawkey/stable-dreamfusion\"> stable-dreamfusion </a></td>\n    <td>Text-to-3D & Image-to-3D & Mesh Exportation with NeRF + Diffusion</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/HVision-NKU/StoryDiffusion\"> StoryDiffusion </a></td>\n    <td>StoryDiffusion can create a magic story by generating consistent images and videos.</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/cumulo-autumn/StreamDiffusion\"> StreamDiffusion </a></td>\n    <td>A Pipeline-Level Solution for Real-Time Interactive Generation</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/Netwrck/stable-diffusion-server\"> Stable Diffusion Server </a></td>\n    <td>A server configured for Inpainting/Generation/img2img with one stable diffusion model</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/suzukimain/auto_diffusers\"> Model Search </a></td>\n    <td>Search models on Civitai and Hugging Face</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/beinsezii/skrample\"> Skrample </a></td>\n    <td>Fully modular scheduler functions with 1st class diffusers integration.</td>\n  </tr>\n</table>\n"
  },
  {
    "path": "docs/source/en/conceptual/contribution.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# How to contribute to Diffusers 🧨\n\nWe ❤️ contributions from the open-source community! Everyone is welcome, and all types of participation –not just code– are valued and appreciated. Answering questions, helping others, reaching out, and improving the documentation are all immensely valuable to the community, so don't be afraid and get involved if you're up for it!\n\nEveryone is encouraged to start by saying 👋 in our public Discord channel. We discuss the latest trends in diffusion models, ask questions, show off personal projects, help each other with contributions, or just hang out ☕. <a href=\"https://Discord.gg/G7tWnz98XR\"><img alt=\"Join us on Discord\" src=\"https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white\"></a>\n\nWhichever way you choose to contribute, we strive to be part of an open, welcoming, and kind community. Please, read our [code of conduct](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md) and be mindful to respect it during your interactions. We also recommend you become familiar with the [ethical guidelines](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines) that guide our project and ask you to adhere to the same principles of transparency and responsibility.\n\nWe enormously value feedback from the community, so please do not be afraid to speak up if you believe you have valuable feedback that can help improve the library - every message, comment, issue, and pull request (PR) is read and considered.\n\n## Overview\n\nYou can contribute in many ways ranging from answering questions on issues and discussions to adding new diffusion models to the core library.\n\nIn the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community.\n\n* 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR).\n* 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose) or new discussions on [the GitHub Discussions tab](https://github.com/huggingface/diffusers/discussions/new/choose).\n* 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues) or discussions on [the GitHub Discussions tab](https://github.com/huggingface/diffusers/discussions).\n* 4. Fix a simple issue, marked by the \"Good first issue\" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).\n* 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source).\n* 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples).\n* 7. Contribute to the [examples](https://github.com/huggingface/diffusers/tree/main/examples).\n* 8. Fix a more difficult issue, marked by the \"Good second issue\" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22).\n* 9. Add a new pipeline, model, or scheduler, see [\"New Pipeline/Model\"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) and [\"New scheduler\"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) issues. For this contribution, please have a look at [Design Philosophy](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md).\n\nAs said before, **all contributions are valuable to the community**.\nIn the following, we will explain each contribution a bit more in detail.\n\nFor all contributions 4 - 9, you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr).\n\n### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord\n\nAny question or comment related to the Diffusers library can be asked on the [discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/) or on [Discord](https://discord.gg/G7tWnz98XR). Such questions and comments include (but are not limited to):\n- Reports of training or inference experiments in an attempt to share knowledge\n- Presentation of personal projects\n- Questions to non-official training examples\n- Project proposals\n- General feedback\n- Paper summaries\n- Asking for help on personal projects that build on top of the Diffusers library\n- General questions\n- Ethical questions regarding diffusion models\n- ...\n\nEvery question that is asked on the forum or on Discord actively encourages the community to publicly\nshare knowledge and might very well help a beginner in the future who has the same question you're\nhaving. Please do pose any questions you might have.\nIn the same spirit, you are of immense help to the community by answering such questions because this way you are publicly documenting knowledge for everybody to learn from.\n\n**Please** keep in mind that the more effort you put into asking or answering a question, the higher\nthe quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.\nIn short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formatted/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.\n\n**NOTE about channels**:\n[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.\nIn addition, questions and answers posted in the forum can easily be linked to.\nIn contrast, *Discord* has a chat-like format that invites fast back-and-forth communication.\nWhile it will most likely take less time for you to get an answer to your question on Discord, your\nquestion won't be visible anymore over time. Also, it's much harder to find information that was posted a while back on Discord. We therefore strongly recommend using the forum for high-quality questions and answers in an attempt to create long-lasting knowledge for the community. If discussions on Discord lead to very interesting answers and conclusions, we recommend posting the results on the forum to make the information more available for future readers.\n\n### 2. Opening new issues on the GitHub issues tab\n\nThe 🧨 Diffusers library is robust and reliable thanks to the users who notify us of\nthe problems they encounter. So thank you for reporting an issue.\n\nRemember, GitHub issues are reserved for technical questions directly related to the Diffusers library, bug reports, feature requests, or feedback on the library design.\n\nIn a nutshell, this means that everything that is **not** related to the **code of the Diffusers library** (including the documentation) should **not** be asked on GitHub, but rather on either the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).\n\n**Please consider the following guidelines when opening a new issue**:\n- Make sure you have searched whether your issue has already been asked before (use the search bar on GitHub under Issues).\n- Please never report a new issue on another (related) issue. If another issue is highly related, please\nopen a new issue nevertheless and link to the related issue.\n- Make sure your issue is written in English. Please use one of the great, free online translation services, such as [DeepL](https://www.deepl.com/translator) to translate from your native language to English if you are not comfortable in English.\n- Check whether your issue might be solved by updating to the newest Diffusers version. Before posting your issue, please make sure that `python -c \"import diffusers; print(diffusers.__version__)\"` is higher or matches the latest Diffusers version.\n- Remember that the more effort you put into opening a new issue, the higher the quality of your answer will be and the better the overall quality of the Diffusers issues.\n\nNew issues usually include the following.\n\n#### 2.1. Reproducible, minimal bug reports\n\nA bug report should always have a reproducible code snippet and be as minimal and concise as possible.\nThis means in more detail:\n- Narrow the bug down as much as you can, **do not just dump your whole code file**.\n- Format your code.\n- Do not include any external libraries except for Diffusers depending on them.\n- **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue.\n- Explain the issue. If the reader doesn't know what the issue is and why it is an issue, (s)he cannot solve it.\n- **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell.\n- If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible.\n\nFor more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.\n\nYou can open a bug report [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml).\n\n#### 2.2. Feature requests\n\nA world-class feature request addresses the following points:\n\n1. Motivation first:\n* Is it related to a problem/frustration with the library? If so, please explain\nwhy. Providing a code snippet that demonstrates the problem is best.\n* Is it related to something you would need for a project? We'd love to hear\nabout it!\n* Is it something you worked on and think could benefit the community?\nAwesome! Tell us what problem it solved for you.\n2. Write a *full paragraph* describing the feature;\n3. Provide a **code snippet** that demonstrates its future use;\n4. In case this is related to a paper, please attach a link;\n5. Attach any additional information (drawings, screenshots, etc.) you think may help.\n\nYou can open a feature request [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=).\n\n#### 2.3 Feedback\n\nFeedback about the library design and why it is good or not good helps the core maintainers immensely to build a user-friendly library. To understand the philosophy behind the current design philosophy, please have a look [here](https://huggingface.co/docs/diffusers/conceptual/philosophy). If you feel like a certain design choice does not fit with the current design philosophy, please explain why and how it should be changed. If a certain design choice follows the design philosophy too much, hence restricting use cases, explain why and how it should be changed.\nIf a certain design choice is very useful for you, please also leave a note as this is great feedback for future design decisions.\n\nYou can open an issue about feedback [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).\n\n#### 2.4 Technical questions\n\nTechnical questions are mainly about why certain code of the library was written in a certain way, or what a certain part of the code does. Please make sure to link to the code in question and please provide details on\nwhy this part of the code is difficult to understand.\n\nYou can open an issue about a technical question [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml).\n\n#### 2.5 Proposal to add a new model, scheduler, or pipeline\n\nIf the diffusion model community released a new model, pipeline, or scheduler that you would like to see in the Diffusers library, please provide the following information:\n\n* Short description of the diffusion pipeline, model, or scheduler and link to the paper or public release.\n* Link to any of its open-source implementation(s).\n* Link to the model weights if they are available.\n\nIf you are willing to contribute to the model yourself, let us know so we can best guide you. Also, don't forget\nto tag the original author of the component (model, scheduler, pipeline, etc.) by GitHub handle if you can find it.\n\nYou can open a request for a model/pipeline/scheduler [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml).\n\n### 3. Answering issues on the GitHub issues tab\n\nAnswering issues on GitHub might require some technical knowledge of Diffusers, but we encourage everybody to give it a try even if you are not 100% certain that your answer is correct.\nSome tips to give a high-quality answer to an issue:\n- Be as concise and minimal as possible.\n- Stay on topic. An answer to the issue should concern the issue and only the issue.\n- Provide links to code, papers, or other sources that prove or encourage your point.\n- Answer in code. If a simple code snippet is the answer to the issue or shows how the issue can be solved, please provide a fully reproducible code snippet.\n\nAlso, many issues tend to be simply off-topic, duplicates of other issues, or irrelevant. It is of great\nhelp to the maintainers if you can answer such issues, encouraging the author of the issue to be\nmore precise, provide the link to a duplicated issue or redirect them to [the forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).\n\nIf you have verified that the issued bug report is correct and requires a correction in the source code,\nplease have a look at the next sections.\n\nFor all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section.\n\n### 4. Fixing a \"Good first issue\"\n\n*Good first issues* are marked by the [Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) label. Usually, the issue already\nexplains how a potential solution should look so that it is easier to fix.\nIf the issue hasn't been closed and you would like to try to fix this issue, you can just leave a message \"I would like to try this issue.\". There are usually three scenarios:\n- a.) The issue description already proposes a fix. In this case and if the solution makes sense to you, you can open a PR or draft PR to fix it.\n- b.) The issue description does not propose a fix. In this case, you can ask what a proposed fix could look like and someone from the Diffusers team should answer shortly. If you have a good idea of how to fix it, feel free to directly open a PR.\n- c.) There is already an open PR to fix the issue, but the issue hasn't been closed yet. If the PR has gone stale, you can simply open a new PR and link to the stale PR. PRs often go stale if the original contributor who wanted to fix the issue suddenly cannot find the time anymore to proceed. This often happens in open-source and is very normal. In this case, the community will be very happy if you give it a new try and leverage the knowledge of the existing PR. If there is already a PR and it is active, you can help the author by giving suggestions, reviewing the PR or even asking whether you can contribute to the PR.\n\n\n### 5. Contribute to the documentation\n\nA good library **always** has good documentation! The official documentation is often one of the first points of contact for new users of the library, and therefore contributing to the documentation is a **highly\nvaluable contribution**.\n\nContributing to the library can have many forms:\n\n- Correcting spelling or grammatical errors.\n- Correct incorrect formatting of the docstring. If you see that the official documentation is weirdly displayed or a link is broken, we would be very happy if you take some time to correct it.\n- Correct the shape or dimensions of a docstring input or output tensor.\n- Clarify documentation that is hard to understand or incorrect.\n- Update outdated code examples.\n- Translating the documentation to another language.\n\nAnything displayed on [the official Diffusers doc page](https://huggingface.co/docs/diffusers/index) is part of the official documentation and can be corrected, adjusted in the respective [documentation source](https://github.com/huggingface/diffusers/tree/main/docs/source).\n\nPlease have a look at [this page](https://github.com/huggingface/diffusers/tree/main/docs) on how to verify changes made to the documentation locally.\n\n### 6. Contribute a community pipeline\n\n> [!TIP]\n> Read the [Community pipelines](../using-diffusers/custom_pipeline_overview#community-pipelines) guide to learn more about the difference between a GitHub and Hugging Face Hub community pipeline. If you're interested in why we have community pipelines, take a look at GitHub Issue [#841](https://github.com/huggingface/diffusers/issues/841) (basically, we can't maintain all the possible ways diffusion models can be used for inference but we also don't want to prevent the community from building them).\n\nContributing a community pipeline is a great way to share your creativity and work with the community. It lets you build on top of the [`DiffusionPipeline`] so that anyone can load and use it by setting the `custom_pipeline` parameter. This section will walk you through how to create a simple pipeline where the UNet only does a single forward pass and calls the scheduler once (a \"one-step\" pipeline).\n\n1. Create a one_step_unet.py file for your community pipeline. This file can contain whatever package you want to use as long as it's installed by the user. Make sure you only have one pipeline class that inherits from [`DiffusionPipeline`] to load model weights and the scheduler configuration from the Hub. Add a UNet and scheduler to the `__init__` function.\n\n    You should also add the `register_modules` function to ensure your pipeline and its components can be saved with [`~DiffusionPipeline.save_pretrained`].\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\nclass UnetSchedulerOneForwardPipeline(DiffusionPipeline):\n    def __init__(self, unet, scheduler):\n        super().__init__()\n\n        self.register_modules(unet=unet, scheduler=scheduler)\n```\n\n1. In the forward pass (which we recommend defining as `__call__`), you can add any feature you'd like. For the \"one-step\" pipeline, create a random image and call the UNet and scheduler once by setting `timestep=1`.\n\n```py\n  from diffusers import DiffusionPipeline\n  import torch\n\n  class UnetSchedulerOneForwardPipeline(DiffusionPipeline):\n      def __init__(self, unet, scheduler):\n          super().__init__()\n\n          self.register_modules(unet=unet, scheduler=scheduler)\n\n      def __call__(self):\n          image = torch.randn(\n              (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),\n          )\n          timestep = 1\n\n          model_output = self.unet(image, timestep).sample\n          scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample\n\n          return scheduler_output\n```\n\nNow you can run the pipeline by passing a UNet and scheduler to it or load pretrained weights if the pipeline structure is identical.\n\n```py\nfrom diffusers import DDPMScheduler, UNet2DModel\n\nscheduler = DDPMScheduler()\nunet = UNet2DModel()\n\npipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler)\noutput = pipeline()\n# load pretrained weights\npipeline = UnetSchedulerOneForwardPipeline.from_pretrained(\"google/ddpm-cifar10-32\", use_safetensors=True)\noutput = pipeline()\n```\n\nYou can either share your pipeline as a GitHub community pipeline or Hub community pipeline.\n\n<hfoptions id=\"pipeline type\">\n<hfoption id=\"GitHub pipeline\">\n\nShare your GitHub pipeline by opening a pull request on the Diffusers [repository](https://github.com/huggingface/diffusers) and add the one_step_unet.py file to the [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) subfolder.\n\n</hfoption>\n<hfoption id=\"Hub pipeline\">\n\nShare your Hub pipeline by creating a model repository on the Hub and uploading the one_step_unet.py file to it.\n\n</hfoption>\n</hfoptions>\n\n### 7. Contribute to training examples\n\nDiffusers examples are a collection of training scripts that reside in [examples](https://github.com/huggingface/diffusers/tree/main/examples).\n\nWe support two types of training examples:\n\n- Official training examples\n- Research training examples\n\nResearch training examples are located in [examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) whereas official training examples include all folders under [examples](https://github.com/huggingface/diffusers/tree/main/examples) except the `research_projects` and `community` folders.\nThe official training examples are maintained by the Diffusers' core maintainers whereas the research training examples are maintained by the community.\nThis is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.\nIf the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.\n\nBoth official training and research examples consist of a directory that contains one or more training scripts, a `requirements.txt` file, and a `README.md` file. In order for the user to make use of the\ntraining examples, it is required to clone the repository:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\n```\n\nas well as to install all additional dependencies required for training:\n\n```bash\ncd diffusers\npip install -r examples/<your-example-folder>/requirements.txt\n```\n\nTherefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).\n\nTraining examples of the Diffusers library should adhere to the following philosophy:\n- All the code necessary to run the examples should be found in a single Python file.\n- One should be able to run the example from the command line with `python <your-example>.py --args`.\n- Examples should be kept simple and serve as **an example** on how to use Diffusers for training. The purpose of example scripts is **not** to create state-of-the-art diffusion models, but rather to reproduce known training schemes without adding too much custom logic. As a byproduct of this point, our examples also strive to serve as good educational materials.\n\nTo contribute an example, it is highly recommended to look at already existing examples such as [dreambooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) to get an idea of how they should look like.\nWe strongly advise contributors to make use of the [Accelerate library](https://github.com/huggingface/accelerate) as it's tightly integrated\nwith Diffusers.\nOnce an example script works, please make sure to add a comprehensive `README.md` that states how to use the example exactly. This README should include:\n- An example command on how to run the example script as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-locally-with-pytorch).\n- A link to some training results (logs, models, etc.) that show what the user can expect as shown [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).\n- If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations).\n\nIf you are contributing to the official training examples, please also make sure to add a test to its folder such as [examples/dreambooth/test_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/test_dreambooth.py). This is not necessary for non-official training examples.\n\n### 8. Fixing a \"Good second issue\"\n\n*Good second issues* are marked by the [Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22) label. Good second issues are\nusually more complicated to solve than [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).\nThe issue description usually gives less guidance on how to fix the issue and requires\na decent understanding of the library by the interested contributor.\nIf you are interested in tackling a good second issue, feel free to open a PR to fix it and link the PR to the issue. If you see that a PR has already been opened for this issue but did not get merged, have a look to understand why it wasn't merged and try to open an improved PR.\nGood second issues are usually more difficult to get merged compared to good first issues, so don't hesitate to ask for help from the core maintainers. If your PR is almost finished the core maintainers can also jump into your PR and commit to it in order to get it merged.\n\n### 9. Adding pipelines, models, schedulers\n\nPipelines, models, and schedulers are the most important pieces of the Diffusers library.\nThey provide easy access to state-of-the-art diffusion technologies and thus allow the community to\nbuild powerful generative AI applications.\n\nBy adding a new model, pipeline, or scheduler you might enable a new powerful use case for any of the user interfaces relying on Diffusers which can be of immense value for the whole generative AI ecosystem.\n\nDiffusers has a couple of open feature requests for all three components - feel free to gloss over them\nif you don't know yet what specific component you would like to add:\n- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)\n- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)\n\nBefore adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy\nas it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.\n\nPlease make sure to add links to the original codebase/paper to the PR and ideally also ping the original author directly on the PR so that they can follow the progress and potentially help with questions.\n\nIf you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help.\n\n#### Copied from mechanism\n\nA unique and important feature to understand when adding any pipeline, model or scheduler code is the `# Copied from` mechanism. You'll see this all over the Diffusers codebase, and the reason we use it is to keep the codebase easy to understand and maintain. Marking code with the `# Copied from` mechanism forces the marked code to be identical to the code it was copied from. This makes it easy to update and propagate changes across many files whenever you run `make fix-copies`.\n\nFor example, in the code example below, [`~diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is the original code and `AltDiffusionPipelineOutput` uses the `# Copied from` mechanism to copy it. The only difference is changing the class prefix from `Stable` to `Alt`.\n\n```py\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt\nclass AltDiffusionPipelineOutput(BaseOutput):\n    \"\"\"\n    Output class for Alt Diffusion pipelines.\n\n    Args:\n        images (`List[PIL.Image.Image]` or `np.ndarray`)\n            List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,\n            num_channels)`.\n        nsfw_content_detected (`List[bool]`)\n            List indicating whether the corresponding generated image contains \"not-safe-for-work\" (nsfw) content or\n            `None` if safety checking could not be performed.\n    \"\"\"\n```\n\nTo learn more, read this section of the [~Don't~ Repeat Yourself*](https://huggingface.co/blog/transformers-design-philosophy#4-machine-learning-models-are-static) blog post.\n\n## How to write a good issue\n\n**The better your issue is written, the higher the chances that it will be quickly resolved.**\n\n1. Make sure that you've used the correct template for your issue. You can pick between *Bug Report*, *Feature Request*, *Feedback about API Design*, *New model/pipeline/scheduler addition*, *Forum*, or a blank issue. Make sure to pick the correct one when opening [a new issue](https://github.com/huggingface/diffusers/issues/new/choose).\n2. **Be precise**: Give your issue a fitting title. Try to formulate your issue description as simple as possible. The more precise you are when submitting an issue, the less time it takes to understand the issue and potentially solve it. Make sure to open an issue for one issue only and not for multiple issues. If you found multiple issues, simply open multiple issues. If your issue is a bug, try to be as precise as possible about what bug it is - you should not just write \"Error in diffusers\".\n3. **Reproducibility**: No reproducible code snippet == no solution. If you encounter a bug, maintainers **have to be able to reproduce** it. Make sure that you include a code snippet that can be copy-pasted into a Python interpreter to reproduce the issue. Make sure that your code snippet works, *i.e.* that there are no missing imports or missing links to images, ... Your issue should contain an error message **and** a code snippet that can be copy-pasted without any changes to reproduce the exact same error message. If your issue is using local model weights or local data that cannot be accessed by the reader, the issue cannot be solved. If you cannot share your data or model, try to make a dummy model or dummy data.\n4. **Minimalistic**: Try to help the reader as much as you can to understand the issue as quickly as possible by staying as concise as possible. Remove all code / all information that is irrelevant to the issue. If you have found a bug, try to create the easiest code example you can to demonstrate your issue, do not just dump your whole workflow into the issue as soon as you have found a bug. E.g., if you train a model and get an error at some point during the training, you should first try to understand what part of the training code is responsible for the error and try to reproduce it with a couple of lines. Try to use dummy data instead of full datasets.\n5. Add links. If you are referring to a certain naming, method, or model make sure to provide a link so that the reader can better understand what you mean. If you are referring to a specific PR or issue, make sure to link it to your issue. Do not assume that the reader knows what you are talking about. The more links you add to your issue the better.\n6. Formatting. Make sure to nicely format your issue by formatting code into Python code syntax, and error messages into normal code syntax. See the [official GitHub formatting docs](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax) for more information.\n7. Think of your issue not as a ticket to be solved, but rather as a beautiful entry to a well-written encyclopedia. Every added issue is a contribution to publicly available knowledge. By adding a nicely written issue you not only make it easier for maintainers to solve your issue, but you are helping the whole community to better understand a certain aspect of the library.\n\n## How to write a good PR\n\n1. Be a chameleon. Understand existing design patterns and syntax and make sure your code additions flow seamlessly into the existing code base. Pull requests that significantly diverge from existing design patterns or user interfaces will not be merged.\n2. Be laser focused. A pull request should solve one problem and one problem only. Make sure to not fall into the trap of \"also fixing another problem while we're adding it\". It is much more difficult to review pull requests that solve multiple, unrelated problems at once.\n3. If helpful, try to add a code snippet that displays an example of how your addition can be used.\n4. The title of your pull request should be a summary of its contribution.\n5. If your pull request addresses an issue, please mention the issue number in\nthe pull request description to make sure they are linked (and people\nconsulting the issue know you are working on it);\n6. To indicate a work in progress please prefix the title with `[WIP]`. These\nare useful to avoid duplicated work, and to differentiate it from PRs ready\nto be merged;\n7. Try to formulate and format your text as explained in [How to write a good issue](#how-to-write-a-good-issue).\n8. Make sure existing tests pass;\n9. Add high-coverage tests. No quality testing = no merge.\n- If you are adding new `@slow` tests, make sure they pass using\n`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.\nCircleCI does not run the slow tests, but GitHub Actions does every night!\n10. All public methods must have informative docstrings that work nicely with markdown. See [`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) for an example.\n11. Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like\n[`hf-internal-testing`](https://huggingface.co/hf-internal-testing) or [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images) to place these files.\nIf an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images\nto this dataset.\n\n## How to open a PR\n\nBefore writing code, we strongly advise you to search through the existing PRs or\nissues to make sure that nobody is already working on the same thing. If you are\nunsure, it is always a good idea to open an issue to get some feedback.\n\nYou will need basic `git` proficiency to be able to contribute to\n🧨 Diffusers. `git` is not the easiest tool to use but it has the greatest\nmanual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro\nGit](https://git-scm.com/book/en/v2) is a very good reference.\n\nFollow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/83bc6c94eaeb6f7704a2a428931cf2d9ad973ae9/setup.py#L270)):\n\n1. Fork the [repository](https://github.com/huggingface/diffusers) by\nclicking on the 'Fork' button on the repository's page. This creates a copy of the code\nunder your GitHub user account.\n\n2. Clone your fork to your local disk, and add the base repository as a remote:\n\n ```bash\n $ git clone git@github.com:<your GitHub handle>/diffusers.git\n $ cd diffusers\n $ git remote add upstream https://github.com/huggingface/diffusers.git\n ```\n\n3. Create a new branch to hold your development changes:\n\n ```bash\n $ git checkout -b a-descriptive-name-for-my-changes\n ```\n\n**Do not** work on the `main` branch.\n\n4. Set up a development environment by running the following command in a virtual environment:\n\n ```bash\n $ pip install -e \".[dev]\"\n ```\n\nIf you have already cloned the repo, you might need to `git pull` to get the most recent changes in the\nlibrary.\n\n5. Develop the features on your branch.\n\nAs you work on the features, you should make sure that the test suite\npasses. You should run the tests impacted by your changes like this:\n\n ```bash\n $ pytest tests/<TEST_TO_RUN>.py\n ```\n\nBefore you run the tests, please make sure you install the dependencies required for testing. You can do so\nwith this command:\n\n ```bash\n $ pip install -e \".[test]\"\n ```\n\nYou can also run the full test suite with the following command, but it takes\na beefy machine to produce a result in a decent amount of time now that\nDiffusers has grown a lot. Here is the command for it:\n\n ```bash\n $ make test\n ```\n\n🧨 Diffusers relies on `black` and `isort` to format its source code\nconsistently. After you make changes, apply automatic style corrections and code verifications\nthat can't be automated in one go with:\n\n ```bash\n $ make style\n ```\n\n🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality\ncontrol runs in CI, however, you can also run the same checks with:\n\n ```bash\n $ make quality\n ```\n\nOnce you're happy with your changes, add changed files using `git add` and\nmake a commit with `git commit` to record your changes locally:\n\n ```bash\n $ git add modified_file.py\n $ git commit -m \"A descriptive message about your changes.\"\n ```\n\nIt is a good idea to sync your copy of the code with the original\nrepository regularly. This way you can quickly account for changes:\n\n ```bash\n $ git pull upstream main\n ```\n\nPush the changes to your account using:\n\n ```bash\n $ git push -u origin a-descriptive-name-for-my-changes\n ```\n\n6. Once you are satisfied, go to the\nwebpage of your fork on GitHub. Click on 'Pull request' to send your changes\nto the project maintainers for review.\n\n7. It's OK if maintainers ask you for changes. It happens to core contributors\ntoo! So everyone can see the changes in the Pull request, work in your local\nbranch and push the changes to your fork. They will automatically appear in\nthe pull request.\n\n### Tests\n\nAn extensive test suite is included to test the library behavior and several examples. Library tests can be found in\nthe [tests folder](https://github.com/huggingface/diffusers/tree/main/tests).\n\nWe like `pytest` and `pytest-xdist` because it's faster. From the root of the\nrepository, here's how to run tests with `pytest` for the library:\n\n```bash\n$ python -m pytest -n auto --dist=loadfile -s -v ./tests/\n```\n\nIn fact, that's how `make test` is implemented!\n\nYou can specify a smaller set of tests in order to test only the feature\nyou're working on.\n\nBy default, slow tests are skipped. Set the `RUN_SLOW` environment variable to\n`yes` to run them. This will download many gigabytes of models — make sure you\nhave enough disk space and a good Internet connection, or a lot of patience!\n\n```bash\n$ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/\n```\n\n`unittest` is fully supported, here's how to run tests with it:\n\n```bash\n$ python -m unittest discover -s tests -t . -v\n$ python -m unittest discover -s examples -t examples -v\n```\n\n### Syncing forked main with upstream (HuggingFace) main\n\nTo avoid pinging the upstream repository which adds reference notes to each upstream PR and sends unnecessary notifications to the developers involved in these PRs,\nwhen syncing the main branch of a forked repository, please, follow these steps:\n1. When possible, avoid syncing with the upstream using a branch and PR on the forked repository. Instead, merge directly into the forked main.\n2. If a PR is absolutely necessary, use the following steps after checking out your branch:\n```bash\n$ git checkout -b your-branch-for-syncing\n$ git pull --squash --no-commit upstream main\n$ git commit -m '<your message without GitHub references>'\n$ git push --set-upstream origin your-branch-for-syncing\n```\n\n### Style guide\n\nFor documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).\n\n\n## Coding with AI agents\n\nThe repository keeps AI-agent configuration in `.ai/` and exposes local agent files via symlinks.\n\n- **Source of truth** — edit files under `.ai/` (`AGENTS.md` for coding guidelines, `skills/` for on-demand task knowledge)\n- **Don't edit** generated root-level `AGENTS.md`, `CLAUDE.md`, or `.agents/skills`/`.claude/skills` — they are symlinks\n- Setup commands:\n  - `make codex` — symlink guidelines + skills for OpenAI Codex\n  - `make claude` — symlink guidelines + skills for Claude Code\n  - `make clean-ai` — remove all generated symlinks"
  },
  {
    "path": "docs/source/en/conceptual/ethical_guidelines.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 🧨 Diffusers’ Ethical Guidelines\n\n## Preamble\n\n[Diffusers](https://huggingface.co/docs/diffusers/index) provides pre-trained diffusion models and serves as a modular toolbox for inference and training.\n\nGiven its real case applications in the world and potential negative impacts on society, we think it is important to provide the project with ethical guidelines to guide the development, users’ contributions, and usage of the Diffusers library.\n\nThe risks associated with using this technology are still being examined, but to name a few: copyrights issues for artists; deep-fake exploitation; sexual content generation in inappropriate contexts; non-consensual impersonation; harmful social biases perpetuating the oppression of marginalized groups.\nWe will keep tracking risks and adapt the following guidelines based on the community's responsiveness and valuable feedback.\n\n\n## Scope\n\nThe Diffusers community will apply the following ethical guidelines to the project’s development and help coordinate how the community will integrate the contributions, especially concerning sensitive topics related to ethical concerns.\n\n\n## Ethical guidelines\n\nThe following ethical guidelines apply generally, but we will primarily implement them when dealing with ethically sensitive issues while making a technical choice. Furthermore, we commit to adapting those ethical principles over time following emerging harms related to the state of the art of the technology in question.\n\n- **Transparency**: we are committed to being transparent in managing PRs, explaining our choices to users, and making technical decisions.\n\n- **Consistency**: we are committed to guaranteeing our users the same level of attention in project management, keeping it technically stable and consistent.\n\n- **Simplicity**: with a desire to make it easy to use and exploit the Diffusers library, we are committed to keeping the project’s goals lean and coherent.\n\n- **Accessibility**: the Diffusers project helps lower the entry bar for contributors who can help run it even without technical expertise. Doing so makes research artifacts more accessible to the community.\n\n- **Reproducibility**: we aim to be transparent about the reproducibility of upstream code, models, and datasets when made available through the Diffusers library.\n\n- **Responsibility**: as a community and through teamwork, we hold a collective responsibility to our users by anticipating and mitigating this technology's potential risks and dangers.\n\n\n## Examples of implementations: Safety features and Mechanisms\n\nThe team works daily to make the technical and non-technical tools available to deal with the potential ethical and social risks associated with diffusion technology. Moreover, the community's input is invaluable in ensuring these features' implementation and raising awareness with us.\n\n- [**Community tab**](https://huggingface.co/docs/hub/repositories-pull-requests-discussions): it enables the community to discuss and better collaborate on a project.\n\n- **Bias exploration and evaluation**: the Hugging Face team provides a [space](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer) to demonstrate the biases in Stable Diffusion interactively. In this sense, we support and encourage bias explorers and evaluations.\n\n- **Encouraging safety in deployment**\n\n  - [**Safe Stable Diffusion**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_safe): It mitigates the well-known issue that models, like Stable Diffusion, that are trained on unfiltered, web-crawled datasets tend to suffer from inappropriate degeneration. Related paper: [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105).\n\n  - [**Safety Checker**](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py): It checks and compares the class probability of a set of hard-coded harmful concepts in the embedding space against an image after it has been generated. The harmful concepts are intentionally hidden to prevent reverse engineering of the checker.\n\n- **Staged released on the Hub**: in particularly sensitive situations, access to some repositories should be restricted. This staged release is an intermediary step that allows the repository’s authors to have more control over its use.\n\n- **Licensing**: [OpenRAILs](https://huggingface.co/blog/open_rail), a new type of licensing, allow us to ensure free access while having a set of restrictions that ensure more responsible use.\n"
  },
  {
    "path": "docs/source/en/conceptual/evaluation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Evaluating Diffusion Models\n\n<a target=\"_blank\" href=\"https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/evaluation.ipynb\">\n    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n</a>\n\n> [!TIP]\n> This document has now grown outdated given the emergence of existing evaluation frameworks for diffusion models for image generation. Please check\n> out works like [HEIM](https://crfm.stanford.edu/helm/heim/latest/), [T2I-Compbench](https://huggingface.co/papers/2307.06350),\n> [GenEval](https://huggingface.co/papers/2310.11513).\n\nEvaluation of generative models like [Stable Diffusion](https://huggingface.co/docs/diffusers/stable_diffusion) is subjective in nature. But as practitioners and researchers, we often have to make careful choices amongst many different possibilities. So, when working with different generative models (like GANs, Diffusion, etc.), how do we choose one over the other?\n\nQualitative evaluation of such models can be error-prone and might incorrectly influence a decision.\nHowever, quantitative metrics don't necessarily correspond to image quality. So, usually, a combination\nof both qualitative and quantitative evaluations provides a stronger signal when choosing one model\nover the other.\n\nIn this document, we provide a non-exhaustive overview of qualitative and quantitative methods to evaluate Diffusion models. For quantitative methods, we specifically focus on how to implement them alongside `diffusers`.\n\nThe methods shown in this document can also be used to evaluate different [noise schedulers](https://huggingface.co/docs/diffusers/main/en/api/schedulers/overview) keeping the underlying generation model fixed.\n\n## Scenarios\n\nWe cover Diffusion models with the following pipelines:\n\n- Text-guided image generation (such as the [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/text2img)).\n- Text-guided image generation, additionally conditioned on an input image (such as the [`StableDiffusionImg2ImgPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/img2img) and [`StableDiffusionInstructPix2PixPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix)).\n- Class-conditioned image generation models (such as the [`DiTPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/dit)).\n\n## Qualitative Evaluation\n\nQualitative evaluation typically involves human assessment of generated images. Quality is measured across aspects such as compositionality, image-text alignment, and spatial relations. Common prompts provide a degree of uniformity for subjective metrics.\nDrawBench and PartiPrompts are prompt datasets used for qualitative benchmarking. DrawBench and PartiPrompts were introduced by [Imagen](https://imagen.research.google/) and [Parti](https://parti.research.google/) respectively.\n\nFrom the [official Parti website](https://parti.research.google/):\n\n> PartiPrompts (P2) is a rich set of over 1600 prompts in English that we release as part of this work. P2 can be used to measure model capabilities across various categories and challenge aspects.\n\n![parti-prompts](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts.png)\n\nPartiPrompts has the following columns:\n\n- Prompt\n- Category of the prompt (such as “Abstract”, “World Knowledge”, etc.)\n- Challenge reflecting the difficulty (such as “Basic”, “Complex”, “Writing & Symbols”, etc.)\n\nThese benchmarks allow for side-by-side human evaluation of different image generation models.\n\nFor this, the 🧨 Diffusers team has built **Open Parti Prompts**, which is a community-driven qualitative benchmark based on Parti Prompts to compare state-of-the-art open-source diffusion models:\n- [Open Parti Prompts Game](https://huggingface.co/spaces/OpenGenAI/open-parti-prompts): For 10 parti prompts, 4 generated images are shown and the user selects the image that suits the prompt best.\n- [Open Parti Prompts Leaderboard](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard): The leaderboard comparing the currently best open-sourced diffusion models to each other.\n\nTo manually compare images, let’s see how we can use `diffusers` on a couple of PartiPrompts.\n\nBelow we show some prompts sampled across different challenges: Basic, Complex, Linguistic Structures, Imagination, and Writing & Symbols. Here we are using PartiPrompts as a [dataset](https://huggingface.co/datasets/nateraw/parti-prompts).\n\n```python\nfrom datasets import load_dataset\n\n# prompts = load_dataset(\"nateraw/parti-prompts\", split=\"train\")\n# prompts = prompts.shuffle()\n# sample_prompts = [prompts[i][\"Prompt\"] for i in range(5)]\n\n# Fixing these sample prompts in the interest of reproducibility.\nsample_prompts = [\n    \"a corgi\",\n    \"a hot air balloon with a yin-yang symbol, with the moon visible in the daytime sky\",\n    \"a car with no windows\",\n    \"a cube made of porcupine\",\n    'The saying \"BE EXCELLENT TO EACH OTHER\" written on a red brick wall with a graffiti image of a green alien wearing a tuxedo. A yellow fire hydrant is on a sidewalk in the foreground.',\n]\n```\n\nNow we can use these prompts to generate some images using Stable Diffusion ([v1-4 checkpoint](https://huggingface.co/CompVis/stable-diffusion-v1-4)):\n\n```python\nimport torch\n\nseed = 0\ngenerator = torch.manual_seed(seed)\n\nimages = sd_pipeline(sample_prompts, num_images_per_prompt=1, generator=generator).images\n```\n\n![parti-prompts-14](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-14.png)\n\nWe can also set `num_images_per_prompt` accordingly to compare different images for the same prompt. Running the same pipeline but with a different checkpoint ([v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)), yields:\n\n![parti-prompts-15](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-15.png)\n\nOnce several images are generated from all the prompts using multiple models (under evaluation), these results are presented to human evaluators for scoring. For\nmore details on the DrawBench and PartiPrompts benchmarks, refer to their respective papers.\n\n> [!TIP]\n> It is useful to look at some inference samples while a model is training to measure the\n> training progress. In our [training scripts](https://github.com/huggingface/diffusers/tree/main/examples/), we support this utility with additional support for\n> logging to TensorBoard and Weights & Biases.\n\n## Quantitative Evaluation\n\nIn this section, we will walk you through how to evaluate three different diffusion pipelines using:\n\n- CLIP score\n- CLIP directional similarity\n- FID\n\n### Text-guided image generation\n\n[CLIP score](https://huggingface.co/papers/2104.08718) measures the compatibility of image-caption pairs. Higher CLIP scores imply higher compatibility 🔼. The CLIP score is a quantitative measurement of the qualitative concept \"compatibility\". Image-caption pair compatibility can also be thought of as the semantic similarity between the image and the caption. CLIP score was found to have high correlation with human judgement.\n\nLet's first load a [`StableDiffusionPipeline`]:\n\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\nmodel_ckpt = \"CompVis/stable-diffusion-v1-4\"\nsd_pipeline = StableDiffusionPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16).to(\"cuda\")\n```\n\nGenerate some images with multiple prompts:\n\n```python\nprompts = [\n    \"a photo of an astronaut riding a horse on mars\",\n    \"A high tech solarpunk utopia in the Amazon rainforest\",\n    \"A pikachu fine dining with a view to the Eiffel Tower\",\n    \"A mecha robot in a favela in expressionist style\",\n    \"an insect robot preparing a delicious meal\",\n    \"A small cabin on top of a snowy mountain in the style of Disney, artstation\",\n]\n\nimages = sd_pipeline(prompts, num_images_per_prompt=1, output_type=\"np\").images\n\nprint(images.shape)\n# (6, 512, 512, 3)\n```\n\nAnd then, we calculate the CLIP score.\n\n```python\nfrom torchmetrics.functional.multimodal import clip_score\nfrom functools import partial\n\nclip_score_fn = partial(clip_score, model_name_or_path=\"openai/clip-vit-base-patch16\")\n\ndef calculate_clip_score(images, prompts):\n    images_int = (images * 255).astype(\"uint8\")\n    clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach()\n    return round(float(clip_score), 4)\n\nsd_clip_score = calculate_clip_score(images, prompts)\nprint(f\"CLIP score: {sd_clip_score}\")\n# CLIP score: 35.7038\n```\n\nIn the above example, we generated one image per prompt. If we generated multiple images per prompt, we would have to take the average score from the generated images per prompt.\n\nNow, if we wanted to compare two checkpoints compatible with the [`StableDiffusionPipeline`] we should pass a generator while calling the pipeline. First, we generate images with a\nfixed seed with the [v1-4 Stable Diffusion checkpoint](https://huggingface.co/CompVis/stable-diffusion-v1-4):\n\n```python\nseed = 0\ngenerator = torch.manual_seed(seed)\n\nimages = sd_pipeline(prompts, num_images_per_prompt=1, generator=generator, output_type=\"np\").images\n```\n\nThen we load the [v1-5 checkpoint](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) to generate images:\n\n```python\nmodel_ckpt_1_5 = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nsd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=torch.float16).to(\"cuda\")\n\nimages_1_5 = sd_pipeline_1_5(prompts, num_images_per_prompt=1, generator=generator, output_type=\"np\").images\n```\n\nAnd finally, we compare their CLIP scores:\n\n```python\nsd_clip_score_1_4 = calculate_clip_score(images, prompts)\nprint(f\"CLIP Score with v-1-4: {sd_clip_score_1_4}\")\n# CLIP Score with v-1-4: 34.9102\n\nsd_clip_score_1_5 = calculate_clip_score(images_1_5, prompts)\nprint(f\"CLIP Score with v-1-5: {sd_clip_score_1_5}\")\n# CLIP Score with v-1-5: 36.2137\n```\n\nIt seems like the [v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint performs better than its predecessor. Note, however, that the number of prompts we used to compute the CLIP scores is quite low. For a more practical evaluation, this number should be way higher, and the prompts should be diverse.\n\n> [!WARNING]\n> By construction, there are some limitations in this score. The captions in the training dataset\n> were crawled from the web and extracted from `alt` and similar tags associated an image on the internet.\n> They are not necessarily representative of what a human being would use to describe an image. Hence we\n> had to \"engineer\" some prompts here.\n\n### Image-conditioned text-to-image generation\n\nIn this case, we condition the generation pipeline with an input image as well as a text prompt. Let's take the [`StableDiffusionInstructPix2PixPipeline`], as an example. It takes an edit instruction as an input prompt and an input image to be edited.\n\nHere is one example:\n\n![edit-instruction](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png)\n\nOne strategy to evaluate such a model is to measure the consistency of the change between the two images (in [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) space) with the change between the two image captions (as shown in [CLIP-Guided Domain Adaptation of Image Generators](https://huggingface.co/papers/2108.00946)). This is referred to as the \"**CLIP directional similarity**\".\n\n- Caption 1 corresponds to the input image (image 1) that is to be edited.\n- Caption 2 corresponds to the edited image (image 2). It should reflect the edit instruction.\n\nFollowing is a pictorial overview:\n\n![edit-consistency](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-consistency.png)\n\nWe have prepared a mini dataset to implement this metric. Let's first load the dataset.\n\n```python\nfrom datasets import load_dataset\n\ndataset = load_dataset(\"sayakpaul/instructpix2pix-demo\", split=\"train\")\ndataset.features\n```\n\n```bash\n{'input': Value(dtype='string', id=None),\n 'edit': Value(dtype='string', id=None),\n 'output': Value(dtype='string', id=None),\n 'image': Image(decode=True, id=None)}\n```\n\nHere we have:\n\n- `input` is a caption corresponding to the `image`.\n- `edit` denotes the edit instruction.\n- `output` denotes the modified caption reflecting the `edit` instruction.\n\nLet's take a look at a sample.\n\n```python\nidx = 0\nprint(f\"Original caption: {dataset[idx]['input']}\")\nprint(f\"Edit instruction: {dataset[idx]['edit']}\")\nprint(f\"Modified caption: {dataset[idx]['output']}\")\n```\n\n```bash\nOriginal caption: 2. FAROE ISLANDS: An archipelago of 18 mountainous isles in the North Atlantic Ocean between Norway and Iceland, the Faroe Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'\nEdit instruction: make the isles all white marble\nModified caption: 2. WHITE MARBLE ISLANDS: An archipelago of 18 mountainous white marble isles in the North Atlantic Ocean between Norway and Iceland, the White Marble Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'\n```\n\nAnd here is the image:\n\n```python\ndataset[idx][\"image\"]\n```\n\n![edit-dataset](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-dataset.png)\n\nWe will first edit the images of our dataset with the edit instruction and compute the directional similarity.\n\nLet's first load the [`StableDiffusionInstructPix2PixPipeline`]:\n\n```python\nfrom diffusers import StableDiffusionInstructPix2PixPipeline\n\ninstruct_pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(\n    \"timbrooks/instruct-pix2pix\", torch_dtype=torch.float16\n).to(\"cuda\")\n```\n\nNow, we perform the edits:\n\n```python\nimport numpy as np\n\n\ndef edit_image(input_image, instruction):\n    image = instruct_pix2pix_pipeline(\n        instruction,\n        image=input_image,\n        output_type=\"np\",\n        generator=generator,\n    ).images[0]\n    return image\n\ninput_images = []\noriginal_captions = []\nmodified_captions = []\nedited_images = []\n\nfor idx in range(len(dataset)):\n    input_image = dataset[idx][\"image\"]\n    edit_instruction = dataset[idx][\"edit\"]\n    edited_image = edit_image(input_image, edit_instruction)\n\n    input_images.append(np.array(input_image))\n    original_captions.append(dataset[idx][\"input\"])\n    modified_captions.append(dataset[idx][\"output\"])\n    edited_images.append(edited_image)\n```\n\nTo measure the directional similarity, we first load CLIP's image and text encoders:\n\n```python\nfrom transformers import (\n    CLIPTokenizer,\n    CLIPTextModelWithProjection,\n    CLIPVisionModelWithProjection,\n    CLIPImageProcessor,\n)\n\nclip_id = \"openai/clip-vit-large-patch14\"\ntokenizer = CLIPTokenizer.from_pretrained(clip_id)\ntext_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to(\"cuda\")\nimage_processor = CLIPImageProcessor.from_pretrained(clip_id)\nimage_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(\"cuda\")\n```\n\nNotice that we are using a particular CLIP checkpoint, i.e., `openai/clip-vit-large-patch14`. This is because the Stable Diffusion pre-training was performed with this CLIP variant. For more details, refer to the [documentation](https://huggingface.co/docs/transformers/model_doc/clip).\n\nNext, we prepare a PyTorch `nn.Module` to compute directional similarity:\n\n```python\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass DirectionalSimilarity(nn.Module):\n    def __init__(self, tokenizer, text_encoder, image_processor, image_encoder):\n        super().__init__()\n        self.tokenizer = tokenizer\n        self.text_encoder = text_encoder\n        self.image_processor = image_processor\n        self.image_encoder = image_encoder\n\n    def preprocess_image(self, image):\n        image = self.image_processor(image, return_tensors=\"pt\")[\"pixel_values\"]\n        return {\"pixel_values\": image.to(\"cuda\")}\n\n    def tokenize_text(self, text):\n        inputs = self.tokenizer(\n            text,\n            max_length=self.tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        return {\"input_ids\": inputs.input_ids.to(\"cuda\")}\n\n    def encode_image(self, image):\n        preprocessed_image = self.preprocess_image(image)\n        image_features = self.image_encoder(**preprocessed_image).image_embeds\n        image_features = image_features / image_features.norm(dim=1, keepdim=True)\n        return image_features\n\n    def encode_text(self, text):\n        tokenized_text = self.tokenize_text(text)\n        text_features = self.text_encoder(**tokenized_text).text_embeds\n        text_features = text_features / text_features.norm(dim=1, keepdim=True)\n        return text_features\n\n    def compute_directional_similarity(self, img_feat_one, img_feat_two, text_feat_one, text_feat_two):\n        sim_direction = F.cosine_similarity(img_feat_two - img_feat_one, text_feat_two - text_feat_one)\n        return sim_direction\n\n    def forward(self, image_one, image_two, caption_one, caption_two):\n        img_feat_one = self.encode_image(image_one)\n        img_feat_two = self.encode_image(image_two)\n        text_feat_one = self.encode_text(caption_one)\n        text_feat_two = self.encode_text(caption_two)\n        directional_similarity = self.compute_directional_similarity(\n            img_feat_one, img_feat_two, text_feat_one, text_feat_two\n        )\n        return directional_similarity\n```\n\nLet's put `DirectionalSimilarity` to use now.\n\n```python\ndir_similarity = DirectionalSimilarity(tokenizer, text_encoder, image_processor, image_encoder)\nscores = []\n\nfor i in range(len(input_images)):\n    original_image = input_images[i]\n    original_caption = original_captions[i]\n    edited_image = edited_images[i]\n    modified_caption = modified_captions[i]\n\n    similarity_score = dir_similarity(original_image, edited_image, original_caption, modified_caption)\n    scores.append(float(similarity_score.detach().cpu()))\n\nprint(f\"CLIP directional similarity: {np.mean(scores)}\")\n# CLIP directional similarity: 0.0797976553440094\n```\n\nLike the CLIP Score, the higher the CLIP directional similarity, the better it is.\n\nIt should be noted that the `StableDiffusionInstructPix2PixPipeline` exposes two arguments, namely, `image_guidance_scale` and `guidance_scale` that let you control the quality of the final edited image. We encourage you to experiment with these two arguments and see the impact of that on the directional similarity.\n\nWe can extend the idea of this metric to measure how similar the original image and edited version are. To do that, we can just do `F.cosine_similarity(img_feat_two, img_feat_one)`. For these kinds of edits, we would still want the primary semantics of the images to be preserved as much as possible, i.e., a high similarity score.\n\nWe can use these metrics for similar pipelines such as the [`StableDiffusionPix2PixZeroPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix_zero#diffusers.StableDiffusionPix2PixZeroPipeline).\n\n> [!TIP]\n> Both CLIP score and CLIP direction similarity rely on the CLIP model, which can make the evaluations biased.\n\n***Extending metrics like IS, FID (discussed later), or KID can be difficult*** when the model under evaluation was pre-trained on a large image-captioning dataset (such as the [LAION-5B dataset](https://laion.ai/blog/laion-5b/)). This is because underlying these metrics is an InceptionNet (pre-trained on the ImageNet-1k dataset) used for extracting intermediate image features. The pre-training dataset of Stable Diffusion may have limited overlap with the pre-training dataset of InceptionNet, so it is not a good candidate here for feature extraction.\n\n***Using the above metrics helps evaluate models that are class-conditioned. For example, [DiT](https://huggingface.co/docs/diffusers/main/en/api/pipelines/dit). It was pre-trained being conditioned on the ImageNet-1k classes.***\n\n### Class-conditioned image generation\n\nClass-conditioned generative models are usually pre-trained on a class-labeled dataset such as [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k). Popular metrics for evaluating these models include Fréchet Inception Distance (FID), Kernel Inception Distance (KID), and Inception Score (IS). In this document, we focus on FID ([Heusel et al.](https://huggingface.co/papers/1706.08500)). We show how to compute it with the [`DiTPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/dit), which uses the [DiT model](https://huggingface.co/papers/2212.09748) under the hood.\n\nFID aims to measure how similar are two datasets of images. As per [this resource](https://mmgeneration.readthedocs.io/en/latest/quick_run.html#fid):\n\n> Fréchet Inception Distance is a measure of similarity between two datasets of images. It was shown to correlate well with the human judgment of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks. FID is calculated by computing the Fréchet distance between two Gaussians fitted to feature representations of the Inception network.\n\nThese two datasets are essentially the dataset of real images and the dataset of fake images (generated images in our case). FID is usually calculated with two large datasets. However, for this document, we will work with two mini datasets.\n\nLet's first download a few images from the ImageNet-1k training set:\n\n```python\nfrom zipfile import ZipFile\nimport requests\n\n\ndef download(url, local_filepath):\n    r = requests.get(url)\n    with open(local_filepath, \"wb\") as f:\n        f.write(r.content)\n    return local_filepath\n\ndummy_dataset_url = \"https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/sample-imagenet-images.zip\"\nlocal_filepath = download(dummy_dataset_url, dummy_dataset_url.split(\"/\")[-1])\n\nwith ZipFile(local_filepath, \"r\") as zipper:\n    zipper.extractall(\".\")\n```\n\n```python\nfrom PIL import Image\nimport os\nimport numpy as np\n\ndataset_path = \"sample-imagenet-images\"\nimage_paths = sorted([os.path.join(dataset_path, x) for x in os.listdir(dataset_path)])\n\nreal_images = [np.array(Image.open(path).convert(\"RGB\")) for path in image_paths]\n```\n\nThese are 10 images from the following ImageNet-1k classes: \"cassette_player\", \"chain_saw\" (x2), \"church\", \"gas_pump\" (x3), \"parachute\" (x2), and \"tench\".\n\n<p align=\"center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/real-images.png\" alt=\"real-images\"><br>\n    <em>Real images.</em>\n</p>\n\nNow that the images are loaded, let's apply some lightweight pre-processing on them to use them for FID calculation.\n\n```python\nfrom torchvision.transforms import functional as F\nimport torch\n\n\ndef preprocess_image(image):\n    image = torch.tensor(image).unsqueeze(0)\n    image = image.permute(0, 3, 1, 2) / 255.0\n    return F.center_crop(image, (256, 256))\n\nreal_images = torch.cat([preprocess_image(image) for image in real_images])\nprint(real_images.shape)\n# torch.Size([10, 3, 256, 256])\n```\n\nWe now load the [`DiTPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/dit) to generate images conditioned on the above-mentioned classes.\n\n```python\nfrom diffusers import DiTPipeline, DPMSolverMultistepScheduler\n\ndit_pipeline = DiTPipeline.from_pretrained(\"facebook/DiT-XL-2-256\", torch_dtype=torch.float16)\ndit_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(dit_pipeline.scheduler.config)\ndit_pipeline = dit_pipeline.to(\"cuda\")\n\nseed = 0\ngenerator = torch.manual_seed(seed)\n\n\nwords = [\n    \"cassette player\",\n    \"chainsaw\",\n    \"chainsaw\",\n    \"church\",\n    \"gas pump\",\n    \"gas pump\",\n    \"gas pump\",\n    \"parachute\",\n    \"parachute\",\n    \"tench\",\n]\n\nclass_ids = dit_pipeline.get_label_ids(words)\noutput = dit_pipeline(class_labels=class_ids, generator=generator, output_type=\"np\")\n\nfake_images = output.images\nfake_images = torch.tensor(fake_images)\nfake_images = fake_images.permute(0, 3, 1, 2)\nprint(fake_images.shape)\n# torch.Size([10, 3, 256, 256])\n```\n\nNow, we can compute the FID using [`torchmetrics`](https://torchmetrics.readthedocs.io/).\n\n```python\nfrom torchmetrics.image.fid import FrechetInceptionDistance\n\nfid = FrechetInceptionDistance(normalize=True)\nfid.update(real_images, real=True)\nfid.update(fake_images, real=False)\n\nprint(f\"FID: {float(fid.compute())}\")\n# FID: 177.7147216796875\n```\n\nThe lower the FID, the better it is. Several things can influence FID here:\n\n- Number of images (both real and fake)\n- Randomness induced in the diffusion process\n- Number of inference steps in the diffusion process\n- The scheduler being used in the diffusion process\n\nFor the last two points, it is, therefore, a good practice to run the evaluation across different seeds and inference steps, and then report an average result.\n\n> [!WARNING]\n> FID results tend to be fragile as they depend on a lot of factors:\n>\n> * The specific Inception model used during computation.\n> * The implementation accuracy of the computation.\n> * The image format (not the same if we start from PNGs vs JPGs).\n>\n> Keeping that in mind, FID is often most useful when comparing similar runs, but it is\n> hard to reproduce paper results unless the authors carefully disclose the FID\n> measurement code.\n>\n> These points apply to other related metrics too, such as KID and IS.\n\nAs a final step, let's visually inspect the `fake_images`.\n\n<p align=\"center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/fake-images.png\" alt=\"fake-images\"><br>\n    <em>Fake images.</em>\n</p>\n"
  },
  {
    "path": "docs/source/en/conceptual/philosophy.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Philosophy\n\n🧨 Diffusers provides **state-of-the-art** pretrained diffusion models across multiple modalities.\nIts purpose is to serve as a **modular toolbox** for both inference and training.\n\nWe aim at building a library that stands the test of time and therefore take API design very seriously.\n\nIn a nutshell, Diffusers is built to be a natural extension of PyTorch. Therefore, most of our design choices are based on [PyTorch's Design Principles](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy). Let's go over the most important ones:\n\n## Usability over Performance\n\n- While Diffusers has many built-in performance-enhancing features (see [Memory and Speed](https://huggingface.co/docs/diffusers/optimization/fp16)), models are always loaded with the highest precision and lowest optimization. Therefore, by default diffusion pipelines are always instantiated on CPU with float32 precision if not otherwise defined by the user. This ensures usability across different platforms and accelerators and means that no complex installations are required to run the library.\n- Diffusers aims to be a **light-weight** package and therefore has very few required dependencies, but many soft dependencies that can improve performance (such as `accelerate`, `safetensors`, `onnx`, etc...). We strive to keep the library as lightweight as possible so that it can be added without much concern as a dependency on other packages.\n- Diffusers prefers simple, self-explainable code over condensed, magic code. This means that short-hand code syntaxes such as lambda functions, and advanced PyTorch operators are often not desired.\n\n## Simple over easy\n\nAs PyTorch states, **explicit is better than implicit** and **simple is better than complex**. This design philosophy is reflected in multiple parts of the library:\n- We follow PyTorch's API with methods like [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to) to let the user handle device management.\n- Raising concise error messages is preferred to silently correct erroneous input. Diffusers aims at teaching the user, rather than making the library as easy to use as possible.\n- Complex model vs. scheduler logic is exposed instead of magically handled inside. Schedulers/Samplers are separated from diffusion models with minimal dependencies on each other. This forces the user to write the unrolled denoising loop. However, the separation allows for easier debugging and gives the user more control over adapting the denoising process or switching out diffusion models or schedulers.\n- Separately trained components of the diffusion pipeline, *e.g.* the text encoder, the unet, and the variational autoencoder, each have their own model class. This forces the user to handle the interaction between the different model components, and the serialization format separates the model components into different files. However, this allows for easier debugging and customization. DreamBooth or Textual Inversion training\nis very simple thanks to Diffusers' ability to separate single components of the diffusion pipeline.\n\n## Tweakable, contributor-friendly over abstraction\n\nFor large parts of the library, Diffusers adopts an important design principle of the [Transformers library](https://github.com/huggingface/transformers), which is to prefer copy-pasted code over hasty abstractions. This design principle is very opinionated and stands in stark contrast to popular design principles such as [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself).\nIn short, just like Transformers does for modeling files, Diffusers prefers to keep an extremely low level of abstraction and very self-contained code for pipelines and schedulers.\nFunctions, long code blocks, and even classes can be copied across multiple files which at first can look like a bad, sloppy design choice that makes the library unmaintainable.\n**However**, this design has proven to be extremely successful for Transformers and makes a lot of sense for community-driven, open-source machine learning libraries because:\n- Machine Learning is an extremely fast-moving field in which paradigms, model architectures, and algorithms are changing rapidly, which therefore makes it very difficult to define long-lasting code abstractions.\n- Machine Learning practitioners like to be able to quickly tweak existing code for ideation and research and therefore prefer self-contained code over one that contains many abstractions.\n- Open-source libraries rely on community contributions and therefore must build a library that is easy to contribute to. The more abstract the code, the more dependencies, the harder to read, and the harder to contribute to. Contributors simply stop contributing to very abstract libraries out of fear of breaking vital functionality. If contributing to a library cannot break other fundamental code, not only is it more inviting for potential new contributors, but it is also easier to review and contribute to multiple parts in parallel.\n\nAt Hugging Face, we call this design the **single-file policy** which means that almost all of the code of a certain class should be written in a single, self-contained file. To read more about the philosophy, you can have a look\nat [this blog post](https://huggingface.co/blog/transformers-design-philosophy).\n\nIn Diffusers, we follow this philosophy for both pipelines and schedulers, but only partly for diffusion models. The reason we don't follow this design fully for diffusion models is because almost all diffusion pipelines, such\nas [DDPM](https://huggingface.co/docs/diffusers/api/pipelines/ddpm), [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview#stable-diffusion-pipelines), [unCLIP (DALL·E 2)](https://huggingface.co/docs/diffusers/api/pipelines/unclip) and [Imagen](https://imagen.research.google/) all rely on the same diffusion model, the [UNet](https://huggingface.co/docs/diffusers/api/models/unet2d-cond).\n\nGreat, now you should have generally understood why 🧨 Diffusers is designed the way it is 🤗.\nWe try to apply these design principles consistently across the library. Nevertheless, there are some minor exceptions to the philosophy or some unlucky design choices. If you have feedback regarding the design, we would ❤️  to hear it [directly on GitHub](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).\n\n## Design Philosophy in Details\n\nNow, let's look a bit into the nitty-gritty details of the design philosophy. Diffusers essentially consists of three major classes: [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines), [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models), and [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).\nLet's walk through more in-detail design decisions for each class.\n\n### Pipelines\n\nPipelines are designed to be easy to use (therefore do not follow [*Simple over easy*](#simple-over-easy) 100%), are not feature complete, and should loosely be seen as examples of how to use [models](#models) and [schedulers](#schedulers) for inference.\n\nThe following design principles are followed:\n- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as it’s done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [# Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251).\n- Pipelines all inherit from [`DiffusionPipeline`].\n- Every pipeline consists of different model and scheduler components, that are documented in the [`model_index.json` file](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/model_index.json), are accessible under the same name as attributes of the pipeline and can be shared between pipelines with [`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components) function.\n- Every pipeline should be loadable via the [`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained) function.\n- Pipelines should be used **only** for inference.\n- Pipelines should be very readable, self-explanatory, and easy to tweak.\n- Pipelines should be designed to build on top of each other and be easy to integrate into higher-level APIs.\n- Pipelines are **not** intended to be feature-complete user interfaces. For feature-complete user interfaces one should rather have a look at [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), and [lama-cleaner](https://github.com/Sanster/lama-cleaner).\n- Every pipeline should have one and only one way to run it via a `__call__` method. The naming of the `__call__` arguments should be shared across all pipelines.\n- Pipelines should be named after the task they are intended to solve.\n- In almost all cases, novel diffusion pipelines shall be implemented in a new pipeline folder/file.\n\n### Models\n\nModels are designed as configurable toolboxes that are natural extensions of [PyTorch's Module class](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). They only partly follow the **single-file policy**.\n\nThe following design principles are followed:\n- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.\n- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unets/unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py), [`transformers/transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py), etc...\n- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.\n- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.\n- Models all inherit from `ModelMixin` and `ConfigMixin`.\n- Models can be optimized for performance when it doesn’t demand major code changes, keeps backward compatibility, and gives significant memory or compute gain.\n- Models should by default have the highest precision and lowest performance setting.\n- To integrate new model checkpoints whose general architecture can be classified as an architecture that already exists in Diffusers, the existing model architecture shall be adapted to make it work with the new checkpoint. One should only create a new file if the model architecture is fundamentally different.\n- Models should be designed to be easily extendable to future changes. This can be achieved by limiting public function arguments, configuration arguments, and \"foreseeing\" future changes, *e.g.* it is usually better to add `string` \"...type\" arguments that can easily be extended to new future types instead of boolean `is_..._type` arguments. Only the minimum amount of changes shall be made to existing architectures to make a new model checkpoint work.\n- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and\nreadable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n\n### Schedulers\n\nSchedulers are responsible to guide the denoising process for inference as well as to define a noise schedule for training. They are designed as individual classes with loadable configuration files and strongly follow the **single-file policy**.\n\nThe following design principles are followed:\n- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).\n- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained.\n- One scheduler Python file corresponds to one scheduler algorithm (as might be defined in a paper).\n- If schedulers share similar functionalities, we can make use of the `# Copied from` mechanism.\n- Schedulers all inherit from `SchedulerMixin` and `ConfigMixin`.\n- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](../using-diffusers/schedulers).\n- Every scheduler has to have a `set_num_inference_steps`, and a `step` function. `set_num_inference_steps(...)` has to be called before every denoising process, *i.e.* before `step(...)` is called.\n- Every scheduler exposes the timesteps to be \"looped over\" via a `timesteps` attribute, which is an array of timesteps the model will be called upon.\n- The `step(...)` function takes a predicted model output and the \"current\" sample (x_t) and returns the \"previous\", slightly more denoised sample (x_t-1).\n- Given the complexity of diffusion schedulers, the `step` function does not expose all the complexity and can be a bit of a \"black box\".\n- In almost all cases, novel schedulers shall be implemented in a new scheduling file."
  },
  {
    "path": "docs/source/en/hybrid_inference/api_reference.md",
    "content": "# Remote inference\n\nRemote inference provides access to an [Inference Endpoint](https://huggingface.co/docs/inference-endpoints/index) to offload local generation requirements for decoding and encoding.\n\n## remote_decode\n\n[[autodoc]] utils.remote_utils.remote_decode\n\n## remote_encode\n\n[[autodoc]] utils.remote_utils.remote_encode\n"
  },
  {
    "path": "docs/source/en/hybrid_inference/overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Remote inference\n\n> [!TIP]\n> This is currently an experimental feature, and if you have any feedback, please feel free to leave it [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).\n\nRemote inference offloads the decoding and encoding process to a remote endpoint to relax the memory requirements for local inference with large models. This feature is powered by [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index). Refer to the table below for the supported models and endpoint.\n\n| Model | Endpoint | Checkpoint | Support |\n|---|---|---|---|\n| Stable Diffusion v1 | https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud | [stabilityai/sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) | encode/decode |\n| Stable Diffusion XL | https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud | [madebyollin/sdxl-vae-fp16-fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) | encode/decode |\n| Flux | https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud | [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | encode/decode |\n| HunyuanVideo | https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud | [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) | decode |\n\nThis guide will show you how to encode and decode latents with remote inference.\n\n## Encoding\n\nEncoding converts images and videos into latent representations. Refer to the table below for the supported VAEs.\n\nPass an image to [`~utils.remote_encode`] to encode it. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference.\n\n```py\nimport torch\nfrom diffusers import FluxPipeline\nfrom diffusers.utils import load_image\nfrom diffusers.utils.remote_utils import remote_encode\n\npipeline = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-schnell\",\n    torch_dtype=torch.float16,\n    vae=None,\n    device_map=\"cuda\"\n)\n\ninit_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg\"\n)\ninit_image = init_image.resize((768, 512))\n\ninit_latent = remote_encode(\n    endpoint=\"https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud\",\n    image=init_image,\n    scaling_factor=0.3611,\n    shift_factor=0.1159\n)\n```\n\n## Decoding\n\nDecoding converts latent representations back into images or videos. Refer to the table below for the available and supported VAEs.\n\nSet the output type to `\"latent\"` in the pipeline and set the `vae` to `None`. Pass the latents to the [`~utils.remote_decode`] function. For Flux, the latents are packed so the `height` and `width` also need to be passed. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference.\n\n<hfoptions id=\"decode\">\n<hfoption id=\"Flux\">\n\n```py\nfrom diffusers import FluxPipeline\n\npipeline = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-schnell\",\n    torch_dtype=torch.bfloat16,\n    vae=None,\n    device_map=\"cuda\"\n)\n\nprompt = \"\"\"\nA photorealistic Apollo-era photograph of a cat in a small astronaut suit with a bubble helmet, standing on the Moon and holding a flagpole planted in the dusty lunar soil. The flag shows a colorful paw-print emblem. Earth glows in the black sky above the stark gray surface, with sharp shadows and high-contrast lighting like vintage NASA photos.\n\"\"\"\n\nlatent = pipeline(\n    prompt=prompt,\n    guidance_scale=0.0,\n    num_inference_steps=4,\n    output_type=\"latent\",\n).images\nimage = remote_decode(\n    endpoint=\"https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/\",\n    tensor=latent,\n    height=1024,\n    width=1024,\n    scaling_factor=0.3611,\n    shift_factor=0.1159,\n)\nimage.save(\"image.jpg\")\n```\n\n</hfoption>\n<hfoption id=\"HunyuanVideo\">\n\n```py\nimport torch\nfrom diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel\n\ntransformer = HunyuanVideoTransformer3DModel.from_pretrained(\n    \"hunyuanvideo-community/HunyuanVideo\", subfolder=\"transformer\", torch_dtype=torch.bfloat16\n)\npipeline = HunyuanVideoPipeline.from_pretrained(\n    model_id, transformer=transformer, vae=None, torch_dtype=torch.float16, device_map=\"cuda\"\n)\n\nlatent = pipeline(\n    prompt=\"A cat walks on the grass, realistic\",\n    height=320,\n    width=512,\n    num_frames=61,\n    num_inference_steps=30,\n    output_type=\"latent\",\n).frames\n\nvideo = remote_decode(\n    endpoint=\"https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/\",\n    tensor=latent,\n    output_type=\"mp4\",\n)\n\nif isinstance(video, bytes):\n    with open(\"video.mp4\", \"wb\") as f:\n        f.write(video)\n```\n\n</hfoption>\n</hfoptions>\n\n## Queuing\n\nRemote inference supports queuing to process multiple generation requests. While the current latent is being decoded, you can queue the next prompt.\n\n```py\nimport queue\nimport threading\nfrom IPython.display import display\nfrom diffusers import StableDiffusionXLPipeline\n\ndef decode_worker(q: queue.Queue):\n    while True:\n        item = q.get()\n        if item is None:\n            break\n        image = remote_decode(\n            endpoint=\"https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/\",\n            tensor=item,\n            scaling_factor=0.13025,\n        )\n        display(image)\n        q.task_done()\n\nq = queue.Queue()\nthread = threading.Thread(target=decode_worker, args=(q,), daemon=True)\nthread.start()\n\ndef decode(latent: torch.Tensor):\n    q.put(latent)\n\nprompts = [\n    \"A grainy Apollo-era style photograph of a cat in a snug astronaut suit with a bubble helmet, standing on the lunar surface and gripping a flag with a paw-print emblem. The gray Moon landscape stretches behind it, Earth glowing vividly in the black sky, shadows crisp and high-contrast.\",\n    \"A vintage 1960s sci-fi pulp magazine cover illustration of a heroic cat astronaut planting a flag on the Moon. Bold, saturated colors, exaggerated space gear, playful typography floating in the background, Earth painted in bright blues and greens.\",\n    \"A hyper-detailed cinematic shot of a cat astronaut on the Moon holding a fluttering flag, fur visible through the helmet glass, lunar dust scattering under its feet. The vastness of space and Earth in the distance create an epic, awe-inspiring tone.\",\n    \"A colorful cartoon drawing of a happy cat wearing a chunky, oversized spacesuit, proudly holding a flag with a big paw print on it. The Moon’s surface is simplified with craters drawn like doodles, and Earth in the sky has a smiling face.\",\n    \"A monochrome 1969-style press photo of a “first cat on the Moon” moment. The cat, in a tiny astronaut suit, stands by a planted flag, with grainy textures, scratches, and a blurred Earth in the background, mimicking old archival space photos.\"\n]\n\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    vae=None,\n    device_map=\"cuda\"\n)\n\npipeline.unet = pipeline.unet.to(memory_format=torch.channels_last)\npipeline.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n\n_ = pipeline(\n    prompt=prompts[0],\n    output_type=\"latent\",\n)\n\nfor prompt in prompts:\n    latent = pipeline(\n        prompt=prompt,\n        output_type=\"latent\",\n    ).images\n    decode(latent)\n\nq.put(None)\nthread.join()\n```\n\n## Benchmarks\n\nThe tables demonstrate the memory requirements for encoding and decoding with Stable Diffusion v1.5 and SDXL on different GPUs.\n\nFor the majority of these GPUs, the memory usage dictates whether other models (text encoders, UNet/transformer) need to be offloaded or required tiled encoding. The latter two techniques increases inference time and impacts quality.\n\n<details><summary>Encoding - Stable Diffusion v1.5</summary>\n\n| GPU                           | Resolution   |   Time (seconds) |   Memory (%) |   Tiled Time (secs) |   Tiled Memory (%) |\n|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|\n| NVIDIA GeForce RTX 4090       | 512x512      |            0.015 |      3.51901 |               0.015 |            3.51901 |\n| NVIDIA GeForce RTX 4090       | 256x256      |            0.004 |      1.3154  |               0.005 |            1.3154  |\n| NVIDIA GeForce RTX 4090       | 2048x2048    |            0.402 |     47.1852  |               0.496 |            3.51901 |\n| NVIDIA GeForce RTX 4090       | 1024x1024    |            0.078 |     12.2658  |               0.094 |            3.51901 |\n| NVIDIA GeForce RTX 4080 SUPER | 512x512      |            0.023 |      5.30105 |               0.023 |            5.30105 |\n| NVIDIA GeForce RTX 4080 SUPER | 256x256      |            0.006 |      1.98152 |               0.006 |            1.98152 |\n| NVIDIA GeForce RTX 4080 SUPER | 2048x2048    |            0.574 |     71.08    |               0.656 |            5.30105 |\n| NVIDIA GeForce RTX 4080 SUPER | 1024x1024    |            0.111 |     18.4772  |               0.14  |            5.30105 |\n| NVIDIA GeForce RTX 3090       | 512x512      |            0.032 |      3.52782 |               0.032 |            3.52782 |\n| NVIDIA GeForce RTX 3090       | 256x256      |            0.01  |      1.31869 |               0.009 |            1.31869 |\n| NVIDIA GeForce RTX 3090       | 2048x2048    |            0.742 |     47.3033  |               0.954 |            3.52782 |\n| NVIDIA GeForce RTX 3090       | 1024x1024    |            0.136 |     12.2965  |               0.207 |            3.52782 |\n| NVIDIA GeForce RTX 3080       | 512x512      |            0.036 |      8.51761 |               0.036 |            8.51761 |\n| NVIDIA GeForce RTX 3080       | 256x256      |            0.01  |      3.18387 |               0.01  |            3.18387 |\n| NVIDIA GeForce RTX 3080       | 2048x2048    |            0.863 |     86.7424  |               1.191 |            8.51761 |\n| NVIDIA GeForce RTX 3080       | 1024x1024    |            0.157 |     29.6888  |               0.227 |            8.51761 |\n| NVIDIA GeForce RTX 3070       | 512x512      |            0.051 |     10.6941  |               0.051 |           10.6941  |\n| NVIDIA GeForce RTX 3070       | 256x256      |            0.015 |      3.99743 |               0.015 |            3.99743 |\n| NVIDIA GeForce RTX 3070       | 2048x2048    |            1.217 |     96.054   |               1.482 |           10.6941  |\n| NVIDIA GeForce RTX 3070       | 1024x1024    |            0.223 |     37.2751  |               0.327 |           10.6941  |\n\n</details>\n\n<details><summary>Encoding SDXL</summary>\n\n| GPU                           | Resolution   |   Time (seconds) |   Memory Consumed (%) |   Tiled Time (seconds) |   Tiled Memory (%) |\n|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|\n| NVIDIA GeForce RTX 4090       | 512x512      |            0.029 |               4.95707 |                  0.029 |            4.95707 |\n| NVIDIA GeForce RTX 4090       | 256x256      |            0.007 |               2.29666 |                  0.007 |            2.29666 |\n| NVIDIA GeForce RTX 4090       | 2048x2048    |            0.873 |              66.3452  |                  0.863 |           15.5649  |\n| NVIDIA GeForce RTX 4090       | 1024x1024    |            0.142 |              15.5479  |                  0.143 |           15.5479  |\n| NVIDIA GeForce RTX 4080 SUPER | 512x512      |            0.044 |               7.46735 |                  0.044 |            7.46735 |\n| NVIDIA GeForce RTX 4080 SUPER | 256x256      |            0.01  |               3.4597  |                  0.01  |            3.4597  |\n| NVIDIA GeForce RTX 4080 SUPER | 2048x2048    |            1.317 |              87.1615  |                  1.291 |           23.447   |\n| NVIDIA GeForce RTX 4080 SUPER | 1024x1024    |            0.213 |              23.4215  |                  0.214 |           23.4215  |\n| NVIDIA GeForce RTX 3090       | 512x512      |            0.058 |               5.65638 |                  0.058 |            5.65638 |\n| NVIDIA GeForce RTX 3090       | 256x256      |            0.016 |               2.45081 |                  0.016 |            2.45081 |\n| NVIDIA GeForce RTX 3090       | 2048x2048    |            1.755 |              77.8239  |                  1.614 |           18.4193  |\n| NVIDIA GeForce RTX 3090       | 1024x1024    |            0.265 |              18.4023  |                  0.265 |           18.4023  |\n| NVIDIA GeForce RTX 3080       | 512x512      |            0.064 |              13.6568  |                  0.064 |           13.6568  |\n| NVIDIA GeForce RTX 3080       | 256x256      |            0.018 |               5.91728 |                  0.018 |            5.91728 |\n| NVIDIA GeForce RTX 3080       | 2048x2048    |          OOM     |             OOM       |                  1.866 |           44.4717  |\n| NVIDIA GeForce RTX 3080       | 1024x1024    |            0.302 |              44.4308  |                  0.302 |           44.4308  |\n| NVIDIA GeForce RTX 3070       | 512x512      |            0.093 |              17.1465  |                  0.093 |           17.1465  |\n| NVIDIA GeForce RTX 3070       | 256x256      |            0.025 |               7.42931 |                  0.026 |            7.42931 |\n| NVIDIA GeForce RTX 3070       | 2048x2048    |          OOM     |             OOM       |                  2.674 |           55.8355  |\n| NVIDIA GeForce RTX 3070       | 1024x1024    |            0.443 |              55.7841  |                  0.443 |           55.7841  |\n\n</details>\n\n<details><summary>Decoding - Stable Diffusion v1.5</summary>\n\n| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |\n| --- | --- | --- | --- | --- | --- |\n| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |\n| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |\n| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |\n| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |\n| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |\n| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |\n| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |\n| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |\n| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |\n| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |\n| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |\n| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |\n\n</details>\n\n<details><summary>Decoding SDXL</summary>\n\n| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |\n| --- | --- | --- | --- | --- | --- |\n| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |\n| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |\n| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |\n| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |\n| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |\n| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |\n| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |\n| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |\n| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |\n| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |\n| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |\n| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |\n\n</details>\n\n\n## Resources\n\n- Remote inference is also supported in [SD.Next](https://github.com/vladmandic/sdnext) and [ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae).\n- Refer to the [Remote VAEs for decoding with Inference Endpoints](https://huggingface.co/blog/remote_vae) blog post to learn more."
  },
  {
    "path": "docs/source/en/index.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://raw.githubusercontent.com/huggingface/diffusers/77aadfee6a891ab9fcfb780f87c693f7a5beeb8e/docs/source/imgs/diffusers_library.jpg\" width=\"400\" style=\"border: none;\"/>\n    <br>\n</p>\n\n# Diffusers\n\nDiffusers is a library of state-of-the-art pretrained diffusion models for generating videos, images, and audio.\n\nThe library revolves around the [`DiffusionPipeline`], an API designed for:\n\n- easy inference with only a few lines of code\n- flexibility to mix-and-match pipeline components (models, schedulers)\n- loading and using adapters like LoRA\n\nDiffusers also comes with optimizations - such as offloading and quantization - to ensure even the largest models are accessible on memory-constrained devices. If memory is not an issue, Diffusers supports torch.compile to boost inference speed.\n\nGet started right away with a Diffusers model on the [Hub](https://huggingface.co/models?library=diffusers&sort=trending) today!\n\n## Learn\n\nIf you're a beginner, we recommend starting with the [Hugging Face Diffusion Models Course](https://huggingface.co/learn/diffusion-course/unit0/1). You'll learn the theory behind diffusion models, and learn how to use the Diffusers library to generate images, fine-tune your own models, and more.\n"
  },
  {
    "path": "docs/source/en/installation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Installation\n\nDiffusers is tested on Python 3.8+ and PyTorch 1.4+. Install [PyTorch](https://pytorch.org/get-started/locally/) according to your system and setup.\n\nCreate a [virtual environment](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) for easier management of separate projects and to avoid compatibility issues between dependencies. Use [uv](https://docs.astral.sh/uv/), a Rust-based Python package and project manager, to create a virtual environment and install Diffusers.\n\n```bash\nuv venv my-env\nsource my-env/bin/activate\n```\n\nInstall Diffusers with one of the following methods.\n\n<hfoptions id=\"install\">\n<hfoption id=\"pip\">\n\nPyTorch only supports Python 3.8 - 3.11 on Windows.\n\n```bash\nuv pip install diffusers[\"torch\"] transformers\n```\n\n</hfoption>\n<hfoption id=\"conda\">\n\n```bash\nconda install -c conda-forge diffusers\n```\n\n</hfoption>\n<hfoption id=\"source\">\n\nA source install installs the `main` version instead of the latest `stable` version. The `main` version is useful for staying updated with the latest changes but it may not always be stable. If you run into a problem, open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and we will try to resolve it as soon as possible.\n\nMake sure [Accelerate](https://huggingface.co/docs/accelerate/index) is installed.\n\n```bash\nuv pip install accelerate\n```\n\nInstall Diffusers from source with the command below.\n\n```bash\nuv pip install git+https://github.com/huggingface/diffusers\n```\n\n</hfoption>\n</hfoptions>\n\n## Editable install\n\nAn editable install is recommended for development workflows or if you're using the `main` version of the source code. A special link is created between the cloned repository and the Python library paths. This avoids reinstalling a package after every change.\n\nClone the repository and install Diffusers with the following commands.\n\n```bash\ngit clone https://github.com/huggingface/diffusers.git\ncd diffusers\nuv pip install -e \".[torch]\"\n```\n\n> [!WARNING]\n> You must keep the `diffusers` folder if you want to keep using the library with the editable install.\n\nUpdate your cloned repository to the latest version of Diffusers with the command below.\n\n```bash\ncd ~/diffusers/\ngit pull\n```\n\n## Cache\n\nModel weights and files are downloaded from the Hub to a cache, which is usually your home directory. Change the cache location with the [HF_HOME](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhome) or [HF_HUB_CACHE](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhubcache) environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].\n\n<hfoptions id=\"cache\">\n<hfoption id=\"env variable\">\n\n```bash\nexport HF_HOME=\"/path/to/your/cache\"\nexport HF_HUB_CACHE=\"/path/to/your/hub/cache\"\n```\n\n</hfoption>\n<hfoption id=\"from_pretrained\">\n\n```py\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    cache_dir=\"/path/to/your/cache\"\n)\n```\n\n</hfoption>\n</hfoptions>\n\nCached files allow you to use Diffusers offline. Set the [HF_HUB_OFFLINE](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhuboffline) environment variable to `1` to prevent Diffusers from connecting to the internet.\n\n```shell\nexport HF_HUB_OFFLINE=1\n```\n\nFor more details about managing and cleaning the cache, take a look at the [Understand caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.\n\n## Telemetry logging\n\nDiffusers gathers telemetry information during [`~DiffusionPipeline.from_pretrained`] requests.\nThe data gathered includes the Diffusers and PyTorch version, the requested model or pipeline class,\nand the path to a pretrained checkpoint if it is hosted on the Hub.\n\nThis usage data helps us debug issues and prioritize new features.\nTelemetry is only sent when loading models and pipelines from the Hub,\nand it is not collected if you're loading local files.\n\nOpt-out and disable telemetry collection with the [HF_HUB_DISABLE_TELEMETRY](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhubdisabletelemetry) environment variable.\n\n<hfoptions id=\"telemetry\">\n<hfoption id=\"Linux/macOS\">\n\n```bash\nexport HF_HUB_DISABLE_TELEMETRY=1\n```\n\n</hfoption>\n<hfoption id=\"Windows\">\n\n```bash\nset HF_HUB_DISABLE_TELEMETRY=1\n```\n\n</hfoption>\n</hfoptions>\n"
  },
  {
    "path": "docs/source/en/modular_diffusers/auto_pipeline_blocks.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AutoPipelineBlocks\n\n[`~modular_pipelines.AutoPipelineBlocks`] are a multi-block type containing blocks that support different workflows. It automatically selects which sub-blocks to run based on the input provided at runtime. This is typically used to package multiple workflows - text-to-image, image-to-image, inpaint - into a single pipeline for convenience.\n\nThis guide shows how to create [`~modular_pipelines.AutoPipelineBlocks`].\n\nCreate three [`~modular_pipelines.ModularPipelineBlocks`] for text-to-image, image-to-image, and inpainting. These represent the different workflows available in the pipeline.\n\n<hfoptions id=\"auto\">\n<hfoption id=\"text-to-image\">\n\n```py\nimport torch\nfrom diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam\n\nclass TextToImageBlock(ModularPipelineBlocks):\n    model_name = \"text2img\"\n\n    @property\n    def inputs(self):\n        return [InputParam(name=\"prompt\")]\n\n    @property\n    def intermediate_outputs(self):\n        return []\n\n    @property\n    def description(self):\n        return \"I'm a text-to-image workflow!\"\n\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        print(\"running the text-to-image workflow\")\n        # Add your text-to-image logic here\n        # For example: generate image from prompt\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\n\n</hfoption>\n<hfoption id=\"image-to-image\">\n\n```py\nclass ImageToImageBlock(ModularPipelineBlocks):\n    model_name = \"img2img\"\n\n    @property\n    def inputs(self):\n        return [InputParam(name=\"prompt\"), InputParam(name=\"image\")]\n\n    @property\n    def intermediate_outputs(self):\n        return []\n\n    @property\n    def description(self):\n        return \"I'm an image-to-image workflow!\"\n\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        print(\"running the image-to-image workflow\")\n        # Add your image-to-image logic here\n        # For example: transform input image based on prompt\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\n\n</hfoption>\n<hfoption id=\"inpaint\">\n\n```py\nclass InpaintBlock(ModularPipelineBlocks):\n    model_name = \"inpaint\"\n\n    @property\n    def inputs(self):\n        return [InputParam(name=\"prompt\"), InputParam(name=\"image\"), InputParam(name=\"mask\")]\n\n    @property\n    def intermediate_outputs(self):\n        return []\n\n    @property\n    def description(self):\n        return \"I'm an inpaint workflow!\"\n\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        print(\"running the inpaint workflow\")\n        # Add your inpainting logic here\n        # For example: fill masked areas based on prompt\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\n</hfoption>\n</hfoptions>\n\nCreate an [`~modular_pipelines.AutoPipelineBlocks`] class that includes a list of the sub-block classes and their corresponding block names.\n\nYou also need to include `block_trigger_inputs`, a list of input names that trigger the corresponding block. If a trigger input is provided at runtime, then that block is selected to run. Use `None` to specify the default block to run if no trigger inputs are detected.\n\nLastly, it is important to include a `description` that clearly explains which inputs trigger which workflow. This helps users understand how to run specific workflows.\n\n```py\nfrom diffusers.modular_pipelines import AutoPipelineBlocks\n\nclass AutoImageBlocks(AutoPipelineBlocks):\n    # List of sub-block classes to choose from\n    block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]\n    # Names for each block in the same order\n    block_names = [\"inpaint\", \"img2img\", \"text2img\"]\n    # Trigger inputs that determine which block to run\n    # - \"mask\" triggers inpaint workflow\n    # - \"image\" triggers img2img workflow (but only if mask is not provided)\n    # - if none of above, runs the text2img workflow (default)\n    block_trigger_inputs = [\"mask\", \"image\", None]\n\n    @property\n    def description(self):\n        return (\n            \"Pipeline generates images given different types of conditions!\\n\"\n            + \"This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\\n\"\n            + \" - inpaint workflow is run when `mask` is provided.\\n\"\n            + \" - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\\n\"\n            + \" - text2img workflow is run when neither `image` nor `mask` is provided.\\n\"\n        )\n```\n\nIt is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, its conditional logic may be difficult to figure out if it isn't properly explained.\n\nCreate an instance of `AutoImageBlocks`.\n\n```py\nauto_blocks = AutoImageBlocks()\n```\n\nFor more complex compositions, such as nested [`~modular_pipelines.AutoPipelineBlocks`] blocks when they're used as sub-blocks in larger pipelines, use the [`~modular_pipelines.SequentialPipelineBlocks.get_execution_blocks`] method to extract the a block that is actually run based on your input.\n\n```py\nauto_blocks.get_execution_blocks(mask=True)\n```\n\n## ConditionalPipelineBlocks\n\n[`~modular_pipelines.AutoPipelineBlocks`] is a special case of [`~modular_pipelines.ConditionalPipelineBlocks`]. While [`~modular_pipelines.AutoPipelineBlocks`] selects blocks based on whether a trigger input is provided or not, [`~modular_pipelines.ConditionalPipelineBlocks`] is able to select a block based on custom selection logic provided in the `select_block` method.\n\nHere is the same example written using [`~modular_pipelines.ConditionalPipelineBlocks`] directly:\n\n```py\nfrom diffusers.modular_pipelines import ConditionalPipelineBlocks\n\nclass AutoImageBlocks(ConditionalPipelineBlocks):\n    block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]\n    block_names = [\"inpaint\", \"img2img\", \"text2img\"]\n    block_trigger_inputs = [\"mask\", \"image\"]\n    default_block_name = \"text2img\"\n\n    @property\n    def description(self):\n        return (\n            \"Pipeline generates images given different types of conditions!\\n\"\n            + \"This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\\n\"\n            + \" - inpaint workflow is run when `mask` is provided.\\n\"\n            + \" - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\\n\"\n            + \" - text2img workflow is run when neither `image` nor `mask` is provided.\\n\"\n        )\n\n    def select_block(self, mask=None, image=None) -> str | None:\n        if mask is not None:\n            return \"inpaint\"\n        if image is not None:\n            return \"img2img\"\n        return None  # falls back to default_block_name (\"text2img\")\n```\n\nThe inputs listed in `block_trigger_inputs` are passed as keyword arguments to `select_block()`. When `select_block` returns `None`, it falls back to `default_block_name`. If `default_block_name` is also `None`, the entire conditional block is skipped — this is useful for optional processing steps that should only run when specific inputs are provided.\n\n## Workflows\n\nPipelines that contain conditional blocks ([`~modular_pipelines.AutoPipelineBlocks`] or [`~modular_pipelines.ConditionalPipelineBlocks]`) can support multiple workflows — for example, our SDXL modular pipeline supports a dozen workflows all in one pipeline. But this also means it can be confusing for users to know what workflows are supported and how to run them. For pipeline builders, it's useful to be able to extract only the blocks relevant to a specific workflow.\n\nWe recommend defining a `_workflow_map` to give each workflow a name and explicitly list the inputs it requires.\n\n```py\nfrom diffusers.modular_pipelines import SequentialPipelineBlocks\n\nclass MyPipelineBlocks(SequentialPipelineBlocks):\n    block_classes = [TextEncoderBlock, AutoImageBlocks, DecodeBlock]\n    block_names = [\"text_encoder\", \"auto_image\", \"decode\"]\n\n    _workflow_map = {\n        \"text2image\": {\"prompt\": True},\n        \"image2image\": {\"image\": True, \"prompt\": True},\n        \"inpaint\": {\"mask\": True, \"image\": True, \"prompt\": True},\n    }\n```\n\nAll of our built-in modular pipelines come with pre-defined workflows. The `available_workflows` property lists all supported workflows:\n\n```py\npipeline_blocks = MyPipelineBlocks()\npipeline_blocks.available_workflows\n# ['text2image', 'image2image', 'inpaint']\n```\n\nRetrieve a specific workflow with `get_workflow` to inspect and debug a specific block that executes the workflow.\n\n```py\npipeline_blocks.get_workflow(\"inpaint\")\n```"
  },
  {
    "path": "docs/source/en/modular_diffusers/components_manager.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ComponentsManager\n\nThe [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), and supports offloading.\n\nThis guide will show you how to use [`ComponentsManager`] to manage components and device memory.\n\n## Connect to a pipeline\n\nCreate a [`ComponentsManager`] and pass it to a [`ModularPipeline`] with either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`]. \n\n\n<hfoptions id=\"create\">\n<hfoption id=\"from_pretrained\">\n\n```py\nfrom diffusers import ModularPipeline, ComponentsManager\nimport torch\n\nmanager = ComponentsManager()\npipe = ModularPipeline.from_pretrained(\"Tongyi-MAI/Z-Image-Turbo\", components_manager=manager)\npipe.load_components(torch_dtype=torch.bfloat16)\n```\n\n</hfoption>\n<hfoption id=\"init_pipeline\">\n\n```py\nfrom diffusers import ModularPipelineBlocks, ComponentsManager\nimport torch\nmanager = ComponentsManager()\nblocks = ModularPipelineBlocks.from_pretrained(\"diffusers/Florence2-image-Annotator\", trust_remote_code=True)\npipe= blocks.init_pipeline(components_manager=manager)\npipe.load_components(torch_dtype=torch.bfloat16)\n```\n\n</hfoption>\n</hfoptions>\n\nComponents loaded by the pipeline are automatically registered in the manager. You can inspect them right away.\n\n## Inspect components\n\nPrint the [`ComponentsManager`] to see all registered components, including their class, device placement, dtype, memory size, and load ID.\n\nThe output below corresponds to the `from_pretrained` example above.\n\n```py\nComponents:\n=============================================================================================================================\nModels:\n-----------------------------------------------------------------------------------------------------------------------------\nName_ID                      | Class                    | Device: act(exec) | Dtype          | Size (GB) | Load ID\n-----------------------------------------------------------------------------------------------------------------------------\ntext_encoder_140458257514752 | Qwen3Model               | cpu               | torch.bfloat16 | 7.49      | Tongyi-MAI/Z-Image-Turbo|text_encoder|null|null\nvae_140458257515376          | AutoencoderKL            | cpu               | torch.bfloat16 | 0.16      | Tongyi-MAI/Z-Image-Turbo|vae|null|null\ntransformer_140458257515616  | ZImageTransformer2DModel | cpu               | torch.bfloat16 | 11.46     | Tongyi-MAI/Z-Image-Turbo|transformer|null|null\n-----------------------------------------------------------------------------------------------------------------------------\n\nOther Components:\n-----------------------------------------------------------------------------------------------------------------------------\nID                           | Class                           | Collection\n-----------------------------------------------------------------------------------------------------------------------------\nscheduler_140461023555264    | FlowMatchEulerDiscreteScheduler | N/A\ntokenizer_140458256346432    | Qwen2Tokenizer                  | N/A\n-----------------------------------------------------------------------------------------------------------------------------\n```\n\nThe table shows models (with device, dtype, and memory info) separately from other components like schedulers and tokenizers. If any models have LoRA adapters, IP-Adapters, or quantization applied, that information is displayed in an additional section at the bottom.\n\n## Offloading\n\nThe [`~ComponentsManager.enable_auto_cpu_offload`] method is a global offloading strategy that works across all models regardless of which pipeline is using them. Once enabled, you don't need to worry about device placement if you add or remove components.\n\n```py\nmanager.enable_auto_cpu_offload(device=\"cuda\")\n```\n\nAll models begin on the CPU and [`ComponentsManager`] moves them to the appropriate device right before they're needed, and moves other models back to the CPU when GPU memory is low.\n\nCall [`~ComponentsManager.disable_auto_cpu_offload`] to disable offloading.\n\n```py\nmanager.disable_auto_cpu_offload()\n```\n"
  },
  {
    "path": "docs/source/en/modular_diffusers/custom_blocks.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n\n# Building Custom Blocks\n\n[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block.\n\n> [!TIP]\n> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom blocks.\n\n## Project Structure\n\nYour custom block project should use the following structure:\n\n```shell\n.\n├── block.py\n└── modular_config.json\n```\n\n- `block.py` contains the custom block implementation\n- `modular_config.json` contains the metadata needed to load the block\n\n## Quick Start with Template\n\nThe fastest way to create a custom block is to start from our template. The template provides a pre-configured project structure with `block.py` and `modular_config.json` files, plus commented examples showing how to define components, inputs, outputs, and the `__call__` method—so you can focus on your custom logic instead of boilerplate setup.\n\n### Download the template\n\n```python\nfrom diffusers import ModularPipelineBlocks\n\nmodel_id = \"diffusers/custom-block-template\"\nlocal_dir = model_id.split(\"/\")[-1]\n\nblocks = ModularPipelineBlocks.from_pretrained(\n    model_id, \n    trust_remote_code=True, \n    local_dir=local_dir\n)\n```\n\nThis saves the template files to `custom-block-template/` locally or you could use `local_dir` to save to a specific location.\n\n### Edit locally\n\nOpen `block.py` and implement your custom block. The template includes commented examples showing how to define each property. See the [Florence-2 example](#example-florence-2-image-annotator) below for a complete implementation.\n\n### Test your block\n\n```python\nfrom diffusers import ModularPipelineBlocks\n\nblocks = ModularPipelineBlocks.from_pretrained(local_dir, trust_remote_code=True)\npipeline = blocks.init_pipeline()\noutput = pipeline(...)  # your inputs here\n```\n\n### Upload to the Hub\n\n```python\npipeline.save_pretrained(local_dir, repo_id=\"your-username/your-block-name\", push_to_hub=True)\n```\n\n## Example: Florence-2 Image Annotator\n\nThis example creates a custom block with [Florence-2](https://huggingface.co/docs/transformers/model_doc/florence2) to process an input image and generate a mask for inpainting.\n\n### Define components\n\nDefine the components the block needs, `Florence2ForConditionalGeneration` and its processor. When defining components, specify the `name` (how you'll access it in code), `type_hint` (the model class), and `pretrained_model_name_or_path` (where to load weights from).\n\n```python\n# Inside block.py\nfrom diffusers.modular_pipelines import ModularPipelineBlocks, ComponentSpec\nfrom transformers import AutoProcessor, Florence2ForConditionalGeneration\n\n\nclass Florence2ImageAnnotatorBlock(ModularPipelineBlocks):\n\n    @property\n    def expected_components(self):\n        return [\n            ComponentSpec(\n                name=\"image_annotator\",\n                type_hint=Florence2ForConditionalGeneration,\n                pretrained_model_name_or_path=\"florence-community/Florence-2-base-ft\",\n            ),\n            ComponentSpec(\n                name=\"image_annotator_processor\",\n                type_hint=AutoProcessor,\n                pretrained_model_name_or_path=\"florence-community/Florence-2-base-ft\",\n            ),\n        ]\n```\n\n### Define inputs and outputs\n\nInputs include the image, annotation task, and prompt. Outputs include the generated mask and annotations.\n\n```python\nfrom typing import List, Union\nfrom PIL import Image\nfrom diffusers.modular_pipelines import InputParam, OutputParam\n\n\nclass Florence2ImageAnnotatorBlock(ModularPipelineBlocks):\n\n    # ... expected_components from above ...\n\n    @property\n    def inputs(self) -> List[InputParam]:\n        return [\n            InputParam(\n                \"image\",\n                type_hint=Union[Image.Image, List[Image.Image]],\n                required=True,\n                description=\"Image(s) to annotate\",\n            ),\n            InputParam(\n                \"annotation_task\",\n                type_hint=str,\n                default=\"<REFERRING_EXPRESSION_SEGMENTATION>\",\n                description=\"Annotation task to perform (e.g., <OD>, <CAPTION>, <REFERRING_EXPRESSION_SEGMENTATION>)\",\n            ),\n            InputParam(\n                \"annotation_prompt\",\n                type_hint=str,\n                required=True,\n                description=\"Prompt to provide context for the annotation task\",\n            ),\n            InputParam(\n                \"annotation_output_type\",\n                type_hint=str,\n                default=\"mask_image\",\n                description=\"Output type: 'mask_image', 'mask_overlay', or 'bounding_box'\",\n            ),\n        ]\n\n    @property\n    def intermediate_outputs(self) -> List[OutputParam]:\n        return [\n            OutputParam(\n                \"mask_image\",\n                type_hint=Image.Image,\n                description=\"Inpainting mask for the input image\",\n            ),\n            OutputParam(\n                \"annotations\",\n                type_hint=dict,\n                description=\"Raw annotation predictions\",\n            ),\n            OutputParam(\n                \"image\",\n                type_hint=Image.Image,\n                description=\"Annotated image\",\n            ),\n        ]\n```\n\n### Implement the `__call__` method\n\nThe `__call__` method contains the block's logic. Access inputs via `block_state`, run your computation, and set outputs back to `block_state`.\n\n```python\nimport torch\nfrom diffusers.modular_pipelines import PipelineState\n\n\nclass Florence2ImageAnnotatorBlock(ModularPipelineBlocks):\n\n    # ... expected_components, inputs, intermediate_outputs from above ...\n\n    @torch.no_grad()\n    def __call__(self, components, state: PipelineState) -> PipelineState:\n        block_state = self.get_block_state(state)\n        \n        images, annotation_task_prompt = self.prepare_inputs(\n            block_state.image, block_state.annotation_prompt\n        )\n        task = block_state.annotation_task\n        fill = block_state.fill\n        \n        annotations = self.get_annotations(\n            components, images, annotation_task_prompt, task\n        )\n        block_state.annotations = annotations\n        if block_state.annotation_output_type == \"mask_image\":\n            block_state.mask_image = self.prepare_mask(images, annotations)\n        else:\n            block_state.mask_image = None\n\n        if block_state.annotation_output_type == \"mask_overlay\":\n            block_state.image = self.prepare_mask(images, annotations, overlay=True, fill=fill)\n\n        elif block_state.annotation_output_type == \"bounding_box\":\n            block_state.image = self.prepare_bounding_boxes(images, annotations)\n\n        self.set_block_state(state, block_state)\n\n        return components, state\n    \n    # Helper methods for mask/bounding box generation...\n```\n\n> [!TIP]\n> See the complete implementation at [diffusers/Florence2-image-Annotator](https://huggingface.co/diffusers/Florence2-image-Annotator).\n\n## Using Custom Blocks\n\nLoad a custom block with [`~ModularPipeline.from_pretrained`] and set `trust_remote_code=True`.\n\n```py\nimport torch\nfrom diffusers import ModularPipeline\nfrom diffusers.utils import load_image\n\n# Load the Florence-2 annotator pipeline\nimage_annotator = ModularPipeline.from_pretrained(\n    \"diffusers/Florence2-image-Annotator\",\n    trust_remote_code=True\n)\n\n# Check the docstring to see inputs/outputs\nprint(image_annotator.blocks.doc)\n```\n\nUse the block to generate a mask:\n\n```python\nimage_annotator.load_components(torch_dtype=torch.bfloat16)\nimage_annotator.to(\"cuda\")\n\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg\")\nimage = image.resize((1024, 1024))\nprompt = [\"A red car\"]\nannotation_task = \"<REFERRING_EXPRESSION_SEGMENTATION>\"\nannotation_prompt = [\"the car\"]\n\nmask_image = image_annotator_node(\n    prompt=prompt,\n    image=image,\n    annotation_task=annotation_task,\n    annotation_prompt=annotation_prompt,\n    annotation_output_type=\"mask_image\",\n).images\nmask_image[0].save(\"car-mask.png\")\n```\n\nCompose it with other blocks to create a new pipeline:\n\n```python\n# Get the annotator block\nannotator_block = image_annotator.blocks\n\n# Get an inpainting workflow and insert the annotator at the beginning\ninpaint_blocks = ModularPipeline.from_pretrained(\"Qwen/Qwen-Image\").blocks.get_workflow(\"inpainting\")\ninpaint_blocks.sub_blocks.insert(\"image_annotator\", annotator_block, 0)\n\n# Initialize the combined pipeline\npipe = inpaint_blocks.init_pipeline()\npipe.load_components(torch_dtype=torch.float16, device=\"cuda\")\n\n# Now the pipeline automatically generates masks from prompts\noutput = pipe(\n    prompt=prompt,\n    image=image,\n    annotation_task=annotation_task,\n    annotation_prompt=annotation_prompt,\n    annotation_output_type=\"mask_image\",\n    num_inference_steps=35,\n    guidance_scale=7.5,\n    strength=0.95,\n    output=\"images\"\n)\noutput[0].save(\"florence-inpainting.png\")\n```\n\n## Editing custom blocks\n\nEdit custom blocks by downloading it locally. This is the same workflow as the [Quick Start with Template](#quick-start-with-template), but starting from an existing block instead of the template.\n\nUse the `local_dir` argument to download a custom block to a specific folder:\n\n```python\nfrom diffusers import ModularPipelineBlocks\n\n# Download to a local folder for editing\nannotator_block = ModularPipelineBlocks.from_pretrained(\n    \"diffusers/Florence2-image-Annotator\",\n    trust_remote_code=True,\n    local_dir=\"./my-florence-block\"\n)\n```\n\nAny changes made to the block files in this folder will be reflected when you load the block again. When you're ready to share your changes, upload to a new repository:\n\n```python\npipeline = annotator_block.init_pipeline()\npipeline.save_pretrained(\"./my-florence-block\", repo_id=\"your-username/my-custom-florence\", push_to_hub=True)\n```\n\n## Next Steps\n\n<hfoptions id=\"next\">\n<hfoption id=\"Learn block types\">\n\nThis guide covered creating a single custom block. Learn how to compose multiple blocks together:\n\n- [SequentialPipelineBlocks](./sequential_pipeline_blocks): Chain blocks to execute in sequence\n- [ConditionalPipelineBlocks](./auto_pipeline_blocks): Create conditional blocks that select different execution paths\n- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks): Define an iterative workflows like the denoising loop\n\n</hfoption>\n<hfoption id=\"Use in Mellon\">\n\nMake your custom block work with Mellon's visual interface. See the [Mellon Custom Blocks](./mellon) guide.\n\n</hfoption>\n<hfoption id=\"Explore existing blocks\">\n\nBrowse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.\n\n</hfoption>\n</hfoptions>\n\n## Dependencies\n\nDeclaring package dependencies in custom blocks prevents runtime import errors later on. Diffusers validates the dependencies and returns a warning if a package is missing or incompatible.\n\nSet a `_requirements` attribute in your block class, mapping package names to version specifiers.\n\n```py\nfrom diffusers.modular_pipelines import PipelineBlock\n\nclass MyCustomBlock(PipelineBlock):\n    _requirements = {\n        \"transformers\": \">=4.44.0\",\n        \"sentencepiece\": \">=0.2.0\"\n    }\n```\n\nWhen there are blocks with different requirements, Diffusers merges their requirements.\n\n```py\nfrom diffusers.modular_pipelines import SequentialPipelineBlocks\n\nclass BlockA(PipelineBlock):\n    _requirements = {\"transformers\": \">=4.44.0\"}\n    # ...\n\nclass BlockB(PipelineBlock):\n    _requirements = {\"sentencepiece\": \">=0.2.0\"}\n    # ...\n\npipe = SequentialPipelineBlocks.from_blocks_dict({\n    \"block_a\": BlockA,\n    \"block_b\": BlockB,\n})\n```\n\nWhen this block is saved with [`~ModularPipeline.save_pretrained`], the requirements are saved to the `modular_config.json` file. When this block is loaded, Diffusers checks each requirement against the current environment. If there is a mismatch or a package isn't found, Diffusers returns the following warning.\n\n```md\n# missing package\nxyz-package was specified in the requirements but wasn't found in the current environment.\n\n# version mismatch\nxyz requirement 'specific-version' is not satisfied by the installed version 'actual-version'. Things might work unexpected.\n```\n"
  },
  {
    "path": "docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# LoopSequentialPipelineBlocks\n\n[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.\n\nThis guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].\n\n## Loop wrapper\n\n[`~modular_pipelines.LoopSequentialPipelineBlocks`], is also known as the *loop wrapper* because it defines the loop structure, iteration variables, and configuration. Within the loop wrapper, you need the following variables.\n\n- `loop_inputs` are user provided values and equivalent to [`~modular_pipelines.ModularPipelineBlocks.inputs`].\n- `loop_intermediate_outputs` are new intermediate variables created by the block and added to the [`~modular_pipelines.PipelineState`]. It is equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`].\n- `__call__` method defines the loop structure and iteration logic.\n\n```py\nimport torch\nfrom diffusers.modular_pipelines import LoopSequentialPipelineBlocks, ModularPipelineBlocks, InputParam, OutputParam\n\nclass LoopWrapper(LoopSequentialPipelineBlocks):\n    model_name = \"test\"\n    @property\n    def description(self):\n        return \"I'm a loop!!\"\n    @property\n    def loop_inputs(self):\n        return [InputParam(name=\"num_steps\")]\n    @torch.no_grad()\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        # Loop structure - can be customized to your needs\n        for i in range(block_state.num_steps):\n            # loop_step executes all registered blocks in sequence\n            components, block_state = self.loop_step(components, block_state, i=i)\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\nThe loop wrapper can pass additional arguments, like current iteration index, to the loop blocks.\n\n## Loop blocks\n\nA loop block is a [`~modular_pipelines.ModularPipelineBlocks`], but the `__call__` method behaves differently.\n\n- It receives the iteration variable from the loop wrapper.\n- It works directly with the [`~modular_pipelines.BlockState`] instead of the [`~modular_pipelines.PipelineState`].\n- It doesn't require retrieving or updating the [`~modular_pipelines.BlockState`].\n\nLoop blocks share the same [`~modular_pipelines.BlockState`] to allow values to accumulate and change for each iteration in the loop.\n\n```py\nclass LoopBlock(ModularPipelineBlocks):\n    model_name = \"test\"\n    @property\n    def inputs(self):\n        return [InputParam(name=\"x\")]\n    @property\n    def intermediate_outputs(self):\n        # outputs produced by this block\n        return [OutputParam(name=\"x\")]\n    @property\n    def description(self):\n        return \"I'm a block used inside the `LoopWrapper` class\"\n    def __call__(self, components, block_state, i: int):\n        block_state.x += 1\n        return components, block_state\n```\n\n## LoopSequentialPipelineBlocks\n\nUse the [`~modular_pipelines.LoopSequentialPipelineBlocks.from_blocks_dict`] method to add the loop block to the loop wrapper to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].\n\n```py\nloop = LoopWrapper.from_blocks_dict({\"block1\": LoopBlock})\n```\n\nAdd more loop blocks to run within each iteration with [`~modular_pipelines.LoopSequentialPipelineBlocks.from_blocks_dict`]. This allows you to modify the blocks without changing the loop logic itself.\n\n```py\nloop = LoopWrapper.from_blocks_dict({\"block1\": LoopBlock(), \"block2\": LoopBlock})\n```\n"
  },
  {
    "path": "docs/source/en/modular_diffusers/mellon.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n\n## Using Custom Blocks with Mellon\n\n[Mellon](https://github.com/cubiq/Mellon) is a visual workflow interface that integrates with Modular Diffusers and is designed for node-based workflows.\n\n> [!WARNING]\n> Mellon is in early development and not ready for production use yet. Consider this a sneak peek of how the integration works!\n\n\nCustom blocks work in Mellon out of the box - just need to add a `mellon_pipeline_config.json` to your repository. This config file tells Mellon how to render your block's parameters as UI components.\n\nHere's what it looks like in action with the [Gemini Prompt Expander](https://huggingface.co/diffusers/gemini-prompt-expander-mellon) block:\n\n![Mellon custom block demo](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/modular_demo_dynamic.gif)\n\nTo use a modular diffusers custom block in Mellon:\n1. Drag a **Dynamic Block Node** from the ModularDiffusers section\n2. Enter the `repo_id` (e.g., `diffusers/gemini-prompt-expander-mellon`)\n3. Click **Load Custom Block**\n4. The node transforms to show your block's inputs and outputs\n\nNow let's walk through how to create this config for your own custom block.\n\n## Steps to create a Mellon config\n\n1. **Specify Mellon types for your parameters** - Each `InputParam`/`OutputParam` needs a type that tells Mellon what UI component to render (e.g., `\"textbox\"`, `\"dropdown\"`, `\"image\"`).\n2. **Generate `mellon_pipeline_config.json`** - Use our utility to generate a config template and push it to your Hub repository.\n3. **(Optional) Manually adjust the config** - Fine-tune the generated config for your specific needs.\n\n## Specify Mellon types for parameters\n\nMellon types determine how each parameter renders in the UI. If you don't specify a type for a parameter, it will default to `\"custom\"`, which renders as a simple connection dot. You can always adjust this later in the generated config.\n\n\n| Type | Input/Output | Description |\n|------|--------------|-------------|\n| `image` | Both | Image (PIL Image) |\n| `video` | Both | Video |\n| `text` | Both | Text display |\n| `textbox` | Input | Text input |\n| `dropdown` | Input | Dropdown selection menu |\n| `slider` | Input | Slider for numeric values |\n| `number` | Input | Numeric input |\n| `checkbox` | Input | Boolean toggle |\n\nFor parameters that need more configuration (like dropdowns with options, or sliders with min/max values), pass a `MellonParam` instance directly instead of a string. You can use one of the class methods below, or create a fully custom one with `MellonParam(name, label, type, ...)`.\n\n| Method | Description |\n|--------|-------------|\n| `MellonParam.Input.image(name)` | Image input |\n| `MellonParam.Input.textbox(name, default)` | Text input as textarea |\n| `MellonParam.Input.dropdown(name, options, default)` | Dropdown selection |\n| `MellonParam.Input.slider(name, default, min, max, step)` | Slider for numeric values |\n| `MellonParam.Input.number(name, default, min, max, step)` | Numeric input (no slider) |\n| `MellonParam.Input.seed(name, default)` | Seed input with randomize button |\n| `MellonParam.Input.checkbox(name, default)` | Boolean checkbox |\n| `MellonParam.Input.model(name)` | Model input for diffusers components |\n| `MellonParam.Output.image(name)` | Image output |\n| `MellonParam.Output.video(name)` | Video output |\n| `MellonParam.Output.text(name)` | Text output |\n| `MellonParam.Output.model(name)` | Model output for diffusers components |\n\nChoose one of the methods below to specify a Mellon type.\n\n### Using `metadata` in block definitions\n\nIf you're defining a custom block from scratch, add `metadata={\"mellon\": \"<type>\"}` directly to your `InputParam` and `OutputParam` definitions. If you're editing an existing custom block from the Hub, see [Editing custom blocks](./custom_blocks#editing-custom-blocks) for how to download it locally.\n\n```python\nclass GeminiPromptExpander(ModularPipelineBlocks):\n    \n    @property\n    def inputs(self) -> List[InputParam]:\n        return [\n            InputParam(\n                \"prompt\",\n                type_hint=str,\n                required=True,\n                description=\"Prompt to use\",\n                metadata={\"mellon\": \"textbox\"},  # Text input\n            )\n        ]\n    \n    @property\n    def intermediate_outputs(self) -> List[OutputParam]:\n        return [\n            OutputParam(\n                \"prompt\",\n                type_hint=str,\n                description=\"Expanded prompt by the LLM\",\n                metadata={\"mellon\": \"text\"},  # Text output\n            ),\n            OutputParam(\n                \"old_prompt\",\n                type_hint=str,\n                description=\"Old prompt provided by the user\",\n                # No metadata - we don't want to render this in UI\n            )\n        ]\n```\n\nFor full control over UI configuration, pass a `MellonParam` instance directly:\n```python\nfrom diffusers.modular_pipelines.mellon_node_utils import MellonParam\n\nInputParam(\n    \"mode\",\n    type_hint=str,\n    default=\"balanced\",\n    metadata={\"mellon\": MellonParam.Input.dropdown(\"mode\", options=[\"fast\", \"balanced\", \"quality\"])},\n)\n```\n\n### Using `input_types` and `output_types` when Generating Config\n\nIf you're working with an existing pipeline or prefer to keep your block definitions clean, specify types when generating the config using the `input_types/output_types` argument:\n```python\nfrom diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig\n\nmellon_config = MellonPipelineConfig.from_custom_block(\n    blocks,\n    input_types={\"prompt\": \"textbox\"},\n    output_types={\"prompt\": \"text\"}\n)\n```\n\n> [!NOTE]\n> When both `metadata` and `input_types`/`output_types` are specified, the arguments overrides `metadata`.\n\n## Generate and push the Mellon config\n\nAfter adding metadata to your block, generate the default Mellon configuration template and push it to the Hub:\n\n```python\nfrom diffusers import ModularPipelineBlocks\nfrom diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig\n\n# load your custom blocks from your local dir\nblocks = ModularPipelineBlocks.from_pretrained(\"/path/local/folder\", trust_remote_code=True)\n\n# Generate the default config template\nmellon_config = MellonPipelineConfig.from_custom_block(blocks)\n# push the default template to `repo_id`, you will need to pass the same local folder path so that it will save the config locally first\nmellon_config.save(\n    local_dir=\"/path/local/folder\",\n    repo_id= repo_id,\n    push_to_hub=True\n)\n```\n\nThis creates a `mellon_pipeline_config.json` file in your repository.\n\n## Review and adjust the config\n\nThe generated template is a starting point - you may want to adjust it for your needs. Let's walk through the generated config for the Gemini Prompt Expander:\n\n```json\n{\n  \"label\": \"Gemini Prompt Expander\",\n  \"default_repo\": \"\",\n  \"default_dtype\": \"\",\n  \"node_params\": {\n    \"custom\": {\n      \"params\": {\n        \"prompt\": {\n          \"label\": \"Prompt\",\n          \"type\": \"string\",\n          \"display\": \"textarea\",\n          \"default\": \"\"\n        },\n        \"out_prompt\": {\n          \"label\": \"Prompt\",\n          \"type\": \"string\",\n          \"display\": \"output\"\n        },\n        \"old_prompt\": {\n          \"label\": \"Old Prompt\",\n          \"type\": \"custom\",\n          \"display\": \"output\"\n        },\n        \"doc\": {\n          \"label\": \"Doc\",\n          \"type\": \"string\",\n          \"display\": \"output\"\n        }\n      },\n      \"input_names\": [\"prompt\"],\n      \"model_input_names\": [],\n      \"output_names\": [\"out_prompt\", \"old_prompt\", \"doc\"],\n      \"block_name\": \"custom\",\n      \"node_type\": \"custom\"\n    }\n  }\n}\n```\n\n### Understanding the Structure\n\nThe `params` dict defines how each UI element renders. The `input_names`, `model_input_names`, and `output_names` lists map these UI elements to the underlying [`ModularPipelineBlocks`]'s I/O interface:\n\n| Mellon Config | ModularPipelineBlocks |\n|---------------|----------------------|\n| `input_names` | `inputs` property |\n| `model_input_names` | `expected_components` property |\n| `output_names` | `intermediate_outputs` property |\n\nIn this example: `prompt` is the only input. There are no model components, and outputs include `out_prompt`, `old_prompt`, and `doc`.\n\nNow let's look at the `params` dict:\n\n- **`prompt`**: An input parameter with `display: \"textarea\"` (renders as a text input box), `label: \"Prompt\"` (shown in the UI), and `default: \"\"` (starts empty). The `type: \"string\"` field is important in Mellon because it determines which nodes can connect together - only matching types can be linked with \"noodles\".\n\n- **`out_prompt`**: The expanded prompt output. The `out_` prefix was automatically added because the input and output share the same name (`prompt`), avoiding naming conflicts in the config. It has `display: \"output\"` which renders as an output socket.\n\n- **`old_prompt`**: Has `type: \"custom\"` because we didn't specify metadata. This renders as a simple dot in the UI. Since we don't actually want to expose this in the UI, we can remove it.\n\n- **`doc`**: The documentation output, automatically added to all custom blocks.\n\n### Making Adjustments\n\nRemove `old_prompt` from both `params` and `output_names` because you won't need to use it.\n\n```json\n{\n  \"label\": \"Gemini Prompt Expander\",\n  \"default_repo\": \"\",\n  \"default_dtype\": \"\",\n  \"node_params\": {\n    \"custom\": {\n      \"params\": {\n        \"prompt\": {\n          \"label\": \"Prompt\",\n          \"type\": \"string\",\n          \"display\": \"textarea\",\n          \"default\": \"\"\n        },\n        \"out_prompt\": {\n          \"label\": \"Prompt\",\n          \"type\": \"string\",\n          \"display\": \"output\"\n        },\n        \"doc\": {\n          \"label\": \"Doc\",\n          \"type\": \"string\",\n          \"display\": \"output\"\n        }\n      },\n      \"input_names\": [\"prompt\"],\n      \"model_input_names\": [],\n      \"output_names\": [\"out_prompt\", \"doc\"],\n      \"block_name\": \"custom\",\n      \"node_type\": \"custom\"\n    }\n  }\n}\n```\n\nSee the final config at [diffusers/gemini-prompt-expander-mellon](https://huggingface.co/diffusers/gemini-prompt-expander-mellon)."
  },
  {
    "path": "docs/source/en/modular_diffusers/modular_diffusers_states.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# States\n\nBlocks rely on the [`~modular_pipelines.PipelineState`] and [`~modular_pipelines.BlockState`] data structures for communicating and sharing data.\n\n| State | Description |\n|-------|-------------|\n| [`~modular_pipelines.PipelineState`] | Maintains the overall data required for a pipeline's execution and allows blocks to read and update its data. |\n| [`~modular_pipelines.BlockState`] | Allows each block to perform its computation with the necessary data from `inputs`|\n\nThis guide explains how states work and how they connect blocks.\n\n## PipelineState\n\nThe [`~modular_pipelines.PipelineState`] is a global state container for all blocks. It maintains the complete runtime state of the pipeline and provides a structured way for blocks to read from and write to shared data.\n\n[`~modular_pipelines.PipelineState`] stores all data in a `values` dict, which is a **mutable** state containing user provided input values and intermediate output values generated by blocks. If a block modifies an `input`, it will be reflected in the `values` dict after calling `set_block_state`.\n\n```py\nPipelineState(\n  values={\n    'prompt': 'a cat'\n    'guidance_scale': 7.0\n    'num_inference_steps': 25\n    'prompt_embeds': Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1]))\n    'negative_prompt_embeds': None\n  },\n)\n```\n\n## BlockState\n\nThe [`~modular_pipelines.BlockState`] is a local view of the relevant variables an individual block needs from [`~modular_pipelines.PipelineState`] for performing it's computations.\n\nAccess these variables directly as attributes like `block_state.image`.\n\n```py\nBlockState(\n    image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494640>\n)\n```\n\nWhen a block's `__call__` method is executed, it retrieves the [`BlockState`] with `self.get_block_state(state)`, performs it's operations, and updates [`~modular_pipelines.PipelineState`] with `self.set_block_state(state, block_state)`.\n\n```py\ndef __call__(self, components, state):\n    # retrieve BlockState\n    block_state = self.get_block_state(state)\n\n    # computation logic on inputs\n\n    # update PipelineState\n    self.set_block_state(state, block_state)\n    return components, state\n```\n\n## State interaction\n\n[`~modular_pipelines.PipelineState`] and [`~modular_pipelines.BlockState`] interaction is defined by a block's `inputs`, and `intermediate_outputs`.\n\n- `inputs`, a block can modify an input - like `block_state.image` - and this change can be propagated globally to [`~modular_pipelines.PipelineState`] by calling `set_block_state`.\n- `intermediate_outputs`, is a new variable that a block creates. It is added to the [`~modular_pipelines.PipelineState`]'s `values` dict and is available as for subsequent blocks or accessed by users as a final output from the pipeline.\n"
  },
  {
    "path": "docs/source/en/modular_diffusers/modular_pipeline.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ModularPipeline\n\n[`ModularPipeline`] converts [`~modular_pipelines.ModularPipelineBlocks`] into an executable pipeline that loads models and performs the computation steps defined in the blocks. It is the main interface for running a pipeline and the API is very similar to [`DiffusionPipeline`] but with a few key differences.\n\n- **Loading is lazy.** With [`DiffusionPipeline`], [`~DiffusionPipeline.from_pretrained`] creates the pipeline and loads all models at the same time. With [`ModularPipeline`], creating and loading are two separate steps: [`~ModularPipeline.from_pretrained`] reads the configuration and knows where to load each component from, but doesn't actually load the model weights. You load the models later with [`~ModularPipeline.load_components`], which is where you pass loading arguments like `torch_dtype` and `quantization_config`.\n\n- **Two ways to create a pipeline.** You can use [`~ModularPipeline.from_pretrained`] with an existing diffusers model repository — it automatically maps to the default pipeline blocks and then converts to a [`ModularPipeline`] with no extra setup. You can check the [modular_pipelines_directory](https://github.com/huggingface/diffusers/tree/main/src/diffusers/modular_pipelines) to see which models are currently supported. You can also assemble your own pipeline from [`ModularPipelineBlocks`] and convert it with the [`~ModularPipelineBlocks.init_pipeline`] method (see [Creating a pipeline](#creating-a-pipeline) for more details).\n\n- **Running the pipeline is the same.** Once loaded, you call the pipeline with the same arguments you're used to. A single [`ModularPipeline`] can support multiple workflows (text-to-image, image-to-image, inpainting, etc.) when the pipeline blocks use [`AutoPipelineBlocks`](./auto_pipeline_blocks) to automatically select the workflow based on your inputs.\n\nBelow are complete examples for text-to-image, image-to-image, and inpainting with SDXL.\n\n<hfoptions id=\"example\">\n<hfoption id=\"text-to-image\">\n\n```py\nimport torch\nfrom diffusers import ModularPipeline\n\npipeline = ModularPipeline.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\")\npipeline.load_components(torch_dtype=torch.float16)\npipeline.to(\"cuda\")\n\nimage = pipeline(prompt=\"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\").images[0]\nimage.save(\"modular_t2i_out.png\")\n```\n\n</hfoption>\n<hfoption id=\"image-to-image\">\n\n```py\nimport torch\nfrom diffusers import ModularPipeline\nfrom diffusers.utils import load_image\n\npipeline = ModularPipeline.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\")\npipeline.load_components(torch_dtype=torch.float16)\npipeline.to(\"cuda\")\n\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png\"\ninit_image = load_image(url)\nprompt = \"a dog catching a frisbee in the jungle\"\nimage = pipeline(prompt=prompt, image=init_image, strength=0.8).images[0]\nimage.save(\"modular_i2i_out.png\")\n```\n\n</hfoption>\n<hfoption id=\"inpainting\">\n\n```py\nimport torch\nfrom diffusers import ModularPipeline\nfrom diffusers.utils import load_image\n\npipeline = ModularPipeline.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\")\npipeline.load_components(torch_dtype=torch.float16)\npipeline.to(\"cuda\")\n\nimg_url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png\"\nmask_url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png\"\n\ninit_image = load_image(img_url)\nmask_image = load_image(mask_url)\n\nprompt = \"A deep sea diver floating\"\nimage = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85).images[0]\nimage.save(\"modular_inpaint_out.png\")\n```\n\n</hfoption>\n</hfoptions>\n\nThis guide will show you how to create a [`ModularPipeline`], manage its components, and run the pipeline.\n\n## Creating a pipeline\n\nThere are two ways to create a [`ModularPipeline`]. Assemble and create a pipeline from [`ModularPipelineBlocks`] with [`~ModularPipelineBlocks.init_pipeline`], or load an existing pipeline with [`~ModularPipeline.from_pretrained`].\n\nYou can also initialize a [`ComponentsManager`](./components_manager) to handle device placement and memory management. If you don't need automatic offloading, you can skip this and move the pipeline to your device manually with `pipeline.to(\"cuda\")`.\n\n> [!TIP]\n> Refer to the [ComponentsManager](./components_manager) doc for more details about how it can help manage components across different workflows.\n\n### init_pipeline\n\n[`~ModularPipelineBlocks.init_pipeline`] converts any [`ModularPipelineBlocks`] into a [`ModularPipeline`].\n\nLet's define a minimal block to see how it works:\n\n```py\nfrom transformers import CLIPTextModel\nfrom diffusers.modular_pipelines import (\n    ComponentSpec,\n    ModularPipelineBlocks,\n    PipelineState,\n)\n\nclass MyBlock(ModularPipelineBlocks):\n    @property\n    def expected_components(self):\n        return [\n            ComponentSpec(\n                name=\"text_encoder\",\n                type_hint=CLIPTextModel,\n                pretrained_model_name_or_path=\"openai/clip-vit-large-patch14\",\n            ),\n        ]\n\n    def __call__(self, components, state: PipelineState) -> PipelineState:\n        return components, state\n```\n\nCall [`~ModularPipelineBlocks.init_pipeline`] to convert it into a pipeline. The `blocks` attribute on the pipeline is the blocks it was created from — it determines the expected inputs, outputs, and computation logic.\n\n```py\nblock = MyBlock()\npipe = block.init_pipeline()\npipe.blocks\n```\n\n```\nMyBlock {\n  \"_class_name\": \"MyBlock\",\n  \"_diffusers_version\": \"0.37.0.dev0\"\n}\n```\n\n> [!WARNING]\n> Blocks are mutable — you can freely add, remove, or swap blocks before creating a pipeline. However, once a pipeline is created, modifying `pipeline.blocks` won't affect the pipeline because it returns a copy. If you want a different block structure, create a new pipeline after modifying the blocks.\n\nWhen you call [`~ModularPipelineBlocks.init_pipeline`] without a repository, it uses the `pretrained_model_name_or_path` defined in the block's [`ComponentSpec`] to determine where to load each component from. Printing the pipeline shows the component loading configuration.\n\n```py\npipe\nModularPipeline {\n  \"_blocks_class_name\": \"MyBlock\",\n  \"_class_name\": \"ModularPipeline\",\n  \"_diffusers_version\": \"0.37.0.dev0\",\n  \"text_encoder\": [\n    null,\n    null,\n    {\n      \"pretrained_model_name_or_path\": \"openai/clip-vit-large-patch14\",\n      \"revision\": null,\n      \"subfolder\": \"\",\n      \"type_hint\": [\n        \"transformers\",\n        \"CLIPTextModel\"\n      ],\n      \"variant\": null\n    }\n  ]\n}\n```\n\nIf you pass a repository to [`~ModularPipelineBlocks.init_pipeline`], it overrides the loading path by matching your block's components against the pipeline config in that repository (`model_index.json` or `modular_model_index.json`).\n\nIn the example below, the `pretrained_model_name_or_path` will be updated to `\"stabilityai/stable-diffusion-xl-base-1.0\"`.\n\n```py\npipe = block.init_pipeline(\"stabilityai/stable-diffusion-xl-base-1.0\")\npipe\nModularPipeline {\n  \"_blocks_class_name\": \"MyBlock\",\n  \"_class_name\": \"ModularPipeline\",\n  \"_diffusers_version\": \"0.37.0.dev0\",\n  \"text_encoder\": [\n    null,\n    null,\n    {\n      \"pretrained_model_name_or_path\": \"stabilityai/stable-diffusion-xl-base-1.0\",\n      \"revision\": null,\n      \"subfolder\": \"text_encoder\",\n      \"type_hint\": [\n        \"transformers\",\n        \"CLIPTextModel\"\n      ],\n      \"variant\": null\n    }\n  ]\n}\n```\n\nIf a component in your block doesn't exist in the repository, it remains `null` and is skipped during [`~ModularPipeline.load_components`].\n\n### from_pretrained\n\n[`~ModularPipeline.from_pretrained`] is a convenient way to create a [`ModularPipeline`] without defining blocks yourself.\n\nIt works with three types of repositories.\n\n**A regular diffusers repository.** Pass any supported model repository and it automatically maps to the default pipeline blocks. Currently supported models include SDXL, Wan, Qwen, Z-Image, Flux, and Flux2.\n\n```py\nfrom diffusers import ModularPipeline, ComponentsManager\n\ncomponents = ComponentsManager()\npipeline = ModularPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", components_manager=components\n)\n```\n\n**A modular repository.** These repositories contain a `modular_model_index.json` that specifies where to load each component from — the components can come from different repositories and the modular repository itself may not contain any model weights. For example, [diffusers/flux2-bnb-4bit-modular](https://huggingface.co/diffusers/flux2-bnb-4bit-modular) loads a quantized transformer from one repository and the remaining components from another. See [Modular repository](#modular-repository) for more details on the format.\n\n```py\nfrom diffusers import ModularPipeline, ComponentsManager\n\ncomponents = ComponentsManager()\npipeline = ModularPipeline.from_pretrained(\n    \"diffusers/flux2-bnb-4bit-modular\", components_manager=components\n)\n```\n\n**A modular repository with custom code.** Some repositories include custom pipeline blocks alongside the loading configuration. Add `trust_remote_code=True` to load them. See [Custom blocks](./custom_blocks) for how to create your own.\n\n```py\nfrom diffusers import ModularPipeline, ComponentsManager\n\ncomponents = ComponentsManager()\npipeline = ModularPipeline.from_pretrained(\n    \"diffusers/Florence2-image-Annotator\", trust_remote_code=True, components_manager=components\n)\n```\n\n## Loading components\n\nA [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load components with [`~ModularPipeline.load_components`].\n\nThis will load all the components that have a valid loading spec.\n\n```py\nimport torch\n\npipeline.load_components(torch_dtype=torch.float16)\n```\n\nYou can also load specific components by name. The example below only loads the `text_encoder`.\n\n```py\npipeline.load_components(names=[\"text_encoder\"], torch_dtype=torch.float16)\n```\n\nAfter loading, printing the pipeline shows which components are loaded — the first two fields change from `null` to the component's library and class.\n\n```py\npipeline\n```\n\n```\n# text_encoder is loaded - shows library and class\n\"text_encoder\": [\n  \"transformers\",\n  \"CLIPTextModel\",\n  { ... }\n]\n\n# unet is not loaded yet - still null\n\"unet\": [\n  null,\n  null,\n  { ... }\n]\n```\n\nLoading keyword arguments like `torch_dtype`, `variant`, `revision`, and `quantization_config` are passed through to `from_pretrained()` for each component. You can pass a single value to apply to all components, or a dict to set per-component values.\n\n```py\n# apply bfloat16 to all components\npipeline.load_components(torch_dtype=torch.bfloat16)\n\n# different dtypes per component\npipeline.load_components(torch_dtype={\"transformer\": torch.bfloat16, \"default\": torch.float32})\n```\n\n[`~ModularPipeline.load_components`] only loads components that haven't been loaded yet and have a valid loading spec. This means if you've already set a component on the pipeline, calling [`~ModularPipeline.load_components`] again won't reload it.\n\n## Updating components\n\n[`~ModularPipeline.update_components`] replaces a component on the pipeline with a new one. When a component is updated, the loading specifications are also updated in the pipeline config and [`~ModularPipeline.load_components`] will skip it on subsequent calls.\n\n### From AutoModel\n\nYou can pass a model object loaded with `AutoModel.from_pretrained()`. Models loaded this way are automatically tagged with their loading information.\n\n```py\nfrom diffusers import AutoModel\n\nunet = AutoModel.from_pretrained(\n    \"RunDiffusion/Juggernaut-XL-v9\", subfolder=\"unet\", variant=\"fp16\", torch_dtype=torch.float16\n)\npipeline.update_components(unet=unet)\n```\n\n### From ComponentSpec\n\nUse [`~ModularPipeline.get_component_spec`] to get a copy of the current component specification, modify it, and load a new component.\n\n```py\nunet_spec = pipeline.get_component_spec(\"unet\")\n\n# modify to load from a different repository\nunet_spec.pretrained_model_name_or_path = \"RunDiffusion/Juggernaut-XL-v9\"\n\n# load and update\nunet = unet_spec.load(torch_dtype=torch.float16)\npipeline.update_components(unet=unet)\n```\n\nYou can also create a [`ComponentSpec`] from scratch.\n\nNot all components are loaded from pretrained weights — some are created from a config (listed under `pipeline.config_component_names`). For these, use [`~ComponentSpec.create`] instead of [`~ComponentSpec.load`].\n\n```py\nguider_spec = pipeline.get_component_spec(\"guider\")\nguider_spec.config = {\"guidance_scale\": 5.0}\nguider = guider_spec.create()\npipeline.update_components(guider=guider)\n```\n\nOr simply pass the object directly.\n\n```py\nfrom diffusers.guiders import ClassifierFreeGuidance\n\nguider = ClassifierFreeGuidance(guidance_scale=5.0)\npipeline.update_components(guider=guider)\n```\n\nSee the [Guiders](../using-diffusers/guiders) guide for more details on available guiders and how to configure them.\n\n## Splitting a pipeline into stages\n\nSince blocks are composable, you can take a pipeline apart and reconstruct it into separate pipelines for each stage. The example below shows how we can separate the text encoder block from the rest of the pipeline, so you can encode the prompt independently and pass the embeddings to the main pipeline.\n\n```py\nfrom diffusers import ModularPipeline, ComponentsManager\nimport torch\n\ndevice = \"cuda\"\ndtype = torch.bfloat16\nrepo_id = \"black-forest-labs/FLUX.2-klein-4B\"\n\n# get the blocks and separate out the text encoder\nblocks = ModularPipeline.from_pretrained(repo_id).blocks\ntext_block = blocks.sub_blocks.pop(\"text_encoder\")\n\n# use ComponentsManager to handle offloading across multiple pipelines\nmanager = ComponentsManager()\nmanager.enable_auto_cpu_offload(device=device)\n\n# create separate pipelines for each stage\ntext_encoder_pipeline = text_block.init_pipeline(repo_id, components_manager=manager)\npipeline = blocks.init_pipeline(repo_id, components_manager=manager)\n\n# encode text\ntext_encoder_pipeline.load_components(torch_dtype=dtype)\ntext_embeddings = text_encoder_pipeline(prompt=\"a cat\").get_by_kwargs(\"denoiser_input_fields\")\n\n# denoise and decode\npipeline.load_components(torch_dtype=dtype)\noutput = pipeline(\n    **text_embeddings,\n    num_inference_steps=4,\n).images[0]\n```\n\n[`ComponentsManager`] handles memory across multiple pipelines. Unlike the offloading strategies in [`DiffusionPipeline`] that follow a fixed order, [`ComponentsManager`] makes offloading decisions dynamically each time a model forward pass runs, based on the current memory situation. This means it works regardless of how many pipelines you create or what order you run them in. See the [ComponentsManager](./components_manager) guide for more details.\n\nIf pipeline stages share components (e.g., the same VAE used for encoding and decoding), you can use [`~ModularPipeline.update_components`] to pass an already-loaded component to another pipeline instead of loading it again.\n\n## Modular repository\n\nA repository is required if the pipeline blocks use *pretrained components*. The repository supplies loading specifications and metadata.\n\n[`ModularPipeline`] works with regular diffusers repositories out of the box. However, you can also create a *modular repository* for more flexibility. A modular repository contains a `modular_model_index.json` file containing the following 3 elements.\n\n- `library` and `class` shows which library the component was loaded from and its class. If `null`, the component hasn't been loaded yet.\n- `loading_specs_dict` contains the information required to load the component such as the repository and subfolder it is loaded from.\n\nThe key advantage of a modular repository is that components can be loaded from different repositories. For example, [diffusers/flux2-bnb-4bit-modular](https://huggingface.co/diffusers/flux2-bnb-4bit-modular) loads a quantized transformer from `diffusers/FLUX.2-dev-bnb-4bit` while loading the remaining components from `black-forest-labs/FLUX.2-dev`.\n\nTo convert a regular diffusers repository into a modular one, create the pipeline using the regular repository, and then push to the Hub. The saved repository will contain a `modular_model_index.json` with all the loading specifications.\n\n```py\nfrom diffusers import ModularPipeline\n\n# load from a regular repo\npipeline = ModularPipeline.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\")\n\n# push as a modular repository\npipeline.save_pretrained(\"local/path\", repo_id=\"my-username/sdxl-modular\", push_to_hub=True)\n```\n\nA modular repository can also include custom pipeline blocks as Python code. This allows you to share specialized blocks that aren't native to Diffusers. For example, [diffusers/Florence2-image-Annotator](https://huggingface.co/diffusers/Florence2-image-Annotator) contains custom blocks alongside the loading configuration:\n\n```\nFlorence2-image-Annotator/\n├── block.py                    # Custom pipeline blocks implementation\n├── config.json                 # Pipeline configuration and auto_map\n├── mellon_config.json          # UI configuration for Mellon\n└── modular_model_index.json    # Component loading specifications\n```\n\nThe `config.json` file contains an `auto_map` key that tells [`ModularPipeline`] where to find the custom blocks:\n\n```json\n{\n  \"_class_name\": \"Florence2AnnotatorBlocks\",\n  \"auto_map\": {\n    \"ModularPipelineBlocks\": \"block.Florence2AnnotatorBlocks\"\n  }\n}\n```\n\nLoad custom code repositories with `trust_remote_code=True` as shown in [from_pretrained](#from_pretrained). See [Custom blocks](./custom_blocks) for how to create and share your own."
  },
  {
    "path": "docs/source/en/modular_diffusers/overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Overview\n\n> [!WARNING]\n> Modular Diffusers is under active development and it's API may change.\n\nModular Diffusers is a unified pipeline system that simplifies your workflow with *pipeline blocks*.\n\n- Blocks are reusable and you only need to create new blocks that are unique to your pipeline.\n- Blocks can be mixed and matched to adapt to or create a pipeline for a specific workflow or multiple workflows.\n\nThe Modular Diffusers docs are organized as shown below.\n\n## Quickstart\n\n- The [quickstart](./quickstart) shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it.\n\n## ModularPipelineBlocks\n\n- [States](./modular_diffusers_states) explains how data is shared and communicated between blocks and [`ModularPipeline`].\n- [ModularPipelineBlocks](./pipeline_block) is the most basic unit of a [`ModularPipeline`] and this guide shows you how to create one.\n- [SequentialPipelineBlocks](./sequential_pipeline_blocks) is a type of block that chains multiple blocks so they run one after another, passing data along the chain. This guide shows you how to create [`~modular_pipelines.SequentialPipelineBlocks`] and how they connect and work together.\n- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks) is a type of block that runs a series of blocks in a loop. This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].\n- [AutoPipelineBlocks](./auto_pipeline_blocks) is a type of block that automatically chooses which blocks to run based on the input. This guide shows you how to create [`~modular_pipelines.AutoPipelineBlocks`].\n- [Building Custom Blocks](./custom_blocks) shows you how to create your own custom blocks and share them on the Hub.\n\n## ModularPipeline\n\n- [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`].\n- [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines.\n- [Guiders](../using-diffusers/guiders) shows you how to use different guidance methods in the pipeline.\n\n## Mellon Integration\n\n- [Using Custom Blocks with Mellon](./mellon) shows you how to make your custom blocks work with [Mellon](https://github.com/cubiq/Mellon), a visual node-based interface for building workflows."
  },
  {
    "path": "docs/source/en/modular_diffusers/pipeline_block.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ModularPipelineBlocks\n\n[`~modular_pipelines.ModularPipelineBlocks`] is the basic block for building a [`ModularPipeline`]. It defines what components, inputs/outputs, and computation a block should perform for a specific step in a pipeline. A [`~modular_pipelines.ModularPipelineBlocks`] connects with other blocks, using [state](./modular_diffusers_states), to enable the modular construction of workflows.\n\nA [`~modular_pipelines.ModularPipelineBlocks`] on it's own can't be executed. It is a blueprint for what a step should do in a pipeline. To actually run and execute a pipeline, the [`~modular_pipelines.ModularPipelineBlocks`] needs to be converted into a [`ModularPipeline`].\n\nThis guide will show you how to create a [`~modular_pipelines.ModularPipelineBlocks`].\n\n## Inputs and outputs\n\n> [!TIP]\n> Refer to the [States](./modular_diffusers_states) guide if you aren't familiar with how state works in Modular Diffusers.\n\nA [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermediate_outputs`.\n\n- `inputs` are values a block reads from the [`~modular_pipelines.PipelineState`] to perform its computation. These can be values provided by a user (like a prompt or image) or values produced by a previous block (like encoded `image_latents`). \n\n    Use `InputParam` to define `inputs`.\n\n```py\nclass ImageEncodeStep(ModularPipelineBlocks):\n    ...\n\n    @property\n    def inputs(self):\n        return [\n            InputParam(name=\"image\", type_hint=\"PIL.Image\", required=True, description=\"raw input image to process\"),\n        ]\n    ...\n```\n\n- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.\n\n    Use `OutputParam` to define `intermediate_outputs`.\n\n```py\nclass ImageEncodeStep(ModularPipelineBlocks):\n    ...\n\n    @property\n    def intermediate_outputs(self):\n        return [\n            OutputParam(name=\"image_latents\", description=\"latents representing the image\"),\n        ]\n\n    ...\n```\n\nThe intermediate inputs and outputs share data to connect blocks. They are accessible at any point, allowing you to track the workflow's progress.\n\n## Components and configs\n\nThe components and pipeline-level configs a block needs are specified in [`ComponentSpec`] and [`~modular_pipelines.ConfigSpec`].\n\n- [`ComponentSpec`] contains the expected components used by a block. You need the `name` of the component and ideally a `type_hint` that specifies exactly what the component is.\n- [`~modular_pipelines.ConfigSpec`] contains pipeline-level settings that control behavior across all blocks.\n\n```py\nclass ImageEncodeStep(ModularPipelineBlocks):\n    ...\n\n    @property\n    def expected_components(self):\n        return [\n            ComponentSpec(name=\"vae\", type_hint=AutoencoderKL),\n        ]\n\n    @property\n    def expected_configs(self):\n        return [\n            ConfigSpec(\"force_zeros_for_empty_prompt\", True),\n        ]\n\n    ...\n```\n\nWhen the blocks are converted into a pipeline, the components become available to the block as the first argument in `__call__`.\n\n## Computation logic\n\nThe computation a block performs is defined in the `__call__` method and it follows a specific structure.\n\n1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`.\n2. Implement the computation logic on the `inputs`.\n3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].\n4. Return the components and state which becomes available to the next block.\n\n```py\nclass ImageEncodeStep(ModularPipelineBlocks):\n\n    def __call__(self, components, state):\n        # Get a local view of the state variables this block needs\n        block_state = self.get_block_state(state)\n\n        # Your computation logic here\n        # block_state contains all your inputs\n        # Access them like: block_state.image, block_state.processed_image\n\n        # Update the pipeline state with your updated block_states\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\n## Putting it all together\n\nHere is the complete block with all the pieces connected.\n\n```py\nfrom diffusers import ComponentSpec, AutoencoderKL\nfrom diffusers.modular_pipelines import InputParam, ModularPipelineBlocks, OutputParam\n\n\nclass ImageEncodeStep(ModularPipelineBlocks):\n\n    @property\n    def description(self):\n        return \"Encode an image into latent space.\"\n\n    @property\n    def expected_components(self):\n        return [\n            ComponentSpec(name=\"vae\", type_hint=AutoencoderKL),\n        ]\n\n    @property\n    def inputs(self):\n        return [\n            InputParam(name=\"image\", type_hint=\"PIL.Image\", required=True, description=\"raw input image to process\"),\n        ]\n\n    @property\n    def intermediate_outputs(self):\n        return [\n            OutputParam(name=\"image_latents\", type_hint=\"torch.Tensor\", description=\"latents representing the image\"),\n        ]\n\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        block_state.image_latents = components.vae.encode(block_state.image)\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\nEvery block has a `doc` property that is automatically generated from the properties you defined above. It provides a summary of the block's description, components, inputs, and outputs.\n\n```py\nblock = ImageEncoderStep()\nprint(block.doc)\nclass ImageEncodeStep\n\n  Encode an image into latent space.\n\n  Components:\n      vae (`AutoencoderKL`)\n\n  Inputs:\n      image (`PIL.Image`):\n          raw input image to process\n\n  Outputs:\n      image_latents (`torch.Tensor`):\n          latents representing the image\n```"
  },
  {
    "path": "docs/source/en/modular_diffusers/quickstart.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Quickstart\n\nModular Diffusers is a framework for quickly building flexible and customizable pipelines. These pipelines can go beyond what standard `DiffusionPipeline`s can do. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface for running generation tasks.\n\nThis guide shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it.\n\n## Run a pipeline\n\n[`ModularPipeline`] is the main interface for loading, running, and managing modular pipelines.\n```py\nimport torch\nfrom diffusers import ModularPipeline, ComponentsManager\n\n# Use ComponentsManager to enable auto CPU offloading for memory efficiency\nmanager = ComponentsManager()\nmanager.enable_auto_cpu_offload(device=\"cuda:0\")\n\npipe = ModularPipeline.from_pretrained(\"Qwen/Qwen-Image\", components_manager=manager)\npipe.load_components(torch_dtype=torch.bfloat16)\n\nimage = pipe(\n    prompt=\"cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney\",\n).images[0]\nimage\n```\n\n[`~ModularPipeline.from_pretrained`] uses lazy loading - it reads the configuration to learn where to load each component from, but doesn't actually load the model weights until you call [`~ModularPipeline.load_components`]. This gives you control over when and how components are loaded.\n\n> [!TIP]\n> `ComponentsManager` with `enable_auto_cpu_offload` automatically moves models between CPU and GPU as needed, reducing memory usage for large models like Qwen-Image. Learn more in the [ComponentsManager](./components_manager) guide.\n>\n> If you don't need offloading, remove the `components_manager` argument and move the pipeline to your device manually with `to(\"cuda\")`.\n\nLearn more about creating and loading pipelines in the [Creating a pipeline](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#creating-a-pipeline) and [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guides.\n\n## Understand the structure\n\nA [`ModularPipeline`] has two parts: a **definition** (the blocks) and a **state** (the loaded components and configs).\n\nPrint the pipeline to see its state — the components and their loading status and configuration.\n```py\nprint(pipe)\n```\n```\nQwenImageModularPipeline {\n  \"_blocks_class_name\": \"QwenImageAutoBlocks\",\n  \"_class_name\": \"QwenImageModularPipeline\",\n  \"_diffusers_version\": \"0.37.0.dev0\",\n  \"transformer\": [\n    \"diffusers\",\n    \"QwenImageTransformer2DModel\",\n    {\n      \"pretrained_model_name_or_path\": \"Qwen/Qwen-Image\",\n      \"revision\": null,\n      \"subfolder\": \"transformer\",\n      \"type_hint\": [\n        \"diffusers\",\n        \"QwenImageTransformer2DModel\"\n      ],\n      \"variant\": null\n    }\n  ],\n  ...\n}\n```\n\nAccess the definition through `pipe.blocks` — this is the [`~modular_pipelines.ModularPipelineBlocks`] that defines the pipeline's workflows, inputs, outputs, and computation logic.\n```py\nprint(pipe.blocks)\n```\n```\nQwenImageAutoBlocks(\n  Class: SequentialPipelineBlocks\n\n  Description: Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n      \n      Supported workflows:\n        - `text2image`: requires `prompt`\n        - `image2image`: requires `prompt`, `image`\n        - `inpainting`: requires `prompt`, `mask_image`, `image`\n        - `controlnet_text2image`: requires `prompt`, `control_image`\n        ...\n\n  Components:\n      text_encoder (`Qwen2_5_VLForConditionalGeneration`)\n      vae (`AutoencoderKLQwenImage`)\n      transformer (`QwenImageTransformer2DModel`)\n      ...\n\n  Sub-Blocks:\n    [0] text_encoder (QwenImageAutoTextEncoderStep)\n    [1] vae_encoder (QwenImageAutoVaeEncoderStep)\n    [2] controlnet_vae_encoder (QwenImageOptionalControlNetVaeEncoderStep)\n    [3] denoise (QwenImageAutoCoreDenoiseStep)\n    [4] decode (QwenImageAutoDecodeStep)\n)\n```\n\nThe output returns:\n- The supported workflows (text2image, image2image, inpainting, etc.)\n- The Sub-Blocks it's composed of (text_encoder, vae_encoder, denoise, decode)\n\n### Workflows\n\nThis pipeline supports multiple workflows and adapts its behavior based on the inputs you provide. For example, if you pass `image` to the pipeline, it runs an image-to-image workflow instead of text-to-image. Learn more about how this works under the hood in the [AutoPipelineBlocks](https://huggingface.co/docs/diffusers/modular_diffusers/auto_pipeline_blocks) guide.\n\n```py\nfrom diffusers.utils import load_image\n\ninput_image = load_image(\"https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true\")\n\nimage = pipe(\n    prompt=\"cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney\",\n    image=input_image,\n).images[0]\n```\n\nUse `get_workflow()` to extract the blocks for a specific workflow. Pass the workflow name (e.g., `\"image2image\"`, `\"inpainting\"`, `\"controlnet_text2image\"`) to get only the blocks relevant to that workflow. This is useful when you want to customize or debug a specific workflow. You can check `pipe.blocks.available_workflows` to see all available workflows.\n```py\nimg2img_blocks = pipe.blocks.get_workflow(\"image2image\")\n```\n\n\n### Sub-blocks\n\nBlocks can contain other blocks. `pipe.blocks` gives you the top-level block definition (here, `QwenImageAutoBlocks`), while `sub_blocks` lets you access the smaller blocks inside it.\n\n`QwenImageAutoBlocks` is composed of: `text_encoder`, `vae_encoder`, `controlnet_vae_encoder`, `denoise`, and `decode`.\n\nThese sub-blocks run one after another and data flows linearly from one block to the next — each block's `intermediate_outputs` become available as `inputs` to the next block. This is how [`SequentialPipelineBlocks`](./sequential_pipeline_blocks) work.\n\nYou can access them through the `sub_blocks` property. The `doc` property is useful for seeing the full documentation of any block, including its inputs, outputs, and components.\n```py\nvae_encoder_block = pipe.blocks.sub_blocks[\"vae_encoder\"]\nprint(vae_encoder_block.doc)\n```\n\nThis block can be converted to a pipeline so that it can run on its own with [`~ModularPipelineBlocks.init_pipeline`].\n```py\nvae_encoder_pipe = vae_encoder_block.init_pipeline()\n\n# Reuse the VAE we already loaded, we can reuse it with update_components() method\nvae_encoder_pipe.update_components(vae=pipe.vae)\n\n# Run just this block\nimage_latents = vae_encoder_pipe(image=input_image).image_latents\nprint(image_latents.shape)\n```\n\nIt reuses the VAE from our original pipeline instead of reloading it, keeping memory usage efficient. Learn more in the [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guide.\n\nSince blocks are composable, you can modify the pipeline's definition by adding, removing, or swapping blocks to create new workflows. In the next section, we'll add a canny edge detection block to a ControlNet pipeline, so you can pass a regular image instead of a pre-processed canny edge map.\n\n## Compose new workflows\n\nLet's add a canny edge detection block to a ControlNet pipeline. First, load a pre-built canny block from the Hub (see [Building Custom Blocks](https://huggingface.co/docs/diffusers/modular_diffusers/custom_blocks) to create your own).\n```py\nfrom diffusers.modular_pipelines import ModularPipelineBlocks\n\n# Load a canny block from the Hub\ncanny_block = ModularPipelineBlocks.from_pretrained(\n    \"diffusers-internal-dev/canny-filtering\",\n    trust_remote_code=True,\n)\n\nprint(canny_block.doc)\n```\n```\nclass CannyBlock\n\n  Inputs:\n      image (`Union[Image, ndarray]`):\n          Image to compute canny filter on\n      low_threshold (`int`, *optional*, defaults to 50):\n          Low threshold for the canny filter.\n      high_threshold (`int`, *optional*, defaults to 200):\n          High threshold for the canny filter.\n      ...\n\n  Outputs:\n      control_image (`PIL.Image`):\n          Canny map for input image\n```\n\nUse `get_workflow` to extract the ControlNet workflow from [`QwenImageAutoBlocks`].\n```py\n# Get the controlnet workflow that we want to work with\nblocks = pipe.blocks.get_workflow(\"controlnet_text2image\")\nprint(blocks.doc)\n```\n```\nclass SequentialPipelineBlocks\n\n  Inputs:\n      prompt (`str`):\n          The prompt or prompts to guide image generation.\n      control_image (`Image`):\n          Control image for ControlNet conditioning.\n      ...\n```\n\n\nThe extracted workflow is a [`SequentialPipelineBlocks`](./sequential_pipeline_blocks) and it currently requires `control_image` as input. Insert the canny block at the beginning so the pipeline accepts a regular image instead.\n```py\n# Insert canny at the beginning\nblocks.sub_blocks.insert(\"canny\", canny_block, 0)\n\n# Check the updated structure: CannyBlock is now listed as first sub-block\nprint(blocks)\n# Check the updated doc\nprint(blocks.doc)\n```\n```\nclass SequentialPipelineBlocks\n\n  Inputs:\n      image (`Union[Image, ndarray]`):\n          Image to compute canny filter on\n      low_threshold (`int`, *optional*, defaults to 50):\n          Low threshold for the canny filter.\n      high_threshold (`int`, *optional*, defaults to 200):\n          High threshold for the canny filter.\n      prompt (`str`):\n          The prompt or prompts to guide image generation.\n      ...\n```\n\nNow the pipeline takes `image` as input instead of `control_image`. Because blocks in a sequence share data automatically, the canny block's output (`control_image`) flows to the denoise block that needs it, and the canny block's input (`image`) becomes a pipeline input since no earlier block provides it.\n\nCreate a pipeline from the modified blocks and load a ControlNet model. The ControlNet isn't part of the original model repository, so load it separately and add it with [`~ModularPipeline.update_components`].\n```py\npipeline = blocks.init_pipeline(\"Qwen/Qwen-Image\", components_manager=manager)\n\npipeline.load_components(torch_dtype=torch.bfloat16)\n\n# Load the ControlNet model\ncontrolnet_spec = pipeline.get_component_spec(\"controlnet\")\ncontrolnet_spec.pretrained_model_name_or_path = \"InstantX/Qwen-Image-ControlNet-Union\"\ncontrolnet = controlnet_spec.load(torch_dtype=torch.bfloat16)\npipeline.update_components(controlnet=controlnet)\n```\n\nNow run the pipeline - the canny block preprocesses the image for ControlNet.\n```py\nfrom diffusers.utils import load_image\n\nprompt = \"cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney\"\nimage = load_image(\"https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true\")\n\noutput = pipeline(\n    prompt=prompt,\n    image=image,\n).images[0]\noutput\n```\n\n## Next steps\n\n<hfoptions id=\"next\">\n<hfoption id=\"Learn the basics\">\n\nUnderstand the core building blocks of Modular Diffusers:\n\n- [ModularPipelineBlocks](./pipeline_block): The basic unit for defining a step in a pipeline.\n- [SequentialPipelineBlocks](./sequential_pipeline_blocks): Chain blocks to run in sequence.\n- [AutoPipelineBlocks](./auto_pipeline_blocks): Create pipelines that support multiple workflows.\n- [States](./modular_diffusers_states): How data is shared between blocks.\n\n</hfoption>\n<hfoption id=\"Build custom blocks\">\n\nLearn how to create your own blocks with custom logic in the [Building Custom Blocks](./custom_blocks) guide.\n\n</hfoption>\n<hfoption id=\"Share components\">\n\nUse [`ComponentsManager`](./components_manager) to share models across multiple pipelines and manage memory efficiently.\n\n</hfoption>\n<hfoption id=\"Visual interface\">\n\nConnect modular pipelines to [Mellon](https://github.com/cubiq/Mellon), a visual node-based interface for building workflows. Custom blocks built with Modular Diffusers work out of the box with Mellon - no UI code required. Read more in the Mellon guide.\n\n</hfoption>\n</hfoptions>"
  },
  {
    "path": "docs/source/en/modular_diffusers/sequential_pipeline_blocks.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# SequentialPipelineBlocks\n\n[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.\n\nThis guide shows you how to connect two blocks into a [`~modular_pipelines.SequentialPipelineBlocks`].\n\nCreate two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `inputs`.\n\n<hfoptions id=\"sequential\">\n<hfoption id=\"InputBlock\">\n\n```py\nfrom diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam\n\nclass InputBlock(ModularPipelineBlocks):\n\n    @property\n    def inputs(self):\n        return [\n            InputParam(name=\"prompt\", type_hint=list, description=\"list of text prompts\"),\n            InputParam(name=\"num_images_per_prompt\", type_hint=int, description=\"number of images per prompt\"),\n        ]\n\n    @property\n    def intermediate_outputs(self):\n        return [\n            OutputParam(name=\"batch_size\", description=\"calculated batch size\"),\n        ]\n\n    @property\n    def description(self):\n        return \"A block that determines batch_size based on the number of prompts and num_images_per_prompt argument.\"\n\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        batch_size = len(block_state.prompt)\n        block_state.batch_size = batch_size * block_state.num_images_per_prompt\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\n</hfoption>\n<hfoption id=\"ImageEncoderBlock\">\n\n```py\nimport torch\nfrom diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam\n\nclass ImageEncoderBlock(ModularPipelineBlocks):\n\n    @property\n    def inputs(self):\n        return [\n            InputParam(name=\"image\", type_hint=\"PIL.Image\", description=\"raw input image to process\"),\n            InputParam(name=\"batch_size\", type_hint=int),\n        ]\n\n    @property\n    def intermediate_outputs(self):\n        return [\n            OutputParam(name=\"image_latents\", description=\"latents representing the image\"),\n        ]\n\n    @property\n    def description(self):\n        return \"Encode raw image into its latent presentation\"\n\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        # Simulate processing the image\n        # This will change the state of the image from a PIL image to a tensor for all blocks\n        block_state.image = torch.randn(1, 3, 512, 512)\n        block_state.batch_size = block_state.batch_size * 2\n        block_state.image_latents = torch.randn(1, 4, 64, 64)\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\n</hfoption>\n</hfoptions>\n\nConnect the two blocks by defining a [`~modular_pipelines.SequentialPipelineBlocks`]. List the block instances in `block_classes` and their corresponding names in `block_names`. The blocks are executed in the order they appear in `block_classes`, and data flows from one block to the next through [`~modular_pipelines.PipelineState`].\n\n```py\nclass ImageProcessingStep(SequentialPipelineBlocks):\n    \"\"\"\n    # auto_docstring\n    \"\"\"\n    model_name = \"my_model\"\n    block_classes = [InputBlock(), ImageEncoderBlock()]\n    block_names = [\"input\", \"image_encoder\"]\n\n    @property\n    def description(self):\n        return (\n            \"Process text prompts and images for the pipeline. It:\\n\"\n            \" - Determines the batch size from the prompts.\\n\"\n            \" - Encodes the image into latent space.\"\n        )\n```\n\nWhen you create a [`~modular_pipelines.SequentialPipelineBlocks`], properties like `inputs`, `intermediate_outputs`, and `expected_components` are automatically aggregated from the sub-blocks, so there is no need to define them again.\n\nThere are a few properties you should set:\n\n- `description`: We recommend adding a description for the assembled block to explain what the combined step does.\n- `model_name`: This is automatically derived from the sub-blocks but isn't always correct, so you may need to override it.\n- `outputs`: By default this is the same as `intermediate_outputs`, but you can manually set it to control which values appear in the doc. This is useful for showing only the final outputs instead of all intermediate values.\n\nThese properties, together with the aggregated `inputs`, `intermediate_outputs`, and `expected_components`, are used to automatically generate the `doc` property.\n\n\nPrint the `ImageProcessingStep` block to inspect its sub-blocks, and use `doc` for a full summary of the block's inputs, outputs, and components.\n\n\n```py\nblocks = ImageProcessingStep()\nprint(blocks)\nprint(blocks.doc)\n```"
  },
  {
    "path": "docs/source/en/optimization/attention_backends.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# Attention backends\n\n> [!NOTE]\n> The attention dispatcher is an experimental feature. Please open an issue if you have any feedback or encounter any problems.\n\nDiffusers provides several optimized attention algorithms that are more memory and computationally efficient through it's *attention dispatcher*. The dispatcher acts as a router for managing and switching between different attention implementations and provides a unified interface for interacting with them.\n\nRefer to the table below for an overview of the available attention families and to the [Available backends](#available-backends) section for a more complete list.\n\n| attention family | main feature |\n|---|---|\n| FlashAttention | minimizes memory reads/writes through tiling and recomputation |\n| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |\n| SageAttention | quantizes attention to int8 |\n| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |\n| xFormers | memory-efficient attention with support for various attention kernels |\n\nThis guide will show you how to set and use the different attention backends.\n\n## set_attention_backend\n\nThe [`~ModelMixin.set_attention_backend`] method iterates through all the modules in the model and sets the appropriate attention backend to use. The attention backend setting persists until [`~ModelMixin.reset_attention_backend`] is called.\n\nThe example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [`kernels`](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.\n\n> [!NOTE]\n> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention with `set_attention_backend(\"flash\")`.\n\n```py\nimport torch\nfrom diffusers import QwenImagePipeline\n\npipeline = QwenImagePipeline.from_pretrained(\n    \"Qwen/Qwen-Image\", torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\npipeline.transformer.set_attention_backend(\"_flash_3_hub\")\n\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\npipeline(prompt).images[0]\n```\n\nTo restore the default attention backend, call [`~ModelMixin.reset_attention_backend`].\n\n```py\npipeline.transformer.reset_attention_backend()\n```\n\n## attention_backend context manager\n\nThe [attention_backend](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L225) context manager temporarily sets an attention backend for a model within the context. Outside the context, the default attention (PyTorch's native scaled dot product attention) is used. This is useful if you want to use different backends for different parts of a pipeline or if you want to test the different backends.\n\n```py\nimport torch\nfrom diffusers import QwenImagePipeline\n\npipeline = QwenImagePipeline.from_pretrained(\n    \"Qwen/Qwen-Image\", torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\n\nwith attention_backend(\"_flash_3_hub\"):\n    image = pipeline(prompt).images[0]\n```\n\n> [!TIP]\n> Most attention backends support `torch.compile` without graph breaks and can be used to further speed up inference.\n\n## Checks\n\nThe attention dispatcher includes debugging checks that catch common errors before they cause problems.\n\n1. Device checks verify that query, key, and value tensors live on the same device.\n2. Data type checks confirm tensors have matching dtypes and use either bfloat16 or float16.\n3. Shape checks validate tensor dimensions and prevent mixing attention masks with causal flags.\n\nEnable these checks by setting the `DIFFUSERS_ATTN_CHECKS` environment variable. Checks add overhead to every attention operation, so they're disabled by default. \n\n```bash\nexport DIFFUSERS_ATTN_CHECKS=yes\n```\n\nThe checks are run now before every attention operation.\n\n```py\nimport torch\n\nquery = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device=\"cuda\")\nkey = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device=\"cuda\")\nvalue = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device=\"cuda\")\n\ntry:\n    with attention_backend(\"flash\"):\n        output = dispatch_attention_fn(query, key, value)\n        print(\"✓ Flash Attention works with checks enabled\")\nexcept Exception as e:\n    print(f\"✗ Flash Attention failed: {e}\")\n```\n\nYou can also configure the registry directly.\n\n```py\nfrom diffusers.models.attention_dispatch import _AttentionBackendRegistry\n\n_AttentionBackendRegistry._checks_enabled = True\n```\n\n## Available backends\n\nRefer to the table below for a complete list of available attention backends and their variants.\n\n<details>\n<summary>Expand</summary>\n\n| Backend Name | Family | Description |\n|--------------|--------|-------------|\n| `native` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Default backend using PyTorch's scaled_dot_product_attention |\n| `flex` | [FlexAttention](https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention) | PyTorch FlexAttention implementation |\n| `_native_cudnn` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | CuDNN-optimized attention |\n| `_native_efficient` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Memory-efficient attention |\n| `_native_flash` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | PyTorch's FlashAttention |\n| `_native_math` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Math-based attention (fallback) |\n| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |\n| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |\n| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |\n| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels |\n| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |\n| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |\n| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |\n| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |\n| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |\n| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |\n| `_flash_3_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 from kernels |\n| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |\n| `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels |\n| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |\n| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |\n| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |\n| `_sage_qk_int8_pv_fp16_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (CUDA) |\n| `_sage_qk_int8_pv_fp16_triton` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (Triton) |\n| `xformers` | [xFormers](https://github.com/facebookresearch/xformers) | Memory-efficient attention |\n\n</details>\n"
  },
  {
    "path": "docs/source/en/optimization/cache.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# Caching\n\nCaching accelerates inference by storing and reusing intermediate outputs of different layers, such as attention and feedforward layers, instead of performing the entire computation at each inference step. It significantly improves generation speed at the expense of more memory and doesn't require additional training.\n\nThis guide shows you how to use the caching methods supported in Diffusers.\n\n## Pyramid Attention Broadcast\n\n[Pyramid Attention Broadcast (PAB)](https://huggingface.co/papers/2408.12588) is based on the observation that attention outputs aren't that different between successive timesteps of the generation process. The attention differences are smallest in the cross attention layers and are generally cached over a longer timestep range. This is followed by temporal attention and spatial attention layers.\n\n> [!TIP]\n> Not all video models have three types of attention (cross, temporal, and spatial)!\n\nPAB can be combined with other techniques like sequence parallelism and classifier-free guidance parallelism (data parallelism) for near real-time video generation.\n\nSet up and pass a [`PyramidAttentionBroadcastConfig`] to a pipeline's transformer to enable it. The `spatial_attention_block_skip_range` controls how often to skip attention calculations in the spatial attention blocks and the `spatial_attention_timestep_skip_range` is the range of timesteps to skip. Take care to choose an appropriate range because a smaller interval can lead to slower inference speeds and a larger interval can result in lower generation quality.\n\n```python\nimport torch\nfrom diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig\n\npipeline = CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-5b\", torch_dtype=torch.bfloat16)\npipeline.to(\"cuda\")\n\nconfig = PyramidAttentionBroadcastConfig(\n    spatial_attention_block_skip_range=2,\n    spatial_attention_timestep_skip_range=(100, 800),\n    current_timestep_callback=lambda: pipe.current_timestep,\n)\npipeline.transformer.enable_cache(config)\n```\n\n## FasterCache\n\n[FasterCache](https://huggingface.co/papers/2410.19355) caches and reuses attention features similar to [PAB](#pyramid-attention-broadcast) since output differences are small for each successive timestep.\n\nThis method may also choose to skip the unconditional branch prediction, when using classifier-free guidance for sampling (common in most base models), and estimate it from the conditional branch prediction if there is significant redundancy in the predicted latent outputs between successive timesteps.\n\nSet up and pass a [`FasterCacheConfig`] to a pipeline's transformer to enable it.\n\n```python\nimport torch\nfrom diffusers import CogVideoXPipeline, FasterCacheConfig\n\npipe line= CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-5b\", torch_dtype=torch.bfloat16)\npipeline.to(\"cuda\")\n\nconfig = FasterCacheConfig(\n    spatial_attention_block_skip_range=2,\n    spatial_attention_timestep_skip_range=(-1, 681),\n    current_timestep_callback=lambda: pipe.current_timestep,\n    attention_weight_callback=lambda _: 0.3,\n    unconditional_batch_skip_range=5,\n    unconditional_batch_timestep_skip_range=(-1, 781),\n    tensor_format=\"BFCHW\",\n)\npipeline.transformer.enable_cache(config)\n```\n\n## FirstBlockCache\n\n[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) checks how much the early layers of the denoiser changes from one timestep to the next. If the change is small, the model skips the expensive later layers and reuses the previous output.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"Qwen/Qwen-Image\", torch_dtype=torch.bfloat16\n)\napply_first_block_cache(pipeline.transformer, FirstBlockCacheConfig(threshold=0.2))\n```\n## TaylorSeer Cache\n\n[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.\n\nThis caching mechanism delivers strong results with minimal additional memory overhead. For detailed performance analysis, see [our findings here](https://github.com/huggingface/diffusers/pull/12648#issuecomment-3610615080).\n\nTo enable TaylorSeer Cache, create a [`TaylorSeerCacheConfig`] and pass it to your pipeline's transformer:\n\n- `cache_interval`: Number of steps to reuse cached outputs before performing a full forward pass\n- `disable_cache_before_step`: Initial steps that use full computations to gather data for approximations\n- `max_order`: Approximation accuracy (in theory, higher values improve quality but increase memory usage but we recommend it should be set to `1`)\n\n```python\nimport torch\nfrom diffusers import FluxPipeline, TaylorSeerCacheConfig\n\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\nconfig = TaylorSeerCacheConfig(\n    cache_interval=5,\n    max_order=1,\n    disable_cache_before_step=10,\n    taylor_factors_dtype=torch.bfloat16,\n)\npipe.transformer.enable_cache(config)\n```\n\n## MagCache\n\n[MagCache](https://github.com/Zehong-Ma/MagCache) accelerates inference by skipping transformer blocks based on the magnitude of the residual update. It observes that the magnitude of updates (Output - Input) decays predictably over the diffusion process. By accumulating an \"error budget\" based on pre-computed magnitude ratios, it dynamically decides when to skip computation and reuse the previous residual.\n\nMagCache relies on **Magnitude Ratios** (`mag_ratios`), which describe this decay curve. These ratios are specific to the model checkpoint and scheduler.\n\n### Usage\n\nTo use MagCache, you typically follow a two-step process: **Calibration** and **Inference**.\n\n1.  **Calibration**: Run inference once with `calibrate=True`. The hook will measure the residual magnitudes and print the calculated ratios to the console.\n2.  **Inference**: Pass these ratios to `MagCacheConfig` to enable acceleration.\n\n```python\nimport torch\nfrom diffusers import FluxPipeline, MagCacheConfig\n\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-schnell\",\n    torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\n# 1. Calibration Step\n# Run full inference to measure model behavior.\ncalib_config = MagCacheConfig(calibrate=True, num_inference_steps=4)\npipe.transformer.enable_cache(calib_config)\n\n# Run a prompt to trigger calibration\npipe(\"A cat playing chess\", num_inference_steps=4)\n# Logs will print something like: \"MagCache Calibration Results: [1.0, 1.37, 0.97, 0.87]\"\n\n# 2. Inference Step\n# Apply the specific ratios obtained from calibration for optimized speed.\n# Note: For Flux models, you can also import defaults: \n# from diffusers.hooks.mag_cache import FLUX_MAG_RATIOS\nmag_config = MagCacheConfig(\n    mag_ratios=[1.0, 1.37, 0.97, 0.87],\n    num_inference_steps=4\n)\n\npipe.transformer.enable_cache(mag_config) \n\nimage = pipe(\"A cat playing chess\", num_inference_steps=4).images[0]\n```\n\n> [!NOTE]\n> `mag_ratios` represent the model's intrinsic magnitude decay curve. Ratios calibrated for a high number of steps (e.g., 50) can be reused for lower step counts (e.g., 20). The implementation uses interpolation to map the curve to the current number of inference steps.\n\n> [!TIP]\n> For pipelines that run Classifier-Free Guidance sequentially (like Kandinsky 5.0), the calibration log might print two arrays: one for the Conditional pass and one for the Unconditional pass. In most cases, you should use the first array (Conditional).\n\n> [!TIP]\n> For pipelines that run Classifier-Free Guidance in a **batched** manner (like SDXL or Flux), the `hidden_states` processed by the model contain both conditional and unconditional branches concatenated together. The calibration process automatically accounts for this, producing a single array of ratios that represents the joint behavior. You can use this resulting array directly without modification.\n"
  },
  {
    "path": "docs/source/en/optimization/cache_dit.md",
    "content": "## CacheDiT  \n\nCacheDiT is a unified, flexible, and training-free cache acceleration framework designed to support nearly all Diffusers' DiT-based pipelines. It provides a unified cache API that supports automatic block adapter, DBCache, and more.\n\nTo learn more, refer to the [CacheDiT](https://github.com/vipshop/cache-dit) repository.\n\nInstall a stable release of CacheDiT from PyPI or you can install the latest version from GitHub.\n\n<hfoptions id=\"install\">\n<hfoption id=\"PyPI\">\n\n```bash\npip3 install -U cache-dit\n```\n\n</hfoption>\n<hfoption id=\"source\">\n\n```bash\npip3 install git+https://github.com/vipshop/cache-dit.git\n```\n\n</hfoption>\n</hfoptions>\n\nRun the command below to view supported DiT pipelines.\n\n```python\n>>> import cache_dit\n>>> cache_dit.supported_pipelines()\n(30, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTX*', 'Allegro*',\n'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'StableDiffusion3*',\n'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'Lumina*', 'OmniGen*', 'PixArt*', 'Sana*', 'StableAudio*',\n'VisualCloze*', 'AuraFlow*', 'Chroma*', 'ShapE*', 'HiDream*', 'HunyuanDiT*', 'HunyuanDiTPAG*'])\n```\n\nFor a complete benchmark, please refer to [Benchmarks](https://github.com/vipshop/cache-dit/blob/main/bench/).\n\n\n## Unified Cache API\n\nCacheDiT works by matching specific input/output patterns as shown below.\n\n![](https://github.com/vipshop/cache-dit/raw/main/assets/patterns-v1.png)\n\nCall the `enable_cache()` function on a pipeline to enable cache acceleration. This function is the entry point to many of CacheDiT's features.\n\n```python\nimport cache_dit\nfrom diffusers import DiffusionPipeline \n\n# Can be any diffusion pipeline\npipe = DiffusionPipeline.from_pretrained(\"Qwen/Qwen-Image\")\n\n# One-line code with default cache options.\ncache_dit.enable_cache(pipe) \n\n# Just call the pipe as normal.\noutput = pipe(...)\n\n# Disable cache and run original pipe.\ncache_dit.disable_cache(pipe)\n```\n\n## Automatic Block Adapter\n\nFor custom or modified pipelines or transformers not included in Diffusers, use the `BlockAdapter` in `auto` mode or via manual configuration. Please check the [BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#automatic-block-adapter) docs for more details. Refer to [Qwen-Image w/ BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_qwen_image_adapter.py) as an example.\n\n\n```python\nfrom cache_dit import ForwardPattern, BlockAdapter\n\n# Use 🔥BlockAdapter with `auto` mode.\ncache_dit.enable_cache(\n    BlockAdapter(\n        # Any DiffusionPipeline, Qwen-Image, etc.  \n        pipe=pipe, auto=True,\n        # Check `📚Forward Pattern Matching` documentation and hack the code of\n        # of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.\n        forward_pattern=ForwardPattern.Pattern_1,\n    ),   \n)\n\n# Or, manually setup transformer configurations.\ncache_dit.enable_cache(\n    BlockAdapter(\n        pipe=pipe, # Qwen-Image, etc.\n        transformer=pipe.transformer,\n        blocks=pipe.transformer.transformer_blocks,\n        forward_pattern=ForwardPattern.Pattern_1,\n    ), \n)\n```\n\nSometimes, a Transformer class will contain more than one transformer `blocks`. For example, FLUX.1 (HiDream, Chroma, etc) contains `transformer_blocks` and `single_transformer_blocks` (with different forward patterns). The BlockAdapter is able to detect this hybrid pattern type as well. \nRefer to [FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_flux_adapter.py) as an example.\n\n```python\n# For diffusers <= 0.34.0, FLUX.1 transformer_blocks and \n# single_transformer_blocks have different forward patterns.\ncache_dit.enable_cache(\n    BlockAdapter(\n        pipe=pipe, # FLUX.1, etc.\n        transformer=pipe.transformer,\n        blocks=[\n            pipe.transformer.transformer_blocks,\n            pipe.transformer.single_transformer_blocks,\n        ],\n        forward_pattern=[\n            ForwardPattern.Pattern_1,\n            ForwardPattern.Pattern_3,\n        ],\n    ),\n)\n```\n\nThis also works if there is more than one transformer (namely `transformer` and `transformer_2`) in its structure. Refer to [Wan 2.2 MoE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.\n\n## Patch Functor\n\nFor any pattern not included in CacheDiT, use the Patch Functor to convert the pattern into a known pattern. You need to subclass the Patch Functor and may also need to fuse the operations within the blocks for loop into block `forward`. After implementing a Patch Functor, set the `patch_functor` property in `BlockAdapter`.\n\n![](https://github.com/vipshop/cache-dit/raw/main/assets/patch-functor.png)\n\nSome Patch Functors are already provided in CacheDiT, [HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc.\n\n```python\n@BlockAdapterRegistry.register(\"HiDream\")\ndef hidream_adapter(pipe, **kwargs) -> BlockAdapter:\n    from diffusers import HiDreamImageTransformer2DModel\n    from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor\n\n    assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)\n    return BlockAdapter(\n        pipe=pipe,\n        transformer=pipe.transformer,\n        blocks=[\n            pipe.transformer.double_stream_blocks,\n            pipe.transformer.single_stream_blocks,\n        ],\n        forward_pattern=[\n            ForwardPattern.Pattern_0,\n            ForwardPattern.Pattern_3,\n        ],\n        # NOTE: Setup your custom patch functor here.\n        patch_functor=HiDreamPatchFunctor(),\n        **kwargs,\n    )\n```\n\nFinally, you can call the `cache_dit.summary()` function on a pipeline after its completed inference to get the cache acceleration details.\n\n```python\nstats = cache_dit.summary(pipe)\n```\n\n```python\n⚡️Cache Steps and Residual Diffs Statistics: QwenImagePipeline\n\n| Cache Steps | Diffs Min | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Max |\n|-------------|-----------|-----------|-----------|-----------|-----------|-----------|\n| 23          | 0.045     | 0.084     | 0.114     | 0.147     | 0.241     | 0.297     |\n```\n\n## DBCache: Dual Block Cache  \n\n![](https://github.com/vipshop/cache-dit/raw/main/assets/dbcache-v1.png)\n\nDBCache (Dual Block Caching) supports different configurations of compute blocks (F8B12, etc.) to enable a balanced trade-off between performance and precision.\n- Fn_compute_blocks: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.\n- Bn_compute_blocks: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.\n\n\n```python\nimport cache_dit\nfrom diffusers import FluxPipeline\n\npipe_or_adapter = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\n# Default options, F8B0, 8 warmup steps, and unlimited cached \n# steps for good balance between performance and precision\ncache_dit.enable_cache(pipe_or_adapter)\n\n# Custom options, F8B8, higher precision\nfrom cache_dit import BasicCacheConfig\n\ncache_dit.enable_cache(\n    pipe_or_adapter,\n    cache_config=BasicCacheConfig(\n        max_warmup_steps=8,  # steps do not cache\n        max_cached_steps=-1, # -1 means no limit\n        Fn_compute_blocks=8, # Fn, F8, etc.\n        Bn_compute_blocks=8, # Bn, B8, etc.\n        residual_diff_threshold=0.12,\n    ),\n)\n```  \nCheck the [DBCache](https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md) and [User Guide](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#dbcache) docs for more design details.\n\n## TaylorSeer Calibrator\n\nThe [TaylorSeers](https://huggingface.co/papers/2503.06923) algorithm further improves the precision of DBCache in cases where the cached steps are large (Hybrid TaylorSeer + DBCache). At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality. \n\nTaylorSeer employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. F_pred can be a residual cache or a hidden-state cache.\n\n```python\nfrom cache_dit import BasicCacheConfig, TaylorSeerCalibratorConfig\n\ncache_dit.enable_cache(\n    pipe_or_adapter,\n    # Basic DBCache w/ FnBn configurations\n    cache_config=BasicCacheConfig(\n        max_warmup_steps=8,  # steps do not cache\n        max_cached_steps=-1, # -1 means no limit\n        Fn_compute_blocks=8, # Fn, F8, etc.\n        Bn_compute_blocks=8, # Bn, B8, etc.\n        residual_diff_threshold=0.12,\n    ),\n    # Then, you can use the TaylorSeer Calibrator to approximate \n    # the values in cached steps, taylorseer_order default is 1.\n    calibrator_config=TaylorSeerCalibratorConfig(\n        taylorseer_order=1,\n    ),\n)\n``` \n\n> [!TIP]  \n> The `Bn_compute_blocks` parameter of DBCache can be set to `0` if you use TaylorSeer as the calibrator for approximate hidden states. DBCache's `Bn_compute_blocks` also acts as a calibrator, so you can choose either `Bn_compute_blocks` > 0 or TaylorSeer. We recommend using the configuration scheme of TaylorSeer + DBCache FnB0.\n\n## Hybrid Cache CFG\n\nCacheDiT supports caching for CFG (classifier-free guidance). For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG in the forward step, please set `enable_separate_cfg` parameter  to `False (default, None)`. Otherwise, set it to `True`. \n\n```python\nfrom cache_dit import BasicCacheConfig\n\ncache_dit.enable_cache(\n    pipe_or_adapter, \n    cache_config=BasicCacheConfig(\n        ...,\n        # For example, set it as True for Wan 2.1, Qwen-Image \n        # and set it as False for FLUX.1, HunyuanVideo, etc.\n        enable_separate_cfg=True,\n    ),\n)\n```\n\n## torch.compile\n\nCacheDiT is designed to work with torch.compile for even better performance. Call `torch.compile` after enabling the cache.\n\n\n```python\ncache_dit.enable_cache(pipe)\n\n# Compile the Transformer module\npipe.transformer = torch.compile(pipe.transformer)\n```\n\nIf you're using CacheDiT with dynamic input shapes, consider increasing the `recompile_limit` of `torch._dynamo`. Otherwise, the `recompile_limit` error may be triggered, causing the module to fall back to eager mode. \n\n```python\ntorch._dynamo.config.recompile_limit = 96  # default is 8\ntorch._dynamo.config.accumulated_recompile_limit = 2048  # default is 256\n```\n\nPlease check [perf.py](https://github.com/vipshop/cache-dit/blob/main/bench/perf.py) for more details.\n"
  },
  {
    "path": "docs/source/en/optimization/coreml.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# How to run Stable Diffusion with Core ML\n\n[Core ML](https://developer.apple.com/documentation/coreml) is the model format and machine learning library supported by Apple frameworks. If you are interested in running Stable Diffusion models inside your macOS or iOS/iPadOS apps, this guide will show you how to convert existing PyTorch checkpoints into the Core ML format and use them for inference with Python or Swift.\n\nCore ML models can leverage all the compute engines available in Apple devices: the CPU, the GPU, and the Apple Neural Engine (or ANE, a tensor-optimized accelerator available in Apple Silicon Macs and modern iPhones/iPads). Depending on the model and the device it's running on, Core ML can mix and match compute engines too, so some portions of the model may run on the CPU while others run on GPU, for example.\n\n> [!TIP]\n> You can also run the `diffusers` Python codebase on Apple Silicon Macs using the `mps` accelerator built into PyTorch. This approach is explained in depth in [the mps guide](mps), but it is not compatible with native apps.\n\n## Stable Diffusion Core ML Checkpoints\n\nStable Diffusion weights (or checkpoints) are stored in the PyTorch format, so you need to convert them to the Core ML format before we can use them inside native apps.\n\nThankfully, Apple engineers developed [a conversion tool](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml) based on `diffusers` to convert the PyTorch checkpoints to Core ML.\n\nBefore you convert a model, though, take a moment to explore the Hugging Face Hub – chances are the model you're interested in is already available in Core ML format:\n\n- the [Apple](https://huggingface.co/apple) organization includes Stable Diffusion versions 1.4, 1.5, 2.0 base, and 2.1 base\n- [coreml community](https://huggingface.co/coreml-community) includes custom finetuned models\n- use this [filter](https://huggingface.co/models?pipeline_tag=text-to-image&library=coreml&p=2&sort=likes) to return all available Core ML checkpoints\n\nIf you can't find the model you're interested in, we recommend you follow the instructions for [Converting Models to Core ML](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml) by Apple.\n\n## Selecting the Core ML Variant to Use\n\nStable Diffusion models can be converted to different Core ML variants intended for different purposes:\n\n- The type of attention blocks used. The attention operation is used to \"pay attention\" to the relationship between different areas in the image representations and to understand how the image and text representations are related. Attention is compute- and memory-intensive, so different implementations exist that consider the hardware characteristics of different devices. For Core ML Stable Diffusion models, there are two attention variants:\n    * `split_einsum` ([introduced by Apple](https://machinelearning.apple.com/research/neural-engine-transformers)) is optimized for ANE devices, which is available in modern iPhones, iPads and M-series computers.\n    * The \"original\" attention (the base implementation used in `diffusers`) is only compatible with CPU/GPU and not ANE. It can be *faster* to run your model on CPU + GPU using `original` attention than ANE. See [this performance benchmark](https://huggingface.co/blog/fast-mac-diffusers#performance-benchmarks) as well as some [additional measures provided by the community](https://github.com/huggingface/swift-coreml-diffusers/issues/31) for additional details.\n\n- The supported inference framework.\n    * `packages` are suitable for Python inference. This can be used to test converted Core ML models before attempting to integrate them inside native apps, or if you want to explore Core ML performance but don't need to support native apps. For example, an application with a web UI could perfectly use a Python Core ML backend.\n    * `compiled` models are required for Swift code. The `compiled` models in the Hub split the large UNet model weights into several files for compatibility with iOS and iPadOS devices. This corresponds to the [`--chunk-unet` conversion option](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml). If you want to support native apps, then you need to select the `compiled` variant.\n\nThe official Core ML Stable Diffusion [models](https://huggingface.co/apple/coreml-stable-diffusion-v1-4/tree/main) include these variants, but the community ones may vary:\n\n```\ncoreml-stable-diffusion-v1-4\n├── README.md\n├── original\n│   ├── compiled\n│   └── packages\n└── split_einsum\n    ├── compiled\n    └── packages\n```\n\nYou can download and use the variant you need as shown below.\n\n## Core ML Inference in Python\n\nInstall the following libraries to run Core ML inference in Python:\n\n```bash\npip install huggingface_hub\npip install git+https://github.com/apple/ml-stable-diffusion\n```\n\n### Download the Model Checkpoints\n\nTo run inference in Python, use one of the versions stored in the `packages` folders because the `compiled` ones are only compatible with Swift. You may choose whether you want to use `original` or `split_einsum` attention.\n\nThis is how you'd download the `original` attention variant from the Hub to a directory called `models`:\n\n```Python\nfrom huggingface_hub import snapshot_download\nfrom pathlib import Path\n\nrepo_id = \"apple/coreml-stable-diffusion-v1-4\"\nvariant = \"original/packages\"\n\nmodel_path = Path(\"./models\") / (repo_id.split(\"/\")[-1] + \"_\" + variant.replace(\"/\", \"_\"))\nsnapshot_download(repo_id, allow_patterns=f\"{variant}/*\", local_dir=model_path, local_dir_use_symlinks=False)\nprint(f\"Model downloaded at {model_path}\")\n```\n\n### Inference[[python-inference]]\n\nOnce you have downloaded a snapshot of the model, you can test it using Apple's Python script.\n\n```shell\npython -m python_coreml_stable_diffusion.pipeline --prompt \"a photo of an astronaut riding a horse on mars\" -i ./models/coreml-stable-diffusion-v1-4_original_packages/original/packages -o </path/to/output/image> --compute-unit CPU_AND_GPU --seed 93\n```\n\nPass the path of the downloaded checkpoint with `-i` flag to the script. `--compute-unit` indicates the hardware you want to allow for inference. It must be one of the following options: `ALL`, `CPU_AND_GPU`, `CPU_ONLY`, `CPU_AND_NE`. You may also provide an optional output path, and a seed for reproducibility.\n\nThe inference script assumes you're using the original version of the Stable Diffusion model, `CompVis/stable-diffusion-v1-4`. If you use another model, you *have* to specify its Hub id in the inference command line, using the `--model-version` option. This works for models already supported and custom models you trained or fine-tuned yourself.\n\nFor example, if you want to use [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5):\n\n```shell\npython -m python_coreml_stable_diffusion.pipeline --prompt \"a photo of an astronaut riding a horse on mars\" --compute-unit ALL -o output --seed 93 -i models/coreml-stable-diffusion-v1-5_original_packages --model-version stable-diffusion-v1-5/stable-diffusion-v1-5\n```\n\n## Core ML inference in Swift\n\nRunning inference in Swift is slightly faster than in Python because the models are already compiled in the `mlmodelc` format. This is noticeable on app startup when the model is loaded but shouldn’t be noticeable if you run several generations afterward.\n\n### Download\n\nTo run inference in Swift on your Mac, you need one of the `compiled` checkpoint versions. We recommend you download them locally using Python code similar to the previous example, but with one of the `compiled` variants:\n\n```Python\nfrom huggingface_hub import snapshot_download\nfrom pathlib import Path\n\nrepo_id = \"apple/coreml-stable-diffusion-v1-4\"\nvariant = \"original/compiled\"\n\nmodel_path = Path(\"./models\") / (repo_id.split(\"/\")[-1] + \"_\" + variant.replace(\"/\", \"_\"))\nsnapshot_download(repo_id, allow_patterns=f\"{variant}/*\", local_dir=model_path, local_dir_use_symlinks=False)\nprint(f\"Model downloaded at {model_path}\")\n```\n\n### Inference[[swift-inference]]\n\nTo run inference, please clone Apple's repo:\n\n```bash\ngit clone https://github.com/apple/ml-stable-diffusion\ncd ml-stable-diffusion\n```\n\nAnd then use Apple's command line tool, [Swift Package Manager](https://www.swift.org/package-manager/#):\n\n```bash\nswift run StableDiffusionSample --resource-path models/coreml-stable-diffusion-v1-4_original_compiled --compute-units all \"a photo of an astronaut riding a horse on mars\"\n```\n\nYou have to specify in `--resource-path` one of the checkpoints downloaded in the previous step, so please make sure it contains compiled Core ML bundles with the extension `.mlmodelc`. The `--compute-units` has to be one of these values: `all`, `cpuOnly`, `cpuAndGPU`, `cpuAndNeuralEngine`.\n\nFor more details, please refer to the [instructions in Apple's repo](https://github.com/apple/ml-stable-diffusion).\n\n## Supported Diffusers Features\n\nThe Core ML models and inference code don't support many of the features, options, and flexibility of 🧨 Diffusers. These are some of the limitations to keep in mind:\n\n- Core ML models are only suitable for inference. They can't be used for training or fine-tuning.\n- Only two schedulers have been ported to Swift, the default one used by Stable Diffusion and `DPMSolverMultistepScheduler`, which we ported to Swift from our `diffusers` implementation. We recommend you use `DPMSolverMultistepScheduler`, since it produces the same quality in about half the steps.\n- Negative prompts, classifier-free guidance scale, and image-to-image tasks are available in the inference code. Advanced features such as depth guidance, ControlNet, and latent upscalers are not available yet.\n\nApple's [conversion and inference repo](https://github.com/apple/ml-stable-diffusion) and our own [swift-coreml-diffusers](https://github.com/huggingface/swift-coreml-diffusers) repos are intended as technology demonstrators to enable other developers to build upon.\n\nIf you feel strongly about any missing features, please feel free to open a feature request or, better yet, a contribution PR 🙂.\n\n## Native Diffusers Swift app\n\nOne easy way to run Stable Diffusion on your own Apple hardware is to use [our open-source Swift repo](https://github.com/huggingface/swift-coreml-diffusers), based on `diffusers` and Apple's conversion and inference repo. You can study the code, compile it with [Xcode](https://developer.apple.com/xcode/) and adapt it for your own needs. For your convenience, there's also a [standalone Mac app in the App Store](https://apps.apple.com/app/diffusers/id1666309574), so you can play with it without having to deal with the code or IDE. If you are a developer and have determined that Core ML is the best solution to build your Stable Diffusion app, then you can use the rest of this guide to get started with your project. We can't wait to see what you'll build 🙂.\n"
  },
  {
    "path": "docs/source/en/optimization/deepcache.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DeepCache\n[DeepCache](https://huggingface.co/papers/2312.00858) accelerates [`StableDiffusionPipeline`] and [`StableDiffusionXLPipeline`] by strategically caching and reusing high-level features while efficiently updating low-level features by taking advantage of the U-Net architecture.\n\nStart by installing [DeepCache](https://github.com/horseee/DeepCache):\n```bash\npip install DeepCache\n```\n\nThen load and enable the [`DeepCacheSDHelper`](https://github.com/horseee/DeepCache#usage):\n\n```diff\n  import torch\n  from diffusers import StableDiffusionPipeline\n  pipe = StableDiffusionPipeline.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5', torch_dtype=torch.float16).to(\"cuda\")\n\n+ from DeepCache import DeepCacheSDHelper\n+ helper = DeepCacheSDHelper(pipe=pipe)\n+ helper.set_params(\n+     cache_interval=3,\n+     cache_branch_id=0,\n+ )\n+ helper.enable()\n\n  image = pipe(\"a photo of an astronaut on a moon\").images[0]\n```\n\nThe `set_params` method accepts two arguments: `cache_interval` and `cache_branch_id`. `cache_interval` means the frequency of feature caching, specified as the number of steps between each cache operation. `cache_branch_id` identifies which branch of the network (ordered from the shallowest to the deepest layer) is responsible for executing the caching processes.\nOpting for a lower `cache_branch_id` or a larger `cache_interval` can lead to faster inference speed at the expense of reduced image quality (ablation experiments of these two hyperparameters can be found in the [paper](https://huggingface.co/papers/2312.00858)). Once those arguments are set, use the `enable` or `disable` methods to activate or deactivate the `DeepCacheSDHelper`.\n\n<div class=\"flex justify-center\">\n    <img src=\"https://github.com/horseee/Diffusion_DeepCache/raw/master/static/images/example.png\">\n</div>\n\nYou can find more generated samples (original pipeline vs DeepCache) and the corresponding inference latency in the [WandB report](https://wandb.ai/horseee/DeepCache/runs/jwlsqqgt?workspace=user-horseee). The prompts are randomly selected from the [MS-COCO 2017](https://cocodataset.org/#home) dataset.\n\n## Benchmark\n\nWe tested how much faster DeepCache accelerates [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) with 50 inference steps on an NVIDIA RTX A5000, using different configurations for resolution, batch size, cache interval (I), and cache branch (B).\n\n| **Resolution** | **Batch size** | **Original** | **DeepCache(I=3, B=0)** | **DeepCache(I=5, B=0)** | **DeepCache(I=5, B=1)** |\n|----------------|----------------|--------------|-------------------------|-------------------------|-------------------------|\n|             512|               8|         15.96|              6.88(2.32x)|              5.03(3.18x)|              7.27(2.20x)|\n|                |               4|          8.39|              3.60(2.33x)|              2.62(3.21x)|              3.75(2.24x)|\n|                |               1|          2.61|              1.12(2.33x)|              0.81(3.24x)|              1.11(2.35x)|\n|             768|               8|         43.58|             18.99(2.29x)|             13.96(3.12x)|             21.27(2.05x)|\n|                |               4|         22.24|              9.67(2.30x)|              7.10(3.13x)|             10.74(2.07x)|\n|                |               1|          6.33|              2.72(2.33x)|              1.97(3.21x)|              2.98(2.12x)|\n|            1024|               8|        101.95|             45.57(2.24x)|             33.72(3.02x)|             53.00(1.92x)|\n|                |               4|         49.25|             21.86(2.25x)|             16.19(3.04x)|             25.78(1.91x)|\n|                |               1|         13.83|              6.07(2.28x)|              4.43(3.12x)|              7.15(1.93x)|\n"
  },
  {
    "path": "docs/source/en/optimization/fp16.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Accelerate inference\n\nDiffusion models are slow at inference because generation is an iterative process where noise is gradually refined into an image or video over a certain number of \"steps\". To speedup this process, you can try experimenting with different [schedulers](../api/schedulers/overview), reduce the precision of the model weights for faster computations, use more memory-efficient attention mechanisms, and more.\n\nCombine and use these techniques together to make inference faster than using any single technique on its own.\n\nThis guide will go over how to accelerate inference.\n\n## Model data type\n\nThe precision and data type of the model weights affect inference speed because a higher precision requires more memory to load and more time to perform the computations. PyTorch loads model weights in float32 or full precision by default, so changing the data type is a simple way to quickly get faster inference.\n\n<hfoptions id=\"dtypes\">\n<hfoption id=\"bfloat16\">\n\nbfloat16 is similar to float16 but it is more robust to numerical errors. Hardware support for bfloat16 varies, but most modern GPUs are capable of supporting bfloat16.\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\npipeline(prompt, num_inference_steps=30).images[0]\n```\n\n</hfoption>\n<hfoption id=\"float16\">\n\nfloat16 is similar to bfloat16 but may be more prone to numerical errors.\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n).to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\npipeline(prompt, num_inference_steps=30).images[0]\n```\n\n</hfoption>\n<hfoption id=\"TensorFloat-32\">\n\n[TensorFloat-32 (tf32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) mode is supported on NVIDIA Ampere GPUs and it computes the convolution and matrix multiplication operations in tf32. Storage and other operations are kept in float32. This enables significantly faster computations when combined with bfloat16 or float16.\n\nPyTorch only enables tf32 mode for convolutions by default and you'll need to explicitly enable it for matrix multiplications.\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\ntorch.backends.cuda.matmul.allow_tf32 = True\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\npipeline(prompt, num_inference_steps=30).images[0]\n```\n\nRefer to the [mixed precision training](https://huggingface.co/docs/transformers/en/perf_train_gpu_one#mixed-precision) docs for more details.\n\n</hfoption>\n</hfoptions>\n\n## Scaled dot product attention\n\n> [!TIP]\n> Memory-efficient attention optimizes for inference speed *and* [memory usage](./memory#memory-efficient-attention)!\n\n[Scaled dot product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) implements several attention backends, [FlashAttention](https://github.com/Dao-AILab/flash-attention), [xFormers](https://github.com/facebookresearch/xformers), and a native C++ implementation. It automatically selects the most optimal backend for your hardware.\n\nSDPA is enabled by default if you're using PyTorch >= 2.0 and no additional changes are required to your code. You could try experimenting with other attention backends though if you'd like to choose your own. The example below uses the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to enable efficient attention.\n\n```py\nfrom torch.nn.attention import SDPBackend, sdpa_kernel\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n\nwith sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):\n  image = pipeline(prompt, num_inference_steps=30).images[0]\n```\n\n## torch.compile\n\n[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) accelerates inference by compiling PyTorch code and operations into optimized kernels. Diffusers typically compiles the more compute-intensive models like the UNet, transformer, or VAE.\n\nEnable the following compiler settings for maximum speed (refer to the [full list](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py) for more options).\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\ntorch._inductor.config.conv_1x1_as_mm = True\ntorch._inductor.config.coordinate_descent_tuning = True\ntorch._inductor.config.epilogue_fusion = False\ntorch._inductor.config.coordinate_descent_check_all_directions = True\n```\n\nLoad and compile the UNet and VAE. There are several different modes you can choose from, but `\"max-autotune\"` optimizes for the fastest speed by compiling to a CUDA graph. CUDA graphs effectively reduces the overhead by launching multiple GPU operations through a single CPU operation.\n\n> [!TIP]\n> With PyTorch 2.3.1, you can control the caching behavior of torch.compile. This is particularly beneficial for compilation modes like `\"max-autotune\"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial.\n\nChanging the memory layout to [channels_last](./memory#torchchannels_last) also optimizes memory and inference speed.\n\n```py\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.unet.to(memory_format=torch.channels_last)\npipeline.vae.to(memory_format=torch.channels_last)\npipeline.unet = torch.compile(\n    pipeline.unet, mode=\"max-autotune\", fullgraph=True\n)\npipeline.vae.decode = torch.compile(\n    pipeline.vae.decode,\n    mode=\"max-autotune\",\n    fullgraph=True\n)\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\npipeline(prompt, num_inference_steps=30).images[0]\n```\n\nCompilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient.\n\n### Dynamic shape compilation\n\n> [!TIP]\n> Make sure to always use the nightly version of PyTorch for better support.\n\n`torch.compile` keeps track of input shapes and conditions, and if these are different, it recompiles the model. For example, if a model is compiled on a 1024x1024 resolution image and used on an image with a different resolution, it triggers recompilation.\n\nTo avoid recompilation, add `dynamic=True` to try and generate a more dynamic kernel to avoid recompilation when conditions change.\n\n```diff\n+ torch.fx.experimental._config.use_duck_shape = False\n+ pipeline.unet = torch.compile(\n    pipeline.unet, fullgraph=True, dynamic=True\n)\n```\n\nSpecifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).\n\nNot all models may benefit from dynamic compilation out of the box and may require changes. Refer to this [PR](https://github.com/huggingface/diffusers/pull/11297/) that improved the [`AuraFlowPipeline`] implementation to benefit from dynamic compilation.\n\nFeel free to open an issue if dynamic compilation doesn't work as expected for a Diffusers model.\n\n### Regional compilation\n\n[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.\nFor many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 8–10x.\n\nUse the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below.\n\n```py\n# pip install -U diffusers\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n\n# compile only the repeated transformer layers inside the UNet\npipeline.unet.compile_repeated_blocks(fullgraph=True)\n```\n\nTo enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile.\n\n```py\nclass MyUNet(ModelMixin):\n    _repeated_blocks = (\"Transformer2DModel\",)  # ← compiled by default\n```\n\n> [!TIP]\n> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).\n\nThere is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.\n\n```py\n# pip install -U accelerate\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom accelerate.utils import compile_regions\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.unet = compile_regions(pipeline.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\n[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.\n\n### Graph breaks\n\nIt is important to specify `fullgraph=True` in torch.compile to ensure there are no graph breaks in the underlying model. This allows you to take advantage of torch.compile without any performance degradation. For the UNet and VAE, this changes how you access the return variables.\n\n```diff\n- latents = unet(\n-   latents, timestep=timestep, encoder_hidden_states=prompt_embeds\n-).sample\n\n+ latents = unet(\n+   latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False\n+)[0]\n```\n\n### GPU sync\n\nThe `step()` function is [called](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228) on the scheduler each time after the denoiser makes a prediction, and the `sigmas` variable is [indexed](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476). When placed on the GPU, it introduces latency because of the communication sync between the CPU and GPU. It becomes more evident when the denoiser has already been compiled.\n\nIn general, the `sigmas` should [stay on the CPU](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240) to avoid the communication sync and latency.\n\n> [!TIP]\n> Refer to the [torch.compile and Diffusers: A Hands-On Guide to Peak Performance](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/) blog post for maximizing performance with `torch.compile` for diffusion models.\n\n### Benchmarks\n\nRefer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/benchmarks) dataset to see inference latency and memory usage data for compiled pipelines.\n\nThe [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results) repository also contains benchmarking results for compiled versions of Flux and CogVideoX.\n\n## Dynamic quantization\n\n[Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.\n\nThe example below applies [dynamic int8 quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) to the UNet and VAE with the [torchao](../quantization/torchao) library.\n\n> [!TIP]\n> Refer to our [torchao](../quantization/torchao) docs to learn more about how to use the Diffusers torchao integration.\n\nConfigure the compiler tags for maximum speed.\n\n```py\nimport torch\nfrom torchao import apply_dynamic_quant\nfrom diffusers import StableDiffusionXLPipeline\n\ntorch._inductor.config.conv_1x1_as_mm = True\ntorch._inductor.config.coordinate_descent_tuning = True\ntorch._inductor.config.epilogue_fusion = False\ntorch._inductor.config.coordinate_descent_check_all_directions = True\ntorch._inductor.config.force_fuse_int_mm_with_mul = True\ntorch._inductor.config.use_mixed_mm = True\n```\n\nFilter out some linear layers in the UNet and VAE which don't benefit from dynamic quantization with the [dynamic_quant_filter_fn](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16).\n\n```py\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\napply_dynamic_quant(pipeline.unet, dynamic_quant_filter_fn)\napply_dynamic_quant(pipeline.vae, dynamic_quant_filter_fn)\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\npipeline(prompt, num_inference_steps=30).images[0]\n```\n\n## Fused projection matrices\n\n> [!WARNING]\n> The [fuse_qkv_projections](https://github.com/huggingface/diffusers/blob/58431f102cf39c3c8a569f32d71b2ea8caa461e1/src/diffusers/pipelines/pipeline_utils.py#L2034) method is experimental and support is limited to mostly Stable Diffusion pipelines. Take a look at this [PR](https://github.com/huggingface/diffusers/pull/6179) to learn more about how to enable it for other pipelines\n\nAn input is projected into three subspaces, represented by the projection matrices Q, K, and V, in an attention block. These projections are typically calculated separately, but you can horizontally combine these into a single matrix and perform the projection in a single step. It increases the size of the matrix multiplications of the input projections and also improves the impact of quantization.\n\n```py\npipeline.fuse_qkv_projections()\n```\n\n## Resources\n\n- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast).\n\n    These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev).\n- Read the [torch.compile and Diffusers: A Hands-On Guide to Peak Performance](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/) blog post\nto maximize performance when using `torch.compile`."
  },
  {
    "path": "docs/source/en/optimization/habana.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Intel Gaudi\n\nThe Intel Gaudi AI accelerator family includes [Intel Gaudi 1](https://habana.ai/products/gaudi/), [Intel Gaudi 2](https://habana.ai/products/gaudi2/), and [Intel Gaudi 3](https://habana.ai/products/gaudi3/). Each server is equipped with 8 devices, known as Habana Processing Units (HPUs), providing 128GB of memory on Gaudi 3, 96GB on Gaudi 2, and 32GB on the first-gen Gaudi. For more details on the underlying hardware architecture, check out the [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html) overview.\n\nDiffusers pipelines can take advantage of HPU acceleration, even if a pipeline hasn't been added to [Optimum for Intel Gaudi](https://huggingface.co/docs/optimum/main/en/habana/index) yet, with the [GPU Migration Toolkit](https://docs.habana.ai/en/latest/PyTorch/PyTorch_Model_Porting/GPU_Migration_Toolkit/GPU_Migration_Toolkit.html).\n\nCall `.to(\"hpu\")` on your pipeline to move it to a HPU device as shown below for Flux:\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16)\npipeline.to(\"hpu\")\n\nimage = pipeline(\"An image of a squirrel in Picasso style\").images[0]\n```\n\n> [!TIP]\n> For Gaudi-optimized diffusion pipeline implementations, we recommend using [Optimum for Intel Gaudi](https://huggingface.co/docs/optimum/main/en/habana/index).\n"
  },
  {
    "path": "docs/source/en/optimization/memory.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Reduce memory usage\n\nModern diffusion models like [Flux](../api/pipelines/flux) and [Wan](../api/pipelines/wan) have billions of parameters that take up a lot of memory on your hardware for inference. This is challenging because common GPUs often don't have sufficient memory. To overcome the memory limitations, you can use more than one GPU (if available), offload some of the pipeline components to the CPU, and more.\n\nThis guide will show you how to reduce your memory usage. \n\n> [!TIP]\n> Keep in mind these techniques may need to be adjusted depending on the model. For example, a transformer-based diffusion model may not benefit equally from these memory optimizations as a UNet-based model.\n\n## Multiple GPUs\n\nIf you have access to more than one GPU, there a few options for efficiently loading and distributing a large model across your hardware. These features are supported by the [Accelerate](https://huggingface.co/docs/accelerate/index) library, so make sure it is installed first.\n\n```bash\npip install -U accelerate\n```\n\n### Sharded checkpoints\n\nLoading large checkpoints in several shards in useful because the shards are loaded one at a time. This keeps memory usage low, only requiring enough memory for the model size and the largest shard size. We recommend sharding when the fp32 checkpoint is greater than 5GB. The default shard size is 5GB.\n\nShard a checkpoint in [`~DiffusionPipeline.save_pretrained`] with the `max_shard_size` parameter.\n\n```py\nfrom diffusers import AutoModel\n\nunet = AutoModel.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", subfolder=\"unet\"\n)\nunet.save_pretrained(\"sdxl-unet-sharded\", max_shard_size=\"5GB\")\n```\n\nNow you can use the sharded checkpoint, instead of the regular checkpoint, to save memory.\n\n```py\nimport torch\nfrom diffusers import AutoModel, StableDiffusionXLPipeline\n\nunet = AutoModel.from_pretrained(\n    \"username/sdxl-unet-sharded\", torch_dtype=torch.float16\n)\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    unet=unet,\n    torch_dtype=torch.float16\n).to(\"cuda\")\n```\n\n### Device placement\n\n> [!WARNING]\n> Device placement is an experimental feature and the API may change. Only the `balanced` strategy is supported at the moment. We plan to support additional mapping strategies in the future.\n\nThe `device_map` parameter controls how the model components in a pipeline or the layers in an individual model are distributed across devices. \n\n<hfoptions id=\"device-map\">\n<hfoption id=\"pipeline level\">\n\nThe `balanced` device placement strategy evenly splits the pipeline across all available devices.\n\n```py\nimport torch\nfrom diffusers import AutoModel, StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    device_map=\"balanced\"\n)\n```\n\nYou can inspect a pipeline's device map with `hf_device_map`.\n\n```py\nprint(pipeline.hf_device_map)\n{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}\n```\n\n</hfoption>\n<hfoption id=\"model level\">\n\nThe `device_map` is useful for loading large models, such as the Flux diffusion transformer which has 12.5B parameters. Set it to `\"auto\"` to automatically distribute a model across the fastest device first before moving to slower devices. Refer to the [Model sharding](../training/distributed_inference#model-sharding) docs for more details.\n\n```py\nimport torch\nfrom diffusers import AutoModel\n\ntransformer = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\", \n    subfolder=\"transformer\",\n    device_map=\"auto\",\n    torch_dtype=torch.bfloat16\n)\n```\n\nYou can inspect a model's device map with `hf_device_map`.\n\n```py\nprint(transformer.hf_device_map)\n```\n\n</hfoption>\n</hfoptions>\n\nWhen designing your own `device_map`, it should be a dictionary of a model's specific module name or layer and a device identifier (an integer for GPUs, `cpu` for CPUs, and `disk` for disk).\n\nCall `hf_device_map` on a model to see how model layers are distributed and then design your own.\n\n```py\nprint(transformer.hf_device_map)\n{'pos_embed': 0, 'time_text_embed': 0, 'context_embedder': 0, 'x_embedder': 0, 'transformer_blocks': 0, 'single_transformer_blocks.0': 0, 'single_transformer_blocks.1': 0, 'single_transformer_blocks.2': 0, 'single_transformer_blocks.3': 0, 'single_transformer_blocks.4': 0, 'single_transformer_blocks.5': 0, 'single_transformer_blocks.6': 0, 'single_transformer_blocks.7': 0, 'single_transformer_blocks.8': 0, 'single_transformer_blocks.9': 0, 'single_transformer_blocks.10': 'cpu', 'single_transformer_blocks.11': 'cpu', 'single_transformer_blocks.12': 'cpu', 'single_transformer_blocks.13': 'cpu', 'single_transformer_blocks.14': 'cpu', 'single_transformer_blocks.15': 'cpu', 'single_transformer_blocks.16': 'cpu', 'single_transformer_blocks.17': 'cpu', 'single_transformer_blocks.18': 'cpu', 'single_transformer_blocks.19': 'cpu', 'single_transformer_blocks.20': 'cpu', 'single_transformer_blocks.21': 'cpu', 'single_transformer_blocks.22': 'cpu', 'single_transformer_blocks.23': 'cpu', 'single_transformer_blocks.24': 'cpu', 'single_transformer_blocks.25': 'cpu', 'single_transformer_blocks.26': 'cpu', 'single_transformer_blocks.27': 'cpu', 'single_transformer_blocks.28': 'cpu', 'single_transformer_blocks.29': 'cpu', 'single_transformer_blocks.30': 'cpu', 'single_transformer_blocks.31': 'cpu', 'single_transformer_blocks.32': 'cpu', 'single_transformer_blocks.33': 'cpu', 'single_transformer_blocks.34': 'cpu', 'single_transformer_blocks.35': 'cpu', 'single_transformer_blocks.36': 'cpu', 'single_transformer_blocks.37': 'cpu', 'norm_out': 'cpu', 'proj_out': 'cpu'}\n```\n\nFor example, the `device_map` below places `single_transformer_blocks.10` through `single_transformer_blocks.20` on a second GPU (`1`).\n\n```py\nimport torch\nfrom diffusers import AutoModel\n\ndevice_map = {\n    'pos_embed': 0, 'time_text_embed': 0, 'context_embedder': 0, 'x_embedder': 0, 'transformer_blocks': 0, 'single_transformer_blocks.0': 0, 'single_transformer_blocks.1': 0, 'single_transformer_blocks.2': 0, 'single_transformer_blocks.3': 0, 'single_transformer_blocks.4': 0, 'single_transformer_blocks.5': 0, 'single_transformer_blocks.6': 0, 'single_transformer_blocks.7': 0, 'single_transformer_blocks.8': 0, 'single_transformer_blocks.9': 0, 'single_transformer_blocks.10': 1, 'single_transformer_blocks.11': 1, 'single_transformer_blocks.12': 1, 'single_transformer_blocks.13': 1, 'single_transformer_blocks.14': 1, 'single_transformer_blocks.15': 1, 'single_transformer_blocks.16': 1, 'single_transformer_blocks.17': 1, 'single_transformer_blocks.18': 1, 'single_transformer_blocks.19': 1, 'single_transformer_blocks.20': 1, 'single_transformer_blocks.21': 'cpu', 'single_transformer_blocks.22': 'cpu', 'single_transformer_blocks.23': 'cpu', 'single_transformer_blocks.24': 'cpu', 'single_transformer_blocks.25': 'cpu', 'single_transformer_blocks.26': 'cpu', 'single_transformer_blocks.27': 'cpu', 'single_transformer_blocks.28': 'cpu', 'single_transformer_blocks.29': 'cpu', 'single_transformer_blocks.30': 'cpu', 'single_transformer_blocks.31': 'cpu', 'single_transformer_blocks.32': 'cpu', 'single_transformer_blocks.33': 'cpu', 'single_transformer_blocks.34': 'cpu', 'single_transformer_blocks.35': 'cpu', 'single_transformer_blocks.36': 'cpu', 'single_transformer_blocks.37': 'cpu', 'norm_out': 'cpu', 'proj_out': 'cpu'\n}\n\ntransformer = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\", \n    subfolder=\"transformer\",\n    device_map=device_map,\n    torch_dtype=torch.bfloat16\n)\n```\n\nPass a dictionary mapping maximum memory usage to each device to enforce a limit. If a device is not in `max_memory`, it is ignored and pipeline components won't be distributed to it.\n\n```py\nimport torch\nfrom diffusers import AutoModel, StableDiffusionXLPipeline\n\nmax_memory = {0:\"1GB\", 1:\"1GB\"}\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n    max_memory=max_memory\n)\n```\n\nDiffusers uses the maxmium memory of all devices by default, but if they don't fit on the GPUs, then you'll need to use a single GPU and offload to the CPU with the methods below.\n\n- [`~DiffusionPipeline.enable_model_cpu_offload`] only works on a single GPU but a very large model may not fit on it\n- [`~DiffusionPipeline.enable_sequential_cpu_offload`] may work but it is extremely slow and also limited to a single GPU\n\nUse the [`~DiffusionPipeline.reset_device_map`] method to reset the `device_map`. This is necessary if you want to use methods like `.to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.\n\n```py\npipeline.reset_device_map()\n```\n\n## VAE slicing\n\nVAE slicing saves memory by splitting large batches of inputs into a single batch of data and separately processing them. This method works best when generating more than one image at a time.\n\nFor example, if you're generating 4 images at once, decoding would increase peak activation memory by 4x. VAE slicing reduces this by only decoding 1 image at a time instead of all 4 images at once.\n\nCall [`~StableDiffusionPipeline.enable_vae_slicing`] to enable sliced VAE. You can expect a small increase in performance when decoding multi-image batches and no performance impact for single-image batches.\n\n```py\nimport torch\nfrom diffusers import AutoModel, StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\npipeline.enable_vae_slicing()\npipeline([\"An astronaut riding a horse on Mars\"]*32).images[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n```\n\n> [!WARNING]\n> The [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] classes don't support slicing.\n\n## VAE tiling\n\nVAE tiling saves memory by dividing an image into smaller overlapping tiles instead of processing the entire image at once. This also reduces peak memory usage because the GPU is only processing a tile at a time.\n\nCall [`~StableDiffusionPipeline.enable_vae_tiling`] to enable VAE tiling. The generated image may have some tone variation from tile-to-tile because they're decoded separately, but there shouldn't be any obvious seams between the tiles. Tiling is disabled for resolutions lower than a pre-specified (but configurable) limit. For example, this limit is 512x512 for the VAE in [`StableDiffusionPipeline`].\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.enable_vae_tiling()\n\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png\")\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\npipeline(prompt, image=init_image, strength=0.5).images[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n```\n\n> [!WARNING]\n> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support tiling.\n\n## Offloading\n\nOffloading strategies move not currently active layers or models to the CPU to avoid increasing GPU memory. These strategies can be combined with quantization and torch.compile to balance inference speed and memory usage.\n\nRefer to the [Compile and offloading quantized models](./speed-memory-optims) guide for more details.\n\n### CPU offloading\n\nCPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.\n\nCPU offloading dramatically reduces memory usage, but it is also **extremely slow** because submodules are passed back and forth multiple times between devices. It can often be impractical due to how slow it is.\n\n> [!WARNING]\n> Don't move the pipeline to CUDA before calling [`~DiffusionPipeline.enable_sequential_cpu_offload`], otherwise the amount of memory saved is only minimal (refer to this [issue](https://github.com/huggingface/diffusers/issues/1934) for more details). This is a stateful operation that installs hooks on the model.\n\nCall [`~DiffusionPipeline.enable_sequential_cpu_offload`] to enable it on a pipeline.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n)\npipeline.enable_sequential_cpu_offload()\n\npipeline(\n    prompt=\"An astronaut riding a horse on Mars\",\n    guidance_scale=0.,\n    height=768,\n    width=1360,\n    num_inference_steps=4,\n    max_sequence_length=256,\n).images[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n```\n\n### Model offloading\n\nModel offloading moves entire models to the GPU instead of selectively moving *some* layers or model components. One of the main pipeline models, usually the text encoder, UNet, and VAE, is placed on the GPU while the other components are held on the CPU. Components like the UNet that run multiple times stays on the GPU until its completely finished and no longer needed. This eliminates the communication overhead of [CPU offloading](#cpu-offloading) and makes model offloading a faster alternative. The tradeoff is memory savings won't be as large.\n\n> [!WARNING]\n> Keep in mind that if models are reused outside the pipeline after hookes have been installed (see [Removing Hooks](https://huggingface.co/docs/accelerate/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module) for more details), you need to run the entire pipeline and models in the expected order to properly offload them. This is a stateful operation that installs hooks on the model.\n\nCall [`~DiffusionPipeline.enable_model_cpu_offload`] to enable it on a pipeline.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n)\npipeline.enable_model_cpu_offload()\n\npipeline(\n    prompt=\"An astronaut riding a horse on Mars\",\n    guidance_scale=0.,\n    height=768,\n    width=1360,\n    num_inference_steps=4,\n    max_sequence_length=256,\n).images[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n```\n\n[`~DiffusionPipeline.enable_model_cpu_offload`] also helps when you're using the [`~StableDiffusionXLPipeline.encode_prompt`] method on its own to generate the text encoders hidden state.\n\n### Group offloading\n\nGroup offloading moves groups of internal layers ([torch.nn.ModuleList](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) or [torch.nn.Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)) to the CPU. It uses less memory than [model offloading](#model-offloading) and it is faster than [CPU offloading](#cpu-offloading) because it reduces communication overhead.\n\n> [!WARNING]\n> Group offloading may not work with all models if the forward implementation contains weight-dependent device casting of inputs because it may clash with group offloading's device casting mechanism.\n\nEnable group offloading by configuring the `offload_type` parameter to `block_level` or `leaf_level`.\n\n- `block_level` offloads groups of layers based on the `num_blocks_per_group` parameter. For example, if `num_blocks_per_group=2` on a model with 40 layers, 2 layers are onloaded and offloaded at a time (20 total onloads/offloads). This drastically reduces memory requirements.\n- `leaf_level` offloads individual layers at the lowest level and is equivalent to [CPU offloading](#cpu-offloading). But it can be made faster if you use streams without giving up inference speed.\n\nGroup offloading is supported for entire pipelines or individual models. Applying group offloading to the entire pipeline is the easiest option while selectively applying it to individual models gives users more flexibility to use different offloading techniques for different models.\n\n<hfoptions id=\"group-offloading\">\n<hfoption id=\"pipeline\">\n\nCall [`~DiffusionPipeline.enable_group_offload`] on a pipeline.\n\n```py\nimport torch\nfrom diffusers import CogVideoXPipeline\nfrom diffusers.hooks import apply_group_offloading\nfrom diffusers.utils import export_to_video\n\nonload_device = torch.device(\"cuda\")\noffload_device = torch.device(\"cpu\")\n\npipeline = CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-5b\", torch_dtype=torch.bfloat16)\npipeline.enable_group_offload(\n    onload_device=onload_device,\n    offload_device=offload_device,\n    offload_type=\"leaf_level\",\n    use_stream=True\n)\n\nprompt = (\n    \"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. \"\n    \"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other \"\n    \"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, \"\n    \"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. \"\n    \"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical \"\n    \"atmosphere of this unique musical performance.\"\n)\nvideo = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\nexport_to_video(video, \"output.mp4\", fps=8)\n```\n\n</hfoption>\n<hfoption id=\"model\">\n\nCall [`~ModelMixin.enable_group_offload`] on standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.\n\n```py\nimport torch\nfrom diffusers import CogVideoXPipeline\nfrom diffusers.hooks import apply_group_offloading\nfrom diffusers.utils import export_to_video\n\nonload_device = torch.device(\"cuda\")\noffload_device = torch.device(\"cpu\")\npipeline = CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-5b\", torch_dtype=torch.bfloat16)\n\n# Use the enable_group_offload method for Diffusers model implementations\npipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=\"leaf_level\")\npipeline.vae.enable_group_offload(onload_device=onload_device, offload_type=\"leaf_level\")\n\n# Use the apply_group_offloading method for other model components\napply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type=\"block_level\", num_blocks_per_group=2)\n\nprompt = (\n    \"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. \"\n    \"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other \"\n    \"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, \"\n    \"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. \"\n    \"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical \"\n    \"atmosphere of this unique musical performance.\"\n)\nvideo = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\nexport_to_video(video, \"output.mp4\", fps=8)\n```\n\n</hfoption>\n</hfoptions>\n\n#### CUDA stream\n\nThe `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.\n\nSet `record_stream=True` for more of a speedup at the cost of slightly increased memory usage. Refer to the [torch.Tensor.record_stream](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) docs to learn more.\n\n> [!TIP]\n> When `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possible with dummy inputs as well) before inference to avoid device mismatch errors. This may not work on all implementations, so feel free to open an issue if you encounter any problems.\n\nIf you're using `block_level` group offloading with `use_stream` enabled, the `num_blocks_per_group` parameter should be set to `1`, otherwise a warning will be raised.\n\n```py\npipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=\"leaf_level\", use_stream=True, record_stream=True)\n```\n\nThe `low_cpu_mem_usage` parameter can be set to `True` to reduce CPU memory usage when using streams during group offloading. It is best for `leaf_level` offloading and when CPU memory is bottlenecked. Memory is saved by creating pinned tensors on the fly instead of pre-pinning them. However, this may increase overall execution time.\n\n#### Offloading to disk\n\nGroup offloading can consume significant system memory depending on the model size. On systems with limited memory, try group offloading onto the disk as a secondary memory.\n\nSet the `offload_to_disk_path` argument in either [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`] to offload the model to the disk.\n\n```py\npipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=\"leaf_level\", offload_to_disk_path=\"path/to/disk\")\n\napply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type=\"block_level\", num_blocks_per_group=2, offload_to_disk_path=\"path/to/disk\")\n```\n\nRefer to these [two](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) [tables](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) to compare the speed and memory trade-offs.\n\n## Layerwise casting\n\n> [!TIP]\n> Combine layerwise casting with [group offloading](#group-offloading) for even more memory savings.\n\nLayerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.\n\n> [!WARNING]\n> Layerwise casting may not work with all models if the forward implementation contains internal typecasting of weights. The current implementation of layerwise casting assumes the forward pass is independent of the weight precision and the input datatypes are always specified in `compute_dtype` (see [here](https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299) for an incompatible implementation).\n>\n> Layerwise casting may also fail on custom modeling implementations with [PEFT](https://huggingface.co/docs/peft/index) layers. There are some checks available but they are not extensively tested or guaranteed to work in all cases.\n\nCall [`~ModelMixin.enable_layerwise_casting`] to set the storage and computation datatypes.\n\n```py\nimport torch\nfrom diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel\nfrom diffusers.utils import export_to_video\n\ntransformer = CogVideoXTransformer3DModel.from_pretrained(\n    \"THUDM/CogVideoX-5b\",\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16\n)\ntransformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)\n\npipeline = CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-5b\",\n    transformer=transformer,\n    torch_dtype=torch.bfloat16\n).to(\"cuda\")\nprompt = (\n    \"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. \"\n    \"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other \"\n    \"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, \"\n    \"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. \"\n    \"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical \"\n    \"atmosphere of this unique musical performance.\"\n)\nvideo = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\nexport_to_video(video, \"output.mp4\", fps=8)\n```\n\nThe [`~hooks.apply_layerwise_casting`] method can also be used if you need more control and flexibility. It can be partially applied to model layers by calling it on specific internal modules. Use the `skip_modules_pattern` or `skip_modules_classes` parameters to specify modules to avoid, such as the normalization and modulation layers.\n\n```python\nimport torch\nfrom diffusers import CogVideoXTransformer3DModel\nfrom diffusers.hooks import apply_layerwise_casting\n\ntransformer = CogVideoXTransformer3DModel.from_pretrained(\n    \"THUDM/CogVideoX-5b\",\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16\n)\n\n# skip the normalization layer\napply_layerwise_casting(\n    transformer,\n    storage_dtype=torch.float8_e4m3fn,\n    compute_dtype=torch.bfloat16,\n    skip_modules_classes=[\"norm\"],\n    non_blocking=True,\n)\n```\n\n## torch.channels_last\n\n[torch.channels_last](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html) flips how tensors are stored from `(batch size, channels, height, width)` to `(batch size, heigh, width, channels)`. This aligns the tensors with how the hardware sequentially accesses the tensors stored in memory and avoids skipping around in memory to access the pixel values.\n\nNot all operators currently support the channels-last format and may result in worst performance, but it is still worth trying.\n\n```py\nprint(pipeline.unet.conv_out.state_dict()[\"weight\"].stride())  # (2880, 9, 3, 1)\npipeline.unet.to(memory_format=torch.channels_last)  # in-place operation\nprint(\n    pipeline.unet.conv_out.state_dict()[\"weight\"].stride()\n)  # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works\n```\n\n## Memory-efficient attention\n\nDiffusers supports multiple memory-efficient attention backends (FlashAttention, xFormers, SageAttention, and more) through [`~ModelMixin.set_attention_backend`]. Refer to the [Attention backends](./attention_backends) guide to learn how to switch between them.\n"
  },
  {
    "path": "docs/source/en/optimization/mps.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Metal Performance Shaders (MPS)\n\n> [!TIP]\n> Pipelines with a <img alt=\"MPS\" src=\"https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22\"> badge indicate a model can take advantage of the MPS backend on Apple silicon devices for faster inference. Feel free to open a [Pull Request](https://github.com/huggingface/diffusers/compare) to add this badge to pipelines that are missing it.\n\n🤗 Diffusers is compatible with Apple silicon (M1/M2 chips) using the PyTorch [`mps`](https://pytorch.org/docs/stable/notes/mps.html) device, which uses the Metal framework to leverage the GPU on MacOS devices. You'll need to have:\n\n- macOS computer with Apple silicon (M1/M2) hardware\n- macOS 12.6 or later (13.0 or later recommended)\n- arm64 version of Python\n- [PyTorch 2.0](https://pytorch.org/get-started/locally/) (recommended) or 1.13 (minimum version supported for `mps`)\n\nThe `mps` backend uses PyTorch's `.to()` interface to move the Stable Diffusion pipeline on to your M1 or M2 device:\n\n```python\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\")\npipe = pipe.to(\"mps\")\n\n# Recommended if your computer has < 64 GB of RAM\npipe.enable_attention_slicing()\n\nprompt = \"a photo of an astronaut riding a horse on mars\"\nimage = pipe(prompt).images[0]\nimage\n```\n\n> [!WARNING]\n> The PyTorch [mps](https://pytorch.org/docs/stable/notes/mps.html) backend does not support NDArray sizes greater than `2**32`. Please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) if you encounter this problem so we can investigate.\n\nIf you're using **PyTorch 1.13**, you need to \"prime\" the pipeline with an additional one-time pass through it. This is a temporary workaround for an issue where the first inference pass produces slightly different results than subsequent ones. You only need to do this pass once, and after just one inference step you can discard the result.\n\n```diff\n  from diffusers import DiffusionPipeline\n\n  pipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\").to(\"mps\")\n  pipe.enable_attention_slicing()\n\n  prompt = \"a photo of an astronaut riding a horse on mars\"\n  # First-time \"warmup\" pass if PyTorch version is 1.13\n+ _ = pipe(prompt, num_inference_steps=1)\n\n  # Results match those from the CPU device after the warmup pass.\n  image = pipe(prompt).images[0]\n```\n\n## Troubleshoot\n\nThis section lists some common issues with using the `mps` backend and how to solve them.\n\n### Attention slicing\n\nM1/M2 performance is very sensitive to memory pressure. When this occurs, the system automatically swaps if it needs to which significantly degrades performance.\n\nTo prevent this from happening, we recommend *attention slicing* to reduce memory pressure during inference and prevent swapping. This is especially relevant if your computer has less than 64GB of system RAM, or if you generate images at non-standard resolutions larger than 512×512 pixels. Call the [`~DiffusionPipeline.enable_attention_slicing`] function on your pipeline:\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True).to(\"mps\")\npipeline.enable_attention_slicing()\n```\n\nAttention slicing performs the costly attention operation in multiple steps instead of all at once. It usually improves performance by ~20% in computers without universal memory, but we've observed *better performance* in most Apple silicon computers unless you have 64GB of RAM or more.\n\n### Batch inference\n\nGenerating multiple prompts in a batch can crash or fail to work reliably. If this is the case, try iterating instead of batching."
  },
  {
    "path": "docs/source/en/optimization/neuron.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AWS Neuron\n\nDiffusers functionalities are available on [AWS Inf2 instances](https://aws.amazon.com/ec2/instance-types/inf2/), which are EC2 instances powered by [Neuron machine learning accelerators](https://aws.amazon.com/machine-learning/inferentia/). These instances aim to provide better compute performance (higher throughput, lower latency) with good cost-efficiency, making them good candidates for AWS users to deploy diffusion models to production.\n\n[Optimum Neuron](https://huggingface.co/docs/optimum-neuron/en/index) is the interface between Hugging Face libraries and AWS Accelerators, including AWS [Trainium](https://aws.amazon.com/machine-learning/trainium/) and AWS [Inferentia](https://aws.amazon.com/machine-learning/inferentia/). It supports many of the features in Diffusers with similar APIs, so it is easier to learn if you're already familiar with Diffusers. Once you have created an AWS Inf2 instance, install Optimum Neuron.\n\n```bash\npython -m pip install --upgrade-strategy eager optimum[neuronx]\n```\n\n> [!TIP]\n> We provide pre-built [Hugging Face Neuron Deep Learning AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2) (DLAMI) and Optimum Neuron containers for Amazon SageMaker. It's recommended to correctly set up your environment.\n\nThe example below demonstrates how to generate images with the Stable Diffusion XL model on an inf2.8xlarge instance (you can switch to cheaper inf2.xlarge instances once the model is compiled). To generate some images, use the [`~optimum.neuron.NeuronStableDiffusionXLPipeline`] class, which is similar to the [`StableDiffusionXLPipeline`] class in Diffusers.\n\nUnlike Diffusers, you need to compile models in the pipeline to the Neuron format, `.neuron`. Launch the following command to export the model to the `.neuron` format.\n\n```bash\noptimum-cli export neuron --model stabilityai/stable-diffusion-xl-base-1.0 \\\n  --batch_size 1 \\\n  --height 1024 `# height in pixels of generated image, eg. 768, 1024` \\\n  --width 1024 `# width in pixels of generated image, eg. 768, 1024` \\\n  --num_images_per_prompt 1 `# number of images to generate per prompt, defaults to 1` \\\n  --auto_cast matmul `# cast only matrix multiplication operations` \\\n  --auto_cast_type bf16 `# cast operations from FP32 to BF16` \\\n  sd_neuron_xl/\n```\n\nNow generate some images with the pre-compiled SDXL model.\n\n```python\n>>> from optimum.neuron import NeuronStableDiffusionXLPipeline\n\n>>> stable_diffusion_xl = NeuronStableDiffusionXLPipeline.from_pretrained(\"sd_neuron_xl/\")\n>>> prompt = \"a pig with wings flying in floating US dollar banknotes in the air, skyscrapers behind, warm color palette, muted colors, detailed, 8k\"\n>>> image = stable_diffusion_xl(prompt).images[0]\n```\n\n<img\n  src=\"https://huggingface.co/datasets/Jingya/document_images/resolve/main/optimum/neuron/sdxl_pig.png\"\n  width=\"256\"\n  height=\"256\"\n  alt=\"peggy generated by sdxl on inf2\"\n/>\n\nFeel free to check out more guides and examples on different use cases from the Optimum Neuron [documentation](https://huggingface.co/docs/optimum-neuron/en/inference_tutorials/stable_diffusion#generate-images-with-stable-diffusion-models-on-aws-inferentia)!\n"
  },
  {
    "path": "docs/source/en/optimization/onnx.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ONNX Runtime\n\n🤗 [Optimum](https://github.com/huggingface/optimum) provides a Stable Diffusion pipeline compatible with ONNX Runtime. You'll need to install 🤗 Optimum with the following command for ONNX Runtime support:\n\n```bash\npip install -q optimum[\"onnxruntime\"]\n```\n\nThis guide will show you how to use the Stable Diffusion and Stable Diffusion XL (SDXL) pipelines with ONNX Runtime.\n\n## Stable Diffusion\n\nTo load and run inference, use the [`~optimum.onnxruntime.ORTStableDiffusionPipeline`]. If you want to load a PyTorch model and convert it to the ONNX format on-the-fly, set `export=True`:\n\n```python\nfrom optimum.onnxruntime import ORTStableDiffusionPipeline\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipeline = ORTStableDiffusionPipeline.from_pretrained(model_id, export=True)\nprompt = \"sailing ship in storm by Leonardo da Vinci\"\nimage = pipeline(prompt).images[0]\npipeline.save_pretrained(\"./onnx-stable-diffusion-v1-5\")\n```\n\n> [!WARNING]\n> Generating multiple prompts in a batch seems to take too much memory. While we look into it, you may need to iterate instead of batching.\n\nTo export the pipeline in the ONNX format offline and use it later for inference,\nuse the [`optimum-cli export`](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) command:\n\n```bash\noptimum-cli export onnx --model stable-diffusion-v1-5/stable-diffusion-v1-5 sd_v15_onnx/\n```\n\nThen to perform inference (you don't have to specify `export=True` again):\n\n```python\nfrom optimum.onnxruntime import ORTStableDiffusionPipeline\n\nmodel_id = \"sd_v15_onnx\"\npipeline = ORTStableDiffusionPipeline.from_pretrained(model_id)\nprompt = \"sailing ship in storm by Leonardo da Vinci\"\nimage = pipeline(prompt).images[0]\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/optimum/documentation-images/resolve/main/onnxruntime/stable_diffusion_v1_5_ort_sail_boat.png\">\n</div>\n\nYou can find more examples in 🤗 Optimum [documentation](https://huggingface.co/docs/optimum/), and Stable Diffusion is supported for text-to-image, image-to-image, and inpainting.\n\n## Stable Diffusion XL\n\nTo load and run inference with SDXL, use the [`~optimum.onnxruntime.ORTStableDiffusionXLPipeline`]:\n\n```python\nfrom optimum.onnxruntime import ORTStableDiffusionXLPipeline\n\nmodel_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\npipeline = ORTStableDiffusionXLPipeline.from_pretrained(model_id)\nprompt = \"sailing ship in storm by Leonardo da Vinci\"\nimage = pipeline(prompt).images[0]\n```\n\nTo export the pipeline in the ONNX format and use it later for inference, use the [`optimum-cli export`](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) command:\n\n```bash\noptimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl sd_xl_onnx/\n```\n\nSDXL in the ONNX format is supported for text-to-image and image-to-image.\n"
  },
  {
    "path": "docs/source/en/optimization/open_vino.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# OpenVINO\n\n🤗 [Optimum](https://github.com/huggingface/optimum-intel) provides Stable Diffusion pipelines compatible with OpenVINO to perform inference on a variety of Intel processors (see the [full list](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html) of supported devices).\n\nYou'll need to install 🤗 Optimum Intel with the `--upgrade-strategy eager` option to ensure [`optimum-intel`](https://github.com/huggingface/optimum-intel) is using the latest version:\n\n```bash\npip install --upgrade-strategy eager optimum[\"openvino\"]\n```\n\nThis guide will show you how to use the Stable Diffusion and Stable Diffusion XL (SDXL) pipelines with OpenVINO.\n\n## Stable Diffusion\n\nTo load and run inference, use the [`~optimum.intel.OVStableDiffusionPipeline`]. If you want to load a PyTorch model and convert it to the OpenVINO format on-the-fly, set `export=True`:\n\n```python\nfrom optimum.intel import OVStableDiffusionPipeline\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipeline = OVStableDiffusionPipeline.from_pretrained(model_id, export=True)\nprompt = \"sailing ship in storm by Rembrandt\"\nimage = pipeline(prompt).images[0]\n\n# Don't forget to save the exported model\npipeline.save_pretrained(\"openvino-sd-v1-5\")\n```\n\nTo further speed-up inference, statically reshape the model. If you change any parameters such as the outputs height or width, you’ll need to statically reshape your model again.\n\n```python\n# Define the shapes related to the inputs and desired outputs\nbatch_size, num_images, height, width = 1, 1, 512, 512\n\n# Statically reshape the model\npipeline.reshape(batch_size, height, width, num_images)\n# Compile the model before inference\npipeline.compile()\n\nimage = pipeline(\n    prompt,\n    height=height,\n    width=width,\n    num_images_per_prompt=num_images,\n).images[0]\n```\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/optimum/documentation-images/resolve/main/intel/openvino/stable_diffusion_v1_5_sail_boat_rembrandt.png\">\n</div>\n\nYou can find more examples in the 🤗 Optimum [documentation](https://huggingface.co/docs/optimum/intel/inference#stable-diffusion), and Stable Diffusion is supported for text-to-image, image-to-image, and inpainting.\n\n## Stable Diffusion XL\n\nTo load and run inference with SDXL, use the [`~optimum.intel.OVStableDiffusionXLPipeline`]:\n\n```python\nfrom optimum.intel import OVStableDiffusionXLPipeline\n\nmodel_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\npipeline = OVStableDiffusionXLPipeline.from_pretrained(model_id)\nprompt = \"sailing ship in storm by Rembrandt\"\nimage = pipeline(prompt).images[0]\n```\n\nTo further speed-up inference, [statically reshape](#stable-diffusion) the model as shown in the Stable Diffusion section.\n\nYou can find more examples in the 🤗 Optimum [documentation](https://huggingface.co/docs/optimum/intel/inference#stable-diffusion-xl), and running SDXL in OpenVINO is supported for text-to-image and image-to-image.\n"
  },
  {
    "path": "docs/source/en/optimization/para_attn.md",
    "content": "# ParaAttention\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-performance.png\">\n</div>\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/hunyuan-video-performance.png\">\n</div>\n\n\nLarge image and video generation models, such as [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) and [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo), can be an inference challenge for real-time applications and deployment because of their size.\n\n[ParaAttention](https://github.com/chengzeyi/ParaAttention) is a library that implements **context parallelism** and **first block cache**, and can be combined with other techniques (torch.compile, fp8 dynamic quantization), to accelerate inference.\n\nThis guide will show you how to apply ParaAttention to FLUX.1-dev and HunyuanVideo on NVIDIA L20 GPUs.\nNo optimizations are applied for our baseline benchmark, except for HunyuanVideo to avoid out-of-memory errors.\n\nOur baseline benchmark shows that FLUX.1-dev is able to generate a 1024x1024 resolution image in 28 steps in 26.36 seconds, and HunyuanVideo is able to generate 129 frames at 720p resolution in 30 steps in 3675.71 seconds.\n\n> [!TIP]\n> For even faster inference with context parallelism, try using NVIDIA A100 or H100 GPUs (if available) with NVLink support, especially when there is a large number of GPUs.\n\n## First Block Cache\n\nCaching the output of the transformers blocks in the model and reusing them in the next inference steps reduces the computation cost and makes inference faster.\n\nHowever, it is hard to decide when to reuse the cache to ensure quality generated images or videos. ParaAttention directly uses the **residual difference of the first transformer block output** to approximate the difference among model outputs. When the difference is small enough, the residual difference of previous inference steps is reused. In other words, the denoising step is skipped.\n\nThis achieves a 2x speedup on FLUX.1-dev and HunyuanVideo inference with very good quality.\n\n<figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/ada-cache.png\" alt=\"Cache in Diffusion Transformer\" />\n    <figcaption>How AdaCache works, First Block Cache is a variant of it</figcaption>\n</figure>\n\n<hfoptions id=\"first-block-cache\">\n<hfoption id=\"FLUX-1.dev\">\n\nTo apply first block cache on FLUX.1-dev, call `apply_cache_on_pipe` as shown below. 0.08 is the default residual difference value for FLUX models.\n\n```python\nimport time\nimport torch\nfrom diffusers import FluxPipeline\n\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(pipe, residual_diff_threshold=0.08)\n\n# Enable memory savings\n# pipe.enable_model_cpu_offload()\n# pipe.enable_sequential_cpu_offload()\n\nbegin = time.time()\nimage = pipe(\n    \"A cat holding a sign that says hello world\",\n    num_inference_steps=28,\n).images[0]\nend = time.time()\nprint(f\"Time: {end - begin:.2f}s\")\n\nprint(\"Saving image to flux.png\")\nimage.save(\"flux.png\")\n```\n\n| Optimizations | Original | FBCache rdt=0.06 | FBCache rdt=0.08 | FBCache rdt=0.10 | FBCache rdt=0.12 |\n| - | - | - | - | - | - |\n| Preview | ![Original](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-original.png) | ![FBCache rdt=0.06](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.06.png) | ![FBCache rdt=0.08](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.08.png) | ![FBCache rdt=0.10](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.10.png) | ![FBCache rdt=0.12](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.12.png) |\n| Wall Time (s) | 26.36 | 21.83 | 17.01 | 16.00 | 13.78 |\n\nFirst Block Cache reduced the inference speed to 17.01 seconds compared to the baseline, or 1.55x faster, while maintaining nearly zero quality loss.\n\n</hfoption>\n<hfoption id=\"HunyuanVideo\">\n\nTo apply First Block Cache on HunyuanVideo, `apply_cache_on_pipe` as shown below. 0.06 is the default residual difference value for HunyuanVideo models.\n\n```python\nimport time\nimport torch\nfrom diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel\nfrom diffusers.utils import export_to_video\n\nmodel_id = \"tencent/HunyuanVideo\"\ntransformer = HunyuanVideoTransformer3DModel.from_pretrained(\n    model_id,\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16,\n    revision=\"refs/pr/18\",\n)\npipe = HunyuanVideoPipeline.from_pretrained(\n    model_id,\n    transformer=transformer,\n    torch_dtype=torch.float16,\n    revision=\"refs/pr/18\",\n).to(\"cuda\")\n\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(pipe, residual_diff_threshold=0.6)\n\npipe.vae.enable_tiling()\n\nbegin = time.time()\noutput = pipe(\n    prompt=\"A cat walks on the grass, realistic\",\n    height=720,\n    width=1280,\n    num_frames=129,\n    num_inference_steps=30,\n).frames[0]\nend = time.time()\nprint(f\"Time: {end - begin:.2f}s\")\n\nprint(\"Saving video to hunyuan_video.mp4\")\nexport_to_video(output, \"hunyuan_video.mp4\", fps=15)\n```\n\n<video controls>\n  <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/hunyuan-video-original.mp4\" type=\"video/mp4\">\n  Your browser does not support the video tag.\n</video>\n\n<small> HunyuanVideo without FBCache </small>\n\n<video controls>\n  <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/hunyuan-video-fbc.mp4\" type=\"video/mp4\">\n  Your browser does not support the video tag.\n</video>\n\n<small> HunyuanVideo with FBCache </small>\n\nFirst Block Cache reduced the inference speed to 2271.06 seconds compared to the baseline, or 1.62x faster, while maintaining nearly zero quality loss.\n\n</hfoption>\n</hfoptions>\n\n## fp8 quantization\n\nfp8 with dynamic quantization further speeds up inference and reduces memory usage. Both the activations and weights must be quantized in order to use the 8-bit [NVIDIA Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/).\n\nUse `float8_weight_only` and `float8_dynamic_activation_float8_weight` to quantize the text encoder and transformer model.\n\nThe default quantization method is per tensor quantization, but if your GPU supports row-wise quantization, you can also try it for better accuracy.\n\nInstall [torchao](https://github.com/pytorch/ao/tree/main) with the command below.\n\n```bash\npip3 install -U torch torchao\n```\n\n[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) with `mode=\"max-autotune-no-cudagraphs\"` or `mode=\"max-autotune\"` selects the best kernel for performance. Compilation can take a long time if it's the first time the model is called, but it is worth it once the model has been compiled.\n\nThis example only quantizes the transformer model, but you can also quantize the text encoder to reduce memory usage even more.\n\n> [!TIP]\n> Dynamic quantization can significantly change the distribution of the model output, so you need to change the `residual_diff_threshold` to a larger value for it to take effect.\n\n<hfoptions id=\"fp8-quantization\">\n<hfoption id=\"FLUX-1.dev\">\n\n```python\nimport time\nimport torch\nfrom diffusers import FluxPipeline\n\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(\n    pipe,\n    residual_diff_threshold=0.12,  # Use a larger value to make the cache take effect\n)\n\nfrom torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only\n\nquantize_(pipe.text_encoder, float8_weight_only())\nquantize_(pipe.transformer, float8_dynamic_activation_float8_weight())\npipe.transformer = torch.compile(\n   pipe.transformer, mode=\"max-autotune-no-cudagraphs\",\n)\n\n# Enable memory savings\n# pipe.enable_model_cpu_offload()\n# pipe.enable_sequential_cpu_offload()\n\nfor i in range(2):\n    begin = time.time()\n    image = pipe(\n        \"A cat holding a sign that says hello world\",\n        num_inference_steps=28,\n    ).images[0]\n    end = time.time()\n    if i == 0:\n        print(f\"Warm up time: {end - begin:.2f}s\")\n    else:\n        print(f\"Time: {end - begin:.2f}s\")\n\nprint(\"Saving image to flux.png\")\nimage.save(\"flux.png\")\n```\n\nfp8 dynamic quantization and torch.compile reduced the inference speed to 7.56 seconds compared to the baseline, or 3.48x faster.\n\n</hfoption>\n<hfoption id=\"HunyuanVideo\">\n\n```python\nimport time\nimport torch\nfrom diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel\nfrom diffusers.utils import export_to_video\n\nmodel_id = \"tencent/HunyuanVideo\"\ntransformer = HunyuanVideoTransformer3DModel.from_pretrained(\n    model_id,\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16,\n    revision=\"refs/pr/18\",\n)\npipe = HunyuanVideoPipeline.from_pretrained(\n    model_id,\n    transformer=transformer,\n    torch_dtype=torch.float16,\n    revision=\"refs/pr/18\",\n).to(\"cuda\")\n\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(pipe)\n\nfrom torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only\n\nquantize_(pipe.text_encoder, float8_weight_only())\nquantize_(pipe.transformer, float8_dynamic_activation_float8_weight())\npipe.transformer = torch.compile(\n   pipe.transformer, mode=\"max-autotune-no-cudagraphs\",\n)\n\n# Enable memory savings\npipe.vae.enable_tiling()\n# pipe.enable_model_cpu_offload()\n# pipe.enable_sequential_cpu_offload()\n\nfor i in range(2):\n    begin = time.time()\n    output = pipe(\n        prompt=\"A cat walks on the grass, realistic\",\n        height=720,\n        width=1280,\n        num_frames=129,\n        num_inference_steps=1 if i == 0 else 30,\n    ).frames[0]\n    end = time.time()\n    if i == 0:\n        print(f\"Warm up time: {end - begin:.2f}s\")\n    else:\n        print(f\"Time: {end - begin:.2f}s\")\n\nprint(\"Saving video to hunyuan_video.mp4\")\nexport_to_video(output, \"hunyuan_video.mp4\", fps=15)\n```\n\nA NVIDIA L20 GPU only has 48GB memory and could face out-of-memory (OOM) errors after compilation and if `enable_model_cpu_offload` isn't called because HunyuanVideo has very large activation tensors when running with high resolution and large number of frames. For GPUs with less than 80GB of memory, you can try reducing the resolution and number of frames to avoid OOM errors.\n\nLarge video generation models are usually bottlenecked by the attention computations rather than the fully connected layers. These models don't significantly benefit from quantization and torch.compile.\n\n</hfoption>\n</hfoptions>\n\n## Context Parallelism\n\nContext Parallelism parallelizes inference and scales with multiple GPUs. The ParaAttention compositional design allows you to combine Context Parallelism with First Block Cache and dynamic quantization.\n\n> [!TIP]\n> Refer to the [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main) repository for detailed instructions and examples of how to scale inference with multiple GPUs.\n\nIf the inference process needs to be persistent and serviceable, it is suggested to use [torch.multiprocessing](https://pytorch.org/docs/stable/multiprocessing.html) to write your own inference processor. This can eliminate the overhead of launching the process and loading and recompiling the model.\n\n<hfoptions id=\"context-parallelism\">\n<hfoption id=\"FLUX-1.dev\">\n\nThe code sample below combines First Block Cache, fp8 dynamic quantization, torch.compile, and Context Parallelism for the fastest inference speed.\n\n```python\nimport time\nimport torch\nimport torch.distributed as dist\nfrom diffusers import FluxPipeline\n\ndist.init_process_group()\n\ntorch.cuda.set_device(dist.get_rank())\n\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\nfrom para_attn.context_parallel import init_context_parallel_mesh\nfrom para_attn.context_parallel.diffusers_adapters import parallelize_pipe\nfrom para_attn.parallel_vae.diffusers_adapters import parallelize_vae\n\nmesh = init_context_parallel_mesh(\n    pipe.device.type,\n    max_ring_dim_size=2,\n)\nparallelize_pipe(\n    pipe,\n    mesh=mesh,\n)\nparallelize_vae(pipe.vae, mesh=mesh._flatten())\n\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(\n    pipe,\n    residual_diff_threshold=0.12,  # Use a larger value to make the cache take effect\n)\n\nfrom torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only\n\nquantize_(pipe.text_encoder, float8_weight_only())\nquantize_(pipe.transformer, float8_dynamic_activation_float8_weight())\ntorch._inductor.config.reorder_for_compute_comm_overlap = True\npipe.transformer = torch.compile(\n   pipe.transformer, mode=\"max-autotune-no-cudagraphs\",\n)\n\n# Enable memory savings\n# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())\n# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank())\n\nfor i in range(2):\n    begin = time.time()\n    image = pipe(\n        \"A cat holding a sign that says hello world\",\n        num_inference_steps=28,\n        output_type=\"pil\" if dist.get_rank() == 0 else \"pt\",\n    ).images[0]\n    end = time.time()\n    if dist.get_rank() == 0:\n        if i == 0:\n            print(f\"Warm up time: {end - begin:.2f}s\")\n        else:\n            print(f\"Time: {end - begin:.2f}s\")\n\nif dist.get_rank() == 0:\n    print(\"Saving image to flux.png\")\n    image.save(\"flux.png\")\n\ndist.destroy_process_group()\n```\n\nSave to `run_flux.py` and launch it with [torchrun](https://pytorch.org/docs/stable/elastic/run.html).\n\n```bash\n# Use --nproc_per_node to specify the number of GPUs\ntorchrun --nproc_per_node=2 run_flux.py\n```\n\nInference speed is reduced to 8.20 seconds compared to the baseline, or 3.21x faster, with 2 NVIDIA L20 GPUs. On 4 L20s, inference speed is 3.90 seconds, or 6.75x faster.\n\n</hfoption>\n<hfoption id=\"HunyuanVideo\">\n\nThe code sample below combines First Block Cache and Context Parallelism for the fastest inference speed.\n\n```python\nimport time\nimport torch\nimport torch.distributed as dist\nfrom diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel\nfrom diffusers.utils import export_to_video\n\ndist.init_process_group()\n\ntorch.cuda.set_device(dist.get_rank())\n\nmodel_id = \"tencent/HunyuanVideo\"\ntransformer = HunyuanVideoTransformer3DModel.from_pretrained(\n    model_id,\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16,\n    revision=\"refs/pr/18\",\n)\npipe = HunyuanVideoPipeline.from_pretrained(\n    model_id,\n    transformer=transformer,\n    torch_dtype=torch.float16,\n    revision=\"refs/pr/18\",\n).to(\"cuda\")\n\nfrom para_attn.context_parallel import init_context_parallel_mesh\nfrom para_attn.context_parallel.diffusers_adapters import parallelize_pipe\nfrom para_attn.parallel_vae.diffusers_adapters import parallelize_vae\n\nmesh = init_context_parallel_mesh(\n    pipe.device.type,\n)\nparallelize_pipe(\n    pipe,\n    mesh=mesh,\n)\nparallelize_vae(pipe.vae, mesh=mesh._flatten())\n\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(pipe)\n\n# from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only\n#\n# torch._inductor.config.reorder_for_compute_comm_overlap = True\n#\n# quantize_(pipe.text_encoder, float8_weight_only())\n# quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())\n# pipe.transformer = torch.compile(\n#    pipe.transformer, mode=\"max-autotune-no-cudagraphs\",\n# )\n\n# Enable memory savings\npipe.vae.enable_tiling()\n# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())\n# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank())\n\nfor i in range(2):\n    begin = time.time()\n    output = pipe(\n        prompt=\"A cat walks on the grass, realistic\",\n        height=720,\n        width=1280,\n        num_frames=129,\n        num_inference_steps=1 if i == 0 else 30,\n        output_type=\"pil\" if dist.get_rank() == 0 else \"pt\",\n    ).frames[0]\n    end = time.time()\n    if dist.get_rank() == 0:\n        if i == 0:\n            print(f\"Warm up time: {end - begin:.2f}s\")\n        else:\n            print(f\"Time: {end - begin:.2f}s\")\n\nif dist.get_rank() == 0:\n    print(\"Saving video to hunyuan_video.mp4\")\n    export_to_video(output, \"hunyuan_video.mp4\", fps=15)\n\ndist.destroy_process_group()\n```\n\nSave to `run_hunyuan_video.py` and launch it with [torchrun](https://pytorch.org/docs/stable/elastic/run.html).\n\n```bash\n# Use --nproc_per_node to specify the number of GPUs\ntorchrun --nproc_per_node=8 run_hunyuan_video.py\n```\n\nInference speed is reduced to 649.23 seconds compared to the baseline, or 5.66x faster, with 8 NVIDIA L20 GPUs.\n\n</hfoption>\n</hfoptions>\n\n## Benchmarks\n\n<hfoptions id=\"conclusion\">\n<hfoption id=\"FLUX-1.dev\">\n\n| GPU Type | Number of GPUs | Optimizations | Wall Time (s) | Speedup |\n| - | - | - | - | - |\n| NVIDIA L20 | 1 | Baseline | 26.36 | 1.00x |\n| NVIDIA L20 | 1 | FBCache (rdt=0.08) | 17.01 | 1.55x |\n| NVIDIA L20 | 1 | FP8 DQ | 13.40 | 1.96x |\n| NVIDIA L20 | 1 | FBCache (rdt=0.12) + FP8 DQ | 7.56 | 3.48x |\n| NVIDIA L20 | 2 | FBCache (rdt=0.12) + FP8 DQ + CP | 4.92 | 5.35x |\n| NVIDIA L20 | 4 | FBCache (rdt=0.12) + FP8 DQ + CP | 3.90 | 6.75x |\n\n</hfoption>\n<hfoption id=\"HunyuanVideo\">\n\n| GPU Type | Number of GPUs | Optimizations | Wall Time (s) | Speedup |\n| - | - | - | - | - |\n| NVIDIA L20 | 1 | Baseline | 3675.71 | 1.00x |\n| NVIDIA L20 | 1 | FBCache | 2271.06 | 1.62x |\n| NVIDIA L20 | 2 | FBCache + CP | 1132.90 | 3.24x |\n| NVIDIA L20 | 4 | FBCache + CP | 718.15 | 5.12x |\n| NVIDIA L20 | 8 | FBCache + CP | 649.23 | 5.66x |\n\n</hfoption>\n</hfoptions>\n"
  },
  {
    "path": "docs/source/en/optimization/pruna.md",
    "content": "# Pruna\n\n[Pruna](https://github.com/PrunaAI/pruna) is a model optimization framework that offers various optimization methods - quantization, pruning, caching, compilation - for accelerating inference and reducing memory usage. A general overview of the optimization methods are shown below.\n\n\n| Technique    | Description                                                                                   | Speed | Memory | Quality |\n|--------------|-----------------------------------------------------------------------------------------------|:-----:|:------:|:-------:|\n| `batcher`    | Groups multiple inputs together to be processed simultaneously, improving computational efficiency and reducing processing time. | ✅    | ❌     | ➖      |\n| `cacher`     | Stores intermediate results of computations to speed up subsequent operations.               | ✅    | ➖     | ➖      |\n| `compiler`   | Optimises the model with instructions for specific hardware.                                 | ✅    | ➖     | ➖      |\n| `distiller`  | Trains a smaller, simpler model to mimic a larger, more complex model.                       | ✅    | ✅     | ❌      |\n| `quantizer`  | Reduces the precision of weights and activations, lowering memory requirements.              | ✅    | ✅     | ❌      |\n| `pruner`     | Removes less important or redundant connections and neurons, resulting in a sparser, more efficient network. | ✅    | ✅     | ❌      |\n| `recoverer`  | Restores the performance of a model after compression.                                       | ➖    | ➖     | ✅      |\n| `factorizer` | Factorization batches several small matrix multiplications into one large fused operation. | ✅ | ➖ | ➖ |\n| `enhancer`   | Enhances the model output by applying post-processing algorithms such as denoising or upscaling. | ❌ | - | ✅ |\n\n✅ (improves), ➖ (approx. the same), ❌ (worsens)\n\nExplore the full range of optimization methods in the [Pruna documentation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms).\n\n## Installation\n\nInstall Pruna with the following command.\n\n```bash\npip install pruna\n```\n\n\n## Optimize Diffusers models\n\nA broad range of optimization algorithms are supported for Diffusers models as shown below.\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/diffusers_combinations.png\" alt=\"Overview of the supported optimization algorithms for diffusers models\">\n</div>\n\nThe example below optimizes [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)\nwith a combination of factorizer, compiler, and cacher algorithms. This combination accelerates inference by up to 4.2x and cuts peak GPU memory usage from 34.7GB to 28.0GB, all while maintaining virtually the same output quality.\n\n> [!TIP]\n> Refer to the [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html) docs to learn more about the optimization techniques used in this example.\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_combination.png\" alt=\"Optimization techniques used for FLUX.1-dev showing the combination of factorizer, compiler, and cacher algorithms\">\n</div>\n\nStart by defining a `SmashConfig` with the optimization algorithms to use. To optimize the model, wrap the pipeline and the `SmashConfig` with `smash` and then use the pipeline as normal for inference.\n\n```python\nimport torch\nfrom diffusers import FluxPipeline\n\nfrom pruna import PrunaModel, SmashConfig, smash\n\n# load the model\n# Try segmind/Segmind-Vega or black-forest-labs/FLUX.1-schnell with a small GPU memory\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\n# define the configuration\nsmash_config = SmashConfig()\nsmash_config[\"factorizer\"] = \"qkv_diffusers\"\nsmash_config[\"compiler\"] = \"torch_compile\"\nsmash_config[\"torch_compile_target\"] = \"module_list\"\nsmash_config[\"cacher\"] = \"fora\"\nsmash_config[\"fora_interval\"] = 2\n\n# for the best results in terms of speed you can add these configs\n# however they will increase your warmup time from 1.5 min to 10 min\n# smash_config[\"torch_compile_mode\"] = \"max-autotune-no-cudagraphs\"\n# smash_config[\"quantizer\"] = \"torchao\"\n# smash_config[\"torchao_quant_type\"] = \"fp8dq\"\n# smash_config[\"torchao_excluded_modules\"] = \"norm+embedding\"\n\n# optimize the model\nsmashed_pipe = smash(pipe, smash_config)\n\n# run the model\nsmashed_pipe(\"a knitted purple prune\").images[0]\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_smashed_comparison.png\">\n</div>\n\nAfter optimization, we can share and load the optimized model using the Hugging Face Hub.\n\n```python\n# save the model\nsmashed_pipe.save_to_hub(\"<username>/FLUX.1-dev-smashed\")\n\n# load the model\nsmashed_pipe = PrunaModel.from_hub(\"<username>/FLUX.1-dev-smashed\")\n```\n\n## Evaluate and benchmark Diffusers models\n\nPruna provides the [EvaluationAgent](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html) to evaluate the quality of your optimized models.\n\nWe can metrics we care about, such as total time and throughput, and the dataset to evaluate on. We can define a model and pass it to the `EvaluationAgent`.\n\n<hfoptions id=\"eval\">\n<hfoption id=\"optimized model\">\n\nWe can load and evaluate an optimized model by using the `EvaluationAgent` and pass it to the `Task`.\n\n```python\nimport torch\nfrom diffusers import FluxPipeline\n\nfrom pruna import PrunaModel\nfrom pruna.data.pruna_datamodule import PrunaDataModule\nfrom pruna.evaluation.evaluation_agent import EvaluationAgent\nfrom pruna.evaluation.metrics import (\n    ThroughputMetric,\n    TorchMetricWrapper,\n    TotalTimeMetric,\n)\nfrom pruna.evaluation.task import Task\n\n# define the device\ndevice = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\n# load the model\n# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory\nsmashed_pipe = PrunaModel.from_hub(\"PrunaAI/FLUX.1-dev-smashed\")\n\n# Define the metrics\nmetrics = [\n    TotalTimeMetric(n_iterations=20, n_warmup_iterations=5),\n    ThroughputMetric(n_iterations=20, n_warmup_iterations=5),\n    TorchMetricWrapper(\"clip\"),\n]\n\n# Define the datamodule\ndatamodule = PrunaDataModule.from_string(\"LAION256\")\ndatamodule.limit_datasets(10)\n\n# Define the task and evaluation agent\ntask = Task(metrics, datamodule=datamodule, device=device)\neval_agent = EvaluationAgent(task)\n\n# Evaluate smashed model and offload it to CPU\nsmashed_pipe.move_to_device(device)\nsmashed_pipe_results = eval_agent.evaluate(smashed_pipe)\nsmashed_pipe.move_to_device(\"cpu\")\n```\n\n</hfoption>\n<hfoption id=\"standalone model\">\n\nInstead of comparing the optimized model to the base model, you can also evaluate the standalone `diffusers` model. This is useful if you want to evaluate the performance of the model without the optimization. We can do so by using the `PrunaModel` wrapper and run the `EvaluationAgent` on it.\n\n```python\nimport torch\nfrom diffusers import FluxPipeline\n\nfrom pruna import PrunaModel\n\n# load the model\n# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16\n).to(\"cpu\")\nwrapped_pipe = PrunaModel(model=pipe)\n```\n\n</hfoption>\n</hfoptions>\n\nNow that you have seen how to optimize and evaluate your models, you can start using Pruna to optimize your own models. Luckily, we have many examples to help you get started.\n\n> [!TIP]\n> For more details about benchmarking Flux, check out the [Announcing FLUX-Juiced: The Fastest Image Generation Endpoint (2.6 times faster)!](https://huggingface.co/blog/PrunaAI/flux-fastest-image-generation-endpoint) blog post and the [InferBench](https://huggingface.co/spaces/PrunaAI/InferBench) Space.\n\n## Reference\n\n- [Pruna](https://github.com/pruna-ai/pruna)\n- [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms)\n- [Pruna evaluation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html)\n- [Pruna tutorials](https://docs.pruna.ai/en/stable/docs_pruna/tutorials/index.html)\n\n"
  },
  {
    "path": "docs/source/en/optimization/speed-memory-optims.md",
    "content": "<!--Copyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Compiling and offloading quantized models\n\nOptimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).\n\n> [!TIP]\n> Check the [torch.compile](./fp16#torchcompile) guide to learn more about compilation and how they can be applied here. For example, regional compilation can significantly reduce compilation time without giving up any speedups. \n\nFor image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.\n\nFor video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound. \n\nThe table below provides a comparison of optimization strategy combinations and their impact on latency and memory-usage for Flux.\n\n| combination | latency (s) | memory-usage (GB) |\n|---|---|---|\n| quantization  | 32.602 | 14.9453 |\n| quantization, torch.compile  | 25.847 | 14.9448 |\n| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |\n\n<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the <a href=\"https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d\">benchmarking script</a> if you're interested in evaluating your own model.</small>\n\nThis guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.\n\n```bash\npip install -U bitsandbytes\n```\n\n## Quantization and torch.compile\n\nStart by [quantizing](../quantization/overview) a model to reduce the memory required for storage and [compiling](./fp16#torchcompile) it to accelerate inference.\n\nConfigure the [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `capture_dynamic_output_shape_ops = True` to handle dynamic outputs when compiling bitsandbytes models.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\n\ntorch._dynamo.config.capture_dynamic_output_shape_ops = True\n\n# quantize\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_backend=\"bitsandbytes_4bit\",\n    quant_kwargs={\"load_in_4bit\": True, \"bnb_4bit_quant_type\": \"nf4\", \"bnb_4bit_compute_dtype\": torch.bfloat16},\n    components_to_quantize=[\"transformer\", \"text_encoder_2\"],\n)\npipeline = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\n# compile\npipeline.transformer.to(memory_format=torch.channels_last)\npipeline.transformer.compile(mode=\"max-autotune\", fullgraph=True)\npipeline(\"\"\"\n    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\n    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\n).images[0]\n```\n\n## Quantization, torch.compile, and offloading\n\nIn addition to quantization and torch.compile, try offloading if you need to reduce memory-usage further. Offloading moves various layers or model components from the CPU to the GPU as needed for computations.\n\nConfigure the [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `cache_size_limit` during offloading to avoid excessive recompilation and set `capture_dynamic_output_shape_ops = True` to handle dynamic outputs when compiling bitsandbytes models.\n\n<hfoptions id=\"offloading\">\n<hfoption id=\"model CPU offloading\">\n\n[Model CPU offloading](./memory#model-offloading) moves an individual pipeline component, like the transformer model, to the GPU when it is needed for computation. Otherwise, it is offloaded to the CPU.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\n\ntorch._dynamo.config.cache_size_limit = 1000\ntorch._dynamo.config.capture_dynamic_output_shape_ops = True\n\n# quantize\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_backend=\"bitsandbytes_4bit\",\n    quant_kwargs={\"load_in_4bit\": True, \"bnb_4bit_quant_type\": \"nf4\", \"bnb_4bit_compute_dtype\": torch.bfloat16},\n    components_to_quantize=[\"transformer\", \"text_encoder_2\"],\n)\npipeline = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\n# model CPU offloading\npipeline.enable_model_cpu_offload()\n\n# compile\npipeline.transformer.compile()\npipeline(\n    \"cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\"\n).images[0]\n```\n\n</hfoption>\n<hfoption id=\"group offloading\">\n\n[Group offloading](./memory#group-offloading) moves the internal layers of an individual pipeline component, like the transformer model, to the GPU for computation and offloads it when it's not required. At the same time, it uses the [CUDA stream](./memory#cuda-stream) feature to prefetch the next layer for execution.\n\nBy overlapping computation and data transfer, it is faster than model CPU offloading while also saving memory. \n\n```py\n# pip install ftfy\nimport torch\nfrom diffusers import AutoModel, DiffusionPipeline\nfrom diffusers.hooks import apply_group_offloading\nfrom diffusers.utils import export_to_video\nfrom diffusers.quantizers import PipelineQuantizationConfig\nfrom transformers import UMT5EncoderModel\n\ntorch._dynamo.config.cache_size_limit = 1000\ntorch._dynamo.config.capture_dynamic_output_shape_ops = True\n\n# quantize\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_backend=\"bitsandbytes_4bit\",\n    quant_kwargs={\"load_in_4bit\": True, \"bnb_4bit_quant_type\": \"nf4\", \"bnb_4bit_compute_dtype\": torch.bfloat16},\n    components_to_quantize=[\"transformer\", \"text_encoder\"],\n)\n\ntext_encoder = UMT5EncoderModel.from_pretrained(\n    \"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"text_encoder\", torch_dtype=torch.bfloat16\n)\npipeline = DiffusionPipeline.from_pretrained(\n    \"Wan-AI/Wan2.1-T2V-14B-Diffusers\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\n# group offloading\nonload_device = torch.device(\"cuda\")\noffload_device = torch.device(\"cpu\")\n\npipeline.transformer.enable_group_offload(\n    onload_device=onload_device,\n    offload_device=offload_device,\n    offload_type=\"leaf_level\",\n    use_stream=True,\n    non_blocking=True\n)\npipeline.vae.enable_group_offload(\n    onload_device=onload_device,\n    offload_device=offload_device,\n    offload_type=\"leaf_level\",\n    use_stream=True,\n    non_blocking=True\n)\napply_group_offloading(\n    pipeline.text_encoder,\n    onload_device=onload_device,\n    offload_type=\"leaf_level\",\n    use_stream=True,\n    non_blocking=True\n)\n\n# compile\npipeline.transformer.compile()\n\nprompt = \"\"\"\nThe camera rushes from far to near in a low-angle shot, \nrevealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in \nfor a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. \nBirch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic \nshadows and warm highlights. Medium composition, front view, low angle, with depth of field.\n\"\"\"\nnegative_prompt = \"\"\"\nBright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, \nlow quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, \nmisshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\n\"\"\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=81,\n    guidance_scale=5.0,\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=16)\n```\n\n</hfoption>\n</hfoptions>"
  },
  {
    "path": "docs/source/en/optimization/tgate.md",
    "content": "# T-GATE\n\n[T-GATE](https://github.com/HaozheLiu-ST/T-GATE/tree/main) accelerates inference for [Stable Diffusion](../api/pipelines/stable_diffusion/overview), [PixArt](../api/pipelines/pixart), and [Latency Consistency Model](../api/pipelines/latent_consistency_models.md) pipelines by skipping the cross-attention calculation once it converges. This method doesn't require any additional training and it can speed up inference from 10-50%. T-GATE is also compatible with other optimization methods like [DeepCache](./deepcache).\n\nBefore you begin, make sure you install T-GATE.\n\n```bash\npip install tgate\npip install -U torch diffusers transformers accelerate DeepCache\n```\n\n\nTo use T-GATE with a pipeline, you need to use its corresponding loader.\n\n| Pipeline | T-GATE Loader |\n|---|---|\n| PixArt | TgatePixArtLoader |\n| Stable Diffusion XL | TgateSDXLLoader |\n| Stable Diffusion XL + DeepCache | TgateSDXLDeepCacheLoader |\n| Stable Diffusion | TgateSDLoader |\n| Stable Diffusion + DeepCache | TgateSDDeepCacheLoader |\n\nNext, create a `TgateLoader` with a pipeline, the gate step (the time step to stop calculating the cross attention), and the number of inference steps. Then call the `tgate` method on the pipeline with a prompt, gate step, and the number of inference steps.\n\nLet's see how to enable this for several different pipelines.\n\n<hfoptions id=\"pipelines\">\n<hfoption id=\"PixArt\">\n\nAccelerate `PixArtAlphaPipeline` with T-GATE:\n\n```py\nimport torch\nfrom diffusers import PixArtAlphaPipeline\nfrom tgate import TgatePixArtLoader\n\npipe = PixArtAlphaPipeline.from_pretrained(\"PixArt-alpha/PixArt-XL-2-1024-MS\", torch_dtype=torch.float16)\n\ngate_step = 8\ninference_step = 25\npipe = TgatePixArtLoader(\n       pipe,\n       gate_step=gate_step,\n       num_inference_steps=inference_step,\n).to(\"cuda\")\n\nimage = pipe.tgate(\n       \"An alpaca made of colorful building blocks, cyberpunk.\",\n       gate_step=gate_step,\n       num_inference_steps=inference_step,\n).images[0]\n```\n</hfoption>\n<hfoption id=\"Stable Diffusion XL\">\n\nAccelerate `StableDiffusionXLPipeline` with T-GATE:\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom diffusers import DPMSolverMultistepScheduler\nfrom tgate import TgateSDXLLoader\n\npipe = StableDiffusionXLPipeline.from_pretrained(\n            \"stabilityai/stable-diffusion-xl-base-1.0\",\n            torch_dtype=torch.float16,\n            variant=\"fp16\",\n            use_safetensors=True,\n)\npipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n\ngate_step = 10\ninference_step = 25\npipe = TgateSDXLLoader(\n       pipe,\n       gate_step=gate_step,\n       num_inference_steps=inference_step,\n).to(\"cuda\")\n\nimage = pipe.tgate(\n       \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.\",\n       gate_step=gate_step,\n       num_inference_steps=inference_step\n).images[0]\n```\n</hfoption>\n<hfoption id=\"StableDiffusionXL with DeepCache\">\n\nAccelerate `StableDiffusionXLPipeline` with [DeepCache](https://github.com/horseee/DeepCache) and T-GATE:\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom diffusers import DPMSolverMultistepScheduler\nfrom tgate import TgateSDXLDeepCacheLoader\n\npipe = StableDiffusionXLPipeline.from_pretrained(\n            \"stabilityai/stable-diffusion-xl-base-1.0\",\n            torch_dtype=torch.float16,\n            variant=\"fp16\",\n            use_safetensors=True,\n)\npipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n\ngate_step = 10\ninference_step = 25\npipe = TgateSDXLDeepCacheLoader(\n       pipe,\n       cache_interval=3,\n       cache_branch_id=0,\n).to(\"cuda\")\n\nimage = pipe.tgate(\n       \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.\",\n       gate_step=gate_step,\n       num_inference_steps=inference_step\n).images[0]\n```\n</hfoption>\n<hfoption id=\"Latent Consistency Model\">\n\nAccelerate `latent-consistency/lcm-sdxl` with T-GATE:\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom diffusers import UNet2DConditionModel, LCMScheduler\nfrom diffusers import DPMSolverMultistepScheduler\nfrom tgate import TgateSDXLLoader\n\nunet = UNet2DConditionModel.from_pretrained(\n    \"latent-consistency/lcm-sdxl\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n)\npipe = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    unet=unet,\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n)\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\n\ngate_step = 1\ninference_step = 4\npipe = TgateSDXLLoader(\n       pipe,\n       gate_step=gate_step,\n       num_inference_steps=inference_step,\n       lcm=True\n).to(\"cuda\")\n\nimage = pipe.tgate(\n       \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.\",\n       gate_step=gate_step,\n       num_inference_steps=inference_step\n).images[0]\n```\n</hfoption>\n</hfoptions>\n\nT-GATE also supports [`StableDiffusionPipeline`] and [PixArt-alpha/PixArt-LCM-XL-2-1024-MS](https://hf.co/PixArt-alpha/PixArt-LCM-XL-2-1024-MS).\n\n## Benchmarks\n| Model                 | MACs     | Param     | Latency | Zero-shot 10K-FID on MS-COCO |\n|-----------------------|----------|-----------|---------|---------------------------|\n| SD-1.5                | 16.938T  | 859.520M  | 7.032s  | 23.927                    |\n| SD-1.5 w/ T-GATE       | 9.875T   | 815.557M  | 4.313s  | 20.789                    |\n| SD-2.1                | 38.041T  | 865.785M  | 16.121s | 22.609                    |\n| SD-2.1 w/ T-GATE       | 22.208T  | 815.433 M | 9.878s  | 19.940                    |\n| SD-XL                 | 149.438T | 2.570B    | 53.187s | 24.628                    |\n| SD-XL w/ T-GATE        | 84.438T  | 2.024B    | 27.932s | 22.738                    |\n| Pixart-Alpha          | 107.031T | 611.350M  | 61.502s | 38.669                    |\n| Pixart-Alpha w/ T-GATE | 65.318T  | 462.585M  | 37.867s | 35.825                    |\n| DeepCache (SD-XL)     | 57.888T  | -         | 19.931s | 23.755                    |\n| DeepCache w/ T-GATE    | 43.868T  | -         | 14.666s | 23.999                    |\n| LCM (SD-XL)           | 11.955T  | 2.570B    | 3.805s  | 25.044                    |\n| LCM w/ T-GATE          | 11.171T  | 2.024B    | 3.533s  | 25.028                    |\n| LCM (Pixart-Alpha)    | 8.563T   | 611.350M  | 4.733s  | 36.086                    |\n| LCM w/ T-GATE          | 7.623T   | 462.585M  | 4.543s  | 37.048                    |\n\nThe latency is tested on an NVIDIA 1080TI, MACs and Params are calculated with [calflops](https://github.com/MrYxJ/calculate-flops.pytorch), and the FID is calculated with [PytorchFID](https://github.com/mseitzer/pytorch-fid).\n"
  },
  {
    "path": "docs/source/en/optimization/tome.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Token merging\n\n[Token merging](https://huggingface.co/papers/2303.17604) (ToMe) merges redundant tokens/patches progressively in the forward pass of a Transformer-based network which can speed-up the inference latency of [`StableDiffusionPipeline`].\n\nInstall ToMe from `pip`:\n\n```bash\npip install tomesd\n```\n\nYou can use ToMe from the [`tomesd`](https://github.com/dbolya/tomesd) library with the [`apply_patch`](https://github.com/dbolya/tomesd?tab=readme-ov-file#usage) function:\n\n```diff\n  from diffusers import StableDiffusionPipeline\n  import torch\n  import tomesd\n\n  pipeline = StableDiffusionPipeline.from_pretrained(\n        \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, use_safetensors=True,\n  ).to(\"cuda\")\n+ tomesd.apply_patch(pipeline, ratio=0.5)\n\n  image = pipeline(\"a photo of an astronaut riding a horse on mars\").images[0]\n```\n\nThe `apply_patch` function exposes a number of [arguments](https://github.com/dbolya/tomesd#usage) to help strike a balance between pipeline inference speed and the quality of the generated tokens. The most important argument is `ratio` which controls the number of tokens that are merged during the forward pass.\n\nAs reported in the [paper](https://huggingface.co/papers/2303.17604), ToMe can greatly preserve the quality of the generated images while boosting inference speed. By increasing the `ratio`, you can speed-up inference even further, but at the cost of some degraded image quality.\n\nTo test the quality of the generated images, we sampled a few prompts from [Parti Prompts](https://parti.research.google/) and performed inference with the [`StableDiffusionPipeline`] with the following settings:\n\n<div class=\"flex justify-center\">\n      <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/tome/tome_samples.png\">\n</div>\n\nWe didn’t notice any significant decrease in the quality of the generated samples, and you can check out the generated samples in this [WandB report](https://wandb.ai/sayakpaul/tomesd-results/runs/23j4bj3i?workspace=). If you're interested in reproducing this experiment, use this [script](https://gist.github.com/sayakpaul/8cac98d7f22399085a060992f411ecbd).\n\n## Benchmarks\n\nWe also benchmarked the impact of `tomesd` on the [`StableDiffusionPipeline`] with [xFormers](https://huggingface.co/docs/diffusers/optimization/xformers) enabled across several image resolutions. The results are obtained from A100 and V100 GPUs in the following development environment:\n\n```bash\n- `diffusers` version: 0.15.1\n- Python version: 3.8.16\n- PyTorch version (GPU?): 1.13.1+cu116 (True)\n- Huggingface_hub version: 0.13.2\n- Transformers version: 4.27.2\n- Accelerate version: 0.18.0\n- xFormers version: 0.0.16\n- tomesd version: 0.1.2\n```\n\nTo reproduce this benchmark, feel free to use this [script](https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335). The results are reported in seconds, and where applicable we report the speed-up percentage over the vanilla pipeline when using ToMe and ToMe + xFormers.\n\n| **GPU**  | **Resolution** | **Batch size** | **Vanilla** | **ToMe**       | **ToMe + xFormers** |\n|----------|----------------|----------------|-------------|----------------|---------------------|\n| **A100** |            512 |             10 |        6.88 | 5.26 (+23.55%) |      4.69 (+31.83%) |\n|          |            768 |             10 |         OOM |          14.71 |                  11 |\n|          |                |              8 |         OOM |          11.56 |                8.84 |\n|          |                |              4 |         OOM |           5.98 |                4.66 |\n|          |                |              2 |        4.99 | 3.24 (+35.07%) |       2.1 (+37.88%) |\n|          |                |              1 |        3.29 | 2.24 (+31.91%) |       2.03 (+38.3%) |\n|          |           1024 |             10 |         OOM |            OOM |                 OOM |\n|          |                |              8 |         OOM |            OOM |                 OOM |\n|          |                |              4 |         OOM |          12.51 |                9.09 |\n|          |                |              2 |         OOM |           6.52 |                4.96 |\n|          |                |              1 |         6.4 | 3.61 (+43.59%) |      2.81 (+56.09%) |\n| **V100** |            512 |             10 |         OOM |          10.03 |                9.29 |\n|          |                |              8 |         OOM |           8.05 |                7.47 |\n|          |                |              4 |         5.7 |  4.3 (+24.56%) |      3.98 (+30.18%) |\n|          |                |              2 |        3.14 | 2.43 (+22.61%) |      2.27 (+27.71%) |\n|          |                |              1 |        1.88 | 1.57 (+16.49%) |      1.57 (+16.49%) |\n|          |            768 |             10 |         OOM |            OOM |               23.67 |\n|          |                |              8 |         OOM |            OOM |               18.81 |\n|          |                |              4 |         OOM |          11.81 |                 9.7 |\n|          |                |              2 |         OOM |           6.27 |                 5.2 |\n|          |                |              1 |        5.43 | 3.38 (+37.75%) |      2.82 (+48.07%) |\n|          |           1024 |             10 |         OOM |            OOM |                 OOM |\n|          |                |              8 |         OOM |            OOM |                 OOM |\n|          |                |              4 |         OOM |            OOM |               19.35 |\n|          |                |              2 |         OOM |             13 |               10.78 |\n|          |                |              1 |         OOM |           6.66 |                5.54 |\n\nAs seen in the tables above, the speed-up from `tomesd` becomes more pronounced for larger image resolutions. It is also interesting to note that with `tomesd`, it is possible to run the pipeline on a higher resolution like 1024x1024. You may be able to speed-up inference even more with [`torch.compile`](fp16#torchcompile).\n"
  },
  {
    "path": "docs/source/en/optimization/xdit.md",
    "content": "# xDiT\n\n[xDiT](https://github.com/xdit-project/xDiT) is an inference engine designed for the large scale parallel deployment of Diffusion Transformers (DiTs). xDiT provides a suite of efficient parallel approaches for Diffusion Models, as well as GPU kernel accelerations.\n\nThere are four parallel methods supported in xDiT, including [Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719), [PipeFusion](https://huggingface.co/papers/2405.14430), CFG parallelism and data parallelism. The four parallel methods in xDiT can be configured in a hybrid manner, optimizing communication patterns to best suit the underlying network hardware.\n\nOptimization orthogonal to parallelization focuses on accelerating single GPU performance. In addition to utilizing well-known Attention optimization libraries, we leverage compilation acceleration technologies such as torch.compile and onediff.\n\nThe overview of xDiT is shown as follows.\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/methods/xdit_overview.png\">\n</div>\nYou can install xDiT using the following command:\n\n\n```bash\npip install xfuser\n```\n\nHere's an example of using xDiT to accelerate inference of a Diffusers model.\n\n```diff\n import torch\n from diffusers import StableDiffusion3Pipeline\n\n from xfuser import xFuserArgs, xDiTParallel\n from xfuser.config import FlexibleArgumentParser\n from xfuser.core.distributed import get_world_group\n\n def main():\n+    parser = FlexibleArgumentParser(description=\"xFuser Arguments\")\n+    args = xFuserArgs.add_cli_args(parser).parse_args()\n+    engine_args = xFuserArgs.from_cli_args(args)\n+    engine_config, input_config = engine_args.create_config()\n\n     local_rank = get_world_group().local_rank\n     pipe = StableDiffusion3Pipeline.from_pretrained(\n         pretrained_model_name_or_path=engine_config.model_config.model,\n         torch_dtype=torch.float16,\n     ).to(f\"cuda:{local_rank}\")\n    \n# do anything you want with pipeline here\n\n+    pipe = xDiTParallel(pipe, engine_config, input_config)\n\n     pipe(\n         height=input_config.height,\n         width=input_config.height,\n         prompt=input_config.prompt,\n         num_inference_steps=input_config.num_inference_steps,\n         output_type=input_config.output_type,\n         generator=torch.Generator(device=\"cuda\").manual_seed(input_config.seed),\n     )\n\n+    if input_config.output_type == \"pil\":\n+        pipe.save(\"results\", \"stable_diffusion_3\")\n\nif __name__ == \"__main__\":\n    main()\n\n```\n\nAs you can see, we only need to use xFuserArgs from xDiT to get configuration parameters, and pass these parameters along with the pipeline object from the Diffusers library into xDiTParallel to complete the parallelization of a specific pipeline in Diffusers.\n\nxDiT runtime parameters can be viewed in the command line using `-h`, and you can refer to this [usage](https://github.com/xdit-project/xDiT?tab=readme-ov-file#2-usage) example for more details.\n\nxDiT needs to be launched using torchrun to support its multi-node, multi-GPU parallel capabilities. For example, the following command can be used for 8-GPU parallel inference:\n\n```bash\ntorchrun --nproc_per_node=8 ./inference.py --model models/FLUX.1-dev --data_parallel_degree 2 --ulysses_degree 2 --ring_degree 2 --prompt \"A snowy mountain\" \"A small dog\" --num_inference_steps 50\n```\n\n## Supported models\n\nA subset of Diffusers models are supported in xDiT, such as Flux.1, Stable Diffusion 3, etc. The latest supported models can be found [here](https://github.com/xdit-project/xDiT?tab=readme-ov-file#-supported-dits).\n\n## Benchmark\nWe tested different models on various machines, and here is some of the benchmark data.\n\n### Flux.1-schnell\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/flux/Flux-2k-L40.png\">\n</div>\n\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/flux/Flux-2K-A100.png\">\n</div>\n\n### Stable Diffusion 3\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/sd3/L40-SD3.png\">\n</div>\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/sd3/A100-SD3.png\">\n</div>\n\n### HunyuanDiT\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/hunuyuandit/L40-HunyuanDiT.png\">\n</div>\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/hunuyuandit/V100-HunyuanDiT.png\">\n</div>\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/hunuyuandit/T4-HunyuanDiT.png\">\n</div>\n\nMore detailed performance metric can be found on our [github page](https://github.com/xdit-project/xDiT?tab=readme-ov-file#perf).\n\n## Reference\n\n[xDiT-project](https://github.com/xdit-project/xDiT)\n\n[USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://huggingface.co/papers/2405.07719)\n\n[PipeFusion: Displaced Patch Pipeline Parallelism for Inference of Diffusion Transformer Models](https://huggingface.co/papers/2405.14430)"
  },
  {
    "path": "docs/source/en/optimization/xformers.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# xFormers\n\nWe recommend [xFormers](https://github.com/facebookresearch/xformers) for both inference and training. In our tests, the optimizations performed in the attention blocks allow for both faster speed and reduced memory consumption.\n\nInstall xFormers from `pip`:\n\n```bash\npip install xformers\n```\n\n> [!TIP]\n> The xFormers `pip` package requires the latest version of PyTorch. If you need to use a previous version of PyTorch, then we recommend [installing xFormers from the source](https://github.com/facebookresearch/xformers#installing-xformers).\n\nAfter xFormers is installed, you can use it with [`~ModelMixin.set_attention_backend`] as shown in the [Attention backends](./attention_backends) guide.\n\n> [!WARNING]\n> According to this [issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training (fine-tune or DreamBooth) in some GPUs. If you observe this problem, please install a development version as indicated in the issue comments.\n"
  },
  {
    "path": "docs/source/en/quantization/bitsandbytes.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n-->\n\n# bitsandbytes\n\n[bitsandbytes](https://huggingface.co/docs/bitsandbytes/index) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance.\n\n4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs.\n\nThis guide demonstrates how quantization can enable running\n[FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)\non less than 16GB of VRAM and even on a free Google\nColab instance.\n\n![comparison image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/quant-bnb/comparison.png)\n\nTo use bitsandbytes, make sure you have the following libraries installed:\n\n```bash\npip install diffusers transformers accelerate bitsandbytes -U\n```\n\nNow you can quantize a model by passing a [`BitsAndBytesConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.\n\n<hfoptions id=\"bnb\">\n<hfoption id=\"8-bit\">\n\nQuantizing a model in 8-bit halves the memory-usage:\n\nbitsandbytes is supported in both Transformers and Diffusers, so you can quantize both the\n[`FluxTransformer2DModel`] and [`~transformers.T5EncoderModel`].\n\nFor Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bfloat16`.\n\n> [!TIP]\n> The [`CLIPTextModel`] and [`AutoencoderKL`] aren't quantized because they're already small in size and because [`AutoencoderKL`] only has a few `torch.nn.Linear` layers.\n\n```py\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig\nfrom transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig\nimport torch\nfrom diffusers import AutoModel\nfrom transformers import T5EncoderModel\n\nquant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)\n\ntext_encoder_2_8bit = T5EncoderModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"text_encoder_2\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)\n\ntransformer_8bit = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n```\n\nBy default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.\n\n```diff\ntransformer_8bit = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n+   torch_dtype=torch.float32,\n)\n```\n\nLet's generate an image using our quantized models.\n\nSetting `device_map=\"auto\"` automatically fills all available space on the GPU(s) first, then the\nCPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.\n\n```py\nfrom diffusers import FluxPipeline\n\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    transformer=transformer_8bit,\n    text_encoder_2=text_encoder_2_8bit,\n    torch_dtype=torch.float16,\n    device_map=\"auto\",\n)\n\npipe_kwargs = {\n    \"prompt\": \"A cat holding a sign that says hello world\",\n    \"height\": 1024,\n    \"width\": 1024,\n    \"guidance_scale\": 3.5,\n    \"num_inference_steps\": 50,\n    \"max_sequence_length\": 512,\n}\n\nimage = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0]\n```\n\n<div class=\"flex justify-center\">\n   <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/quant-bnb/8bit.png\"/>\n</div>\n\nWhen there is enough memory, you can also directly move the pipeline to the GPU with `.to(\"cuda\")` and apply [`~DiffusionPipeline.enable_model_cpu_offload`] to optimize GPU memory usage.\n\nOnce a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 8-bit models locally with [`~ModelMixin.save_pretrained`].\n\n</hfoption>\n<hfoption id=\"4-bit\">\n\nQuantizing a model in 4-bit reduces your memory-usage by 4x:\n\nbitsandbytes is supported in both Transformers and Diffusers, so you can can quantize both the\n[`FluxTransformer2DModel`] and [`~transformers.T5EncoderModel`].\n\nFor Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bfloat16`.\n\n> [!TIP]\n> The [`CLIPTextModel`] and [`AutoencoderKL`] aren't quantized because they're already small in size and because [`AutoencoderKL`] only has a few `torch.nn.Linear` layers.\n\n```py\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig\nfrom transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig\nimport torch\nfrom diffusers import AutoModel\nfrom transformers import T5EncoderModel\n\nquant_config = TransformersBitsAndBytesConfig(load_in_4bit=True,)\n\ntext_encoder_2_4bit = T5EncoderModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"text_encoder_2\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True,)\n\ntransformer_4bit = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n```\n\nBy default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.\n\n```diff\ntransformer_4bit = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n+   torch_dtype=torch.float32,\n)\n```\n\nLet's generate an image using our quantized models.\n\nSetting `device_map=\"auto\"` automatically fills all available space on the GPU(s) first, then the CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.\n\n```py\nfrom diffusers import FluxPipeline\n\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    transformer=transformer_4bit,\n    text_encoder_2=text_encoder_2_4bit,\n    torch_dtype=torch.float16,\n    device_map=\"auto\",\n)\n\npipe_kwargs = {\n    \"prompt\": \"A cat holding a sign that says hello world\",\n    \"height\": 1024,\n    \"width\": 1024,\n    \"guidance_scale\": 3.5,\n    \"num_inference_steps\": 50,\n    \"max_sequence_length\": 512,\n}\n\nimage = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0]\n```\n\n<div class=\"flex justify-center\">\n   <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/quant-bnb/4bit.png\"/>\n</div>\n\nWhen there is enough memory, you can also directly move the pipeline to the GPU with `.to(\"cuda\")` and apply [`~DiffusionPipeline.enable_model_cpu_offload`] to optimize GPU memory usage.\n\nOnce a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].\n\n</hfoption>\n</hfoptions>\n\n> [!WARNING]\n> Training with 8-bit and 4-bit weights are only supported for training *extra* parameters.\n\nCheck your memory footprint with the `get_memory_footprint` method:\n\n```py\nprint(model.get_memory_footprint())\n```\n\nNote that this only tells you the memory footprint of the model params and does _not_ estimate the inference memory requirements.\n\nQuantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters:\n\n```py\nfrom diffusers import AutoModel, BitsAndBytesConfig\n\nquantization_config = BitsAndBytesConfig(load_in_4bit=True)\n\nmodel_4bit = AutoModel.from_pretrained(\n    \"hf-internal-testing/flux.1-dev-nf4-pkg\", subfolder=\"transformer\"\n)\n```\n\n## 8-bit (LLM.int8() algorithm)\n\n> [!TIP]\n> Learn more about the details of 8-bit quantization in this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration)!\n\nThis section explores some of the specific features of 8-bit models, such as outlier thresholds and skipping module conversion.\n\n### Outlier threshold\n\nAn \"outlier\" is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning).\n\nTo find the best threshold for your model, we recommend experimenting with the `llm_int8_threshold` parameter in [`BitsAndBytesConfig`]:\n\n```py\nfrom diffusers import AutoModel, BitsAndBytesConfig\n\nquantization_config = BitsAndBytesConfig(\n    load_in_8bit=True, llm_int8_threshold=10,\n)\n\nmodel_8bit = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"transformer\",\n    quantization_config=quantization_config,\n)\n```\n\n### Skip module conversion\n\nFor some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module can be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]:\n\n```py\nfrom diffusers import SD3Transformer2DModel, BitsAndBytesConfig\n\nquantization_config = BitsAndBytesConfig(\n    load_in_8bit=True, llm_int8_skip_modules=[\"proj_out\"],\n)\n\nmodel_8bit = SD3Transformer2DModel.from_pretrained(\n    \"stabilityai/stable-diffusion-3-medium-diffusers\",\n    subfolder=\"transformer\",\n    quantization_config=quantization_config,\n)\n```\n\n\n## 4-bit (QLoRA algorithm)\n\n> [!TIP]\n> Learn more about its details in this [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).\n\nThis section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization.\n\n\n### Compute data type\n\nTo speedup computation, you can change the data type from float32 (the default value) to bf16 using the `bnb_4bit_compute_dtype` parameter in [`BitsAndBytesConfig`]:\n\n```py\nimport torch\nfrom diffusers import BitsAndBytesConfig\n\nquantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)\n```\n\n### Normal Float 4 (NF4)\n\nNF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]:\n\n```py\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig\nfrom transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig\n\nfrom diffusers import AutoModel\nfrom transformers import T5EncoderModel\n\nquant_config = TransformersBitsAndBytesConfig(\n    load_in_4bit=True,\n    bnb_4bit_quant_type=\"nf4\",\n)\n\ntext_encoder_2_4bit = T5EncoderModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"text_encoder_2\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(\n    load_in_4bit=True,\n    bnb_4bit_quant_type=\"nf4\",\n)\n\ntransformer_4bit = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n```\n\nFor inference, the `bnb_4bit_quant_type` does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the `bnb_4bit_compute_dtype` and `torch_dtype` values.\n\n### Nested quantization\n\nNested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter. \n\n```py\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig\nfrom transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig\n\nfrom diffusers import AutoModel\nfrom transformers import T5EncoderModel\n\nquant_config = TransformersBitsAndBytesConfig(\n    load_in_4bit=True,\n    bnb_4bit_use_double_quant=True,\n)\n\ntext_encoder_2_4bit = T5EncoderModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"text_encoder_2\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(\n    load_in_4bit=True,\n    bnb_4bit_use_double_quant=True,\n)\n\ntransformer_4bit = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n```\n\n## Dequantizing `bitsandbytes` models\n\nOnce quantized, you can dequantize a model to its original precision, but this might result in a small loss of quality. Make sure you have enough GPU RAM to fit the dequantized model. \n\n```python\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig\nfrom transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig\n\nfrom diffusers import AutoModel\nfrom transformers import T5EncoderModel\n\nquant_config = TransformersBitsAndBytesConfig(\n    load_in_4bit=True,\n    bnb_4bit_use_double_quant=True,\n)\n\ntext_encoder_2_4bit = T5EncoderModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"text_encoder_2\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\nquant_config = DiffusersBitsAndBytesConfig(\n    load_in_4bit=True,\n    bnb_4bit_use_double_quant=True,\n)\n\ntransformer_4bit = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\n\ntext_encoder_2_4bit.dequantize()\ntransformer_4bit.dequantize()\n```\n\n## torch.compile\n\nSpeed up inference with `torch.compile`. Make sure you have the latest `bitsandbytes` installed and we also recommend installing [PyTorch nightly](https://pytorch.org/get-started/locally/).\n\n<hfoptions id=\"bnb\">\n<hfoption id=\"8-bit\">\n```py\ntorch._dynamo.config.capture_dynamic_output_shape_ops = True\n\nquant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)\ntransformer_4bit = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\ntransformer_4bit.compile(fullgraph=True)\n```\n\n</hfoption>\n<hfoption id=\"4-bit\">\n\n```py\nquant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True)\ntransformer_4bit = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"transformer\",\n    quantization_config=quant_config,\n    torch_dtype=torch.float16,\n)\ntransformer_4bit.compile(fullgraph=True)\n```\n</hfoption>\n</hfoptions>\n\nOn an RTX 4090 with compilation, 4-bit Flux generation completed in 25.809 seconds versus 32.570 seconds without.\n\nCheck out the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) for more details.\n\n## Resources\n\n* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)\n* [Training](https://github.com/huggingface/diffusers/blob/8c661ea586bf11cb2440da740dd3c4cf84679b85/examples/dreambooth/README_hidream.md#using-quantization)"
  },
  {
    "path": "docs/source/en/quantization/gguf.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n-->\n\n# GGUF\n\nThe GGUF file format is typically used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Loading GGUF checkpoints via Pipelines is currently not supported.\n\nThe following example will load the [FLUX.1 DEV](https://huggingface.co/black-forest-labs/FLUX.1-dev) transformer model using the GGUF Q2_K quantization variant.\n\nBefore starting please install gguf in your environment\n\n```shell\npip install -U gguf\n```\n\nSince GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`].\n\nWhen using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.uint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`.\n\nThe functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original [`numpy`](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py) implementation by [compilade](https://github.com/compilade).\n\n```python\nimport torch\n\nfrom diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig\n\nckpt_path = (\n    \"https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf\"\n)\ntransformer = FluxTransformer2DModel.from_single_file(\n    ckpt_path,\n    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),\n    torch_dtype=torch.bfloat16,\n)\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    transformer=transformer,\n    torch_dtype=torch.bfloat16,\n)\npipe.enable_model_cpu_offload()\nprompt = \"A cat holding a sign that says hello world\"\nimage = pipe(prompt, generator=torch.manual_seed(0)).images[0]\nimage.save(\"flux-gguf.png\")\n```\n\n## Using Optimized CUDA Kernels with GGUF\n\nOptimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library:\n\n```shell\npip install -U kernels\n```\n\nOnce installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true`  to use optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`.\n\n## Supported Quantization Types\n\n- BF16\n- Q4_0\n- Q4_1\n- Q5_0\n- Q5_1\n- Q8_0\n- Q2_K\n- Q3_K\n- Q4_K\n- Q5_K\n- Q6_K\n\n## Convert to GGUF\n\nUse the Space below to convert a Diffusers checkpoint into the GGUF format for inference.\nrun conversion:\n\n<iframe\n\tsrc=\"https://diffusers-internal-dev-diffusers-to-gguf.hf.space\"\n\tframeborder=\"0\"\n\twidth=\"850\"\n\theight=\"450\"\n></iframe>\n\n\n```py\nimport torch\n\nfrom diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig\n\nckpt_path = (\n    \"https://huggingface.co/sayakpaul/different-lora-from-civitai/blob/main/flux_dev_diffusers-q4_0.gguf\"\n)\ntransformer = FluxTransformer2DModel.from_single_file(\n    ckpt_path,\n    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),\n    config=\"black-forest-labs/FLUX.1-dev\",\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16,\n)\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    transformer=transformer,\n    torch_dtype=torch.bfloat16,\n)\npipe.enable_model_cpu_offload()\nprompt = \"A cat holding a sign that says hello world\"\nimage = pipe(prompt, generator=torch.manual_seed(0)).images[0]\nimage.save(\"flux-gguf.png\")\n```\n\nWhen using Diffusers format GGUF checkpoints, it's a must to provide the model `config` path. If the\nmodel config resides in a `subfolder`, that needs to be specified, too."
  },
  {
    "path": "docs/source/en/quantization/modelopt.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# NVIDIA ModelOpt\n\n[NVIDIA-ModelOpt](https://github.com/NVIDIA/Model-Optimizer) is a unified library of state-of-the-art model optimization techniques like quantization, pruning, distillation, speculative decoding, etc. It compresses deep learning models for downstream deployment frameworks like TensorRT-LLM or TensorRT to optimize inference speed.\n\nBefore you begin, make sure you have nvidia_modelopt installed.\n\n```bash\npip install -U \"nvidia_modelopt[hf]\"\n```\n\nQuantize a model by passing [`NVIDIAModelOptConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.\n\nThe example below only quantizes the weights to FP8.\n\n```python\nimport torch\nfrom diffusers import AutoModel, SanaPipeline, NVIDIAModelOptConfig\n\nmodel_id = \"Efficient-Large-Model/Sana_600M_1024px_diffusers\"\ndtype = torch.bfloat16\n\nquantization_config = NVIDIAModelOptConfig(quant_type=\"FP8\", quant_method=\"modelopt\")\ntransformer = AutoModel.from_pretrained(\n    model_id,\n    subfolder=\"transformer\",\n    quantization_config=quantization_config,\n    torch_dtype=dtype,\n)\npipe = SanaPipeline.from_pretrained(\n    model_id,\n    transformer=transformer,\n    torch_dtype=dtype,\n)\npipe.to(\"cuda\")\n\nprint(f\"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB\")\n\nprompt = \"A cat holding a sign that says hello world\"\nimage = pipe(\n    prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512\n).images[0]\nimage.save(\"output.png\")\n```\n\n> **Note:**\n>\n> The quantization methods in NVIDIA-ModelOpt are designed to reduce the memory footprint of model weights using various QAT (Quantization-Aware Training) and PTQ (Post-Training Quantization) techniques while maintaining model performance. However, the actual performance gain during inference depends on the deployment framework (e.g., TRT-LLM, TensorRT) and the specific hardware configuration.  \n> \n> More details can be found [here](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples).\n\n## NVIDIAModelOptConfig\n\nThe `NVIDIAModelOptConfig` class accepts three parameters:\n- `quant_type`: A string value mentioning one of the quantization types below.\n- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`SD3Transformer2DModel`]'s pos_embed projection blocks, one would specify: `modules_to_not_convert=[\"pos_embed.proj.weight\"]`.\n- `disable_conv_quantization`: A boolean value which when set to `True` disables quantization for all convolutional layers in the model. This is useful as channel and block quantization generally don't work well with convolutional layers (used with INT4, NF4, NVFP4). If you want to disable quantization for specific convolutional layers, use `modules_to_not_convert` instead.\n- `algorithm`: The algorithm to use for determining scale, defaults to `\"max\"`. You can check modelopt documentation for more algorithms and details.\n- `forward_loop`: The forward loop function to use for calibrating activation during quantization. If not provided, it relies on static scale values computed using the weights only.\n- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.\n\n## Supported quantization types\n\nModelOpt supports weight-only, channel and block quantization int8, fp8, int4, nf4, and nvfp4. The quantization methods are designed to reduce the memory footprint of the model weights while maintaining the performance of the model during inference.\n\nWeight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.\n\nThe quantization methods supported are as follows:\n\n| **Quantization Type** | **Supported Schemes** | **Required Kwargs** | **Additional Notes** |\n|-----------------------|-----------------------|---------------------|----------------------|\n| **INT8** | `int8 weight only`, `int8 channel quantization`, `int8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` |\n| **FP8** | `fp8 weight only`, `fp8 channel quantization`, `fp8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` |\n| **INT4** | `int4 weight only`, `int4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`|\n| **NF4** | `nf4 weight only`, `nf4 double block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize + scale_channel_quantize` + `scale_block_quantize` | `channel_quantize = -1 and scale_channel_quantize = -1 are only supported for now` |\n| **NVFP4** | `nvfp4 weight only`, `nvfp4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`|\n\n\nRefer to the [official modelopt documentation](https://nvidia.github.io/Model-Optimizer/) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.\n\n## Serializing and Deserializing quantized models\n\nTo serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.\n\n```python\nimport torch\nfrom diffusers import AutoModel, NVIDIAModelOptConfig\nfrom modelopt.torch.opt import enable_huggingface_checkpointing\n\nenable_huggingface_checkpointing()\n\nmodel_id = \"Efficient-Large-Model/Sana_600M_1024px_diffusers\"\nquant_config_fp8 = {\"quant_type\": \"FP8\", \"quant_method\": \"modelopt\"}\nquant_config_fp8 = NVIDIAModelOptConfig(**quant_config_fp8)\nmodel = AutoModel.from_pretrained(\n    model_id,\n    subfolder=\"transformer\",\n    quantization_config=quant_config_fp8,\n    torch_dtype=torch.bfloat16,\n)\nmodel.save_pretrained('path/to/sana_fp8', safe_serialization=False)\n```\n\nTo load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.\n\n```python\nimport torch\nfrom diffusers import AutoModel, NVIDIAModelOptConfig, SanaPipeline\nfrom modelopt.torch.opt import enable_huggingface_checkpointing\n\nenable_huggingface_checkpointing()\n\nquantization_config = NVIDIAModelOptConfig(quant_type=\"FP8\", quant_method=\"modelopt\")\ntransformer = AutoModel.from_pretrained(\n    \"path/to/sana_fp8\",\n    subfolder=\"transformer\",\n    quantization_config=quantization_config,\n    torch_dtype=torch.bfloat16,\n)\npipe = SanaPipeline.from_pretrained(\n    \"Efficient-Large-Model/Sana_600M_1024px_diffusers\",\n    transformer=transformer,\n    torch_dtype=torch.bfloat16,\n)\npipe.to(\"cuda\")\nprompt = \"A cat holding a sign that says hello world\"\nimage = pipe(\n    prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512\n).images[0]\nimage.save(\"output.png\")\n```\n"
  },
  {
    "path": "docs/source/en/quantization/overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n-->\n\n# Getting started\n\nQuantization focuses on representing data with fewer bits while also trying to preserve the precision of the original data. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits.\n\nDiffusers supports multiple quantization backends to make large diffusion models like [Flux](../api/pipelines/flux) more accessible. This guide shows how to use the [`~quantizers.PipelineQuantizationConfig`] class to quantize a pipeline during its initialization from a pretrained or non-quantized checkpoint.\n\n## Pipeline-level quantization\n\nThere are two ways to use [`~quantizers.PipelineQuantizationConfig`] depending on how much customization you want to apply to the quantization configuration. \n\n- for basic use cases, define the `quant_backend`, `quant_kwargs`, and `components_to_quantize` arguments\n- for granular quantization control, define a `quant_mapping` that provides the quantization configuration for individual model components\n\n### Basic quantization\n\nInitialize [`~quantizers.PipelineQuantizationConfig`] with the following parameters.\n\n- `quant_backend` specifies which quantization backend to use. Currently supported backends include: `bitsandbytes_4bit`, `bitsandbytes_8bit`, `gguf`, `quanto`, and `torchao`.\n- `quant_kwargs` specifies the quantization arguments to use.\n\n> [!TIP]\n> These `quant_kwargs` arguments are different for each backend. Refer to the [Quantization API](../api/quantization) docs to view the arguments for each backend.\n\n- `components_to_quantize` specifies which component(s) of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.\n\n   `components_to_quantize` accepts either a list for multiple models or a string for a single model.\n\nThe example below loads the bitsandbytes backend with the following arguments from [`~quantizers.quantization_config.BitsAndBytesConfig`], `load_in_4bit`, `bnb_4bit_quant_type`, and `bnb_4bit_compute_dtype`.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\n\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_backend=\"bitsandbytes_4bit\",\n    quant_kwargs={\"load_in_4bit\": True, \"bnb_4bit_quant_type\": \"nf4\", \"bnb_4bit_compute_dtype\": torch.bfloat16},\n    components_to_quantize=[\"transformer\", \"text_encoder_2\"],\n)\n```\n\nPass the `pipeline_quant_config` to [`~DiffusionPipeline.from_pretrained`] to quantize the pipeline.\n\n```py\npipe = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\nimage = pipe(\"photo of a cute dog\").images[0]\n```\n\n\n### Advanced quantization\n\nThe `quant_mapping` argument provides more options for how to quantize each individual component in a pipeline, like combining different quantization backends.\n\nInitialize [`~quantizers.PipelineQuantizationConfig`] and pass a `quant_mapping` to it. The `quant_mapping` allows you to specify the quantization options for each component in the pipeline such as the transformer and text encoder.\n\nThe example below uses two quantization backends, [`~quantizers.quantization_config.QuantoConfig`] and [`transformers.BitsAndBytesConfig`], for the transformer and text encoder.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig\nfrom diffusers.quantizers.quantization_config import QuantoConfig\nfrom diffusers.quantizers import PipelineQuantizationConfig\nfrom transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig\n\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_mapping={\n        \"transformer\": QuantoConfig(weights_dtype=\"int8\"),\n        \"text_encoder_2\": TransformersBitsAndBytesConfig(\n            load_in_4bit=True, compute_dtype=torch.bfloat16\n        ),\n    }\n)\n```\n\nThere is a separate bitsandbytes backend in [Transformers](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig). You need to import and use [`transformers.BitsAndBytesConfig`] for components that come from Transformers. For example, `text_encoder_2` in [`FluxPipeline`] is a [`~transformers.T5EncoderModel`] from Transformers so you need to use [`transformers.BitsAndBytesConfig`] instead of [`diffusers.BitsAndBytesConfig`].\n\n> [!TIP]\n> Use the [basic quantization](#basic-quantization) method above if you don't want to manage these distinct imports or aren't sure where each pipeline component comes from.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig\nfrom diffusers.quantizers import PipelineQuantizationConfig\nfrom transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig\n\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_mapping={\n        \"transformer\": DiffusersBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16),\n        \"text_encoder_2\": TransformersBitsAndBytesConfig(\n            load_in_4bit=True, compute_dtype=torch.bfloat16\n        ),\n    }\n)\n```\n\nPass the `pipeline_quant_config` to [`~DiffusionPipeline.from_pretrained`] to quantize the pipeline.\n\n```py\npipe = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\nimage = pipe(\"photo of a cute dog\").images[0]\n```\n\n## Resources\n\nCheck out the resources below to learn more about quantization.\n\n- If you are new to quantization, we recommend checking out the following beginner-friendly courses in collaboration with DeepLearning.AI.\n\n    - [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/)\n    - [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/)\n\n- Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) if you're interested in adding a new quantization method.\n\n- The Transformers quantization [Overview](https://huggingface.co/docs/transformers/quantization/overview#when-to-use-what) provides an overview of the pros and cons of different quantization backends.\n\n- Read the [Exploring Quantization Backends in Diffusers](https://huggingface.co/blog/diffusers-quantization) blog post for a brief introduction to each quantization backend, how to choose a backend, and combining quantization with other memory optimizations.\n"
  },
  {
    "path": "docs/source/en/quantization/quanto.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n\n-->\n\n# Quanto\n\n[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind:\n\n- All features are available in eager mode (works with non-traceable models)\n- Supports quantization aware training\n- Quantized models are compatible with `torch.compile`\n- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU)\n\nIn order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate`\n\n```shell\npip install optimum-quanto accelerate\n```\n\nNow you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto.   \n\n```python\nimport torch\nfrom diffusers import FluxTransformer2DModel, QuantoConfig\n\nmodel_id = \"black-forest-labs/FLUX.1-dev\"\nquantization_config = QuantoConfig(weights_dtype=\"float8\")\ntransformer = FluxTransformer2DModel.from_pretrained(\n      model_id,\n      subfolder=\"transformer\",\n      quantization_config=quantization_config,\n      torch_dtype=torch.bfloat16,\n)\n\npipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)\npipe.to(\"cuda\")\n\nprompt = \"A cat holding a sign that says hello world\"\nimage = pipe(\n    prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512\n).images[0]\nimage.save(\"output.png\")\n```\n\n## Skipping Quantization on specific modules\n\nIt is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict`  \n\n```python\nimport torch\nfrom diffusers import FluxTransformer2DModel, QuantoConfig\n\nmodel_id = \"black-forest-labs/FLUX.1-dev\"\nquantization_config = QuantoConfig(weights_dtype=\"float8\", modules_to_not_convert=[\"proj_out\"])\ntransformer = FluxTransformer2DModel.from_pretrained(\n      model_id,\n      subfolder=\"transformer\",\n      quantization_config=quantization_config,\n      torch_dtype=torch.bfloat16,\n)\n```\n\n## Using `from_single_file` with the Quanto Backend\n\n`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`. \n\n```python\nimport torch\nfrom diffusers import FluxTransformer2DModel, QuantoConfig\n\nckpt_path = \"https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors\"\nquantization_config = QuantoConfig(weights_dtype=\"float8\")\ntransformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)\n```\n\n## Saving Quantized models\n\nDiffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method.\n\nThe serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized\nwith Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained`\n\n```python\nimport torch\nfrom diffusers import FluxTransformer2DModel, QuantoConfig\n\nmodel_id = \"black-forest-labs/FLUX.1-dev\"\nquantization_config = QuantoConfig(weights_dtype=\"float8\")\ntransformer = FluxTransformer2DModel.from_pretrained(\n      model_id,\n      subfolder=\"transformer\",\n      quantization_config=quantization_config,\n      torch_dtype=torch.bfloat16,\n)\n# save quantized model to reuse\ntransformer.save_pretrained(\"<your quantized model save path>\")\n\n# you can reload your quantized model with\nmodel = FluxTransformer2DModel.from_pretrained(\"<your quantized model save path>\")\n```\n\n## Using `torch.compile` with Quanto\n\nCurrently the Quanto backend supports `torch.compile` for the following quantization types:\n\n- `int8` weights \n\n```python\nimport torch\nfrom diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig\n\nmodel_id = \"black-forest-labs/FLUX.1-dev\"\nquantization_config = QuantoConfig(weights_dtype=\"int8\")\ntransformer = FluxTransformer2DModel.from_pretrained(\n    model_id,\n    subfolder=\"transformer\",\n    quantization_config=quantization_config,\n    torch_dtype=torch.bfloat16,\n)\ntransformer = torch.compile(transformer, mode=\"max-autotune\", fullgraph=True)\n\npipe = FluxPipeline.from_pretrained(\n    model_id, transformer=transformer, torch_dtype=torch_dtype\n)\npipe.to(\"cuda\")\nimages = pipe(\"A cat holding a sign that says hello\").images[0]\nimages.save(\"flux-quanto-compile.png\")\n```\n\n## Supported Quantization Types\n\n### Weights\n\n- float8\n- int8\n- int4\n- int2\n\n\n"
  },
  {
    "path": "docs/source/en/quantization/torchao.md",
    "content": "<!-- Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License. -->\n\n# torchao\n\n[torchao](https://github.com/pytorch/ao) provides high-performance dtypes and optimizations based on quantization and sparsity for inference and training PyTorch models. It is supported for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.\n\nMake sure Pytorch 2.5+ and torchao are installed with the command below.\n\n```bash\nuv pip install -U torch torchao\n```\n\nEach quantization dtype is available as a separate instance of a [AOBaseConfig](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) class. This provides more flexible configuration options by exposing more available arguments.\n\nPass the `AOBaseConfig` of a quantization dtype, like [Int4WeightOnlyConfig](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig) to [`TorchAoConfig`] in [`~ModelMixin.from_pretrained`].\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig\nfrom torchao.quantization import Int8WeightOnlyConfig\n\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_mapping={\"transformer\": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}\n)\npipeline = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n    device_map=\"cuda\"\n)\n```\n\nFor simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig\n\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_mapping={\"transformer\": TorchAoConfig(\"int8wo\")}\n)\npipeline = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n    device_map=\"cuda\"\n)\n```\n\n## torch.compile\n\ntorchao supports [torch.compile](../optimization/fp16#torchcompile) which can speed up inference with one line of code.\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig\nfrom torchao.quantization import Int4WeightOnlyConfig\n\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_mapping={\"transformer\": TorchAoConfig(Int4WeightOnlyConfig(group_size=128))}\n)\npipeline = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n    device_map=\"cuda\"\n)\n\npipeline.transformer.compile(transformer, mode=\"max-autotune\", fullgraph=True)\n```\n\nRefer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450) for inference speed and memory usage benchmarks with Flux and CogVideoX. More benchmarks on various hardware are also available in the torchao [repository](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).\n\n> [!TIP]\n> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.\n\n## Supported quantization types\n\ntorchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.\n\nWeight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.\n\nDynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.\n\nThe quantization methods supported are as follows:\n\n| **Category** | **Full Function Names** | **Shorthands** |\n|--------------|-------------------------|----------------|\n| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |\n| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8dq_e4m3_tensor`, `float8dq_e4m3_row` |\n| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |\n| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |\n\nSome quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.\n\nRefer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.\n\n## Serializing and Deserializing quantized models\n\nTo serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.\n\n```python\nimport torch\nfrom diffusers import AutoModel, TorchAoConfig\n\nquantization_config = TorchAoConfig(\"int8wo\")\ntransformer = AutoModel.from_pretrained(\n    \"black-forest-labs/Flux.1-Dev\",\n    subfolder=\"transformer\",\n    quantization_config=quantization_config,\n    torch_dtype=torch.bfloat16,\n)\ntransformer.save_pretrained(\"/path/to/flux_int8wo\", safe_serialization=False)\n```\n\nTo load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.\n\n```python\nimport torch\nfrom diffusers import FluxPipeline, AutoModel\n\ntransformer = AutoModel.from_pretrained(\"/path/to/flux_int8wo\", torch_dtype=torch.bfloat16, use_safetensors=False)\npipe = FluxPipeline.from_pretrained(\"black-forest-labs/Flux.1-Dev\", transformer=transformer, torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\nprompt = \"A cat holding a sign that says hello world\"\nimage = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]\nimage.save(\"output.png\")\n```\n\nIf you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.\n\n```python\nimport torch\nfrom accelerate import init_empty_weights\nfrom diffusers import FluxPipeline, AutoModel, TorchAoConfig\n\n# Serialize the model\ntransformer = AutoModel.from_pretrained(\n    \"black-forest-labs/Flux.1-Dev\",\n    subfolder=\"transformer\",\n    quantization_config=TorchAoConfig(\"uint4wo\"),\n    torch_dtype=torch.bfloat16,\n)\ntransformer.save_pretrained(\"/path/to/flux_uint4wo\", safe_serialization=False, max_shard_size=\"50GB\")\n# ...\n\n# Load the model\nstate_dict = torch.load(\"/path/to/flux_uint4wo/diffusion_pytorch_model.bin\", weights_only=False, map_location=\"cpu\")\nwith init_empty_weights():\n    transformer = AutoModel.from_config(\"/path/to/flux_uint4wo/config.json\")\ntransformer.load_state_dict(state_dict, strict=True, assign=True)\n```\n\n> [!TIP]\n> The [`AutoModel`] API is supported for PyTorch >= 2.6 as shown in the examples below.\n\n## Resources\n\n- [TorchAO Quantization API](https://docs.pytorch.org/ao/stable/index.html)\n- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)\n"
  },
  {
    "path": "docs/source/en/quicktour.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Quickstart\n\nDiffusers is a library for developers and researchers that provides an easy inference API for generating images, videos and audio, as well as the building blocks for implementing new workflows.\n\nDiffusers provides many optimizations out-of-the-box that makes it possible to load and run large models on setups with limited memory or to accelerate inference.\n\nThis Quickstart will give you an overview of Diffusers and get you up and generating quickly.\n\n> [!TIP]\n> Before you begin, make sure you have a Hugging Face [account](https://huggingface.co/join) in order to use gated models like [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev).\n\nFollow the [Installation](./installation) guide to install Diffusers if it's not already installed.\n\n## DiffusionPipeline\n\nA diffusion model combines multiple components to generate outputs in any modality based on an input, such as a text description, image or both.\n\nFor a standard text-to-image model:\n\n1. A text encoder turns a prompt into embeddings that guide the denoising process. Some models have more than one text encoder.\n2. A scheduler contains the algorithmic specifics for gradually denoising initial random noise into clean outputs. Different schedulers affect generation speed and quality.\n3. A UNet or diffusion transformer (DiT) is the workhorse of a diffusion model.\n\n  At each step, it performs the denoising predictions, such as how much noise to remove or the general direction in which to steer the noise to generate better quality outputs.\n\n  The UNet or DiT repeats this loop for a set amount of steps to generate the final output.\n  \n4. A variational autoencoder (VAE) encodes and decodes pixels to a spatially compressed latent-space. *Latents* are compressed representations of an image and are more efficient to work with. The UNet or DiT operates on latents, and the clean latents at the end are decoded back into images.\n\nThe [`DiffusionPipeline`] packages all these components into a single class for inference. There are several arguments in [`~DiffusionPipeline.__call__`] you can change, such as `num_inference_steps`, that affect the diffusion process. Try different values and arguments to see how they change generation quality or speed.\n\nLoad a model with [`~DiffusionPipeline.from_pretrained`] and describe what you'd like to generate. The example below uses the default argument values.\n\n<hfoptions id=\"diffusionpipeline\">\n<hfoption id=\"text-to-image\">\n\nUse `.images[0]` to access the generated image output.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"Qwen/Qwen-Image\", torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\n\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\npipeline(prompt).images[0]\n```\n\n</hfoption>\n<hfoption id=\"text-to-video\">\n\nUse `.frames[0]` to access the generated video output and [`~utils.export_to_video`] to save the video.\n\n```py\nimport torch\nfrom diffusers import AutoencoderKLWan, DiffusionPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\nfrom diffusers.utils import export_to_video\n\nvae = AutoencoderKLWan.from_pretrained(\n  \"Wan-AI/Wan2.2-T2V-A14B-Diffusers\",\n  subfolder=\"vae\",\n  torch_dtype=torch.float32\n)\npipeline = DiffusionPipeline.from_pretrained(\n  \"Wan-AI/Wan2.2-T2V-A14B-Diffusers\",\n  vae=vae\n  torch_dtype=torch.bfloat16,\n  device_map=\"cuda\"\n)\n\nprompt = \"\"\"\nCinematic video of a sleek cat lounging on a colorful inflatable in a crystal-clear turquoise pool in Palm Springs, \nsipping a salt-rimmed margarita through a straw. Golden-hour sunlight glows over mid-century modern homes and swaying palms. \nShot in rich Sony a7S III: with moody, glamorous color grading, subtle lens flares, and soft vintage film grain. \nRipples shimmer as a warm desert breeze stirs the water, blending luxury and playful charm in an epic, gorgeously composed frame.\n\"\"\"\nvideo = pipeline(prompt=prompt, num_frames=81, num_inference_steps=40).frames[0]\nexport_to_video(video, \"output.mp4\", fps=16)\n```\n\n</hfoption>\n</hfoptions>\n\n## LoRA\n\nAdapters insert a small number of trainable parameters to the original base model. Only the inserted parameters are fine-tuned while the rest of the model weights remain frozen. This makes it fast and cheap to fine-tune a model on a new style. Among adapters, [LoRA's](./tutorials/using_peft_for_inference) are the most popular.\n\nAdd a LoRA to a pipeline with the [`~loaders.QwenImageLoraLoaderMixin.load_lora_weights`] method. Some LoRA's require a special word to trigger it, such as `Realism`, in the example below. Check a LoRA's model card to see if it requires a trigger word.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"Qwen/Qwen-Image\", torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\npipeline.load_lora_weights(\n  \"flymy-ai/qwen-image-realism-lora\",\n)\n\nprompt = \"\"\"\nsuper Realism cinematic film still of a cat sipping a margarita in a pool in Palm Springs in the style of umempart, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\npipeline(prompt).images[0]\n```\n\nCheck out the [LoRA](./tutorials/using_peft_for_inference) docs or Adapters section to learn more.\n\n## Quantization\n\n[Quantization](./quantization/overview) stores data in fewer bits to reduce memory usage. It may also speed up inference because it takes less time to perform calculations with fewer bits.\n\nDiffusers provides several quantization backends and picking one depends on your use case. For example, [bitsandbytes](./quantization/bitsandbytes) and [torchao](./quantization/torchao) are both simple and easy to use for inference, but torchao supports more [quantization types](./quantization/torchao#supported-quantization-types) like fp8.\n\nConfigure [`PipelineQuantizationConfig`] with the backend to use, the specific arguments (refer to the [API](./api/quantization) reference for available arguments) for that backend, and which components to quantize. The example below quantizes the model to 4-bits and only uses 14.93GB of memory.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\n\nquant_config = PipelineQuantizationConfig(\n  quant_backend=\"bitsandbytes_4bit\",\n  quant_kwargs={\"load_in_4bit\": True, \"bnb_4bit_quant_type\": \"nf4\", \"bnb_4bit_compute_dtype\": torch.bfloat16},\n  components_to_quantize=[\"transformer\", \"text_encoder\"],\n)\npipeline = DiffusionPipeline.from_pretrained(\n  \"Qwen/Qwen-Image\",\n  torch_dtype=torch.bfloat16,\n  quantization_config=quant_config,\n  device_map=\"cuda\"\n)\n\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\npipeline(prompt).images[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n```\n\nTake a look at the [Quantization](./quantization/overview) section for more details.\n\n## Optimizations\n\n> [!TIP]\n> Optimization is dependent on hardware specs such as memory. Use this [Space](https://huggingface.co/spaces/diffusers/optimized-diffusers-code) to generate code examples that include all of Diffusers' available memory and speed optimization techniques for any model you're using.\n\nModern diffusion models are very large and have billions of parameters. The iterative denoising process is also computationally intensive and slow. Diffusers provides techniques for reducing memory usage and boosting inference speed. These techniques can be combined with quantization to optimize for both memory usage and inference speed.\n\n### Memory usage\n\nThe text encoders and UNet or DiT can use up as much as ~30GB of memory, exceeding the amount available on many free-tier or consumer GPUs.\n\nOffloading stores weights that aren't currently used on the CPU and only moves them to the GPU when they're needed. There are a few offloading types and the example below uses [model offloading](./optimization/memory#model-offloading). This moves an entire model, like a text encoder or transformer, to the CPU when it isn't actively being used.\n\nCall [`~DiffusionPipeline.enable_model_cpu_offload`] to activate it. By combining quantization and offloading, the following example only requires ~12.54GB of memory.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\n\nquant_config = PipelineQuantizationConfig(\n  quant_backend=\"bitsandbytes_4bit\",\n  quant_kwargs={\"load_in_4bit\": True, \"bnb_4bit_quant_type\": \"nf4\", \"bnb_4bit_compute_dtype\": torch.bfloat16},\n  components_to_quantize=[\"transformer\", \"text_encoder\"],\n)\npipeline = DiffusionPipeline.from_pretrained(\n  \"Qwen/Qwen-Image\",\n  torch_dtype=torch.bfloat16,\n  quantization_config=quant_config,\n  device_map=\"cuda\"\n)\npipeline.enable_model_cpu_offload()\n\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\npipeline(prompt).images[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n```\n\nRefer to the [Reduce memory usage](./optimization/memory) docs to learn more about other memory reducing techniques.\n\n### Inference speed\n\nThe denoising loop performs a lot of computations and can be slow. Methods like [torch.compile](./optimization/fp16#torchcompile) increases inference speed by compiling the computations into an optimized kernel. Compilation is slow for the first generation but successive generations should be much faster.\n\nThe example below uses [regional compilation](./optimization/fp16#regional-compilation) to only compile small regions of a model. It reduces cold-start latency while also providing a runtime speed up.\n\nCall [`~ModelMixin.compile_repeated_blocks`] on the model to activate it.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"Qwen/Qwen-Image\", torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\n\npipeline.transformer.compile_repeated_blocks(\n    fullgraph=True,\n)\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\npipeline(prompt).images[0]\n```\n\nCheck out the [Accelerate inference](./optimization/fp16) or [Caching](./optimization/cache) docs for more methods that speed up inference."
  },
  {
    "path": "docs/source/en/stable_diffusion.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# Basic performance\n\nDiffusion is a random process that is computationally demanding. You may need to run the [`DiffusionPipeline`] several times before getting a desired output. That's why it's important to carefully balance generation speed and memory usage in order to iterate faster,\n\nThis guide recommends some basic performance tips for using the [`DiffusionPipeline`]. Refer to the Inference Optimization section docs such as [Accelerate inference](./optimization/fp16) or [Reduce memory usage](./optimization/memory) for more detailed performance guides.\n\n## Memory usage\n\nReducing the amount of memory used indirectly speeds up generation and can help a model fit on device.\n\nThe [`~DiffusionPipeline.enable_model_cpu_offload`] method moves a model to the CPU when it is not in use to save GPU memory.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\",\n  torch_dtype=torch.bfloat16,\n  device_map=\"cuda\"\n)\npipeline.enable_model_cpu_offload()\n\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\npipeline(prompt).images[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n```\n\n## Inference speed\n\nDenoising is the most computationally demanding process during diffusion. Methods that optimizes this process accelerates inference speed. Try the following methods for a speed up.\n\n- Add `device_map=\"cuda\"` to place the pipeline on a GPU. Placing a model on an accelerator, like a GPU, increases speed because it performs computations in parallel.\n- Set `torch_dtype=torch.bfloat16` to execute the pipeline in half-precision. Reducing the data type precision increases speed because it takes less time to perform computations in a lower precision.\n\n```py\nimport torch\nimport time\nfrom diffusers import DiffusionPipeline, DPMSolverMultistepScheduler\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\",\n  torch_dtype=torch.bfloat16,\n  device_map=\"cuda\n)\n```\n\n- Use a faster scheduler, such as [`DPMSolverMultistepScheduler`], which only requires ~20-25 steps.\n- Set `num_inference_steps` to a lower value. Reducing the number of inference steps reduces the overall number of computations. However, this can result in lower generation quality.\n\n```py\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\n\nstart_time = time.perf_counter()\nimage = pipeline(prompt).images[0]\nend_time = time.perf_counter()\n\nprint(f\"Image generation took {end_time - start_time:.3f} seconds\")\n```\n\n## Generation quality\n\nMany modern diffusion models deliver high-quality images out-of-the-box. However, you can still improve generation quality by trying the following.\n\n- Try a more detailed and descriptive prompt. Include details such as the image medium, subject, style, and aesthetic. A negative prompt may also help by guiding a model away from undesirable features by using words like low quality or blurry.\n\n    ```py\n    import torch\n    from diffusers import DiffusionPipeline\n\n    pipeline = DiffusionPipeline.from_pretrained(\n        \"stabilityai/stable-diffusion-xl-base-1.0\",\n        torch_dtype=torch.bfloat16,\n        device_map=\"cuda\"\n    )\n\n    prompt = \"\"\"\n    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\n    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n    \"\"\"\n    negative_prompt = \"low quality, blurry, ugly, poor details\"\n    pipeline(prompt, negative_prompt=negative_prompt).images[0]\n    ```\n\n    For more details about creating better prompts, take a look at the [Prompt techniques](./using-diffusers/weighted_prompts) doc.\n\n- Try a different scheduler, like [`HeunDiscreteScheduler`] or [`LMSDiscreteScheduler`], that gives up generation speed for quality.\n\n    ```py\n    import torch\n    from diffusers import DiffusionPipeline, HeunDiscreteScheduler\n\n    pipeline = DiffusionPipeline.from_pretrained(\n        \"stabilityai/stable-diffusion-xl-base-1.0\",\n        torch_dtype=torch.bfloat16,\n        device_map=\"cuda\"\n    )\n    pipeline.scheduler = HeunDiscreteScheduler.from_config(pipeline.scheduler.config)\n\n    prompt = \"\"\"\n    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\n    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n    \"\"\"\n    negative_prompt = \"low quality, blurry, ugly, poor details\"\n    pipeline(prompt, negative_prompt=negative_prompt).images[0]\n    ```\n\n## Next steps\n\nDiffusers offers more advanced and powerful optimizations such as [group-offloading](./optimization/memory#group-offloading) and [regional compilation](./optimization/fp16#regional-compilation). To learn more about how to maximize performance, take a look at the Inference Optimization section."
  },
  {
    "path": "docs/source/en/training/adapt_a_model.md",
    "content": "# Adapt a model to a new task\n\nMany diffusion systems share the same components, allowing you to adapt a pretrained model for one task to an entirely different task.\n\nThis guide will show you how to adapt a pretrained text-to-image model for inpainting by initializing and modifying the architecture of a pretrained [`UNet2DConditionModel`].\n\n## Configure UNet2DConditionModel parameters\n\nA [`UNet2DConditionModel`] by default accepts 4 channels in the [input sample](https://huggingface.co/docs/diffusers/v0.16.0/en/api/models#diffusers.UNet2DConditionModel.in_channels). For example, load a pretrained text-to-image model like [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) and take a look at the number of `in_channels`:\n\n```py\nfrom diffusers import StableDiffusionPipeline\n\npipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", use_safetensors=True)\npipeline.unet.config[\"in_channels\"]\n4\n```\n\nInpainting requires 9 channels in the input sample. You can check this value in a pretrained inpainting model like [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting):\n\n```py\nfrom diffusers import StableDiffusionPipeline\n\npipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-inpainting\", use_safetensors=True)\npipeline.unet.config[\"in_channels\"]\n9\n```\n\nTo adapt your text-to-image model for inpainting, you'll need to change the number of `in_channels` from 4 to 9.\n\nInitialize a [`UNet2DConditionModel`] with the pretrained text-to-image model weights, and change `in_channels` to 9. Changing the number of `in_channels` means you need to set `ignore_mismatched_sizes=True` and `low_cpu_mem_usage=False` to avoid a size mismatch error because the shape is different now.\n\n```py\nfrom diffusers import AutoModel\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nunet = AutoModel.from_pretrained(\n    model_id,\n    subfolder=\"unet\",\n    in_channels=9,\n    low_cpu_mem_usage=False,\n    ignore_mismatched_sizes=True,\n    use_safetensors=True,\n)\n```\n\nThe pretrained weights of the other components from the text-to-image model are initialized from their checkpoints, but the input channel weights (`conv_in.weight`) of the `unet` are randomly initialized. It is important to finetune the model for inpainting because otherwise the model returns noise.\n"
  },
  {
    "path": "docs/source/en/training/cogvideox.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n# CogVideoX\n\nCogVideoX is a text-to-video generation model focused on creating more coherent videos aligned with a prompt. It achieves this using several methods.\n\n- a 3D variational autoencoder that compresses videos spatially and temporally, improving compression rate and video accuracy.\n\n- an expert transformer block to help align text and video, and a 3D full attention module for capturing and creating spatially and temporally accurate videos.\n\nThe actual test of the video instruction dimension found that CogVideoX has good effects on consistent theme, dynamic information, consistent background, object information, smooth motion, color, scene, appearance style, and temporal style but cannot achieve good results with human action, spatial relationship, and multiple objects.\n\nFinetuning with Diffusers can help make up for these poor results. \n\n## Data Preparation\n\nThe training scripts accepts data in two formats.  \n\nThe first format is suited for small-scale training, and the second format uses a CSV format, which is more appropriate for streaming data for large-scale training. In the future, Diffusers will support the `<Video>` tag.\n\n### Small format\n\nTwo files where one file contains line-separated prompts and another file contains line-separated paths to video data (the path to video files must be relative to the path you pass when specifying `--instance_data_root`). Let's take a look at an example to understand this better!\n\nAssume you've specified `--instance_data_root` as `/dataset`, and that this directory contains the files: `prompts.txt` and `videos.txt`.\n\nThe `prompts.txt` file should contain line-separated prompts:\n\n```\nA black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.\nA black and white animated sequence on a ship's deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language. The character progresses from confident to focused, then to strained and distressed, displaying a range of emotions as it navigates challenges. The ship's interior remains static in the background, with minimalistic details such as a bell and open door. The character's dynamic movements and changing expressions drive the narrative, with no camera movement to distract from its evolving reactions and physical gestures.\n...\n```\n\nThe `videos.txt` file should contain line-separate paths to video files. Note that the path should be _relative_ to the `--instance_data_root` directory.\n\n```\nvideos/00000.mp4\nvideos/00001.mp4\n...\n```\n\nOverall, this is how your dataset would look like if you ran the `tree` command on the dataset root directory:\n\n```\n/dataset\n├── prompts.txt\n├── videos.txt\n├── videos\n    ├── videos/00000.mp4\n    ├── videos/00001.mp4\n    ├── ...\n```\n\nWhen using this format, the `--caption_column` must be `prompts.txt` and `--video_column` must be `videos.txt`.\n\n### Stream format\n\nYou could use a single CSV file. For the sake of this example, assume you have a `metadata.csv` file. The expected format is:\n\n```\n<CAPTION_COLUMN>,<PATH_TO_VIDEO_COLUMN>\n\"\"\"A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.\"\"\",\"\"\"00000.mp4\"\"\"\n\"\"\"A black and white animated sequence on a ship's deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language. The character progresses from confident to focused, then to strained and distressed, displaying a range of emotions as it navigates challenges. The ship's interior remains static in the background, with minimalistic details such as a bell and open door. The character's dynamic movements and changing expressions drive the narrative, with no camera movement to distract from its evolving reactions and physical gestures.\"\"\",\"\"\"00001.mp4\"\"\"\n...\n```\n\nIn this case, the `--instance_data_root` should be the location where the videos are stored and `--dataset_name` should be either a path to local folder or a [`~datasets.load_dataset`] compatible dataset hosted on the Hub. Assuming you have videos of Minecraft gameplay at `https://huggingface.co/datasets/my-awesome-username/minecraft-videos`, you would have to specify `my-awesome-username/minecraft-videos`.\n\nWhen using this format, the `--caption_column` must be `<CAPTION_COLUMN>` and `--video_column` must be `<PATH_TO_VIDEO_COLUMN>`.\n\nYou are not strictly restricted to the CSV format. Any format works as long as the `load_dataset` method supports the file format to load a basic `<PATH_TO_VIDEO_COLUMN>` and `<CAPTION_COLUMN>`. The reason for going through these dataset organization gymnastics for loading video data is because `load_dataset` does not fully support all kinds of video formats.\n\n> [!NOTE]\n> CogVideoX works best with long and descriptive LLM-augmented prompts for video generation. We recommend pre-processing your videos by first generating a summary using a VLM and then augmenting the prompts with an LLM. To generate the above captions, we use [MiniCPM-V-26](https://huggingface.co/openbmb/MiniCPM-V-2_6) and [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct). A very barebones and no-frills example for this is available [here](https://gist.github.com/a-r-r-o-w/4dee20250e82f4e44690a02351324a4a). The official recommendation for augmenting prompts is [ChatGLM](https://huggingface.co/THUDM?search_models=chatglm) and a length of 50-100 words is considered good.\n\n>![NOTE]\n> It is expected that your dataset is already pre-processed. If not, some basic pre-processing can be done by playing with the following parameters:\n> `--height`, `--width`, `--fps`, `--max_num_frames`, `--skip_frames_start` and `--skip_frames_end`.\n> Presently, all videos in your dataset should contain the same number of video frames when using a training batch size > 1.\n\n<!-- TODO: Implement frame packing in future to address above issue. -->\n\n## Training\n\nYou need to setup your development environment by installing the necessary requirements. The following packages are required:\n- Torch 2.0 or above based on the training features you are utilizing (might require latest or nightly versions for quantized/deepspeed training)\n- `pip install diffusers transformers accelerate peft huggingface_hub` for all things modeling and training related\n- `pip install datasets decord` for loading video training data\n- `pip install bitsandbytes` for using 8-bit Adam or AdamW optimizers for memory-optimized training\n- `pip install wandb` optionally for monitoring training logs\n- `pip install deepspeed` optionally for [DeepSpeed](https://github.com/microsoft/DeepSpeed) training\n- `pip install prodigyopt` optionally if you would like to use the Prodigy optimizer for training\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\nBefore running the script, make sure you install the library from source:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\n \n\nThen navigate to the example folder containing the training script and install the required dependencies for the script you're using:\n\n- PyTorch\n\n```bash\ncd examples/cogvideo\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if you use torch.compile, there can be dramatic speedups. The PEFT library is used as a backend for LoRA training, so make sure to have `peft>=0.6.0` installed in your environment.\n\nIf you would like to push your model to the Hub after training is completed with a neat model card, make sure you're logged in:\n\n```bash\nhf auth login\n\n# Alternatively, you could upload your model manually using:\n# hf upload my-cool-account-name/my-cool-lora-name /path/to/awesome/lora\n```\n\nMake sure your data is prepared as described in [Data Preparation](#data-preparation). When ready, you can begin training!\n\nAssuming you are training on 50 videos of a similar concept, we have found 1500-2000 steps to work well. The official recommendation, however, is 100 videos with a total of 4000 steps. Assuming you are training on a single GPU with a `--train_batch_size` of `1`:\n- 1500 steps on 50 videos would correspond to `30` training epochs\n- 4000 steps on 100 videos would correspond to `40` training epochs\n\n```bash\n#!/bin/bash\n\nGPU_IDS=\"0\"\n\naccelerate launch --gpu_ids $GPU_IDS examples/cogvideo/train_cogvideox_lora.py \\\n  --pretrained_model_name_or_path THUDM/CogVideoX-2b \\\n  --cache_dir <CACHE_DIR> \\\n  --instance_data_root <PATH_TO_WHERE_VIDEO_FILES_ARE_STORED> \\\n  --dataset_name my-awesome-name/my-awesome-dataset \\\n  --caption_column <CAPTION_COLUMN> \\\n  --video_column <PATH_TO_VIDEO_COLUMN> \\\n  --id_token <ID_TOKEN> \\\n  --validation_prompt \"<ID_TOKEN> Spiderman swinging over buildings:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \\\n  --validation_prompt_separator ::: \\\n  --num_validation_videos 1 \\\n  --validation_epochs 10 \\\n  --seed 42 \\\n  --rank 64 \\\n  --lora_alpha 64 \\\n  --mixed_precision fp16 \\\n  --output_dir /raid/aryan/cogvideox-lora \\\n  --height 480 --width 720 --fps 8 --max_num_frames 49 --skip_frames_start 0 --skip_frames_end 0 \\\n  --train_batch_size 1 \\\n  --num_train_epochs 30 \\\n  --checkpointing_steps 1000 \\\n  --gradient_accumulation_steps 1 \\\n  --learning_rate 1e-3 \\\n  --lr_scheduler cosine_with_restarts \\\n  --lr_warmup_steps 200 \\\n  --lr_num_cycles 1 \\\n  --enable_slicing \\\n  --enable_tiling \\\n  --optimizer Adam \\\n  --adam_beta1 0.9 \\\n  --adam_beta2 0.95 \\\n  --max_grad_norm 1.0 \\\n  --report_to wandb\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n* `--report_to wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\nSetting the `<ID_TOKEN>` is not necessary. From some limited experimentation, we found it works better (as it resembles [Dreambooth](https://huggingface.co/docs/diffusers/en/training/dreambooth) training) than without. When provided, the `<ID_TOKEN>` is appended to the beginning of each prompt. So, if your `<ID_TOKEN>` was `\"DISNEY\"` and your prompt was `\"Spiderman swinging over buildings\"`, the effective prompt used in training would be `\"DISNEY Spiderman swinging over buildings\"`. When not provided, you would either be training without any additional token or could augment your dataset to apply the token where you wish before starting the training.\n\n> [!NOTE]\n> You can pass `--use_8bit_adam` to reduce the memory requirements of training.\n\n> [!IMPORTANT]\n> The following settings have been tested at the time of adding CogVideoX LoRA training support:\n> - Our testing was primarily done on CogVideoX-2b. We will work on CogVideoX-5b and CogVideoX-5b-I2V soon\n> - One dataset comprised of 70 training videos of resolutions `200 x 480 x 720` (F x H x W). From this, by using frame skipping in data preprocessing, we created two smaller 49-frame and 16-frame datasets for faster experimentation and because the maximum limit recommended by the CogVideoX team is 49 frames. Out of the 70 videos, we created three groups of 10, 25 and 50 videos. All videos were similar in nature of the concept being trained.\n> - 25+ videos worked best for training new concepts and styles.\n> - We found that it is better to train with an identifier token that can be specified as `--id_token`. This is similar to Dreambooth-like training but normal finetuning without such a token works too.\n> - Trained concept seemed to work decently well when combined with completely unrelated prompts. We expect even better results if CogVideoX-5B is finetuned.\n> - The original repository uses a `lora_alpha` of `1`. We found this not suitable in many runs, possibly due to difference in modeling backends and training settings. Our recommendation is to set to the `lora_alpha` to either `rank` or `rank // 2`.\n> - If you're training on data whose captions generate bad results with the original model, a `rank` of 64 and above is good and also the recommendation by the team behind CogVideoX. If the generations are already moderately good on your training captions, a `rank` of 16/32 should work. We found that setting the rank too low, say `4`, is not ideal and doesn't produce promising results.\n> - The authors of CogVideoX recommend 4000 training steps and 100 training videos overall to achieve the best result. While that might yield the best results, we found from our limited experimentation that 2000 steps and 25 videos could also be sufficient.\n> - When using the Prodigy optimizer for training, one can follow the recommendations from [this](https://huggingface.co/blog/sdxl_lora_advanced_script) blog. Prodigy tends to overfit quickly. From my very limited testing, I found a learning rate of `0.5` to be suitable in addition to `--prodigy_use_bias_correction`, `prodigy_safeguard_warmup` and `--prodigy_decouple`.\n> - The recommended learning rate by the CogVideoX authors and from our experimentation with Adam/AdamW is between `1e-3` and `1e-4` for a dataset of 25+ videos.\n>\n> Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data.\n\n<!-- TODO: Test finetuning with CogVideoX-5b and CogVideoX-5b-I2V and update scripts accordingly -->\n\n## Inference\n\nOnce you have trained a lora model, the inference can be done simply loading the lora weights into the `CogVideoXPipeline`.\n\n```python\nimport torch\nfrom diffusers import CogVideoXPipeline\nfrom diffusers.utils import export_to_video\n\npipe = CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-2b\", torch_dtype=torch.float16)\n# pipe.load_lora_weights(\"/path/to/lora/weights\", adapter_name=\"cogvideox-lora\") # Or,\npipe.load_lora_weights(\"my-awesome-hf-username/my-awesome-lora-name\", adapter_name=\"cogvideox-lora\") # If loading from the HF Hub\npipe.to(\"cuda\")\n\n# Assuming lora_alpha=32 and rank=64 for training. If different, set accordingly\npipe.set_adapters([\"cogvideox-lora\"], [32 / 64])\n\nprompt = \"A vast, shimmering ocean flows gracefully under a twilight sky, its waves undulating in a mesmerizing dance of blues and greens. The surface glints with the last rays of the setting sun, casting golden highlights that ripple across the water. Seagulls soar above, their cries blending with the gentle roar of the waves. The horizon stretches infinitely, where the ocean meets the sky in a seamless blend of hues. Close-ups reveal the intricate patterns of the waves, capturing the fluidity and dynamic beauty of the sea in motion.\"\nframes = pipe(prompt, guidance_scale=6, use_dynamic_cfg=True).frames[0]\nexport_to_video(frames, \"output.mp4\", fps=8)\n```\n\n\n## Reduce memory usage\n\nWhile testing using the diffusers library, all optimizations included in the diffusers library were enabled. This\nscheme has not been tested for actual memory usage on devices outside of **NVIDIA A100 / H100** architectures.\nGenerally, this scheme can be adapted to all **NVIDIA Ampere architecture** and above devices. If optimizations are\ndisabled, memory consumption will multiply, with peak memory usage being about 3 times the value in the table.\nHowever, speed will increase by about 3-4 times. You can selectively disable some optimizations, including:\n\n```\npipe.enable_sequential_cpu_offload()\npipe.vae.enable_slicing()\npipe.vae.enable_tiling()\n```\n\n+ For multi-GPU inference, the `enable_sequential_cpu_offload()` optimization needs to be disabled.\n+ Using INT8 models will slow down inference, which is done to accommodate lower-memory GPUs while maintaining minimal\n  video quality loss, though inference speed will significantly decrease.\n+ The CogVideoX-2B model was trained in `FP16` precision, and all CogVideoX-5B models were trained in `BF16` precision.\n  We recommend using the precision in which the model was trained for inference.\n+ [PytorchAO](https://github.com/pytorch/ao) and [Optimum-quanto](https://github.com/huggingface/optimum-quanto/) can be\n  used to quantize the text encoder, transformer, and VAE modules to reduce the memory requirements of CogVideoX. This\n  allows the model to run on free T4 Colabs or GPUs with smaller memory! Also, note that TorchAO quantization is fully\n  compatible with `torch.compile`, which can significantly improve inference speed. FP8 precision must be used on\n  devices with NVIDIA H100 and above, requiring source installation of `torch`, `torchao`, `diffusers`, and `accelerate`\n  Python packages. CUDA 12.4 is recommended.\n+ The inference speed tests also used the above memory optimization scheme. Without memory optimization, inference speed\n  increases by about 10%. Only the `diffusers` version of the model supports quantization.\n+ The model only supports English input; other languages can be translated into English for use via large model\n  refinement.\n+ The memory usage of model fine-tuning is tested in an `8 * H100` environment, and the program automatically\n  uses `Zero 2` optimization. If a specific number of GPUs is marked in the table, that number or more GPUs must be used\n  for fine-tuning.\n\n\n | **Attribute**                        | **CogVideoX-2B**                                                       | **CogVideoX-5B**                                                       |\n| ------------------------------------ | ---------------------------------------------------------------------- | ---------------------------------------------------------------------- |\n| **Model Name**                       | CogVideoX-2B                                                           | CogVideoX-5B                                                           |\n| **Inference Precision**              | FP16* (Recommended), BF16, FP32, FP8*, INT8, Not supported INT4         | BF16 (Recommended), FP16, FP32, FP8*, INT8, Not supported INT4         |\n| **Single GPU Inference VRAM**        | FP16: Using diffusers 12.5GB* INT8: Using diffusers with torchao 7.8GB* | BF16: Using diffusers 20.7GB* INT8: Using diffusers with torchao 11.4GB* |\n| **Multi GPU Inference VRAM**         | FP16: Using diffusers 10GB*                                             | BF16: Using diffusers 15GB*                                             |\n| **Inference Speed**                  | Single A100: ~90 seconds, Single H100: ~45 seconds                      | Single A100: ~180 seconds, Single H100: ~90 seconds                     |\n| **Fine-tuning Precision**            | FP16                                                                   | BF16                                                                   |\n| **Fine-tuning VRAM Consumption**     | 47 GB (bs=1, LORA) 61 GB (bs=2, LORA) 62GB (bs=1, SFT)                 | 63 GB (bs=1, LORA) 80 GB (bs=2, LORA) 75GB (bs=1, SFT)                 |\n"
  },
  {
    "path": "docs/source/en/training/controlnet.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNet\n\n[ControlNet](https://hf.co/papers/2302.05543) models are adapters trained on top of another pretrained model. It allows for a greater degree of control over image generation by conditioning the model with an additional input image. The input image can be a canny edge, depth map, human pose, and many more.\n\nIf you're training on a GPU with limited vRAM, you should try enabling the `gradient_checkpointing`, `gradient_accumulation_steps`, and `mixed_precision` parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with [xFormers](../optimization/xformers).\n\nThis guide will explore the [train_controlnet.py](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.\n\nBefore running the script, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen navigate to the example folder containing the training script and install the required dependencies for the script you're using:\n\n```bash\ncd examples/controlnet\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell, like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n> [!TIP]\n> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py) and let us know if you have any questions or concerns.\n\n## Script parameters\n\nThe training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L231) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.\n\nFor example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command:\n\n```bash\naccelerate launch train_controlnet.py \\\n  --mixed_precision=\"fp16\"\n```\n\nMany of the basic and important parameters are described in the [Text-to-image](text2image#script-parameters) training guide, so this guide just focuses on the relevant parameters for ControlNet:\n\n- `--max_train_samples`: the number of training samples; this can be lowered for faster training, but if you want to stream really large datasets, you'll need to include this parameter and the `--streaming` parameter in your training command\n- `--gradient_accumulation_steps`: number of update steps to accumulate before the backward pass; this allows you to train with a bigger batch size than your GPU memory can typically handle\n\n### Min-SNR weighting\n\nThe [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch.\n\nAdd the `--snr_gamma` parameter and set it to the recommended value of 5.0:\n\n```bash\naccelerate launch train_controlnet.py \\\n  --snr_gamma=5.0\n```\n\n## Training script\n\nAs with the script parameters, a general walkthrough of the training script is provided in the [Text-to-image](text2image#training-script) training guide. Instead, this guide takes a look at the relevant parts of the ControlNet script.\n\nThe training script has a [`make_train_dataset`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L582) function for preprocessing the dataset with image transforms and caption tokenization. You'll see that in addition to the usual caption tokenization and image transforms, the script also includes transforms for the conditioning image.\n\n> [!TIP]\n> If you're streaming a dataset on a TPU, performance may be bottlenecked by the 🤗 Datasets library which is not optimized for images. To ensure maximum throughput, you're encouraged to explore other dataset formats like [WebDataset](https://webdataset.github.io/webdataset/), [TorchData](https://github.com/pytorch/data), and [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds).\n\n```py\nconditioning_image_transforms = transforms.Compose(\n    [\n        transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n        transforms.CenterCrop(args.resolution),\n        transforms.ToTensor(),\n    ]\n)\n```\n\nWithin the [`main()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L713) function, you'll find the code for loading the tokenizer, text encoder, scheduler and models. This is also where the ControlNet model is loaded either from existing weights or randomly initialized from a UNet:\n\n```py\nif args.controlnet_model_name_or_path:\n    logger.info(\"Loading existing controlnet weights\")\n    controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)\nelse:\n    logger.info(\"Initializing controlnet weights from unet\")\n    controlnet = ControlNetModel.from_unet(unet)\n```\n\nThe [optimizer](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L871) is set up to update the ControlNet parameters:\n\n```py\nparams_to_optimize = controlnet.parameters()\noptimizer = optimizer_class(\n    params_to_optimize,\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\nFinally, in the [training loop](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L943), the conditioning text embeddings and image are passed to the down and mid-blocks of the ControlNet model:\n\n```py\nencoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\ncontrolnet_image = batch[\"conditioning_pixel_values\"].to(dtype=weight_dtype)\n\ndown_block_res_samples, mid_block_res_sample = controlnet(\n    noisy_latents,\n    timesteps,\n    encoder_hidden_states=encoder_hidden_states,\n    controlnet_cond=controlnet_image,\n    return_dict=False,\n)\n```\n\nIf you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.\n\n## Launch the script\n\nNow you're ready to launch the training script! 🚀\n\nThis guide uses the [fusing/fill50k](https://huggingface.co/datasets/fusing/fill50k) dataset, but remember, you can create and use your own dataset if you want (see the [Create a dataset for training](create_dataset) guide).\n\nSet the environment variable `MODEL_NAME` to a model id on the Hub or a path to a local model and `OUTPUT_DIR` to where you want to save the model.\n\nDownload the following images to condition your training with:\n\n```bash\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png\n```\n\nOne more thing before you launch the script! Depending on the GPU you have, you may need to enable certain optimizations to train a ControlNet. The default configuration in this script requires ~38GB of vRAM. If you're training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.\n\n<hfoptions id=\"gpu-select\">\n<hfoption id=\"16GB\">\n\nOn a 16GB GPU, you can use bitsandbytes 8-bit optimizer and gradient checkpointing to optimize your training run. Install bitsandbytes:\n\n```py\npip install bitsandbytes\n```\n\nThen, add the following parameter to your training command:\n\n```bash\naccelerate launch train_controlnet.py \\\n  --gradient_checkpointing \\\n  --use_8bit_adam \\\n```\n\n</hfoption>\n<hfoption id=\"12GB\">\n\nOn a 12GB GPU, you'll need bitsandbytes 8-bit optimizer, gradient checkpointing, xFormers, and set the gradients to `None` instead of zero to reduce your memory-usage.\n\n```bash\naccelerate launch train_controlnet.py \\\n  --use_8bit_adam \\\n  --gradient_checkpointing \\\n  --enable_xformers_memory_efficient_attention \\\n  --set_grads_to_none \\\n```\n\n</hfoption>\n<hfoption id=\"8GB\">\n\nOn a 8GB GPU, you'll need to use [DeepSpeed](https://www.deepspeed.ai/) to offload some of the tensors from the vRAM to either the CPU or NVME to allow training with less GPU memory.\n\nRun the following command to configure your 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nDuring configuration, confirm that you want to use DeepSpeed stage 2. Now it should be possible to train on under 8GB vRAM by combining DeepSpeed stage 2, fp16 mixed precision, and offloading the model parameters and the optimizer state to the CPU. The drawback is that this requires more system RAM (~25 GB). See the [DeepSpeed documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more configuration options. Your configuration file should look something like:\n\n```bash\ncompute_environment: LOCAL_MACHINE\ndeepspeed_config:\n  gradient_accumulation_steps: 4\n  offload_optimizer_device: cpu\n  offload_param_device: cpu\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\n```\n\nYou should also change the default Adam optimizer to DeepSpeed’s optimized version of Adam [`deepspeed.ops.adam.DeepSpeedCPUAdam`](https://deepspeed.readthedocs.io/en/latest/optimizers.html#adam-cpu) for a substantial speedup. Enabling `DeepSpeedCPUAdam` requires your system’s CUDA toolchain version to be the same as the one installed with PyTorch.\n\nbitsandbytes 8-bit optimizers don’t seem to be compatible with DeepSpeed at the moment.\n\nThat's it! You don't need to add any additional parameters to your training command.\n\n</hfoption>\n</hfoptions>\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path/to/save/model\"\n\naccelerate launch train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n --push_to_hub\n```\n\nOnce training is complete, you can use your newly trained model for inference!\n\n```py\nfrom diffusers import StableDiffusionControlNetPipeline, ControlNetModel\nfrom diffusers.utils import load_image\nimport torch\n\ncontrolnet = ControlNetModel.from_pretrained(\"path/to/controlnet\", torch_dtype=torch.float16)\npipeline = StableDiffusionControlNetPipeline.from_pretrained(\n    \"path/to/base/model\", controlnet=controlnet, torch_dtype=torch.float16\n).to(\"cuda\")\n\ncontrol_image = load_image(\"./conditioning_image_1.png\")\nprompt = \"pale golden rod circle with old lace background\"\n\ngenerator = torch.manual_seed(0)\nimage = pipeline(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]\nimage.save(\"./output.png\")\n```\n\n## Stable Diffusion XL\n\nStable Diffusion XL (SDXL) is a powerful text-to-image model that generates high-resolution images, and it adds a second text-encoder to its architecture. Use the [`train_controlnet_sdxl.py`](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet_sdxl.py) script to train a ControlNet adapter for the SDXL model.\n\nThe SDXL training script is discussed in more detail in the [SDXL training](sdxl) guide.\n\n## Next steps\n\nCongratulations on training your own ControlNet! To learn more about how to use your new model, the following guides may be helpful:\n\n- Learn how to [use a ControlNet](../using-diffusers/controlnet) for inference on a variety of tasks.\n"
  },
  {
    "path": "docs/source/en/training/create_dataset.md",
    "content": "# Create a dataset for training\n\nThere are many datasets on the [Hub](https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=downloads) to train a model on, but if you can't find one you're interested in or want to use your own, you can create a dataset with the 🤗 [Datasets](https://huggingface.co/docs/datasets) library. The dataset structure depends on the task you want to train your model on. The most basic dataset structure is a directory of images for tasks like unconditional image generation. Another dataset structure may be a directory of images and a text file containing their corresponding text captions for tasks like text-to-image generation.\n\nThis guide will show you two ways to create a dataset to finetune on:\n\n- provide a folder of images to the `--train_data_dir` argument\n- upload a dataset to the Hub and pass the dataset repository id to the `--dataset_name` argument\n\n> [!TIP]\n> 💡 Learn more about how to create an image dataset for training in the [Create an image dataset](https://huggingface.co/docs/datasets/image_dataset) guide.\n\n## Provide a dataset as a folder\n\nFor unconditional generation, you can provide your own dataset as a folder of images. The training script uses the [`ImageFolder`](https://huggingface.co/docs/datasets/en/image_dataset#imagefolder) builder from 🤗 Datasets to automatically build a dataset from the folder. Your directory structure should look like:\n\n```bash\ndata_dir/xxx.png\ndata_dir/xxy.png\ndata_dir/[...]/xxz.png\n```\n\nPass the path to the dataset directory to the `--train_data_dir` argument, and then you can start training:\n\n```bash\naccelerate launch train_unconditional.py \\\n    --train_data_dir <path-to-train-directory> \\\n    <other-arguments>\n```\n\n## Upload your data to the Hub\n\n> [!TIP]\n> 💡 For more details and context about creating and uploading a dataset to the Hub, take a look at the [Image search with 🤗 Datasets](https://huggingface.co/blog/image-search-datasets) post.\n\nStart by creating a dataset with the [`ImageFolder`](https://huggingface.co/docs/datasets/image_load#imagefolder) feature, which creates an `image` column containing the PIL-encoded images.\n\nYou can use the `data_dir` or `data_files` parameters to specify the location of the dataset. The `data_files` parameter supports mapping specific files to dataset splits like `train` or `test`:\n\n```python\nfrom datasets import load_dataset\n\n# example 1: local folder\ndataset = load_dataset(\"imagefolder\", data_dir=\"path_to_your_folder\")\n\n# example 2: local files (supported formats are tar, gzip, zip, xz, rar, zstd)\ndataset = load_dataset(\"imagefolder\", data_files=\"path_to_zip_file\")\n\n# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd)\ndataset = load_dataset(\n    \"imagefolder\",\n    data_files=\"https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip\",\n)\n\n# example 4: providing several splits\ndataset = load_dataset(\n    \"imagefolder\", data_files={\"train\": [\"path/to/file1\", \"path/to/file2\"], \"test\": [\"path/to/file3\", \"path/to/file4\"]}\n)\n```\n\nThen use the [`~datasets.Dataset.push_to_hub`] method to upload the dataset to the Hub:\n\n```python\n# assuming you have ran the hf auth login command in a terminal\ndataset.push_to_hub(\"name_of_your_dataset\")\n\n# if you want to push to a private repo, simply pass private=True:\ndataset.push_to_hub(\"name_of_your_dataset\", private=True)\n```\n\nNow the dataset is available for training by passing the dataset name to the `--dataset_name` argument:\n\n```bash\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image.py \\\n  --pretrained_model_name_or_path=\"stable-diffusion-v1-5/stable-diffusion-v1-5\" \\\n  --dataset_name=\"name_of_your_dataset\" \\\n  <other-arguments>\n```\n\n## Next steps\n\nNow that you've created a dataset, you can plug it into the `train_data_dir` (if your dataset is local) or `dataset_name` (if your dataset is on the Hub) arguments of a training script.\n\nFor your next steps, feel free to try and use your dataset to train a model for [unconditional generation](unconditional_training) or [text-to-image generation](text2image)!\n"
  },
  {
    "path": "docs/source/en/training/custom_diffusion.md",
    "content": "<!--Copyright 2025 Custom Diffusion authors The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Custom Diffusion\n\n[Custom Diffusion](https://huggingface.co/papers/2212.04488) is a training technique for personalizing image generation models. Like Textual Inversion, DreamBooth, and LoRA, Custom Diffusion only requires a few (~4-5) example images. This technique works by only training weights in the cross-attention layers, and it uses a special word to represent the newly learned concept. Custom Diffusion is unique because it can also learn multiple concepts at the same time.\n\nIf you're training on a GPU with limited vRAM, you should try enabling xFormers with `--enable_xformers_memory_efficient_attention` for faster training with lower vRAM requirements (16GB). To save even more memory, add `--set_grads_to_none` in the training argument to set the gradients to `None` instead of zero (this option can cause some issues, so if you experience any, try removing this parameter).\n\nThis guide will explore the [train_custom_diffusion.py](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion/train_custom_diffusion.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.\n\nBefore running the script, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nNavigate to the example folder with the training script and install the required dependencies:\n\n```bash\ncd examples/custom_diffusion\npip install -r requirements.txt\npip install clip-retrieval\n```\n\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell, like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n> [!TIP]\n> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion/train_custom_diffusion.py) and let us know if you have any questions or concerns.\n\n## Script parameters\n\nThe training script contains all the parameters to help you customize your training run. These are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/custom_diffusion/train_custom_diffusion.py#L319) function. The function comes with default values, but you can also set your own values in the training command if you'd like.\n\nFor example, to change the resolution of the input image:\n\n```bash\naccelerate launch train_custom_diffusion.py \\\n  --resolution=256\n```\n\nMany of the basic parameters are described in the [DreamBooth](dreambooth#script-parameters) training guide, so this guide focuses on the parameters unique to Custom Diffusion:\n\n- `--freeze_model`: freezes the key and value parameters in the cross-attention layer; the default is `crossattn_kv`, but you can set it to `crossattn` to train all the parameters in the cross-attention layer\n- `--concepts_list`: to learn multiple concepts, provide a path to a JSON file containing the concepts\n- `--modifier_token`: a special word used to represent the learned concept\n- `--initializer_token`: a special word used to initialize the embeddings of the `modifier_token`\n\n### Prior preservation loss\n\nPrior preservation loss is a method that uses a model's own generated samples to help it learn how to generate more diverse images. Because these generated sample images belong to the same class as the images you provided, they help the model retain what it has learned about the class and how it can use what it already knows about the class to make new compositions.\n\nMany of the parameters for prior preservation loss are described in the [DreamBooth](dreambooth#prior-preservation-loss) training guide.\n\n### Regularization\n\nCustom Diffusion includes training the target images with a small set of real images to prevent overfitting. As you can imagine, this can be easy to do when you're only training on a few images! Download 200 real images with `clip_retrieval`. The `class_prompt` should be the same category as the target images. These images are stored in `class_data_dir`.\n\n```bash\npython retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200\n```\n\nTo enable regularization, add the following parameters:\n\n- `--with_prior_preservation`: whether to use prior preservation loss\n- `--prior_loss_weight`: controls the influence of the prior preservation loss on the model\n- `--real_prior`: whether to use a small set of real images to prevent overfitting\n\n```bash\naccelerate launch train_custom_diffusion.py \\\n  --with_prior_preservation \\\n  --prior_loss_weight=1.0 \\\n  --class_data_dir=\"./real_reg/samples_cat\" \\\n  --class_prompt=\"cat\" \\\n  --real_prior=True \\\n```\n\n## Training script\n\n> [!TIP]\n> A lot of the code in the Custom Diffusion training script is similar to the [DreamBooth](dreambooth#training-script) script. This guide instead focuses on the code that is relevant to Custom Diffusion.\n\nThe Custom Diffusion training script has two dataset classes:\n\n- [`CustomDiffusionDataset`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/custom_diffusion/train_custom_diffusion.py#L165): preprocesses the images, class images, and prompts for training\n- [`PromptDataset`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/custom_diffusion/train_custom_diffusion.py#L148): prepares the prompts for generating class images\n\nNext, the `modifier_token` is [added to the tokenizer](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/custom_diffusion/train_custom_diffusion.py#L811), converted to token ids, and the token embeddings are resized to account for the new `modifier_token`. Then the `modifier_token` embeddings are initialized with the embeddings of the `initializer_token`. All parameters in the text encoder are frozen, except for the token embeddings since this is what the model is trying to learn to associate with the concepts.\n\n```py\nparams_to_freeze = itertools.chain(\n    text_encoder.text_model.encoder.parameters(),\n    text_encoder.text_model.final_layer_norm.parameters(),\n    text_encoder.text_model.embeddings.position_embedding.parameters(),\n)\nfreeze_params(params_to_freeze)\n```\n\nNow you'll need to add the [Custom Diffusion weights](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/custom_diffusion/train_custom_diffusion.py#L911C3-L911C3) to the attention layers. This is a really important step for getting the shape and size of the attention weights correct, and for setting the appropriate number of attention processors in each UNet block.\n\n```py\nst = unet.state_dict()\nfor name, _ in unet.attn_processors.items():\n    cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n    if name.startswith(\"mid_block\"):\n        hidden_size = unet.config.block_out_channels[-1]\n    elif name.startswith(\"up_blocks\"):\n        block_id = int(name[len(\"up_blocks.\")])\n        hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n    elif name.startswith(\"down_blocks\"):\n        block_id = int(name[len(\"down_blocks.\")])\n        hidden_size = unet.config.block_out_channels[block_id]\n    layer_name = name.split(\".processor\")[0]\n    weights = {\n        \"to_k_custom_diffusion.weight\": st[layer_name + \".to_k.weight\"],\n        \"to_v_custom_diffusion.weight\": st[layer_name + \".to_v.weight\"],\n    }\n    if train_q_out:\n        weights[\"to_q_custom_diffusion.weight\"] = st[layer_name + \".to_q.weight\"]\n        weights[\"to_out_custom_diffusion.0.weight\"] = st[layer_name + \".to_out.0.weight\"]\n        weights[\"to_out_custom_diffusion.0.bias\"] = st[layer_name + \".to_out.0.bias\"]\n    if cross_attention_dim is not None:\n        custom_diffusion_attn_procs[name] = attention_class(\n            train_kv=train_kv,\n            train_q_out=train_q_out,\n            hidden_size=hidden_size,\n            cross_attention_dim=cross_attention_dim,\n        ).to(unet.device)\n        custom_diffusion_attn_procs[name].load_state_dict(weights)\n    else:\n        custom_diffusion_attn_procs[name] = attention_class(\n            train_kv=False,\n            train_q_out=False,\n            hidden_size=hidden_size,\n            cross_attention_dim=cross_attention_dim,\n        )\ndel st\nunet.set_attn_processor(custom_diffusion_attn_procs)\ncustom_diffusion_layers = AttnProcsLayers(unet.attn_processors)\n```\n\nThe [optimizer](https://github.com/huggingface/diffusers/blob/84cd9e8d01adb47f046b1ee449fc76a0c32dc4e2/examples/custom_diffusion/train_custom_diffusion.py#L982) is initialized to update the cross-attention layer parameters:\n\n```py\noptimizer = optimizer_class(\n    itertools.chain(text_encoder.get_input_embeddings().parameters(), custom_diffusion_layers.parameters())\n    if args.modifier_token is not None\n    else custom_diffusion_layers.parameters(),\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\nIn the [training loop](https://github.com/huggingface/diffusers/blob/84cd9e8d01adb47f046b1ee449fc76a0c32dc4e2/examples/custom_diffusion/train_custom_diffusion.py#L1048), it is important to only update the embeddings for the concept you're trying to learn. This means setting the gradients of all the other token embeddings to zero:\n\n```py\nif args.modifier_token is not None:\n    if accelerator.num_processes > 1:\n        grads_text_encoder = text_encoder.module.get_input_embeddings().weight.grad\n    else:\n        grads_text_encoder = text_encoder.get_input_embeddings().weight.grad\n    index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0]\n    for i in range(len(modifier_token_id[1:])):\n        index_grads_to_zero = index_grads_to_zero & (\n            torch.arange(len(tokenizer)) != modifier_token_id[i]\n        )\n    grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[\n        index_grads_to_zero, :\n    ].fill_(0)\n```\n\n## Launch the script\n\nOnce you’ve made all your changes or you’re okay with the default configuration, you’re ready to launch the training script! 🚀\n\nIn this guide, you'll download and use these example [cat images](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip). You can also create and use your own dataset if you want (see the [Create a dataset for training](create_dataset) guide).\n\nSet the environment variable `MODEL_NAME` to a model id on the Hub or a path to a local model, `INSTANCE_DIR`  to the path where you just downloaded the cat images to, and `OUTPUT_DIR` to where you want to save the model. You'll use `<new1>` as the special word to tie the newly learned embeddings to. The script creates and saves model checkpoints and a pytorch_custom_diffusion_weights.bin file to your repository.\n\nTo monitor training progress with Weights and Biases, add the `--report_to=wandb` parameter to the training command and specify a validation prompt with `--validation_prompt`. This is useful for debugging and saving intermediate results.\n\n> [!TIP]\n> If you're training on human faces, the Custom Diffusion team has found the following parameters to work well:\n>\n> - `--learning_rate=5e-6`\n> - `--max_train_steps` can be anywhere between 1000 and 2000\n> - `--freeze_model=crossattn`\n> - use at least 15-20 images to train with\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"single concept\">\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport OUTPUT_DIR=\"path-to-save-model\"\nexport INSTANCE_DIR=\"./data/cat\"\n\naccelerate launch train_custom_diffusion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --class_data_dir=./real_reg/samples_cat/ \\\n  --with_prior_preservation \\\n  --real_prior \\\n  --prior_loss_weight=1.0 \\\n  --class_prompt=\"cat\" \\\n  --num_class_images=200 \\\n  --instance_prompt=\"photo of a <new1> cat\"  \\\n  --resolution=512  \\\n  --train_batch_size=2  \\\n  --learning_rate=1e-5  \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=250 \\\n  --scale_lr \\\n  --hflip  \\\n  --modifier_token \"<new1>\" \\\n  --validation_prompt=\"<new1> cat sitting in a bucket\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub\n```\n\n</hfoption>\n<hfoption id=\"multiple concepts\">\n\nCustom Diffusion can also learn multiple concepts if you provide a [JSON](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with some details about each concept it should learn.\n\nRun clip-retrieval to collect some real images to use for regularization:\n\n```bash\npip install clip-retrieval\npython retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200\n```\n\nThen you can launch the script:\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_custom_diffusion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --output_dir=$OUTPUT_DIR \\\n  --concepts_list=./concept_list.json \\\n  --with_prior_preservation \\\n  --real_prior \\\n  --prior_loss_weight=1.0 \\\n  --resolution=512  \\\n  --train_batch_size=2  \\\n  --learning_rate=1e-5  \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --num_class_images=200 \\\n  --scale_lr \\\n  --hflip  \\\n  --modifier_token \"<new1>+<new2>\" \\\n  --push_to_hub\n```\n\n</hfoption>\n</hfoptions>\n\nOnce training is finished, you can use your new Custom Diffusion model for inference.\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"single concept\">\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\", torch_dtype=torch.float16,\n).to(\"cuda\")\npipeline.unet.load_attn_procs(\"path-to-save-model\", weight_name=\"pytorch_custom_diffusion_weights.bin\")\npipeline.load_textual_inversion(\"path-to-save-model\", weight_name=\"<new1>.bin\")\n\nimage = pipeline(\n    \"<new1> cat sitting in a bucket\",\n    num_inference_steps=100,\n    guidance_scale=6.0,\n    eta=1.0,\n).images[0]\nimage.save(\"cat.png\")\n```\n\n</hfoption>\n<hfoption id=\"multiple concepts\">\n\n```py\nimport torch\nfrom huggingface_hub.repocard import RepoCard\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\", torch_dtype=torch.float16,\n).to(\"cuda\")\nmodel_id = \"sayakpaul/custom-diffusion-cat-wooden-pot\"\npipeline.unet.load_attn_procs(model_id, weight_name=\"pytorch_custom_diffusion_weights.bin\")\npipeline.load_textual_inversion(model_id, weight_name=\"<new1>.bin\")\npipeline.load_textual_inversion(model_id, weight_name=\"<new2>.bin\")\n\nimage = pipeline(\n    \"the <new1> cat sculpture in the style of a <new2> wooden pot\",\n    num_inference_steps=100,\n    guidance_scale=6.0,\n    eta=1.0,\n).images[0]\nimage.save(\"multi-subject.png\")\n```\n\n</hfoption>\n</hfoptions>\n\n## Next steps\n\nCongratulations on training a model with Custom Diffusion! 🎉 To learn more:\n\n- Read the [Multi-Concept Customization of Text-to-Image Diffusion](https://www.cs.cmu.edu/~custom-diffusion/) blog post to learn more details about the experimental results from the Custom Diffusion team."
  },
  {
    "path": "docs/source/en/training/ddpo.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Reinforcement learning training with DDPO\n\nYou can fine-tune Stable Diffusion on a reward function via reinforcement learning with the 🤗 TRL library and 🤗 Diffusers. This is done with the Denoising Diffusion Policy Optimization (DDPO) algorithm introduced by Black et al. in [Training Diffusion Models with Reinforcement Learning](https://huggingface.co/papers/2305.13301), which is implemented in 🤗 TRL with the [`~trl.DDPOTrainer`].\n\nFor more information, check out the [`~trl.DDPOTrainer`] API reference and the [Finetune Stable Diffusion Models with DDPO via TRL](https://huggingface.co/blog/trl-ddpo) blog post."
  },
  {
    "path": "docs/source/en/training/distributed_inference.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Distributed inference\n\nDistributed inference splits the workload across multiple GPUs. It a useful technique for fitting larger models in memory and can process multiple prompts for higher throughput.\n\nThis guide will show you how to use [Accelerate](https://huggingface.co/docs/accelerate/index) and [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html) for distributed inference.\n\n## Accelerate\n\nAccelerate is a library designed to simplify inference and training on multiple accelerators by handling the setup, allowing users to focus on their PyTorch code.\n\nInstall Accelerate with the following command.\n\n```bash\nuv pip install accelerate\n```\n\nInitialize a [`accelerate.PartialState`] class in a Python file to create a distributed environment. The [`accelerate.PartialState`] class manages process management, device control and distribution, and process coordination.\n\nMove the [`DiffusionPipeline`] to [`accelerate.PartialState.device`] to assign a GPU to each process.\n\n```py\nimport torch\nfrom accelerate import PartialState\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"Qwen/Qwen-Image\", torch_dtype=torch.float16\n)\ndistributed_state = PartialState()\npipeline.to(distributed_state.device)\n```\n\nUse the [`~accelerate.PartialState.split_between_processes`] utility as a context manager to automatically distribute the prompts between the number of processes.\n\n```py\nwith distributed_state.split_between_processes([\"a dog\", \"a cat\"]) as prompt:\n    result = pipeline(prompt).images[0]\n    result.save(f\"result_{distributed_state.process_index}.png\")\n```\n\nCall `accelerate launch` to run the script and use the `--num_processes` argument to set the number of GPUs to use.\n\n```bash\naccelerate launch run_distributed.py --num_processes=2\n```\n\n> [!TIP]\n> Refer to this minimal example [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for running inference across multiple GPUs. To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.\n\n## PyTorch Distributed\n\nPyTorch [DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) enables [data parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=data_parallelism), which replicates the same model on each device, to process different batches of data in parallel.\n\nImport `torch.distributed` and `torch.multiprocessing` into a Python file to set up the distributed process group and to spawn the processes for inference on each GPU.\n\n```py\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"Qwen/Qwen-Image\", torch_dtype=torch.float16,\n)\n```\n\nCreate a function for inference with [init_process_group](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group). This method creates a distributed environment with the backend type, the `rank` of the current process, and the `world_size` or number of processes participating (for example, 2 GPUs would be `world_size=2`).\n\nMove the pipeline to `rank` and use `get_rank` to assign a GPU to each process. Each process handles a different prompt.\n\n```py\ndef run_inference(rank, world_size):\n    dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n\n    pipeline.to(rank)\n\n    if torch.distributed.get_rank() == 0:\n        prompt = \"a dog\"\n    elif torch.distributed.get_rank() == 1:\n        prompt = \"a cat\"\n\n    image = sd(prompt).images[0]\n    image.save(f\"./{'_'.join(prompt)}.png\")\n```\n\nUse [mp.spawn](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to create the number of processes defined in `world_size`.\n\n```py\ndef main():\n    world_size = 2\n    mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)\n\n\nif __name__ == \"__main__\":\n    main()\n```\n\nCall `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.\n\n```bash\ntorchrun --nproc_per_node=2 run_distributed.py\n```\n\n## device_map\n\nThe `device_map` argument enables distributed inference by automatically placing model components on separate GPUs. This is especially useful when a model doesn't fit on a single GPU. You can use `device_map` to selectively load and unload the required model components at a given stage as shown in the example below (assumes two GPUs are available).\n\nSet `device_map=\"balanced\"` to evenly distributes the text encoders on all available GPUs. You can use the `max_memory` argument to allocate a maximum amount of memory for each text encoder. Don't load any other pipeline components to avoid memory usage.\n\n```py\nfrom diffusers import FluxPipeline\nimport torch\n\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\n\npipeline = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    transformer=None,\n    vae=None,\n    device_map=\"balanced\",\n    max_memory={0: \"16GB\", 1: \"16GB\"},\n    torch_dtype=torch.bfloat16\n)\nwith torch.no_grad():\n    print(\"Encoding prompts.\")\n    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(\n        prompt=prompt, prompt_2=None, max_sequence_length=512\n    )\n```\n\nAfter the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.\n\n```py\nimport gc \n\ndef flush():\n    gc.collect()\n    torch.cuda.empty_cache()\n    torch.cuda.reset_max_memory_allocated()\n    torch.cuda.reset_peak_memory_stats()\n\ndel pipeline.text_encoder\ndel pipeline.text_encoder_2\ndel pipeline.tokenizer\ndel pipeline.tokenizer_2\ndel pipeline\n\nflush()\n```\n\nSet `device_map=\"auto\"` to automatically distribute the model on the two GPUs. This strategy places a model on the fastest device first before placing a model on a slower device like a CPU or hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.\n\n```py\nfrom diffusers import AutoModel\nimport torch \n\ntransformer = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\", \n    subfolder=\"transformer\",\n    device_map=\"auto\",\n    torch_dtype=torch.bfloat16\n)\n```\n\n> [!TIP]\n> Run `pipeline.hf_device_map` to see how the various models are distributed across devices. This is useful for tracking model device placement. You can also call `hf_device_map` on the transformer model to see how it is distributed.\n\nAdd the transformer model to the pipeline and set the `output_type=\"latent\"` to generate the latents.\n\n```py\npipeline = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    text_encoder=None,\n    text_encoder_2=None,\n    tokenizer=None,\n    tokenizer_2=None,\n    vae=None,\n    transformer=transformer,\n    torch_dtype=torch.bfloat16\n)\n\nprint(\"Running denoising.\")\nheight, width = 768, 1360\nlatents = pipeline(\n    prompt_embeds=prompt_embeds,\n    pooled_prompt_embeds=pooled_prompt_embeds,\n    num_inference_steps=50,\n    guidance_scale=3.5,\n    height=height,\n    width=width,\n    output_type=\"latent\",\n).images\n```\n\nRemove the pipeline and transformer from memory and load a VAE to decode the latents. The VAE is typically small enough to be loaded on a single device.\n\n```py\nimport torch\nfrom diffusers import AutoencoderKL\nfrom diffusers.image_processor import VaeImageProcessor\n\nvae = AutoencoderKL.from_pretrained(ckpt_id, subfolder=\"vae\", torch_dtype=torch.bfloat16).to(\"cuda\")\nvae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)\nimage_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)\n\nwith torch.no_grad():\n    print(\"Running decoding.\")\n    latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)\n    latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor\n\n    image = vae.decode(latents, return_dict=False)[0]\n    image = image_processor.postprocess(image, output_type=\"pil\")\n    image[0].save(\"split_transformer.png\")\n```\n\nBy selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.\n\n## Context parallelism\n\n[Context parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism) splits input sequences across multiple GPUs to reduce memory usage. Each GPU processes its own slice of the sequence.\n\nUse [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.\n\nMost attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible.\n\n### Ring Attention\n\nKey (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.\n\nPass a [`ContextParallelConfig`] to the `parallel_config` argument of the transformer model. The config supports the `ring_degree` argument that determines how many devices to use for Ring Attention.\n\n```py\nimport torch\nfrom torch import distributed as dist\nfrom diffusers import DiffusionPipeline, ContextParallelConfig\n\ndef setup_distributed():\n    if not dist.is_initialized():\n        dist.init_process_group(backend=\"nccl\")\n    rank = dist.get_rank()\n    device = torch.device(f\"cuda:{rank}\")\n    torch.cuda.set_device(device)\n    return device\n\ndef main():\n    device = setup_distributed()\n    world_size = dist.get_world_size()\n\n    pipeline = DiffusionPipeline.from_pretrained(\n        \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n    ).to(device)\n    pipeline.transformer.set_attention_backend(\"_native_cudnn\")\n\n    cp_config = ContextParallelConfig(ring_degree=world_size)\n    pipeline.transformer.enable_parallelism(config=cp_config)\n\n    prompt = \"\"\"\n    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\n    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n    \"\"\"\n\n    # Must specify generator so all ranks start with same latents (or pass your own)\n    generator = torch.Generator().manual_seed(42)\n    image = pipeline(\n        prompt,\n        guidance_scale=3.5,\n        num_inference_steps=50,\n        generator=generator,\n    ).images[0]\n\n    if dist.get_rank() == 0:\n        image.save(f\"output.png\")\n\n    if dist.is_initialized():\n        dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    main()\n```\n\nThe script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available.\n\n```shell\ntorchrun --nproc-per-node 2 above_script.py\n```\n\n### Ulysses Attention\n\n[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.\n\n[`ContextParallelConfig`] supports Ulysses Attention through the `ulysses_degree` argument. This determines how many devices to use for Ulysses Attention.\n\nPass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].\n\n```py\n# Depending on the number of GPUs available.\npipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))\n```\n\n### Unified Attention\n\n[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout.\n\nThis hybrid approach leverages the strengths of both methods:\n- **Ulysses Attention** efficiently parallelizes across attention heads\n- **Ring Attention** handles very long sequences with minimal memory overhead\n- Together, they enable 2D parallelization across both heads and sequence dimensions\n\n[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).\nPass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`].\n\n```py\npipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2))\n```\n\n> [!TIP]\n> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).\n\nWe ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows:\n\n| CP Backend         | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) |\n|--------------------|------------------|-------------|------------------|\n| ulysses            | 6670.789         | 7.50        | 33.85            |\n| ring               | 13076.492        | 3.82        | 56.02            |\n| unified_balanced   | 11068.705        | 4.52        | 33.85            |\n\nFrom the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.\n\n\n### Ulysses Anything Attention\n\nThe default Ulysses Attention mechanism requires that the sequence length of hidden states must be divisible by the number of devices. This imposes significant limitations on the practical application of Ulysses Attention. [Ulysses Anything Attention](https://github.com/huggingface/diffusers/pull/12996) is a variant of Ulysses Attention that supports arbitrary sequence lengths and arbitrary numbers of attention heads, thereby enhancing the versatility of Ulysses Attention in practical use.\n\n[`ContextParallelConfig`] supports Ulysses Anything Attention by specifying both `ulysses_degree` and `ulysses_anything`. Please note that Ulysses Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with both `ulysses_degree` set to bigger than 1 and `ulysses_anything=True` to [`~ModelMixin.enable_parallelism`].\n\n```py\npipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ulysses_anything=True))\n```\n\n> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency.\n\nWe ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996#issuecomment-3797695999) on a node of 4 L20 GPUs. The results are summarized as follows:\n\n| CP Backend         | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)|\n|--------------------|------------------|-------------|------------------|------------|\n| ulysses            |   281.07         |    3.56     |     37.11        | 1024x1024  |\n| ring               |   351.34         |    2.85     |     37.01        | 1024x1024  |\n| unified_balanced   |   324.37         |    3.08     |     37.16        | 1024x1024  |\n| ulysses_anything   |   280.94         |    3.56     |     37.11        | 1024x1024  |\n| ulysses            |   failed         |    failed   |     failed       | 1008x1008  |\n| ring               |   failed         |    failed   |     failed       | 1008x1008  |\n| unified_balanced   |   failed         |    failed   |     failed       | 1008x1008  |\n| ulysses_anything   |   278.40         |    3.59     |     36.99        | 1008x1008  |\n\nFrom the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.\n\n### parallel_config\n\nPass `parallel_config` during model initialization to enable context parallelism.\n\n```py\nCKPT_ID = \"black-forest-labs/FLUX.1-dev\"\n\ncp_config = ContextParallelConfig(ring_degree=2)\ntransformer = AutoModel.from_pretrained(\n    CKPT_ID, \n    subfolder=\"transformer\", \n    torch_dtype=torch.bfloat16, \n    parallel_config=cp_config\n)\n\npipeline = DiffusionPipeline.from_pretrained(\n    CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,\n).to(device)\n```\n"
  },
  {
    "path": "docs/source/en/training/dreambooth.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DreamBooth\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a training technique that updates the entire diffusion model by training on just a few images of a subject or style. It works by associating a special word in the prompt with the example images.\n\nIf you're training on a GPU with limited vRAM, you should try enabling the `gradient_checkpointing` and `mixed_precision` parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with [xFormers](../optimization/xformers).\n\nThis guide will explore the [train_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.\n\nBefore running the script, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nNavigate to the example folder with the training script and install the required dependencies for the script you're using:\n\n```bash\ncd examples/dreambooth\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell, like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n> [!TIP]\n> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) and let us know if you have any questions or concerns.\n\n## Script parameters\n\n> [!WARNING]\n> DreamBooth is very sensitive to training hyperparameters, and it is easy to overfit. Read the [Training Stable Diffusion with Dreambooth using 🧨 Diffusers](https://huggingface.co/blog/dreambooth) blog post for recommended settings for different subjects to help you choose the appropriate hyperparameters.\n\nThe training script offers many parameters for customizing your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L228) function. The parameters are set with default values that should work pretty well out-of-the-box, but you can also set your own values in the training command if you'd like.\n\nFor example, to train in the bf16 format:\n\n```bash\naccelerate launch train_dreambooth.py \\\n    --mixed_precision=\"bf16\"\n```\n\nSome basic and important parameters to know and specify are:\n\n- `--pretrained_model_name_or_path`: the name of the model on the Hub or a local path to the pretrained model\n- `--instance_data_dir`: path to a folder containing the training dataset (example images)\n- `--instance_prompt`: the text prompt that contains the special word for the example images\n- `--train_text_encoder`: whether to also train the text encoder\n- `--output_dir`: where to save the trained model\n- `--push_to_hub`: whether to push the trained model to the Hub\n- `--checkpointing_steps`: frequency of saving a checkpoint as the model trains; this is useful if for some reason training is interrupted, you can continue training from that checkpoint by adding `--resume_from_checkpoint` to your training command\n\n### Min-SNR weighting\n\nThe [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch.\n\nAdd the `--snr_gamma` parameter and set it to the recommended value of 5.0:\n\n```bash\naccelerate launch train_dreambooth.py \\\n  --snr_gamma=5.0\n```\n\n### Prior preservation loss\n\nPrior preservation loss is a method that uses a model's own generated samples to help it learn how to generate more diverse images. Because these generated sample images belong to the same class as the images you provided, they help the model retain what it has learned about the class and how it can use what it already knows about the class to make new compositions.\n\n- `--with_prior_preservation`: whether to use prior preservation loss\n- `--prior_loss_weight`: controls the influence of the prior preservation loss on the model\n- `--class_data_dir`: path to a folder containing the generated class sample images\n- `--class_prompt`: the text prompt describing the class of the generated sample images\n\n```bash\naccelerate launch train_dreambooth.py \\\n  --with_prior_preservation \\\n  --prior_loss_weight=1.0 \\\n  --class_data_dir=\"path/to/class/images\" \\\n  --class_prompt=\"text prompt describing class\"\n```\n\n### Train text encoder\n\nTo improve the quality of the generated outputs, you can also train the text encoder in addition to the UNet. This requires additional memory and you'll need a GPU with at least 24GB of vRAM. If you have the necessary hardware, then training the text encoder produces better results, especially when generating images of faces. Enable this option by:\n\n```bash\naccelerate launch train_dreambooth.py \\\n  --train_text_encoder\n```\n\n## Training script\n\nDreamBooth comes with its own dataset classes:\n\n- [`DreamBoothDataset`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L604): preprocesses the images and class images, and tokenizes the prompts for training\n- [`PromptDataset`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L738): generates the prompt embeddings to generate the class images\n\nIf you enabled [prior preservation loss](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L842), the class images are generated here:\n\n```py\nsample_dataset = PromptDataset(args.class_prompt, num_new_images)\nsample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\nsample_dataloader = accelerator.prepare(sample_dataloader)\npipeline.to(accelerator.device)\n\nfor example in tqdm(\n    sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n):\n    images = pipeline(example[\"prompt\"]).images\n```\n\nNext is the [`main()`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L799) function which handles setting up the dataset for training and the training loop itself. The script loads the [tokenizer](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L898), [scheduler and models](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L912C1-L912C1):\n\n```py\n# Load the tokenizer\nif args.tokenizer_name:\n    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)\nelif args.pretrained_model_name_or_path:\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n# Load scheduler and models\nnoise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\ntext_encoder = text_encoder_cls.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n)\n\nif model_has_vae(args):\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision\n    )\nelse:\n    vae = None\n\nunet = UNet2DConditionModel.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n)\n```\n\nThen, it's time to [create the training dataset](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L1073) and DataLoader from `DreamBoothDataset`:\n\n```py\ntrain_dataset = DreamBoothDataset(\n    instance_data_root=args.instance_data_dir,\n    instance_prompt=args.instance_prompt,\n    class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n    class_prompt=args.class_prompt,\n    class_num=args.num_class_images,\n    tokenizer=tokenizer,\n    size=args.resolution,\n    center_crop=args.center_crop,\n    encoder_hidden_states=pre_computed_encoder_hidden_states,\n    class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,\n    tokenizer_max_length=args.tokenizer_max_length,\n)\n\ntrain_dataloader = torch.utils.data.DataLoader(\n    train_dataset,\n    batch_size=args.train_batch_size,\n    shuffle=True,\n    collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n    num_workers=args.dataloader_num_workers,\n)\n```\n\nLastly, the [training loop](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L1151) takes care of the remaining steps such as converting images to latent space, adding noise to the input, predicting the noise residual, and calculating the loss.\n\nIf you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.\n\n## Launch the script\n\nYou're now ready to launch the training script! 🚀\n\nFor this guide, you'll download some images of a [dog](https://huggingface.co/datasets/diffusers/dog-example) and store them in a directory. But remember, you can create and use your own dataset if you want (see the [Create a dataset for training](create_dataset) guide).\n\n```py\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog\"\nsnapshot_download(\n    \"diffusers/dog-example\",\n    local_dir=local_dir,\n    repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nSet the environment variable `MODEL_NAME` to a model id on the Hub or a path to a local model, `INSTANCE_DIR` to the path where you just downloaded the dog images to, and `OUTPUT_DIR` to where you want to save the model. You'll use `sks` as the special word to tie the training to.\n\nIf you're interested in following along with the training process, you can periodically save generated images as training progresses. Add the following parameters to the training command:\n\n```bash\n--validation_prompt=\"a photo of a sks dog\"\n--num_validation_images=4\n--validation_steps=100\n```\n\nOne more thing before you launch the script! Depending on the GPU you have, you may need to enable certain optimizations to train DreamBooth.\n\n<hfoptions id=\"gpu-select\">\n<hfoption id=\"16GB\">\n\nOn a 16GB GPU, you can use bitsandbytes 8-bit optimizer and gradient checkpointing to help you train a DreamBooth model. Install bitsandbytes:\n\n```py\npip install bitsandbytes\n```\n\nThen, add the following parameter to your training command:\n\n```bash\naccelerate launch train_dreambooth.py \\\n  --gradient_checkpointing \\\n  --use_8bit_adam \\\n```\n\n</hfoption>\n<hfoption id=\"12GB\">\n\nOn a 12GB GPU, you'll need bitsandbytes 8-bit optimizer, gradient checkpointing, xFormers, and set the gradients to `None` instead of zero to reduce your memory-usage.\n\n```bash\naccelerate launch train_dreambooth.py \\\n  --use_8bit_adam \\\n  --gradient_checkpointing \\\n  --enable_xformers_memory_efficient_attention \\\n  --set_grads_to_none \\\n```\n\n</hfoption>\n<hfoption id=\"8GB\">\n\nOn a 8GB GPU, you'll need [DeepSpeed](https://www.deepspeed.ai/) to offload some of the tensors from the vRAM to either the CPU or NVME to allow training with less GPU memory.\n\nRun the following command to configure your 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nDuring configuration, confirm that you want to use DeepSpeed. Now it should be possible to train on under 8GB vRAM by combining DeepSpeed stage 2, fp16 mixed precision, and offloading the model parameters and the optimizer state to the CPU. The drawback is that this requires more system RAM (~25 GB). See the [DeepSpeed documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more configuration options.\n\nYou should also change the default Adam optimizer to DeepSpeed’s optimized version of Adam [`deepspeed.ops.adam.DeepSpeedCPUAdam`](https://deepspeed.readthedocs.io/en/latest/optimizers.html#adam-cpu) for a substantial speedup. Enabling `DeepSpeedCPUAdam` requires your system’s CUDA toolchain version to be the same as the one installed with PyTorch.\n\nbitsandbytes 8-bit optimizers don’t seem to be compatible with DeepSpeed at the moment.\n\nThat's it! You don't need to add any additional parameters to your training command.\n\n</hfoption>\n</hfoptions>\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport INSTANCE_DIR=\"./dog\"\nexport OUTPUT_DIR=\"path_to_saved_model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=400 \\\n  --push_to_hub\n```\n\nOnce training is complete, you can use your newly trained model for inference!\n\n> [!TIP]\n> Can't wait to try your model for inference before training is complete? 🤭 Make sure you have the latest version of 🤗 Accelerate installed.\n>\n> ```py\n> from diffusers import DiffusionPipeline, UNet2DConditionModel\n> from transformers import CLIPTextModel\n> import torch\n>\n> unet = UNet2DConditionModel.from_pretrained(\"path/to/model/checkpoint-100/unet\")\n>\n> # if you have trained with `--args.train_text_encoder` make sure to also load the text encoder\n> text_encoder = CLIPTextModel.from_pretrained(\"path/to/model/checkpoint-100/checkpoint-100/text_encoder\")\n>\n> pipeline = DiffusionPipeline.from_pretrained(\n>     \"stable-diffusion-v1-5/stable-diffusion-v1-5\", unet=unet, text_encoder=text_encoder, dtype=torch.float16,\n> ).to(\"cuda\")\n>\n> image = pipeline(\"A photo of sks dog in a bucket\", num_inference_steps=50, guidance_scale=7.5).images[0]\n> image.save(\"dog-bucket.png\")\n> ```\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(\"path_to_saved_model\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\nimage = pipeline(\"A photo of sks dog in a bucket\", num_inference_steps=50, guidance_scale=7.5).images[0]\nimage.save(\"dog-bucket.png\")\n```\n\n## LoRA\n\nLoRA is a training technique for significantly reducing the number of trainable parameters. As a result, training is faster and it is easier to store the resulting weights because they are a lot smaller (~100MBs). Use the [train_dreambooth_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py) script to train with LoRA.\n\nThe LoRA training script is discussed in more detail in the [LoRA training](lora) guide.\n\n## Stable Diffusion XL\n\nStable Diffusion XL (SDXL) is a powerful text-to-image model that generates high-resolution images, and it adds a second text-encoder to its architecture. Use the [train_dreambooth_lora_sdxl.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py) script to train a SDXL model with LoRA.\n\nThe SDXL training script is discussed in more detail in the [SDXL training](sdxl) guide.\n\n## DeepFloyd IF\n\nDeepFloyd IF is a cascading pixel diffusion model with three stages. The first stage generates a base image and the second and third stages progressively upscales the base image into a high-resolution 1024x1024 image. Use the [train_dreambooth_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py) or [train_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) scripts to train a DeepFloyd IF model with LoRA or the full model.\n\nDeepFloyd IF uses predicted variance, but the Diffusers training scripts uses predicted error so the trained DeepFloyd IF models are switched to a fixed variance schedule. The training scripts will update the scheduler config of the fully trained model for you. However, when you load the saved LoRA weights you must also update the pipeline's scheduler config.\n\n```py\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\", use_safetensors=True)\n\npipe.load_lora_weights(\"<lora weights path>\")\n\n# Update scheduler config to fixed variance schedule\npipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type=\"fixed_small\")\n```\n\nThe stage 2 model requires additional validation images to upscale. You can download and use a downsized version of the training images for this.\n\n```py\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog_downsized\"\nsnapshot_download(\n    \"diffusers/dog-example-downsized\",\n    local_dir=local_dir,\n    repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nThe code samples below provide a brief overview of how to train a DeepFloyd IF model with a combination of DreamBooth and LoRA. Some important parameters to note are:\n\n* `--resolution=64`, a much smaller resolution is required because DeepFloyd IF is a pixel diffusion model and to work on uncompressed pixels, the input images must be smaller\n* `--pre_compute_text_embeddings`, compute the text embeddings ahead of time to save memory because the [`~transformers.T5Model`] can take up a lot of memory\n* `--tokenizer_max_length=77`, you can use a longer default text length with T5 as the text encoder but the default model encoding procedure uses a shorter text length\n* `--text_encoder_use_attention_mask`, to pass the attention mask to the text encoder\n\n<hfoptions id=\"IF-DreamBooth\">\n<hfoption id=\"Stage 1 LoRA DreamBooth\">\n\nTraining stage 1 of DeepFloyd IF with LoRA and DreamBooth requires ~28GB of memory.\n\n```bash\nexport MODEL_NAME=\"DeepFloyd/IF-I-XL-v1.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"dreambooth_dog_lora\"\n\naccelerate launch train_dreambooth_lora.py \\\n  --report_to wandb \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a sks dog\" \\\n  --resolution=64 \\\n  --train_batch_size=4 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --scale_lr \\\n  --max_train_steps=1200 \\\n  --validation_prompt=\"a sks dog\" \\\n  --validation_epochs=25 \\\n  --checkpointing_steps=100 \\\n  --pre_compute_text_embeddings \\\n  --tokenizer_max_length=77 \\\n  --text_encoder_use_attention_mask\n```\n\n</hfoption>\n<hfoption id=\"Stage 2 LoRA DreamBooth\">\n\nFor stage 2 of DeepFloyd IF with LoRA and DreamBooth, pay attention to these parameters:\n\n* `--validation_images`, the images to upscale during validation\n* `--class_labels_conditioning=timesteps`, to additionally conditional the UNet as needed in stage 2\n* `--learning_rate=1e-6`, a lower learning rate is used compared to stage 1\n* `--resolution=256`, the expected resolution for the upscaler\n\n```bash\nexport MODEL_NAME=\"DeepFloyd/IF-II-L-v1.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"dreambooth_dog_upscale\"\nexport VALIDATION_IMAGES=\"dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png\"\n\npython train_dreambooth_lora.py \\\n    --report_to wandb \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --instance_data_dir=$INSTANCE_DIR \\\n    --output_dir=$OUTPUT_DIR \\\n    --instance_prompt=\"a sks dog\" \\\n    --resolution=256 \\\n    --train_batch_size=4 \\\n    --gradient_accumulation_steps=1 \\\n    --learning_rate=1e-6 \\\n    --max_train_steps=2000 \\\n    --validation_prompt=\"a sks dog\" \\\n    --validation_epochs=100 \\\n    --checkpointing_steps=500 \\\n    --pre_compute_text_embeddings \\\n    --tokenizer_max_length=77 \\\n    --text_encoder_use_attention_mask \\\n    --validation_images $VALIDATION_IMAGES \\\n    --class_labels_conditioning=timesteps\n```\n\n</hfoption>\n<hfoption id=\"Stage 1 DreamBooth\">\n\nFor stage 1 of DeepFloyd IF with DreamBooth, pay attention to these parameters:\n\n* `--skip_save_text_encoder`, to skip saving the full T5 text encoder with the finetuned model\n* `--use_8bit_adam`, to use 8-bit Adam optimizer to save memory due to the size of the optimizer state when training the full model\n* `--learning_rate=1e-7`, a really low learning rate should be used for full model training otherwise the model quality is degraded (you can use a higher learning rate with a larger batch size)\n\nTraining with 8-bit Adam and a batch size of 4, the full model can be trained with ~48GB of memory.\n\n```bash\nexport MODEL_NAME=\"DeepFloyd/IF-I-XL-v1.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"dreambooth_if\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=64 \\\n  --train_batch_size=4 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=1e-7 \\\n  --max_train_steps=150 \\\n  --validation_prompt \"a photo of sks dog\" \\\n  --validation_steps 25 \\\n  --text_encoder_use_attention_mask \\\n  --tokenizer_max_length 77 \\\n  --pre_compute_text_embeddings \\\n  --use_8bit_adam \\\n  --set_grads_to_none \\\n  --skip_save_text_encoder \\\n  --push_to_hub\n```\n\n</hfoption>\n<hfoption id=\"Stage 2 DreamBooth\">\n\nFor stage 2 of DeepFloyd IF with DreamBooth, pay attention to these parameters:\n\n* `--learning_rate=5e-6`, use a lower learning rate with a smaller effective batch size\n* `--resolution=256`, the expected resolution for the upscaler\n* `--train_batch_size=2` and `--gradient_accumulation_steps=6`, to effectively train on images with faces requires larger batch sizes\n\n```bash\nexport MODEL_NAME=\"DeepFloyd/IF-II-L-v1.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"dreambooth_dog_upscale\"\nexport VALIDATION_IMAGES=\"dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png\"\n\naccelerate launch train_dreambooth.py \\\n  --report_to wandb \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a sks dog\" \\\n  --resolution=256 \\\n  --train_batch_size=2 \\\n  --gradient_accumulation_steps=6 \\\n  --learning_rate=5e-6 \\\n  --max_train_steps=2000 \\\n  --validation_prompt=\"a sks dog\" \\\n  --validation_steps=150 \\\n  --checkpointing_steps=500 \\\n  --pre_compute_text_embeddings \\\n  --tokenizer_max_length=77 \\\n  --text_encoder_use_attention_mask \\\n  --validation_images $VALIDATION_IMAGES \\\n  --class_labels_conditioning timesteps \\\n  --push_to_hub\n```\n\n</hfoption>\n</hfoptions>\n\n### Training tips\n\nTraining the DeepFloyd IF model can be challenging, but here are some tips that we've found helpful:\n\n- LoRA is sufficient for training the stage 1 model because the model's low resolution makes representing finer details difficult regardless.\n- For common or simple objects, you don't necessarily need to finetune the upscaler. Make sure the prompt passed to the upscaler is adjusted to remove the new token from the instance prompt. For example, if your stage 1 prompt is \"a sks dog\" then your stage 2 prompt should be \"a dog\".\n- For finer details like faces, fully training the stage 2 upscaler is better than training the stage 2 model with LoRA. It also helps to use lower learning rates with larger batch sizes.\n- Lower learning rates should be used to train the stage 2 model.\n- The [`DDPMScheduler`] works better than the DPMSolver used in the training scripts.\n\n## Next steps\n\nCongratulations on training your DreamBooth model! To learn more about how to use your new model, the following guide may be helpful:\n\n- Learn how to [load a DreamBooth](../using-diffusers/dreambooth) model for inference if you trained your model with LoRA."
  },
  {
    "path": "docs/source/en/training/instructpix2pix.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# InstructPix2Pix\n\n[InstructPix2Pix](https://hf.co/papers/2211.09800) is a Stable Diffusion model trained to edit images from human-provided instructions. For example, your prompt can be \"turn the clouds rainy\" and the model will edit the input image accordingly. This model is conditioned on the text prompt (or editing instruction) and the input image.\n\nThis guide will explore the [train_instruct_pix2pix.py](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) training script to help you become familiar with it, and how you can adapt it for your own use case.\n\nBefore running the script, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen navigate to the example folder containing the training script and install the required dependencies for the script you're using:\n\n```bash\ncd examples/instruct_pix2pix\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell, like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n> [!TIP]\n> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) and let us know if you have any questions or concerns.\n\n## Script parameters\n\nThe training script has many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L65) function. Default values are provided for most parameters that work pretty well, but you can also set your own values in the training command if you'd like.\n\nFor example, to increase the resolution of the input image:\n\n```bash\naccelerate launch train_instruct_pix2pix.py \\\n  --resolution=512 \\\n```\n\nMany of the basic and important parameters are described in the [Text-to-image](text2image#script-parameters) training guide, so this guide just focuses on the relevant parameters for InstructPix2Pix:\n\n- `--original_image_column`: the original image before the edits are made\n- `--edited_image_column`: the image after the edits are made\n- `--edit_prompt_column`: the instructions to edit the image\n- `--conditioning_dropout_prob`: the dropout probability for the edited image and edit prompts during training which enables classifier-free guidance (CFG) for one or both conditioning inputs\n\n## Training script\n\nThe dataset preprocessing code and training loop are found in the [`main()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L374) function. This is where you'll make your changes to the training script to adapt it for your own use-case.\n\nAs with the script parameters, a walkthrough of the training script is provided in the [Text-to-image](text2image#training-script) training guide. Instead, this guide takes a look at the InstructPix2Pix relevant parts of the script.\n\nThe script begins by modifying the [number of input channels](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L445) in the first convolutional layer of the UNet to account for InstructPix2Pix's additional conditioning image:\n\n```py\nin_channels = 8\nout_channels = unet.conv_in.out_channels\nunet.register_to_config(in_channels=in_channels)\n\nwith torch.no_grad():\n    new_conv_in = nn.Conv2d(\n        in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding\n    )\n    new_conv_in.weight.zero_()\n    new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)\n    unet.conv_in = new_conv_in\n```\n\nThese UNet parameters are [updated](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L545C1-L551C6) by the optimizer:\n\n```py\noptimizer = optimizer_cls(\n    unet.parameters(),\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\nNext, the edited images and edit instructions are [preprocessed](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L624) and [tokenized](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L610C24-L610C24). It is important the same image transformations are applied to the original and edited images.\n\n```py\ndef preprocess_train(examples):\n    preprocessed_images = preprocess_images(examples)\n\n    original_images, edited_images = preprocessed_images.chunk(2)\n    original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)\n    edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)\n\n    examples[\"original_pixel_values\"] = original_images\n    examples[\"edited_pixel_values\"] = edited_images\n\n    captions = list(examples[edit_prompt_column])\n    examples[\"input_ids\"] = tokenize_captions(captions)\n    return examples\n```\n\nFinally, in the [training loop](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L730), it starts by encoding the edited images into latent space:\n\n```py\nlatents = vae.encode(batch[\"edited_pixel_values\"].to(weight_dtype)).latent_dist.sample()\nlatents = latents * vae.config.scaling_factor\n```\n\nThen, the script applies dropout to the original image and edit instruction embeddings to support CFG. This is what enables the model to modulate the influence of the edit instruction and original image on the edited image.\n\n```py\nencoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\noriginal_image_embeds = vae.encode(batch[\"original_pixel_values\"].to(weight_dtype)).latent_dist.mode()\n\nif args.conditioning_dropout_prob is not None:\n    random_p = torch.rand(bsz, device=latents.device, generator=generator)\n    prompt_mask = random_p < 2 * args.conditioning_dropout_prob\n    prompt_mask = prompt_mask.reshape(bsz, 1, 1)\n    null_conditioning = text_encoder(tokenize_captions([\"\"]).to(accelerator.device))[0]\n    encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)\n\n    image_mask_dtype = original_image_embeds.dtype\n    image_mask = 1 - (\n        (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)\n        * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)\n    )\n    image_mask = image_mask.reshape(bsz, 1, 1, 1)\n    original_image_embeds = image_mask * original_image_embeds\n```\n\nThat's pretty much it! Aside from the differences described here, the rest of the script is very similar to the [Text-to-image](text2image#training-script) training script, so feel free to check it out for more details. If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.\n\n## Launch the script\n\nOnce you're happy with the changes to your script or if you're okay with the default configuration, you're ready to launch the training script! 🚀\n\nThis guide uses the [fusing/instructpix2pix-1000-samples](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) dataset, which is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered). You can also create and use your own dataset if you'd like (see the [Create a dataset for training](create_dataset) guide).\n\nSet the `MODEL_NAME` environment variable to the name of the model (can be a model id on the Hub or a path to a local model), and the `DATASET_ID` to the name of the dataset on the Hub. The script creates and saves all the components (feature extractor, scheduler, text encoder, UNet, etc.) to a subfolder in your repository.\n\n> [!TIP]\n> For better results, try longer training runs with a larger dataset. We've only tested this training script on a smaller-scale dataset.\n>\n> <br>\n>\n> To monitor training progress with Weights and Biases, add the `--report_to=wandb` parameter to the training command and specify a validation image with `--val_image_url` and a validation prompt with `--validation_prompt`. This can be really useful for debugging the model.\n\nIf you’re training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" train_instruct_pix2pix.py \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --dataset_name=$DATASET_ID \\\n    --enable_xformers_memory_efficient_attention \\\n    --resolution=256 \\\n    --random_flip \\\n    --train_batch_size=4 \\\n    --gradient_accumulation_steps=4 \\\n    --gradient_checkpointing \\\n    --max_train_steps=15000 \\\n    --checkpointing_steps=5000 \\\n    --checkpoints_total_limit=1 \\\n    --learning_rate=5e-05 \\\n    --max_grad_norm=1 \\\n    --lr_warmup_steps=0 \\\n    --conditioning_dropout_prob=0.05 \\\n    --mixed_precision=fp16 \\\n    --seed=42 \\\n    --push_to_hub\n```\n\nAfter training is finished, you can use your new InstructPix2Pix for inference:\n\n```py\nimport PIL\nimport requests\nimport torch\nfrom diffusers import StableDiffusionInstructPix2PixPipeline\nfrom diffusers.utils import load_image\n\npipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(\"your_cool_model\", torch_dtype=torch.float16).to(\"cuda\")\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\n\nimage = load_image(\"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png\")\nprompt = \"add some ducks to the lake\"\nnum_inference_steps = 20\nimage_guidance_scale = 1.5\nguidance_scale = 10\n\nedited_image = pipeline(\n   prompt,\n   image=image,\n   num_inference_steps=num_inference_steps,\n   image_guidance_scale=image_guidance_scale,\n   guidance_scale=guidance_scale,\n   generator=generator,\n).images[0]\nedited_image.save(\"edited_image.png\")\n```\n\nYou should experiment with different `num_inference_steps`, `image_guidance_scale`, and `guidance_scale` values to see how they affect inference speed and quality. The guidance scale parameters are especially impactful because they control how much the original image and edit instructions affect the edited image.\n\n## Stable Diffusion XL\n\nStable Diffusion XL (SDXL) is a powerful text-to-image model that generates high-resolution images, and it adds a second text-encoder to its architecture. Use the [`train_instruct_pix2pix_sdxl.py`](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py) script to train a SDXL model to follow image editing instructions.\n\nThe SDXL training script is discussed in more detail in the [SDXL training](sdxl) guide.\n\n## Next steps\n\nCongratulations on training your own InstructPix2Pix model! 🥳 To learn more about the model, it may be helpful to:\n\n- Read the [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd) blog post to learn more about some experiments we've done with InstructPix2Pix, dataset preparation, and results for different instructions.\n"
  },
  {
    "path": "docs/source/en/training/kandinsky.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Kandinsky 2.2\n\n> [!WARNING]\n> This script is experimental, and it's easy to overfit and run into issues like catastrophic forgetting. Try exploring different hyperparameters to get the best results on your dataset.\n\nKandinsky 2.2 is a multilingual text-to-image model capable of producing more photorealistic images. The model includes an image prior model for creating image embeddings from text prompts, and a decoder model that generates images based on the prior model's embeddings. That's why you'll find two separate scripts in Diffusers for Kandinsky 2.2, one for training the prior model and one for training the decoder model. You can train both models separately, but to get the best results, you should train both the prior and decoder models.\n\nDepending on your GPU, you may need to enable `gradient_checkpointing` (⚠️ not supported for the prior model!), `mixed_precision`, and `gradient_accumulation_steps` to help fit the model into memory and to speedup training. You can reduce your memory-usage even more by enabling memory-efficient attention with [xFormers](../optimization/xformers) (version [v0.0.16](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212) fails for training on some GPUs so you may need to install a development version instead).\n\nThis guide explores the [train_text_to_image_prior.py](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py) and the [train_text_to_image_decoder.py](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py) scripts to help you become more familiar with it, and how you can adapt it for your own use-case.\n\nBefore running the scripts, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen navigate to the example folder containing the training script and install the required dependencies for the script you're using:\n\n```bash\ncd examples/kandinsky2_2/text_to_image\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell, like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n> [!TIP]\n> The following sections highlight parts of the training scripts that are important for understanding how to modify it, but it doesn't cover every aspect of the scripts in detail. If you're interested in learning more, feel free to read through the scripts and let us know if you have any questions or concerns.\n\n## Script parameters\n\nThe training scripts provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L190) function. The training scripts provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.\n\nFor example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command:\n\n```bash\naccelerate launch train_text_to_image_prior.py \\\n  --mixed_precision=\"fp16\"\n```\n\nMost of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so let's get straight to a walkthrough of the Kandinsky training scripts!\n\n### Min-SNR weighting\n\nThe [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch.\n\nAdd the `--snr_gamma` parameter and set it to the recommended value of 5.0:\n\n```bash\naccelerate launch train_text_to_image_prior.py \\\n  --snr_gamma=5.0\n```\n\n## Training script\n\nThe training script is also similar to the [Text-to-image](text2image#training-script) training guide, but it's been modified to support training the prior and decoder models. This guide focuses on the code that is unique to the Kandinsky 2.2 training scripts.\n\n<hfoptions id=\"script\">\n<hfoption id=\"prior model\">\n\nThe [`main()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L441) function contains the code for preparing the dataset and training the model.\n\nOne of the main differences you'll notice right away is that the training script also loads a [`~transformers.CLIPImageProcessor`] - in addition to a scheduler and tokenizer - for preprocessing images and a [`~transformers.CLIPVisionModelWithProjection`] model for encoding the images:\n\n```py\nnoise_scheduler = DDPMScheduler(beta_schedule=\"squaredcos_cap_v2\", prediction_type=\"sample\")\nimage_processor = CLIPImageProcessor.from_pretrained(\n    args.pretrained_prior_model_name_or_path, subfolder=\"image_processor\"\n)\ntokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"tokenizer\")\n\nwith ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n    image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"image_encoder\", torch_dtype=weight_dtype\n    ).eval()\n    text_encoder = CLIPTextModelWithProjection.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"text_encoder\", torch_dtype=weight_dtype\n    ).eval()\n```\n\nKandinsky uses a [`PriorTransformer`] to generate the image embeddings, so you'll want to setup the optimizer to learn the prior mode's parameters.\n\n```py\nprior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"prior\")\nprior.train()\noptimizer = optimizer_cls(\n    prior.parameters(),\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\nNext, the input captions are tokenized, and images are [preprocessed](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L632) by the [`~transformers.CLIPImageProcessor`]:\n\n```py\ndef preprocess_train(examples):\n    images = [image.convert(\"RGB\") for image in examples[image_column]]\n    examples[\"clip_pixel_values\"] = image_processor(images, return_tensors=\"pt\").pixel_values\n    examples[\"text_input_ids\"], examples[\"text_mask\"] = tokenize_captions(examples)\n    return examples\n```\n\nFinally, the [training loop](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L718) converts the input images into latents, adds noise to the image embeddings, and makes a prediction:\n\n```py\nmodel_pred = prior(\n    noisy_latents,\n    timestep=timesteps,\n    proj_embedding=prompt_embeds,\n    encoder_hidden_states=text_encoder_hidden_states,\n    attention_mask=text_mask,\n).predicted_image_embedding\n```\n\nIf you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.\n\n</hfoption>\n<hfoption id=\"decoder model\">\n\nThe [`main()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py#L440) function contains the code for preparing the dataset and training the model.\n\nUnlike the prior model, the decoder initializes a [`VQModel`] to decode the latents into images and it uses a [`UNet2DConditionModel`]:\n\n```py\nwith ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n    vae = VQModel.from_pretrained(\n        args.pretrained_decoder_model_name_or_path, subfolder=\"movq\", torch_dtype=weight_dtype\n    ).eval()\n    image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"image_encoder\", torch_dtype=weight_dtype\n    ).eval()\nunet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder=\"unet\")\n```\n\nNext, the script includes several image transforms and a [preprocessing](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py#L622) function for applying the transforms to the images and returning the pixel values:\n\n```py\ndef preprocess_train(examples):\n    images = [image.convert(\"RGB\") for image in examples[image_column]]\n    examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n    examples[\"clip_pixel_values\"] = image_processor(images, return_tensors=\"pt\").pixel_values\n    return examples\n```\n\nLastly, the [training loop](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py#L706) handles converting the images to latents, adding noise, and predicting the noise residual.\n\nIf you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.\n\n```py\nmodel_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]\n```\n\n</hfoption>\n</hfoptions>\n\n## Launch the script\n\nOnce you’ve made all your changes or you’re okay with the default configuration, you’re ready to launch the training script! 🚀\n\nYou'll train on the [Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) dataset to generate your own Naruto characters, but you can also create and train on your own dataset by following the [Create a dataset for training](create_dataset) guide. Set the environment variable `DATASET_NAME` to the name of the dataset on the Hub or if you're training on your own files, set the environment variable `TRAIN_DIR` to a path to your dataset.\n\nIf you’re training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.\n\n> [!TIP]\n> To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You’ll also need to add the `--validation_prompt` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"prior model\">\n\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image_prior.py \\\n  --dataset_name=$DATASET_NAME \\\n  --resolution=768 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --checkpoints_total_limit=3 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --validation_prompts=\"A robot naruto, 4k photo\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub \\\n  --output_dir=\"kandi2-prior-naruto-model\"\n```\n\n</hfoption>\n<hfoption id=\"decoder model\">\n\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image_decoder.py \\\n  --dataset_name=$DATASET_NAME \\\n  --resolution=768 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --checkpoints_total_limit=3 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --validation_prompts=\"A robot naruto, 4k photo\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub \\\n  --output_dir=\"kandi2-decoder-naruto-model\"\n```\n\n</hfoption>\n</hfoptions>\n\nOnce training is finished, you can use your newly trained model for inference!\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"prior model\">\n\n```py\nfrom diffusers import AutoPipelineForText2Image, DiffusionPipeline\nimport torch\n\nprior_pipeline = DiffusionPipeline.from_pretrained(output_dir, torch_dtype=torch.float16)\nprior_components = {\"prior_\" + k: v for k,v in prior_pipeline.components.items()}\npipeline = AutoPipelineForText2Image.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", **prior_components, torch_dtype=torch.float16)\n\npipe.enable_model_cpu_offload()\nprompt=\"A robot naruto, 4k photo\"\nimage = pipeline(prompt=prompt, negative_prompt=negative_prompt).images[0]\n```\n\n> [!TIP]\n> Feel free to replace `kandinsky-community/kandinsky-2-2-decoder` with your own trained decoder checkpoint!\n\n</hfoption>\n<hfoption id=\"decoder model\">\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"path/to/saved/model\", torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n\nprompt=\"A robot naruto, 4k photo\"\nimage = pipeline(prompt=prompt).images[0]\n```\n\nFor the decoder model, you can also perform inference from a saved checkpoint which can be useful for viewing intermediate results. In this case, load the checkpoint into the UNet:\n\n```py\nfrom diffusers import AutoPipelineForText2Image, UNet2DConditionModel\n\nunet = UNet2DConditionModel.from_pretrained(\"path/to/saved/model\" + \"/checkpoint-<N>/unet\")\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", unet=unet, torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n\nimage = pipeline(prompt=\"A robot naruto, 4k photo\").images[0]\n```\n\n</hfoption>\n</hfoptions>\n\n## Next steps\n\nCongratulations on training a Kandinsky 2.2 model! To learn more about how to use your new model, the following guides may be helpful:\n\n- Read the [Kandinsky](../using-diffusers/kandinsky) guide to learn how to use it for a variety of different tasks (text-to-image, image-to-image, inpainting, interpolation), and how it can be combined with a ControlNet.\n- Check out the [DreamBooth](dreambooth) and [LoRA](lora) training guides to learn how to train a personalized Kandinsky model with just a few example images. These two training techniques can even be combined!\n"
  },
  {
    "path": "docs/source/en/training/lcm_distill.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Latent Consistency Distillation\n\n[Latent Consistency Models (LCMs)](https://hf.co/papers/2310.04378) are able to generate high-quality images in just a few steps, representing a big leap forward because many pipelines require at least 25+ steps. LCMs are produced by applying the latent consistency distillation method to any Stable Diffusion model. This method works by applying *one-stage guided distillation* to the latent space, and incorporating a *skipping-step* method to consistently skip timesteps to accelerate the distillation process (refer to section 4.1, 4.2, and 4.3 of the paper for more details).\n\nIf you're training on a GPU with limited vRAM, try enabling `gradient_checkpointing`, `gradient_accumulation_steps`, and `mixed_precision` to reduce memory-usage and speedup training. You can reduce your memory-usage even more by enabling memory-efficient attention with [xFormers](../optimization/xformers) and [bitsandbytes'](https://github.com/TimDettmers/bitsandbytes) 8-bit optimizer.\n\nThis guide will explore the [train_lcm_distill_sd_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sd_wds.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.\n\nBefore running the script, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen navigate to the example folder containing the training script and install the required dependencies for the script you're using:\n\n```bash\ncd examples/consistency_distillation\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment (try enabling `torch.compile` to significantly speedup training):\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell, like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n## Script parameters\n\n> [!TIP]\n> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sd_wds.py) and let us know if you have any questions or concerns.\n\nThe training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L419) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.\n\nFor example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command:\n\n```bash\naccelerate launch train_lcm_distill_sd_wds.py \\\n  --mixed_precision=\"fp16\"\n```\n\nMost of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so you'll focus on the parameters that are relevant to latent consistency distillation in this guide.\n\n- `--pretrained_teacher_model`: the path to a pretrained latent diffusion model to use as the teacher model\n- `--pretrained_vae_model_name_or_path`: path to a pretrained VAE; the SDXL VAE is known to suffer from numerical instability, so this parameter allows you to specify an alternative VAE (like this [VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)) by madebyollin which works in fp16)\n- `--w_min` and `--w_max`: the minimum and maximum guidance scale values for guidance scale sampling\n- `--num_ddim_timesteps`: the number of timesteps for DDIM sampling\n- `--loss_type`: the type of loss (L2 or Huber) to calculate for latent consistency distillation; Huber loss is generally preferred because it's more robust to outliers\n- `--huber_c`: the Huber loss parameter\n\n## Training script\n\nThe training script starts by creating a dataset class - [`Text2ImageDataset`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L141) - for preprocessing the images and creating a training dataset.\n\n```py\ndef transform(example):\n    image = example[\"image\"]\n    image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)\n\n    c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))\n    image = TF.crop(image, c_top, c_left, resolution, resolution)\n    image = TF.to_tensor(image)\n    image = TF.normalize(image, [0.5], [0.5])\n\n    example[\"image\"] = image\n    return example\n```\n\nFor improved performance on reading and writing large datasets stored in the cloud, this script uses the [WebDataset](https://github.com/webdataset/webdataset) format to create a preprocessing pipeline to apply transforms and create a dataset and dataloader for training. Images are processed and fed to the training loop without having to download the full dataset first.\n\n```py\nprocessing_pipeline = [\n    wds.decode(\"pil\", handler=wds.ignore_and_continue),\n    wds.rename(image=\"jpg;png;jpeg;webp\", text=\"text;txt;caption\", handler=wds.warn_and_continue),\n    wds.map(filter_keys({\"image\", \"text\"})),\n    wds.map(transform),\n    wds.to_tuple(\"image\", \"text\"),\n]\n```\n\nIn the [`main()`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L768) function, all the necessary components like the noise scheduler, tokenizers, text encoders, and VAE are loaded. The teacher UNet is also loaded here and then you can create a student UNet from the teacher UNet. The student UNet is updated by the optimizer during training.\n\n```py\nteacher_unet = UNet2DConditionModel.from_pretrained(\n    args.pretrained_teacher_model, subfolder=\"unet\", revision=args.teacher_revision\n)\n\nunet = UNet2DConditionModel(**teacher_unet.config)\nunet.load_state_dict(teacher_unet.state_dict(), strict=False)\nunet.train()\n```\n\nNow you can create the [optimizer](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L979) to update the UNet parameters:\n\n```py\noptimizer = optimizer_class(\n    unet.parameters(),\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\nCreate the [dataset](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L994):\n\n```py\ndataset = Text2ImageDataset(\n    train_shards_path_or_url=args.train_shards_path_or_url,\n    num_train_examples=args.max_train_samples,\n    per_gpu_batch_size=args.train_batch_size,\n    global_batch_size=args.train_batch_size * accelerator.num_processes,\n    num_workers=args.dataloader_num_workers,\n    resolution=args.resolution,\n    shuffle_buffer_size=1000,\n    pin_memory=True,\n    persistent_workers=True,\n)\ntrain_dataloader = dataset.train_dataloader\n```\n\nNext, you're ready to setup the [training loop](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L1049) and implement the latent consistency distillation method (see Algorithm 1 in the paper for more details). This section of the script takes care of adding noise to the latents, sampling and creating a guidance scale embedding, and predicting the original image from the noise.\n\n```py\npred_x_0 = predicted_origin(\n    noise_pred,\n    start_timesteps,\n    noisy_model_input,\n    noise_scheduler.config.prediction_type,\n    alpha_schedule,\n    sigma_schedule,\n)\n\nmodel_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0\n```\n\nIt gets the [teacher model predictions](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L1172) and the [LCM predictions](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L1209) next, calculates the loss, and then backpropagates it to the LCM.\n\n```py\nif args.loss_type == \"l2\":\n    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\nelif args.loss_type == \"huber\":\n    loss = torch.mean(\n        torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c\n    )\n```\n\nIf you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers tutorial](../using-diffusers/write_own_pipeline) which breaks down the basic pattern of the denoising process.\n\n## Launch the script\n\nNow you're ready to launch the training script and start distilling!\n\nFor this guide, you'll use the `--train_shards_path_or_url` to specify the path to the [Conceptual Captions 12M](https://github.com/google-research-datasets/conceptual-12m) dataset stored on the Hub [here](https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset). Set the `MODEL_DIR` environment variable to the name of the teacher model and `OUTPUT_DIR` to where you want to save the model.\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path/to/saved/model\"\n\naccelerate launch train_lcm_distill_sd_wds.py \\\n    --pretrained_teacher_model=$MODEL_DIR \\\n    --output_dir=$OUTPUT_DIR \\\n    --mixed_precision=fp16 \\\n    --resolution=512 \\\n    --learning_rate=1e-6 --loss_type=\"huber\" --ema_decay=0.95 --adam_weight_decay=0.0 \\\n    --max_train_steps=1000 \\\n    --max_train_samples=4000000 \\\n    --dataloader_num_workers=8 \\\n    --train_shards_path_or_url=\"pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true\" \\\n    --validation_steps=200 \\\n    --checkpointing_steps=200 --checkpoints_total_limit=10 \\\n    --train_batch_size=12 \\\n    --gradient_checkpointing --enable_xformers_memory_efficient_attention \\\n    --gradient_accumulation_steps=1 \\\n    --use_8bit_adam \\\n    --resume_from_checkpoint=latest \\\n    --report_to=wandb \\\n    --seed=453645634 \\\n    --push_to_hub\n```\n\nOnce training is complete, you can use your new LCM for inference.\n\n```py\nfrom diffusers import UNet2DConditionModel, DiffusionPipeline, LCMScheduler\nimport torch\n\nunet = UNet2DConditionModel.from_pretrained(\"your-username/your-model\", torch_dtype=torch.float16, variant=\"fp16\")\npipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", unet=unet, torch_dtype=torch.float16, variant=\"fp16\")\n\npipeline.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\npipeline.to(\"cuda\")\n\nprompt = \"sushi rolls in the form of panda heads, sushi platter\"\n\nimage = pipeline(prompt, num_inference_steps=4, guidance_scale=1.0).images[0]\n```\n\n## LoRA\n\nLoRA is a training technique for significantly reducing the number of trainable parameters. As a result, training is faster and it is easier to store the resulting weights because they are a lot smaller (~100MBs). Use the [train_lcm_distill_lora_sd_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py) or [train_lcm_distill_lora_sdxl.wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py) script to train with LoRA.\n\nThe LoRA training script is discussed in more detail in the [LoRA training](lora) guide.\n\n## Stable Diffusion XL\n\nStable Diffusion XL (SDXL) is a powerful text-to-image model that generates high-resolution images, and it adds a second text-encoder to its architecture. Use the [train_lcm_distill_sdxl_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py) script to train a SDXL model with LoRA.\n\nThe SDXL training script is discussed in more detail in the [SDXL training](sdxl) guide.\n\n## Next steps\n\nCongratulations on distilling a LCM model! To learn more about LCM, the following may be helpful:\n\n- Learn how to use [LCMs for inference](../using-diffusers/inference_with_lcm) for text-to-image, image-to-image, and with LoRA checkpoints.\n- Read the [SDXL in 4 steps with Latent Consistency LoRAs](https://huggingface.co/blog/lcm_lora) blog post to learn more about SDXL LCM-LoRA's for super fast inference, quality comparisons, benchmarks, and more.\n"
  },
  {
    "path": "docs/source/en/training/lora.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# LoRA\n\n> [!WARNING]\n> This is experimental and the API may change in the future.\n\n[LoRA (Low-Rank Adaptation of Large Language Models)](https://hf.co/papers/2106.09685) is a popular and lightweight training technique that significantly reduces the number of trainable parameters. It works by inserting a smaller number of new weights into the model and only these are trained. This makes training with LoRA much faster, memory-efficient, and produces smaller model weights (a few hundred MBs), which are easier to store and share. LoRA can also be combined with other training techniques like DreamBooth to speedup training.\n\n> [!TIP]\n> LoRA is very versatile and supported for [DreamBooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py), [Kandinsky 2.2](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py), [Stable Diffusion XL](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py), [text-to-image](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py), and [Wuerstchen](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py).\n\nThis guide will explore the [train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.\n\nBefore running the script, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nNavigate to the example folder with the training script and install the required dependencies for the script you're using:\n\n```bash\ncd examples/text_to_image\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell, like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n> [!TIP]\n> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) and let us know if you have any questions or concerns.\n\n## Script parameters\n\nThe training script has many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/dd9a5caf61f04d11c0fa9f3947b69ab0010c9a0f/examples/text_to_image/train_text_to_image_lora.py#L85) function. Default values are provided for most parameters that work pretty well, but you can also set your own values in the training command if you'd like.\n\nFor example, to increase the number of epochs to train:\n\n```bash\naccelerate launch train_text_to_image_lora.py \\\n  --num_train_epochs=150 \\\n```\n\nMany of the basic and important parameters are described in the [Text-to-image](text2image#script-parameters) training guide, so this guide just focuses on the LoRA relevant parameters:\n\n- `--rank`: the inner dimension of the low-rank matrices to train; a higher rank means more trainable parameters\n- `--learning_rate`: the default learning rate is 1e-4, but with LoRA, you can use a higher learning rate\n\n## Training script\n\nThe dataset preprocessing code and training loop are found in the [`main()`](https://github.com/huggingface/diffusers/blob/dd9a5caf61f04d11c0fa9f3947b69ab0010c9a0f/examples/text_to_image/train_text_to_image_lora.py#L371) function, and if you need to adapt the training script, this is where you'll make your changes.\n\nAs with the script parameters, a walkthrough of the training script is provided in the [Text-to-image](text2image#training-script) training guide. Instead, this guide takes a look at the LoRA relevant parts of the script.\n\n<hfoptions id=\"lora\">\n<hfoption id=\"UNet\">\n\nDiffusers uses [`~peft.LoraConfig`] from the [PEFT](https://hf.co/docs/peft) library to set up the parameters of the LoRA adapter such as the rank, alpha, and which modules to insert the LoRA weights into. The adapter is added to the UNet, and only the LoRA layers are filtered for optimization in `lora_layers`.\n\n```py\nunet_lora_config = LoraConfig(\n    r=args.rank,\n    lora_alpha=args.rank,\n    init_lora_weights=\"gaussian\",\n    target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n)\n\nunet.add_adapter(unet_lora_config)\nlora_layers = filter(lambda p: p.requires_grad, unet.parameters())\n```\n\n</hfoption>\n<hfoption id=\"text encoder\">\n\nDiffusers also supports finetuning the text encoder with LoRA from the [PEFT](https://hf.co/docs/peft) library when necessary such as finetuning Stable Diffusion XL (SDXL). The [`~peft.LoraConfig`] is used to configure the parameters of the LoRA adapter which are then added to the text encoder, and only the LoRA layers are filtered for training.\n\n```py\ntext_lora_config = LoraConfig(\n    r=args.rank,\n    lora_alpha=args.rank,\n    init_lora_weights=\"gaussian\",\n    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n)\n\ntext_encoder_one.add_adapter(text_lora_config)\ntext_encoder_two.add_adapter(text_lora_config)\ntext_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))\ntext_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))\n```\n\n</hfoption>\n</hfoptions>\n\nThe [optimizer](https://github.com/huggingface/diffusers/blob/e4b8f173b97731686e290b2eb98e7f5df2b1b322/examples/text_to_image/train_text_to_image_lora.py#L529) is initialized with the `lora_layers` because these are the only weights that'll be optimized:\n\n```py\noptimizer = optimizer_cls(\n    lora_layers,\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\nAside from setting up the LoRA layers, the training script is more or less the same as train_text_to_image.py!\n\n## Launch the script\n\nOnce you've made all your changes or you're okay with the default configuration, you're ready to launch the training script! 🚀\n\nLet's train on the [Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) dataset to generate your own Naruto characters. Set the environment variables `MODEL_NAME` and `DATASET_NAME` to the model and dataset respectively. You should also specify where to save the model in `OUTPUT_DIR`, and the name of the model to save to on the Hub with `HUB_MODEL_ID`. The script creates and saves the following files to your repository:\n\n- saved model checkpoints\n- `pytorch_lora_weights.safetensors` (the trained LoRA weights)\n\nIf you're training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.\n\n> [!WARNING]\n> A full training run takes ~5 hours on a 2080 Ti GPU with 11GB of VRAM.\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"/sddata/finetune/lora/naruto\"\nexport HUB_MODEL_ID=\"naruto-lora\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image_lora.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$DATASET_NAME \\\n  --dataloader_num_workers=8 \\\n  --resolution=512 \\\n  --center_crop \\\n  --random_flip \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-04 \\\n  --max_grad_norm=1 \\\n  --lr_scheduler=\"cosine\" \\\n  --lr_warmup_steps=0 \\\n  --output_dir=${OUTPUT_DIR} \\\n  --push_to_hub \\\n  --hub_model_id=${HUB_MODEL_ID} \\\n  --report_to=wandb \\\n  --checkpointing_steps=500 \\\n  --validation_prompt=\"A naruto with blue eyes.\" \\\n  --seed=1337\n```\n\nOnce training has been completed, you can use your model for inference:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16).to(\"cuda\")\npipeline.load_lora_weights(\"path/to/lora/model\", weight_name=\"pytorch_lora_weights.safetensors\")\nimage = pipeline(\"A naruto with blue eyes\").images[0]\n```\n\n## Next steps\n\nCongratulations on training a new model with LoRA! To learn more about how to use your new model, the following guides may be helpful:\n\n- Learn how to [load different LoRA formats](../tutorials/using_peft_for_inference) trained using community trainers like Kohya and TheLastBen.\n- Learn how to use and [combine multiple LoRA's](../tutorials/using_peft_for_inference) with PEFT for inference.\n"
  },
  {
    "path": "docs/source/en/training/overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Overview\n\n🤗 Diffusers provides a collection of training scripts for you to train your own diffusion models. You can find all of our training scripts in [diffusers/examples](https://github.com/huggingface/diffusers/tree/main/examples).\n\nEach training script is:\n\n- **Self-contained**: the training script does not depend on any local files, and all packages required to run the script are installed from the `requirements.txt` file.\n- **Easy-to-tweak**: the training scripts are an example of how to train a diffusion model for a specific task and won't work out-of-the-box for every training scenario. You'll likely need to adapt the training script for your specific use-case. To help you with that, we've fully exposed the data preprocessing code and the training loop so you can modify it for your own use.\n- **Beginner-friendly**: the training scripts are designed to be beginner-friendly and easy to understand, rather than including the latest state-of-the-art methods to get the best and most competitive results. Any training methods we consider too complex are purposefully left out.\n- **Single-purpose**: each training script is expressly designed for only one task to keep it readable and understandable.\n\nOur current collection of training scripts include:\n\n| Training | SDXL-support | LoRA-support |\n|---|---|---|\n| [unconditional image generation](https://github.com/huggingface/diffusers/tree/main/examples/unconditional_image_generation) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) |  |  |\n| [text-to-image](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) | 👍 | 👍 |\n| [textual inversion](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) |  |  |\n| [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) | 👍 | 👍 |\n| [ControlNet](https://github.com/huggingface/diffusers/tree/main/examples/controlnet) | 👍 |  |\n| [InstructPix2Pix](https://github.com/huggingface/diffusers/tree/main/examples/instruct_pix2pix) | 👍 |  |\n| [Custom Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion) |  |  |\n| [T2I-Adapters](https://github.com/huggingface/diffusers/tree/main/examples/t2i_adapter) | 👍 |  |\n| [Kandinsky 2.2](https://github.com/huggingface/diffusers/tree/main/examples/kandinsky2_2/text_to_image) |  | 👍 |\n| [Wuerstchen](https://github.com/huggingface/diffusers/tree/main/examples/wuerstchen/text_to_image) |  | 👍 |\n\nThese examples are **actively** maintained, so please feel free to open an issue if they aren't working as expected. If you feel like another training example should be included, you're more than welcome to start a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) to discuss your feature idea with us and whether it meets our criteria of being self-contained, easy-to-tweak, beginner-friendly, and single-purpose.\n\n## Install\n\nMake sure you can successfully run the latest versions of the example scripts by installing the library from source in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen navigate to the folder of the training script (for example, [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)) and install the `requirements.txt` file. Some training scripts have a specific requirement file for SDXL or LoRA. If you're using one of these scripts, make sure you install its corresponding requirements file.\n\n```bash\ncd examples/dreambooth\npip install -r requirements.txt\n# to train SDXL with DreamBooth\npip install -r requirements_sdxl.txt\n```\n\nTo speedup training and reduce memory-usage, we recommend:\n\n- using PyTorch 2.0 or higher to automatically use [scaled dot product attention](../optimization/fp16#scaled-dot-product-attention) during training (you don't need to make any changes to the training code)\n- installing [xFormers](../optimization/xformers) to enable memory-efficient attention"
  },
  {
    "path": "docs/source/en/training/sdxl.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Diffusion XL\n\n> [!WARNING]\n> This script is experimental, and it's easy to overfit and run into issues like catastrophic forgetting. Try exploring different hyperparameters to get the best results on your dataset.\n\n[Stable Diffusion XL (SDXL)](https://hf.co/papers/2307.01952) is a larger and more powerful iteration of the Stable Diffusion model, capable of producing higher resolution images.\n\nSDXL's UNet is 3x larger and the model adds a second text encoder to the architecture. Depending on the hardware available to you, this can be very computationally intensive and it may not run on a consumer GPU like a Tesla T4. To help fit this larger model into memory and to speedup training, try enabling `gradient_checkpointing`, `mixed_precision`, and `gradient_accumulation_steps`. You can reduce your memory-usage even more by enabling memory-efficient attention with [xFormers](../optimization/xformers) and using [bitsandbytes'](https://github.com/TimDettmers/bitsandbytes) 8-bit optimizer.\n\nThis guide will explore the [train_text_to_image_sdxl.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_sdxl.py) training script to help you become more familiar with it, and how you can adapt it for your own use-case.\n\nBefore running the script, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen navigate to the example folder containing the training script and install the required dependencies for the script you're using:\n\n```bash\ncd examples/text_to_image\npip install -r requirements_sdxl.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell, like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n## Script parameters\n\n> [!TIP]\n> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_sdxl.py) and let us know if you have any questions or concerns.\n\nThe training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L129) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.\n\nFor example, to speedup training with mixed precision using the bf16 format, add the `--mixed_precision` parameter to the training command:\n\n```bash\naccelerate launch train_text_to_image_sdxl.py \\\n  --mixed_precision=\"bf16\"\n```\n\nMost of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so you'll focus on the parameters that are relevant to training SDXL in this guide.\n\n- `--pretrained_vae_model_name_or_path`: path to a pretrained VAE; the SDXL VAE is known to suffer from numerical instability, so this parameter allows you to specify a better [VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)\n- `--proportion_empty_prompts`: the proportion of image prompts to replace with empty strings\n- `--timestep_bias_strategy`: where (earlier vs. later) in the timestep to apply a bias, which can encourage the model to either learn low or high frequency details\n- `--timestep_bias_multiplier`: the weight of the bias to apply to the timestep\n- `--timestep_bias_begin`: the timestep to begin applying the bias\n- `--timestep_bias_end`: the timestep to end applying the bias\n- `--timestep_bias_portion`: the proportion of timesteps to apply the bias to\n\n### Min-SNR weighting\n\nThe [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting either `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch.\n\nAdd the `--snr_gamma` parameter and set it to the recommended value of 5.0:\n\n```bash\naccelerate launch train_text_to_image_sdxl.py \\\n  --snr_gamma=5.0\n```\n\n## Training script\n\nThe training script is also similar to the [Text-to-image](text2image#training-script) training guide, but it's been modified to support SDXL training. This guide will focus on the code that is unique to the SDXL training script.\n\nIt starts by creating functions to [tokenize the prompts](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L478) to calculate the prompt embeddings, and to compute the image embeddings with the [VAE](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L519). Next, you'll a function to [generate the timesteps weights](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L531) depending on the number of timesteps and the timestep bias strategy to apply.\n\nWithin the [`main()`](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L572) function, in addition to loading a tokenizer, the script loads a second tokenizer and text encoder because the SDXL architecture uses two of each:\n\n```py\ntokenizer_one = AutoTokenizer.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision, use_fast=False\n)\ntokenizer_two = AutoTokenizer.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"tokenizer_2\", revision=args.revision, use_fast=False\n)\n\ntext_encoder_cls_one = import_model_class_from_model_name_or_path(\n    args.pretrained_model_name_or_path, args.revision\n)\ntext_encoder_cls_two = import_model_class_from_model_name_or_path(\n    args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n)\n```\n\nThe [prompt and image embeddings](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L857) are computed first and kept in memory, which isn't typically an issue for a smaller dataset, but for larger datasets it can lead to memory problems. If this is the case, you should save the pre-computed embeddings to disk separately and load them into memory during the training process (see this [PR](https://github.com/huggingface/diffusers/pull/4505) for more discussion about this topic).\n\n```py\ntext_encoders = [text_encoder_one, text_encoder_two]\ntokenizers = [tokenizer_one, tokenizer_two]\ncompute_embeddings_fn = functools.partial(\n    encode_prompt,\n    text_encoders=text_encoders,\n    tokenizers=tokenizers,\n    proportion_empty_prompts=args.proportion_empty_prompts,\n    caption_column=args.caption_column,\n)\n\ntrain_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)\ntrain_dataset = train_dataset.map(\n    compute_vae_encodings_fn,\n    batched=True,\n    batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,\n    new_fingerprint=new_fingerprint_for_vae,\n)\n```\n\nAfter calculating the embeddings, the text encoder, VAE, and tokenizer are deleted to free up some memory:\n\n```py\ndel text_encoders, tokenizers, vae\ngc.collect()\ntorch.cuda.empty_cache()\n```\n\nFinally, the [training loop](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L943) takes care of the rest. If you chose to apply a timestep bias strategy, you'll see the timestep weights are calculated and added as noise:\n\n```py\nweights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(\n        model_input.device\n    )\n    timesteps = torch.multinomial(weights, bsz, replacement=True).long()\n\nnoisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n```\n\nIf you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.\n\n## Launch the script\n\nOnce you’ve made all your changes or you’re okay with the default configuration, you’re ready to launch the training script! 🚀\n\nLet’s train on the [Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) dataset to generate your own Naruto characters. Set the environment variables `MODEL_NAME` and `DATASET_NAME` to the model and the dataset (either from the Hub or a local path). You should also specify a VAE other than the SDXL VAE (either from the Hub or a local path) with `VAE_NAME` to avoid numerical instabilities.\n\n> [!TIP]\n> To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You’ll also need to add the `--validation_prompt` and `--validation_epochs` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport VAE_NAME=\"madebyollin/sdxl-vae-fp16-fix\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch train_text_to_image_sdxl.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --pretrained_vae_model_name_or_path=$VAE_NAME \\\n  --dataset_name=$DATASET_NAME \\\n  --enable_xformers_memory_efficient_attention \\\n  --resolution=512 \\\n  --center_crop \\\n  --random_flip \\\n  --proportion_empty_prompts=0.2 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --max_train_steps=10000 \\\n  --use_8bit_adam \\\n  --learning_rate=1e-06 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --mixed_precision=\"fp16\" \\\n  --report_to=\"wandb\" \\\n  --validation_prompt=\"a cute Sundar Pichai creature\" \\\n  --validation_epochs 5 \\\n  --checkpointing_steps=5000 \\\n  --output_dir=\"sdxl-naruto-model\" \\\n  --push_to_hub\n```\n\nAfter you've finished training, you can use your newly trained SDXL model for inference!\n\n<hfoptions id=\"inference\">\n<hfoption id=\"PyTorch\">\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(\"path/to/your/model\", torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A naruto with green eyes and red legs.\"\nimage = pipeline(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]\nimage.save(\"naruto.png\")\n```\n\n</hfoption>\n<hfoption id=\"PyTorch XLA\">\n\n[PyTorch XLA](https://pytorch.org/xla) allows you to run PyTorch on XLA devices such as TPUs, which can be faster. The initial warmup step takes longer because the model needs to be compiled and optimized. However, subsequent calls to the pipeline on an input **with the same length** as the original prompt are much faster because it can reuse the optimized graph.\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\nimport torch_xla.core.xla_model as xm\n\ndevice = xm.xla_device()\npipeline = DiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\").to(device)\n\nprompt = \"A naruto with green eyes and red legs.\"\nstart = time()\nimage = pipeline(prompt, num_inference_steps=inference_steps).images[0]\nprint(f'Compilation time is {time()-start} sec')\nimage.save(\"naruto.png\")\n\nstart = time()\nimage = pipeline(prompt, num_inference_steps=inference_steps).images[0]\nprint(f'Inference time is {time()-start} sec after compilation')\n```\n\n</hfoption>\n</hfoptions>\n\n## Next steps\n\nCongratulations on training a SDXL model! To learn more about how to use your new model, the following guides may be helpful:\n\n- Read the [Stable Diffusion XL](../using-diffusers/sdxl) guide to learn how to use it for a variety of different tasks (text-to-image, image-to-image, inpainting), how to use it's refiner model, and the different types of micro-conditionings.\n- Check out the [DreamBooth](dreambooth) and [LoRA](lora) training guides to learn how to train a personalized SDXL model with just a few example images. These two training techniques can even be combined!"
  },
  {
    "path": "docs/source/en/training/t2i_adapters.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# T2I-Adapter\n\n[T2I-Adapter](https://hf.co/papers/2302.08453) is a lightweight adapter model that provides an additional conditioning input image (line art, canny, sketch, depth, pose) to better control image generation. It is similar to a ControlNet, but it is a lot smaller (~77M parameters and ~300MB file size) because its only inserts weights into the UNet instead of copying and training it.\n\nThe T2I-Adapter is only available for training with the Stable Diffusion XL (SDXL) model.\n\nThis guide will explore the [train_t2i_adapter_sdxl.py](https://github.com/huggingface/diffusers/blob/main/examples/t2i_adapter/train_t2i_adapter_sdxl.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.\n\nBefore running the script, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen navigate to the example folder containing the training script and install the required dependencies for the script you're using:\n\n```bash\ncd examples/t2i_adapter\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell, like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n> [!TIP]\n> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/t2i_adapter/train_t2i_adapter_sdxl.py) and let us know if you have any questions or concerns.\n\n## Script parameters\n\nThe training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L233) function. It provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.\n\nFor example, to activate gradient accumulation, add the `--gradient_accumulation_steps` parameter to the training command:\n\n```bash\naccelerate launch train_t2i_adapter_sdxl.py \\\n  ----gradient_accumulation_steps=4\n```\n\nMany of the basic and important parameters are described in the [Text-to-image](text2image#script-parameters) training guide, so this guide just focuses on the relevant T2I-Adapter parameters:\n\n- `--pretrained_vae_model_name_or_path`: path to a pretrained VAE; the SDXL VAE is known to suffer from numerical instability, so this parameter allows you to specify a better [VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)\n- `--crops_coords_top_left_h` and `--crops_coords_top_left_w`: height and width coordinates to include in SDXL's crop coordinate embeddings\n- `--conditioning_image_column`: the column of the conditioning images in the dataset\n- `--proportion_empty_prompts`: the proportion of image prompts to replace with empty strings\n\n## Training script\n\nAs with the script parameters, a walkthrough of the training script is provided in the [Text-to-image](text2image#training-script) training guide. Instead, this guide takes a look at the T2I-Adapter relevant parts of the script.\n\nThe training script begins by preparing the dataset. This includes [tokenizing](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L674) the prompt and [applying transforms](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L714) to the images and conditioning images.\n\n```py\nconditioning_image_transforms = transforms.Compose(\n    [\n        transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n        transforms.CenterCrop(args.resolution),\n        transforms.ToTensor(),\n    ]\n)\n```\n\nWithin the [`main()`](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L770) function, the T2I-Adapter is either loaded from a pretrained adapter or it is randomly initialized:\n\n```py\nif args.adapter_model_name_or_path:\n    logger.info(\"Loading existing adapter weights.\")\n    t2iadapter = T2IAdapter.from_pretrained(args.adapter_model_name_or_path)\nelse:\n    logger.info(\"Initializing t2iadapter weights.\")\n    t2iadapter = T2IAdapter(\n        in_channels=3,\n        channels=(320, 640, 1280, 1280),\n        num_res_blocks=2,\n        downscale_factor=16,\n        adapter_type=\"full_adapter_xl\",\n    )\n```\n\nThe [optimizer](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L952) is initialized for the T2I-Adapter parameters:\n\n```py\nparams_to_optimize = t2iadapter.parameters()\noptimizer = optimizer_class(\n    params_to_optimize,\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\nLastly, in the [training loop](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1086), the adapter conditioning image and the text embeddings are passed to the UNet to predict the noise residual:\n\n```py\nt2iadapter_image = batch[\"conditioning_pixel_values\"].to(dtype=weight_dtype)\ndown_block_additional_residuals = t2iadapter(t2iadapter_image)\ndown_block_additional_residuals = [\n    sample.to(dtype=weight_dtype) for sample in down_block_additional_residuals\n]\n\nmodel_pred = unet(\n    inp_noisy_latents,\n    timesteps,\n    encoder_hidden_states=batch[\"prompt_ids\"],\n    added_cond_kwargs=batch[\"unet_added_conditions\"],\n    down_block_additional_residuals=down_block_additional_residuals,\n).sample\n```\n\nIf you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.\n\n## Launch the script\n\nNow you’re ready to launch the training script! 🚀\n\nFor this example training, you'll use the [fusing/fill50k](https://huggingface.co/datasets/fusing/fill50k) dataset. You can also create and use your own dataset if you want (see the [Create a dataset for training](https://moon-ci-docs.huggingface.co/docs/diffusers/pr_5512/en/training/create_dataset) guide).\n\nSet the environment variable `MODEL_DIR` to a model id on the Hub or a path to a local model and `OUTPUT_DIR` to where you want to save the model.\n\nDownload the following images to condition your training with:\n\n```bash\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png\n```\n\n> [!TIP]\n> To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You'll also need to add the `--validation_image`, `--validation_prompt`, and `--validation_steps` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.\n\n```bash\nexport MODEL_DIR=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_t2i_adapter_sdxl.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --mixed_precision=\"fp16\" \\\n --resolution=1024 \\\n --learning_rate=1e-5 \\\n --max_train_steps=15000 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --validation_steps=100 \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n --report_to=\"wandb\" \\\n --seed=42 \\\n --push_to_hub\n```\n\nOnce training is complete, you can use your T2I-Adapter for inference:\n\n```py\nfrom diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteSchedulerTest\nfrom diffusers.utils import load_image\nimport torch\n\nadapter = T2IAdapter.from_pretrained(\"path/to/adapter\", torch_dtype=torch.float16)\npipeline = StableDiffusionXLAdapterPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", adapter=adapter, torch_dtype=torch.float16\n)\n\npipeline.scheduler = EulerAncestralDiscreteSchedulerTest.from_config(pipe.scheduler.config)\npipeline.enable_xformers_memory_efficient_attention()\npipeline.enable_model_cpu_offload()\n\ncontrol_image = load_image(\"./conditioning_image_1.png\")\nprompt = \"pale golden rod circle with old lace background\"\n\ngenerator = torch.manual_seed(0)\nimage = pipeline(\n    prompt, image=control_image, generator=generator\n).images[0]\nimage.save(\"./output.png\")\n```\n\n## Next steps\n\nCongratulations on training a T2I-Adapter model! 🎉 To learn more:\n\n- Read the [Efficient Controllable Generation for SDXL with T2I-Adapters](https://huggingface.co/blog/t2i-sdxl-adapters) blog post to learn more details about the experimental results from the T2I-Adapter team.\n"
  },
  {
    "path": "docs/source/en/training/text2image.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Text-to-image\n\n> [!WARNING]\n> The text-to-image script is experimental, and it's easy to overfit and run into issues like catastrophic forgetting. Try exploring different hyperparameters to get the best results on your dataset.\n\nText-to-image models like Stable Diffusion are conditioned to generate images given a text prompt.\n\nTraining a model can be taxing on your hardware, but if you enable `gradient_checkpointing` and `mixed_precision`, it is possible to train a model on a single 24GB GPU. If you're training with larger batch sizes or want to train faster, it's better to use GPUs with more than 30GB of memory. You can reduce your memory footprint by enabling memory-efficient attention with [xFormers](../optimization/xformers).\n\nThis guide will explore the [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.\n\nBefore running the script, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen navigate to the example folder containing the training script and install the required dependencies for the script you're using:\n\n```bash\ncd examples/text_to_image\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell, like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n## Script parameters\n\n> [!TIP]\n> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) and let us know if you have any questions or concerns.\n\nThe training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L193) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.\n\nFor example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command:\n\n```bash\naccelerate launch train_text_to_image.py \\\n  --mixed_precision=\"fp16\"\n```\n\nSome basic and important parameters include:\n\n- `--pretrained_model_name_or_path`: the name of the model on the Hub or a local path to the pretrained model\n- `--dataset_name`: the name of the dataset on the Hub or a local path to the dataset to train on\n- `--image_column`: the name of the image column in the dataset to train on\n- `--caption_column`: the name of the text column in the dataset to train on\n- `--output_dir`: where to save the trained model\n- `--push_to_hub`: whether to push the trained model to the Hub\n- `--checkpointing_steps`: frequency of saving a checkpoint as the model trains; this is useful if for some reason training is interrupted, you can continue training from that checkpoint by adding `--resume_from_checkpoint` to your training command\n\n### Min-SNR weighting\n\nThe [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch.\n\nAdd the `--snr_gamma` parameter and set it to the recommended value of 5.0:\n\n```bash\naccelerate launch train_text_to_image.py \\\n  --snr_gamma=5.0\n```\n\nYou can compare the loss surfaces for different `snr_gamma` values in this [Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) report. For smaller datasets, the effects of Min-SNR may not be as obvious compared to larger datasets.\n\n## Training script\n\nThe dataset preprocessing code and training loop are found in the [`main()`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L490) function. If you need to adapt the training script, this is where you'll need to make your changes.\n\nThe `train_text_to_image` script starts by [loading a scheduler](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L543) and tokenizer. You can choose to use a different scheduler here if you want:\n\n```py\nnoise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\ntokenizer = CLIPTokenizer.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n)\n```\n\nThen the script [loads the UNet](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L619) model:\n\n```py\nload_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\nmodel.register_to_config(**load_model.config)\n\nmodel.load_state_dict(load_model.state_dict())\n```\n\nNext, the text and image columns of the dataset need to be preprocessed. The [`tokenize_captions`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L724) function handles tokenizing the inputs, and the [`train_transforms`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L742) function specifies the type of transforms to apply to the image. Both of these functions are bundled into `preprocess_train`:\n\n```py\ndef preprocess_train(examples):\n    images = [image.convert(\"RGB\") for image in examples[image_column]]\n    examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n    examples[\"input_ids\"] = tokenize_captions(examples)\n    return examples\n```\n\nLastly, the [training loop](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L878) handles everything else. It encodes images into latent space, adds noise to the latents, computes the text embeddings to condition on, updates the model parameters, and saves and pushes the model to the Hub. If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.\n\n## Launch the script\n\nOnce you've made all your changes or you're okay with the default configuration, you're ready to launch the training script! 🚀\n\nLet's train on the [Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) dataset to generate your own Naruto characters. Set the environment variables `MODEL_NAME` and `dataset_name` to the model and the dataset (either from the Hub or a local path). If you're training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.\n\n> [!TIP]\n> To train on a local dataset, set the `TRAIN_DIR` and `OUTPUT_DIR` environment variables to the path of the dataset and where to save the model to.\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport dataset_name=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$dataset_name \\\n  --use_ema \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --enable_xformers_memory_efficient_attention \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --output_dir=\"sd-naruto-model\" \\\n  --push_to_hub\n```\n\nOnce training is complete, you can use your newly trained model for inference:\n\n```py\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\npipeline = StableDiffusionPipeline.from_pretrained(\"path/to/saved_model\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n\nimage = pipeline(prompt=\"yoda\").images[0]\nimage.save(\"yoda-naruto.png\")\n```\n\n## Next steps\n\nCongratulations on training your own text-to-image model! To learn more about how to use your new model, the following guides may be helpful:\n\n- Learn how to [load LoRA weights](../tutorials/using_peft_for_inference) for inference if you trained your model with LoRA.\n- Learn more about how certain parameters like guidance scale or techniques such as prompt weighting can help you control inference in the [Text-to-image](../using-diffusers/conditional_image_generation) task guide.\n"
  },
  {
    "path": "docs/source/en/training/text_inversion.md",
    "content": " <!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Textual Inversion\n\n[Textual Inversion](https://hf.co/papers/2208.01618) is a training technique for personalizing image generation models with just a few example images of what you want it to learn. This technique works by learning and updating the text embeddings (the new embeddings are tied to a special word you must use in the prompt) to match the example images you provide.\n\nIf you're training on a GPU with limited vRAM, you should try enabling the `gradient_checkpointing` and `mixed_precision` parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with [xFormers](../optimization/xformers).\n\nThis guide will explore the [textual_inversion.py](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.\n\nBefore running the script, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nNavigate to the example folder with the training script and install the required dependencies for the script you're using:\n\n```bash\ncd examples/textual_inversion\npip install -r requirements.txt\n```\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell, like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n> [!TIP]\n> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py) and let us know if you have any questions or concerns.\n\n## Script parameters\n\nThe training script has many parameters to help you tailor the training run to your needs. All of the parameters and their descriptions are listed in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/839c2a5ece0af4e75530cb520d77bc7ed8acf474/examples/textual_inversion/textual_inversion.py#L176) function. Where applicable, Diffusers provides default values for each parameter such as the training batch size and learning rate, but feel free to change these values in the training command if you'd like.\n\nFor example, to increase the number of gradient accumulation steps above the default value of 1:\n\n```bash\naccelerate launch textual_inversion.py \\\n  --gradient_accumulation_steps=4\n```\n\nSome other basic and important parameters to specify include:\n\n- `--pretrained_model_name_or_path`: the name of the model on the Hub or a local path to the pretrained model\n- `--train_data_dir`: path to a folder containing the training dataset (example images)\n- `--output_dir`: where to save the trained model\n- `--push_to_hub`: whether to push the trained model to the Hub\n- `--checkpointing_steps`: frequency of saving a checkpoint as the model trains; this is useful if for some reason training is interrupted, you can continue training from that checkpoint by adding `--resume_from_checkpoint` to your training command\n- `--num_vectors`: the number of vectors to learn the embeddings with; increasing this parameter helps the model learn better but it comes with increased training costs\n- `--placeholder_token`: the special word to tie the learned embeddings to (you must use the word in your prompt for inference)\n- `--initializer_token`: a single-word that roughly describes the object or style you're trying to train on\n- `--learnable_property`: whether you're training the model to learn a new \"style\" (for example, Van Gogh's painting style) or \"object\" (for example, your dog)\n\n## Training script\n\nUnlike some of the other training scripts, textual_inversion.py has a custom dataset class, [`TextualInversionDataset`](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L487) for creating a dataset. You can customize the image size, placeholder token, interpolation method, whether to crop the image, and more. If you need to change how the dataset is created, you can modify `TextualInversionDataset`.\n\nNext, you'll find the dataset preprocessing code and training loop in the [`main()`](https://github.com/huggingface/diffusers/blob/839c2a5ece0af4e75530cb520d77bc7ed8acf474/examples/textual_inversion/textual_inversion.py#L573) function.\n\nThe script starts by loading the [tokenizer](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L616), [scheduler and model](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L622):\n\n```py\n# Load tokenizer\nif args.tokenizer_name:\n    tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\nelif args.pretrained_model_name_or_path:\n    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n# Load scheduler and models\nnoise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\ntext_encoder = CLIPTextModel.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n)\nvae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\nunet = UNet2DConditionModel.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n)\n```\n\nThe special [placeholder token](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L632) is added next to the tokenizer, and the embedding is readjusted to account for the new token.\n\nThen, the script [creates a dataset](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L716) from the `TextualInversionDataset`:\n\n```py\ntrain_dataset = TextualInversionDataset(\n    data_root=args.train_data_dir,\n    tokenizer=tokenizer,\n    size=args.resolution,\n    placeholder_token=(\" \".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))),\n    repeats=args.repeats,\n    learnable_property=args.learnable_property,\n    center_crop=args.center_crop,\n    set=\"train\",\n)\ntrain_dataloader = torch.utils.data.DataLoader(\n    train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers\n)\n```\n\nFinally, the [training loop](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L784) handles everything else from predicting the noisy residual to updating the embedding weights of the special placeholder token.\n\nIf you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.\n\n## Launch the script\n\nOnce you've made all your changes or you're okay with the default configuration, you're ready to launch the training script! 🚀\n\nFor this guide, you'll download some images of a [cat toy](https://huggingface.co/datasets/diffusers/cat_toy_example) and store them in a directory. But remember, you can create and use your own dataset if you want (see the [Create a dataset for training](create_dataset) guide).\n\n```py\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./cat\"\nsnapshot_download(\n    \"diffusers/cat_toy_example\", local_dir=local_dir, repo_type=\"dataset\", ignore_patterns=\".gitattributes\"\n)\n```\n\nSet the environment variable `MODEL_NAME` to a model id on the Hub or a path to a local model, and `DATA_DIR`  to the path where you just downloaded the cat images to. The script creates and saves the following files to your repository:\n\n- `learned_embeds.bin`: the learned embedding vectors corresponding to your example images\n- `token_identifier.txt`: the special placeholder token\n- `type_of_concept.txt`: the type of concept you're training on (either \"object\" or \"style\")\n\n> [!WARNING]\n> A full training run takes ~1 hour on a single V100 GPU.\n\nOne more thing before you launch the script. If you're interested in following along with the training process, you can periodically save generated images as training progresses. Add the following parameters to the training command:\n\n```bash\n--validation_prompt=\"A <cat-toy> train\"\n--num_validation_images=4\n--validation_steps=100\n```\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport DATA_DIR=\"./cat\"\n\naccelerate launch textual_inversion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<cat-toy>\" \\\n  --initializer_token=\"toy\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=3000 \\\n  --learning_rate=5.0e-04 \\\n  --scale_lr \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --output_dir=\"textual_inversion_cat\" \\\n  --push_to_hub\n```\n\nAfter training is complete, you can use your newly trained model for inference like:\n\n```py\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\npipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16).to(\"cuda\")\npipeline.load_textual_inversion(\"sd-concepts-library/cat-toy\")\nimage = pipeline(\"A <cat-toy> train\", num_inference_steps=50).images[0]\nimage.save(\"cat-train.png\")\n```\n\n## Next steps\n\nCongratulations on training your own Textual Inversion model! 🎉 To learn more about how to use your new model, the following guides may be helpful:\n\n- Learn how to [load Textual Inversion embeddings](../using-diffusers/textual_inversion_inference) and also use them as negative embeddings."
  },
  {
    "path": "docs/source/en/training/unconditional_training.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Unconditional image generation\n\nUnconditional image generation models are not conditioned on text or images during training. It only generates images that resemble its training data distribution.\n\nThis guide will explore the [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.\n\nBefore running the script, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen navigate to the example folder containing the training script and install the required dependencies:\n\n```bash\ncd examples/unconditional_image_generation\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n## Script parameters\n\n> [!TIP]\n> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) and let us know if you have any questions or concerns.\n\nThe training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L55) function. It provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.\n\nFor example, to speedup training with mixed precision using the bf16 format, add the `--mixed_precision` parameter to the training command:\n\n```bash\naccelerate launch train_unconditional.py \\\n  --mixed_precision=\"bf16\"\n```\n\nSome basic and important parameters to specify include:\n\n- `--dataset_name`: the name of the dataset on the Hub or a local path to the dataset to train on\n- `--output_dir`: where to save the trained model\n- `--push_to_hub`: whether to push the trained model to the Hub\n- `--checkpointing_steps`: frequency of saving a checkpoint as the model trains; this is useful if training is interrupted, you can continue training from that checkpoint by adding `--resume_from_checkpoint` to your training command\n\nBring your dataset, and let the training script handle everything else!\n\n## Training script\n\nThe code for preprocessing the dataset and the training loop is found in the [`main()`](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L275) function. If you need to adapt the training script, this is where you'll need to make your changes.\n\nThe `train_unconditional` script [initializes a `UNet2DModel`](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L356) if you don't provide a model configuration. You can configure the UNet here if you'd like:\n\n```py\nmodel = UNet2DModel(\n    sample_size=args.resolution,\n    in_channels=3,\n    out_channels=3,\n    layers_per_block=2,\n    block_out_channels=(128, 128, 256, 256, 512, 512),\n    down_block_types=(\n        \"DownBlock2D\",\n        \"DownBlock2D\",\n        \"DownBlock2D\",\n        \"DownBlock2D\",\n        \"AttnDownBlock2D\",\n        \"DownBlock2D\",\n    ),\n    up_block_types=(\n        \"UpBlock2D\",\n        \"AttnUpBlock2D\",\n        \"UpBlock2D\",\n        \"UpBlock2D\",\n        \"UpBlock2D\",\n        \"UpBlock2D\",\n    ),\n)\n```\n\nNext, the script initializes a [scheduler](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L418) and [optimizer](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L429):\n\n```py\n# Initialize the scheduler\naccepts_prediction_type = \"prediction_type\" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())\nif accepts_prediction_type:\n    noise_scheduler = DDPMScheduler(\n        num_train_timesteps=args.ddpm_num_steps,\n        beta_schedule=args.ddpm_beta_schedule,\n        prediction_type=args.prediction_type,\n    )\nelse:\n    noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)\n\n# Initialize the optimizer\noptimizer = torch.optim.AdamW(\n    model.parameters(),\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\nThen it [loads a dataset](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L451) and you can specify how to [preprocess](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L455) it:\n\n```py\ndataset = load_dataset(\"imagefolder\", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split=\"train\")\n\naugmentations = transforms.Compose(\n    [\n        transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n        transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n        transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n        transforms.ToTensor(),\n        transforms.Normalize([0.5], [0.5]),\n    ]\n)\n```\n\nFinally, the [training loop](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L540) handles everything else such as adding noise to the images, predicting the noise residual, calculating the loss, saving checkpoints at specified steps, and saving and pushing the model to the Hub. If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.\n\n## Launch the script\n\nOnce you've made all your changes or you're okay with the default configuration, you're ready to launch the training script! 🚀\n\n> [!WARNING]\n> A full training run takes 2 hours on 4xV100 GPUs.\n\n<hfoptions id=\"launchtraining\">\n<hfoption id=\"single GPU\">\n\n```bash\naccelerate launch train_unconditional.py \\\n  --dataset_name=\"huggan/flowers-102-categories\" \\\n  --output_dir=\"ddpm-ema-flowers-64\" \\\n  --mixed_precision=\"fp16\" \\\n  --push_to_hub\n```\n\n</hfoption>\n<hfoption id=\"multi-GPU\">\n\nIf you're training with more than one GPU, add the `--multi_gpu` parameter to the training command:\n\n```bash\naccelerate launch --multi_gpu train_unconditional.py \\\n  --dataset_name=\"huggan/flowers-102-categories\" \\\n  --output_dir=\"ddpm-ema-flowers-64\" \\\n  --mixed_precision=\"fp16\" \\\n  --push_to_hub\n```\n\n</hfoption>\n</hfoptions>\n\nThe training script creates and saves a checkpoint file in your repository. Now you can load and use your trained model for inference:\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(\"anton-l/ddpm-butterflies-128\").to(\"cuda\")\nimage = pipeline().images[0]\n```\n"
  },
  {
    "path": "docs/source/en/training/wuerstchen.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Wuerstchen\n\nThe [Wuerstchen](https://hf.co/papers/2306.00637) model drastically reduces computational costs by compressing the latent space by 42x, without compromising image quality and accelerating inference. During training, Wuerstchen uses two models (VQGAN + autoencoder) to compress the latents, and then a third model (text-conditioned latent diffusion model) is conditioned on this highly compressed space to generate an image.\n\nTo fit the prior model into GPU memory and to speedup training, try enabling `gradient_accumulation_steps`, `gradient_checkpointing`, and `mixed_precision` respectively.\n\nThis guide explores the [train_text_to_image_prior.py](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.\n\nBefore running the script, make sure you install the library from source:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen navigate to the example folder containing the training script and install the required dependencies for the script you're using:\n\n```bash\ncd examples/wuerstchen/text_to_image\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.\n\nInitialize an 🤗 Accelerate environment:\n\n```bash\naccelerate config\n```\n\nTo setup a default 🤗 Accelerate environment without choosing any configurations:\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell, like a notebook, you can use:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\nLastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.\n\n> [!TIP]\n> The following sections highlight parts of the training scripts that are important for understanding how to modify it, but it doesn't cover every aspect of the [script](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) in detail. If you're interested in learning more, feel free to read through the scripts and let us know if you have any questions or concerns.\n\n## Script parameters\n\nThe training scripts provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L192) function. It provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.\n\nFor example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command:\n\n```bash\naccelerate launch train_text_to_image_prior.py \\\n  --mixed_precision=\"fp16\"\n```\n\nMost of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so let's dive right into the Wuerstchen training script!\n\n## Training script\n\nThe training script is also similar to the [Text-to-image](text2image#training-script) training guide, but it's been modified to support Wuerstchen. This guide focuses on the code that is unique to the Wuerstchen training script.\n\nThe [`main()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L441) function starts by initializing the image encoder - an [EfficientNet](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py) - in addition to the usual scheduler and tokenizer.\n\n```py\nwith ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n    pretrained_checkpoint_file = hf_hub_download(\"dome272/wuerstchen\", filename=\"model_v2_stage_b.pt\")\n    state_dict = torch.load(pretrained_checkpoint_file, map_location=\"cpu\")\n    image_encoder = EfficientNetEncoder()\n    image_encoder.load_state_dict(state_dict[\"effnet_state_dict\"])\n    image_encoder.eval()\n```\n\nYou'll also load the [`WuerstchenPrior`] model for optimization.\n\n```py\nprior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"prior\")\n\noptimizer = optimizer_cls(\n    prior.parameters(),\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\nNext, you'll apply some [transforms](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L656) to the images and [tokenize](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L637) the captions:\n\n```py\ndef preprocess_train(examples):\n    images = [image.convert(\"RGB\") for image in examples[image_column]]\n    examples[\"effnet_pixel_values\"] = [effnet_transforms(image) for image in images]\n    examples[\"text_input_ids\"], examples[\"text_mask\"] = tokenize_captions(examples)\n    return examples\n```\n\nFinally, the [training loop](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L656) handles compressing the images to latent space with the `EfficientNetEncoder`, adding noise to the latents, and predicting the noise residual with the [`WuerstchenPrior`] model.\n\n```py\npred_noise = prior(noisy_latents, timesteps, prompt_embeds)\n```\n\nIf you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers](../using-diffusers/write_own_pipeline) tutorial which breaks down the basic pattern of the denoising process.\n\n## Launch the script\n\nOnce you’ve made all your changes or you’re okay with the default configuration, you’re ready to launch the training script! 🚀\n\nSet the `DATASET_NAME` environment variable to the dataset name from the Hub. This guide uses the [Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) dataset, but you can create and train on your own datasets as well (see the [Create a dataset for training](create_dataset) guide).\n\n> [!TIP]\n> To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You’ll also need to add the `--validation_prompt` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.\n\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch  train_text_to_image_prior.py \\\n  --mixed_precision=\"fp16\" \\\n  --dataset_name=$DATASET_NAME \\\n  --resolution=768 \\\n  --train_batch_size=4 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --dataloader_num_workers=4 \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --checkpoints_total_limit=3 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --validation_prompts=\"A robot naruto, 4k photo\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub \\\n  --output_dir=\"wuerstchen-prior-naruto-model\"\n```\n\nOnce training is complete, you can use your newly trained model for inference!\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\nfrom diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"path/to/saved/model\", torch_dtype=torch.float16).to(\"cuda\")\n\ncaption = \"A cute bird naruto holding a shield\"\nimages = pipeline(\n    caption,\n    width=1024,\n    height=1536,\n    prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,\n    prior_guidance_scale=4.0,\n    num_images_per_prompt=2,\n).images\n```\n\n## Next steps\n\nCongratulations on training a Wuerstchen model! To learn more about how to use your new model, the following may be helpful:\n\n- Take a look at the [Wuerstchen](../api/pipelines/wuerstchen#text-to-image-generation) API documentation to learn more about how to use the pipeline for text-to-image generation and its limitations.\n"
  },
  {
    "path": "docs/source/en/tutorials/autopipeline.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AutoPipeline\n\n[AutoPipeline](../api/models/auto_model) is a *task-and-model* pipeline that automatically selects the correct pipeline subclass based on the task. It handles the complexity of loading different pipeline subclasses without needing to know the specific pipeline subclass name.\n\nThis is unlike [`DiffusionPipeline`], a *model-only* pipeline that automatically selects the pipeline subclass based on the model.\n\n[`AutoPipelineForImage2Image`] returns a specific pipeline subclass, (for example, [`StableDiffusionXLImg2ImgPipeline`]), which can only be used for image-to-image tasks.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n  \"RunDiffusion/Juggernaut-XL-v9\", torch_dtype=torch.bfloat16, device_map=\"cuda\",\n)\nprint(pipeline)\n\"StableDiffusionXLImg2ImgPipeline {\n  \"_class_name\": \"StableDiffusionXLImg2ImgPipeline\",\n  ...\n\"\n```\n\nLoading the same model with [`DiffusionPipeline`] returns the [`StableDiffusionXLPipeline`] subclass. It can be used for text-to-image, image-to-image, or inpainting tasks depending on the inputs.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"RunDiffusion/Juggernaut-XL-v9\", torch_dtype=torch.bfloat16, device_map=\"cuda\",\n)\nprint(pipeline)\n\"StableDiffusionXLPipeline {\n  \"_class_name\": \"StableDiffusionXLPipeline\",\n  ...\n\"\n```\n\nCheck the [mappings](https://github.com/huggingface/diffusers/blob/130fd8df54f24ffb006d84787b598d8adc899f23/src/diffusers/pipelines/auto_pipeline.py#L114) to see whether a model is supported or not.\n\nTrying to load an unsupported model returns an error.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"openai/shap-e-img2img\", torch_dtype=torch.float16,\n)\n\"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None\"\n```\n\nThere are three types of [AutoPipeline](../api/models/auto_model) classes, [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`] and [`AutoPipelineForInpainting`]. Each of these classes have a predefined mapping, linking a pipeline to their task-specific subclass.\n\nWhen [`~AutoPipelineForText2Image.from_pretrained`] is called, it extracts the class name from the `model_index.json` file and selects the appropriate pipeline subclass for the task based on the mapping."
  },
  {
    "path": "docs/source/en/tutorials/basic_training.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# Train a diffusion model\n\nUnconditional image generation is a popular application of diffusion models that generates images that look like those in the dataset used for training. Typically, the best results are obtained from finetuning a pretrained model on a specific dataset. You can find many of these checkpoints on the [Hub](https://huggingface.co/search/full-text?q=unconditional-image-generation&type=model), but if you can't find one you like, you can always train your own!\n\nThis tutorial will teach you how to train a [`UNet2DModel`] from scratch on a subset of the [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) dataset to generate your own 🦋 butterflies 🦋.\n\n> [!TIP]\n> 💡 This training tutorial is based on the [Training with 🧨 Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) notebook. For additional details and context about diffusion models like how they work, check out the notebook!\n\nBefore you begin, make sure you have 🤗 Datasets installed to load and preprocess image datasets, and 🤗 Accelerate, to simplify training on any number of GPUs. The following command will also install [TensorBoard](https://www.tensorflow.org/tensorboard) to visualize training metrics (you can also use [Weights & Biases](https://docs.wandb.ai/) to track your training).\n\n```py\n# uncomment to install the necessary libraries in Colab\n#!pip install diffusers[training]\n```\n\nWe encourage you to share your model with the community, and in order to do that, you'll need to login to your Hugging Face account (create one [here](https://hf.co/join) if you don't already have one!). You can login from a notebook and enter your token when prompted. Make sure your token has the write role.\n\n```py\n>>> from huggingface_hub import notebook_login\n\n>>> notebook_login()\n```\n\nOr login in from the terminal:\n\n```bash\nhf auth login\n```\n\nSince the model checkpoints are quite large, install [Git-LFS](https://git-lfs.com/) to version these large files:\n\n```bash\n!sudo apt -qq install git-lfs\n!git config --global credential.helper store\n```\n\n## Training configuration\n\nFor convenience, create a `TrainingConfig` class containing the training hyperparameters (feel free to adjust them):\n\n```py\n>>> from dataclasses import dataclass\n\n>>> @dataclass\n... class TrainingConfig:\n...     image_size = 128  # the generated image resolution\n...     train_batch_size = 16\n...     eval_batch_size = 16  # how many images to sample during evaluation\n...     num_epochs = 50\n...     gradient_accumulation_steps = 1\n...     learning_rate = 1e-4\n...     lr_warmup_steps = 500\n...     save_image_epochs = 10\n...     save_model_epochs = 30\n...     mixed_precision = \"fp16\"  # `no` for float32, `fp16` for automatic mixed precision\n...     output_dir = \"ddpm-butterflies-128\"  # the model name locally and on the HF Hub\n\n...     push_to_hub = True  # whether to upload the saved model to the HF Hub\n...     hub_model_id = \"<your-username>/<my-awesome-model>\"  # the name of the repository to create on the HF Hub\n...     hub_private_repo = None\n...     overwrite_output_dir = True  # overwrite the old model when re-running the notebook\n...     seed = 0\n\n\n>>> config = TrainingConfig()\n```\n\n## Load the dataset\n\nYou can easily load the [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) dataset with the 🤗 Datasets library:\n\n```py\n>>> from datasets import load_dataset\n\n>>> config.dataset_name = \"huggan/smithsonian_butterflies_subset\"\n>>> dataset = load_dataset(config.dataset_name, split=\"train\")\n```\n\n> [!TIP]\n> 💡 You can find additional datasets from the [HugGan Community Event](https://huggingface.co/huggan) or you can use your own dataset by creating a local [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder). Set `config.dataset_name` to the repository id of the dataset if it is from the HugGan Community Event, or `imagefolder` if you're using your own images.\n\n🤗 Datasets uses the [`~datasets.Image`] feature to automatically decode the image data and load it as a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html) which we can visualize:\n\n```py\n>>> import matplotlib.pyplot as plt\n\n>>> fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n>>> for i, image in enumerate(dataset[:4][\"image\"]):\n...     axs[i].imshow(image)\n...     axs[i].set_axis_off()\n>>> fig.show()\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_ds.png\"/>\n</div>\n\nThe images are all different sizes though, so you'll need to preprocess them first:\n\n* `Resize` changes the image size to the one defined in `config.image_size`.\n* `RandomHorizontalFlip` augments the dataset by randomly mirroring the images.\n* `Normalize` is important to rescale the pixel values into a [-1, 1] range, which is what the model expects.\n\n```py\n>>> from torchvision import transforms\n\n>>> preprocess = transforms.Compose(\n...     [\n...         transforms.Resize((config.image_size, config.image_size)),\n...         transforms.RandomHorizontalFlip(),\n...         transforms.ToTensor(),\n...         transforms.Normalize([0.5], [0.5]),\n...     ]\n... )\n```\n\nUse 🤗 Datasets' [`~datasets.Dataset.set_transform`] method to apply the `preprocess` function on the fly during training:\n\n```py\n>>> def transform(examples):\n...     images = [preprocess(image.convert(\"RGB\")) for image in examples[\"image\"]]\n...     return {\"images\": images}\n\n\n>>> dataset.set_transform(transform)\n```\n\nFeel free to visualize the images again to confirm that they've been resized. Now you're ready to wrap the dataset in a [DataLoader](https://pytorch.org/docs/stable/data#torch.utils.data.DataLoader) for training!\n\n```py\n>>> import torch\n\n>>> train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n```\n\n## Create a UNet2DModel\n\nPretrained models in 🧨 Diffusers are easily created from their model class with the parameters you want. For example, to create a [`UNet2DModel`]:\n\n```py\n>>> from diffusers import UNet2DModel\n\n>>> model = UNet2DModel(\n...     sample_size=config.image_size,  # the target image resolution\n...     in_channels=3,  # the number of input channels, 3 for RGB images\n...     out_channels=3,  # the number of output channels\n...     layers_per_block=2,  # how many ResNet layers to use per UNet block\n...     block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block\n...     down_block_types=(\n...         \"DownBlock2D\",  # a regular ResNet downsampling block\n...         \"DownBlock2D\",\n...         \"DownBlock2D\",\n...         \"DownBlock2D\",\n...         \"AttnDownBlock2D\",  # a ResNet downsampling block with spatial self-attention\n...         \"DownBlock2D\",\n...     ),\n...     up_block_types=(\n...         \"UpBlock2D\",  # a regular ResNet upsampling block\n...         \"AttnUpBlock2D\",  # a ResNet upsampling block with spatial self-attention\n...         \"UpBlock2D\",\n...         \"UpBlock2D\",\n...         \"UpBlock2D\",\n...         \"UpBlock2D\",\n...     ),\n... )\n```\n\nIt is often a good idea to quickly check the sample image shape matches the model output shape:\n\n```py\n>>> sample_image = dataset[0][\"images\"].unsqueeze(0)\n>>> print(\"Input shape:\", sample_image.shape)\nInput shape: torch.Size([1, 3, 128, 128])\n\n>>> print(\"Output shape:\", model(sample_image, timestep=0).sample.shape)\nOutput shape: torch.Size([1, 3, 128, 128])\n```\n\nGreat! Next, you'll need a scheduler to add some noise to the image.\n\n## Create a scheduler\n\nThe scheduler behaves differently depending on whether you're using the model for training or inference. During inference, the scheduler generates image from the noise. During training, the scheduler takes a model output - or a sample - from a specific point in the diffusion process and applies noise to the image according to a *noise schedule* and an *update rule*.\n\nLet's take a look at the [`DDPMScheduler`] and use the `add_noise` method to add some random noise to the `sample_image` from before:\n\n```py\n>>> import torch\n>>> from PIL import Image\n>>> from diffusers import DDPMScheduler\n\n>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)\n>>> noise = torch.randn(sample_image.shape)\n>>> timesteps = torch.LongTensor([50])\n>>> noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)\n\n>>> Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/noisy_butterfly.png\"/>\n</div>\n\nThe training objective of the model is to predict the noise added to the image. The loss at this step can be calculated by:\n\n```py\n>>> import torch.nn.functional as F\n\n>>> noise_pred = model(noisy_image, timesteps).sample\n>>> loss = F.mse_loss(noise_pred, noise)\n```\n\n## Train the model\n\nBy now, you have most of the pieces to start training the model and all that's left is putting everything together.\n\nFirst, you'll need an optimizer and a learning rate scheduler:\n\n```py\n>>> from diffusers.optimization import get_cosine_schedule_with_warmup\n\n>>> optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)\n>>> lr_scheduler = get_cosine_schedule_with_warmup(\n...     optimizer=optimizer,\n...     num_warmup_steps=config.lr_warmup_steps,\n...     num_training_steps=(len(train_dataloader) * config.num_epochs),\n... )\n```\n\nThen, you'll need a way to evaluate the model. For evaluation, you can use the [`DDPMPipeline`] to generate a batch of sample images and save it as a grid:\n\n```py\n>>> from diffusers import DDPMPipeline\n>>> from diffusers.utils import make_image_grid\n>>> import os\n\n>>> def evaluate(config, epoch, pipeline):\n...     # Sample some images from random noise (this is the backward diffusion process).\n...     # The default pipeline output type is `List[PIL.Image]`\n...     images = pipeline(\n...         batch_size=config.eval_batch_size,\n...         generator=torch.Generator(device='cpu').manual_seed(config.seed), # Use a separate torch generator to avoid rewinding the random state of the main training loop\n...     ).images\n\n...     # Make a grid out of the images\n...     image_grid = make_image_grid(images, rows=4, cols=4)\n\n...     # Save the images\n...     test_dir = os.path.join(config.output_dir, \"samples\")\n...     os.makedirs(test_dir, exist_ok=True)\n...     image_grid.save(f\"{test_dir}/{epoch:04d}.png\")\n```\n\nNow you can wrap all these components together in a training loop with 🤗 Accelerate for easy TensorBoard logging, gradient accumulation, and mixed precision training. To upload the model to the Hub, write a function to get your repository name and information and then push it to the Hub.\n\n> [!TIP]\n> 💡 The training loop below may look intimidating and long, but it'll be worth it later when you launch your training in just one line of code! If you can't wait and want to start generating images, feel free to copy and run the code below. You can always come back and examine the training loop more closely later, like when you're waiting for your model to finish training. 🤗\n\n```py\n>>> from accelerate import Accelerator\n>>> from huggingface_hub import create_repo, upload_folder\n>>> from tqdm.auto import tqdm\n>>> from pathlib import Path\n>>> import os\n\n>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):\n...     # Initialize accelerator and tensorboard logging\n...     accelerator = Accelerator(\n...         mixed_precision=config.mixed_precision,\n...         gradient_accumulation_steps=config.gradient_accumulation_steps,\n...         log_with=\"tensorboard\",\n...         project_dir=os.path.join(config.output_dir, \"logs\"),\n...     )\n...     if accelerator.is_main_process:\n...         if config.output_dir is not None:\n...             os.makedirs(config.output_dir, exist_ok=True)\n...         if config.push_to_hub:\n...             repo_id = create_repo(\n...                 repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True\n...             ).repo_id\n...         accelerator.init_trackers(\"train_example\")\n\n...     # Prepare everything\n...     # There is no specific order to remember, you just need to unpack the\n...     # objects in the same order you gave them to the prepare method.\n...     model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n...         model, optimizer, train_dataloader, lr_scheduler\n...     )\n\n...     global_step = 0\n\n...     # Now you train the model\n...     for epoch in range(config.num_epochs):\n...         progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)\n...         progress_bar.set_description(f\"Epoch {epoch}\")\n\n...         for step, batch in enumerate(train_dataloader):\n...             clean_images = batch[\"images\"]\n...             # Sample noise to add to the images\n...             noise = torch.randn(clean_images.shape, device=clean_images.device)\n...             bs = clean_images.shape[0]\n\n...             # Sample a random timestep for each image\n...             timesteps = torch.randint(\n...                 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,\n...                 dtype=torch.int64\n...             )\n\n...             # Add noise to the clean images according to the noise magnitude at each timestep\n...             # (this is the forward diffusion process)\n...             noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)\n\n...             with accelerator.accumulate(model):\n...                 # Predict the noise residual\n...                 noise_pred = model(noisy_images, timesteps, return_dict=False)[0]\n...                 loss = F.mse_loss(noise_pred, noise)\n...                 accelerator.backward(loss)\n\n...                 if accelerator.sync_gradients:\n...                     accelerator.clip_grad_norm_(model.parameters(), 1.0)\n...                 optimizer.step()\n...                 lr_scheduler.step()\n...                 optimizer.zero_grad()\n\n...             progress_bar.update(1)\n...             logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0], \"step\": global_step}\n...             progress_bar.set_postfix(**logs)\n...             accelerator.log(logs, step=global_step)\n...             global_step += 1\n\n...         # After each epoch you optionally sample some demo images with evaluate() and save the model\n...         if accelerator.is_main_process:\n...             pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)\n\n...             if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:\n...                 evaluate(config, epoch, pipeline)\n\n...             if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:\n...                 if config.push_to_hub:\n...                     upload_folder(\n...                         repo_id=repo_id,\n...                         folder_path=config.output_dir,\n...                         commit_message=f\"Epoch {epoch}\",\n...                         ignore_patterns=[\"step_*\", \"epoch_*\"],\n...                     )\n...                 else:\n...                     pipeline.save_pretrained(config.output_dir)\n```\n\nPhew, that was quite a bit of code! But you're finally ready to launch the training with 🤗 Accelerate's [`~accelerate.notebook_launcher`] function. Pass the function the training loop, all the training arguments, and the number of processes (you can change this value to the number of GPUs available to you) to use for training:\n\n```py\n>>> from accelerate import notebook_launcher\n\n>>> args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\n\n>>> notebook_launcher(train_loop, args, num_processes=1)\n```\n\nOnce training is complete, take a look at the final 🦋 images 🦋 generated by your diffusion model!\n\n```py\n>>> import glob\n\n>>> sample_images = sorted(glob.glob(f\"{config.output_dir}/samples/*.png\"))\n>>> Image.open(sample_images[-1])\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_final.png\"/>\n</div>\n\n## Next steps\n\nUnconditional image generation is one example of a task that can be trained. You can explore other tasks and training techniques by visiting the [🧨 Diffusers Training Examples](../training/overview) page. Here are some examples of what you can learn:\n\n* [Textual Inversion](../training/text_inversion), an algorithm that teaches a model a specific visual concept and integrates it into the generated image.\n* [DreamBooth](../training/dreambooth), a technique for generating personalized images of a subject given several input images of the subject.\n* [Guide](../training/text2image) to finetuning a Stable Diffusion model on your own dataset.\n* [Guide](../training/lora) to using LoRA, a memory-efficient technique for finetuning really large models faster.\n"
  },
  {
    "path": "docs/source/en/tutorials/using_peft_for_inference.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# LoRA\n\n[LoRA (Low-Rank Adaptation)](https://huggingface.co/papers/2106.09685) is a method for quickly training a model for a new task. It works by freezing the original model weights and adding a small number of *new* trainable parameters. This means it is significantly faster and cheaper to adapt an existing model to new tasks, such as generating images in a new style.\n\nLoRA checkpoints are typically only a couple hundred MBs in size, so they're very lightweight and easy to store. Load these smaller set of weights into an existing base model with [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] and specify the file name.\n\n<hfoptions id=\"usage\">\n<hfoption id=\"text-to-image\">\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_lora_weights(\n    \"ostris/super-cereal-sdxl-lora\",\n    weight_name=\"cereal_box_sdxl_v1.safetensors\",\n    adapter_name=\"cereal\"\n)\npipeline(\"bears, pizza bites\").images[0]\n```\n\n</hfoption>\n<hfoption id=\"text-to-video\">\n\n```py\nimport torch\nfrom diffusers import LTXConditionPipeline\nfrom diffusers.utils import export_to_video, load_image\n\npipeline = LTXConditionPipeline.from_pretrained(\n    \"Lightricks/LTX-Video-0.9.5\", torch_dtype=torch.bfloat16\n)\n\npipeline.load_lora_weights(\n    \"Lightricks/LTX-Video-Cakeify-LoRA\",\n    weight_name=\"ltxv_095_cakeify_lora.safetensors\",\n    adapter_name=\"cakeify\"\n)\npipeline.set_adapters(\"cakeify\")\n\n# use \"CAKEIFY\" to trigger the LoRA\nprompt = \"CAKEIFY a person using a knife to cut a cake shaped like a Pikachu plushie\"\nimage = load_image(\"https://huggingface.co/Lightricks/LTX-Video-Cakeify-LoRA/resolve/main/assets/images/pikachu.png\")\n\nvideo = pipeline(\n    prompt=prompt,\n    image=image,\n    width=576,\n    height=576,\n    num_frames=161,\n    decode_timestep=0.03,\n    decode_noise_scale=0.025,\n    num_inference_steps=50,\n).frames[0]\nexport_to_video(video, \"output.mp4\", fps=26)\n```\n\n</hfoption>\n</hfoptions>\n\nThe [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method is the preferred way to load LoRA weights into the UNet and text encoder because it can handle cases where:\n\n- the LoRA weights don't have separate UNet and text encoder identifiers\n- the LoRA weights have separate UNet and text encoder identifiers\n\nThe [`~loaders.PeftAdapterMixin.load_lora_adapter`] method is used to directly load a LoRA adapter at the *model-level*, as long as the model is a Diffusers model that is a subclass of [`PeftAdapterMixin`]. It builds and prepares the necessary model configuration for the adapter. This method also loads the LoRA adapter into the UNet.\n\nFor example, if you're only loading a LoRA into the UNet, [`~loaders.PeftAdapterMixin.load_lora_adapter`] ignores the text encoder keys. Use the `prefix` parameter to filter and load the appropriate state dicts, `\"unet\"` to load.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.unet.load_lora_adapter(\n    \"jbilcke-hf/sdxl-cinematic-1\",\n    weight_name=\"pytorch_lora_weights.safetensors\",\n    adapter_name=\"cinematic\",\n    prefix=\"unet\"\n)\n# use cnmt in the prompt to trigger the LoRA\npipeline(\"A cute cnmt eating a slice of pizza, stunning color scheme, masterpiece, illustration\").images[0]\n```\n\n## torch.compile\n\n[torch.compile](../optimization/fp16#torchcompile) speeds up inference by compiling the PyTorch model to use optimized kernels. Before compiling, the LoRA weights need to be fused into the base model and unloaded first.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\n# load base model and LoRA\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_lora_weights(\n    \"ostris/ikea-instructions-lora-sdxl\",\n    weight_name=\"ikea_instructions_xl_v1_5.safetensors\",\n    adapter_name=\"ikea\"\n)\n\n# activate LoRA and set adapter weight\npipeline.set_adapters(\"ikea\", adapter_weights=0.7)\n\n# fuse LoRAs and unload weights\npipeline.fuse_lora(adapter_names=[\"ikea\"], lora_scale=1.0)\npipeline.unload_lora_weights()\n```\n\nTypically, the UNet is compiled because its the most compute intensive component of the pipeline.\n\n```py\npipeline.unet.to(memory_format=torch.channels_last)\npipeline.unet = torch.compile(pipeline.unet, mode=\"reduce-overhead\", fullgraph=True)\n\npipeline(\"A bowl of ramen shaped like a cute kawaii bear\").images[0]\n```\n\nRefer to the [hotswapping](#hotswapping) section to learn how to avoid recompilation when working with compiled models and multiple LoRAs.\n\n## Weight scale\n\nThe `scale` parameter is used to control how much of a LoRA to apply. A value of `0` is equivalent to only using the base model weights and a value of `1` is equivalent to fully using the LoRA.\n\n<hfoptions id=\"weight-scale\">\n<hfoption id=\"simple use case\">\n\nFor simple use cases, you can pass `cross_attention_kwargs={\"scale\": 1.0}` to the pipeline.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_lora_weights(\n    \"ostris/super-cereal-sdxl-lora\",\n    weight_name=\"cereal_box_sdxl_v1.safetensors\",\n    adapter_name=\"cereal\"\n)\npipeline(\"bears, pizza bites\", cross_attention_kwargs={\"scale\": 1.0}).images[0]\n```\n\n</hfoption>\n<hfoption id=\"finer control\">\n\n> [!WARNING]\n> The [`~loaders.PeftAdapterMixin.set_adapters`] method only scales attention weights. If a LoRA has ResNets or down and upsamplers, these components keep a scale value of `1.0`.\n\nFor finer control over each individual component of the UNet or text encoder, pass a dictionary instead. In the example below, the `\"down\"` block in the UNet is scaled by 0.9 and you can further specify in the `\"up\"` block the scales of the transformers in `\"block_0\"` and `\"block_1\"`. If a block like `\"mid\"` isn't specified, the default value `1.0` is used.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_lora_weights(\n    \"ostris/super-cereal-sdxl-lora\",\n    weight_name=\"cereal_box_sdxl_v1.safetensors\",\n    adapter_name=\"cereal\"\n)\nscales = {\n    \"text_encoder\": 0.5,\n    \"text_encoder_2\": 0.5,\n    \"unet\": {\n        \"down\": 0.9,\n        \"up\": {\n            \"block_0\": 0.6,\n            \"block_1\": [0.4, 0.8, 1.0],\n        }\n    }\n}\npipeline.set_adapters(\"cereal\", scales)\npipeline(\"bears, pizza bites\").images[0]\n```\n\n</hfoption>\n</hfoptions>\n\n### Scale scheduling\n\nDynamically adjusting the LoRA scale during sampling gives you better control over the overall composition and layout because certain steps may benefit more from an increased or reduced scale.\n\nThe [character LoRA](https://huggingface.co/alvarobartt/ghibli-characters-flux-lora) in the example below starts with a higher scale that gradually decays over the first 20 steps to establish the character generation. In the later steps, only a scale of 0.2 is applied to avoid adding too much of the LoRA features to other parts of the image the LoRA wasn't trained on.\n\n```py\nimport torch\nfrom diffusers import FluxPipeline\n\npipeline = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\npipelne.load_lora_weights(\"alvarobartt/ghibli-characters-flux-lora\", \"lora\")\n\nnum_inference_steps = 30\nlora_steps = 20\nlora_scales = torch.linspace(1.5, 0.7, lora_steps).tolist()\nlora_scales += [0.2] * (num_inference_steps - lora_steps + 1)\n\npipeline.set_adapters(\"lora\", lora_scales[0])\n\ndef callback(pipeline: FluxPipeline, step: int, timestep: torch.LongTensor, callback_kwargs: dict):\n    pipeline.set_adapters(\"lora\", lora_scales[step + 1])\n    return callback_kwargs\n\nprompt = \"\"\"\nGhibli style The Grinch, a mischievous green creature with a sly grin, peeking out from behind a snow-covered tree while plotting his antics, \nin a quaint snowy village decorated for the holidays, warm light glowing from cozy homes, with playful snowflakes dancing in the air\n\"\"\"\npipeline(\n    prompt=prompt,\n    guidance_scale=3.0,\n    num_inference_steps=num_inference_steps,\n    generator=torch.Generator().manual_seed(42),\n    callback_on_step_end=callback,\n).images[0]\n```\n\n## Hotswapping\n\nHotswapping LoRAs is an efficient way to work with multiple LoRAs while avoiding accumulating memory from multiple calls to [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] and in some cases, recompilation, if a model is compiled. This workflow requires a loaded LoRA because the new LoRA weights are swapped in place for the existing loaded LoRA.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\n# load base model and LoRAs\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_lora_weights(\n    \"ostris/ikea-instructions-lora-sdxl\",\n    weight_name=\"ikea_instructions_xl_v1_5.safetensors\",\n    adapter_name=\"ikea\"\n)\n```\n\n> [!WARNING]\n> Hotswapping is unsupported for LoRAs that target the text encoder.\n\nSet `hotswap=True` in [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] to swap the second LoRA. Use the `adapter_name` parameter to indicate which LoRA to swap (`default_0` is the default name).\n\n```py\npipeline.load_lora_weights(\n    \"lordjia/by-feng-zikai\",\n    hotswap=True,\n    adapter_name=\"ikea\"\n)\n```\n\n### Compiled models\n\nFor compiled models, use [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] to avoid recompilation when hotswapping LoRAs. This method should be called *before* loading the first LoRA and `torch.compile` should be called *after* loading the first LoRA.\n\n> [!TIP]\n> The [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] method isn't always necessary if the second LoRA targets the identical LoRA ranks and scales as the first LoRA.\n\nWithin [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`], the `target_rank` parameter is important for setting the rank for all LoRA adapters. Setting it to `max_rank` sets it to the highest value. For LoRAs with different ranks, you set it to a higher rank value. The default rank value is 128.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\n# load base model and LoRAs\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\n# 1. enable_lora_hotswap\npipeline.enable_lora_hotswap(target_rank=max_rank)\npipeline.load_lora_weights(\n    \"ostris/ikea-instructions-lora-sdxl\",\n    weight_name=\"ikea_instructions_xl_v1_5.safetensors\",\n    adapter_name=\"ikea\"\n)\n# 2. torch.compile\npipeline.unet = torch.compile(pipeline.unet, mode=\"reduce-overhead\", fullgraph=True)\n\n# 3. hotswap\npipeline.load_lora_weights(\n    \"lordjia/by-feng-zikai\",\n    hotswap=True,\n    adapter_name=\"ikea\"\n)\n```\n\n> [!TIP]\n> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.\n\nIf you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.\n\nThere are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.\n\n<details>\n<summary>Technical details of hotswapping</summary>\n\nThe [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] method converts the LoRA scaling factor from floats to torch.tensors and pads the shape of the weights to the largest required shape to avoid reassigning the whole attribute when the data in the weights are replaced.\n\nThis is why the `max_rank` argument is important. The results are unchanged even when the values are padded with zeros. Computation may be slower though depending on the padding size.\n\nSince no new LoRA attributes are added, each subsequent LoRA is only allowed to target the same layers, or subset of layers, the first LoRA targets. Choosing the LoRA loading order is important because if the LoRAs target disjoint layers, you may end up creating a dummy LoRA that targets the union of all target layers.\n\nFor more implementation details, take a look at the [`hotswap.py`](https://github.com/huggingface/peft/blob/92d65cafa51c829484ad3d95cf71d09de57ff066/src/peft/utils/hotswap.py) file.\n\n</details>\n\n## Merge\n\nThe weights from each LoRA can be merged together to produce a blend of multiple existing styles. There are several methods for merging LoRAs, each of which differ in *how* the weights are merged (may affect generation quality).\n\n### set_adapters\n\nThe [`~loaders.PeftAdapterMixin.set_adapters`] method merges LoRAs by concatenating their weighted matrices. Pass the LoRA names to [`~loaders.PeftAdapterMixin.set_adapters`] and use the `adapter_weights` parameter to control the scaling of each LoRA. For example, if `adapter_weights=[0.5, 0.5]`, the output is an average of both LoRAs.\n\n> [!TIP]\n> The `\"scale\"` parameter determines how much of the merged LoRA to apply. See the [Weight scale](#weight-scale) section for more details.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_lora_weights(\n    \"ostris/ikea-instructions-lora-sdxl\",\n    weight_name=\"ikea_instructions_xl_v1_5.safetensors\",\n    adapter_name=\"ikea\"\n)\npipeline.load_lora_weights(\n    \"lordjia/by-feng-zikai\",\n    weight_name=\"fengzikai_v1.0_XL.safetensors\",\n    adapter_name=\"feng\"\n)\npipeline.set_adapters([\"ikea\", \"feng\"], adapter_weights=[0.7, 0.8])\n# use by Feng Zikai to activate the lordjia/by-feng-zikai LoRA\npipeline(\"A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai\", cross_attention_kwargs={\"scale\": 1.0}).images[0]\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lora_merge_set_adapters.png\"/>\n</div>\n\n### add_weighted_adapter\n\n> [!TIP]\n> This is an experimental method and you can refer to PEFTs [Model merging](https://huggingface.co/docs/peft/developer_guides/model_merging) for more details. Take a look at this [issue](https://github.com/huggingface/diffusers/issues/6892) if you're interested in the motivation and design behind this integration.\n\nThe [`~peft.LoraModel.add_weighted_adapter`] method enables more efficient merging methods like [TIES](https://huggingface.co/papers/2306.01708) or [DARE](https://huggingface.co/papers/2311.03099). These merging methods remove redundant and potentially interfering parameters from merged models. Keep in mind the LoRA ranks need to have identical ranks to be merged.\n\nMake sure the latest stable version of Diffusers and PEFT is installed.\n\n```bash\npip install -U -q diffusers peft\n```\n\nLoad a UNET that corresponds to the LoRA UNet.\n\n```py\nimport copy\nimport torch\nfrom diffusers import AutoModel, DiffusionPipeline\nfrom peft import get_peft_model, LoraConfig, PeftModel\n\nunet = AutoModel.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n    variant=\"fp16\",\n    subfolder=\"unet\",\n).to(\"cuda\")\n```\n\nLoad a pipeline, pass the UNet to it, and load a LoRA.\n\n```py\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    variant=\"fp16\",\n    torch_dtype=torch.float16,\n    unet=unet\n).to(\"cuda\")\npipeline.load_lora_weights(\n    \"ostris/ikea-instructions-lora-sdxl\",\n    weight_name=\"ikea_instructions_xl_v1_5.safetensors\",\n    adapter_name=\"ikea\"\n)\n```\n\nCreate a [`~peft.PeftModel`] from the LoRA checkpoint by combining the first UNet you loaded and the LoRA UNet from the pipeline.\n\n```py\nsdxl_unet = copy.deepcopy(unet)\nikea_peft_model = get_peft_model(\n    sdxl_unet,\n    pipeline.unet.peft_config[\"ikea\"],\n    adapter_name=\"ikea\"\n)\n\noriginal_state_dict = {f\"base_model.model.{k}\": v for k, v in pipeline.unet.state_dict().items()}\nikea_peft_model.load_state_dict(original_state_dict, strict=True)\n```\n\n> [!TIP]\n> You can save and reuse the `ikea_peft_model` by pushing it to the Hub as shown below.\n> ```py\n> ikea_peft_model.push_to_hub(\"ikea_peft_model\", token=TOKEN)\n> ```\n\nRepeat this process and create a [`~peft.PeftModel`] for the second LoRA.\n\n```py\npipeline.delete_adapters(\"ikea\")\nsdxl_unet.delete_adapters(\"ikea\")\n\npipeline.load_lora_weights(\n    \"lordjia/by-feng-zikai\",\n    weight_name=\"fengzikai_v1.0_XL.safetensors\",\n    adapter_name=\"feng\"\n)\npipeline.set_adapters(adapter_names=\"feng\")\n\nfeng_peft_model = get_peft_model(\n    sdxl_unet,\n    pipeline.unet.peft_config[\"feng\"],\n    adapter_name=\"feng\"\n)\n\noriginal_state_dict = {f\"base_model.model.{k}\": v for k, v in pipe.unet.state_dict().items()}\nfeng_peft_model.load_state_dict(original_state_dict, strict=True)\n```\n\nLoad a base UNet model and load the adapters.\n\n```py\nbase_unet = AutoModel.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n    variant=\"fp16\",\n    subfolder=\"unet\",\n).to(\"cuda\")\n\nmodel = PeftModel.from_pretrained(\n    base_unet,\n    \"stevhliu/ikea_peft_model\",\n    use_safetensors=True,\n    subfolder=\"ikea\",\n    adapter_name=\"ikea\"\n)\nmodel.load_adapter(\n    \"stevhliu/feng_peft_model\",\n    use_safetensors=True,\n    subfolder=\"feng\",\n    adapter_name=\"feng\"\n)\n```\n\nMerge the LoRAs with [`~peft.LoraModel.add_weighted_adapter`] and specify how you want to merge them with `combination_type`. The example below uses the `\"dare_linear\"` method (refer to this [blog post](https://huggingface.co/blog/peft_merging) to learn more about these merging methods), which randomly prunes some weights and then performs a weighted sum of the tensors based on the set weightage of each LoRA in `weights`.\n\nActivate the merged LoRAs with [`~loaders.PeftAdapterMixin.set_adapters`].\n\n```py\nmodel.add_weighted_adapter(\n    adapters=[\"ikea\", \"feng\"],\n    combination_type=\"dare_linear\",\n    weights=[1.0, 1.0],\n    adapter_name=\"ikea-feng\"\n)\nmodel.set_adapters(\"ikea-feng\")\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    unet=model,\n    variant=\"fp16\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\npipeline(\"A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai\").images[0]\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ikea-feng-dare-linear.png\"/>\n</div>\n\n### fuse_lora\n\nThe [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method fuses the LoRA weights directly with the original UNet and text encoder weights of the underlying model. This reduces the overhead of loading the underlying model for each LoRA because it only loads the model once, which lowers memory usage and increases inference speed.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_lora_weights(\n    \"ostris/ikea-instructions-lora-sdxl\",\n    weight_name=\"ikea_instructions_xl_v1_5.safetensors\",\n    adapter_name=\"ikea\"\n)\npipeline.load_lora_weights(\n    \"lordjia/by-feng-zikai\",\n    weight_name=\"fengzikai_v1.0_XL.safetensors\",\n    adapter_name=\"feng\"\n)\npipeline.set_adapters([\"ikea\", \"feng\"], adapter_weights=[0.7, 0.8])\n```\n\nCall [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] to fuse them. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make this adjustment now because passing `scale` to `cross_attention_kwargs` won't work in the pipeline.\n\n```py\npipeline.fuse_lora(adapter_names=[\"ikea\", \"feng\"], lora_scale=1.0)\n```\n\nUnload the LoRA weights since they're already fused with the underlying model. Save the fused pipeline with either [`~DiffusionPipeline.save_pretrained`] to save it locally or [`~PushToHubMixin.push_to_hub`] to save it to the Hub.\n\n<hfoptions id=\"save\">\n<hfoption id=\"save locally\">\n\n```py\npipeline.unload_lora_weights()\npipeline.save_pretrained(\"path/to/fused-pipeline\")\n```\n\n</hfoption>\n<hfoption id=\"save to Hub\">\n\n```py\npipeline.unload_lora_weights()\npipeline.push_to_hub(\"fused-ikea-feng\")\n```\n\n</hfoption>\n</hfoptions>\n\nThe fused pipeline can now be quickly loaded for inference without requiring each LoRA to be separately loaded.\n\n```py\npipeline = DiffusionPipeline.from_pretrained(\n    \"username/fused-ikea-feng\", torch_dtype=torch.float16,\n).to(\"cuda\")\npipeline(\"A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai\").images[0]\n```\n\nUse [`~loaders.LoraLoaderMixin.unfuse_lora`] to restore the underlying models weights, for example, if you want to use a different `lora_scale` value. You can only unfuse if there is a single LoRA fused. For example, it won't work with the pipeline from above because there are multiple fused LoRAs. In these cases, you'll need to reload the entire model.\n\n```py\npipeline.unfuse_lora()\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/fuse_lora.png\"/>\n</div>\n\n## Manage\n\nDiffusers provides several methods to help you manage working with LoRAs. These methods can be especially useful if you're working with multiple LoRAs.\n\n### set_adapters\n\n[`~loaders.PeftAdapterMixin.set_adapters`] also activates the current LoRA to use if there are multiple active LoRAs. This allows you to switch between different LoRAs by specifying their name.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_lora_weights(\n    \"ostris/ikea-instructions-lora-sdxl\",\n    weight_name=\"ikea_instructions_xl_v1_5.safetensors\",\n    adapter_name=\"ikea\"\n)\npipeline.load_lora_weights(\n    \"lordjia/by-feng-zikai\",\n    weight_name=\"fengzikai_v1.0_XL.safetensors\",\n    adapter_name=\"feng\"\n)\n# activates the feng LoRA instead of the ikea LoRA\npipeline.set_adapters(\"feng\")\n```\n\n### save_lora_adapter\n\nSave an adapter with [`~loaders.PeftAdapterMixin.save_lora_adapter`].\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.unet.load_lora_adapter(\n    \"jbilcke-hf/sdxl-cinematic-1\",\n    weight_name=\"pytorch_lora_weights.safetensors\",\n    adapter_name=\"cinematic\"\n    prefix=\"unet\"\n)\npipeline.save_lora_adapter(\"path/to/save\", adapter_name=\"cinematic\")\n```\n\n### unload_lora_weights\n\nThe [`~loaders.lora_base.LoraBaseMixin.unload_lora_weights`] method unloads any LoRA weights in the pipeline to restore the underlying model weights.\n\n```py\npipeline.unload_lora_weights()\n```\n\n### disable_lora\n\nThe [`~loaders.PeftAdapterMixin.disable_lora`] method disables all LoRAs (but they're still kept on the pipeline) and restores the pipeline to the underlying model weights.\n\n```py\npipeline.disable_lora()\n```\n\n### get_active_adapters\n\nThe [`~loaders.lora_base.LoraBaseMixin.get_active_adapters`] method returns a list of active LoRAs attached to a pipeline.\n\n```py\npipeline.get_active_adapters()\n[\"cereal\", \"ikea\"]\n```\n\n### get_list_adapters\n\nThe [`~loaders.lora_base.LoraBaseMixin.get_list_adapters`] method returns the active LoRAs for each component in the pipeline.\n\n```py\npipeline.get_list_adapters()\n{\"unet\": [\"cereal\", \"ikea\"], \"text_encoder_2\": [\"cereal\"]}\n```\n\n### delete_adapters\n\nThe [`~loaders.PeftAdapterMixin.delete_adapters`] method completely removes a LoRA and its layers from a model.\n\n```py\npipeline.delete_adapters(\"ikea\")\n```\n\n## Resources\n\nBrowse the [LoRA Studio](https://lorastudio.co/models) for different LoRAs to use or you can upload your favorite LoRAs from Civitai to the Hub with the Space below.\n\n<iframe\n\tsrc=\"https://multimodalart-civitai-to-hf.hf.space\"\n\tframeborder=\"0\"\n\twidth=\"850\"\n\theight=\"450\"\n></iframe>\n\nYou can find additional LoRAs in the [FLUX LoRA the Explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer) and [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer) Spaces.\n\nCheck out the [Fast LoRA inference for Flux with Diffusers and PEFT](https://huggingface.co/blog/lora-fast) blog post to learn how to optimize LoRA inference with methods like FlashAttention-3 and fp8 quantization.\n"
  },
  {
    "path": "docs/source/en/using-diffusers/automodel.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AutoModel\n\nThe [`AutoModel`] class automatically detects and loads the correct model class (UNet, transformer, VAE) from a `config.json` file. You don't need to know the specific model class name ahead of time. It supports data types and device placement, and works across model types and libraries.\n\nThe example below loads a transformer from Diffusers and a text encoder from Transformers. Use the `subfolder` parameter to specify where to load the `config.json` file from.\n\n```py\nimport torch\nfrom diffusers import AutoModel, DiffusionPipeline\n\ntransformer = AutoModel.from_pretrained(\n    \"Qwen/Qwen-Image\", subfolder=\"transformer\", torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\n\ntext_encoder = AutoModel.from_pretrained(\n    \"Qwen/Qwen-Image\", subfolder=\"text_encoder\", torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\n```\n\n## Custom models\n\n[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.\n\nA custom model repository needs a Python module with the model class, and a `config.json` with an `auto_map` entry that maps `\"AutoModel\"` to `\"module_file.ClassName\"`.\n\n```\ncustom/custom-transformer-model/\n├── config.json\n├── my_model.py\n└── diffusion_pytorch_model.safetensors\n```\n\nThe `config.json` includes the `auto_map` field pointing to the custom class.\n\n```json\n{\n  \"auto_map\": {\n    \"AutoModel\": \"my_model.MyCustomModel\"\n  }\n}\n```\n\nThen load it with `trust_remote_code=True`.\n\n```py\nimport torch\nfrom diffusers import AutoModel\n\ntransformer = AutoModel.from_pretrained(\n    \"custom/custom-transformer-model\", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\n```\n\nFor a real-world example, [Overworld/Waypoint-1-Small](https://huggingface.co/Overworld/Waypoint-1-Small/tree/main/transformer) hosts a custom `WorldModel` class across several modules in its `transformer` subfolder.\n\n```\ntransformer/\n├── config.json          # auto_map: \"model.WorldModel\"\n├── model.py\n├── attn.py\n├── nn.py\n├── cache.py\n├── quantize.py\n├── __init__.py\n└── diffusion_pytorch_model.safetensors\n```\n\n```py\nimport torch\nfrom diffusers import AutoModel\n\ntransformer = AutoModel.from_pretrained(\n    \"Overworld/Waypoint-1-Small\", subfolder=\"transformer\", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\n```\n\nIf the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).\n\n> [!WARNING]\n> As a precaution with `trust_remote_code=True`, pass a commit hash to the `revision` argument in [`AutoModel.from_pretrained`] to make sure the code hasn't been updated with new malicious code (unless you fully trust the model owners).\n>\n> ```py\n> transformer = AutoModel.from_pretrained(\n>     \"Overworld/Waypoint-1-Small\", subfolder=\"transformer\", trust_remote_code=True, revision=\"a3d8cb2\"\n> )\n> ```\n\n### Saving custom models\n\nUse [`~ConfigMixin.register_for_auto_class`] to add the `auto_map` entry to `config.json` automatically when saving. This avoids having to manually edit the config file.\n\n```py\n# my_model.py\nfrom diffusers import ModelMixin, ConfigMixin\n\nclass MyCustomModel(ModelMixin, ConfigMixin):\n    ...\n\nMyCustomModel.register_for_auto_class(\"AutoModel\")\n\nmodel = MyCustomModel(...)\nmodel.save_pretrained(\"./my_model\")\n```\n\nThe saved `config.json` will include the `auto_map` field.\n\n```json\n{\n  \"auto_map\": {\n    \"AutoModel\": \"my_model.MyCustomModel\"\n  }\n}\n```\n\n> [!NOTE]\n> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide."
  },
  {
    "path": "docs/source/en/using-diffusers/batched_inference.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Batch inference\n\nBatch inference processes multiple prompts at a time to increase throughput. It is more efficient because processing multiple prompts at once maximizes GPU usage versus processing a single prompt and underutilizing the GPU.\n\nThe downside is increased latency because you must wait for the entire batch to complete, and more GPU memory is required for large batches.\n\nFor text-to-image, pass a list of prompts to the pipeline and for image-to-image, pass a list of images and prompts to the pipeline. The example below demonstrates batched text-to-image inference.\n\n```py\nimport torch\nimport matplotlib.pyplot as plt\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\n\nprompts = [\n    \"Cinematic shot of a cozy coffee shop interior, warm pastel light streaming through a window where a cat rests. Shallow depth of field, glowing cups in soft focus, dreamy lofi-inspired mood, nostalgic tones, framed like a quiet film scene.\",\n    \"Polaroid-style photograph of a cozy coffee shop interior, bathed in warm pastel light. A cat sits on the windowsill near steaming mugs. Soft, slightly faded tones and dreamy blur evoke nostalgia, a lofi mood, and the intimate, imperfect charm of instant film.\",\n    \"Soft watercolor illustration of a cozy coffee shop interior, pastel washes of color filling the space. A cat rests peacefully on the windowsill as warm light glows through. Gentle brushstrokes create a dreamy, lofi-inspired atmosphere with whimsical textures and nostalgic calm.\",\n    \"Isometric pixel-art illustration of a cozy coffee shop interior in detailed 8-bit style. Warm pastel light fills the space as a cat rests on the windowsill. Blocky furniture and tiny mugs add charm, low-res retro graphics enhance the nostalgic, lofi-inspired game aesthetic.\"\n]\n\nimages = pipeline(\n    prompt=prompts,\n).images\n\nfig, axes = plt.subplots(2, 2, figsize=(12, 12))\naxes = axes.flatten()\n\nfor i, image in enumerate(images):\n    axes[i].imshow(image)\n    axes[i].set_title(f\"Image {i+1}\")\n    axes[i].axis('off')\n\nplt.tight_layout()\nplt.show()\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/batch-inference.png\"/>\n</div>\n\nTo generate multiple variations of one prompt, use the `num_images_per_prompt` argument.\n\n```py\nimport torch\nimport matplotlib.pyplot as plt\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\n\nprompt=\"\"\"\nIsometric pixel-art illustration of a cozy coffee shop interior in detailed 8-bit style. Warm pastel light fills the\nspace as a cat rests on the windowsill. Blocky furniture and tiny mugs add charm, low-res retro graphics enhance the\nnostalgic, lofi-inspired game aesthetic.\n\"\"\"\n\nimages = pipeline(\n    prompt=prompt,\n    num_images_per_prompt=4\n).images\n\nfig, axes = plt.subplots(2, 2, figsize=(12, 12))\naxes = axes.flatten()\n\nfor i, image in enumerate(images):\n    axes[i].imshow(image)\n    axes[i].set_title(f\"Image {i+1}\")\n    axes[i].axis('off')\n\nplt.tight_layout()\nplt.show()\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/batch-inference-2.png\"/>\n</div>\n\nCombine both approaches to generate different variations of different prompts.\n\n```py\nimages = pipeline(\n    prompt=prompts,\n    num_images_per_prompt=2,\n).images\n\nfig, axes = plt.subplots(2, 4, figsize=(12, 12))\naxes = axes.flatten()\n\nfor i, image in enumerate(images):\n    axes[i].imshow(image)\n    axes[i].set_title(f\"Image {i+1}\")\n    axes[i].axis('off')\n\nplt.tight_layout()\nplt.show()\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/batch-inference-3.png\"/>\n</div>\n\n## Deterministic generation\n\nEnable reproducible batch generation by passing a list of [Generator’s](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed to reuse it.\n\n> [!TIP]\n> Refer to the [Reproducibility](./reusing_seeds) docs to learn more about deterministic algorithms and the `Generator` object.\n\nUse a list comprehension to iterate over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch. Don't multiply the `Generator` by the batch size because that only creates one `Generator` object that is used sequentially for each image in the batch.\n\n```py\ngenerator = [torch.Generator(device=\"cuda\").manual_seed(0)] * 3\n```\n\nPass the `generator` to the pipeline.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\n\ngenerator = [torch.Generator(device=\"cuda\").manual_seed(i) for i in range(3)]\nprompts = [\n    \"Cinematic shot of a cozy coffee shop interior, warm pastel light streaming through a window where a cat rests. Shallow depth of field, glowing cups in soft focus, dreamy lofi-inspired mood, nostalgic tones, framed like a quiet film scene.\",\n    \"Polaroid-style photograph of a cozy coffee shop interior, bathed in warm pastel light. A cat sits on the windowsill near steaming mugs. Soft, slightly faded tones and dreamy blur evoke nostalgia, a lofi mood, and the intimate, imperfect charm of instant film.\",\n    \"Soft watercolor illustration of a cozy coffee shop interior, pastel washes of color filling the space. A cat rests peacefully on the windowsill as warm light glows through. Gentle brushstrokes create a dreamy, lofi-inspired atmosphere with whimsical textures and nostalgic calm.\",\n    \"Isometric pixel-art illustration of a cozy coffee shop interior in detailed 8-bit style. Warm pastel light fills the space as a cat rests on the windowsill. Blocky furniture and tiny mugs add charm, low-res retro graphics enhance the nostalgic, lofi-inspired game aesthetic.\"\n]\n\nimages = pipeline(\n    prompt=prompts,\n    generator=generator\n).images\n\nfig, axes = plt.subplots(2, 2, figsize=(12, 12))\naxes = axes.flatten()\n\nfor i, image in enumerate(images):\n    axes[i].imshow(image)\n    axes[i].set_title(f\"Image {i+1}\")\n    axes[i].axis('off')\n\nplt.tight_layout()\nplt.show()\n```\n\nYou can use this to select an image associated with a seed and iteratively improve on it by crafting a more detailed prompt."
  },
  {
    "path": "docs/source/en/using-diffusers/callback.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Pipeline callbacks\n\nA callback is a function that modifies [`DiffusionPipeline`] behavior and it is executed at the end of a denoising step. The changes are propagated to subsequent steps in the denoising process. It is useful for adjusting pipeline attributes or tensor variables to support new features without rewriting the underlying pipeline code.\n\nDiffusers provides several callbacks in the pipeline [overview](../api/pipelines/overview#callbacks).\n\nTo enable a callback, configure when the callback is executed after a certain number of denoising steps with one of the following arguments.\n\n- `cutoff_step_ratio` specifies when a callback is activated as a percentage of the total denoising steps.\n- `cutoff_step_index` specifies the exact step number a callback is activated.\n\nThe example below uses `cutoff_step_ratio=0.4`, which means the callback is activated once denoising reaches 40% of the total inference steps. [`~callbacks.SDXLCFGCutoffCallback`] disables classifier-free guidance (CFG) after a certain number of steps, which can help save compute without significantly affecting performance.\n\nDefine a callback with either of the `cutoff` arguments and pass it to the `callback_on_step_end` parameter in the pipeline.\n\n```py\nimport torch\nfrom diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline\nfrom diffusers.callbacks import SDXLCFGCutoffCallback\n\ncallback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)\n# if using cutoff_step_index\n# callback = SDXLCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10)\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True)\n\nprompt = \"a sports car at the road, best quality, high quality, high detail, 8k resolution\"\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=\"\",\n    guidance_scale=6.5,\n    num_inference_steps=25,\n    generator=generator,\n    callback_on_step_end=callback,\n)\n```\n\nIf you want to add a new official callback, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) or [submit a PR](https://huggingface.co/docs/diffusers/main/en/conceptual/contribution#how-to-open-a-pr). Otherwise, you can also create your own callback as shown below.\n\n## Early stopping\n\nEarly stopping is useful if you aren't happy with the intermediate results during generation. This callback sets a hardcoded stop point after which the pipeline terminates by setting the `_interrupt` attribute to `True`.\n\n```py\nfrom diffusers import StableDiffusionXLPipeline\n\ndef interrupt_callback(pipeline, i, t, callback_kwargs):\n    stop_idx = 10\n    if i == stop_idx:\n        pipeline._interrupt = True\n\n    return callback_kwargs\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n)\nnum_inference_steps = 50\n\npipeline(\n    \"A photo of a cat\",\n    num_inference_steps=num_inference_steps,\n    callback_on_step_end=interrupt_callback,\n)\n```\n\n## Display intermediate images\n\nVisualizing the intermediate images is useful for progress monitoring and assessing the quality of the generated content. This callback decodes the latent tensors at each step and converts them to images.\n\n[Convert](https://huggingface.co/blog/TimothyAlexisVass/explaining-the-sdxl-latent-space) the Stable Diffusion XL latents from latents (4 channels) to RGB tensors (3 tensors).\n\n```py\ndef latents_to_rgb(latents):\n    weights = (\n        (60, -60, 25, -70),\n        (60,  -5, 15, -50),\n        (60,  10, -5, -35),\n    )\n\n    weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))\n    biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)\n    rgb_tensor = torch.einsum(\"...lxy,lr -> ...rxy\", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)\n    image_array = rgb_tensor.clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)\n\n    return Image.fromarray(image_array)\n```\n\nExtract the latents and convert the first image in the batch to RGB. Save the image as a PNG file with the step number.\n\n```py\ndef decode_tensors(pipe, step, timestep, callback_kwargs):\n    latents = callback_kwargs[\"latents\"]\n\n    image = latents_to_rgb(latents[0])\n    image.save(f\"{step}.png\")\n\n    return callback_kwargs\n```\n\nUse the `callback_on_step_end_tensor_inputs` parameter to specify what input type to modify, which in this case, are the latents.\n\n```py\nimport torch\nfrom PIL import Image\nfrom diffusers import AutoPipelineForText2Image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\n\nimage = pipeline(\n    prompt=\"A croissant shaped like a cute bear.\",\n    negative_prompt=\"Deformed, ugly, bad anatomy\",\n    callback_on_step_end=decode_tensors,\n    callback_on_step_end_tensor_inputs=[\"latents\"],\n).images[0]\n```\n"
  },
  {
    "path": "docs/source/en/using-diffusers/conditional_image_generation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Text-to-image\n\n[[open-in-colab]]\n\nWhen you think of diffusion models, text-to-image is usually one of the first things that come to mind. Text-to-image generates an image from a text description (for example, \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\") which is also known as a *prompt*.\n\nFrom a very high level, a diffusion model takes a prompt and some random initial noise, and iteratively removes the noise to construct an image. The *denoising* process is guided by the prompt, and once the denoising process ends after a predetermined number of time steps, the image representation is decoded into an image.\n\n> [!TIP]\n> Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog post to learn more about how a latent diffusion model works.\n\nYou can generate images from a prompt in 🤗 Diffusers in two steps:\n\n1. Load a checkpoint into the [`AutoPipelineForText2Image`] class, which automatically detects the appropriate pipeline class to use based on the checkpoint:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n\t\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\"\n).to(\"cuda\")\n```\n\n2. Pass a prompt to the pipeline to generate an image:\n\n```py\nimage = pipeline(\n\t\"stained glass of darth vader, backlight, centered composition, masterpiece, photorealistic, 8k\"\n).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n\t<img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-vader.png\"/>\n</div>\n\n## Popular models\n\nThe most common text-to-image models are [Stable Diffusion v1.5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5), [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). There are also ControlNet models or adapters that can be used with text-to-image models for more direct control in generating images. The results from each model are slightly different because of their architecture and training process, but no matter which model you choose, their usage is more or less the same. Let's use the same prompt for each model and compare their results.\n\n### Stable Diffusion v1.5\n\n[Stable Diffusion v1.5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) is a latent diffusion model initialized from [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4), and finetuned for 595K steps on 512x512 images from the LAION-Aesthetics V2 dataset. You can use this model like:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n\t\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\"\n).to(\"cuda\")\ngenerator = torch.Generator(\"cuda\").manual_seed(31)\nimage = pipeline(\"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\", generator=generator).images[0]\nimage\n```\n\n### Stable Diffusion XL\n\nSDXL is a much larger version of the previous Stable Diffusion models, and involves a two-stage model process that adds even more details to an image. It also includes some additional *micro-conditionings* to generate high-quality images centered subjects. Take a look at the more comprehensive [SDXL](sdxl) guide to learn more about how to use it. In general, you can use SDXL like:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\"\n).to(\"cuda\")\ngenerator = torch.Generator(\"cuda\").manual_seed(31)\nimage = pipeline(\"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\", generator=generator).images[0]\nimage\n```\n\n### Kandinsky 2.2\n\nThe Kandinsky model is a bit different from the Stable Diffusion models because it also uses an image prior model to create embeddings that are used to better align text and images in the diffusion model.\n\nThe easiest way to use Kandinsky 2.2 is:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n\t\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16\n).to(\"cuda\")\ngenerator = torch.Generator(\"cuda\").manual_seed(31)\nimage = pipeline(\"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\", generator=generator).images[0]\nimage\n```\n\n### ControlNet\n\nControlNet models are auxiliary models or adapters that are finetuned on top of text-to-image models, such as [Stable Diffusion v1.5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5). Using ControlNet models in combination with text-to-image models offers diverse options for more explicit control over how to generate an image. With ControlNet, you add an additional conditioning input image to the model. For example, if you provide an image of a human pose (usually represented as multiple keypoints that are connected into a skeleton) as a conditioning input, the model generates an image that follows the pose of the image. Check out the more in-depth [ControlNet](controlnet) guide to learn more about other conditioning inputs and how to use them.\n\nIn this example, let's condition the ControlNet with a human pose estimation image. Load the ControlNet model pretrained on human pose estimations:\n\n```py\nfrom diffusers import ControlNetModel, AutoPipelineForText2Image\nfrom diffusers.utils import load_image\nimport torch\n\ncontrolnet = ControlNetModel.from_pretrained(\n\t\"lllyasviel/control_v11p_sd15_openpose\", torch_dtype=torch.float16, variant=\"fp16\"\n).to(\"cuda\")\npose_image = load_image(\"https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/control.png\")\n```\n\nPass the `controlnet` to the [`AutoPipelineForText2Image`], and provide the prompt and pose estimation image:\n\n```py\npipeline = AutoPipelineForText2Image.from_pretrained(\n\t\"stable-diffusion-v1-5/stable-diffusion-v1-5\", controlnet=controlnet, torch_dtype=torch.float16, variant=\"fp16\"\n).to(\"cuda\")\ngenerator = torch.Generator(\"cuda\").manual_seed(31)\nimage = pipeline(\"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\", image=pose_image, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-1.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">Stable Diffusion v1.5</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">Stable Diffusion XL</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-2.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">Kandinsky 2.2</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-3.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">ControlNet (pose conditioning)</figcaption>\n  </div>\n</div>\n\n## Configure pipeline parameters\n\nThere are a number of parameters that can be configured in the pipeline that affect how an image is generated. You can change the image's output size, specify a negative prompt to improve image quality, and more. This section dives deeper into how to use these parameters.\n\n### Height and width\n\nThe `height` and `width` parameters control the height and width (in pixels) of the generated image. By default, the Stable Diffusion v1.5 model outputs 512x512 images, but you can change this to any size that is a multiple of 8. For example, to create a rectangular image:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n\t\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\"\n).to(\"cuda\")\nimage = pipeline(\n\t\"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\", height=768, width=512\n).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n\t<img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-hw.png\"/>\n</div>\n\n> [!WARNING]\n> Other models may have different default image sizes depending on the image sizes in the training dataset. For example, SDXL's default image size is 1024x1024 and using lower `height` and `width` values may result in lower quality images. Make sure you check the model's API reference first!\n\n### Guidance scale\n\nThe `guidance_scale` parameter affects how much the prompt influences image generation. A lower value gives the model \"creativity\" to generate images that are more loosely related to the prompt. Higher `guidance_scale` values push the model to follow the prompt more closely, and if this value is too high, you may observe some artifacts in the generated image.\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n\t\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16\n).to(\"cuda\")\nimage = pipeline(\n\t\"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\", guidance_scale=3.5\n).images[0]\nimage\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-guidance-scale-2.5.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">guidance_scale = 2.5</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-guidance-scale-7.5.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">guidance_scale = 7.5</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-guidance-scale-10.5.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">guidance_scale = 10.5</figcaption>\n  </div>\n</div>\n\n### Negative prompt\n\nJust like how a prompt guides generation, a *negative prompt* steers the model away from things you don't want the model to generate. This is commonly used to improve overall image quality by removing poor or bad image features such as \"low resolution\" or \"bad details\". You can also use a negative prompt to remove or modify the content and style of an image.\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n\t\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16\n).to(\"cuda\")\nimage = pipeline(\n\tprompt=\"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\",\n\tnegative_prompt=\"ugly, deformed, disfigured, poor details, bad anatomy\",\n).images[0]\nimage\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-neg-prompt-1.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">negative_prompt = \"ugly, deformed, disfigured, poor details, bad anatomy\"</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-neg-prompt-2.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">negative_prompt = \"astronaut\"</figcaption>\n  </div>\n</div>\n\n### Generator\n\nA [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html#generator) object enables reproducibility in a pipeline by setting a manual seed. You can use a `Generator` to generate batches of images and iteratively improve on an image generated from a seed as detailed in the [Improve image quality with deterministic generation](reusing_seeds) guide.\n\nYou can set a seed and `Generator` as shown below. Creating an image with a `Generator` should return the same result each time instead of randomly generating a new image.\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n\t\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16\n).to(\"cuda\")\ngenerator = torch.Generator(device=\"cuda\").manual_seed(30)\nimage = pipeline(\n\t\"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\",\n\tgenerator=generator,\n).images[0]\nimage\n```\n\n## Control image generation\n\nThere are several ways to exert more control over how an image is generated outside of configuring a pipeline's parameters, such as prompt weighting and ControlNet models.\n\n### Prompt weighting\n\nPrompt weighting is a technique for increasing or decreasing the importance of concepts in a prompt to emphasize or minimize certain features in an image. We recommend using the [Compel](https://github.com/damian0815/compel) library to help you generate the weighted prompt embeddings.\n\n> [!TIP]\n> Learn how to create the prompt embeddings in the [Prompt weighting](weighted_prompts) guide. This example focuses on how to use the prompt embeddings in the pipeline.\n\nOnce you've created the embeddings, you can pass them to the `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter in the pipeline.\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n\t\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16\n).to(\"cuda\")\nimage = pipeline(\n\tprompt_embeds=prompt_embeds, # generated from Compel\n\tnegative_prompt_embeds=negative_prompt_embeds, # generated from Compel\n).images[0]\n```\n\n### ControlNet\n\nAs you saw in the [ControlNet](#controlnet) section, these models offer a more flexible and accurate way to generate images by incorporating an additional conditioning image input. Each ControlNet model is pretrained on a particular type of conditioning image to generate new images that resemble it. For example, if you take a ControlNet model pretrained on depth maps, you can give the model a depth map as a conditioning input and it'll generate an image that preserves the spatial information in it. This is quicker and easier than specifying the depth information in a prompt. You can even combine multiple conditioning inputs with a [MultiControlNet](controlnet#multicontrolnet)!\n\nThere are many types of conditioning inputs you can use, and 🤗 Diffusers supports ControlNet for Stable Diffusion and SDXL models. Take a look at the more comprehensive [ControlNet](controlnet) guide to learn how you can use these models.\n\n## Optimize\n\nDiffusion models are large, and the iterative nature of denoising an image is computationally expensive and intensive. But this doesn't mean you need access to powerful - or even many - GPUs to use them. There are many optimization techniques for running diffusion models on consumer and free-tier resources. For example, you can load model weights in half-precision to save GPU memory and increase speed or offload the entire model to the GPU to save even more memory.\n\nPyTorch 2.0 also supports a more memory-efficient attention mechanism called [*scaled dot product attention*](../optimization/fp16#scaled-dot-product-attention) that is automatically enabled if you're using PyTorch 2.0. You can combine this with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) to speed your code up even more:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\").to(\"cuda\")\npipeline.unet = torch.compile(pipeline.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\nFor more tips on how to optimize your code to save memory and speed up inference, read the [Accelerate inference](../optimization/fp16) and [Reduce memory usage](../optimization/memory) guides.\n"
  },
  {
    "path": "docs/source/en/using-diffusers/consisid.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n# ConsisID\n\n[ConsisID](https://github.com/PKU-YuanGroup/ConsisID) is an identity-preserving text-to-video generation model that keeps the face consistent in the generated video by frequency decomposition. The main features of ConsisID are:\n\n- Frequency decomposition: The characteristics of the DiT architecture are analyzed from the frequency domain perspective, and based on these characteristics, a reasonable control information injection method is designed.\n- Consistency training strategy: A coarse-to-fine training strategy, dynamic masking loss, and dynamic cross-face loss further enhance the model's generalization ability and identity preservation performance.\n- Inference without finetuning: Previous methods required case-by-case finetuning of the input ID before inference, leading to significant time and computational costs. In contrast, ConsisID is tuning-free.\n\nThis guide will walk you through using ConsisID for use cases.\n\n## Load Model Checkpoints\n\nModel weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.\n\n```python\n# !pip install consisid_eva_clip insightface facexlib\nimport torch\nfrom diffusers import ConsisIDPipeline\nfrom diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer\nfrom huggingface_hub import snapshot_download\n\n# Download ckpts\nsnapshot_download(repo_id=\"BestWishYsh/ConsisID-preview\", local_dir=\"BestWishYsh/ConsisID-preview\")\n\n# Load face helper model to preprocess input face image\nface_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = prepare_face_models(\"BestWishYsh/ConsisID-preview\", device=\"cuda\", dtype=torch.bfloat16)\n\n# Load consisid base model\npipe = ConsisIDPipeline.from_pretrained(\"BestWishYsh/ConsisID-preview\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n```\n\n## Identity-Preserving Text-to-Video\n\nFor identity-preserving text-to-video, pass a text prompt and an image contain clear face (e.g., preferably half-body or full-body). By default, ConsisID generates a 720x480 video for the best results.\n\n```python\nfrom diffusers.utils import export_to_video\n\nprompt = \"The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.\"\nimage = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true\"\n\nid_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(face_helper_1, face_clip_model, face_helper_2, eva_transform_mean, eva_transform_std, face_main_model, \"cuda\", torch.bfloat16, image, is_align_face=True)\n\nvideo = pipe(image=image, prompt=prompt, num_inference_steps=50, guidance_scale=6.0, use_dynamic_cfg=False, id_vit_hidden=id_vit_hidden, id_cond=id_cond, kps_cond=face_kps, generator=torch.Generator(\"cuda\").manual_seed(42))\nexport_to_video(video.frames[0], \"output.mp4\", fps=8)\n```\n<table>\n  <tr>\n    <th style=\"text-align: center;\">Face Image</th>\n    <th style=\"text-align: center;\">Video</th>\n    <th style=\"text-align: center;\">Description</th>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_0.png?download=true\" style=\"height: auto; width: 600px;\"></td>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_0.gif?download=true\" style=\"height: auto; width: 2000px;\"></td>\n    <td>The video, in a beautifully crafted animated style, features a confident woman riding a horse through a lush forest clearing. Her expression is focused yet serene as she adjusts her wide-brimmed hat with a practiced hand. She wears a flowy bohemian dress, which moves gracefully with the rhythm of the horse, the fabric flowing fluidly in the animated motion. The dappled sunlight filters through the trees, casting soft, painterly patterns on the forest floor. Her posture is poised, showing both control and elegance as she guides the horse with ease. The animation's gentle, fluid style adds a dreamlike quality to the scene, with the woman’s calm demeanor and the peaceful surroundings evoking a sense of freedom and harmony.</td>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_1.png?download=true\" style=\"height: auto; width: 600px;\"></td>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_1.gif?download=true\" style=\"height: auto; width: 2000px;\"></td>\n    <td>The video, in a captivating animated style, shows a woman standing in the center of a snowy forest, her eyes narrowed in concentration as she extends her hand forward. She is dressed in a deep blue cloak, her breath visible in the cold air, which is rendered with soft, ethereal strokes. A faint smile plays on her lips as she summons a wisp of ice magic, watching with focus as the surrounding trees and ground begin to shimmer and freeze, covered in delicate ice crystals. The animation’s fluid motion brings the magic to life, with the frost spreading outward in intricate, sparkling patterns. The environment is painted with soft, watercolor-like hues, enhancing the magical, dreamlike atmosphere. The overall mood is serene yet powerful, with the quiet winter air amplifying the delicate beauty of the frozen scene.</td>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_2.png?download=true\" style=\"height: auto; width: 600px;\"></td>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_2.gif?download=true\" style=\"height: auto; width: 2000px;\"></td>\n    <td>The animation features a whimsical portrait of a balloon seller standing in a gentle breeze, captured with soft, hazy brushstrokes that evoke the feel of a serene spring day. His face is framed by a gentle smile, his eyes squinting slightly against the sun, while a few wisps of hair flutter in the wind. He is dressed in a light, pastel-colored shirt, and the balloons around him sway with the wind, adding a sense of playfulness to the scene. The background blurs softly, with hints of a vibrant market or park, enhancing the light-hearted, yet tender mood of the moment.</td>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_3.png?download=true\" style=\"height: auto; width: 600px;\"></td>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_3.gif?download=true\" style=\"height: auto; width: 2000px;\"></td>\n    <td>The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.</td>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_4.png?download=true\" style=\"height: auto; width: 600px;\"></td>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_4.gif?download=true\" style=\"height: auto; width: 2000px;\"></td>\n    <td>The video features a baby wearing a bright superhero cape, standing confidently with arms raised in a powerful pose. The baby has a determined look on their face, with eyes wide and lips pursed in concentration, as if ready to take on a challenge. The setting appears playful, with colorful toys scattered around and a soft rug underfoot, while sunlight streams through a nearby window, highlighting the fluttering cape and adding to the impression of heroism. The overall atmosphere is lighthearted and fun, with the baby's expressions capturing a mix of innocence and an adorable attempt at bravery, as if truly ready to save the day.</td>\n  </tr>\n</table>\n\n## Resources\n\nLearn more about ConsisID with the following resources.\n- A [video](https://www.youtube.com/watch?v=PhlgC-bI5SQ) demonstrating ConsisID's main features.\n- The research paper, [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440) for more details.\n"
  },
  {
    "path": "docs/source/en/using-diffusers/controlling_generation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Controlled generation\n\nControlling outputs generated by diffusion models has been long pursued by the community and is now an active research topic. In many popular diffusion models, subtle changes in inputs, both images and text prompts, can drastically change outputs. In an ideal world we want to be able to control how semantics are preserved and changed.\n\nMost examples of preserving semantics reduce to being able to accurately map a change in input to a change in output. I.e. adding an adjective to a subject in a prompt preserves the entire image, only modifying the changed subject. Or, image variation of a particular subject preserves the subject's pose.\n\nAdditionally, there are qualities of generated images that we would like to influence beyond semantic preservation. I.e. in general, we would like our outputs to be of good quality, adhere to a particular style, or be realistic.\n\nWe will document some of the techniques `diffusers` supports to control generation of diffusion models. Much is cutting edge research and can be quite nuanced. If something needs clarifying or you have a suggestion, don't hesitate to open a discussion on the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or a [GitHub issue](https://github.com/huggingface/diffusers/issues).\n\nWe provide a high level explanation of how the generation can be controlled as well as a snippet of the technicals. For more in depth explanations on the technicals, the original papers which are linked from the pipelines are always the best resources.\n\nDepending on the use case, one should choose a technique accordingly. In many cases, these techniques can be combined. For example, one can combine Textual Inversion with SEGA to provide more semantic guidance to the outputs generated using Textual Inversion.\n\nUnless otherwise mentioned, these are techniques that work with existing models and don't require their own weights.\n\n1. [InstructPix2Pix](#instruct-pix2pix)\n2. [Pix2Pix Zero](#pix2pix-zero)\n3. [Attend and Excite](#attend-and-excite)\n4. [Semantic Guidance](#semantic-guidance-sega)\n5. [Self-attention Guidance](#self-attention-guidance-sag)\n6. [Depth2Image](#depth2image)\n7. [MultiDiffusion Panorama](#multidiffusion-panorama)\n8. [DreamBooth](#dreambooth)\n9. [Textual Inversion](#textual-inversion)\n10. [ControlNet](#controlnet)\n11. [Prompt Weighting](#prompt-weighting)\n12. [Custom Diffusion](#custom-diffusion)\n13. [Model Editing](#model-editing)\n14. [DiffEdit](#diffedit)\n15. [T2I-Adapter](#t2i-adapter)\n16. [FABRIC](#fabric)\n\nFor convenience, we provide a table to denote which methods are inference-only and which require fine-tuning/training.\n\n|                     **Method**                      | **Inference only** | **Requires training /<br> fine-tuning** |                                          **Comments**                                           |\n| :-------------------------------------------------: | :----------------: | :-------------------------------------: | :---------------------------------------------------------------------------------------------: |\n|        [InstructPix2Pix](#instruct-pix2pix)        |         ✅         |                   ❌                    | Can additionally be<br>fine-tuned for better <br>performance on specific <br>edit instructions. |\n|            [Pix2Pix Zero](#pix2pix-zero)            |         ✅         |                   ❌                    |                                                                                                 |\n|       [Attend and Excite](#attend-and-excite)       |         ✅         |                   ❌                    |                                                                                                 |\n|       [Semantic Guidance](#semantic-guidance-sega)       |         ✅         |                   ❌                    |                                                                                                 |\n| [Self-attention Guidance](#self-attention-guidance-sag) |         ✅         |                   ❌                    |                                                                                                 |\n|             [Depth2Image](#depth2image)             |         ✅         |                   ❌                    |                                                                                                 |\n| [MultiDiffusion Panorama](#multidiffusion-panorama) |         ✅         |                   ❌                    |                                                                                                 |\n|              [DreamBooth](#dreambooth)              |         ❌         |                   ✅                    |                                                                                                 |\n|       [Textual Inversion](#textual-inversion)       |         ❌         |                   ✅                    |                                                                                                 |\n|              [ControlNet](#controlnet)              |         ✅         |                   ❌                    |             A ControlNet can be <br>trained/fine-tuned on<br>a custom conditioning.             |\n|        [Prompt Weighting](#prompt-weighting)        |         ✅         |                   ❌                    |                                                                                                 |\n|        [Custom Diffusion](#custom-diffusion)        |         ❌         |                   ✅                    |                                                                                                 |\n|           [Model Editing](#model-editing)           |         ✅         |                   ❌                    |                                                                                                 |\n|                [DiffEdit](#diffedit)                |         ✅         |                   ❌                    |                                                                                                 |\n|             [T2I-Adapter](#t2i-adapter)             |         ✅         |                   ❌                    |                                                                                                 |\n|                [Fabric](#fabric)                    |         ✅         |                   ❌                    |                                                                                                 |\n## InstructPix2Pix\n\n[Paper](https://huggingface.co/papers/2211.09800)\n\n[InstructPix2Pix](../api/pipelines/pix2pix) is fine-tuned from Stable Diffusion to support editing input images. It takes as inputs an image and a prompt describing an edit, and it outputs the edited image.\nInstructPix2Pix has been explicitly trained to work well with [InstructGPT](https://openai.com/blog/instruction-following/)-like prompts.\n\n## Attend and Excite\n\n[Paper](https://huggingface.co/papers/2301.13826)\n\n[Attend and Excite](../api/pipelines/attend_and_excite) allows subjects in the prompt to be faithfully represented in the final image.\n\nA set of token indices are given as input, corresponding to the subjects in the prompt that need to be present in the image. During denoising, each token index is guaranteed to have a minimum attention threshold for at least one patch of the image. The intermediate latents are iteratively optimized during the denoising process to strengthen the attention of the most neglected subject token until the attention threshold is passed for all subject tokens.\n\nLike Pix2Pix Zero, Attend and Excite also involves a mini optimization loop (leaving the pre-trained weights untouched) in its pipeline and can require more memory than the usual [StableDiffusionPipeline](../api/pipelines/stable_diffusion/text2img).\n\n## Semantic Guidance (SEGA)\n\n[Paper](https://huggingface.co/papers/2301.12247)\n\n[SEGA](../api/pipelines/semantic_stable_diffusion) allows applying or removing one or more concepts from an image. The strength of the concept can also be controlled. I.e. the smile concept can be used to incrementally increase or decrease the smile of a portrait.\n\nSimilar to how classifier free guidance provides guidance via empty prompt inputs, SEGA provides guidance on conceptual prompts. Multiple of these conceptual prompts can be applied simultaneously. Each conceptual prompt can either add or remove their concept depending on if the guidance is applied positively or negatively.\n\nUnlike Pix2Pix Zero or Attend and Excite, SEGA directly interacts with the diffusion process instead of performing any explicit gradient-based optimization.\n\n## Self-attention Guidance (SAG)\n\n[Paper](https://huggingface.co/papers/2210.00939)\n\n[Self-attention Guidance](../api/pipelines/self_attention_guidance) improves the general quality of images.\n\nSAG provides guidance from predictions not conditioned on high-frequency details to fully conditioned images. The high frequency details are extracted out of the UNet self-attention maps.\n\n## Depth2Image\n\n[Project](https://huggingface.co/stabilityai/stable-diffusion-2-depth)\n\n[Depth2Image](../api/pipelines/stable_diffusion/depth2img) is fine-tuned from Stable Diffusion to better preserve semantics for text guided image variation.\n\nIt conditions on a monocular depth estimate of the original image.\n\n## MultiDiffusion Panorama\n\n[Paper](https://huggingface.co/papers/2302.08113)\n\n[MultiDiffusion Panorama](../api/pipelines/panorama) defines a new generation process over a pre-trained diffusion model. This process binds together multiple diffusion generation methods that can be readily applied to generate high quality and diverse images. Results adhere to user-provided controls, such as desired aspect ratio (e.g., panorama), and spatial guiding signals, ranging from tight segmentation masks to bounding boxes.\nMultiDiffusion Panorama allows to generate high-quality images at arbitrary aspect ratios (e.g., panoramas).\n\n## Fine-tuning your own models\n\nIn addition to pre-trained models, Diffusers has training scripts for fine-tuning models on user-provided data.\n\n## DreamBooth\n\n[Project](https://dreambooth.github.io/)\n\n[DreamBooth](../training/dreambooth) fine-tunes a model to teach it about a new subject. I.e. a few pictures of a person can be used to generate images of that person in different styles.\n\n## Textual Inversion\n\n[Paper](https://huggingface.co/papers/2208.01618)\n\n[Textual Inversion](../training/text_inversion) fine-tunes a model to teach it about a new concept. I.e. a few pictures of a style of artwork can be used to generate images in that style.\n\n## ControlNet\n\n[Paper](https://huggingface.co/papers/2302.05543)\n\n[ControlNet](../api/pipelines/controlnet) is an auxiliary network which adds an extra condition.\nThere are 8 canonical pre-trained ControlNets trained on different conditionings such as edge detection, scribbles,\ndepth maps, and semantic segmentations.\n\n## Prompt Weighting\n\n[Prompt weighting](../using-diffusers/weighted_prompts) is a simple technique that puts more attention weight on certain parts of the text\ninput.\n\n## Custom Diffusion\n\n[Paper](https://huggingface.co/papers/2212.04488)\n\n[Custom Diffusion](../training/custom_diffusion) only fine-tunes the cross-attention maps of a pre-trained\ntext-to-image diffusion model. It also allows for additionally performing Textual Inversion. It supports\nmulti-concept training by design. Like DreamBooth and Textual Inversion, Custom Diffusion is also used to\nteach a pre-trained text-to-image diffusion model about new concepts to generate outputs involving the\nconcept(s) of interest.\n\n## DiffEdit\n\n[Paper](https://huggingface.co/papers/2210.11427)\n\n[DiffEdit](../api/pipelines/diffedit) allows for semantic editing of input images along with\ninput prompts while preserving the original input images as much as possible.\n\n## T2I-Adapter\n\n[Paper](https://huggingface.co/papers/2302.08453)\n\n[T2I-Adapter](../api/pipelines/stable_diffusion/adapter) is an auxiliary network which adds an extra condition.\nThere are 8 canonical pre-trained adapters trained on different conditionings such as edge detection, sketch,\ndepth maps, and semantic segmentations.\n\n## Fabric\n\n[Paper](https://huggingface.co/papers/2307.10159)\n\n[Fabric](https://github.com/huggingface/diffusers/tree/442017ccc877279bcf24fbe92f92d3d0def191b6/examples/community#stable-diffusion-fabric-pipeline) is a training-free\napproach applicable to a wide range of popular diffusion models, which exploits\nthe self-attention layer present in the most widely used architectures to condition\nthe diffusion process on a set of feedback images.\n"
  },
  {
    "path": "docs/source/en/using-diffusers/controlnet.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNet\n\n[ControlNet](https://huggingface.co/papers/2302.05543) is an adapter that enables controllable generation such as generating an image of a cat in a *specific pose* or following the lines in a sketch of a *specific* cat. It works by adding a smaller network of \"zero convolution\" layers and progressively training these to avoid disrupting with the original model. The original model parameters are frozen to avoid retraining it.\n\nA ControlNet is conditioned on extra visual information or \"structural controls\" (canny edge, depth maps, human pose, etc.) that can be combined with text prompts to generate images that are guided by the visual input.\n\n> [!TIP]\n> ControlNets are available to many models such as [Flux](../api/pipelines/controlnet_flux), [Hunyuan-DiT](../api/pipelines/controlnet_hunyuandit), [Stable Diffusion 3](../api/pipelines/controlnet_sd3), and more. The examples in this guide use Flux and Stable Diffusion XL.\n\nLoad a ControlNet conditioned on a specific control, such as canny edge, and pass it to the pipeline in [`~DiffusionPipeline.from_pretrained`].\n\n<hfoptions id=\"usage\">\n<hfoption id=\"text-to-image\">\n\nGenerate a canny image with [opencv-python](https://github.com/opencv/opencv-python).\n\n```py\nimport cv2\nimport numpy as np\nfrom PIL import Image\nfrom diffusers.utils import load_image\n\noriginal_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png\"\n)\n\nimage = np.array(original_image)\n\nlow_threshold = 100\nhigh_threshold = 200\n\nimage = cv2.Canny(image, low_threshold, high_threshold)\nimage = image[:, :, None]\nimage = np.concatenate([image, image, image], axis=2)\ncanny_image = Image.fromarray(image)\n```\n\nPass the canny image to the pipeline. Use the `controlnet_conditioning_scale` parameter to determine how much weight to assign to the control.\n\n```py\nimport torch\nfrom diffusers.utils import load_image\nfrom diffusers import FluxControlNetPipeline, FluxControlNetModel\n\ncontrolnet = FluxControlNetModel.from_pretrained(\n    \"InstantX/FLUX.1-dev-Controlnet-Canny\", torch_dtype=torch.bfloat16\n)\npipeline = FluxControlNetPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\", controlnet=controlnet, torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\nprompt = \"\"\"\nA photorealistic overhead image of a cat reclining sideways in a flamingo pool floatie holding a margarita. \nThe cat is floating leisurely in the pool and completely relaxed and happy.\n\"\"\"\n\npipeline(\n    prompt, \n    control_image=canny_image,\n    controlnet_conditioning_scale=0.5,\n    num_inference_steps=50, \n    guidance_scale=3.5,\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png\" width=\"300\" alt=\"Generated image (prompt only)\"/>\n    <figcaption style=\"text-align: center;\">original image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png\" width=\"300\" alt=\"Control image (Canny edges)\"/>\n    <figcaption style=\"text-align: center;\">canny image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat-generated.png\" width=\"300\" alt=\"Generated image (ControlNet + prompt)\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\n\n</hfoption>\n<hfoption id=\"image-to-image\">\n\nGenerate a depth map with a depth estimation pipeline from Transformers.\n\n```py\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom transformers import DPTImageProcessor, DPTForDepthEstimation\nfrom diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL\nfrom diffusers.utils import load_image\n\n\ndepth_estimator = DPTForDepthEstimation.from_pretrained(\"Intel/dpt-hybrid-midas\").to(\"cuda\")\nfeature_extractor = DPTImageProcessor.from_pretrained(\"Intel/dpt-hybrid-midas\")\n\ndef get_depth_map(image):\n    image = feature_extractor(images=image, return_tensors=\"pt\").pixel_values.to(\"cuda\")\n    with torch.no_grad(), torch.autocast(\"cuda\"):\n        depth_map = depth_estimator(image).predicted_depth\n\n    depth_map = torch.nn.functional.interpolate(\n        depth_map.unsqueeze(1),\n        size=(1024, 1024),\n        mode=\"bicubic\",\n        align_corners=False,\n    )\n    depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)\n    depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)\n    depth_map = (depth_map - depth_min) / (depth_max - depth_min)\n    image = torch.cat([depth_map] * 3, dim=1)\n    image = image.permute(0, 2, 3, 1).cpu().numpy()[0]\n    image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))\n    return image\n\ndepth_image = get_depth_map(image)\n```\n\nPass the depth map to the pipeline. Use the `controlnet_conditioning_scale` parameter to determine how much weight to assign to the control.\n\n```py\ncontrolnet = ControlNetModel.from_pretrained(\n    \"diffusers/controlnet-depth-sdxl-1.0-small\",\n    torch_dtype=torch.float16,\n)\nvae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16)\npipeline = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    controlnet=controlnet,\n    vae=vae,\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n\nprompt = \"\"\"\nA photorealistic overhead image of a cat reclining sideways in a flamingo pool floatie holding a margarita. \nThe cat is floating leisurely in the pool and completely relaxed and happy.\n\"\"\"\nimage = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png\"\n).resize((1024, 1024))\ncontrolnet_conditioning_scale = 0.5 \npipeline(\n    prompt,\n    image=image,\n    control_image=depth_image,\n    controlnet_conditioning_scale=controlnet_conditioning_scale,\n    strength=0.99,\n    num_inference_steps=100,\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png\" width=\"300\" alt=\"Generated image (prompt only)\"/>\n    <figcaption style=\"text-align: center;\">original image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_depth_image.png\" width=\"300\" alt=\"Control image (Canny edges)\"/>\n    <figcaption style=\"text-align: center;\">depth map</figcaption>\n  </figure>\n  <figure> \n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_depth_cat.png\" width=\"300\" alt=\"Generated image (ControlNet + prompt)\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\n</hfoption>\n<hfoption id=\"inpainting\">\n\nGenerate a mask image and convert it to a tensor to mark the pixels in the original image as masked if the corresponding pixel in the mask image is over a certain threshold.\n\n```py\nimport cv2\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom diffusers.utils import load_image\nfrom diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel\n\ninit_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png\"\n)\ninit_image = init_image.resize((1024, 1024))\nmask_image = load_image(\n    \"/content/cat_mask.png\"\n)\nmask_image = mask_image.resize((1024, 1024))\n\ndef make_canny_condition(image):\n    image = np.array(image)\n    image = cv2.Canny(image, 100, 200)\n    image = image[:, :, None]\n    image = np.concatenate([image, image, image], axis=2)\n    image = Image.fromarray(image)\n    return image\n\ncontrol_image = make_canny_condition(init_image)\n```\n\nPass the mask and control image to the pipeline. Use the `controlnet_conditioning_scale` parameter to determine how much weight to assign to the control.\n\n```py\ncontrolnet = ControlNetModel.from_pretrained(\n    \"diffusers/controlnet-canny-sdxl-1.0\", torch_dtype=torch.float16\n)\npipeline = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", controlnet=controlnet, torch_dtype=torch.float16\n)\npipeline(\n    \"a cute and fluffy bunny rabbit\",\n    num_inference_steps=100,\n    strength=0.99,\n    controlnet_conditioning_scale=0.5,\n    image=init_image,\n    mask_image=mask_image,\n    control_image=control_image,\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png\" width=\"300\" alt=\"Generated image (prompt only)\"/>\n    <figcaption style=\"text-align: center;\">original image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat_mask.png\" width=\"300\" alt=\"Control image (Canny edges)\"/>\n    <figcaption style=\"text-align: center;\">mask image</figcaption>\n  </figure>\n  <figure> \n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_rabbit_inpaint.png\" width=\"300\" alt=\"Generated image (ControlNet + prompt)\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\n</hfoption>\n</hfoptions>\n\n## Multi-ControlNet\n\nYou can compose multiple ControlNet conditionings, such as canny image and a depth map, to create a *MultiControlNet*. For the best rersults, you should mask conditionings so they don't overlap and experiment with different `controlnet_conditioning_scale` parameters to adjust how much weight is assigned to each control input.\n\nThe example below composes a canny image and depth map.\n\nPass the ControlNets as a list to the pipeline and resize the images to the expected input size.\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL\n\ncontrolnets = [\n    ControlNetModel.from_pretrained(\n        \"diffusers/controlnet-depth-sdxl-1.0-small\", torch_dtype=torch.float16\n    ),\n    ControlNetModel.from_pretrained(\n        \"diffusers/controlnet-canny-sdxl-1.0\", torch_dtype=torch.float16,\n    ),\n]\n\nvae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16)\npipeline = StableDiffusionXLControlNetPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", controlnet=controlnets, vae=vae, torch_dtype=torch.float16\n).to(\"cuda\")\n\nprompt = \"\"\"\na relaxed rabbit sitting on a striped towel next to a pool with a tropical drink nearby, \nbright sunny day, vacation scene, 35mm photograph, film, professional, 4k, highly detailed\n\"\"\"\nnegative_prompt = \"lowres, bad anatomy, worst quality, low quality, deformed, ugly\"\n\nimages = [canny_image.resize((1024, 1024)), depth_image.resize((1024, 1024))]\n\npipeline(\n    prompt,\n    negative_prompt=negative_prompt,\n    image=images,\n    num_inference_steps=100,\n    controlnet_conditioning_scale=[0.5, 0.5],\n    strength=0.7,\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png\" width=\"300\" alt=\"Generated image (prompt only)\"/>\n    <figcaption style=\"text-align: center;\">canny image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/multicontrolnet_depth.png\" width=\"300\" alt=\"Control image (Canny edges)\"/>\n    <figcaption style=\"text-align: center;\">depth map</figcaption>\n  </figure>\n  <figure> \n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_multi_controlnet.png\" width=\"300\" alt=\"Generated image (ControlNet + prompt)\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\n## guess_mode\n\n[Guess mode](https://github.com/lllyasviel/ControlNet/discussions/188) generates an image from **only** the control input (canny edge, depth map, pose, etc.) and without guidance from a prompt. It adjusts the scale of the ControlNet's output residuals by a fixed ratio depending on block depth. The earlier `DownBlock` is only scaled by `0.1` and the `MidBlock` is fully scaled by `1.0`.\n\n```py\nimport torch\nfrom diffusers.utils import load_iamge\nfrom diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel\n\ncontrolnet = ControlNetModel.from_pretrained(\n  \"diffusers/controlnet-canny-sdxl-1.0\", torch_dtype=torch.float16\n)\npipeline = StableDiffusionXLControlNetPipeline.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\",\n  controlnet=controlnet,\n  torch_dtype=torch.float16\n).to(\"cuda\")\n\ncanny_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png\")\npipeline(\n  \"\",\n  image=canny_image,\n  guess_mode=True\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png\" width=\"300\" alt=\"Control image (Canny edges)\"/>\n    <figcaption style=\"text-align: center;\">canny image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guess_mode.png\" width=\"300\" alt=\"Generated image (Guess mode)\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>"
  },
  {
    "path": "docs/source/en/using-diffusers/create_a_server.md",
    "content": "\n# Create a server\n\nDiffusers' pipelines can be used as an inference engine for a server. It supports concurrent and multithreaded requests to generate images that may be requested by multiple users at the same time.\n\nThis guide will show you how to use the [`StableDiffusion3Pipeline`] in a server, but feel free to use any pipeline you want.\n\n\nStart by navigating to the `examples/server` folder and installing all of the dependencies.\n\n```py\npip install .\npip install -f requirements.txt\n```\n\nLaunch the server with the following command.\n\n```py\npython server.py\n```\n\nThe server is accessed at http://localhost:8000. You can curl this model with the following command.\n```\ncurl -X POST -H \"Content-Type: application/json\" --data '{\"model\": \"something\", \"prompt\": \"a kitten in front of a fireplace\"}' http://localhost:8000/v1/images/generations\n```\n\nIf you need to upgrade some dependencies, you can use either [pip-tools](https://github.com/jazzband/pip-tools) or [uv](https://github.com/astral-sh/uv). For example, upgrade the dependencies with `uv` using the following command.\n\n```\nuv pip compile requirements.in -o requirements.txt\n```\n\n\nThe server is built with [FastAPI](https://fastapi.tiangolo.com/async/). The endpoint for `v1/images/generations` is shown below.\n```py\n@app.post(\"/v1/images/generations\")\nasync def generate_image(image_input: TextToImageInput):\n    try:\n        loop = asyncio.get_event_loop()\n        scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)\n        pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)\n        generator = torch.Generator(device=\"cuda\")\n        generator.manual_seed(random.randint(0, 10000000))\n        output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))\n        logger.info(f\"output: {output}\")\n        image_url = save_image(output.images[0])\n        return {\"data\": [{\"url\": image_url}]}\n    except Exception as e:\n        if isinstance(e, HTTPException):\n            raise e\n        elif hasattr(e, 'message'):\n            raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())\n        raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())\n```\nThe `generate_image` function is defined as asynchronous with the [async](https://fastapi.tiangolo.com/async/) keyword so that FastAPI knows that whatever is happening in this function won't necessarily return a result right away. Once it hits some point in the function that it needs to await some other [Task](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task), the main thread goes back to answering other HTTP requests. This is shown in the code below with the [await](https://fastapi.tiangolo.com/async/#async-and-await) keyword.\n```py\noutput = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))\n```\nAt this point, the execution of the pipeline function is placed onto a [new thread](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor), and the main thread performs other things until a result is returned from the `pipeline`.\n\nAnother important aspect of this implementation is creating a `pipeline` from `shared_pipeline`. The goal behind this is to avoid loading the underlying model more than once onto the GPU while still allowing for each new request that is running on a separate thread to have its own generator and scheduler. The scheduler, in particular, is not thread-safe, and it will cause errors like: `IndexError: index 21 is out of bounds for dimension 0 with size 21` if you try to use the same scheduler across multiple threads.\n"
  },
  {
    "path": "docs/source/en/using-diffusers/custom_pipeline_overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# Community pipelines and components\n\nCommunity pipelines are [`DiffusionPipeline`] classes that are different from the original paper implementation. They provide additional functionality or extend the original pipeline implementation.\n\n> [!TIP]\n> Check out the community pipelines in [diffusers/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) with inference and training examples for how to use them.\n\nCommunity pipelines are either stored on the Hub or the Diffusers' GitHub repository. Hub pipelines are completely customizable (scheduler, models, pipeline code, etc.) while GitHub pipelines are limited to only the custom pipeline code. Further compare the two community pipeline types in the table below.\n\n|  | GitHub | Hub |\n|---|---|---|\n| Usage | Same. | Same. |\n| Review process | Open a Pull Request on GitHub and undergo a review process from the Diffusers team before merging. This option is slower. | Upload directly to a Hub repository without a review. This is the fastest option. |\n| Visibility | Included in the official Diffusers repository and docs. | Included on your Hub profile and relies on your own usage and promotion to gain visibility. |\n\n## custom_pipeline\n\nLoad either community pipeline types by passing the `custom_pipeline` argument to [`~DiffusionPipeline.from_pretrained`].\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-3-medium-diffusers\",\n    custom_pipeline=\"pipeline_stable_diffusion_3_instruct_pix2pix\",\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\n```\n\nAdd the `custom_revision` argument to [`~DiffusionPipeline.from_pretrained`] to load a community pipeline from a specific version (for example, `v0.30.0` or `main`). By default, community pipelines are loaded from the latest stable version of Diffusers.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-3-medium-diffusers\",\n    custom_pipeline=\"pipeline_stable_diffusion_3_instruct_pix2pix\",\n    custom_revision=\"main\"\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\n```\n\n> [!WARNING]\n> While the Hugging Face Hub [scans](https://huggingface.co/docs/hub/security-malware) files, you should still inspect the Hub pipeline code and make sure it is safe.\n\nThere are a few ways to load a community pipeline.\n\n- Pass a path to `custom_pipeline` to load a local community pipeline. The directory must contain a `pipeline.py` file containing the pipeline class.\n\n  ```py\n  import torch\n  from diffusers import DiffusionPipeline\n\n  pipeline = DiffusionPipeline.from_pretrained(\n      \"stabilityai/stable-diffusion-3-medium-diffusers\",\n      custom_pipeline=\"path/to/pipeline_directory\",\n      torch_dtype=torch.float16,\n      device_map=\"cuda\"\n  )\n  ```\n\n- The `custom_pipeline` argument is also supported by [`~DiffusionPipeline.from_pipe`], which is useful for [reusing pipelines](./loading#reuse-a-pipeline) without using additional memory. It limits the memory usage to only the largest pipeline loaded.\n\n  ```py\n  import torch\n  from diffusers import DiffusionPipeline\n\n  pipeline_sd = DiffusionPipeline.from_pretrained(\"emilianJR/CyberRealistic_V3\", torch_dtype=torch.float16, device_map=\"cuda\")\n  pipeline_lpw = DiffusionPipeline.from_pipe(\n      pipeline_sd, custom_pipeline=\"lpw_stable_diffusion\", device_map=\"cuda\"\n  )\n  ```\n\n  The [`~DiffusionPipeline.from_pipe`] method is especially useful for loading community pipelines because many of them don't have pretrained weights. Community pipelines generally add a feature on top of an existing pipeline.\n\n## Community components\n\nCommunity components let users build pipelines with custom transformers, UNets, VAEs, and schedulers not supported by Diffusers. These components require Python module implementations. \n\nThis section shows how users can use community components to build a community pipeline using [showlab/show-1-base](https://huggingface.co/showlab/show-1-base) as an example.\n\n1. Load the required components, the scheduler and image processor. The text encoder is generally imported from [Transformers](https://huggingface.co/docs/transformers/index).\n\n```python\nfrom transformers import T5Tokenizer, T5EncoderModel, CLIPImageProcessor\nfrom diffusers import DPMSolverMultistepScheduler\n\npipeline_id = \"showlab/show-1-base\"\ntokenizer = T5Tokenizer.from_pretrained(pipeline_id, subfolder=\"tokenizer\")\ntext_encoder = T5EncoderModel.from_pretrained(pipeline_id, subfolder=\"text_encoder\")\nscheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder=\"scheduler\")\nfeature_extractor = CLIPImageProcessor.from_pretrained(pipe_id, subfolder=\"feature_extractor\")\n```\n\n> [!WARNING]\n> In steps 2 and 3, the custom [UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py) and [pipeline](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) implementation must match the format shown in their files for this example to work.\n\n2. Load a [custom UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py) which is already implemented in [showone_unet_3d_condition.py](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py). The [`UNet3DConditionModel`] class name is renamed to the custom implementation, `ShowOneUNet3DConditionModel`, because [`UNet3DConditionModel`] already exists in Diffusers. Any components required for `ShowOneUNet3DConditionModel` class should be placed in `showone_unet_3d_condition.py`.\n\n```python\nfrom showone_unet_3d_condition import ShowOneUNet3DConditionModel\n\nunet = ShowOneUNet3DConditionModel.from_pretrained(pipeline_id, subfolder=\"unet\")\n```\n\n3. Load the custom pipeline code (already implemented in [pipeline_t2v_base_pixel.py](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/pipeline_t2v_base_pixel.py)). This script contains a custom `TextToVideoIFPipeline` class for generating videos from text. Like the custom UNet, any code required for `TextToVideIFPipeline` should be placed in `pipeline_t2v_base_pixel.py`.\n\nInitialize `TextToVideoIFPipeline` with `ShowOneUNet3DConditionModel`.\n\n```python\nimport torch\nfrom pipeline_t2v_base_pixel import TextToVideoIFPipeline\n\npipeline = TextToVideoIFPipeline(\n    unet=unet,\n    text_encoder=text_encoder,\n    tokenizer=tokenizer,\n    scheduler=scheduler,\n    feature_extractor=feature_extractor,\n    device_map=\"cuda\",\n    torch_dtype=torch.float16\n)\n```\n\n4. Push the pipeline to the Hub to share with the community.\n\n```python\npipeline.push_to_hub(\"custom-t2v-pipeline\")\n```\n\nAfter the pipeline is successfully pushed, make the following changes.\n\n- Change the `_class_name` attribute in [model_index.json](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/model_index.json#L2) to `\"pipeline_t2v_base_pixel\"` and `\"TextToVideoIFPipeline\"`.\n- Upload `showone_unet_3d_condition.py` to the [unet](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) subfolder.\n- Upload `pipeline_t2v_base_pixel.py` to the pipeline [repository](https://huggingface.co/sayakpaul/show-1-base-with-code/tree/main).\n\nTo run inference, add the `trust_remote_code` argument while initializing the pipeline to handle all the \"magic\" behind the scenes.\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"<change-username>/<change-id>\", trust_remote_code=True, torch_dtype=torch.float16\n)\n```\n\n> [!WARNING]\n> As an additional precaution with `trust_remote_code=True`, we strongly encourage passing a commit hash to the `revision` argument in [`~DiffusionPipeline.from_pretrained`] to make sure the code hasn't been updated with new malicious code (unless you fully trust the model owners).\n\n## Resources\n\n- Take a look at Issue [#841](https://github.com/huggingface/diffusers/issues/841) for more context about why we're adding community pipelines to help everyone easily share their work without being slowed down.\n- Check out the [stabilityai/japanese-stable-diffusion-xl](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl/) repository for an additional example of a community pipeline that also uses the `trust_remote_code` feature."
  },
  {
    "path": "docs/source/en/using-diffusers/depth2img.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Text-guided depth-to-image generation\n\n[[open-in-colab]]\n\nThe [`StableDiffusionDepth2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images. In addition, you can also pass a `depth_map` to preserve the image structure. If no `depth_map` is provided, the pipeline automatically predicts the depth via an integrated [depth-estimation model](https://github.com/isl-org/MiDaS).\n\nStart by creating an instance of the [`StableDiffusionDepth2ImgPipeline`]:\n\n```python\nimport torch\nfrom diffusers import StableDiffusionDepth2ImgPipeline\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = StableDiffusionDepth2ImgPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-2-depth\",\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n).to(\"cuda\")\n```\n\nNow pass your prompt to the pipeline. You can also pass a `negative_prompt` to prevent certain words from guiding how an image is generated:\n\n```python\nurl = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\ninit_image = load_image(url)\nprompt = \"two tigers\"\nnegative_prompt = \"bad, deformed, ugly, bad anatomy\"\nimage = pipeline(prompt=prompt, image=init_image, negative_prompt=negative_prompt, strength=0.7).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n| Input                                                                           | Output                                                                                                                                |\n|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|\n| <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/coco-cats.png\" width=\"500\"/> | <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/depth2img-tigers.png\" width=\"500\"/> |\n"
  },
  {
    "path": "docs/source/en/using-diffusers/diffedit.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DiffEdit\n\n[[open-in-colab]]\n\nImage editing typically requires providing a mask of the area to be edited. DiffEdit automatically generates the mask for you based on a text query, making it easier overall to create a mask without image editing software. The DiffEdit algorithm works in three steps:\n\n1. the diffusion model denoises an image conditioned on some query text and reference text which produces different noise estimates for different areas of the image; the difference is used to infer a mask to identify which area of the image needs to be changed to match the query text\n2. the input image is encoded into latent space with DDIM\n3. the latents are decoded with the diffusion model conditioned on the text query, using the mask as a guide such that pixels outside the mask remain the same as in the input image\n\nThis guide will show you how to use DiffEdit to edit images without manually creating a mask.\n\nBefore you begin, make sure you have the following libraries installed:\n\n```py\n# uncomment to install the necessary libraries in Colab\n#!pip install -q diffusers transformers accelerate\n```\n\nThe [`StableDiffusionDiffEditPipeline`] requires an image mask and a set of partially inverted latents. The image mask is generated from the [`~StableDiffusionDiffEditPipeline.generate_mask`] function, and includes two parameters, `source_prompt` and `target_prompt`. These parameters determine what to edit in the image. For example, if you want to change a bowl of *fruits* to a bowl of *pears*, then:\n\n```py\nsource_prompt = \"a bowl of fruits\"\ntarget_prompt = \"a bowl of pears\"\n```\n\nThe partially inverted latents are generated from the [`~StableDiffusionDiffEditPipeline.invert`] function, and it is generally a good idea to include a `prompt` or *caption* describing the image to help guide the inverse latent sampling process. The caption can often be your `source_prompt`, but feel free to experiment with other text descriptions!\n\nLet's load the pipeline, scheduler, inverse scheduler, and enable some optimizations to reduce memory usage:\n\n```py\nimport torch\nfrom diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionDiffEditPipeline\n\npipeline = StableDiffusionDiffEditPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-2-1\",\n    torch_dtype=torch.float16,\n    safety_checker=None,\n    use_safetensors=True,\n)\npipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)\npipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)\npipeline.enable_model_cpu_offload()\npipeline.enable_vae_slicing()\n```\n\nLoad the image to edit:\n\n```py\nfrom diffusers.utils import load_image, make_image_grid\n\nimg_url = \"https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png\"\nraw_image = load_image(img_url).resize((768, 768))\nraw_image\n```\n\nUse the [`~StableDiffusionDiffEditPipeline.generate_mask`] function to generate the image mask. You'll need to pass it the `source_prompt` and `target_prompt` to specify what to edit in the image:\n\n```py\nfrom PIL import Image\n\nsource_prompt = \"a bowl of fruits\"\ntarget_prompt = \"a basket of pears\"\nmask_image = pipeline.generate_mask(\n    image=raw_image,\n    source_prompt=source_prompt,\n    target_prompt=target_prompt,\n)\nImage.fromarray((mask_image.squeeze()*255).astype(\"uint8\"), \"L\").resize((768, 768))\n```\n\nNext, create the inverted latents and pass it a caption describing the image:\n\n```py\ninv_latents = pipeline.invert(prompt=source_prompt, image=raw_image).latents\n```\n\nFinally, pass the image mask and inverted latents to the pipeline. The `target_prompt` becomes the `prompt` now, and the `source_prompt` is used as the `negative_prompt`:\n\n```py\noutput_image = pipeline(\n    prompt=target_prompt,\n    mask_image=mask_image,\n    image_latents=inv_latents,\n    negative_prompt=source_prompt,\n).images[0]\nmask_image = Image.fromarray((mask_image.squeeze()*255).astype(\"uint8\"), \"L\").resize((768, 768))\nmake_image_grid([raw_image, mask_image, output_image], rows=1, cols=3)\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">original image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://github.com/Xiang-cd/DiffEdit-stable-diffusion/blob/main/assets/target.png?raw=true\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">edited image</figcaption>\n  </div>\n</div>\n\n## Generate source and target embeddings\n\nThe source and target embeddings can be automatically generated with the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model instead of creating them manually.\n\nLoad the Flan-T5 model and tokenizer from the 🤗 Transformers library:\n\n```py\nimport torch\nfrom transformers import AutoTokenizer, T5ForConditionalGeneration\n\ntokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-large\")\nmodel = T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-large\", device_map=\"auto\", torch_dtype=torch.float16)\n```\n\nProvide some initial text to prompt the model to generate the source and target prompts.\n\n```py\nsource_concept = \"bowl\"\ntarget_concept = \"basket\"\n\nsource_text = f\"Provide a caption for images containing a {source_concept}. \"\n\"The captions should be in English and should be no longer than 150 characters.\"\n\ntarget_text = f\"Provide a caption for images containing a {target_concept}. \"\n\"The captions should be in English and should be no longer than 150 characters.\"\n```\n\nNext, create a utility function to generate the prompts:\n\n```py\n@torch.no_grad()\ndef generate_prompts(input_prompt):\n    input_ids = tokenizer(input_prompt, return_tensors=\"pt\").input_ids.to(\"cuda\")\n\n    outputs = model.generate(\n        input_ids, temperature=0.8, num_return_sequences=16, do_sample=True, max_new_tokens=128, top_k=10\n    )\n    return tokenizer.batch_decode(outputs, skip_special_tokens=True)\n\nsource_prompts = generate_prompts(source_text)\ntarget_prompts = generate_prompts(target_text)\nprint(source_prompts)\nprint(target_prompts)\n```\n\n> [!TIP]\n> Check out the [generation strategy](https://huggingface.co/docs/transformers/main/en/generation_strategies) guide if you're interested in learning more about strategies for generating different quality text.\n\nLoad the text encoder model used by the [`StableDiffusionDiffEditPipeline`] to encode the text. You'll use the text encoder to compute the text embeddings:\n\n```py\nimport torch\nfrom diffusers import StableDiffusionDiffEditPipeline\n\npipeline = StableDiffusionDiffEditPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-2-1\", torch_dtype=torch.float16, use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\npipeline.enable_vae_slicing()\n\n@torch.no_grad()\ndef embed_prompts(sentences, tokenizer, text_encoder, device=\"cuda\"):\n    embeddings = []\n    for sent in sentences:\n        text_inputs = tokenizer(\n            sent,\n            padding=\"max_length\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]\n        embeddings.append(prompt_embeds)\n    return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0)\n\nsource_embeds = embed_prompts(source_prompts, pipeline.tokenizer, pipeline.text_encoder)\ntarget_embeds = embed_prompts(target_prompts, pipeline.tokenizer, pipeline.text_encoder)\n```\n\nFinally, pass the embeddings to the [`~StableDiffusionDiffEditPipeline.generate_mask`] and [`~StableDiffusionDiffEditPipeline.invert`] functions, and pipeline to generate the image:\n\n```diff\n  from diffusers import DDIMInverseScheduler, DDIMScheduler\n  from diffusers.utils import load_image, make_image_grid\n  from PIL import Image\n\n  pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)\n  pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)\n\n  img_url = \"https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png\"\n  raw_image = load_image(img_url).resize((768, 768))\n\n  mask_image = pipeline.generate_mask(\n      image=raw_image,\n-     source_prompt=source_prompt,\n-     target_prompt=target_prompt,\n+     source_prompt_embeds=source_embeds,\n+     target_prompt_embeds=target_embeds,\n  )\n\n  inv_latents = pipeline.invert(\n-     prompt=source_prompt,\n+     prompt_embeds=source_embeds,\n      image=raw_image,\n  ).latents\n\n  output_image = pipeline(\n      mask_image=mask_image,\n      image_latents=inv_latents,\n-     prompt=target_prompt,\n-     negative_prompt=source_prompt,\n+     prompt_embeds=target_embeds,\n+     negative_prompt_embeds=source_embeds,\n  ).images[0]\n  mask_image = Image.fromarray((mask_image.squeeze()*255).astype(\"uint8\"), \"L\")\n  make_image_grid([raw_image, mask_image, output_image], rows=1, cols=3)\n```\n\n## Generate a caption for inversion\n\nWhile you can use the `source_prompt` as a caption to help generate the partially inverted latents, you can also use the [BLIP](https://huggingface.co/docs/transformers/model_doc/blip) model to automatically generate a caption.\n\nLoad the BLIP model and processor from the 🤗 Transformers library:\n\n```py\nimport torch\nfrom transformers import BlipForConditionalGeneration, BlipProcessor\n\nprocessor = BlipProcessor.from_pretrained(\"Salesforce/blip-image-captioning-base\")\nmodel = BlipForConditionalGeneration.from_pretrained(\"Salesforce/blip-image-captioning-base\", torch_dtype=torch.float16, low_cpu_mem_usage=True)\n```\n\nCreate a utility function to generate a caption from the input image:\n\n```py\n@torch.no_grad()\ndef generate_caption(images, caption_generator, caption_processor):\n    text = \"a photograph of\"\n\n    inputs = caption_processor(images, text, return_tensors=\"pt\").to(device=\"cuda\", dtype=caption_generator.dtype)\n    caption_generator.to(\"cuda\")\n    outputs = caption_generator.generate(**inputs, max_new_tokens=128)\n\n    # offload caption generator\n    caption_generator.to(\"cpu\")\n\n    caption = caption_processor.batch_decode(outputs, skip_special_tokens=True)[0]\n    return caption\n```\n\nLoad an input image and generate a caption for it using the `generate_caption` function:\n\n```py\nfrom diffusers.utils import load_image\n\nimg_url = \"https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png\"\nraw_image = load_image(img_url).resize((768, 768))\ncaption = generate_caption(raw_image, model, processor)\n```\n\n<div class=\"flex justify-center\">\n    <figure>\n        <img class=\"rounded-xl\" src=\"https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png\"/>\n        <figcaption class=\"text-center\">generated caption: \"a photograph of a bowl of fruit on a table\"</figcaption>\n    </figure>\n</div>\n\nNow you can drop the caption into the [`~StableDiffusionDiffEditPipeline.invert`] function to generate the partially inverted latents!\n"
  },
  {
    "path": "docs/source/en/using-diffusers/dreambooth.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DreamBooth\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method for generating personalized images of a specific instance. It works by fine-tuning the model on 3-5 images of the subject (for example, a cat) that is associated with a unique identifier (`sks cat`). This allows you to use `sks cat` in your prompt to trigger the model to generate images of your cat in different settings, lighting, poses, and styles.\n\nDreamBooth checkpoints are typically a few GBs in size because it contains the full model weights.\n\nLoad the DreamBooth checkpoint with [`~DiffusionPipeline.from_pretrained`] and include the unique identifier in the prompt to activate its generation.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"sd-dreambooth-library/herge-style\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\nprompt = \"A cute sks herge_style brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration\"\npipeline(prompt).images[0]\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_dreambooth.png\" />\n</div>"
  },
  {
    "path": "docs/source/en/using-diffusers/guiders.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Guiders\n\n[Classifier-free guidance](https://huggingface.co/papers/2207.12598) steers model generation that better match a prompt and is commonly used to improve generation quality, control, and adherence to prompts. There are different types of guidance methods, and in Diffusers, they are known as *guiders*. Like blocks, it is easy to switch and use different guiders for different use cases without rewriting the pipeline.\n\nThis guide will show you how to switch guiders, adjust guider parameters, and load and share them to the Hub.\n\n## Switching guiders\n\n[`ClassifierFreeGuidance`] is the default guider and created when a pipeline is initialized with [`~ModularPipelineBlocks.init_pipeline`]. It is created by `from_config` which means it doesn't require loading specifications from a modular repository. A guider won't be listed in `modular_model_index.json`.\n\nUse [`~ModularPipeline.get_component_spec`] to inspect a guider.\n\n```py\nt2i_pipeline.get_component_spec(\"guider\")\nComponentSpec(name='guider', type_hint=<class 'diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance'>, description=None, config=FrozenDict([('guidance_scale', 7.5), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['start', 'guidance_rescale', 'stop', 'use_original_formulation'])]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')\n```\n\nSwitch to a different guider by passing the new guider to [`~ModularPipeline.update_components`].\n\n> [!TIP]\n> Changing guiders will return text letting you know you're changing the guider type.\n> ```bash\n> ModularPipeline.update_components: adding guider with new type: PerturbedAttentionGuidance, previous type: ClassifierFreeGuidance\n> ```\n\n```py\nfrom diffusers import LayerSkipConfig, PerturbedAttentionGuidance\n\nconfig = LayerSkipConfig(indices=[2, 9], fqn=\"mid_block.attentions.0.transformer_blocks\", skip_attention=False, skip_attention_scores=True, skip_ff=False)\nguider = PerturbedAttentionGuidance(\n    guidance_scale=5.0, perturbed_guidance_scale=2.5, perturbed_guidance_config=config\n)\nt2i_pipeline.update_components(guider=guider)\n```\n\nUse [`~ModularPipeline.get_component_spec`] again to verify the guider type is different.\n\n```py\nt2i_pipeline.get_component_spec(\"guider\")\nComponentSpec(name='guider', type_hint=<class 'diffusers.guiders.perturbed_attention_guidance.PerturbedAttentionGuidance'>, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['perturbed_guidance_start', 'use_original_formulation', 'perturbed_guidance_layers', 'stop', 'start', 'guidance_rescale', 'perturbed_guidance_stop']), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')\n```\n\n## Loading custom guiders\n\nGuiders that are already saved on the Hub with a `modular_model_index.json` file are considered a `from_pretrained` component now instead of a `from_config` component.\n\n```json\n{\n  \"guider\": [\n    null,\n    null,\n    {\n      \"repo\": \"YiYiXu/modular-loader-t2i-guider\",\n      \"revision\": null,\n      \"subfolder\": \"pag_guider\",\n      \"type_hint\": [\n        \"diffusers\",\n        \"PerturbedAttentionGuidance\"\n      ],\n      \"variant\": null\n    }\n  ]\n}\n```\n\nThe guider is only created after calling [`~ModularPipeline.load_components`] based on the loading specification in `modular_model_index.json`.\n\n```py\nt2i_pipeline = t2i_blocks.init_pipeline(\"YiYiXu/modular-doc-guider\")\n# not created during init\nassert t2i_pipeline.guider is None\nt2i_pipeline.load_components()\n# loaded as PAG guider\nt2i_pipeline.guider\n```\n\n\n## Changing guider parameters\n\nThe guider parameters can be adjusted with the [`~ComponentSpec.create`] method and [`~ModularPipeline.update_components`]. The example below changes the `guidance_scale` value.\n\n\n```py\nguider_spec = t2i_pipeline.get_component_spec(\"guider\")\nguider = guider_spec.create(guidance_scale=10)\nt2i_pipeline.update_components(guider=guider)\n```\n\n## Uploading custom guiders\n\nCall the [`~utils.PushToHubMixin.push_to_hub`] method on a custom guider to share it to the Hub.\n\n```py\nguider.push_to_hub(\"YiYiXu/modular-loader-t2i-guider\", subfolder=\"pag_guider\")\n```\n\nTo make this guider available to the pipeline, either modify the `modular_model_index.json` file or use the [`~ModularPipeline.update_components`] method.\n\n<hfoptions id=\"upload\">\n<hfoption id=\"modular_model_index.json\">\n\nEdit the `modular_model_index.json` file and add a loading specification for the guider by pointing to a folder containing the guider config.\n\n```json\n{\n  \"guider\": [\n    \"diffusers\",\n    \"PerturbedAttentionGuidance\",\n    {\n      \"repo\": \"YiYiXu/modular-loader-t2i-guider\",\n      \"revision\": null,\n      \"subfolder\": \"pag_guider\",\n      \"type_hint\": [\n        \"diffusers\",\n        \"PerturbedAttentionGuidance\"\n      ],\n      \"variant\": null\n    }\n  ],\n```\n\n</hfoption>\n<hfoption id=\"update_components\">\n\nChange the [`~ComponentSpec.default_creation_method`] to `from_pretrained` and use [`~ModularPipeline.update_components`] to update the guider and component specifications as well as the pipeline config.\n\n> [!TIP]\n> Changing the creation method will return text letting you know you're changing the creation type to `from_pretrained`.\n> ```bash\n> ModularPipeline.update_components: changing the default_creation_method of guider from from_config to from_pretrained.\n> ```\n\n```py\nguider_spec = t2i_pipeline.get_component_spec(\"guider\")\nguider_spec.default_creation_method=\"from_pretrained\"\nguider_spec.pretrained_model_name_or_path=\"YiYiXu/modular-loader-t2i-guider\"\nguider_spec.subfolder=\"pag_guider\"\npag_guider = guider_spec.load()\nt2i_pipeline.update_components(guider=pag_guider)\n```\n\nTo make it the default guider for a pipeline, call [`~utils.PushToHubMixin.push_to_hub`]. This is an optional step and not necessary if you are only experimenting locally.\n\n```py\nt2i_pipeline.push_to_hub(\"YiYiXu/modular-doc-guider\")\n```\n\n</hfoption>\n</hfoptions>\n"
  },
  {
    "path": "docs/source/en/using-diffusers/helios.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n# Helios\n\n[Helios](https://github.com/PKU-YuanGroup/Helios) is the first 14B video generation model that runs at 19.5 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching the quality of a strong baseline, natively integrating T2V, I2V, and V2V tasks within a unified architecture. The main features of Helios are:\n\n- Without commonly used anti-drifting strategies (eg, self-forcing, error-banks, keyframe sampling, or inverted sampling), Helios generates minute-scale videos with high quality and strong coherence.\n- Without standard acceleration techniques (eg, KV-cache, causal masking, sparse/linear attention, TinyVAE, progressive noise schedules, hidden-state caching, or quantization), Helios achieves 19.5 FPS in end-to-end inference for a 14B video generation model on a single H100 GPU.\n- Introducing optimizations that improve both training and inference throughput while reducing memory consumption. These changes enable training a 14B video generation model without parallelism or sharding infrastructure, with batch sizes comparable to image models.\n\nThis guide will walk you through using Helios for use cases.\n\n## Load Model Checkpoints\n\nModel weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.\n\n```python\nimport torch\nfrom diffusers import HeliosPipeline, HeliosPyramidPipeline\nfrom huggingface_hub import snapshot_download\n\n# For Best Quality\nsnapshot_download(repo_id=\"BestWishYsh/Helios-Base\", local_dir=\"BestWishYsh/Helios-Base\")\npipe = HeliosPipeline.from_pretrained(\"BestWishYsh/Helios-Base\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\n# Intermediate Weight\nsnapshot_download(repo_id=\"BestWishYsh/Helios-Mid\", local_dir=\"BestWishYsh/Helios-Mid\")\npipe = HeliosPyramidPipeline.from_pretrained(\"BestWishYsh/Helios-Mid\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\n# For Best Efficiency\nsnapshot_download(repo_id=\"BestWishYsh/Helios-Distilled\", local_dir=\"BestWishYsh/Helios-Distilled\")\npipe = HeliosPyramidPipeline.from_pretrained(\"BestWishYsh/Helios-Distilled\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n```\n\n## Text-to-Video Showcases\n\n<table>\n  <tr>\n    <th style=\"text-align: center;\">Prompt</th>\n    <th style=\"text-align: center;\">Generated Video</th>\n  </tr>\n  <tr>\n    <td><small>A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression.\n    </small></td>\n    <td>\n      <video width=\"4000\" controls>\n        <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases1.mp4\" type=\"video/mp4\">\n      </video>\n    </td>\n  </tr>\n  <tr>\n    <td><small>A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle.\n    </small></td>\n    <td>\n      <video width=\"4000\" controls>\n        <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases2.mp4\" type=\"video/mp4\">\n      </video>\n    </td>\n  </tr>\n</table>\n\n## Image-to-Video Showcases\n\n<table>\n  <tr>\n    <th style=\"text-align: center;\">Image</th>\n    <th style=\"text-align: center;\">Prompt</th>\n    <th style=\"text-align: center;\">Generated Video</th>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.jpg\" style=\"height: auto; width: 300px;\"></td>\n    <td><small>A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads \"KIA 626,\" and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees.\n    </small></td>\n    <td>\n      <video width=\"2000\" controls>\n        <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.mp4\" type=\"video/mp4\">\n      </video>\n    </td>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.jpg\" style=\"height: auto; width: 300px;\"></td>\n    <td><small>A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective.\n    </small></td>\n    <td>\n      <video width=\"2000\" controls>\n        <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.mp4\" type=\"video/mp4\">\n      </video>\n    </td>\n  </tr>\n</table>\n\n## Interactive-Video Showcases\n\n<table>\n  <tr>\n    <th style=\"text-align: center;\">Prompt</th>\n    <th style=\"text-align: center;\">Generated Video</th>\n  </tr>\n  <tr>\n    <td><small>The prompt can be found <a href=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.txt\">here</a></small></td>\n    <td>\n      <video width=\"680\" controls>\n        <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.mp4\" type=\"video/mp4\">\n      </video>\n    </td>\n  </tr>\n  <tr>\n    <td><small>The prompt can be found <a href=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.txt\">here</a></small></td>\n    <td>\n      <video width=\"680\" controls>\n        <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.mp4\" type=\"video/mp4\">\n      </video>\n    </td>\n  </tr>\n</table>\n\n## Resources\n\nLearn more about Helios with the following resources.\n- Watch [video1](https://www.youtube.com/watch?v=vd_AgHtOUFQ) and [video2](https://www.youtube.com/watch?v=1GeIU2Dn7UY) for a demonstration of Helios's key features.\n- The research paper, [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/2603.04379) for more details.\n"
  },
  {
    "path": "docs/source/en/using-diffusers/image_quality.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# FreeU\n\n[FreeU](https://hf.co/papers/2309.11497) improves image details by rebalancing the UNet's backbone and skip connection weights. The skip connections can cause the model to overlook some of the backbone semantics which may lead to unnatural image details in the generated image. This technique does not require any additional training and can be applied on the fly during inference for tasks like image-to-image and text-to-video.\n\nUse the [`~pipelines.StableDiffusionMixin.enable_freeu`] method on your pipeline and configure the scaling factors for the backbone (`b1` and `b2`) and skip connections (`s1` and `s2`). The number after each scaling factor corresponds to the stage in the UNet where the factor is applied. Take a look at the [FreeU](https://github.com/ChenyangSi/FreeU#parameters) repository for reference hyperparameters for different models.\n\n<hfoptions id=\"freeu\">\n<hfoption id=\"Stable Diffusion v1-5\">\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, safety_checker=None\n).to(\"cuda\")\npipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.5, b2=1.6)\ngenerator = torch.Generator(device=\"cpu\").manual_seed(33)\nprompt = \"\"\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdv15-no-freeu.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">FreeU disabled</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdv15-freeu.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">FreeU enabled</figcaption>\n  </div>\n</div>\n\n</hfoption>\n<hfoption id=\"Stable Diffusion v2-1\">\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-2-1\", torch_dtype=torch.float16, safety_checker=None\n).to(\"cuda\")\npipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.4, b2=1.6)\ngenerator = torch.Generator(device=\"cpu\").manual_seed(80)\nprompt = \"A squirrel eating a burger\"\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdv21-no-freeu.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">FreeU disabled</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdv21-freeu.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">FreeU enabled</figcaption>\n  </div>\n</div>\n\n</hfoption>\n<hfoption id=\"Stable Diffusion XL\">\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16,\n).to(\"cuda\")\npipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.3, b2=1.4)\ngenerator = torch.Generator(device=\"cpu\").manual_seed(13)\nprompt = \"A squirrel eating a burger\"\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-no-freeu.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">FreeU disabled</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-freeu.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">FreeU enabled</figcaption>\n  </div>\n</div>\n\n</hfoption>\n<hfoption id=\"Zeroscope\">\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers.utils import export_to_video\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"damo-vilab/text-to-video-ms-1.7b\", torch_dtype=torch.float16\n).to(\"cuda\")\n# values come from https://github.com/lyn-rgb/FreeU_Diffusers#video-pipelines\npipeline.enable_freeu(b1=1.2, b2=1.4, s1=0.9, s2=0.2)\nprompt = \"Confident teddy bear surfer rides the wave in the tropics\"\ngenerator = torch.Generator(device=\"cpu\").manual_seed(47)\nvideo_frames = pipeline(prompt, generator=generator).frames[0]\nexport_to_video(video_frames, \"teddy_bear.mp4\", fps=10)\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/video-no-freeu.gif\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">FreeU disabled</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/video-freeu.gif\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">FreeU enabled</figcaption>\n  </div>\n</div>\n\n</hfoption>\n</hfoptions>\n\nCall the [`~pipelines.StableDiffusionMixin.disable_freeu`] method to disable FreeU.\n\n```py\npipeline.disable_freeu()\n```\n"
  },
  {
    "path": "docs/source/en/using-diffusers/img2img.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Image-to-image\n\n[[open-in-colab]]\n\nImage-to-image is similar to [text-to-image](conditional_image_generation), but in addition to a prompt, you can also pass an initial image as a starting point for the diffusion process. The initial image is encoded to latent space and noise is added to it. Then the latent diffusion model takes a prompt and the noisy latent image, predicts the added noise, and removes the predicted noise from the initial latent image to get the new latent image. Lastly, a decoder decodes the new latent image back into an image.\n\nWith 🤗 Diffusers, this is as easy as 1-2-3:\n\n1. Load a checkpoint into the [`AutoPipelineForImage2Image`] class; this pipeline automatically handles loading the correct pipeline class  based on the checkpoint:\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16, use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n```\n\n> [!TIP]\n> You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/fp16#scaled-dot-product-attention).\n\n2. Load an image to pass to the pipeline:\n\n```py\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png\")\n```\n\n3. Pass a prompt and image to the pipeline to generate an image:\n\n```py\nprompt = \"cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k\"\nimage = pipeline(prompt, image=init_image).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">initial image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image</figcaption>\n  </div>\n</div>\n\n## Popular models\n\nThe most popular image-to-image models are [Stable Diffusion v1.5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5), [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). The results from the Stable Diffusion and Kandinsky models vary due to their architecture differences and training process; you can generally expect SDXL to produce higher quality images than Stable Diffusion v1.5. Let's take a quick look at how to use each of these models and compare their results.\n\n### Stable Diffusion v1.5\n\nStable Diffusion v1.5 is a latent diffusion model initialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image:\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import make_image_grid, load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# prepare image\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"\ninit_image = load_image(url)\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n\n# pass prompt and image to pipeline\nimage = pipeline(prompt, image=init_image).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">initial image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdv1.5.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image</figcaption>\n  </div>\n</div>\n\n### Stable Diffusion XL (SDXL)\n\nSDXL is a more powerful version of the Stable Diffusion model. It uses a larger base model, and an additional refiner model to increase the quality of the base model's output. Read the [SDXL](sdxl) guide for a more detailed walkthrough of how to use this model, and other techniques it uses to produce high quality images.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import make_image_grid, load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-refiner-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# prepare image\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png\"\ninit_image = load_image(url)\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n\n# pass prompt and image to pipeline\nimage = pipeline(prompt, image=init_image, strength=0.5).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">initial image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image</figcaption>\n  </div>\n</div>\n\n### Kandinsky 2.2\n\nThe Kandinsky model is different from the Stable Diffusion models because it uses an image prior model to create image embeddings. The embeddings help create a better alignment between text and images, allowing the latent diffusion model to generate better images.\n\nThe simplest way to use Kandinsky 2.2 is:\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import make_image_grid, load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16, use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# prepare image\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"\ninit_image = load_image(url)\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n\n# pass prompt and image to pipeline\nimage = pipeline(prompt, image=init_image).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">initial image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-kandinsky.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image</figcaption>\n  </div>\n</div>\n\n## Configure pipeline parameters\n\nThere are several important parameters you can configure in the pipeline that'll affect the image generation process and image quality. Let's take a closer look at what these parameters do and how changing them affects the output.\n\n### Strength\n\n`strength` is one of the most important parameters to consider and it'll have a huge impact on your generated image. It determines how much the generated image resembles the initial image. In other words:\n\n- 📈 a higher `strength` value gives the model more \"creativity\" to generate an image that's different from the initial image; a `strength` value of 1.0 means the initial image is more or less ignored\n- 📉 a lower `strength` value means the generated image is more similar to the initial image\n\nThe `strength` and `num_inference_steps` parameters are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import make_image_grid, load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# prepare image\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"\ninit_image = load_image(url)\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n\n# pass prompt and image to pipeline\nimage = pipeline(prompt, image=init_image, strength=0.8).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-strength-0.4.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">strength = 0.4</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-strength-0.6.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">strength = 0.6</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-strength-1.0.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">strength = 1.0</figcaption>\n  </div>\n</div>\n\n### Guidance scale\n\nThe `guidance_scale` parameter is used to control how closely aligned the generated image and text prompt are. A higher `guidance_scale` value means your generated image is more aligned with the prompt, while a lower `guidance_scale` value means your generated image has more space to deviate from the prompt.\n\nYou can combine `guidance_scale` with `strength` for even more precise control over how expressive the model is. For example, combine a high `strength + guidance_scale` for maximum creativity or use a combination of low `strength` and low `guidance_scale` to generate an image that resembles the initial image but is not as strictly bound to the prompt.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import make_image_grid, load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# prepare image\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"\ninit_image = load_image(url)\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n\n# pass prompt and image to pipeline\nimage = pipeline(prompt, image=init_image, guidance_scale=8.0).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-guidance-0.1.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">guidance_scale = 0.1</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-guidance-3.0.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">guidance_scale = 5.0</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-guidance-7.5.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">guidance_scale = 10.0</figcaption>\n  </div>\n</div>\n\n### Negative prompt\n\nA negative prompt conditions the model to *not* include things in an image, and it can be used to improve image quality or modify an image. For example, you can improve image quality by including negative prompts like \"poor details\" or \"blurry\" to encourage the model to generate a higher quality image. Or you can modify an image by specifying things to exclude from an image.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import make_image_grid, load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-refiner-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# prepare image\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"\ninit_image = load_image(url)\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\nnegative_prompt = \"ugly, deformed, disfigured, poor details, bad anatomy\"\n\n# pass prompt and image to pipeline\nimage = pipeline(prompt, negative_prompt=negative_prompt, image=init_image).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-negative-1.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">negative_prompt = \"ugly, deformed, disfigured, poor details, bad anatomy\"</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-negative-2.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">negative_prompt = \"jungle\"</figcaption>\n  </div>\n</div>\n\n## Chained image-to-image pipelines\n\nThere are some other interesting ways you can use an image-to-image pipeline aside from just generating an image (although that is pretty cool too). You can take it a step further and chain it with other pipelines.\n\n### Text-to-image-to-image\n\nChaining a text-to-image and image-to-image pipeline allows you to generate an image from text and use the generated image as the initial image for the image-to-image pipeline. This is useful if you want to generate an image entirely from scratch. For example, let's chain a Stable Diffusion and a Kandinsky model.\n\nStart by generating an image with the text-to-image pipeline:\n\n```py\nfrom diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image\nimport torch\nfrom diffusers.utils import make_image_grid\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\ntext2image = pipeline(\"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\").images[0]\ntext2image\n```\n\nNow you can pass this generated image to the image-to-image pipeline:\n\n```py\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16, use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\nimage2image = pipeline(\"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\", image=text2image).images[0]\nmake_image_grid([text2image, image2image], rows=1, cols=2)\n```\n\n### Image-to-image-to-image\n\nYou can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generating short GIFs, restoring color to an image, or restoring missing areas of an image.\n\nStart by generating an image:\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import make_image_grid, load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# prepare image\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"\ninit_image = load_image(url)\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n\n# pass prompt and image to pipeline\nimage = pipeline(prompt, image=init_image, output_type=\"latent\").images[0]\n```\n\n> [!TIP]\n> It is important to specify `output_type=\"latent\"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.\n\nPass the latent output from this pipeline to the next pipeline to generate an image in a [comic book art style](https://huggingface.co/ogkalu/Comic-Diffusion):\n\n```py\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"ogkalu/Comic-Diffusion\", torch_dtype=torch.float16\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# need to include the token \"charliebo artstyle\" in the prompt to use this checkpoint\nimage = pipeline(\"Astronaut in a jungle, charliebo artstyle\", image=image, output_type=\"latent\").images[0]\n```\n\nRepeat one more time to generate the final image in a [pixel art style](https://huggingface.co/kohbanye/pixel-art-style):\n\n```py\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"kohbanye/pixel-art-style\", torch_dtype=torch.float16\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# need to include the token \"pixelartstyle\" in the prompt to use this checkpoint\nimage = pipeline(\"Astronaut in a jungle, pixelartstyle\", image=image).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n### Image-to-upscaler-to-super-resolution\n\nAnother way you can chain your image-to-image pipeline is with an upscaler and super-resolution pipeline to really increase the level of details in an image.\n\nStart with an image-to-image pipeline:\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import make_image_grid, load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# prepare image\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"\ninit_image = load_image(url)\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n\n# pass prompt and image to pipeline\nimage_1 = pipeline(prompt, image=init_image, output_type=\"latent\").images[0]\n```\n\n> [!TIP]\n> It is important to specify `output_type=\"latent\"` in the pipeline to keep all the outputs in *latent* space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.\n\nChain it to an upscaler pipeline to increase the image resolution:\n\n```py\nfrom diffusers import StableDiffusionLatentUpscalePipeline\n\nupscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(\n    \"stabilityai/sd-x2-latent-upscaler\", torch_dtype=torch.float16, use_safetensors=True\n)\nupscaler.enable_model_cpu_offload()\nupscaler.enable_xformers_memory_efficient_attention()\n\nimage_2 = upscaler(prompt, image=image_1).images[0]\n```\n\nFinally, chain it to a super-resolution pipeline to further enhance the resolution:\n\n```py\nfrom diffusers import StableDiffusionUpscalePipeline\n\nsuper_res = StableDiffusionUpscalePipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-x4-upscaler\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\nsuper_res.enable_model_cpu_offload()\nsuper_res.enable_xformers_memory_efficient_attention()\n\nimage_3 = super_res(prompt, image=image_2).images[0]\nmake_image_grid([init_image, image_3.resize((512, 512))], rows=1, cols=2)\n```\n\n## Control image generation\n\nTrying to generate an image that looks exactly the way you want can be difficult, which is why controlled generation techniques and models are so useful. While you can use the `negative_prompt` to partially control image generation, there are more robust methods like prompt weighting and ControlNets.\n\n### Prompt weighting\n\nPrompt weighting allows you to scale the representation of each concept in a prompt. For example, in a prompt like \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\", you can choose to increase or decrease the embeddings of \"astronaut\" and \"jungle\". The [Compel](https://github.com/damian0815/compel) library provides a simple syntax for adjusting prompt weights and generating the embeddings. You can learn how to create the embeddings in the [Prompt weighting](weighted_prompts) guide.\n\n[`AutoPipelineForImage2Image`] has a `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter where you can pass the embeddings which replaces the `prompt` parameter.\n\n```py\nfrom diffusers import AutoPipelineForImage2Image\nimport torch\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\nimage = pipeline(prompt_embeds=prompt_embeds, # generated from Compel\n    negative_prompt_embeds=negative_prompt_embeds, # generated from Compel\n    image=init_image,\n).images[0]\n```\n\n### ControlNet\n\nControlNets provide a more flexible and accurate way to control image generation because you can use an additional conditioning image. The conditioning image can be a canny image, depth map, image segmentation, and even scribbles! Whatever type of conditioning image you choose, the ControlNet generates an image that preserves the information in it.\n\nFor example, let's condition an image with a depth map to keep the spatial information in the image.\n\n```py\nfrom diffusers.utils import load_image, make_image_grid\n\n# prepare image\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"\ninit_image = load_image(url)\ninit_image = init_image.resize((958, 960)) # resize to depth image dimensions\ndepth_image = load_image(\"https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png\")\nmake_image_grid([init_image, depth_image], rows=1, cols=2)\n```\n\nLoad a ControlNet model conditioned on depth maps and the [`AutoPipelineForImage2Image`]:\n\n```py\nfrom diffusers import ControlNetModel, AutoPipelineForImage2Image\nimport torch\n\ncontrolnet = ControlNetModel.from_pretrained(\"lllyasviel/control_v11f1p_sd15_depth\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True)\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", controlnet=controlnet, torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n```\n\nNow generate a new image conditioned on the depth map, initial image, and prompt:\n\n```py\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\nimage_control_net = pipeline(prompt, image=init_image, control_image=depth_image).images[0]\nmake_image_grid([init_image, depth_image, image_control_net], rows=1, cols=3)\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">initial image</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">depth image</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-controlnet.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">ControlNet image</figcaption>\n  </div>\n</div>\n\nLet's apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion) to the image generated from the ControlNet by chaining it with an image-to-image pipeline:\n\n```py\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"nitrosocke/elden-ring-diffusion\", torch_dtype=torch.float16,\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\nprompt = \"elden ring style astronaut in a jungle\" # include the token \"elden ring style\" in the prompt\nnegative_prompt = \"ugly, deformed, disfigured, poor details, bad anatomy\"\n\nimage_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image_control_net, strength=0.45, guidance_scale=10.5).images[0]\nmake_image_grid([init_image, depth_image, image_control_net, image_elden_ring], rows=2, cols=2)\n```\n\n<div class=\"flex justify-center\">\n  <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-elden-ring.png\">\n</div>\n\n## Optimize\n\nRunning diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](../optimization/fp16#scaled-dot-product-attention) or [xFormers](../optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU.\n\n```diff\n+ pipeline.enable_model_cpu_offload()\n+ pipeline.enable_xformers_memory_efficient_attention()\n```\n\nWith [`torch.compile`](../optimization/fp16#torchcompile), you can boost your inference speed even more by wrapping your UNet with it:\n\n```py\npipeline.unet = torch.compile(pipeline.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\nTo learn more, take a look at the [Reduce memory usage](../optimization/memory) and [Accelerate inference](../optimization/fp16) guides.\n"
  },
  {
    "path": "docs/source/en/using-diffusers/inference_with_lcm.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Latent Consistency Model\n\n[[open-in-colab]]\n\n[Latent Consistency Models (LCMs)](https://hf.co/papers/2310.04378) enable fast high-quality image generation by directly predicting the reverse diffusion process in the latent rather than pixel space. In other words, LCMs try to predict the noiseless image from the noisy image in contrast to typical diffusion models that iteratively remove noise from the noisy image. By avoiding the iterative sampling process, LCMs are able to generate high-quality images in 2-4 steps instead of 20-30 steps.\n\nLCMs are distilled from pretrained models which requires ~32 hours of A100 compute. To speed this up, [LCM-LoRAs](https://hf.co/papers/2311.05556) train a [LoRA adapter](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) which have much fewer parameters to train compared to the full model. The LCM-LoRA can be plugged into a diffusion model once it has been trained.\n\nThis guide will show you how to use LCMs and LCM-LoRAs for fast inference on tasks and how to use them with other adapters like ControlNet or T2I-Adapter.\n\n> [!TIP]\n> LCMs and LCM-LoRAs are available for Stable Diffusion v1.5, Stable Diffusion XL, and the SSD-1B model. You can find their checkpoints on the [Latent Consistency](https://hf.co/collections/latent-consistency/latent-consistency-models-weights-654ce61a95edd6dffccef6a8) Collections.\n\n## Text-to-image\n\n<hfoptions id=\"lcm-text2img\">\n<hfoption id=\"LCM\">\n\nTo use LCMs, you need to load the LCM checkpoint for your supported model into [`UNet2DConditionModel`] and replace the scheduler with the [`LCMScheduler`]. Then you can use the pipeline as usual, and pass a text prompt to generate an image in just 4 steps.\n\nA couple of notes to keep in mind when using LCMs are:\n\n* Typically, batch size is doubled inside the pipeline for classifier-free guidance. But LCM applies guidance with guidance embeddings and doesn't need to double the batch size, which leads to faster inference. The downside is that negative prompts don't work with LCM because they don't have any effect on the denoising process.\n* The ideal range for `guidance_scale` is [3., 13.] because that is what the UNet was trained with. However, disabling `guidance_scale` with a value of 1.0 is also effective in most cases.\n\n```python\nfrom diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, LCMScheduler\nimport torch\n\nunet = UNet2DConditionModel.from_pretrained(\n    \"latent-consistency/lcm-sdxl\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n)\npipe = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", unet=unet, torch_dtype=torch.float16, variant=\"fp16\",\n).to(\"cuda\")\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\n\nprompt = \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\"\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    prompt=prompt, num_inference_steps=4, generator=generator, guidance_scale=8.0\n).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_full_sdxl_t2i.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"LCM-LoRA\">\n\nTo use LCM-LoRAs, you need to replace the scheduler with the [`LCMScheduler`] and load the LCM-LoRA weights with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method. Then you can use the pipeline as usual, and pass a text prompt to generate an image in just 4 steps.\n\nA couple of notes to keep in mind when using LCM-LoRAs are:\n\n* Typically, batch size is doubled inside the pipeline for classifier-free guidance. But LCM applies guidance with guidance embeddings and doesn't need to double the batch size, which leads to faster inference. The downside is that negative prompts don't work with LCM because they don't have any effect on the denoising process.\n* You could use guidance with LCM-LoRAs, but it is very sensitive to high `guidance_scale` values and can lead to artifacts in the generated image. The best values we've found are between [1.0, 2.0].\n* Replace [stabilityai/stable-diffusion-xl-base-1.0](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0) with any finetuned model. For example, try using the [animagine-xl](https://huggingface.co/Linaqruf/animagine-xl) checkpoint to generate anime images with SDXL.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline, LCMScheduler\n\npipe = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    variant=\"fp16\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\npipe.load_lora_weights(\"latent-consistency/lcm-lora-sdxl\")\n\nprompt = \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\"\ngenerator = torch.manual_seed(42)\nimage = pipe(\n    prompt=prompt, num_inference_steps=4, generator=generator, guidance_scale=1.0\n).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_sdxl_t2i.png\"/>\n</div>\n\n</hfoption>\n</hfoptions>\n\n## Image-to-image\n\n<hfoptions id=\"lcm-img2img\">\n<hfoption id=\"LCM\">\n\nTo use LCMs for image-to-image, you need to load the LCM checkpoint for your supported model into [`UNet2DConditionModel`] and replace the scheduler with the [`LCMScheduler`]. Then you can use the pipeline as usual, and pass a text prompt and initial image to generate an image in just 4 steps.\n\n> [!TIP]\n> Experiment with different values for `num_inference_steps`, `strength`, and `guidance_scale` to get the best results.\n\n```python\nimport torch\nfrom diffusers import AutoPipelineForImage2Image, UNet2DConditionModel, LCMScheduler\nfrom diffusers.utils import load_image\n\nunet = UNet2DConditionModel.from_pretrained(\n    \"SimianLuo/LCM_Dreamshaper_v7\",\n    subfolder=\"unet\",\n    torch_dtype=torch.float16,\n)\n\npipe = AutoPipelineForImage2Image.from_pretrained(\n    \"Lykon/dreamshaper-7\",\n    unet=unet,\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n).to(\"cuda\")\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\n\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\")\nprompt = \"Astronauts in a jungle, cold color palette, muted colors, detailed, 8k\"\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    prompt,\n    image=init_image,\n    num_inference_steps=4,\n    guidance_scale=7.5,\n    strength=0.5,\n    generator=generator\n).images[0]\nimage\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">initial image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm-img2img.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image</figcaption>\n  </div>\n</div>\n\n</hfoption>\n<hfoption id=\"LCM-LoRA\">\n\nTo use LCM-LoRAs for image-to-image, you need to replace the scheduler with the [`LCMScheduler`] and load the LCM-LoRA weights with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method. Then you can use the pipeline as usual, and pass a text prompt and initial image to generate an image in just 4 steps.\n\n> [!TIP]\n> Experiment with different values for `num_inference_steps`, `strength`, and `guidance_scale` to get the best results.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image, LCMScheduler\nfrom diffusers.utils import make_image_grid, load_image\n\npipe = AutoPipelineForImage2Image.from_pretrained(\n    \"Lykon/dreamshaper-7\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n).to(\"cuda\")\n\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\n\npipe.load_lora_weights(\"latent-consistency/lcm-lora-sdv1-5\")\n\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\")\nprompt = \"Astronauts in a jungle, cold color palette, muted colors, detailed, 8k\"\n\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    prompt,\n    image=init_image,\n    num_inference_steps=4,\n    guidance_scale=1,\n    strength=0.6,\n    generator=generator\n).images[0]\nimage\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">initial image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm-lora-img2img.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image</figcaption>\n  </div>\n</div>\n\n</hfoption>\n</hfoptions>\n\n## Inpainting\n\nTo use LCM-LoRAs for inpainting, you need to replace the scheduler with the [`LCMScheduler`] and load the LCM-LoRA weights with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method. Then you can use the pipeline as usual, and pass a text prompt, initial image, and mask image to generate an image in just 4 steps.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting, LCMScheduler\nfrom diffusers.utils import load_image, make_image_grid\n\npipe = AutoPipelineForInpainting.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n).to(\"cuda\")\n\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\n\npipe.load_lora_weights(\"latent-consistency/lcm-lora-sdv1-5\")\n\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\")\n\nprompt = \"concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k\"\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    prompt=prompt,\n    image=init_image,\n    mask_image=mask_image,\n    generator=generator,\n    num_inference_steps=4,\n    guidance_scale=4,\n).images[0]\nimage\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">initial image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm-lora-inpaint.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image</figcaption>\n  </div>\n</div>\n\n## Adapters\n\nLCMs are compatible with adapters like LoRA, ControlNet, T2I-Adapter, and AnimateDiff. You can bring the speed of LCMs to these adapters to generate images in a certain style or condition the model on another input like a canny image.\n\n### LoRA\n\n[LoRA](../tutorials/using_peft_for_inference) adapters can be rapidly finetuned to learn a new style from just a few images and plugged into a pretrained model to generate images in that style.\n\n<hfoptions id=\"lcm-lora\">\n<hfoption id=\"LCM\">\n\nLoad the LCM checkpoint for your supported model into [`UNet2DConditionModel`] and replace the scheduler with the [`LCMScheduler`]. Then you can use the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method to load the LoRA weights into the LCM and generate a styled image in a few steps.\n\n```python\nfrom diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, LCMScheduler\nimport torch\n\nunet = UNet2DConditionModel.from_pretrained(\n    \"latent-consistency/lcm-sdxl\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n)\npipe = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", unet=unet, torch_dtype=torch.float16, variant=\"fp16\",\n).to(\"cuda\")\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\npipe.load_lora_weights(\"TheLastBen/Papercut_SDXL\", weight_name=\"papercut.safetensors\", adapter_name=\"papercut\")\n\nprompt = \"papercut, a cute fox\"\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    prompt=prompt, num_inference_steps=4, generator=generator, guidance_scale=8.0\n).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_full_sdx_lora_mix.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"LCM-LoRA\">\n\nReplace the scheduler with the [`LCMScheduler`]. Then you can use the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method to load the LCM-LoRA weights and the style LoRA you want to use. Combine both LoRA adapters with the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method and generate a styled image in a few steps.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline, LCMScheduler\n\npipe = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    variant=\"fp16\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\n\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\n\npipe.load_lora_weights(\"latent-consistency/lcm-lora-sdxl\", adapter_name=\"lcm\")\npipe.load_lora_weights(\"TheLastBen/Papercut_SDXL\", weight_name=\"papercut.safetensors\", adapter_name=\"papercut\")\n\npipe.set_adapters([\"lcm\", \"papercut\"], adapter_weights=[1.0, 0.8])\n\nprompt = \"papercut, a cute fox\"\ngenerator = torch.manual_seed(0)\nimage = pipe(prompt, num_inference_steps=4, guidance_scale=1, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_sdx_lora_mix.png\"/>\n</div>\n\n</hfoption>\n</hfoptions>\n\n### ControlNet\n\n[ControlNet](./controlnet) are adapters that can be trained on a variety of inputs like canny edge, pose estimation, or depth. The ControlNet can be inserted into the pipeline to provide additional conditioning and control to the model for more accurate generation.\n\nYou can find additional ControlNet models trained on other inputs in [lllyasviel's](https://hf.co/lllyasviel) repository.\n\n<hfoptions id=\"lcm-controlnet\">\n<hfoption id=\"LCM\">\n\nLoad a ControlNet model trained on canny images and pass it to the [`ControlNetModel`]. Then you can load a LCM model into [`StableDiffusionControlNetPipeline`] and replace the scheduler with the [`LCMScheduler`]. Now pass the canny image to the pipeline and generate an image.\n\n> [!TIP]\n> Experiment with different values for `num_inference_steps`, `controlnet_conditioning_scale`, `cross_attention_kwargs`, and `guidance_scale` to get the best results.\n\n```python\nimport torch\nimport cv2\nimport numpy as np\nfrom PIL import Image\n\nfrom diffusers import StableDiffusionControlNetPipeline, ControlNetModel, LCMScheduler\nfrom diffusers.utils import load_image, make_image_grid\n\nimage = load_image(\n    \"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\"\n).resize((512, 512))\n\nimage = np.array(image)\n\nlow_threshold = 100\nhigh_threshold = 200\n\nimage = cv2.Canny(image, low_threshold, high_threshold)\nimage = image[:, :, None]\nimage = np.concatenate([image, image, image], axis=2)\ncanny_image = Image.fromarray(image)\n\ncontrolnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\npipe = StableDiffusionControlNetPipeline.from_pretrained(\n    \"SimianLuo/LCM_Dreamshaper_v7\",\n    controlnet=controlnet,\n    torch_dtype=torch.float16,\n    safety_checker=None,\n).to(\"cuda\")\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\n\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    \"the mona lisa\",\n    image=canny_image,\n    num_inference_steps=4,\n    generator=generator,\n).images[0]\nmake_image_grid([canny_image, image], rows=1, cols=2)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_full_sdv1-5_controlnet.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"LCM-LoRA\">\n\nLoad a ControlNet model trained on canny images and pass it to the [`ControlNetModel`]. Then you can load a Stable Diffusion v1.5 model into [`StableDiffusionControlNetPipeline`] and replace the scheduler with the [`LCMScheduler`]. Use the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method to load the LCM-LoRA weights, and pass the canny image to the pipeline and generate an image.\n\n> [!TIP]\n> Experiment with different values for `num_inference_steps`, `controlnet_conditioning_scale`, `cross_attention_kwargs`, and `guidance_scale` to get the best results.\n\n```py\nimport torch\nimport cv2\nimport numpy as np\nfrom PIL import Image\n\nfrom diffusers import StableDiffusionControlNetPipeline, ControlNetModel, LCMScheduler\nfrom diffusers.utils import load_image\n\nimage = load_image(\n    \"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\"\n).resize((512, 512))\n\nimage = np.array(image)\n\nlow_threshold = 100\nhigh_threshold = 200\n\nimage = cv2.Canny(image, low_threshold, high_threshold)\nimage = image[:, :, None]\nimage = np.concatenate([image, image, image], axis=2)\ncanny_image = Image.fromarray(image)\n\ncontrolnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\npipe = StableDiffusionControlNetPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    controlnet=controlnet,\n    torch_dtype=torch.float16,\n    safety_checker=None,\n    variant=\"fp16\"\n).to(\"cuda\")\n\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\n\npipe.load_lora_weights(\"latent-consistency/lcm-lora-sdv1-5\")\n\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    \"the mona lisa\",\n    image=canny_image,\n    num_inference_steps=4,\n    guidance_scale=1.5,\n    controlnet_conditioning_scale=0.8,\n    cross_attention_kwargs={\"scale\": 1},\n    generator=generator,\n).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_sdv1-5_controlnet.png\"/>\n</div>\n\n</hfoption>\n</hfoptions>\n\n### T2I-Adapter\n\n[T2I-Adapter](./t2i_adapter) is an even more lightweight adapter than ControlNet, that provides an additional input to condition a pretrained model with. It is faster than ControlNet but the results may be slightly worse.\n\nYou can find additional T2I-Adapter checkpoints trained on other inputs in [TencentArc's](https://hf.co/TencentARC) repository.\n\n<hfoptions id=\"lcm-t2i\">\n<hfoption id=\"LCM\">\n\nLoad a T2IAdapter trained on canny images and pass it to the [`StableDiffusionXLAdapterPipeline`]. Then load a LCM checkpoint into [`UNet2DConditionModel`] and replace the scheduler with the [`LCMScheduler`]. Now pass the canny image to the pipeline and generate an image.\n\n```python\nimport torch\nimport cv2\nimport numpy as np\nfrom PIL import Image\n\nfrom diffusers import StableDiffusionXLAdapterPipeline, UNet2DConditionModel, T2IAdapter, LCMScheduler\nfrom diffusers.utils import load_image, make_image_grid\n\n# detect the canny map in low resolution to avoid high-frequency details\nimage = load_image(\n    \"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\"\n).resize((384, 384))\n\nimage = np.array(image)\n\nlow_threshold = 100\nhigh_threshold = 200\n\nimage = cv2.Canny(image, low_threshold, high_threshold)\nimage = image[:, :, None]\nimage = np.concatenate([image, image, image], axis=2)\ncanny_image = Image.fromarray(image).resize((1024, 1216))\n\nadapter = T2IAdapter.from_pretrained(\"TencentARC/t2i-adapter-canny-sdxl-1.0\", torch_dtype=torch.float16, variant=\"fp16\").to(\"cuda\")\n\nunet = UNet2DConditionModel.from_pretrained(\n    \"latent-consistency/lcm-sdxl\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n)\npipe = StableDiffusionXLAdapterPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    unet=unet,\n    adapter=adapter,\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n).to(\"cuda\")\n\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\n\nprompt = \"the mona lisa, 4k picture, high quality\"\nnegative_prompt = \"extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured\"\n\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    image=canny_image,\n    num_inference_steps=4,\n    guidance_scale=5,\n    adapter_conditioning_scale=0.8,\n    adapter_conditioning_factor=1,\n    generator=generator,\n).images[0]\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm-t2i.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"LCM-LoRA\">\n\nLoad a T2IAdapter trained on canny images and pass it to the [`StableDiffusionXLAdapterPipeline`]. Replace the scheduler with the [`LCMScheduler`], and use the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method to load the LCM-LoRA weights. Pass the canny image to the pipeline and generate an image.\n\n```py\nimport torch\nimport cv2\nimport numpy as np\nfrom PIL import Image\n\nfrom diffusers import StableDiffusionXLAdapterPipeline, UNet2DConditionModel, T2IAdapter, LCMScheduler\nfrom diffusers.utils import load_image, make_image_grid\n\n# detect the canny map in low resolution to avoid high-frequency details\nimage = load_image(\n    \"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\"\n).resize((384, 384))\n\nimage = np.array(image)\n\nlow_threshold = 100\nhigh_threshold = 200\n\nimage = cv2.Canny(image, low_threshold, high_threshold)\nimage = image[:, :, None]\nimage = np.concatenate([image, image, image], axis=2)\ncanny_image = Image.fromarray(image).resize((1024, 1024))\n\nadapter = T2IAdapter.from_pretrained(\"TencentARC/t2i-adapter-canny-sdxl-1.0\", torch_dtype=torch.float16, variant=\"fp16\").to(\"cuda\")\n\npipe = StableDiffusionXLAdapterPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    adapter=adapter,\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n).to(\"cuda\")\n\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\n\npipe.load_lora_weights(\"latent-consistency/lcm-lora-sdxl\")\n\nprompt = \"the mona lisa, 4k picture, high quality\"\nnegative_prompt = \"extra digit, fewer digits, cropped, worst quality, low quality, glitch, deformed, mutated, ugly, disfigured\"\n\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    image=canny_image,\n    num_inference_steps=4,\n    guidance_scale=1.5,\n    adapter_conditioning_scale=0.8,\n    adapter_conditioning_factor=1,\n    generator=generator,\n).images[0]\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm-lora-t2i.png\"/>\n</div>\n\n</hfoption>\n</hfoptions>\n\n### AnimateDiff\n\n[AnimateDiff](../api/pipelines/animatediff) is an adapter that adds motion to an image. It can be used with most Stable Diffusion models, effectively turning them into \"video generation\" models. Generating good results with a video model usually requires generating multiple frames (16-24), which can be very slow with a regular Stable Diffusion model. LCM-LoRA can speed up this process by only taking 4-8 steps for each frame.\n\nLoad a [`AnimateDiffPipeline`] and pass a [`MotionAdapter`] to it. Then replace the scheduler with the [`LCMScheduler`], and combine both LoRA adapters with the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method. Now you can pass a prompt to the pipeline and generate an animated image.\n\n```py\nimport torch\nfrom diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler, LCMScheduler\nfrom diffusers.utils import export_to_gif\n\nadapter = MotionAdapter.from_pretrained(\"guoyww/animatediff-motion-adapter-v1-5\")\npipe = AnimateDiffPipeline.from_pretrained(\n    \"frankjoshua/toonyou_beta6\",\n    motion_adapter=adapter,\n).to(\"cuda\")\n\n# set scheduler\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\n\n# load LCM-LoRA\npipe.load_lora_weights(\"latent-consistency/lcm-lora-sdv1-5\", adapter_name=\"lcm\")\npipe.load_lora_weights(\"guoyww/animatediff-motion-lora-zoom-in\", weight_name=\"diffusion_pytorch_model.safetensors\", adapter_name=\"motion-lora\")\n\npipe.set_adapters([\"lcm\", \"motion-lora\"], adapter_weights=[0.55, 1.2])\n\nprompt = \"best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress\"\ngenerator = torch.manual_seed(0)\nframes = pipe(\n    prompt=prompt,\n    num_inference_steps=5,\n    guidance_scale=1.25,\n    cross_attention_kwargs={\"scale\": 1},\n    num_frames=24,\n    generator=generator\n).frames[0]\nexport_to_gif(frames, \"animation.gif\")\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm-lora-animatediff.gif\"/>\n</div>\n"
  },
  {
    "path": "docs/source/en/using-diffusers/inference_with_tcd_lora.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# Trajectory Consistency Distillation-LoRA\n\nTrajectory Consistency Distillation (TCD) enables a model to generate higher quality and more detailed images with fewer steps. Moreover, owing to the effective error mitigation during the distillation process, TCD demonstrates superior performance even under conditions of large inference steps.\n\nThe major advantages of TCD are:\n\n- Better than Teacher: TCD demonstrates superior generative quality at both small and large inference steps and exceeds the performance of [DPM-Solver++(2S)](../api/schedulers/multistep_dpm_solver) with Stable Diffusion XL (SDXL). There is no additional discriminator or LPIPS supervision included during TCD training.\n\n- Flexible Inference Steps: The inference steps for TCD sampling can be freely adjusted without adversely affecting the image quality.\n\n- Freely change detail level: During inference, the level of detail in the image can be adjusted with a single hyperparameter, *gamma*.\n\n> [!TIP]\n> For more technical details of TCD, please refer to the [paper](https://huggingface.co/papers/2402.19159) or official [project page](https://mhh0318.github.io/tcd/).\n\nFor large models like SDXL, TCD is trained with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) to reduce memory usage. This is also useful because you can reuse LoRAs between different finetuned models, as long as they share the same base model, without further training.\n\n\n\nThis guide will show you how to perform inference with TCD-LoRAs for a variety of tasks like text-to-image and inpainting, as well as how you can easily combine TCD-LoRAs with other adapters. Choose one of the supported base model and it's corresponding TCD-LoRA checkpoint from the table below to get started.\n\n| Base model                                                                                      | TCD-LoRA checkpoint                                            |\n|-------------------------------------------------------------------------------------------------|----------------------------------------------------------------|\n| [stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)                  | [TCD-SD15](https://huggingface.co/h1t/TCD-SD15-LoRA)           |\n| [stable-diffusion-2-1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base)       | [TCD-SD21-base](https://huggingface.co/h1t/TCD-SD21-base-LoRA) |\n| [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | [TCD-SDXL](https://huggingface.co/h1t/TCD-SDXL-LoRA)           |\n\n\nMake sure you have [PEFT](https://github.com/huggingface/peft) installed for better LoRA support.\n\n```bash\npip install -U peft\n```\n\n## General tasks\n\nIn this guide, let's use the [`StableDiffusionXLPipeline`] and the [`TCDScheduler`]. Use the [`~StableDiffusionPipeline.load_lora_weights`] method to load the SDXL-compatible TCD-LoRA weights.\n\nA few tips to keep in mind for TCD-LoRA inference are to:\n\n- Keep the `num_inference_steps` between 4 and 50\n- Set `eta` (used to control stochasticity at each step) between 0 and 1. You should use a higher `eta` when increasing the number of inference steps, but the downside is that a larger `eta` in [`TCDScheduler`] leads to blurrier images. A value of 0.3 is recommended to produce good results.\n\n<hfoptions id=\"tasks\">\n<hfoption id=\"text-to-image\">\n\n```python\nimport torch\nfrom diffusers import StableDiffusionXLPipeline, TCDScheduler\n\ndevice = \"cuda\"\nbase_model_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\ntcd_lora_id = \"h1t/TCD-SDXL-LoRA\"\n\npipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant=\"fp16\").to(device)\npipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)\n\npipe.load_lora_weights(tcd_lora_id)\npipe.fuse_lora()\n\nprompt = \"Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.\"\n\nimage = pipe(\n    prompt=prompt,\n    num_inference_steps=4,\n    guidance_scale=0,\n    eta=0.3,\n    generator=torch.Generator(device=device).manual_seed(0),\n).images[0]\n```\n\n![](https://github.com/jabir-zheng/TCD/raw/main/assets/demo_image.png)\n\n</hfoption>\n\n<hfoption id=\"inpainting\">\n\n```python\nimport torch\nfrom diffusers import AutoPipelineForInpainting, TCDScheduler\nfrom diffusers.utils import load_image, make_image_grid\n\ndevice = \"cuda\"\nbase_model_id = \"diffusers/stable-diffusion-xl-1.0-inpainting-0.1\"\ntcd_lora_id = \"h1t/TCD-SDXL-LoRA\"\n\npipe = AutoPipelineForInpainting.from_pretrained(base_model_id, torch_dtype=torch.float16, variant=\"fp16\").to(device)\npipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)\n\npipe.load_lora_weights(tcd_lora_id)\npipe.fuse_lora()\n\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\ninit_image = load_image(img_url).resize((1024, 1024))\nmask_image = load_image(mask_url).resize((1024, 1024))\n\nprompt = \"a tiger sitting on a park bench\"\n\nimage = pipe(\n  prompt=prompt,\n  image=init_image,\n  mask_image=mask_image,\n  num_inference_steps=8,\n  guidance_scale=0,\n  eta=0.3,\n  strength=0.99,  # make sure to use `strength` below 1.0\n  generator=torch.Generator(device=device).manual_seed(0),\n).images[0]\n\ngrid_image = make_image_grid([init_image, mask_image, image], rows=1, cols=3)\n```\n\n![](https://github.com/jabir-zheng/TCD/raw/main/assets/inpainting_tcd.png)\n\n\n</hfoption>\n</hfoptions>\n\n## Community models\n\nTCD-LoRA also works with many community finetuned models and plugins. For example, load the [animagine-xl-3.0](https://huggingface.co/cagliostrolab/animagine-xl-3.0) checkpoint which is a community finetuned version of SDXL for generating anime images.\n\n```python\nimport torch\nfrom diffusers import StableDiffusionXLPipeline, TCDScheduler\n\ndevice = \"cuda\"\nbase_model_id = \"cagliostrolab/animagine-xl-3.0\"\ntcd_lora_id = \"h1t/TCD-SDXL-LoRA\"\n\npipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant=\"fp16\").to(device)\npipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)\n\npipe.load_lora_weights(tcd_lora_id)\npipe.fuse_lora()\n\nprompt = \"A man, clad in a meticulously tailored military uniform, stands with unwavering resolve. The uniform boasts intricate details, and his eyes gleam with determination. Strands of vibrant, windswept hair peek out from beneath the brim of his cap.\"\n\nimage = pipe(\n    prompt=prompt,\n    num_inference_steps=8,\n    guidance_scale=0,\n    eta=0.3,\n    generator=torch.Generator(device=device).manual_seed(0),\n).images[0]\n```\n\n![](https://github.com/jabir-zheng/TCD/raw/main/assets/animagine_xl.png)\n\nTCD-LoRA also supports other LoRAs trained on different styles. For example, let's load the [TheLastBen/Papercut_SDXL](https://huggingface.co/TheLastBen/Papercut_SDXL) LoRA and fuse it with the TCD-LoRA with the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method.\n\n> [!TIP]\n> Check out the [Merge LoRAs](../tutorials/using_peft_for_inference#merge) guide to learn more about efficient merging methods.\n\n```python\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom scheduling_tcd import TCDScheduler\n\ndevice = \"cuda\"\nbase_model_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\ntcd_lora_id = \"h1t/TCD-SDXL-LoRA\"\nstyled_lora_id = \"TheLastBen/Papercut_SDXL\"\n\npipe = StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant=\"fp16\").to(device)\npipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)\n\npipe.load_lora_weights(tcd_lora_id, adapter_name=\"tcd\")\npipe.load_lora_weights(styled_lora_id, adapter_name=\"style\")\npipe.set_adapters([\"tcd\", \"style\"], adapter_weights=[1.0, 1.0])\n\nprompt = \"papercut of a winter mountain, snow\"\n\nimage = pipe(\n    prompt=prompt,\n    num_inference_steps=4,\n    guidance_scale=0,\n    eta=0.3,\n    generator=torch.Generator(device=device).manual_seed(0),\n).images[0]\n```\n\n![](https://github.com/jabir-zheng/TCD/raw/main/assets/styled_lora.png)\n\n\n## Adapters\n\nTCD-LoRA is very versatile, and it can be combined with other adapter types like ControlNets, IP-Adapter, and AnimateDiff.\n\n<hfoptions id=\"adapters\">\n<hfoption id=\"ControlNet\">\n\n### Depth ControlNet\n\n```python\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom transformers import DPTImageProcessor, DPTForDepthEstimation\nfrom diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline\nfrom diffusers.utils import load_image, make_image_grid\nfrom scheduling_tcd import TCDScheduler\n\ndevice = \"cuda\"\ndepth_estimator = DPTForDepthEstimation.from_pretrained(\"Intel/dpt-hybrid-midas\").to(device)\nfeature_extractor = DPTImageProcessor.from_pretrained(\"Intel/dpt-hybrid-midas\")\n\ndef get_depth_map(image):\n    image = feature_extractor(images=image, return_tensors=\"pt\").pixel_values.to(device)\n    with torch.no_grad(), torch.autocast(device):\n        depth_map = depth_estimator(image).predicted_depth\n\n    depth_map = torch.nn.functional.interpolate(\n        depth_map.unsqueeze(1),\n        size=(1024, 1024),\n        mode=\"bicubic\",\n        align_corners=False,\n    )\n    depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)\n    depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)\n    depth_map = (depth_map - depth_min) / (depth_max - depth_min)\n    image = torch.cat([depth_map] * 3, dim=1)\n\n    image = image.permute(0, 2, 3, 1).cpu().numpy()[0]\n    image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))\n    return image\n\nbase_model_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\ncontrolnet_id = \"diffusers/controlnet-depth-sdxl-1.0\"\ntcd_lora_id = \"h1t/TCD-SDXL-LoRA\"\n\ncontrolnet = ControlNetModel.from_pretrained(\n    controlnet_id,\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n)\npipe = StableDiffusionXLControlNetPipeline.from_pretrained(\n    base_model_id,\n    controlnet=controlnet,\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n)\npipe.enable_model_cpu_offload()\n\npipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)\n\npipe.load_lora_weights(tcd_lora_id)\npipe.fuse_lora()\n\nprompt = \"stormtrooper lecture, photorealistic\"\n\nimage = load_image(\"https://huggingface.co/lllyasviel/sd-controlnet-depth/resolve/main/images/stormtrooper.png\")\ndepth_image = get_depth_map(image)\n\ncontrolnet_conditioning_scale = 0.5  # recommended for good generalization\n\nimage = pipe(\n    prompt,\n    image=depth_image,\n    num_inference_steps=4,\n    guidance_scale=0,\n    eta=0.3,\n    controlnet_conditioning_scale=controlnet_conditioning_scale,\n    generator=torch.Generator(device=device).manual_seed(0),\n).images[0]\n\ngrid_image = make_image_grid([depth_image, image], rows=1, cols=2)\n```\n\n![](https://github.com/jabir-zheng/TCD/raw/main/assets/controlnet_depth_tcd.png)\n\n### Canny ControlNet\n```python\nimport torch\nfrom diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline\nfrom diffusers.utils import load_image, make_image_grid\nfrom scheduling_tcd import TCDScheduler\n\ndevice = \"cuda\"\nbase_model_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\ncontrolnet_id = \"diffusers/controlnet-canny-sdxl-1.0\"\ntcd_lora_id = \"h1t/TCD-SDXL-LoRA\"\n\ncontrolnet = ControlNetModel.from_pretrained(\n    controlnet_id,\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n)\npipe = StableDiffusionXLControlNetPipeline.from_pretrained(\n    base_model_id,\n    controlnet=controlnet,\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n)\npipe.enable_model_cpu_offload()\n\npipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)\n\npipe.load_lora_weights(tcd_lora_id)\npipe.fuse_lora()\n\nprompt = \"ultrarealistic shot of a furry blue bird\"\n\ncanny_image = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png\")\n\ncontrolnet_conditioning_scale = 0.5  # recommended for good generalization\n\nimage = pipe(\n    prompt,\n    image=canny_image,\n    num_inference_steps=4,\n    guidance_scale=0,\n    eta=0.3,\n    controlnet_conditioning_scale=controlnet_conditioning_scale,\n    generator=torch.Generator(device=device).manual_seed(0),\n).images[0]\n\ngrid_image = make_image_grid([canny_image, image], rows=1, cols=2)\n```\n![](https://github.com/jabir-zheng/TCD/raw/main/assets/controlnet_canny_tcd.png)\n\n> [!TIP]\n> The inference parameters in this example might not work for all examples, so we recommend you to try different values for `num_inference_steps`, `guidance_scale`, `controlnet_conditioning_scale` and `cross_attention_kwargs` parameters and choose the best one.\n\n</hfoption>\n<hfoption id=\"IP-Adapter\">\n\nThis example shows how to use the TCD-LoRA with the [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter/tree/main) and SDXL.\n\n```python\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom diffusers.utils import load_image, make_image_grid\n\nfrom ip_adapter import IPAdapterXL\nfrom scheduling_tcd import TCDScheduler\n\ndevice = \"cuda\"\nbase_model_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\nimage_encoder_path = \"sdxl_models/image_encoder\"\nip_ckpt = \"sdxl_models/ip-adapter_sdxl.bin\"\ntcd_lora_id = \"h1t/TCD-SDXL-LoRA\"\n\npipe = StableDiffusionXLPipeline.from_pretrained(\n    base_model_path,\n    torch_dtype=torch.float16,\n    variant=\"fp16\"\n)\npipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)\n\npipe.load_lora_weights(tcd_lora_id)\npipe.fuse_lora()\n\nip_model = IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device)\n\nref_image = load_image(\"https://raw.githubusercontent.com/tencent-ailab/IP-Adapter/main/assets/images/woman.png\").resize((512, 512))\n\nprompt = \"best quality, high quality, wearing sunglasses\"\n\nimage = ip_model.generate(\n    pil_image=ref_image,\n    prompt=prompt,\n    scale=0.5,\n    num_samples=1,\n    num_inference_steps=4,\n    guidance_scale=0,\n    eta=0.3,\n    seed=0,\n)[0]\n\ngrid_image = make_image_grid([ref_image, image], rows=1, cols=2)\n```\n\n![](https://github.com/jabir-zheng/TCD/raw/main/assets/ip_adapter.png)\n\n\n\n</hfoption>\n<hfoption id=\"AnimateDiff\">\n\n[`AnimateDiff`] allows animating images using Stable Diffusion models. TCD-LoRA can substantially accelerate the process without degrading image quality. The quality of animation with TCD-LoRA and AnimateDiff has a more lucid outcome.\n\n```python\nimport torch\nfrom diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler\nfrom scheduling_tcd import TCDScheduler\nfrom diffusers.utils import export_to_gif\n\nadapter = MotionAdapter.from_pretrained(\"guoyww/animatediff-motion-adapter-v1-5\")\npipe = AnimateDiffPipeline.from_pretrained(\n    \"frankjoshua/toonyou_beta6\",\n    motion_adapter=adapter,\n).to(\"cuda\")\n\n# set TCDScheduler\npipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)\n\n# load TCD LoRA\npipe.load_lora_weights(\"h1t/TCD-SD15-LoRA\", adapter_name=\"tcd\")\npipe.load_lora_weights(\"guoyww/animatediff-motion-lora-zoom-in\", weight_name=\"diffusion_pytorch_model.safetensors\", adapter_name=\"motion-lora\")\n\npipe.set_adapters([\"tcd\", \"motion-lora\"], adapter_weights=[1.0, 1.2])\n\nprompt = \"best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress\"\ngenerator = torch.manual_seed(0)\nframes = pipe(\n    prompt=prompt,\n    num_inference_steps=5,\n    guidance_scale=0,\n    cross_attention_kwargs={\"scale\": 1},\n    num_frames=24,\n    eta=0.3,\n    generator=generator\n).frames[0]\nexport_to_gif(frames, \"animation.gif\")\n```\n\n![](https://github.com/jabir-zheng/TCD/raw/main/assets/animation_example.gif)\n\n</hfoption>\n</hfoptions>"
  },
  {
    "path": "docs/source/en/using-diffusers/inpaint.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Inpainting\n\n[[open-in-colab]]\n\nInpainting replaces or edits specific areas of an image. This makes it a useful tool for image restoration like removing defects and artifacts, or even replacing an image area with something entirely new. Inpainting relies on a mask to determine which regions of an image to fill in; the area to inpaint is represented by white pixels and the area to keep is represented by black pixels. The white pixels are filled in by the prompt.\n\nWith 🤗 Diffusers, here is how you can do inpainting:\n\n1. Load an inpainting checkpoint with the [`AutoPipelineForInpainting`] class. This'll automatically detect the appropriate pipeline class to load based on the checkpoint:\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-decoder-inpaint\", torch_dtype=torch.float16\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n```\n\n> [!TIP]\n> You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, it's not necessary to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/fp16#scaled-dot-product-attention).\n\n2. Load the base and mask images:\n\n```py\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\")\n```\n\n3. Create a prompt to inpaint the image with and pass it to the pipeline with the base and mask images:\n\n```py\nprompt = \"a black cat with glowing eyes, cute, adorable, disney, pixar, highly detailed, 8k\"\nnegative_prompt = \"bad anatomy, deformed, ugly, disfigured\"\nimage = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]\nmake_image_grid([init_image, mask_image, image], rows=1, cols=3)\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">base image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">mask image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-cat.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image</figcaption>\n  </div>\n</div>\n\n## Create a mask image\n\nThroughout this guide, the mask image is provided in all of the code examples for convenience. You can inpaint on your own images, but you'll need to create a mask image for it. Use the Space below to easily create a mask image.\n\nUpload a base image to inpaint on and use the sketch tool to draw a mask. Once you're done, click **Run** to generate and download the mask image.\n\n<iframe\n  src=\"https://stevhliu-inpaint-mask-maker.hf.space\"\n  frameborder=\"0\"\n  width=\"850\"\n  height=\"450\"\n></iframe>\n\n### Mask blur\n\nThe [`~VaeImageProcessor.blur`] method provides an option for how to blend the original image and inpaint area. The amount of blur is determined by the `blur_factor` parameter. Increasing the `blur_factor` increases the amount of blur applied to the mask edges, softening the transition between the original image and inpaint area. A low or zero `blur_factor` preserves the sharper edges of the mask.\n\nTo use this, create a blurred mask with the image processor.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image\nfrom PIL import Image\n\npipeline = AutoPipelineForInpainting.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16).to('cuda')\n\nmask = load_image(\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore_mask.png\")\nblurred_mask = pipeline.mask_processor.blur(mask, blur_factor=33)\nblurred_mask\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore_mask.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">mask with no blur</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/mask_blurred.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">mask with blur applied</figcaption>\n  </div>\n</div>\n\n## Popular models\n\n[Stable Diffusion Inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting), [Stable Diffusion XL (SDXL) Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), and [Kandinsky 2.2 Inpainting](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint) are among the most popular models for inpainting. SDXL typically produces higher resolution images than Stable Diffusion v1.5, and Kandinsky 2.2 is also capable of generating high-quality images.\n\n### Stable Diffusion Inpainting\n\nStable Diffusion Inpainting is a latent diffusion model finetuned on 512x512 images on inpainting. It is a good starting point because it is relatively fast and generates good quality images. To use this model for inpainting, you'll need to pass a prompt, base and mask image to the pipeline:\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# load base and mask image\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\")\n\ngenerator = torch.Generator(\"cuda\").manual_seed(92)\nprompt = \"concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k\"\nimage = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]\nmake_image_grid([init_image, mask_image, image], rows=1, cols=3)\n```\n\n### Stable Diffusion XL (SDXL) Inpainting\n\nSDXL is a larger and more powerful version of Stable Diffusion v1.5. This model can follow a two-stage model process (though each model can also be used alone); the base model generates an image, and a refiner model takes that image and further enhances its details and quality. Take a look at the [SDXL](sdxl) guide for a more comprehensive guide on how to use SDXL and configure it's parameters.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"diffusers/stable-diffusion-xl-1.0-inpainting-0.1\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# load base and mask image\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\")\n\ngenerator = torch.Generator(\"cuda\").manual_seed(92)\nprompt = \"concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k\"\nimage = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]\nmake_image_grid([init_image, mask_image, image], rows=1, cols=3)\n```\n\n### Kandinsky 2.2 Inpainting\n\nThe Kandinsky model family is similar to SDXL because it uses two models as well; the image prior model creates image embeddings, and the diffusion model generates images from them. You can load the image prior and diffusion model separately, but the easiest way to use Kandinsky 2.2 is to load it into the [`AutoPipelineForInpainting`] class which uses the [`KandinskyV22InpaintCombinedPipeline`] under the hood.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-decoder-inpaint\", torch_dtype=torch.float16\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# load base and mask image\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\")\n\ngenerator = torch.Generator(\"cuda\").manual_seed(92)\nprompt = \"concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k\"\nimage = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]\nmake_image_grid([init_image, mask_image, image], rows=1, cols=3)\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">base image</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-sdv1.5.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">Stable Diffusion Inpainting</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-sdxl.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">Stable Diffusion XL Inpainting</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-kandinsky.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">Kandinsky 2.2 Inpainting</figcaption>\n  </div>\n</div>\n\n## Non-inpaint specific checkpoints\n\n\nSo far, this guide has used inpaint specific checkpoints such as [stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting). But you can also use regular checkpoints like [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5). Let's compare the results of the two checkpoints.\n\nThe image on the left is generated from a regular checkpoint, and the image on the right is from an inpaint checkpoint. You'll immediately notice the image on the left is not as clean, and you can still see the outline of the area the model is supposed to inpaint. The image on the right is much cleaner and the inpainted area appears more natural.\n\n<hfoptions id=\"regular-specific\">\n<hfoption id=\"stable-diffusion-v1-5/stable-diffusion-v1-5\">\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# load base and mask image\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\")\n\ngenerator = torch.Generator(\"cuda\").manual_seed(92)\nprompt = \"concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k\"\nimage = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n</hfoption>\n<hfoption id=\"stable-diffusion-v1-5/stable-diffusion-inpainting\">\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# load base and mask image\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\")\n\ngenerator = torch.Generator(\"cuda\").manual_seed(92)\nprompt = \"concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k\"\nimage = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n</hfoption>\n</hfoptions>\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-inpaint-specific.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">stable-diffusion-v1-5/stable-diffusion-v1-5</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-specific.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">stable-diffusion-v1-5/stable-diffusion-inpainting</figcaption>\n  </div>\n</div>\n\nHowever, for more basic tasks like erasing an object from an image (like the rocks in the road for example), a regular checkpoint yields pretty good results. There isn't as noticeable of difference between the regular and inpaint checkpoint.\n\n<hfoptions id=\"inpaint\">\n<hfoption id=\"stable-diffusion-v1-5/stable-diffusion-v1-5\">\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# load base and mask image\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/road-mask.png\")\n\nimage = pipeline(prompt=\"road\", image=init_image, mask_image=mask_image).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n</hfoption>\n<hfoption id=\"stable-diffusion-v1-5/stable-diffusion-inpaint\">\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# load base and mask image\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/road-mask.png\")\n\nimage = pipeline(prompt=\"road\", image=init_image, mask_image=mask_image).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n</hfoption>\n</hfoptions>\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/regular-inpaint-basic.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">stable-diffusion-v1-5/stable-diffusion-v1-5</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/specific-inpaint-basic.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">stable-diffusion-v1-5/stable-diffusion-inpainting</figcaption>\n  </div>\n</div>\n\nThe trade-off of using a non-inpaint specific checkpoint is the overall image quality may be lower, but it generally tends to preserve the mask area (that is why you can see the mask outline). The inpaint specific checkpoints are intentionally trained to generate higher quality inpainted images, and that includes creating a more natural transition between the masked and unmasked areas. As a result, these checkpoints are more likely to change your unmasked area.\n\nIf preserving the unmasked area is important for your task, you can use the [`VaeImageProcessor.apply_overlay`] method to force the unmasked area of an image to remain the same at the expense of some more unnatural transitions between the masked and unmasked areas.\n\n```py\nimport PIL\nimport numpy as np\nimport torch\n\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\ndevice = \"cuda\"\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\"\n)\npipeline = pipeline.to(device)\n\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\ninit_image = load_image(img_url).resize((512, 512))\nmask_image = load_image(mask_url).resize((512, 512))\n\nprompt = \"Face of a yellow cat, high resolution, sitting on a park bench\"\nrepainted_image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]\nrepainted_image.save(\"repainted_image.png\")\n\nunmasked_unchanged_image = pipeline.image_processor.apply_overlay(mask_image, init_image, repainted_image)\nunmasked_unchanged_image.save(\"force_unmasked_unchanged.png\")\nmake_image_grid([init_image, mask_image, repainted_image, unmasked_unchanged_image], rows=2, cols=2)\n```\n\n## Configure pipeline parameters\n\nImage features - like quality and \"creativity\" - are dependent on pipeline parameters. Knowing what these parameters do is important for getting the results you want. Let's take a look at the most important parameters and see how changing them affects the output.\n\n### Strength\n\n`strength` is a measure of how much noise is added to the base image, which influences how similar the output is to the base image.\n\n* 📈 a high `strength` value means more noise is added to an image and the denoising process takes longer, but you'll get higher quality images that are more different from the base image\n* 📉 a low `strength` value means less noise is added to an image and the denoising process is faster, but the image quality may not be as great and the generated image resembles the base image more\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# load base and mask image\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\")\n\nprompt = \"concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k\"\nimage = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.6).images[0]\nmake_image_grid([init_image, mask_image, image], rows=1, cols=3)\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-strength-0.6.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">strength = 0.6</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-strength-0.8.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">strength = 0.8</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-strength-1.0.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">strength = 1.0</figcaption>\n  </div>\n</div>\n\n### Guidance scale\n\n`guidance_scale` affects how aligned the text prompt and generated image are.\n\n* 📈 a high `guidance_scale` value means the prompt and generated image are closely aligned, so the output is a stricter interpretation of the prompt\n* 📉 a low `guidance_scale` value means the prompt and generated image are more loosely aligned, so the output may be more varied from the prompt\n\nYou can use `strength` and `guidance_scale` together for more control over how expressive the model is. For example, a combination high `strength` and `guidance_scale` values gives the model the most creative freedom.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# load base and mask image\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\")\n\nprompt = \"concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k\"\nimage = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, guidance_scale=2.5).images[0]\nmake_image_grid([init_image, mask_image, image], rows=1, cols=3)\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-guidance-2.5.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">guidance_scale = 2.5</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-guidance-7.5.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">guidance_scale = 7.5</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-guidance-12.5.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">guidance_scale = 12.5</figcaption>\n  </div>\n</div>\n\n### Negative prompt\n\nA negative prompt assumes the opposite role of a prompt; it guides the model away from generating certain things in an image. This is useful for quickly improving image quality and preventing the model from generating things you don't want.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# load base and mask image\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\")\n\nprompt = \"concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k\"\nnegative_prompt = \"bad architecture, unstable, poor details, blurry\"\nimage = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]\nmake_image_grid([init_image, mask_image, image], rows=1, cols=3)\n```\n\n<div class=\"flex justify-center\">\n  <figure>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-negative.png\" />\n    <figcaption class=\"text-center\">negative_prompt = \"bad architecture, unstable, poor details, blurry\"</figcaption>\n  </figure>\n</div>\n\n### Padding mask crop\n\nA method for increasing the inpainting image quality is to use the [`padding_mask_crop`](https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusionInpaintPipeline.__call__.padding_mask_crop) parameter. When enabled, this option crops the masked area with some user-specified padding and it'll also crop the same area from the original image. Both the image and mask are upscaled to a higher resolution for inpainting, and then overlaid on the original image. This is a quick and easy way to improve image quality without using a separate pipeline like [`StableDiffusionUpscalePipeline`].\n\nAdd the `padding_mask_crop` parameter to the pipeline call and set it to the desired padding value.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image\nfrom PIL import Image\n\ngenerator = torch.Generator(device='cuda').manual_seed(0)\npipeline = AutoPipelineForInpainting.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16).to('cuda')\n\nbase = load_image(\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png\")\nmask = load_image(\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore_mask.png\")\n\nimage = pipeline(\"boat\", image=base, mask_image=mask, strength=0.75, generator=generator, padding_mask_crop=32).images[0]\nimage\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/baseline_inpaint.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">default inpaint image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/padding_mask_crop_inpaint.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">inpaint image with `padding_mask_crop` enabled</figcaption>\n  </div>\n</div>\n\n## Chained inpainting pipelines\n\n[`AutoPipelineForInpainting`] can be chained with other 🤗 Diffusers pipelines to edit their outputs. This is often useful for improving the output quality from your other diffusion pipelines, and if you're using multiple pipelines, it can be more memory-efficient to chain them together to keep the outputs in latent space and reuse the same pipeline components.\n\n### Text-to-image-to-inpaint\n\nChaining a text-to-image and inpainting pipeline allows you to inpaint the generated image, and you don't have to provide a base image to begin with. This makes it convenient to edit your favorite text-to-image outputs without having to generate an entirely new image.\n\nStart with the text-to-image pipeline to create a castle:\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\ntext2image = pipeline(\"concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k\").images[0]\n```\n\nLoad the mask image of the output from above:\n\n```py\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_text-chain-mask.png\")\n```\n\nAnd let's inpaint the masked area with a waterfall:\n\n```py\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-decoder-inpaint\", torch_dtype=torch.float16\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\nprompt = \"digital painting of a fantasy waterfall, cloudy\"\nimage = pipeline(prompt=prompt, image=text2image, mask_image=mask_image).images[0]\nmake_image_grid([text2image, mask_image, image], rows=1, cols=3)\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-text-chain.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">text-to-image</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-text-chain-out.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">inpaint</figcaption>\n  </div>\n</div>\n\n### Inpaint-to-image-to-image\n\nYou can also chain an inpainting pipeline before another pipeline like image-to-image or an upscaler to improve the quality.\n\nBegin by inpainting an image:\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting, AutoPipelineForImage2Image\nfrom diffusers.utils import load_image, make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# load base and mask image\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\")\n\nprompt = \"concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k\"\nimage_inpainting = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]\n\n# resize image to 1024x1024 for SDXL\nimage_inpainting = image_inpainting.resize((1024, 1024))\n```\n\nNow let's pass the image to another inpainting pipeline with SDXL's refiner model to enhance the image details and quality:\n\n```py\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-refiner-1.0\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\nimage = pipeline(prompt=prompt, image=image_inpainting, mask_image=mask_image, output_type=\"latent\").images[0]\n```\n\n> [!TIP]\n> It is important to specify `output_type=\"latent\"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE. For example, in the [Text-to-image-to-inpaint](#text-to-image-to-inpaint) section, Kandinsky 2.2 uses a different VAE class than the Stable Diffusion model so it won't work. But if you use Stable Diffusion v1.5 for both pipelines, then you can keep everything in latent space because they both use [`AutoencoderKL`].\n\nFinally, you can pass this image to an image-to-image pipeline to put the finishing touches on it. It is more efficient to use the [`~AutoPipelineForImage2Image.from_pipe`] method to reuse the existing pipeline components, and avoid unnecessarily loading all the pipeline components into memory again.\n\n```py\npipeline = AutoPipelineForImage2Image.from_pipe(pipeline)\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\nimage = pipeline(prompt=prompt, image=image).images[0]\nmake_image_grid([init_image, mask_image, image_inpainting, image], rows=2, cols=2)\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">initial image</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-to-image-chain.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">inpaint</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-to-image-final.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">image-to-image</figcaption>\n  </div>\n</div>\n\nImage-to-image and inpainting are actually very similar tasks. Image-to-image generates a new image that resembles the existing provided image. Inpainting does the same thing, but it only transforms the image area defined by the mask and the rest of the image is unchanged. You can think of inpainting as a more precise tool for making specific changes and image-to-image has a broader scope for making more sweeping changes.\n\n## Control image generation\n\nGetting an image to look exactly the way you want is challenging because the denoising process is random. While you can control certain aspects of generation by configuring parameters like `negative_prompt`, there are better and more efficient methods for controlling image generation.\n\n### Prompt weighting\n\nPrompt weighting provides a quantifiable way to scale the representation of concepts in a prompt. You can use it to increase or decrease the magnitude of the text embedding vector for each concept in the prompt, which subsequently determines how much of each concept is generated. The [Compel](https://github.com/damian0815/compel) library offers an intuitive syntax for scaling the prompt weights and generating the embeddings. Learn how to create the embeddings in the [Prompt weighting](../using-diffusers/weighted_prompts) guide.\n\nOnce you've generated the embeddings, pass them to the `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter in the [`AutoPipelineForInpainting`]. The embeddings replace the `prompt` parameter:\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import make_image_grid\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\", torch_dtype=torch.float16,\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\nimage = pipeline(prompt_embeds=prompt_embeds, # generated from Compel\n    negative_prompt_embeds=negative_prompt_embeds, # generated from Compel\n    image=init_image,\n    mask_image=mask_image\n).images[0]\nmake_image_grid([init_image, mask_image, image], rows=1, cols=3)\n```\n\n### ControlNet\n\nControlNet models are used with other diffusion models like Stable Diffusion, and they provide an even more flexible and accurate way to control how an image is generated. A ControlNet accepts an additional conditioning image input that guides the diffusion model to preserve the features in it.\n\nFor example, let's condition an image with a ControlNet pretrained on inpaint images:\n\n```py\nimport torch\nimport numpy as np\nfrom diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline\nfrom diffusers.utils import load_image, make_image_grid\n\n# load ControlNet\ncontrolnet = ControlNetModel.from_pretrained(\"lllyasviel/control_v11p_sd15_inpaint\", torch_dtype=torch.float16, variant=\"fp16\")\n\n# pass ControlNet to the pipeline\npipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\", controlnet=controlnet, torch_dtype=torch.float16, variant=\"fp16\"\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\n# load base and mask image\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\")\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png\")\n\n# prepare control image\ndef make_inpaint_condition(init_image, mask_image):\n    init_image = np.array(init_image.convert(\"RGB\")).astype(np.float32) / 255.0\n    mask_image = np.array(mask_image.convert(\"L\")).astype(np.float32) / 255.0\n\n    assert init_image.shape[0:1] == mask_image.shape[0:1], \"image and image_mask must have the same image size\"\n    init_image[mask_image > 0.5] = -1.0  # set as masked pixel\n    init_image = np.expand_dims(init_image, 0).transpose(0, 3, 1, 2)\n    init_image = torch.from_numpy(init_image)\n    return init_image\n\ncontrol_image = make_inpaint_condition(init_image, mask_image)\n```\n\nNow generate an image from the base, mask and control images. You'll notice features of the base image are strongly preserved in the generated image.\n\n```py\nprompt = \"concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k\"\nimage = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, control_image=control_image).images[0]\nmake_image_grid([init_image, mask_image, PIL.Image.fromarray(np.uint8(control_image[0][0])).convert('RGB'), image], rows=2, cols=2)\n```\n\nYou can take this a step further and chain it with an image-to-image pipeline to apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion):\n\n```py\nfrom diffusers import AutoPipelineForImage2Image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"nitrosocke/elden-ring-diffusion\", torch_dtype=torch.float16,\n)\npipeline.enable_model_cpu_offload()\n# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed\npipeline.enable_xformers_memory_efficient_attention()\n\nprompt = \"elden ring style castle\" # include the token \"elden ring style\" in the prompt\nnegative_prompt = \"bad architecture, deformed, disfigured, poor details\"\n\nimage_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image).images[0]\nmake_image_grid([init_image, mask_image, image, image_elden_ring], rows=2, cols=2)\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">initial image</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-controlnet.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">ControlNet inpaint</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-img2img.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">image-to-image</figcaption>\n  </div>\n</div>\n\n## Optimize\n\nIt can be difficult and slow to run diffusion models if you're resource constrained, but it doesn't have to be with a few optimization tricks. One of the biggest (and easiest) optimizations you can enable is switching to memory-efficient attention. If you're using PyTorch 2.0, [scaled-dot product attention](../optimization/fp16#scaled-dot-product-attention) is automatically enabled and you don't need to do anything else. For non-PyTorch 2.0 users, you can install and use [xFormers](../optimization/xformers)'s implementation of memory-efficient attention. Both options reduce memory usage and accelerate inference.\n\nYou can also offload the model to the CPU to save even more memory:\n\n```diff\n+ pipeline.enable_xformers_memory_efficient_attention()\n+ pipeline.enable_model_cpu_offload()\n```\n\nTo speed-up your inference code even more, use [`torch_compile`](../optimization/fp16#torchcompile). You should wrap `torch.compile` around the most intensive component in the pipeline which is typically the UNet:\n\n```py\npipeline.unet = torch.compile(pipeline.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\nLearn more in the [Reduce memory usage](../optimization/memory) and [Accelerate inference](../optimization/fp16) guides.\n"
  },
  {
    "path": "docs/source/en/using-diffusers/ip_adapter.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# IP-Adapter\n\n[IP-Adapter](https://huggingface.co/papers/2308.06721) is a lightweight adapter designed to integrate image-based guidance with text-to-image diffusion models. The adapter uses an image encoder to extract image features that are passed to the newly added cross-attention layers in the UNet and fine-tuned. The original UNet model and the existing cross-attention layers corresponding to text features is frozen. Decoupling the cross-attention for image and text features enables more fine-grained and controllable generation.\n\nIP-Adapter files are typically ~100MBs because they only contain the image embeddings. This means you need to load a model first, and then load the IP-Adapter with [`~loaders.IPAdapterMixin.load_ip_adapter`].\n\n> [!TIP]\n> IP-Adapters are available to many models such as [Flux](../api/pipelines/flux#ip-adapter) and [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), and more. The examples in this guide use Stable Diffusion and Stable Diffusion XL.\n\nUse the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] parameter to scale the influence of the IP-Adapter during generation. A value of `1.0` means the model is only conditioned on the image prompt, and `0.5` typically produces balanced results between the text and image prompt.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\nfrom diffusers.utils import load_image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\",\n  torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter\",\n  subfolder=\"sdxl_models\",\n  weight_name=\"ip-adapter_sdxl.bin\"\n)\npipeline.set_ip_adapter_scale(0.8)\n```\n\nPass an image to `ip_adapter_image` along with a text prompt to generate an image.\n\n```py\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png\")\npipeline(\n    prompt=\"a polar bear sitting in a chair drinking a milkshake\",\n    ip_adapter_image=image,\n    negative_prompt=\"deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality\",\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png\" width=\"400\" alt=\"IP-Adapter image\"/>\n    <figcaption style=\"text-align: center;\">IP-Adapter image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner_2.png\" width=\"400\" alt=\"generated image\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\nTake a look at the examples below to learn how to use IP-Adapter for other tasks.\n\n<hfoptions id=\"usage\">\n<hfoption id=\"image-to-image\">\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\",\n  torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter\",\n  subfolder=\"sdxl_models\",\n  weight_name=\"ip-adapter_sdxl.bin\"\n)\npipeline.set_ip_adapter_scale(0.8)\n\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_bear_1.png\")\nip_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_gummy.png\")\npipeline(\n    prompt=\"best quality, high quality\",\n    image=image,\n    ip_adapter_image=ip_image,\n    strength=0.5,\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_bear_1.png\" width=\"300\" alt=\"input image\"/>\n    <figcaption style=\"text-align: center;\">input image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_gummy.png\" width=\"300\" alt=\"IP-Adapter image\"/>\n    <figcaption style=\"text-align: center;\">IP-Adapter image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_bear_3.png\" width=\"300\" alt=\"generated image\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\n</hfoption>\n<hfoption id=\"inpainting\">\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\",\n  torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter\",\n  subfolder=\"sdxl_models\",\n  weight_name=\"ip-adapter_sdxl.bin\"\n)\npipeline.set_ip_adapter_scale(0.6)\n\nmask_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_mask.png\")\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_bear_1.png\")\nip_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_gummy.png\")\npipeline(\n    prompt=\"a cute gummy bear waving\",\n    image=image,\n    mask_image=mask_image,\n    ip_adapter_image=ip_image,\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_bear_1.png\" width=\"300\" alt=\"input image\"/>\n    <figcaption style=\"text-align: center;\">input image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_gummy.png\" width=\"300\" alt=\"IP-Adapter image\"/>\n    <figcaption style=\"text-align: center;\">IP-Adapter image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_inpaint.png\" width=\"300\" alt=\"generated image\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\n</hfoption>\n<hfoption id=\"video\">\n\nThe [`~DiffusionPipeline.enable_model_cpu_offload`] method is useful for reducing memory and it should be enabled **after** the IP-Adapter is loaded. Otherwise, the IP-Adapter's image encoder is also offloaded to the CPU and returns an error.\n\n```py\nimport torch\nfrom diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter\nfrom diffusers.utils import export_to_gif\nfrom diffusers.utils import load_image\n\nadapter = MotionAdapter.from_pretrained(\n  \"guoyww/animatediff-motion-adapter-v1-5-2\",\n  torch_dtype=torch.float16\n)\npipeline = AnimateDiffPipeline.from_pretrained(\n  \"emilianJR/epiCRealism\",\n  motion_adapter=adapter,\n  torch_dtype=torch.float16\n)\nscheduler = DDIMScheduler.from_pretrained(\n    \"emilianJR/epiCRealism\",\n    subfolder=\"scheduler\",\n    clip_sample=False,\n    timestep_spacing=\"linspace\",\n    beta_schedule=\"linear\",\n    steps_offset=1,\n)\npipeline.scheduler = scheduler\npipeline.enable_vae_slicing()\npipeline.load_ip_adapter(\"h94/IP-Adapter\", subfolder=\"models\", weight_name=\"ip-adapter_sd15.bin\")\npipeline.enable_model_cpu_offload()\n\nip_adapter_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_inpaint.png\")\npipeline(\n    prompt=\"A cute gummy bear waving\",\n    negative_prompt=\"bad quality, worse quality, low resolution\",\n    ip_adapter_image=ip_adapter_image,\n    num_frames=16,\n    guidance_scale=7.5,\n    num_inference_steps=50,\n).frames[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_inpaint.png\" width=\"400\" alt=\"IP-Adapter image\"/>\n    <figcaption style=\"text-align: center;\">IP-Adapter image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gummy_bear.gif\" width=\"400\" alt=\"generated video\"/>\n    <figcaption style=\"text-align: center;\">generated video</figcaption>\n  </figure>\n</div>\n\n</hfoption>\n</hfoptions>\n\n## Model variants\n\nThere are two variants of IP-Adapter, Plus and FaceID. The Plus variant uses patch embeddings and the ViT-H image encoder. FaceID variant uses face embeddings generated from InsightFace.\n\n<hfoptions id=\"ipadapter-variants\">\n<hfoption id=\"IP-Adapter Plus\">\n\n```py\nimport torch\nfrom transformers import CLIPVisionModelWithProjection, AutoPipelineForText2Image\n\nimage_encoder = CLIPVisionModelWithProjection.from_pretrained(\n    \"h94/IP-Adapter\",\n    subfolder=\"models/image_encoder\",\n    torch_dtype=torch.float16\n)\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    image_encoder=image_encoder,\n    torch_dtype=torch.float16\n).to(\"cuda\")\n\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter\",\n  subfolder=\"sdxl_models\",\n  weight_name=\"ip-adapter-plus_sdxl_vit-h.safetensors\"\n)\n```\n\n</hfoption>\n<hfoption id=\"IP-Adapter FaceID\">\n\n```py\nimport torch\nfrom transformers import AutoPipelineForText2Image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\n\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter-FaceID\",\n  subfolder=None,\n  weight_name=\"ip-adapter-faceid_sdxl.bin\",\n  image_encoder_folder=None\n)\n```\n\nTo use a IP-Adapter FaceID Plus model, load the CLIP image encoder as well as [`~transformers.CLIPVisionModelWithProjection`].\n\n```py\nfrom transformers import AutoPipelineForText2Image, CLIPVisionModelWithProjection\n\nimage_encoder = CLIPVisionModelWithProjection.from_pretrained(\n    \"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\",\n    torch_dtype=torch.float16,\n)\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    image_encoder=image_encoder,\n    torch_dtype=torch.float16\n).to(\"cuda\")\n\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter-FaceID\",\n  subfolder=None,\n  weight_name=\"ip-adapter-faceid-plus_sd15.bin\"\n)\n```\n\n</hfoption>\n</hfoptions>\n\n## Image embeddings\n\nThe `prepare_ip_adapter_image_embeds` generates image embeddings you can reuse if you're running the pipeline multiple times because you have more than one image. Loading and encoding multiple images each time you use the pipeline can be inefficient. Precomputing the image embeddings ahead of time, saving them to disk, and loading them when you need them is more efficient.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\",\n  torch_dtype=torch.float16\n).to(\"cuda\")\n\nimage_embeds = pipeline.prepare_ip_adapter_image_embeds(\n    ip_adapter_image=image,\n    ip_adapter_image_embeds=None,\n    device=\"cuda\",\n    num_images_per_prompt=1,\n    do_classifier_free_guidance=True,\n)\n\ntorch.save(image_embeds, \"image_embeds.ipadpt\")\n```\n\nReload the image embeddings by passing them to the `ip_adapter_image_embeds` parameter. Set `image_encoder_folder` to `None` because you don't need the image encoder anymore to generate the image embeddings.\n\n> [!TIP]\n> You can also load image embeddings from other sources such as ComfyUI.\n\n```py\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter\",\n  subfolder=\"sdxl_models\",\n  image_encoder_folder=None,\n  weight_name=\"ip-adapter_sdxl.bin\"\n)\npipeline.set_ip_adapter_scale(0.8)\nimage_embeds = torch.load(\"image_embeds.ipadpt\")\npipeline(\n    prompt=\"a polar bear sitting in a chair drinking a milkshake\",\n    ip_adapter_image_embeds=image_embeds,\n    negative_prompt=\"deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality\",\n    num_inference_steps=100,\n    generator=generator,\n).images[0]\n```\n\n## Masking\n\nBinary masking enables assigning an IP-Adapter image to a specific area of the output image, making it useful for composing multiple IP-Adapter images. Each IP-Adapter image requires a binary mask.\n\nLoad the [`~image_processor.IPAdapterMaskProcessor`] to preprocess the image masks. For the best results, provide the output `height` and `width` to ensure masks with different aspect ratios are appropriately sized. If the input masks already match the aspect ratio of the generated image, you don't need to set the `height` and `width`.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\nfrom diffusers.image_processor import IPAdapterMaskProcessor\nfrom diffusers.utils import load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\",\n  torch_dtype=torch.float16\n).to(\"cuda\")\n\nmask1 = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_mask1.png\")\nmask2 = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_mask2.png\")\n\nprocessor = IPAdapterMaskProcessor()\nmasks = processor.preprocess([mask1, mask2], height=1024, width=1024)\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask1.png\" width=\"200\" alt=\"mask 1\"/>\n    <figcaption style=\"text-align: center;\">mask 1</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_mask2.png\" width=\"200\" alt=\"mask 2\"/>\n    <figcaption style=\"text-align: center;\">mask 2</figcaption>\n  </figure>\n</div>\n\nProvide both the IP-Adapter images and their scales as a list. Pass the preprocessed masks to `cross_attention_kwargs` in the pipeline.\n\n```py\nface_image1 = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png\")\nface_image2 = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl2.png\")\n\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter\",\n  subfolder=\"sdxl_models\",\n  weight_name=[\"ip-adapter-plus-face_sdxl_vit-h.safetensors\"]\n)\npipeline.set_ip_adapter_scale([[0.7, 0.7]])\n\nip_images = [[face_image1, face_image2]]\nmasks = [masks.reshape(1, masks.shape[0], masks.shape[2], masks.shape[3])]\n\npipeline(\n  prompt=\"2 girls\",\n  ip_adapter_image=ip_images,\n  negative_prompt=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n  cross_attention_kwargs={\"ip_adapter_masks\": masks}\n).images[0]\n```\n\n<div style=\"display: flex; flex-direction: column; gap: 10px;\">\n  <div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n    <figure>\n      <img src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl1.png\" width=\"400\" alt=\"IP-Adapter image 1\"/>\n      <figcaption style=\"text-align: center;\">IP-Adapter image 1</figcaption>\n    </figure>\n    <figure>\n      <img src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_mask_girl2.png\" width=\"400\" alt=\"IP-Adapter image 2\"/>\n      <figcaption style=\"text-align: center;\">IP-Adapter image 2</figcaption>\n    </figure>\n  </div>\n  <div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n    <figure>\n      <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_attention_mask_result_seed_0.png\" width=\"400\" alt=\"Generated image with mask\"/>\n      <figcaption style=\"text-align: center;\">generated with mask</figcaption>\n    </figure>\n    <figure>\n      <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_no_attention_mask_result_seed_0.png\" width=\"400\" alt=\"Generated image without mask\"/>\n      <figcaption style=\"text-align: center;\">generated without mask</figcaption>\n    </figure>\n  </div>\n</div>\n\n## Applications\n\nThe section below covers some popular applications of IP-Adapter.\n\n### Face models\n\nFace generation and preserving its details can be challenging. To help generate more accurate faces, there are checkpoints specifically conditioned on images of cropped faces. You can find the face models in the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repository or the [h94/IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) repository. The FaceID checkpoints use the FaceID embeddings from [InsightFace](https://github.com/deepinsight/insightface) instead of CLIP image embeddings.\n\nWe recommend using the [`DDIMScheduler`] or [`EulerDiscreteScheduler`] for face models.\n\n<hfoptions id=\"usage\">\n<hfoption id=\"h94/IP-Adapter\">\n\n```py\nimport torch\nfrom diffusers import StableDiffusionPipeline, DDIMScheduler\nfrom diffusers.utils import load_image\n\npipeline = StableDiffusionPipeline.from_pretrained(\n  \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n  torch_dtype=torch.float16,\n).to(\"cuda\")\npipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter\",\n  subfolder=\"models\", \n  weight_name=\"ip-adapter-full-face_sd15.bin\"\n)\n\npipeline.set_ip_adapter_scale(0.5)\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_einstein_base.png\")\n\npipeline(\n    prompt=\"A photo of Einstein as a chef, wearing an apron, cooking in a French restaurant\",\n    ip_adapter_image=image,\n    negative_prompt=\"lowres, bad anatomy, worst quality, low quality\",\n    num_inference_steps=100,\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_einstein_base.png\" width=\"400\" alt=\"IP-Adapter image\"/>\n    <figcaption style=\"text-align: center;\">IP-Adapter image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_einstein.png\" width=\"400\" alt=\"generated image\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\n</hfoption>\n<hfoption id=\"h94/IP-Adapter-FaceID\">\n\nFor FaceID models, extract the face embeddings and pass them as a list of tensors to `ip_adapter_image_embeds`.\n\n```py\n# pip install insightface\nimport torch\nfrom diffusers import StableDiffusionPipeline, DDIMScheduler\nfrom diffusers.utils import load_image\nfrom insightface.app import FaceAnalysis\n\npipeline = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\npipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter-FaceID\",\n  subfolder=None,\n  weight_name=\"ip-adapter-faceid_sd15.bin\",\n  image_encoder_folder=None\n)\npipeline.set_ip_adapter_scale(0.6)\n\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png\")\n\nref_images_embeds = []\napp = FaceAnalysis(name=\"buffalo_l\", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\napp.prepare(ctx_id=0, det_size=(640, 640))\nimage = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)\nfaces = app.get(image)\nimage = torch.from_numpy(faces[0].normed_embedding)\nref_images_embeds.append(image.unsqueeze(0))\nref_images_embeds = torch.stack(ref_images_embeds, dim=0).unsqueeze(0)\nneg_ref_images_embeds = torch.zeros_like(ref_images_embeds)\nid_embeds = torch.cat([neg_ref_images_embeds, ref_images_embeds]).to(dtype=torch.float16, device=\"cuda\")\n\npipeline(\n    prompt=\"A photo of a girl\",\n    ip_adapter_image_embeds=[id_embeds],\n    negative_prompt=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n).images[0]\n```\n\nThe IP-Adapter FaceID Plus and Plus v2 models require CLIP image embeddings. Prepare the face embeddings and then extract and pass the CLIP embeddings to the hidden image projection layers.\n\n```py\nclip_embeds = pipeline.prepare_ip_adapter_image_embeds(\n  [ip_adapter_images], None, torch.device(\"cuda\"), num_images, True)[0]\n\npipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = clip_embeds.to(dtype=torch.float16)\n# set to True if using IP-Adapter FaceID Plus v2\npipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = False\n```\n\n</hfoption>\n</hfoptions>\n\n### Multiple IP-Adapters\n\nCombine multiple IP-Adapters to generate images in more diverse styles. For example, you can use IP-Adapter Face to generate consistent faces and characters and IP-Adapter Plus to generate those faces in specific styles.\n\nLoad an image encoder with [`~transformers.CLIPVisionModelWithProjection`].\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image, DDIMScheduler\nfrom transformers import CLIPVisionModelWithProjection\nfrom diffusers.utils import load_image\n\nimage_encoder = CLIPVisionModelWithProjection.from_pretrained(\n    \"h94/IP-Adapter\",\n    subfolder=\"models/image_encoder\",\n    torch_dtype=torch.float16,\n)\n```\n\nLoad a base model, scheduler and the following IP-Adapters.\n\n- [ip-adapter-plus_sdxl_vit-h](https://huggingface.co/h94/IP-Adapter#ip-adapter-for-sdxl-10) uses patch embeddings and a ViT-H image encoder\n- [ip-adapter-plus-face_sdxl_vit-h](https://huggingface.co/h94/IP-Adapter#ip-adapter-for-sdxl-10) uses patch embeddings and a ViT-H image encoder but it is conditioned on images of cropped faces\n\n```py\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    image_encoder=image_encoder,\n)\npipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter\",\n  subfolder=\"sdxl_models\",\n  weight_name=[\"ip-adapter-plus_sdxl_vit-h.safetensors\", \"ip-adapter-plus-face_sdxl_vit-h.safetensors\"]\n)\npipeline.set_ip_adapter_scale([0.7, 0.3])\n# enable_model_cpu_offload to reduce memory usage\npipeline.enable_model_cpu_offload()\n```\n\nLoad an image and a folder containing images of a certain style to apply.\n\n```py\nface_image = load_image(\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png\")\nstyle_folder = \"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy\"\nstyle_images = [load_image(f\"{style_folder}/img{i}.png\") for i in range(10)]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png\" width=\"400\" alt=\"Face image\"/>\n    <figcaption style=\"text-align: center;\">face image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_style_grid.png\" width=\"400\" alt=\"Style images\"/>\n    <figcaption style=\"text-align: center;\">style images</figcaption>\n  </figure>\n</div>\n\nPass style and face images as a list to `ip_adapter_image`.\n\n```py\ngenerator = torch.Generator(device=\"cpu\").manual_seed(0)\n\npipeline(\n    prompt=\"wonderwoman\",\n    ip_adapter_image=[style_images, face_image],\n    negative_prompt=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n).images[0]\n```\n\n<div style=\"display: flex; justify-content: center;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip_multi_out.png\" width=\"400\" alt=\"Generated image\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\n### Instant generation\n\n[Latent Consistency Models (LCM)](../api/pipelines/latent_consistency_models) can generate images 4 steps or less, unlike other diffusion models which require a lot more steps, making it feel \"instantaneous\". IP-Adapters are compatible with LCM models to instantly generate images.\n\nLoad the IP-Adapter weights and load the LoRA weights with [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline, LCMScheduler\nfrom diffusers.utils import load_image\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"sd-dreambooth-library/herge-style\",\n  torch_dtype=torch.float16\n)\n\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter\",\n  subfolder=\"models\",\n  weight_name=\"ip-adapter_sd15.bin\"\n)\npipeline.load_lora_weights(\"latent-consistency/lcm-lora-sdv1-5\")\npipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)\n# enable_model_cpu_offload to reduce memory usage\npipeline.enable_model_cpu_offload()\n```\n\nTry using a lower IP-Adapter scale to condition generation more on the style you want to apply and remember to use the special token in your prompt to trigger its generation.\n\n```py\npipeline.set_ip_adapter_scale(0.4)\n\nprompt = \"herge_style woman in armor, best quality, high quality\"\n\nip_adapter_image = load_image(\"https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png\")\npipeline(\n    prompt=prompt,\n    ip_adapter_image=ip_adapter_image,\n    num_inference_steps=4,\n    guidance_scale=1,\n).images[0]\n```\n\n<div style=\"display: flex; justify-content: center;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_herge.png\" width=\"400\" alt=\"Generated image\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\n### Structural control\n\nFor structural control, combine IP-Adapter with [ControlNet](../api/pipelines/controlnet) conditioned on depth maps, edge maps, pose estimations, and more.\n\nThe example below loads a [`ControlNetModel`] checkpoint conditioned on depth maps and combines it with a IP-Adapter.\n\n```py\nimport torch\nfrom diffusers.utils import load_image\nfrom diffusers import StableDiffusionControlNetPipeline, ControlNetModel\n\ncontrolnet = ControlNetModel.from_pretrained(\n  \"lllyasviel/control_v11f1p_sd15_depth\",\n  torch_dtype=torch.float16\n)\n\npipeline = StableDiffusionControlNetPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    controlnet=controlnet,\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter\",\n  subfolder=\"models\",\n  weight_name=\"ip-adapter_sd15.bin\"\n)\n```\n\nPass the depth map and IP-Adapter image to the pipeline.\n\n```py\npipeline(\n  prompt=\"best quality, high quality\",\n  image=depth_map,\n  ip_adapter_image=ip_adapter_image,\n  negative_prompt=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png\" width=\"300\" alt=\"IP-Adapter image\"/>\n    <figcaption style=\"text-align: center;\">IP-Adapter image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/depth.png\" width=\"300\" alt=\"Depth map\"/>\n    <figcaption style=\"text-align: center;\">depth map</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png\" width=\"300\" alt=\"Generated image\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\n### Style and layout control\n\nFor style and layout control, combine IP-Adapter with [InstantStyle](https://huggingface.co/papers/2404.02733). InstantStyle separates *style* (color, texture, overall feel) and *content* from each other. It only applies the style in style-specific blocks of the model to prevent it from distorting other areas of an image. This generates images with stronger and more consistent styles and better control over the layout.\n\nThe IP-Adapter is only activated for specific parts of the model. Use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method to scale the influence of the IP-Adapter in different layers. The example below activates the IP-Adapter in the second layer of the models down `block_2` and up `block_0`. Down `block_2` is where the IP-Adapter injects layout information and up `block_0` is where style is injected.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\nfrom diffusers.utils import load_image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\",\n  torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_ip_adapter(\n  \"h94/IP-Adapter\",\n  subfolder=\"sdxl_models\",\n  weight_name=\"ip-adapter_sdxl.bin\"\n)\n\nscale = {\n    \"down\": {\"block_2\": [0.0, 1.0]},\n    \"up\": {\"block_0\": [0.0, 1.0, 0.0]},\n}\npipeline.set_ip_adapter_scale(scale)\n```\n\nLoad the style image and generate an image.\n\n```py\nstyle_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg\")\n\npipeline(\n    prompt=\"a cat, masterpiece, best quality, high quality\",\n    ip_adapter_image=style_image,\n    negative_prompt=\"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n    guidance_scale=5,\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg\" width=\"400\" alt=\"Style image\"/>\n    <figcaption style=\"text-align: center;\">style image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png\" width=\"400\" alt=\"Generated image\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\nYou can also insert the IP-Adapter in all the model layers. This tends to generate images that focus more on the image prompt and may reduce the diversity of generated images. Only activate the IP-Adapter in up `block_0` or the style layer.\n\n> [!TIP]\n> You don't need to specify all the layers in the `scale` dictionary. Layers not included are set to 0, which means the IP-Adapter is disabled.\n\n```py\nscale = {\n    \"up\": {\"block_0\": [0.0, 1.0, 0.0]},\n}\npipeline.set_ip_adapter_scale(scale)\n\npipeline(\n    prompt=\"a cat, masterpiece, best quality, high quality\",\n    ip_adapter_image=style_image,\n    negative_prompt=\"text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry\",\n    guidance_scale=5,\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_only.png\" width=\"400\" alt=\"Generated image (style only)\"/>\n    <figcaption style=\"text-align: center;\">style-layer generated image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_ip_adapter.png\" width=\"400\" alt=\"Generated image (IP-Adapter only)\"/>\n    <figcaption style=\"text-align: center;\">all layers generated image</figcaption>\n  </figure>\n</div>"
  },
  {
    "path": "docs/source/en/using-diffusers/kandinsky.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Kandinsky\n\n[[open-in-colab]]\n\nThe Kandinsky models are a series of multilingual text-to-image generation models. The Kandinsky 2.0 model uses two multilingual text encoders and concatenates those results for the UNet.\n\n[Kandinsky 2.1](../api/pipelines/kandinsky) changes the architecture to include an image prior model ([`CLIP`](https://huggingface.co/docs/transformers/model_doc/clip)) to generate a mapping between text and image embeddings. The mapping provides better text-image alignment and it is used with the text embeddings during training, leading to higher quality results. Finally, Kandinsky 2.1 uses a [Modulating Quantized Vectors (MoVQ)](https://huggingface.co/papers/2209.09002) decoder - which adds a spatial conditional normalization layer to increase photorealism - to decode the latents into images.\n\n[Kandinsky 2.2](../api/pipelines/kandinsky_v22) improves on the previous model by replacing the image encoder of the image prior model with a larger CLIP-ViT-G model to improve quality. The image prior model was also retrained on images with different resolutions and aspect ratios to generate higher-resolution images and different image sizes.\n\n[Kandinsky 3](../api/pipelines/kandinsky3) simplifies the architecture and shifts away from the two-stage generation process involving the prior model and diffusion model. Instead, Kandinsky 3 uses [Flan-UL2](https://huggingface.co/google/flan-ul2) to encode text, a UNet with [BigGan-deep](https://hf.co/papers/1809.11096) blocks, and [Sber-MoVQGAN](https://github.com/ai-forever/MoVQGAN) to decode the latents into images. Text understanding and generated image quality are primarily achieved by using a larger text encoder and UNet.\n\nThis guide will show you how to use the Kandinsky models for text-to-image, image-to-image, inpainting, interpolation, and more.\n\nBefore you begin, make sure you have the following libraries installed:\n\n```py\n# uncomment to install the necessary libraries in Colab\n#!pip install -q diffusers transformers accelerate\n```\n\n> [!WARNING]\n> Kandinsky 2.1 and 2.2 usage is very similar! The only difference is Kandinsky 2.2 doesn't accept `prompt` as an input when decoding the latents. Instead, Kandinsky 2.2 only accepts `image_embeds` during decoding.\n>\n> <br>\n>\n> Kandinsky 3 has a more concise architecture and it doesn't require a prior model. This means it's usage is identical to other diffusion models like [Stable Diffusion XL](sdxl).\n\n## Text-to-image\n\nTo use the Kandinsky models for any task, you always start by setting up the prior pipeline to encode the prompt and generate the image embeddings. The prior pipeline also generates `negative_image_embeds` that correspond to the negative prompt `\"\"`. For better results, you can pass an actual `negative_prompt` to the prior pipeline, but this'll increase the effective batch size of the prior pipeline by 2x.\n\n<hfoptions id=\"text-to-image\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nfrom diffusers import KandinskyPriorPipeline, KandinskyPipeline\nimport torch\n\nprior_pipeline = KandinskyPriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1-prior\", torch_dtype=torch.float16).to(\"cuda\")\npipeline = KandinskyPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting\"\nnegative_prompt = \"low quality, bad quality\" # optional to include a negative prompt, but results are usually better\nimage_embeds, negative_image_embeds = prior_pipeline(prompt, negative_prompt, guidance_scale=1.0).to_tuple()\n```\n\nNow pass all the prompts and embeddings to the [`KandinskyPipeline`] to generate an image:\n\n```py\nimage = pipeline(prompt, image_embeds=image_embeds, negative_prompt=negative_prompt, negative_image_embeds=negative_image_embeds, height=768, width=768).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/cheeseburger.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nfrom diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline\nimport torch\n\nprior_pipeline = KandinskyV22PriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-prior\", torch_dtype=torch.float16).to(\"cuda\")\npipeline = KandinskyV22Pipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting\"\nnegative_prompt = \"low quality, bad quality\" # optional to include a negative prompt, but results are usually better\nimage_embeds, negative_image_embeds = prior_pipeline(prompt, guidance_scale=1.0).to_tuple()\n```\n\nPass the `image_embeds` and `negative_image_embeds` to the [`KandinskyV22Pipeline`] to generate an image:\n\n```py\nimage = pipeline(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, height=768, width=768).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-text-to-image.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"Kandinsky 3\">\n\nKandinsky 3 doesn't require a prior model so you can directly load the [`Kandinsky3Pipeline`] and pass a prompt to generate an image:\n\n```py\nfrom diffusers import Kandinsky3Pipeline\nimport torch\n\npipeline = Kandinsky3Pipeline.from_pretrained(\"kandinsky-community/kandinsky-3\", variant=\"fp16\", torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n\nprompt = \"A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting\"\nimage = pipeline(prompt).images[0]\nimage\n```\n\n</hfoption>\n</hfoptions>\n\n🤗 Diffusers also provides an end-to-end API with the [`KandinskyCombinedPipeline`] and [`KandinskyV22CombinedPipeline`], meaning you don't have to separately load the prior and text-to-image pipeline. The combined pipeline automatically loads both the prior model and the decoder. You can still set different values for the prior pipeline with the `prior_guidance_scale` and `prior_num_inference_steps` parameters if you want.\n\nUse the [`AutoPipelineForText2Image`] to automatically call the combined pipelines under the hood:\n\n<hfoptions id=\"text-to-image\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n\nprompt = \"A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting\"\nnegative_prompt = \"low quality, bad quality\"\n\nimage = pipeline(prompt=prompt, negative_prompt=negative_prompt, prior_guidance_scale=1.0, guidance_scale=4.0, height=768, width=768).images[0]\nimage\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n\nprompt = \"A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting\"\nnegative_prompt = \"low quality, bad quality\"\n\nimage = pipeline(prompt=prompt, negative_prompt=negative_prompt, prior_guidance_scale=1.0, guidance_scale=4.0, height=768, width=768).images[0]\nimage\n```\n\n</hfoption>\n</hfoptions>\n\n## Image-to-image\n\nFor image-to-image, pass the initial image and text prompt to condition the image to the pipeline. Start by loading the prior pipeline:\n\n<hfoptions id=\"image-to-image\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nimport torch\nfrom diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline\n\nprior_pipeline = KandinskyPriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\npipeline = KandinskyImg2ImgPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nimport torch\nfrom diffusers import KandinskyV22Img2ImgPipeline, KandinskyPriorPipeline\n\nprior_pipeline = KandinskyPriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\npipeline = KandinskyV22Img2ImgPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 3\">\n\nKandinsky 3 doesn't require a prior model so you can directly load the image-to-image pipeline:\n\n```py\nfrom diffusers import Kandinsky3Img2ImgPipeline\nfrom diffusers.utils import load_image\nimport torch\n\npipeline = Kandinsky3Img2ImgPipeline.from_pretrained(\"kandinsky-community/kandinsky-3\", variant=\"fp16\", torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n```\n\n</hfoption>\n</hfoptions>\n\nDownload an image to condition on:\n\n```py\nfrom diffusers.utils import load_image\n\n# download image\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\noriginal_image = load_image(url)\noriginal_image = original_image.resize((768, 512))\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"/>\n</div>\n\nGenerate the `image_embeds` and `negative_image_embeds` with the prior pipeline:\n\n```py\nprompt = \"A fantasy landscape, Cinematic lighting\"\nnegative_prompt = \"low quality, bad quality\"\n\nimage_embeds, negative_image_embeds = prior_pipeline(prompt, negative_prompt).to_tuple()\n```\n\nNow pass the original image, and all the prompts and embeddings to the pipeline to generate an image:\n\n<hfoptions id=\"image-to-image\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nfrom diffusers.utils import make_image_grid\n\nimage = pipeline(prompt, negative_prompt=negative_prompt, image=original_image, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, height=768, width=768, strength=0.3).images[0]\nmake_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/img2img_fantasyland.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nfrom diffusers.utils import make_image_grid\n\nimage = pipeline(image=original_image, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, height=768, width=768, strength=0.3).images[0]\nmake_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-image-to-image.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"Kandinsky 3\">\n\n```py\nimage = pipeline(prompt, negative_prompt=negative_prompt, image=image, strength=0.75, num_inference_steps=25).images[0]\nimage\n```\n\n</hfoption>\n</hfoptions>\n\n🤗 Diffusers also provides an end-to-end API with the [`KandinskyImg2ImgCombinedPipeline`] and [`KandinskyV22Img2ImgCombinedPipeline`], meaning you don't have to separately load the prior and image-to-image pipeline. The combined pipeline automatically loads both the prior model and the decoder. You can still set different values for the prior pipeline with the `prior_guidance_scale` and `prior_num_inference_steps` parameters if you want.\n\nUse the [`AutoPipelineForImage2Image`] to automatically call the combined pipelines under the hood:\n\n<hfoptions id=\"image-to-image\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import make_image_grid, load_image\nimport torch\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16, use_safetensors=True)\npipeline.enable_model_cpu_offload()\n\nprompt = \"A fantasy landscape, Cinematic lighting\"\nnegative_prompt = \"low quality, bad quality\"\n\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\noriginal_image = load_image(url)\n\noriginal_image.thumbnail((768, 768))\n\nimage = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=original_image, strength=0.3).images[0]\nmake_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import make_image_grid, load_image\nimport torch\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n\nprompt = \"A fantasy landscape, Cinematic lighting\"\nnegative_prompt = \"low quality, bad quality\"\n\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\noriginal_image = load_image(url)\n\noriginal_image.thumbnail((768, 768))\n\nimage = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=original_image, strength=0.3).images[0]\nmake_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)\n```\n\n</hfoption>\n</hfoptions>\n\n## Inpainting\n\n> [!WARNING]\n> ⚠️ The Kandinsky models use ⬜️ **white pixels** to represent the masked area now instead of black pixels. If you are using [`KandinskyInpaintPipeline`] in production, you need to change the mask to use white pixels:\n>\n> ```py\n> # For PIL input\n> import PIL.ImageOps\n> mask = PIL.ImageOps.invert(mask)\n>\n> # For PyTorch and NumPy input\n> mask = 1 - mask\n> ```\n\nFor inpainting, you'll need the original image, a mask of the area to replace in the original image, and a text prompt of what to inpaint. Load the prior pipeline:\n\n<hfoptions id=\"inpaint\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nfrom diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline\nfrom diffusers.utils import load_image, make_image_grid\nimport torch\nimport numpy as np\nfrom PIL import Image\n\nprior_pipeline = KandinskyPriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\npipeline = KandinskyInpaintPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1-inpaint\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nfrom diffusers import KandinskyV22InpaintPipeline, KandinskyV22PriorPipeline\nfrom diffusers.utils import load_image, make_image_grid\nimport torch\nimport numpy as np\nfrom PIL import Image\n\nprior_pipeline = KandinskyV22PriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\npipeline = KandinskyV22InpaintPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder-inpaint\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n```\n\n</hfoption>\n</hfoptions>\n\nLoad an initial image and create a mask:\n\n```py\ninit_image = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png\")\nmask = np.zeros((768, 768), dtype=np.float32)\n# mask area above cat's head\nmask[:250, 250:-250] = 1\n```\n\nGenerate the embeddings with the prior pipeline:\n\n```py\nprompt = \"a hat\"\nprior_output = prior_pipeline(prompt)\n```\n\nNow pass the initial image, mask, and prompt and embeddings to the pipeline to generate an image:\n\n<hfoptions id=\"inpaint\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\noutput_image = pipeline(prompt, image=init_image, mask_image=mask, **prior_output, height=768, width=768, num_inference_steps=150).images[0]\nmask = Image.fromarray((mask*255).astype('uint8'), 'L')\nmake_image_grid([init_image, mask, output_image], rows=1, cols=3)\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/inpaint_cat_hat.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\noutput_image = pipeline(image=init_image, mask_image=mask, **prior_output, height=768, width=768, num_inference_steps=150).images[0]\nmask = Image.fromarray((mask*255).astype('uint8'), 'L')\nmake_image_grid([init_image, mask, output_image], rows=1, cols=3)\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinskyv22-inpaint.png\"/>\n</div>\n\n</hfoption>\n</hfoptions>\n\nYou can also use the end-to-end [`KandinskyInpaintCombinedPipeline`] and [`KandinskyV22InpaintCombinedPipeline`] to call the prior and decoder pipelines together under the hood. Use the [`AutoPipelineForInpainting`] for this:\n\n<hfoptions id=\"inpaint\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipe = AutoPipelineForInpainting.from_pretrained(\"kandinsky-community/kandinsky-2-1-inpaint\", torch_dtype=torch.float16)\npipe.enable_model_cpu_offload()\n\ninit_image = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png\")\nmask = np.zeros((768, 768), dtype=np.float32)\n# mask area above cat's head\nmask[:250, 250:-250] = 1\nprompt = \"a hat\"\n\noutput_image = pipe(prompt=prompt, image=init_image, mask_image=mask).images[0]\nmask = Image.fromarray((mask*255).astype('uint8'), 'L')\nmake_image_grid([init_image, mask, output_image], rows=1, cols=3)\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipe = AutoPipelineForInpainting.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder-inpaint\", torch_dtype=torch.float16)\npipe.enable_model_cpu_offload()\n\ninit_image = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png\")\nmask = np.zeros((768, 768), dtype=np.float32)\n# mask area above cat's head\nmask[:250, 250:-250] = 1\nprompt = \"a hat\"\n\noutput_image = pipe(prompt=prompt, image=original_image, mask_image=mask).images[0]\nmask = Image.fromarray((mask*255).astype('uint8'), 'L')\nmake_image_grid([init_image, mask, output_image], rows=1, cols=3)\n```\n\n</hfoption>\n</hfoptions>\n\n## Interpolation\n\nInterpolation allows you to explore the latent space between the image and text embeddings which is a cool way to see some of the prior model's intermediate outputs. Load the prior pipeline and two images you'd like to interpolate:\n\n<hfoptions id=\"interpolate\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nfrom diffusers import KandinskyPriorPipeline, KandinskyPipeline\nfrom diffusers.utils import load_image, make_image_grid\nimport torch\n\nprior_pipeline = KandinskyPriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\nimg_1 = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png\")\nimg_2 = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/starry_night.jpeg\")\nmake_image_grid([img_1.resize((512,512)), img_2.resize((512,512))], rows=1, cols=2)\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nfrom diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline\nfrom diffusers.utils import load_image, make_image_grid\nimport torch\n\nprior_pipeline = KandinskyV22PriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\nimg_1 = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png\")\nimg_2 = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/starry_night.jpeg\")\nmake_image_grid([img_1.resize((512,512)), img_2.resize((512,512))], rows=1, cols=2)\n```\n\n</hfoption>\n</hfoptions>\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">a cat</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/starry_night.jpeg\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">Van Gogh's Starry Night painting</figcaption>\n  </div>\n</div>\n\nSpecify the text or images to interpolate, and set the weights for each text or image. Experiment with the weights to see how they affect the interpolation!\n\n```py\nimages_texts = [\"a cat\", img_1, img_2]\nweights = [0.3, 0.3, 0.4]\n```\n\nCall the `interpolate` function to generate the embeddings, and then pass them to the pipeline to generate the image:\n\n<hfoptions id=\"interpolate\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\n# prompt can be left empty\nprompt = \"\"\nprior_out = prior_pipeline.interpolate(images_texts, weights)\n\npipeline = KandinskyPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n\nimage = pipeline(prompt, **prior_out, height=768, width=768).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/starry_cat.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\n# prompt can be left empty\nprompt = \"\"\nprior_out = prior_pipeline.interpolate(images_texts, weights)\n\npipeline = KandinskyV22Pipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n\nimage = pipeline(prompt, **prior_out, height=768, width=768).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinskyv22-interpolate.png\"/>\n</div>\n\n</hfoption>\n</hfoptions>\n\n## ControlNet\n\n> [!WARNING]\n> ⚠️ ControlNet is only supported for Kandinsky 2.2!\n\nControlNet enables conditioning large pretrained diffusion models with additional inputs such as a depth map or edge detection. For example, you can condition Kandinsky 2.2 with a depth map so the model understands and preserves the structure of the depth image.\n\nLet's load an image and extract it's depth map:\n\n```py\nfrom diffusers.utils import load_image\n\nimg = load_image(\n    \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png\"\n).resize((768, 768))\nimg\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png\"/>\n</div>\n\nThen you can use the `depth-estimation` [`~transformers.Pipeline`] from 🤗 Transformers to process the image and retrieve the depth map:\n\n```py\nimport torch\nimport numpy as np\n\nfrom transformers import pipeline\n\ndef make_hint(image, depth_estimator):\n    image = depth_estimator(image)[\"depth\"]\n    image = np.array(image)\n    image = image[:, :, None]\n    image = np.concatenate([image, image, image], axis=2)\n    detected_map = torch.from_numpy(image).float() / 255.0\n    hint = detected_map.permute(2, 0, 1)\n    return hint\n\ndepth_estimator = pipeline(\"depth-estimation\")\nhint = make_hint(img, depth_estimator).unsqueeze(0).half().to(\"cuda\")\n```\n\n### Text-to-image [[controlnet-text-to-image]]\n\nLoad the prior pipeline and the [`KandinskyV22ControlnetPipeline`]:\n\n```py\nfrom diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline\n\nprior_pipeline = KandinskyV22PriorPipeline.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-prior\", torch_dtype=torch.float16, use_safetensors=True\n).to(\"cuda\")\n\npipeline = KandinskyV22ControlnetPipeline.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-controlnet-depth\", torch_dtype=torch.float16\n).to(\"cuda\")\n```\n\nGenerate the image embeddings from a prompt and negative prompt:\n\n```py\nprompt = \"A robot, 4k photo\"\nnegative_prior_prompt = \"lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature\"\n\ngenerator = torch.Generator(device=\"cuda\").manual_seed(43)\n\nimage_emb, zero_image_emb = prior_pipeline(\n    prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator\n).to_tuple()\n```\n\nFinally, pass the image embeddings and the depth image to the [`KandinskyV22ControlnetPipeline`] to generate an image:\n\n```py\nimage = pipeline(image_embeds=image_emb, negative_image_embeds=zero_image_emb, hint=hint, num_inference_steps=50, generator=generator, height=768, width=768).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/robot_cat_text2img.png\"/>\n</div>\n\n### Image-to-image [[controlnet-image-to-image]]\n\nFor image-to-image with ControlNet, you'll need to use the:\n\n- [`KandinskyV22PriorEmb2EmbPipeline`] to generate the image embeddings from a text prompt and an image\n- [`KandinskyV22ControlnetImg2ImgPipeline`] to generate an image from the initial image and the image embeddings\n\nProcess and extract a depth map of an initial image of a cat with the `depth-estimation` [`~transformers.Pipeline`] from 🤗 Transformers:\n\n```py\nimport torch\nimport numpy as np\n\nfrom diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline\nfrom diffusers.utils import load_image\nfrom transformers import pipeline\n\nimg = load_image(\n    \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png\"\n).resize((768, 768))\n\ndef make_hint(image, depth_estimator):\n    image = depth_estimator(image)[\"depth\"]\n    image = np.array(image)\n    image = image[:, :, None]\n    image = np.concatenate([image, image, image], axis=2)\n    detected_map = torch.from_numpy(image).float() / 255.0\n    hint = detected_map.permute(2, 0, 1)\n    return hint\n\ndepth_estimator = pipeline(\"depth-estimation\")\nhint = make_hint(img, depth_estimator).unsqueeze(0).half().to(\"cuda\")\n```\n\nLoad the prior pipeline and the [`KandinskyV22ControlnetImg2ImgPipeline`]:\n\n```py\nprior_pipeline = KandinskyV22PriorEmb2EmbPipeline.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-prior\", torch_dtype=torch.float16, use_safetensors=True\n).to(\"cuda\")\n\npipeline = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-controlnet-depth\", torch_dtype=torch.float16\n).to(\"cuda\")\n```\n\nPass a text prompt and the initial image to the prior pipeline to generate the image embeddings:\n\n```py\nprompt = \"A robot, 4k photo\"\nnegative_prior_prompt = \"lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature\"\n\ngenerator = torch.Generator(device=\"cuda\").manual_seed(43)\n\nimg_emb = prior_pipeline(prompt=prompt, image=img, strength=0.85, generator=generator)\nnegative_emb = prior_pipeline(prompt=negative_prior_prompt, image=img, strength=1, generator=generator)\n```\n\nNow you can run the [`KandinskyV22ControlnetImg2ImgPipeline`] to generate an image from the initial image and the image embeddings:\n\n```py\nimage = pipeline(image=img, strength=0.5, image_embeds=img_emb.image_embeds, negative_image_embeds=negative_emb.image_embeds, hint=hint, num_inference_steps=50, generator=generator, height=768, width=768).images[0]\nmake_image_grid([img.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/robot_cat.png\"/>\n</div>\n\n## Optimizations\n\nKandinsky is unique because it requires a prior pipeline to generate the mappings, and a second pipeline to decode the latents into an image. Optimization efforts should be focused on the second pipeline because that is where the bulk of the computation is done. Here are some tips to improve Kandinsky during inference.\n\n1. Enable [xFormers](../optimization/xformers) if you're using PyTorch < 2.0:\n\n```diff\n  from diffusers import DiffusionPipeline\n  import torch\n\n  pipe = DiffusionPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16)\n+ pipe.enable_xformers_memory_efficient_attention()\n```\n\n2. Enable `torch.compile` if you're using PyTorch >= 2.0 to automatically use scaled dot-product attention (SDPA):\n\n```diff\n  pipe.unet.to(memory_format=torch.channels_last)\n+ pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\nThis is the same as explicitly setting the attention processor to use [`~models.attention_processor.AttnAddedKVProcessor2_0`]:\n\n```py\nfrom diffusers.models.attention_processor import AttnAddedKVProcessor2_0\n\npipe.unet.set_attn_processor(AttnAddedKVProcessor2_0())\n```\n\n3. Offload the model to the CPU with [`~KandinskyPriorPipeline.enable_model_cpu_offload`] to avoid out-of-memory errors:\n\n```diff\n  from diffusers import DiffusionPipeline\n  import torch\n\n  pipe = DiffusionPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16)\n+ pipe.enable_model_cpu_offload()\n```\n\n4. By default, the text-to-image pipeline uses the [`DDIMScheduler`] but you can replace it with another scheduler like [`DDPMScheduler`] to see how that affects the tradeoff between inference speed and image quality:\n\n```py\nfrom diffusers import DDPMScheduler\nfrom diffusers import DiffusionPipeline\n\nscheduler = DDPMScheduler.from_pretrained(\"kandinsky-community/kandinsky-2-1\", subfolder=\"ddpm_scheduler\")\npipe = DiffusionPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", scheduler=scheduler, torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n```\n"
  },
  {
    "path": "docs/source/en/using-diffusers/loading.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# DiffusionPipeline\n\nDiffusion models consists of multiple components like UNets or diffusion transformers (DiTs), text encoders, variational autoencoders (VAEs), and schedulers. The [`DiffusionPipeline`] wraps all of these components into a single easy-to-use API without giving up the flexibility to modify it's components.\n\nThis guide will show you how to load a [`DiffusionPipeline`].\n\n## Loading a pipeline\n\n[`DiffusionPipeline`] is a base pipeline class that automatically selects and returns an instance of a model's pipeline subclass, like [`QwenImagePipeline`], by scanning the `model_index.json` file for the class name.\n\nPass a model id to [`~DiffusionPipeline.from_pretrained`] to load a pipeline.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"Qwen/Qwen-Image\", torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\n```\n\nEvery model has a specific pipeline subclass that inherits from [`DiffusionPipeline`]. A subclass usually has a narrow focus and are task-specific. See the table below for an example.\n\n| pipeline subclass | task |\n|---|---|\n| [`QwenImagePipeline`] | text-to-image |\n| [`QwenImageImg2ImgPipeline`] | image-to-image |\n| [`QwenImageInpaintPipeline`] | inpaint |\n\nYou could use the subclass directly by passing a model id to [`~QwenImagePipeline.from_pretrained`].\n\n```py\nimport torch\nfrom diffusers import QwenImagePipeline\n\npipeline = QwenImagePipeline.from_pretrained(\n  \"Qwen/Qwen-Image\", torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\n```\n\n> [!TIP]\n> Refer to the [Single file format](./other-formats#single-file-format) docs to learn how to load single file models.\n\n### Local pipelines\n\nPipelines can also be run locally. Use [`~huggingface_hub.snapshot_download`] to download a model repository.\n\n```py\nfrom huggingface_hub import snapshot_download\n\nsnapshot_download(repo_id=\"Qwen/Qwen-Image\")\n```\n\nThe model is downloaded to your [cache](../installation#cache). Pass the folder path to [`~QwenImagePipeline.from_pretrained`] to load it.\n\n```py\nimport torch\nfrom diffusers import QwenImagePipeline\n\npipeline = QwenImagePipeline.from_pretrained(\n  \"path/to/your/cache\", torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\n```\n\nThe [`~QwenImagePipeline.from_pretrained`] method won't download files from the Hub when it detects a local path. But this also means it won't download and cache any updates that have been made to the model either.\n\n## Pipeline data types\n\nUse the `torch_dtype` argument in [`~DiffusionPipeline.from_pretrained`] to load a model with a specific data type. This allows you to load different models in different precisions. For example, loading a large transformer model in half-precision reduces the memory required.\n\nPass the data type for each model as a dictionary to `torch_dtype`. Use the `default` key to set the default data type. If a model isn't in the dictionary and `default` isn't provided, it is loaded in full precision (`torch.float32`).\n\n```py\nimport torch\nfrom diffusers import QwenImagePipeline\n\npipeline = QwenImagePipeline.from_pretrained(\n  \"Qwen/Qwen-Image\",\n  torch_dtype={\"transformer\": torch.bfloat16, \"default\": torch.float16},\n)\nprint(pipeline.transformer.dtype, pipeline.vae.dtype)\n```\n\nYou don't need to use a dictionary if you're loading all the models in the same data type.\n\n```py\nimport torch\nfrom diffusers import QwenImagePipeline\n\npipeline = QwenImagePipeline.from_pretrained(\n  \"Qwen/Qwen-Image\", torch_dtype=torch.bfloat16\n)\nprint(pipeline.transformer.dtype, pipeline.vae.dtype)\n```\n\n## Device placement\n\nThe `device_map` argument determines individual model or pipeline placement on an accelerator like a GPU. It is especially helpful when there are multiple GPUs.\n\nA pipeline supports two options for `device_map`, `\"cuda\"` and `\"balanced\"`. Refer to the table below to compare the placement strategies.\n\n| parameter | description |\n|---|---|\n| `\"cuda\"` | places pipeline on a supported accelerator device like CUDA |\n| `\"balanced\"` | evenly distributes pipeline on all GPUs |\n\nUse the `max_memory` argument in [`~DiffusionPipeline.from_pretrained`] to allocate a maximum amount of memory to use on each device. By default, Diffusers uses the maximum amount available.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\nmax_memory = {0: \"16GB\", 1: \"16GB\"}\npipeline = DiffusionPipeline.from_pretrained(\n  \"Qwen/Qwen-Image\", \n  torch_dtype=torch.bfloat16,\n  device_map=\"cuda\",\n)\n```\n\nThe `hf_device_map` attribute allows you to access and view the `device_map`.\n\n```py\nprint(pipeline.hf_device_map)\n# {'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}\n```\n\nReset a pipeline's `device_map` with the [`~DiffusionPipeline.reset_device_map`] method. This is necessary if you want to use methods such as `.to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`].\n\n```py\npipeline.reset_device_map()\n```\n\n## Parallel loading\n\nLarge models are often [sharded](../training/distributed_inference#model-sharding) into smaller files so that they are easier to load. Diffusers supports loading shards in parallel to speed up the loading process.\n\nSet `HF_ENABLE_PARALLEL_LOADING` to `\"YES\"` to enable parallel loading of shards.\n\nThe `device_map` argument should be set to `\"cuda\"` to pre-allocate a large chunk of memory based on the model size. This substantially reduces model load time because warming up the memory allocator now avoids many smaller calls to the allocator later.\n\n```py\nimport os\nimport torch\nfrom diffusers import DiffusionPipeline\n\nos.environ[\"HF_ENABLE_PARALLEL_LOADING\"] = \"YES\"\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"Wan-AI/Wan2.2-I2V-A14B-Diffusers\", torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\n```\n\n## Replacing models in a pipeline\n\n[`DiffusionPipeline`] is flexible and accommodates loading different models or schedulers. You can experiment with different schedulers to optimize for generation speed or quality, and you can replace models with more performant ones.\n\nThe example below uses a more stable VAE version.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline, AutoModel\n\nvae = AutoModel.from_pretrained(\n  \"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16\n)\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\",\n  vae=vae,\n  torch_dtype=torch.float16,\n  device_map=\"cuda\"\n)\n```\n\n## Reusing models in multiple pipelines\n\nWhen working with multiple pipelines that use the same model, the [`~DiffusionPipeline.from_pipe`] method enables reusing a model instead of reloading it each time. This allows you to use multiple pipelines without increasing memory usage.\n\nMemory usage is determined by the pipeline with the highest memory requirement regardless of the number of pipelines.\n\nThe example below loads a pipeline and then loads a second pipeline with [`~DiffusionPipeline.from_pipe`] to use [perturbed-attention guidance (PAG)](../api/pipelines/pag) to improve generation quality.\n\n> [!WARNING]\n> Use [`AutoPipelineForText2Image`] because [`DiffusionPipeline`] doesn't support PAG. Refer to the [AutoPipeline](../tutorials/autopipeline) docs to learn more. \n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\n\npipeline_sdxl = AutoPipelineForText2Image.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, device_map=\"cuda\"\n)\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\nimage = pipeline_sdxl(prompt).images[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n# Max memory reserved: 10.47 GB\n```\n\nSet `enable_pag=True` in the second pipeline to enable PAG. The second pipeline uses the same amount of memory because it shares model weights with the first one.\n\n```py\npipeline = AutoPipelineForText2Image.from_pipe(\n  pipeline_sdxl, enable_pag=True\n)\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\nimage = pipeline(prompt).images[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n# Max memory reserved: 10.47 GB\n```\n\n> [!WARNING]\n> Pipelines created by [`~DiffusionPipeline.from_pipe`] share the same models and *state*. Modifying the state of a model in one pipeline affects all the other pipelines that share the same model.\n\nSome methods may not work correctly on pipelines created with [`~DiffusionPipeline.from_pipe`]. For example, [`~DiffusionPipeline.enable_model_cpu_offload`] relies on a unique model execution order, which may differ in the new pipeline. To ensure proper functionality, reapply these methods on the new pipeline.\n\n## Safety checker\n\nDiffusers provides a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) for older Stable Diffusion models to prevent generating harmful content. It screens the generated output against a set of hardcoded harmful concepts.\n\nIf you want to disable the safety checker, pass `safety_checker=None` in [`~DiffusionPipeline.from_pretrained`] as shown below.\n\n```py\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"stable-diffusion-v1-5/stable-diffusion-v1-5\", safety_checker=None\n)\n\"\"\"\nYou have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide by the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend keeping the safety filter enabled in all public-facing circumstances, disabling it only for use cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\n\"\"\"\n```"
  },
  {
    "path": "docs/source/en/using-diffusers/marigold_usage.md",
    "content": "<!--\nCopyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.\nCopyright 2024-2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Marigold Computer Vision\n\n**Marigold** is a diffusion-based [method](https://huggingface.co/papers/2312.02145) and a collection of [pipelines](../api/pipelines/marigold) designed for \ndense computer vision tasks, including **monocular depth prediction**, **surface normals estimation**, and **intrinsic \nimage decomposition**.\n\nThis guide will walk you through using Marigold to generate fast and high-quality predictions for images and videos.\n\nEach pipeline is tailored for a specific computer vision task, processing an input RGB image and generating a \ncorresponding prediction.\nCurrently, the following computer vision tasks are implemented:\n\n| Pipeline                                                                                                                                          | Recommended Model Checkpoints                                                                                                                                                                           |                              Spaces (Interactive Apps)                               | Predicted Modalities                                                                                                                                                               |\n|---------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py)           | [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1)                                                                                                                       |          [Depth Estimation](https://huggingface.co/spaces/prs-eth/marigold)          | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity)                                                                   |\n| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py)       | [prs-eth/marigold-normals-v1-1](https://huggingface.co/prs-eth/marigold-normals-v1-1)                                                                                                                   | [Surface Normals Estimation](https://huggingface.co/spaces/prs-eth/marigold-normals) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping)                                                                                                                    |\n| [MarigoldIntrinsicsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py) | [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1),<br>[prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | [Intrinsic Image Decomposition](https://huggingface.co/spaces/prs-eth/marigold-iid)  | [Albedo](https://en.wikipedia.org/wiki/Albedo), [Materials](https://www.n.aiq3d.com/wiki/roughnessmetalnessao-map), [Lighting](https://en.wikipedia.org/wiki/Diffuse_reflection)   |\n\nAll original checkpoints are available under the [PRS-ETH](https://huggingface.co/prs-eth/) organization on Hugging Face.\nThey are designed for use with diffusers pipelines and the [original codebase](https://github.com/prs-eth/marigold), which can also be used to train \nnew model checkpoints. \nThe following is a summary of the recommended checkpoints, all of which produce reliable results with 1 to 4 steps. \n\n| Checkpoint                                                                                          | Modality     | Comment                                                                                                                                                           |\n|-----------------------------------------------------------------------------------------------------|--------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1)                   | Depth        | Affine-invariant depth prediction assigns each pixel a value between 0 (near plane) and 1 (far plane), with both planes determined by the model during inference. |\n| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1)               | Normals      | The surface normals predictions are unit-length 3D vectors in the screen space camera, with values in the range from -1 to 1.                                     |\n| [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1) | Intrinsics   | InteriorVerse decomposition is comprised of Albedo and two BRDF material properties: Roughness and Metallicity.                                                   | \n| [prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1)     | Intrinsics   | HyperSim decomposition of an image \\\\(I\\\\) is comprised of Albedo \\\\(A\\\\), Diffuse shading \\\\(S\\\\), and Non-diffuse residual \\\\(R\\\\): \\\\(I = A*S+R\\\\).            | \n\nThe examples below are mostly given for depth prediction, but they can be universally applied to other supported \nmodalities.\nWe showcase the predictions using the same input image of Albert Einstein generated by Midjourney.\nThis makes it easier to compare visualizations of the predictions across various modalities and checkpoints.\n\n<div class=\"flex gap-4\" style=\"justify-content: center; width: 100%;\">\n  <div style=\"flex: 1 1 50%; max-width: 50%;\">\n    <img class=\"rounded-xl\" src=\"https://marigoldmonodepth.github.io/images/einstein.jpg\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Example input image for all Marigold pipelines\n    </figcaption>\n  </div>\n</div>\n\n## Depth Prediction\n\nTo get a depth prediction, load the `prs-eth/marigold-depth-v1-1` checkpoint into [`MarigoldDepthPipeline`], \nput the image through the pipeline, and save the predictions:\n\n```python\nimport diffusers\nimport torch\n\npipe = diffusers.MarigoldDepthPipeline.from_pretrained(\n    \"prs-eth/marigold-depth-v1-1\", variant=\"fp16\", torch_dtype=torch.float16\n).to(\"cuda\")\n\nimage = diffusers.utils.load_image(\"https://marigoldmonodepth.github.io/images/einstein.jpg\")\n\ndepth = pipe(image)\n\nvis = pipe.image_processor.visualize_depth(depth.prediction)\nvis[0].save(\"einstein_depth.png\")\n\ndepth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction)\ndepth_16bit[0].save(\"einstein_depth_16bit.png\")\n```\n\nThe [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_depth`] function applies one of \n[matplotlib's colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html) (`Spectral` by default) to map the predicted pixel values from a single-channel `[0, 1]` \ndepth range into an RGB image.\nWith the `Spectral` colormap, pixels with near depth are painted red, and far pixels are blue.\nThe 16-bit PNG file stores the single channel values mapped linearly from the `[0, 1]` range into `[0, 65535]`.\nBelow are the raw and the visualized predictions. The darker and closer areas (mustache) are easier to distinguish in \nthe visualization.\n\n<div class=\"flex gap-4\">\n  <div style=\"flex: 1 1 50%; max-width: 50%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_depth_16bit.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Predicted depth (16-bit PNG)\n    </figcaption>\n  </div>\n  <div style=\"flex: 1 1 50%; max-width: 50%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_depth.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Predicted depth visualization (Spectral)\n    </figcaption>\n  </div>\n</div>\n\n## Surface Normals Estimation\n\nLoad the `prs-eth/marigold-normals-v1-1` checkpoint into [`MarigoldNormalsPipeline`], put the image through the \npipeline, and save the predictions:\n\n```python\nimport diffusers\nimport torch\n\npipe = diffusers.MarigoldNormalsPipeline.from_pretrained(\n    \"prs-eth/marigold-normals-v1-1\", variant=\"fp16\", torch_dtype=torch.float16\n).to(\"cuda\")\n\nimage = diffusers.utils.load_image(\"https://marigoldmonodepth.github.io/images/einstein.jpg\")\n\nnormals = pipe(image)\n\nvis = pipe.image_processor.visualize_normals(normals.prediction)\nvis[0].save(\"einstein_normals.png\")\n```\n\nThe [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_normals`] maps the three-dimensional \nprediction with pixel values in the range `[-1, 1]` into an RGB image.\nThe visualization function supports flipping surface normals axes to make the visualization compatible with other \nchoices of the frame of reference.\nConceptually, each pixel is painted according to the surface normal vector in the frame of reference, where `X` axis \npoints right, `Y` axis points up, and `Z` axis points at the viewer.\nBelow is the visualized prediction:\n\n<div class=\"flex gap-4\" style=\"justify-content: center; width: 100%;\">\n  <div style=\"flex: 1 1 50%; max-width: 50%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_normals.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Predicted surface normals visualization\n    </figcaption>\n  </div>\n</div>\n\nIn this example, the nose tip almost certainly has a point on the surface, in which the surface normal vector points \nstraight at the viewer, meaning that its coordinates are `[0, 0, 1]`.\nThis vector maps to the RGB `[128, 128, 255]`, which corresponds to the violet-blue color.\nSimilarly, a surface normal on the cheek in the right part of the image has a large `X` component, which increases the \nred hue.\nPoints on the shoulders pointing up with a large `Y` promote green color.\n\n## Intrinsic Image Decomposition\n\nMarigold provides two models for Intrinsic Image Decomposition (IID): \"Appearance\" and \"Lighting\". \nEach model produces Albedo maps, derived from InteriorVerse and Hypersim annotations, respectively.\n\n- The \"Appearance\" model also estimates Material properties: Roughness and Metallicity.\n- The \"Lighting\" model generates Diffuse Shading and Non-diffuse Residual.\n\nHere is the sample code saving predictions made by the \"Appearance\" model:\n\n```python\nimport diffusers\nimport torch\n\npipe = diffusers.MarigoldIntrinsicsPipeline.from_pretrained(\n    \"prs-eth/marigold-iid-appearance-v1-1\", variant=\"fp16\", torch_dtype=torch.float16\n).to(\"cuda\")\n\nimage = diffusers.utils.load_image(\"https://marigoldmonodepth.github.io/images/einstein.jpg\")\n\nintrinsics = pipe(image)\n\nvis = pipe.image_processor.visualize_intrinsics(intrinsics.prediction, pipe.target_properties)\nvis[0][\"albedo\"].save(\"einstein_albedo.png\")\nvis[0][\"roughness\"].save(\"einstein_roughness.png\")\nvis[0][\"metallicity\"].save(\"einstein_metallicity.png\")\n```\n\nAnother example demonstrating the predictions made by the \"Lighting\" model:\n\n```python\nimport diffusers\nimport torch\n\npipe = diffusers.MarigoldIntrinsicsPipeline.from_pretrained(\n    \"prs-eth/marigold-iid-lighting-v1-1\", variant=\"fp16\", torch_dtype=torch.float16\n).to(\"cuda\")\n\nimage = diffusers.utils.load_image(\"https://marigoldmonodepth.github.io/images/einstein.jpg\")\n\nintrinsics = pipe(image)\n\nvis = pipe.image_processor.visualize_intrinsics(intrinsics.prediction, pipe.target_properties)\nvis[0][\"albedo\"].save(\"einstein_albedo.png\")\nvis[0][\"shading\"].save(\"einstein_shading.png\")\nvis[0][\"residual\"].save(\"einstein_residual.png\")\n```\n\nBoth models share the same pipeline while supporting different decomposition types.\nThe exact decomposition parameterization (e.g., sRGB vs. linear space) is stored in the \n`pipe.target_properties` dictionary, which is passed into the \n[`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_intrinsics`] function.\n\nBelow are some examples showcasing the predicted decomposition outputs. \nAll modalities can be inspected in the \n[Intrinsic Image Decomposition](https://huggingface.co/spaces/prs-eth/marigold-iid) Space.\n\n<div class=\"flex gap-4\">\n  <div style=\"flex: 1 1 50%; max-width: 50%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/8c7986eaaab5eb9604eb88336311f46a7b0ff5ab/marigold/marigold_einstein_albedo.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Predicted albedo (\"Appearance\" model)\n    </figcaption>\n  </div>\n  <div style=\"flex: 1 1 50%; max-width: 50%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/8c7986eaaab5eb9604eb88336311f46a7b0ff5ab/marigold/marigold_einstein_diffuse.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Predicted diffuse shading (\"Lighting\" model)\n    </figcaption>\n  </div>\n</div>\n\n## Speeding up inference\n\nThe above quick start snippets are already optimized for quality and speed, loading the checkpoint, utilizing the \n`fp16` variant of weights and computation, and performing the default number (4) of denoising diffusion steps.\nThe first step to accelerate inference, at the expense of prediction quality, is to reduce the denoising diffusion \nsteps to the minimum:\n\n```diff\n  import diffusers\n  import torch\n\n  pipe = diffusers.MarigoldDepthPipeline.from_pretrained(\n      \"prs-eth/marigold-depth-v1-1\", variant=\"fp16\", torch_dtype=torch.float16\n  ).to(\"cuda\")\n\n  image = diffusers.utils.load_image(\"https://marigoldmonodepth.github.io/images/einstein.jpg\")\n  \n- depth = pipe(image)\n+ depth = pipe(image, num_inference_steps=1)\n```\n\nWith this change, the `pipe` call completes in 280ms on RTX 3090 GPU.\nInternally, the input image is first encoded using the Stable Diffusion VAE encoder, followed by a single denoising \nstep performed by the U-Net. \nFinally, the prediction latent is decoded with the VAE decoder into pixel space.\nIn this setup, two out of three module calls are dedicated to converting between the pixel and latent spaces of the LDM.\nSince Marigold's latent space is compatible with Stable Diffusion 2.0, inference can be accelerated by more than 3x, \nreducing the call time to 85ms on an RTX 3090, by using a [lightweight replacement of the SD VAE](../api/models/autoencoder_tiny). \nNote that using a lightweight VAE may slightly reduce the visual quality of the predictions.\n\n```diff\n  import diffusers\n  import torch\n\n  pipe = diffusers.MarigoldDepthPipeline.from_pretrained(\n      \"prs-eth/marigold-depth-v1-1\", variant=\"fp16\", torch_dtype=torch.float16\n  ).to(\"cuda\")\n\n+ pipe.vae = diffusers.AutoencoderTiny.from_pretrained(\n+     \"madebyollin/taesd\", torch_dtype=torch.float16\n+ ).cuda()\n\n  image = diffusers.utils.load_image(\"https://marigoldmonodepth.github.io/images/einstein.jpg\")\n\n  depth = pipe(image, num_inference_steps=1)\n```\n\nSo far, we have optimized the number of diffusion steps and model components. Self-attention operations account for a \nsignificant portion of computations. \nSpeeding them up can be achieved by using a more efficient attention processor:\n\n```diff\n  import diffusers\n  import torch\n+ from diffusers.models.attention_processor import AttnProcessor2_0\n\n  pipe = diffusers.MarigoldDepthPipeline.from_pretrained(\n      \"prs-eth/marigold-depth-v1-1\", variant=\"fp16\", torch_dtype=torch.float16\n  ).to(\"cuda\")\n\n+ pipe.vae.set_attn_processor(AttnProcessor2_0()) \n+ pipe.unet.set_attn_processor(AttnProcessor2_0())\n\n  image = diffusers.utils.load_image(\"https://marigoldmonodepth.github.io/images/einstein.jpg\")\n\n  depth = pipe(image, num_inference_steps=1)\n```\n\nFinally, as suggested in [Optimizations](../optimization/fp16#torchcompile), enabling `torch.compile` can further enhance performance depending on \nthe target hardware.\nHowever, compilation incurs a significant overhead during the first pipeline invocation, making it beneficial only when \nthe same pipeline instance is called repeatedly, such as within a loop.\n\n```diff\n  import diffusers\n  import torch\n  from diffusers.models.attention_processor import AttnProcessor2_0\n\n  pipe = diffusers.MarigoldDepthPipeline.from_pretrained(\n      \"prs-eth/marigold-depth-v1-1\", variant=\"fp16\", torch_dtype=torch.float16\n  ).to(\"cuda\")\n\n  pipe.vae.set_attn_processor(AttnProcessor2_0()) \n  pipe.unet.set_attn_processor(AttnProcessor2_0())\n\n+ pipe.vae = torch.compile(pipe.vae, mode=\"reduce-overhead\", fullgraph=True)\n+ pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n\n  image = diffusers.utils.load_image(\"https://marigoldmonodepth.github.io/images/einstein.jpg\")\n\n  depth = pipe(image, num_inference_steps=1)\n```\n\n## Maximizing Precision and Ensembling\n\nMarigold pipelines have a built-in ensembling mechanism combining multiple predictions from different random latents.\nThis is a brute-force way of improving the precision of predictions, capitalizing on the generative nature of diffusion.\nThe ensembling path is activated automatically when the `ensemble_size` argument is set greater or equal than `3`.\nWhen aiming for maximum precision, it makes sense to adjust `num_inference_steps` simultaneously with `ensemble_size`.\nThe recommended values vary across checkpoints but primarily depend on the scheduler type.\nThe effect of ensembling is particularly well-seen with surface normals:\n\n```diff\n  import diffusers\n\n  pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(\"prs-eth/marigold-normals-v1-1\").to(\"cuda\")\n\n  image = diffusers.utils.load_image(\"https://marigoldmonodepth.github.io/images/einstein.jpg\")\n\n- depth = pipe(image)\n+ depth = pipe(image, num_inference_steps=10, ensemble_size=5)\n\n  vis = pipe.image_processor.visualize_normals(depth.prediction)\n  vis[0].save(\"einstein_normals.png\")\n```\n\n<div class=\"flex gap-4\">\n  <div style=\"flex: 1 1 50%; max-width: 50%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_normals.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Surface normals, no ensembling\n    </figcaption>\n  </div>\n  <div style=\"flex: 1 1 50%; max-width: 50%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_normals.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Surface normals, with ensembling\n    </figcaption>\n  </div>\n</div>\n\nAs can be seen, all areas with fine-grained structurers, such as hair, got more conservative and on average more \ncorrect predictions.\nSuch a result is more suitable for precision-sensitive downstream tasks, such as 3D reconstruction.\n\n## Frame-by-frame Video Processing with Temporal Consistency\n\nDue to Marigold's generative nature, each prediction is unique and defined by the random noise sampled for the latent \ninitialization.\nThis becomes an obvious drawback compared to traditional end-to-end dense regression networks, as exemplified in the \nfollowing videos:\n\n<div class=\"flex gap-4\">\n  <div style=\"flex: 1 1 50%; max-width: 50%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_obama.gif\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">Input video</figcaption>\n  </div>\n  <div style=\"flex: 1 1 50%; max-width: 50%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_obama_depth_independent.gif\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">Marigold Depth applied to input video frames independently</figcaption>\n  </div>\n</div>\n\nTo address this issue, it is possible to pass `latents` argument to the pipelines, which defines the starting point of \ndiffusion.\nEmpirically, we found that a convex combination of the very same starting point noise latent and the latent \ncorresponding to the previous frame prediction give sufficiently smooth results, as implemented in the snippet below:\n\n```python\nimport imageio\nimport diffusers\nimport torch\nfrom diffusers.models.attention_processor import AttnProcessor2_0\nfrom PIL import Image\nfrom tqdm import tqdm\n\ndevice = \"cuda\"\npath_in = \"https://huggingface.co/spaces/prs-eth/marigold-lcm/resolve/c7adb5427947d2680944f898cd91d386bf0d4924/files/video/obama.mp4\"\npath_out = \"obama_depth.gif\"\n\npipe = diffusers.MarigoldDepthPipeline.from_pretrained(\n    \"prs-eth/marigold-depth-v1-1\", variant=\"fp16\", torch_dtype=torch.float16\n).to(device)\npipe.vae = diffusers.AutoencoderTiny.from_pretrained(\n    \"madebyollin/taesd\", torch_dtype=torch.float16\n).to(device)\npipe.unet.set_attn_processor(AttnProcessor2_0())\npipe.vae = torch.compile(pipe.vae, mode=\"reduce-overhead\", fullgraph=True)\npipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\npipe.set_progress_bar_config(disable=True)\n\nwith imageio.get_reader(path_in) as reader:\n    size = reader.get_meta_data()['size']\n    last_frame_latent = None\n    latent_common = torch.randn(\n        (1, 4, 768 * size[1] // (8 * max(size)), 768 * size[0] // (8 * max(size)))\n    ).to(device=device, dtype=torch.float16)\n\n    out = []\n    for frame_id, frame in tqdm(enumerate(reader), desc=\"Processing Video\"):\n        frame = Image.fromarray(frame)\n        latents = latent_common\n        if last_frame_latent is not None:\n            latents = 0.9 * latents + 0.1 * last_frame_latent\n\n        depth = pipe(\n            frame,\n            num_inference_steps=1,\n            match_input_resolution=False, \n            latents=latents, \n            output_latent=True,\n        )\n        last_frame_latent = depth.latent\n        out.append(pipe.image_processor.visualize_depth(depth.prediction)[0])\n\n    diffusers.utils.export_to_gif(out, path_out, fps=reader.get_meta_data()['fps'])\n```\n\nHere, the diffusion process starts from the given computed latent.\nThe pipeline sets `output_latent=True` to access `out.latent` and computes its contribution to the next frame's latent \ninitialization.\nThe result is much more stable now:\n\n<div class=\"flex gap-4\">\n  <div style=\"flex: 1 1 50%; max-width: 50%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_obama_depth_independent.gif\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">Marigold Depth applied to input video frames independently</figcaption>\n  </div>\n  <div style=\"flex: 1 1 50%; max-width: 50%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_obama_depth_consistent.gif\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">Marigold Depth with forced latents initialization</figcaption>\n  </div>\n</div>\n\n## Marigold for ControlNet\n\nA very common application for depth prediction with diffusion models comes in conjunction with ControlNet.\nDepth crispness plays a crucial role in obtaining high-quality results from ControlNet.\nAs seen in comparisons with other methods above, Marigold excels at that task.\nThe snippet below demonstrates how to load an image, compute depth, and pass it into ControlNet in a compatible format:\n\n```python\nimport torch\nimport diffusers\n\ndevice = \"cuda\"\ngenerator = torch.Generator(device=device).manual_seed(2024)\nimage = diffusers.utils.load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_depth_source.png\"\n)\n\npipe = diffusers.MarigoldDepthPipeline.from_pretrained(\n    \"prs-eth/marigold-depth-v1-1\", torch_dtype=torch.float16, variant=\"fp16\"\n).to(device)\n\ndepth_image = pipe(image, generator=generator).prediction\ndepth_image = pipe.image_processor.visualize_depth(depth_image, color_map=\"binary\")\ndepth_image[0].save(\"motorcycle_controlnet_depth.png\")\n\ncontrolnet = diffusers.ControlNetModel.from_pretrained(\n    \"diffusers/controlnet-depth-sdxl-1.0\", torch_dtype=torch.float16, variant=\"fp16\"\n).to(device)\npipe = diffusers.StableDiffusionXLControlNetPipeline.from_pretrained(\n    \"SG161222/RealVisXL_V4.0\", torch_dtype=torch.float16, variant=\"fp16\", controlnet=controlnet\n).to(device)\npipe.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)\n\ncontrolnet_out = pipe(\n    prompt=\"high quality photo of a sports bike, city\",\n    negative_prompt=\"\",\n    guidance_scale=6.5,\n    num_inference_steps=25,\n    image=depth_image,\n    controlnet_conditioning_scale=0.7,\n    control_guidance_end=0.7,\n    generator=generator,\n).images\ncontrolnet_out[0].save(\"motorcycle_controlnet_out.png\")\n```\n\n<div class=\"flex gap-4\">\n  <div style=\"flex: 1 1 33%; max-width: 33%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_depth_source.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Input image\n    </figcaption>\n  </div>\n  <div style=\"flex: 1 1 33%; max-width: 33%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/motorcycle_controlnet_depth.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Depth in the format compatible with ControlNet\n    </figcaption>\n  </div>\n  <div style=\"flex: 1 1 33%; max-width: 33%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/motorcycle_controlnet_out.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      ControlNet generation, conditioned on depth and prompt: \"high quality photo of a sports bike, city\"\n    </figcaption>\n  </div>\n</div>\n\n## Quantitative Evaluation\n\nTo evaluate Marigold quantitatively in standard leaderboards and benchmarks (such as NYU, KITTI, and other datasets), \nfollow the evaluation protocol outlined in the paper: load the full precision fp32 model and use appropriate values \nfor `num_inference_steps` and `ensemble_size`.\nOptionally seed randomness to ensure reproducibility. \nMaximizing `batch_size` will deliver maximum device utilization.\n\n```python\nimport diffusers\nimport torch\n\ndevice = \"cuda\"\nseed = 2024\n\ngenerator = torch.Generator(device=device).manual_seed(seed)\npipe = diffusers.MarigoldDepthPipeline.from_pretrained(\"prs-eth/marigold-depth-v1-1\").to(device)\n\nimage = diffusers.utils.load_image(\"https://marigoldmonodepth.github.io/images/einstein.jpg\")\n\ndepth = pipe(\n    image, \n    num_inference_steps=4,  # set according to the evaluation protocol from the paper\n    ensemble_size=10,       # set according to the evaluation protocol from the paper\n    generator=generator,\n)\n\n# evaluate metrics\n```\n\n## Using Predictive Uncertainty\n\nThe ensembling mechanism built into Marigold pipelines combines multiple predictions obtained from different random \nlatents.\nAs a side effect, it can be used to quantify epistemic (model) uncertainty; simply specify `ensemble_size` greater \nor equal than 3 and set `output_uncertainty=True`.\nThe resulting uncertainty will be available in the `uncertainty` field of the output.\nIt can be visualized as follows:\n\n```python\nimport diffusers\nimport torch\n\npipe = diffusers.MarigoldDepthPipeline.from_pretrained(\n    \"prs-eth/marigold-depth-v1-1\", variant=\"fp16\", torch_dtype=torch.float16\n).to(\"cuda\")\n\nimage = diffusers.utils.load_image(\"https://marigoldmonodepth.github.io/images/einstein.jpg\")\n\ndepth = pipe(\n\timage,\n\tensemble_size=10,  # any number >= 3\n\toutput_uncertainty=True,\n)\n\nuncertainty = pipe.image_processor.visualize_uncertainty(depth.uncertainty)\nuncertainty[0].save(\"einstein_depth_uncertainty.png\")\n```\n\n<div class=\"flex gap-4\">\n  <div style=\"flex: 1 1 33%; max-width: 33%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_depth_uncertainty.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Depth uncertainty\n    </figcaption>\n  </div>\n  <div style=\"flex: 1 1 33%; max-width: 33%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_normals_uncertainty.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Surface normals uncertainty\n    </figcaption>\n  </div>\n  <div style=\"flex: 1 1 33%; max-width: 33%;\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/4f83035d84a24e5ec44fdda129b1d51eba12ce04/marigold/marigold_einstein_albedo_uncertainty.png\"/>\n    <figcaption class=\"mt-1 text-center text-sm text-gray-500\">\n      Albedo uncertainty\n    </figcaption>\n  </div>\n</div>\n\nThe interpretation of uncertainty is easy: higher values (white) correspond to pixels, where the model struggles to \nmake consistent predictions.\n- The depth model exhibits the most uncertainty around discontinuities, where object depth changes abruptly.\n- The surface normals model is least confident in fine-grained structures like hair and in dark regions such as the \ncollar area.\n- Albedo uncertainty is represented as an RGB image, as it captures uncertainty independently for each color channel, \nunlike depth and surface normals. It is also higher in shaded regions and at discontinuities.\n\n## Conclusion\n\nWe hope Marigold proves valuable for your downstream tasks, whether as part of a broader generative workflow or for \nperception-based applications like 3D reconstruction."
  },
  {
    "path": "docs/source/en/using-diffusers/omnigen.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n# OmniGen\n\nOmniGen is an image generation model. Unlike existing text-to-image models, OmniGen is a single model designed to handle a variety of tasks (e.g., text-to-image, image editing, controllable generation). It has the following features:\n- Minimalist model architecture, consisting of only a VAE and a transformer module, for joint modeling of text and images.\n- Support for multimodal inputs. It can process any text-image mixed data as instructions for image generation, rather than relying solely on text.\n\nFor more information, please refer to the [paper](https://huggingface.co/papers/2409.11340).\nThis guide will walk you through using OmniGen for various tasks and use cases.\n\n## Load model checkpoints\n\nModel weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.\n\n```python\nimport torch\nfrom diffusers import OmniGenPipeline\n\npipe = OmniGenPipeline.from_pretrained(\"Shitao/OmniGen-v1-diffusers\", torch_dtype=torch.bfloat16)\n```\n\n## Text-to-image\n\nFor text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image. \nYou can try setting the `height` and `width` parameters to generate images with different size.\n\n```python\nimport torch\nfrom diffusers import OmniGenPipeline\n\npipe = OmniGenPipeline.from_pretrained(\n    \"Shitao/OmniGen-v1-diffusers\",\n    torch_dtype=torch.bfloat16\n)\npipe.to(\"cuda\")\n\nprompt = \"Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD.\"\nimage = pipe(\n    prompt=prompt,\n    height=1024,\n    width=1024,\n    guidance_scale=3,\n    generator=torch.Generator(device=\"cpu\").manual_seed(111),\n).images[0]\nimage.save(\"output.png\")\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png\" alt=\"generated image\"/>\n</div>\n\n## Image edit\n\nOmniGen supports multimodal inputs. \nWhen the input includes an image, you need to add a placeholder `<img><|image_1|></img>` in the text prompt to represent the image. \nIt is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.\n\n```python\nimport torch\nfrom diffusers import OmniGenPipeline\nfrom diffusers.utils import load_image \n\npipe = OmniGenPipeline.from_pretrained(\n    \"Shitao/OmniGen-v1-diffusers\",\n    torch_dtype=torch.bfloat16\n)\npipe.to(\"cuda\")\n\nprompt=\"<img><|image_1|></img> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola.\"\ninput_images=[load_image(\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png\")]\nimage = pipe(\n    prompt=prompt, \n    input_images=input_images, \n    guidance_scale=2, \n    img_guidance_scale=1.6,\n    use_input_image_size_as_output=True,\n    generator=torch.Generator(device=\"cpu\").manual_seed(222)\n).images[0]\nimage.save(\"output.png\")\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">original image</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">edited image</figcaption>\n  </div>\n</div>\n\nOmniGen has some interesting features, such as visual reasoning, as shown in the example below.\n\n```python\nprompt=\"If the woman is thirsty, what should she take? Find it in the image and highlight it in blue. <img><|image_1|></img>\"\ninput_images=[load_image(\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png\")]\nimage = pipe(\n    prompt=prompt, \n    input_images=input_images, \n    guidance_scale=2, \n    img_guidance_scale=1.6,\n    use_input_image_size_as_output=True,\n    generator=torch.Generator(device=\"cpu\").manual_seed(0)\n).images[0]\nimage.save(\"output.png\")\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/reasoning.png\" alt=\"generated image\"/>\n</div>\n\n## Controllable generation\n\nOmniGen can handle several classic computer vision tasks. As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images.\n\n```python\nimport torch\nfrom diffusers import OmniGenPipeline\nfrom diffusers.utils import load_image \n\npipe = OmniGenPipeline.from_pretrained(\n    \"Shitao/OmniGen-v1-diffusers\",\n    torch_dtype=torch.bfloat16\n)\npipe.to(\"cuda\")\n\nprompt=\"Detect the skeleton of human in this image: <img><|image_1|></img>\"\ninput_images=[load_image(\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png\")]\nimage1 = pipe(\n    prompt=prompt, \n    input_images=input_images, \n    guidance_scale=2, \n    img_guidance_scale=1.6,\n    use_input_image_size_as_output=True,\n    generator=torch.Generator(device=\"cpu\").manual_seed(333)\n).images[0]\nimage1.save(\"image1.png\")\n\nprompt=\"Generate a new photo using the following picture and text as conditions: <img><|image_1|></img>\\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.\"\ninput_images=[load_image(\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png\")]\nimage2 = pipe(\n    prompt=prompt, \n    input_images=input_images, \n    guidance_scale=2, \n    img_guidance_scale=1.6,\n    use_input_image_size_as_output=True,\n    generator=torch.Generator(device=\"cpu\").manual_seed(333)\n).images[0]\nimage2.save(\"image2.png\")\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">original image</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">detected skeleton</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal2img.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">skeleton to image</figcaption>\n  </div>\n</div>\n\n\nOmniGen can also directly use relevant information from input images to generate new images.\n\n```python\nimport torch\nfrom diffusers import OmniGenPipeline\nfrom diffusers.utils import load_image \n\npipe = OmniGenPipeline.from_pretrained(\n    \"Shitao/OmniGen-v1-diffusers\",\n    torch_dtype=torch.bfloat16\n)\npipe.to(\"cuda\")\n\nprompt=\"Following the pose of this image <img><|image_1|></img>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him.\"\ninput_images=[load_image(\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png\")]\nimage = pipe(\n    prompt=prompt, \n    input_images=input_images, \n    guidance_scale=2, \n    img_guidance_scale=1.6,\n    use_input_image_size_as_output=True,\n    generator=torch.Generator(device=\"cpu\").manual_seed(0)\n).images[0]\nimage.save(\"output.png\")\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/same_pose.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image</figcaption>\n  </div>\n</div>\n\n## ID and object preserving\n\nOmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously. \nAdditionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions.\n\n```python\nimport torch\nfrom diffusers import OmniGenPipeline\nfrom diffusers.utils import load_image \n\npipe = OmniGenPipeline.from_pretrained(\n    \"Shitao/OmniGen-v1-diffusers\",\n    torch_dtype=torch.bfloat16\n)\npipe.to(\"cuda\")\n\nprompt=\"A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in <img><|image_1|></img>. The woman is the woman on the left of <img><|image_2|></img>\"\ninput_image_1 = load_image(\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png\")\ninput_image_2 = load_image(\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.png\")\ninput_images=[input_image_1, input_image_2]\nimage = pipe(\n    prompt=prompt, \n    input_images=input_images, \n    height=1024,\n    width=1024,\n    guidance_scale=2.5, \n    img_guidance_scale=1.6,\n    generator=torch.Generator(device=\"cpu\").manual_seed(666)\n).images[0]\nimage.save(\"output.png\")\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">input_image_1</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">input_image_2</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/id2.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image</figcaption>\n  </div>\n</div>\n\n```py\nimport torch\nfrom diffusers import OmniGenPipeline\nfrom diffusers.utils import load_image \n\npipe = OmniGenPipeline.from_pretrained(\n    \"Shitao/OmniGen-v1-diffusers\",\n    torch_dtype=torch.bfloat16\n)\npipe.to(\"cuda\")\n\nprompt=\"A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is <img><|image_1|></img>. The long-sleeve blouse and a pleated skirt are <img><|image_2|></img>.\"\ninput_image_1 = load_image(\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg\")\ninput_image_2 = load_image(\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg\")\ninput_images=[input_image_1, input_image_2]\nimage = pipe(\n    prompt=prompt, \n    input_images=input_images, \n    height=1024,\n    width=1024,\n    guidance_scale=2.5, \n    img_guidance_scale=1.6,\n    generator=torch.Generator(device=\"cpu\").manual_seed(666)\n).images[0]\nimage.save(\"output.png\")\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">person image</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">clothe image</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/tryon.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image</figcaption>\n  </div>\n</div>\n\n## Optimization when using multiple images \n\nFor text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU). \nHowever, when using input images, the computational cost increases. \n\nHere are some guidelines to help you reduce computational costs when using multiple images. The experiments are conducted on an A800 GPU with two input images.\n\nLike other pipelines, you can reduce memory usage by offloading the model: `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload() `. \nIn OmniGen, you can also decrease computational overhead by reducing the `max_input_image_size`. \nThe memory consumption for different image sizes is shown in the table below:\n\n| Method                    | Memory Usage |\n|---------------------------|--------------|\n| max_input_image_size=1024 | 40GB         |\n| max_input_image_size=512  | 17GB         |\n| max_input_image_size=256  | 14GB         |\n\n"
  },
  {
    "path": "docs/source/en/using-diffusers/other-formats.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# Model formats\n\nDiffusion models are typically stored in the Diffusers format or single-file format. Model files can be stored in various file types such as safetensors, dduf, or ckpt.\n\n> [!TIP]\n> Format refers to whether the weights are stored in a directory structure and file refers to the file type.\n\nThis guide will show you how to load pipelines and models from these formats and files.\n\n## Diffusers format\n\nThe Diffusers format stores each model (UNet, transformer, text encoder) in a separate subfolder. There are several benefits to storing models separately.\n\n- Faster overall pipeline initialization because you can load the individual model you need or load them all in parallel.\n- Reduced memory usage because you don't need to load all the pipeline components if you only need one model. [Reuse](./loading#reusing-models-in-multiple-pipelines) a model that is shared between multiple pipelines.\n- Lower storage requirements because common models shared between multiple pipelines are only downloaded once.\n- Flexibility to use new or improved models in a pipeline.\n\n## Single file format\n\nA single-file format stores *all* the model (UNet, transformer, text encoder) weights in a single file. Benefits of single-file formats include the following.\n\n- Greater compatibility with [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui).\n- Easier to download and share a single file.\n\nUse [`~loaders.FromSingleFileMixin.from_single_file`] to load a single file.\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_single_file(\n    \"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors\",\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\n```\n\nThe [`~loaders.FromSingleFileMixin.from_single_file`] method also supports passing new models or schedulers.\n\n```py\nimport torch\nfrom diffusers import FluxPipeline, FluxTransformer2DModel\n\ntransformer = FluxTransformer2DModel.from_single_file(\n    \"https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors\", torch_dtype=torch.bfloat16\n)\npipeline = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    transformer=transformer,\n    torch_dtype=torch.bfloat16,\n    device_map=\"cuda\"\n)\n```\n\n### Configuration options\n\nDiffusers format models have a `config.json` file in their repositories with important attributes such as the number of layers and attention heads. The [`~loaders.FromSingleFileMixin.from_single_file`] method automatically determines the appropriate config to use from `config.json`. This may fail in a few rare instances though, in which case, you should use the `config` argument.\n\nYou should also use the `config` argument if the models in a pipeline are different from the original implementation or if it doesn't have the necessary metadata to determine the correct config.\n\n```py\nfrom diffusers import StableDiffusionXLPipeline\n\nckpt_path = \"https://huggingface.co/segmind/SSD-1B/blob/main/SSD-1B.safetensors\"\n\npipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, config=\"segmind/SSD-1B\")\n```\n\nDiffusers attempts to infer the pipeline components based on the signature types of the pipeline class when using `original_config` with `local_files_only=True`. It won't download the config files from a Hub repository to avoid backward breaking changes when you can't connect to the internet. This method isn't as reliable as providing a path to a local model with the `config` argument and may lead to errors. You should run the pipeline with `local_files_only=False` to download the config files to the local cache to avoid errors.\n\nOverride default configs by passing the arguments directly to [`~loaders.FromSingleFileMixin.from_single_file`]. The examples below demonstrate how to override the configs in a pipeline or model.\n\n```py\nfrom diffusers import StableDiffusionXLInstructPix2PixPipeline\n\nckpt_path = \"https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors\"\npipeline = StableDiffusionXLInstructPix2PixPipeline.from_single_file(\n    ckpt_path, config=\"diffusers/sdxl-instructpix2pix-768\", is_cosxl_edit=True\n)\n```\n\n```py\nfrom diffusers import UNet2DConditionModel\n\nckpt_path = \"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors\"\nmodel = UNet2DConditionModel.from_single_file(ckpt_path, upcast_attention=True)\n```\n\n### Local files\n\nThe [`~loaders.FromSingleFileMixin.from_single_file`] method attempts to configure a pipeline or model by inferring the model type from the keys in the checkpoint file. For example, any single file checkpoint based on the Stable Diffusion XL base model is configured from [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).\n\nIf you're working with local files, download the config files with the [`~huggingface_hub.snapshot_download`] method and the model checkpoint with [`~huggingface_hub.hf_hub_download`]. These files are downloaded to your [cache directory](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache), but you can download them to a specific directory with the `local_dir` argument.\n\n```py\nfrom huggingface_hub import hf_hub_download, snapshot_download\nfrom diffusers import StableDiffusionXLPipeline\n\nmy_local_checkpoint_path = hf_hub_download(\n    repo_id=\"segmind/SSD-1B\",\n    filename=\"SSD-1B.safetensors\"\n)\n\nmy_local_config_path = snapshot_download(\n    repo_id=\"segmind/SSD-1B\",\n    allow_patterns=[\"*.json\", \"**/*.json\", \"*.txt\", \"**/*.txt\"]\n)\n\npipeline = StableDiffusionXLPipeline.from_single_file(\n    my_local_checkpoint_path, config=my_local_config_path, local_files_only=True\n)\n```\n\n### Symlink\n\nIf you're working with a file system that does not support symlinking, download the checkpoint file to a local directory first with the `local_dir` parameter. Using the `local_dir` parameter automatically disables symlinks.\n\n```py\nfrom huggingface_hub import hf_hub_download, snapshot_download\nfrom diffusers import StableDiffusionXLPipeline\n\nmy_local_checkpoint_path = hf_hub_download(\n    repo_id=\"segmind/SSD-1B\",\n    filename=\"SSD-1B.safetensors\"\n    local_dir=\"my_local_checkpoints\",\n)\nprint(\"My local checkpoint: \", my_local_checkpoint_path)\n\nmy_local_config_path = snapshot_download(\n    repo_id=\"segmind/SSD-1B\",\n    allow_patterns=[\"*.json\", \"**/*.json\", \"*.txt\", \"**/*.txt\"]\n)\nprint(\"My local config: \", my_local_config_path)\n```\n\nPass these paths to [`~loaders.FromSingleFileMixin.from_single_file`].\n\n```py\npipeline = StableDiffusionXLPipeline.from_single_file(\n    my_local_checkpoint_path, config=my_local_config_path, local_files_only=True\n)\n```\n\n## File types\n\nModels can be stored in several file types. Safetensors is the most common file type but you may encounter other file types on the Hub or diffusion community.\n\n### safetensors\n\n[Safetensors](https://hf.co/docs/safetensors) is a safe and fast file type for securely storing and loading tensors. It restricts the header size to limit certain types of attacks, supports lazy loading (useful for distributed setups), and generally loads faster.\n\nDiffusers loads safetensors file by default (a required dependency) if they are available and the Safetensors library is installed.\n\nUse [`~DiffusionPipeline.from_pretrained`] or [`~loaders.FromSingleFileMixin.from_single_file`] to load safetensor files.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch.dtype=torch.float16,\n    device_map=\"cuda\"\n)\n\npipeline = DiffusionPipeline.from_single_file(\n    \"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors\",\n    torch_dtype=torch.float16,\n)\n```\n\nIf you're using a checkpoint trained with a Diffusers training script, metadata such as the LoRA configuration, is automatically saved. When the file is loaded, the metadata is parsed to correctly configure the LoRA and avoid missing or incorrect LoRA configs. Inspect the metadata of a safetensors file by clicking on the <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/safetensors/logo.png\" alt=\"safetensors logo\" style=\"vertical-align: middle; display: inline-block; max-height: 0.8em; max-width: 0.8em; margin: 0; padding: 0; line-height: 1;\"> logo next to the file on the Hub.\n\nSave the metadata for LoRAs that aren't trained with Diffusers with either `transformer_lora_adapter_metadata` or `unet_lora_adapter_metadata` depending on your model. For the text encoder, use the `text_encoder_lora_adapter_metadata` and `text_encoder_2_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`]. This is only supported for safetensors files.\n\n```py\nimport torch\nfrom diffusers import FluxPipeline\n\npipeline = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\npipeline.load_lora_weights(\"linoyts/yarn_art_Flux_LoRA\")\npipeline.save_lora_weights(\n    text_encoder_lora_adapter_metadata={\"r\": 8, \"lora_alpha\": 8},\n    text_encoder_2_lora_adapter_metadata={\"r\": 8, \"lora_alpha\": 8}\n)\n```\n\n### ckpt\n\nOlder model weights are commonly saved with Python's [pickle](https://docs.python.org/3/library/pickle.html) utility in a ckpt file.\n\nPickled files may be unsafe because they can be exploited to execute malicious code. It is recommended to use safetensors files or convert the weights to safetensors files.\n\nUse [`~loaders.FromSingleFileMixin.from_single_file`] to load a ckpt file.\n\n```py\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_single_file(\n    \"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned.ckpt\"\n)\n```\n\n### dduf\n\n> [!TIP]\n> DDUF is an experimental file type and the API may change. Refer to the DDUF [docs](https://huggingface.co/docs/hub/dduf) to learn more.\n\nDDUF is a file type designed to unify different diffusion model distribution methods and weight-saving formats. It is a standardized and flexible method to package all components of a diffusion model into a single file, providing a balance between the Diffusers and single-file formats.\n\nUse the `dduf_file` argument in [`~DiffusionPipeline.from_pretrained`] to load a DDUF file. You can also load quantized dduf files as long as they are stored in the Diffusers format.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"DDUF/FLUX.1-dev-DDUF\",\n    dduf_file=\"FLUX.1-dev.dduf\",\n    torch_dtype=torch.bfloat16,\n    device_map=\"cuda\"\n)\n```\n\nTo save a pipeline as a dduf file, use the [`~huggingface_hub.export_folder_as_dduf`] utility.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom huggingface_hub import export_folder_as_dduf\n\npipeline = DiffusionPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16)\n\nsave_folder = \"flux-dev\"\npipeline.save_pretrained(\"flux-dev\")\nexport_folder_as_dduf(\"flux-dev.dduf\", folder_path=save_folder)\n```\n\n## Converting formats and files\n\nDiffusers provides scripts and methods to convert format and files to enable broader support across the diffusion ecosystem.\n\nTake a look at the [diffusers/scripts](https://github.com/huggingface/diffusers/tree/main/scripts) folder to find a conversion script. Scripts with `\"to_diffusers` appended at the end converts a model to the Diffusers format. Each script has a specific set of arguments for configuring the conversion. Make sure you check what arguments are available.\n\nThe example below converts a model stored in Diffusers format to a single-file format. Provide the path to the model to convert and where to save the converted model. You can optionally specify what file type and data type to save the model as.\n\n```bash\npython convert_diffusers_to_original_sdxl.py --model_path path/to/model/to/convert --checkpoint_path path/to/save/model/to --use_safetensors\n```\n\nThe [`~DiffusionPipeline.save_pretrained`] method also saves a model in Diffusers format and takes care of creating subfolders for each model. It saves the files as safetensor files by default.\n\n```py\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_single_file(\n    \"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors\",\n)\npipeline.save_pretrained()\n```\n\nFinally, you can use a Space like [SD To Diffusers](https://hf.co/spaces/diffusers/sd-to-diffusers) or [SD-XL To Diffusers](https://hf.co/spaces/diffusers/sdxl-to-diffusers) to convert models to the Diffusers format. It'll open a PR on your model repository with the converted files. This is the easiest way to convert a model, but it may fail for more complicated models. Using a conversion script is more reliable.\n\n## Resources\n\n- Learn more about the design decisions and why safetensor files are preferred for saving and loading model weights in the [Safetensors audited as really safe and becoming the default](https://blog.eleuther.ai/safetensors-security-audit/) blog post.\n\n"
  },
  {
    "path": "docs/source/en/using-diffusers/pag.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Perturbed-Attention Guidance\n\n[Perturbed-Attention Guidance (PAG)](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) is a new diffusion sampling guidance that improves sample quality across both unconditional and conditional settings, achieving this without requiring further training or the integration of external modules. PAG is designed to progressively enhance the structure of synthesized samples throughout the denoising process by considering the self-attention mechanisms' ability to capture structural information. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, and guiding the denoising process away from these degraded samples.\n\nThis guide will show you how to use PAG for various tasks and use cases.\n\n\n## General tasks\n\nYou can apply PAG to the [`StableDiffusionXLPipeline`] for tasks such as text-to-image, image-to-image, and inpainting. To enable PAG for a specific task, load the pipeline using the [AutoPipeline](../api/pipelines/auto_pipeline) API with the `enable_pag=True` flag and the `pag_applied_layers` argument.\n\n> [!TIP]\n> 🤗 Diffusers currently only supports using PAG with selected SDXL pipelines and [`PixArtSigmaPAGPipeline`]. But feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to add PAG support to a new pipeline!\n\n<hfoptions id=\"tasks\">\n<hfoption id=\"Text-to-image\">\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nfrom diffusers.utils import load_image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    enable_pag=True,\n    pag_applied_layers=[\"mid\"],\n    torch_dtype=torch.float16\n)\npipeline.enable_model_cpu_offload()\n```\n\n> [!TIP]\n> The `pag_applied_layers` argument allows you to specify which layers PAG is applied to. Additionally, you can use `set_pag_applied_layers` method to update these layers after the pipeline has been created. Check out the [pag_applied_layers](#pag_applied_layers) section to learn more about applying PAG to other layers.\n\nIf you already have a pipeline created and loaded, you can enable PAG on it using the `from_pipe` API with the `enable_pag` flag. Internally, a PAG pipeline is created based on the pipeline and task you specified. In the example below, since we used `AutoPipelineForText2Image` and passed a `StableDiffusionXLPipeline`, a `StableDiffusionXLPAGPipeline` is created accordingly. Note that this does not require additional memory, and you will have both `StableDiffusionXLPipeline` and  `StableDiffusionXLPAGPipeline` loaded and ready to use. You can read more about the `from_pipe` API and how to reuse pipelines in diffuser [here](https://huggingface.co/docs/diffusers/using-diffusers/loading#reuse-a-pipeline).\n\n```py\npipeline_sdxl = AutoPipelineForText2Image.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16)\npipeline = AutoPipelineForText2Image.from_pipe(pipeline_sdxl, enable_pag=True)\n```\n\nTo generate an image, you will also need to pass a `pag_scale`. When `pag_scale` increases, images gain more semantically coherent structures and exhibit fewer artifacts. However overly large guidance scale can lead to smoother textures and slight saturation in the images, similarly to CFG. `pag_scale=3.0` is used in the official demo and works well in most of the use cases, but feel free to experiment and select the appropriate value according to your needs! PAG is disabled when `pag_scale=0`.\n\n```py\nprompt = \"an insect robot preparing a delicious meal, anime style\"\n\nfor pag_scale in [0.0, 3.0]:\n    generator = torch.Generator(device=\"cpu\").manual_seed(0)\n    images = pipeline(\n        prompt=prompt,\n        num_inference_steps=25,\n        guidance_scale=7.0,\n        generator=generator,\n        pag_scale=pag_scale,\n    ).images\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_0.0_cfg_7.0_mid.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image without PAG</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_3.0_cfg_7.0_mid.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image with PAG</figcaption>\n  </div>\n</div>\n\n</hfoption>\n<hfoption id=\"Image-to-image\">\n\nYou can use PAG with image-to-image pipelines.\n\n```py\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import load_image\nimport torch\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    enable_pag=True,\n    pag_applied_layers=[\"mid\"],\n    torch_dtype=torch.float16\n)\npipeline.enable_model_cpu_offload()\n```\n\nIf you already have a image-to-image pipeline and would like enable PAG on it, you can run this\n\n```py\npipeline_t2i = AutoPipelineForImage2Image.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16)\npipeline = AutoPipelineForImage2Image.from_pipe(pipeline_t2i, enable_pag=True)\n```\n\nIt is also very easy to directly switch from a text-to-image pipeline to PAG enabled image-to-image pipeline\n\n```py\npipeline_pag = AutoPipelineForText2Image.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16)\npipeline = AutoPipelineForImage2Image.from_pipe(pipeline_t2i, enable_pag=True)\n```\n\nIf you have a PAG enabled text-to-image pipeline, you can directly switch to a image-to-image pipeline with PAG still enabled\n\n```py\npipeline_pag = AutoPipelineForText2Image.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", enable_pag=True, torch_dtype=torch.float16)\npipeline = AutoPipelineForImage2Image.from_pipe(pipeline_t2i)\n```\n\nNow let's generate an image!\n\n```py\npag_scales =  4.0\nguidance_scales = 7.0\n\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png\"\ninit_image = load_image(url)\nprompt = \"a dog catching a frisbee in the jungle\"\n\ngenerator = torch.Generator(device=\"cpu\").manual_seed(0)\nimage = pipeline(\n    prompt,\n    image=init_image,\n    strength=0.8,\n    guidance_scale=guidance_scale,\n    pag_scale=pag_scale,\n    generator=generator).images[0]\n```\n\n</hfoption>\n<hfoption id=\"Inpainting\">\n\n```py\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image\nimport torch\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    enable_pag=True,\n    torch_dtype=torch.float16\n)\npipeline.enable_model_cpu_offload()\n```\n\nYou can enable PAG on an existing inpainting pipeline like this\n\n```py\npipeline_inpaint = AutoPipelineForInpainting.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16)\npipeline = AutoPipelineForInpainting.from_pipe(pipeline_inpaint, enable_pag=True)\n```\n\nThis still works when your pipeline has a different task:\n\n```py\npipeline_t2i = AutoPipelineForText2Image.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16)\npipeline = AutoPipelineForInpaiting.from_pipe(pipeline_t2i, enable_pag=True)\n```\n\nLet's generate an image!\n\n```py\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\ninit_image = load_image(img_url).convert(\"RGB\")\nmask_image = load_image(mask_url).convert(\"RGB\")\n\nprompt = \"A majestic tiger sitting on a bench\"\n\npag_scales =  3.0\nguidance_scales = 7.5\n\ngenerator = torch.Generator(device=\"cpu\").manual_seed(1)\nimages = pipeline(\n    prompt=prompt,\n    image=init_image,\n    mask_image=mask_image,\n    strength=0.8,\n    num_inference_steps=50,\n    guidance_scale=guidance_scale,\n    generator=generator,\n    pag_scale=pag_scale,\n).images\nimages[0]\n```\n</hfoption>\n</hfoptions>\n\n## PAG with ControlNet\n\nTo use PAG with ControlNet, first create a `controlnet`. Then, pass the `controlnet` and other PAG arguments to the `from_pretrained` method of the AutoPipeline for the specified task.\n\n```py\nfrom diffusers import AutoPipelineForText2Image, ControlNetModel\nimport torch\n\ncontrolnet = ControlNetModel.from_pretrained(\n    \"diffusers/controlnet-canny-sdxl-1.0\", torch_dtype=torch.float16\n)\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    controlnet=controlnet,\n    enable_pag=True,\n    pag_applied_layers=\"mid\",\n    torch_dtype=torch.float16\n)\npipeline.enable_model_cpu_offload()\n```\n\n> [!TIP]\n> If you already have a controlnet pipeline and want to enable PAG, you can use the `from_pipe` API: `AutoPipelineForText2Image.from_pipe(pipeline_controlnet, enable_pag=True)`\n\nYou can use the pipeline in the same way you normally use ControlNet pipelines, with the added option to specify a `pag_scale` parameter. Note that PAG works well for unconditional generation. In this example, we will generate an image without a prompt.\n\n```py\nfrom diffusers.utils import load_image\ncanny_image = load_image(\n    \"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_control_input.png\"\n)\n\nfor pag_scale in [0.0, 3.0]:\n    generator = torch.Generator(device=\"cpu\").manual_seed(1)\n    images = pipeline(\n        prompt=\"\",\n        controlnet_conditioning_scale=controlnet_conditioning_scale,\n        image=canny_image,\n        num_inference_steps=50,\n        guidance_scale=0,\n        generator=generator,\n        pag_scale=pag_scale,\n    ).images\n    images[0]\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_0.0_controlnet.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image without PAG</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_3.0_controlnet.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image with PAG</figcaption>\n  </div>\n</div>\n\n## PAG with IP-Adapter\n\n[IP-Adapter](https://hf.co/papers/2308.06721) is a popular model that can be plugged into diffusion models to enable image prompting without any changes to the underlying model. You can enable PAG on a pipeline with IP-Adapter loaded.\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nfrom diffusers.utils import load_image\nfrom transformers import CLIPVisionModelWithProjection\nimport torch\n\nimage_encoder = CLIPVisionModelWithProjection.from_pretrained(\n    \"h94/IP-Adapter\",\n    subfolder=\"models/image_encoder\",\n    torch_dtype=torch.float16\n)\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    image_encoder=image_encoder,\n    enable_pag=True,\n    torch_dtype=torch.float16\n).to(\"cuda\")\n\npipeline.load_ip_adapter(\"h94/IP-Adapter\", subfolder=\"sdxl_models\", weight_name=\"ip-adapter-plus_sdxl_vit-h.bin\")\n\npag_scales = 5.0\nip_adapter_scales = 0.8\n\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png\")\n\npipeline.set_ip_adapter_scale(ip_adapter_scale)\ngenerator = torch.Generator(device=\"cpu\").manual_seed(0)\nimages = pipeline(\n    prompt=\"a polar bear sitting in a chair drinking a milkshake\",\n    ip_adapter_image=image,\n    negative_prompt=\"deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality\",\n    num_inference_steps=25,\n    guidance_scale=3.0,\n    generator=generator,\n    pag_scale=pag_scale,\n).images\nimages[0]\n\n```\n\nPAG reduces artifacts and improves the overall compposition.\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_0.0_ipa_0.8.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image without PAG</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_5.0_ipa_0.8.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">generated image with PAG</figcaption>\n  </div>\n</div>\n\n\n## Configure parameters\n\n### pag_applied_layers\n\nThe `pag_applied_layers` argument allows you to specify which layers PAG is applied to. By default, it applies only to the mid blocks. Changing this setting will significantly impact the output. You can use the `set_pag_applied_layers` method to adjust the PAG layers after the pipeline is created, helping you find the optimal layers for your model.\n\nAs an example, here is the images generated with `pag_layers = [\"down.block_2\"]` and `pag_layers = [\"down.block_2\", \"up.block_1.attentions_0\"]`\n\n```py\nprompt = \"an insect robot preparing a delicious meal, anime style\"\npipeline.set_pag_applied_layers(pag_layers)\ngenerator = torch.Generator(device=\"cpu\").manual_seed(0)\nimages = pipeline(\n    prompt=prompt,\n    num_inference_steps=25,\n    guidance_scale=guidance_scale,\n    generator=generator,\n    pag_scale=pag_scale,\n).images\nimages[0]\n```\n\n<div class=\"flex flex-row gap-4\">\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_3.0_cfg_7.0_down2_up1a0.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">down.block_2 + up.block1.attentions_0</figcaption>\n  </div>\n  <div class=\"flex-1\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_3.0_cfg_7.0_down2.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">down.block_2</figcaption>\n  </div>\n</div>\n"
  },
  {
    "path": "docs/source/en/using-diffusers/push_to_hub.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# Sharing pipelines and models\n\nShare your pipeline or models and schedulers on the Hub with the [`~diffusers.utils.PushToHubMixin`] class. This class:\n\n1. creates a repository on the Hub\n2. saves your model, scheduler, or pipeline files so they can be reloaded later\n3. uploads folder containing these files to the Hub\n\nThis guide will show you how to upload your files to the Hub with the [`~diffusers.utils.PushToHubMixin`] class.\n\nLog in to your Hugging Face account with your access [token](https://huggingface.co/settings/tokens).\n\n<hfoptions id=\"login\">\n<hfoption id=\"notebook\">\n\n```py\nfrom huggingface_hub import notebook_login\n\nnotebook_login()\n```\n\n</hfoption>\n<hfoption id=\"hf CLI\">\n\n```bash\nhf auth login\n```\n\n</hfoption>\n</hfoptions>\n\n## Models\n\nTo push a model to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the model.\n\n```py\nfrom diffusers import ControlNetModel\n\ncontrolnet = ControlNetModel(\n    block_out_channels=(32, 64),\n    layers_per_block=2,\n    in_channels=4,\n    down_block_types=(\"DownBlock2D\", \"CrossAttnDownBlock2D\"),\n    cross_attention_dim=32,\n    conditioning_embedding_out_channels=(16, 32),\n)\ncontrolnet.push_to_hub(\"my-controlnet-model\")\n```\n\nThe [`~diffusers.utils.PushToHubMixin.push_to_hub`] method saves the model's `config.json` file and the weights are automatically saved as safetensors files.\n\nLoad the model again with [`~DiffusionPipeline.from_pretrained`].\n\n```py\nmodel = ControlNetModel.from_pretrained(\"your-namespace/my-controlnet-model\")\n```\n\n## Scheduler\n\nTo push a scheduler to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the scheduler.\n\n```py\nfrom diffusers import DDIMScheduler\n\nscheduler = DDIMScheduler(\n    beta_start=0.00085,\n    beta_end=0.012,\n    beta_schedule=\"scaled_linear\",\n    clip_sample=False,\n    set_alpha_to_one=False,\n)\nscheduler.push_to_hub(\"my-controlnet-scheduler\")\n```\n\nThe [`~diffusers.utils.PushToHubMixin.push_to_hub`] function saves the scheduler's `scheduler_config.json` file to the specified repository.\n\nLoad the scheduler again with [`~SchedulerMixin.from_pretrained`].\n\n```py\nscheduler = DDIMScheduler.from_pretrained(\"your-namepsace/my-controlnet-scheduler\")\n```\n\n## Pipeline\n\nTo push a pipeline to the Hub, initialize the pipeline components with your desired parameters.\n\n```py\nfrom diffusers import (\n    UNet2DConditionModel,\n    AutoencoderKL,\n    DDIMScheduler,\n    StableDiffusionPipeline,\n)\nfrom transformers import CLIPTextModel, CLIPTextConfig, CLIPTokenizer\n\nunet = UNet2DConditionModel(\n    block_out_channels=(32, 64),\n    layers_per_block=2,\n    sample_size=32,\n    in_channels=4,\n    out_channels=4,\n    down_block_types=(\"DownBlock2D\", \"CrossAttnDownBlock2D\"),\n    up_block_types=(\"CrossAttnUpBlock2D\", \"UpBlock2D\"),\n    cross_attention_dim=32,\n)\n\nscheduler = DDIMScheduler(\n    beta_start=0.00085,\n    beta_end=0.012,\n    beta_schedule=\"scaled_linear\",\n    clip_sample=False,\n    set_alpha_to_one=False,\n)\n\nvae = AutoencoderKL(\n    block_out_channels=[32, 64],\n    in_channels=3,\n    out_channels=3,\n    down_block_types=[\"DownEncoderBlock2D\", \"DownEncoderBlock2D\"],\n    up_block_types=[\"UpDecoderBlock2D\", \"UpDecoderBlock2D\"],\n    latent_channels=4,\n)\n\ntext_encoder_config = CLIPTextConfig(\n    bos_token_id=0,\n    eos_token_id=2,\n    hidden_size=32,\n    intermediate_size=37,\n    layer_norm_eps=1e-05,\n    num_attention_heads=4,\n    num_hidden_layers=5,\n    pad_token_id=1,\n    vocab_size=1000,\n)\ntext_encoder = CLIPTextModel(text_encoder_config)\ntokenizer = CLIPTokenizer.from_pretrained(\"hf-internal-testing/tiny-random-clip\")\n```\n\nPass all components to the pipeline and call [`~diffusers.utils.PushToHubMixin.push_to_hub`].\n\n```py\ncomponents = {\n    \"unet\": unet,\n    \"scheduler\": scheduler,\n    \"vae\": vae,\n    \"text_encoder\": text_encoder,\n    \"tokenizer\": tokenizer,\n    \"safety_checker\": None,\n    \"feature_extractor\": None,\n}\n\npipeline = StableDiffusionPipeline(**components)\npipeline.push_to_hub(\"my-pipeline\")\n```\n\nThe [`~diffusers.utils.PushToHubMixin.push_to_hub`] method saves each component to a subfolder in the repository. Load the pipeline again with [`~DiffusionPipeline.from_pretrained`].\n\n```py\npipeline = StableDiffusionPipeline.from_pretrained(\"your-namespace/my-pipeline\")\n```\n\n## Privacy\n\nSet `private=True` in [`~diffusers.utils.PushToHubMixin.push_to_hub`] to keep a model, scheduler, or pipeline files private.\n\n```py\ncontrolnet.push_to_hub(\"my-controlnet-model-private\", private=True)\n```\n\nPrivate repositories are only visible to you. Other users won't be able to clone the repository and it won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for`. You must be [logged in](https://huggingface.co/docs/huggingface_hub/quick-start#login) to load a model from a private repository."
  },
  {
    "path": "docs/source/en/using-diffusers/reusing_seeds.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Reproducibility\n\nDiffusion is a random process that generates a different output every time. For certain situations like testing and replicating results, you want to generate the same result each time, across releases and platforms within a certain tolerance range.\n\nThis guide will show you how to control sources of randomness and enable deterministic algorithms.\n\n## Generator\n\nPipelines rely on [torch.randn](https://pytorch.org/docs/stable/generated/torch.randn.html), which uses a different random seed each time, to create the initial noisy tensors. To generate the same output on a CPU or GPU, use a [Generator](https://docs.pytorch.org/docs/stable/generated/torch.Generator.html) to manage how random values are generated.\n\n> [!TIP]\n> If reproducibility is important to your use case, we recommend always using a CPU `Generator`. The performance loss is often negligible and you'll generate more similar values.\n\n<hfoptions id=\"generator\">\n<hfoption id=\"GPU\">\n\nThe GPU uses a different random number generator than the CPU. Diffusers solves this issue with the [`~utils.torch_utils.randn_tensor`] function to create the random tensor on a CPU and then moving it to the GPU. This function is used everywhere inside the pipeline and you don't need to explicitly call it.\n\nUse [manual_seed](https://docs.pytorch.org/docs/stable/generated/torch.manual_seed.html) as shown below to set a seed.\n\n```py\nimport torch\nimport numpy as np\nfrom diffusers import DDIMPipeline\n\nddim = DDIMPipeline.from_pretrained(\"google/ddpm-cifar10-32\", device_map=\"cuda\")\ngenerator = torch.manual_seed(0)\nimage = ddim(num_inference_steps=2, output_type=\"np\", generator=generator).images\nprint(np.abs(image).sum())\n```\n\n</hfoption>\n<hfoption id=\"CPU\">\n\nSet `device=\"cpu\"` in the `Generator` and use [manual_seed](https://docs.pytorch.org/docs/stable/generated/torch.manual_seed.html) to set a seed for generating random numbers.\n\n```py\nimport torch\nimport numpy as np\nfrom diffusers import DDIMPipeline\n\nddim = DDIMPipeline.from_pretrained(\"google/ddpm-cifar10-32\")\ngenerator = torch.Generator(device=\"cpu\").manual_seed(0)\nimage = ddim(num_inference_steps=2, output_type=\"np\", generator=generator).images\nprint(np.abs(image).sum())\n```\n\n</hfoption>\n</hfoptions>\n\nThe `Generator` object should be passed to the pipeline instead of an integer seed. `Generator` maintains a *random state* that is consumed and modified when used. Once consumed, the same `Generator` object produces different results in subsequent calls, even across different pipelines, because it's *state* has changed.\n\n```py\ngenerator = torch.manual_seed(0)\n\nfor _ in range(5):\n-    image = pipeline(prompt, generator=generator)\n+    image = pipeline(prompt, generator=torch.manual_seed(0))\n```\n\n## Deterministic algorithms\n\nPyTorch supports [deterministic algorithms](https://docs.pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms) - where available - for certain operations so they produce the same results. Deterministic algorithms may be slower and decrease performance.\n\nUse Diffusers' [enable_full_determinism](https://github.com/huggingface/diffusers/blob/142f353e1c638ff1d20bd798402b68f72c1ebbdd/src/diffusers/utils/testing_utils.py#L861) function to enable deterministic algorithms.\n\n```py\nimport torch\nfrom diffusers_utils import enable_full_determinism\n\nenable_full_determinism()\n```\n\nUnder the hood, `enable_full_determinism` works by:\n\n- Setting the environment variable [CUBLAS_WORKSPACE_CONFIG](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) to `:16:8` to only use one buffer size during rntime. Non-deterministic behavior occurs when operations are used in more than one CUDA stream.\n- Disabling benchmarking to find the fastest convolution operation by setting `torch.backends.cudnn.benchmark=False`. Non-deterministic behavior occurs because the benchmark may select different algorithms each time depending on hardware or benchmarking noise.\n- Disabling TensorFloat32 (TF32) operations in favor of more precise and consistent full-precision operations.\n\n\n## Resources\n\nWe strongly recommend reading PyTorch's developer notes about [Reproducibility](https://docs.pytorch.org/docs/stable/notes/randomness.html). You can try to limit randomness, but it is not *guaranteed* even with an identical seed."
  },
  {
    "path": "docs/source/en/using-diffusers/schedulers.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# Schedulers\n\nA scheduler is an algorithm that provides instructions to the denoising process such as how much noise to remove at a certain step. It takes the model prediction from step *t* and applies an update for how to compute the next sample at step *t-1*. Different schedulers produce different results; some are faster while others are more accurate.\n\nDiffusers supports many schedulers and allows you to modify their timestep schedules, timestep spacing, and more, to generate high-quality images in fewer steps.\n\nThis guide will show you how to load and customize schedulers.\n\n## Loading schedulers\n\nSchedulers don't have any parameters and are defined in a configuration file. Access the `.scheduler` attribute of a pipeline to view the configuration.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, device_map=\"cuda\"\n)\npipeline.scheduler\n```\n\nLoad a different scheduler with [`~SchedulerMixin.from_pretrained`] and specify the `subfolder` argument to load the configuration file into the correct subfolder of the pipeline repository. Pass the new scheduler to the existing pipeline.\n\n```py\nfrom diffusers import DPMSolverMultistepScheduler\n\ndpm = DPMSolverMultistepScheduler.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", subfolder=\"scheduler\"\n)\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    scheduler=dpm,\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\npipeline.scheduler\n```\n\n## Timestep schedules\n\nTimestep or noise schedule decides how noise is distributed over the denoising process. The schedule can be linear or more concentrated toward the beginning or end. It is a precomputed sequence of noise levels generated from the scheduler's default configuration, but it can be customized to use other schedules.\n\n> [!TIP]\n> The `timesteps` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!\n\nThe example below uses the [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) schedule which can generate a high-quality image in 10 steps, significantly speeding up generation and reducing computation time.\n\nImport the schedule and pass it to the `timesteps` argument in the pipeline.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline, DPMSolverMultistepScheduler\nfrom diffusers.schedulers import AysSchedules\n\nsampling_schedule = AysSchedules[\"StableDiffusionXLTimesteps\"]\nprint(sampling_schedule)\n\"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]\"\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"SG161222/RealVisXL_V4.0\",\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(\n  pipeline.scheduler.config, algorithm_type=\"sde-dpmsolver++\"\n)\n\nprompt = \"A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up\"\nimage = pipeline(\n    prompt=prompt,\n    negative_prompt=\"\",\n    timesteps=sampling_schedule,\n).images[0]\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ays.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">AYS timestep schedule 10 steps</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/10.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">Linearly-spaced timestep schedule 10 steps</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/25.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">Linearly-spaced timestep schedule 25 steps</figcaption>\n  </div>\n</div>\n\n### Rescaling schedules\n\nDenoising should begin with pure noise and the signal-to-noise (SNR) ration should be zero. However, some models don't actually start from pure noise which makes it difficult to generate images at brightness extremes.\n\n> [!TIP]\n> Train your own model with `v_prediction` by adding the `--prediction_type=\"v_prediction\"` flag to your training script. You can also [search](https://huggingface.co/search/full-text?q=v_prediction&type=model) for existing models trained with `v_prediction`.\n\nTo fix this, a model must be trained with `v_prediction`. If a model is trained with `v_prediction`, then enable the following arguments in the scheduler.\n\n- Set `rescale_betas_zero_snr=True` to rescale the noise schedule to the very last timestep with exactly zero SNR\n- Set `timestep_spacing=\"trailing\"` to force sampling from the last timestep with pure noise\n\n```py\nfrom diffusers import DiffusionPipeline, DDIMScheduler\n\npipeline = DiffusionPipeline.from_pretrained(\"ptx0/pseudo-journey-v2\", device_map=\"cuda\")\n\npipeline.scheduler = DDIMScheduler.from_config(\n    pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing=\"trailing\"\n)\n```\n\nSet `guidance_rescale` in the pipeline to avoid overexposed images. A lower value increases brightness, but some details may appear washed out.\n\n```py\nprompt = \"\"\"\ncinematic photo of a snowy mountain at night with the northern lights aurora borealis\noverhead, 35mm photograph, film, professional, 4k, highly detailed\n\"\"\"\nimage = pipeline(prompt, guidance_rescale=0.7).images[0]\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/no-zero-snr.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">default Stable Diffusion v2-1 image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/zero-snr.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">image with zero SNR and trailing timestep spacing enabled</figcaption>\n  </div>\n</div>\n\n## Timestep spacing\n\nTimestep spacing refers to the specific steps *t* to sample from from the schedule. Diffusers provides three spacing types as shown below.\n\n| spacing strategy | spacing calculation | example timesteps |\n|---|---|---|\n| `leading` | evenly spaced steps | `[900, 800, 700, ..., 100, 0]` |\n| `linspace` | include first and last steps and evenly divide remaining intermediate steps | `[1000, 888.89, 777.78, ..., 111.11, 0]` |\n| `trailing` | include last step and evenly divide remaining intermediate steps beginning from the end | `[999, 899, 799, 699, 599, 499, 399, 299, 199, 99]` |\n\nPass the spacing strategy to the `timestep_spacing` argument in the scheduler.\n\n> [!TIP]\n> The `trailing` strategy typically produces higher quality images with more details with fewer steps, but the difference in quality is not as obvious for more standard step values.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline, DPMSolverMultistepScheduler\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"SG161222/RealVisXL_V4.0\",\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(\n  pipeline.scheduler.config, timestep_spacing=\"trailing\"\n)\n\nprompt = \"A cinematic shot of a cute little black cat sitting on a pumpkin at night\"\nimage = pipeline(\n    prompt=prompt,\n    negative_prompt=\"\",\n    num_inference_steps=5,\n).images[0]\nimage\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/trailing_spacing.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">trailing spacing after 5 steps</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/leading_spacing.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">leading spacing after 5 steps</figcaption>\n  </div>\n</div>\n\n## Sigmas\n\nSigmas is a measure of how noisy a sample is at a certain step as defined by the schedule. When using custom `sigmas`, the `timesteps` are calculated from these values instead of the default scheduler configuration.\n\n> [!TIP]\n> The `sigmas` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!\n\nPass the custom sigmas to the `sigmas` argument in the pipeline. The example below uses the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) from the 10-step AYS schedule.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline, DPMSolverMultistepScheduler\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"SG161222/RealVisXL_V4.0\",\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(\n  pipeline.scheduler.config, algorithm_type=\"sde-dpmsolver++\"\n)\n\nsigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]\nprompt = \"A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up\"\nimage = pipeline(\n    prompt=prompt,\n    negative_prompt=\"\",\n    sigmas=sigmas,\n).images[0]\n```\n\n### Karras sigmas\n\n[Karras sigmas](https://huggingface.co/papers/2206.00364) resamples the noise schedule for more efficient sampling by clustering sigmas more densely in the middle of the sequence where structure reconstruction is critical, while using fewer sigmas at the beginning and end where noise changes have less impact. This can increase the level of details in a generated image.\n\nSet `use_karras_sigmas=True` in the scheduler to enable it.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline, DPMSolverMultistepScheduler\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"SG161222/RealVisXL_V4.0\",\n    torch_dtype=torch.float16,\n    device_map=\"cuda\"\n)\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(\n  pipeline.scheduler.config,\n  algorithm_type=\"sde-dpmsolver++\",\n  use_karras_sigmas=True,\n)\n\nprompt = \"A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up\"\nimage = pipeline(\n    prompt=prompt,\n    negative_prompt=\"\",\n    sigmas=sigmas,\n).images[0]\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_true.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">Karras sigmas enabled</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/stevhliu/testing-images/resolve/main/karras_sigmas_false.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">Karras sigmas disabled</figcaption>\n  </div>\n</div>\n\nRefer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas. It should only be used for models trained with Karras sigmas.\n\n## Choosing a scheduler\n\nIt's important to try different schedulers to find the best one for your use case. Here are a few recommendations to help you get started.\n\n- DPM++ 2M SDE Karras is generally a good all-purpose option.\n- [`TCDScheduler`] works well for distilled models.\n- [`FlowMatchEulerDiscreteScheduler`] and [`FlowMatchHeunDiscreteScheduler`] for FlowMatch models.\n- [`EulerDiscreteScheduler`] or [`EulerAncestralDiscreteScheduler`] for generating anime style images.\n- DPM++ 2M paired with [`LCMScheduler`] on SDXL for generating realistic images.\n\n## Resources\n\n- Read the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) paper for more details about rescaling the noise schedule to enforce zero SNR."
  },
  {
    "path": "docs/source/en/using-diffusers/sdxl.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Diffusion XL\n\n[[open-in-colab]]\n\n[Stable Diffusion XL](https://huggingface.co/papers/2307.01952) (SDXL) is a powerful text-to-image generation model that iterates on the previous Stable Diffusion models in three key ways:\n\n1. the UNet is 3x larger and SDXL combines a second text encoder (OpenCLIP ViT-bigG/14) with the original text encoder to significantly increase the number of parameters\n2. introduces size and crop-conditioning to preserve training data from being discarded and gain more control over how a generated image should be cropped\n3. introduces a two-stage model process; the *base* model (can also be run as a standalone model) generates an image as an input to the *refiner* model which adds additional high-quality details\n\nThis guide will show you how to use SDXL for text-to-image, image-to-image, and inpainting.\n\nBefore you begin, make sure you have the following libraries installed:\n\n```py\n# uncomment to install the necessary libraries in Colab\n#!pip install -q diffusers transformers accelerate invisible-watermark>=0.2.0\n```\n\n> [!WARNING]\n> We recommend installing the [invisible-watermark](https://pypi.org/project/invisible-watermark/) library to help identify images that are generated. If the invisible-watermark library is installed, it is used by default. To disable the watermarker:\n>\n> ```py\n> pipeline = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)\n> ```\n\n## Load model checkpoints\n\nModel weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~StableDiffusionXLPipeline.from_pretrained`] method:\n\n```py\nfrom diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline\nimport torch\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n).to(\"cuda\")\n\nrefiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-refiner-1.0\", torch_dtype=torch.float16, use_safetensors=True, variant=\"fp16\"\n).to(\"cuda\")\n```\n\nYou can also use the [`~StableDiffusionXLPipeline.from_single_file`] method to load a model checkpoint stored in a single file format (`.ckpt` or `.safetensors`) from the Hub or locally:\n\n```py\nfrom diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline\nimport torch\n\npipeline = StableDiffusionXLPipeline.from_single_file(\n    \"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\n\nrefiner = StableDiffusionXLImg2ImgPipeline.from_single_file(\n    \"https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors\", torch_dtype=torch.float16\n).to(\"cuda\")\n```\n\n## Text-to-image\n\nFor text-to-image, pass a text prompt. By default, SDXL generates a 1024x1024 image for the best results. You can try setting the `height` and `width` parameters to 768x768 or 512x512, but anything below 512x512 is not likely to work.\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline_text2image = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n).to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\nimage = pipeline_text2image(prompt=prompt).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png\" alt=\"generated image of an astronaut in a jungle\"/>\n</div>\n\n## Image-to-image\n\nFor image-to-image, SDXL works especially well with image sizes between 768x768 and 1024x1024. Pass an initial image, and a text prompt to condition the image with:\n\n```py\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import load_image, make_image_grid\n\n# use from_pipe to avoid consuming additional memory when loading a checkpoint\npipeline = AutoPipelineForImage2Image.from_pipe(pipeline_text2image).to(\"cuda\")\n\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png\"\ninit_image = load_image(url)\nprompt = \"a dog catching a frisbee in the jungle\"\nimage = pipeline(prompt, image=init_image, strength=0.8, guidance_scale=10.5).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-img2img.png\" alt=\"generated image of a dog catching a frisbee in a jungle\"/>\n</div>\n\n## Inpainting\n\nFor inpainting, you'll need the original image and a mask of what you want to replace in the original image. Create a prompt to describe what you want to replace the masked area with.\n\n```py\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\n# use from_pipe to avoid consuming additional memory when loading a checkpoint\npipeline = AutoPipelineForInpainting.from_pipe(pipeline_text2image).to(\"cuda\")\n\nimg_url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png\"\nmask_url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png\"\n\ninit_image = load_image(img_url)\nmask_image = load_image(mask_url)\n\nprompt = \"A deep sea diver floating\"\nimage = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, guidance_scale=12.5).images[0]\nmake_image_grid([init_image, mask_image, image], rows=1, cols=3)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint.png\" alt=\"generated image of a deep sea diver in a jungle\"/>\n</div>\n\n## Refine image quality\n\nSDXL includes a [refiner model](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0) specialized in denoising low-noise stage images to generate higher-quality images from the base model. There are two ways to use the refiner:\n\n1. use the base and refiner models together to produce a refined image\n2. use the base model to produce an image, and subsequently use the refiner model to add more details to the image (this is how SDXL was originally trained)\n\n### Base + refiner model\n\nWhen you use the base and refiner model together to generate an image, this is known as an [*ensemble of expert denoisers*](https://research.nvidia.com/labs/dir/eDiff-I/). The ensemble of expert denoisers approach requires fewer overall denoising steps versus passing the base model's output to the refiner model, so it should be significantly faster to run. However, you won't be able to inspect the base model's output because it still contains a large amount of noise.\n\nAs an ensemble of expert denoisers, the base model serves as the expert during the high-noise diffusion stage and the refiner model serves as the expert during the low-noise diffusion stage. Load the base and refiner model:\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\nbase = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n).to(\"cuda\")\n\nrefiner = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-refiner-1.0\",\n    text_encoder_2=base.text_encoder_2,\n    vae=base.vae,\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n    variant=\"fp16\",\n).to(\"cuda\")\n```\n\nTo use this approach, you need to define the number of timesteps for each model to run through their respective stages. For the base model, this is controlled by the [`denoising_end`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.denoising_end) parameter and for the refiner model, it is controlled by the [`denoising_start`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline.__call__.denoising_start) parameter.\n\n> [!TIP]\n> The `denoising_end` and `denoising_start` parameters should be a float between 0 and 1. These parameters are represented as a proportion of discrete timesteps as defined by the scheduler. If you're also using the `strength` parameter, it'll be ignored because the number of denoising steps is determined by the discrete timesteps the model is trained on and the declared fractional cutoff.\n\nLet's set `denoising_end=0.8` so the base model performs the first 80% of denoising the **high-noise** timesteps and set `denoising_start=0.8` so the refiner model performs the last 20% of denoising the **low-noise** timesteps. The base model output should be in **latent** space instead of a PIL image.\n\n```py\nprompt = \"A majestic lion jumping from a big stone at night\"\n\nimage = base(\n    prompt=prompt,\n    num_inference_steps=40,\n    denoising_end=0.8,\n    output_type=\"latent\",\n).images\nimage = refiner(\n    prompt=prompt,\n    num_inference_steps=40,\n    denoising_start=0.8,\n    image=image,\n).images[0]\nimage\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lion_base.png\" alt=\"generated image of a lion on a rock at night\" />\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">default base model</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lion_refined.png\" alt=\"generated image of a lion on a rock at night in higher quality\" />\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">ensemble of expert denoisers</figcaption>\n  </div>\n</div>\n\nThe refiner model can also be used for inpainting in the [`StableDiffusionXLInpaintPipeline`]:\n\n```py\nfrom diffusers import StableDiffusionXLInpaintPipeline\nfrom diffusers.utils import load_image, make_image_grid\nimport torch\n\nbase = StableDiffusionXLInpaintPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n).to(\"cuda\")\n\nrefiner = StableDiffusionXLInpaintPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-refiner-1.0\",\n    text_encoder_2=base.text_encoder_2,\n    vae=base.vae,\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n    variant=\"fp16\",\n).to(\"cuda\")\n\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\ninit_image = load_image(img_url)\nmask_image = load_image(mask_url)\n\nprompt = \"A majestic tiger sitting on a bench\"\nnum_inference_steps = 75\nhigh_noise_frac = 0.7\n\nimage = base(\n    prompt=prompt,\n    image=init_image,\n    mask_image=mask_image,\n    num_inference_steps=num_inference_steps,\n    denoising_end=high_noise_frac,\n    output_type=\"latent\",\n).images\nimage = refiner(\n    prompt=prompt,\n    image=image,\n    mask_image=mask_image,\n    num_inference_steps=num_inference_steps,\n    denoising_start=high_noise_frac,\n).images[0]\nmake_image_grid([init_image, mask_image, image.resize((512, 512))], rows=1, cols=3)\n```\n\nThis ensemble of expert denoisers method works well for all available schedulers!\n\n### Base to refiner model\n\nSDXL gets a boost in image quality by using the refiner model to add additional high-quality details to the fully-denoised image from the base model, in an image-to-image setting.\n\nLoad the base and refiner models:\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\nbase = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n).to(\"cuda\")\n\nrefiner = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-refiner-1.0\",\n    text_encoder_2=base.text_encoder_2,\n    vae=base.vae,\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n    variant=\"fp16\",\n).to(\"cuda\")\n```\n\n> [!WARNING]\n> You can use SDXL refiner with a different base model. For example, you can use the [Hunyuan-DiT](../api/pipelines/hunyuandit) or [PixArt-Sigma](../api/pipelines/pixart_sigma) pipelines to generate images with better prompt adherence. Once you have generated an image, you can pass it to the SDXL refiner model to enhance final generation quality.\n\nGenerate an image from the base model, and set the model output to **latent** space:\n\n```py\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n\nimage = base(prompt=prompt, output_type=\"latent\").images[0]\n```\n\nPass the generated image to the refiner model:\n\n```py\nimage = refiner(prompt=prompt, image=image[None, :]).images[0]\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/init_image.png\" alt=\"generated image of an astronaut riding a green horse on Mars\" />\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">base model</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_image.png\" alt=\"higher quality generated image of an astronaut riding a green horse on Mars\" />\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">base model + refiner model</figcaption>\n  </div>\n</div>\n\nFor inpainting, load the base and the refiner model in the [`StableDiffusionXLInpaintPipeline`], remove the `denoising_end` and `denoising_start` parameters, and choose a smaller number of inference steps for the refiner.\n\n## Micro-conditioning\n\nSDXL training involves several additional conditioning techniques, which are referred to as *micro-conditioning*. These include original image size, target image size, and cropping parameters. The micro-conditionings can be used at inference time to create high-quality, centered images.\n\n> [!TIP]\n> You can use both micro-conditioning and negative micro-conditioning parameters thanks to classifier-free guidance. They are available in the [`StableDiffusionXLPipeline`], [`StableDiffusionXLImg2ImgPipeline`], [`StableDiffusionXLInpaintPipeline`], and [`StableDiffusionXLControlNetPipeline`].\n\n### Size conditioning\n\nThere are two types of size conditioning:\n\n- [`original_size`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.original_size) conditioning comes from upscaled images in the training batch (because it would be wasteful to discard the smaller images which make up almost 40% of the total training data). This way, SDXL learns that upscaling artifacts are not supposed to be present in high-resolution images. During inference, you can use `original_size` to indicate the original image resolution. Using the default value of `(1024, 1024)` produces higher-quality images that resemble the 1024x1024 images in the dataset. If you choose to use a lower resolution, such as `(256, 256)`, the model still generates 1024x1024 images, but they'll look like the low resolution images (simpler patterns, blurring) in the dataset.\n\n- [`target_size`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.target_size) conditioning comes from finetuning SDXL to support different image aspect ratios. During inference, if you use the default value of `(1024, 1024)`, you'll get an image that resembles the composition of square images in the dataset. We recommend using the same value for `target_size` and `original_size`, but feel free to experiment with other options!\n\n🤗 Diffusers also lets you specify negative conditions about an image's size to steer generation away from certain image resolutions:\n\n```py\nfrom diffusers import StableDiffusionXLPipeline\nimport torch\n\npipe = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n).to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\nimage = pipe(\n    prompt=prompt,\n    negative_original_size=(512, 512),\n    negative_target_size=(1024, 1024),\n).images[0]\n```\n\n<div class=\"flex flex-col justify-center\">\n  <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/negative_conditions.png\"/>\n  <figcaption class=\"text-center\">Images negatively conditioned on image resolutions of (128, 128), (256, 256), and (512, 512).</figcaption>\n</div>\n\n### Crop conditioning\n\nImages generated by previous Stable Diffusion models may sometimes appear to be cropped. This is because images are actually cropped during training so that all the images in a batch have the same size. By conditioning on crop coordinates, SDXL *learns* that no cropping - coordinates `(0, 0)` - usually correlates with centered subjects and complete faces (this is the default value in 🤗 Diffusers). You can experiment with different coordinates if you want to generate off-centered compositions!\n\n```py\nfrom diffusers import StableDiffusionXLPipeline\nimport torch\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n).to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\nimage = pipeline(prompt=prompt, crops_coords_top_left=(256, 0)).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-cropped.png\" alt=\"generated image of an astronaut in a jungle, slightly cropped\"/>\n</div>\n\nYou can also specify negative cropping coordinates to steer generation away from certain cropping parameters:\n\n```py\nfrom diffusers import StableDiffusionXLPipeline\nimport torch\n\npipe = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n).to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\nimage = pipe(\n    prompt=prompt,\n    negative_original_size=(512, 512),\n    negative_crops_coords_top_left=(0, 0),\n    negative_target_size=(1024, 1024),\n).images[0]\nimage\n```\n\n## Use a different prompt for each text-encoder\n\nSDXL uses two text-encoders, so it is possible to pass a different prompt to each text-encoder, which can [improve quality](https://github.com/huggingface/diffusers/issues/4004#issuecomment-1627764201). Pass your original prompt to `prompt` and the second prompt to `prompt_2` (use `negative_prompt` and `negative_prompt_2` if you're using negative prompts):\n\n```py\nfrom diffusers import StableDiffusionXLPipeline\nimport torch\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n).to(\"cuda\")\n\n# prompt is passed to OAI CLIP-ViT/L-14\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n# prompt_2 is passed to OpenCLIP-ViT/bigG-14\nprompt_2 = \"Van Gogh painting\"\nimage = pipeline(prompt=prompt, prompt_2=prompt_2).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-double-prompt.png\" alt=\"generated image of an astronaut in a jungle in the style of a van gogh painting\"/>\n</div>\n\nThe dual text-encoders also support textual inversion embeddings that need to be loaded separately as explained in the [SDXL textual inversion](textual_inversion_inference#stable-diffusion-xl) section.\n\n## Optimizations\n\nSDXL is a large model, and you may need to optimize memory to get it to run on your hardware. Here are some tips to save memory and speed up inference.\n\n1. Offload the model to the CPU with [`~StableDiffusionXLPipeline.enable_model_cpu_offload`] for out-of-memory errors:\n\n```diff\n- base.to(\"cuda\")\n- refiner.to(\"cuda\")\n+ base.enable_model_cpu_offload()\n+ refiner.enable_model_cpu_offload()\n```\n\n2. Use `torch.compile` for ~20% speed-up (you need `torch>=2.0`):\n\n```diff\n+ base.unet = torch.compile(base.unet, mode=\"reduce-overhead\", fullgraph=True)\n+ refiner.unet = torch.compile(refiner.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\n3. Enable [xFormers](../optimization/xformers) to run SDXL if `torch<2.0`:\n\n```diff\n+ base.enable_xformers_memory_efficient_attention()\n+ refiner.enable_xformers_memory_efficient_attention()\n```\n\n## Other resources\n\nIf you're interested in experimenting with a minimal version of the [`UNet2DConditionModel`] used in SDXL, take a look at the [minSDXL](https://github.com/cloneofsimo/minSDXL) implementation which is written in PyTorch and directly compatible with 🤗 Diffusers.\n"
  },
  {
    "path": "docs/source/en/using-diffusers/sdxl_turbo.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Diffusion XL Turbo\n\n[[open-in-colab]]\n\nSDXL Turbo is an adversarial time-distilled [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) (SDXL) model capable\nof running inference in as little as 1 step.\n\nThis guide will show you how to use SDXL-Turbo for text-to-image and image-to-image.\n\nBefore you begin, make sure you have the following libraries installed:\n\n```py\n# uncomment to install the necessary libraries in Colab\n#!pip install -q diffusers transformers accelerate\n```\n\n## Load model checkpoints\n\nModel weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~StableDiffusionXLPipeline.from_pretrained`] method:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"stabilityai/sdxl-turbo\", torch_dtype=torch.float16, variant=\"fp16\")\npipeline = pipeline.to(\"cuda\")\n```\n\nYou can also use the [`~StableDiffusionXLPipeline.from_single_file`] method to load a model checkpoint stored in a single file format (`.ckpt` or `.safetensors`) from the Hub or locally. For this loading method, you need to set `timestep_spacing=\"trailing\"` (feel free to experiment with the other scheduler config values to get better results):\n\n```py\nfrom diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler\nimport torch\n\npipeline = StableDiffusionXLPipeline.from_single_file(\n    \"https://huggingface.co/stabilityai/sdxl-turbo/blob/main/sd_xl_turbo_1.0_fp16.safetensors\",\n    torch_dtype=torch.float16, variant=\"fp16\")\npipeline = pipeline.to(\"cuda\")\npipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing=\"trailing\")\n```\n\n## Text-to-image\n\nFor text-to-image, pass a text prompt. By default, SDXL Turbo generates a 512x512 image, and that resolution gives the best results. You can try setting the `height` and `width` parameters to 768x768 or 1024x1024, but you should expect quality degradations when doing so.\n\nMake sure to set `guidance_scale` to 0.0 to disable, as the model was trained without it. A single inference step is enough to generate high quality images.\nIncreasing the number of steps to 2, 3 or 4 should improve image quality.\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline_text2image = AutoPipelineForText2Image.from_pretrained(\"stabilityai/sdxl-turbo\", torch_dtype=torch.float16, variant=\"fp16\")\npipeline_text2image = pipeline_text2image.to(\"cuda\")\n\nprompt = \"A cinematic shot of a baby racoon wearing an intricate italian priest robe.\"\n\nimage = pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=1).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/sdxl-turbo-text2img.png\" alt=\"generated image of a racoon in a robe\"/>\n</div>\n\n## Image-to-image\n\nFor image-to-image generation, make sure that `num_inference_steps * strength` is larger or equal to 1.\nThe image-to-image pipeline will run for `int(num_inference_steps * strength)` steps, e.g. `0.5 * 2.0 = 1` step in\nour example below.\n\n```py\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import load_image, make_image_grid\n\n# use from_pipe to avoid consuming additional memory when loading a checkpoint\npipeline_image2image = AutoPipelineForImage2Image.from_pipe(pipeline_text2image).to(\"cuda\")\n\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png\")\ninit_image = init_image.resize((512, 512))\n\nprompt = \"cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k\"\n\nimage = pipeline_image2image(prompt, image=init_image, strength=0.5, guidance_scale=0.0, num_inference_steps=2).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/sdxl-turbo-img2img.png\" alt=\"Image-to-image generation sample using SDXL Turbo\"/>\n</div>\n\n## Speed-up SDXL Turbo even more\n\n- Compile the UNet if you are using PyTorch version 2.0 or higher. The first inference run will be very slow, but subsequent ones will be much faster.\n\n```py\npipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\n- When using the default VAE, keep it in `float32` to avoid costly `dtype` conversions before and after each generation. You only need to do this one before your first generation:\n\n```py\npipe.upcast_vae()\n```\n\nAs an alternative, you can also use a [16-bit VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) created by community member [`@madebyollin`](https://huggingface.co/madebyollin) that does not need to be upcasted to `float32`.\n"
  },
  {
    "path": "docs/source/en/using-diffusers/shap-e.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Shap-E\n\n[[open-in-colab]]\n\nShap-E is a conditional model for generating 3D assets which could be used for video game development, interior design, and architecture. It is trained on a large dataset of 3D assets, and post-processed to render more views of each object and produce 16K instead of 4K point clouds. The Shap-E model is trained in two steps:\n\n1. an encoder accepts the point clouds and rendered views of a 3D asset and outputs the parameters of implicit functions that represent the asset\n2. a diffusion model is trained on the latents produced by the encoder to generate either neural radiance fields (NeRFs) or a textured 3D mesh, making it easier to render and use the 3D asset in downstream applications\n\nThis guide will show you how to use Shap-E to start generating your own 3D assets!\n\nBefore you begin, make sure you have the following libraries installed:\n\n```py\n# uncomment to install the necessary libraries in Colab\n#!pip install -q diffusers transformers accelerate trimesh\n```\n\n## Text-to-3D\n\nTo generate a gif of a 3D object, pass a text prompt to the [`ShapEPipeline`]. The pipeline generates a list of image frames which are used to create the 3D object.\n\n```py\nimport torch\nfrom diffusers import ShapEPipeline\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\npipe = ShapEPipeline.from_pretrained(\"openai/shap-e\", torch_dtype=torch.float16, variant=\"fp16\")\npipe = pipe.to(device)\n\nguidance_scale = 15.0\nprompt = [\"A firecracker\", \"A birthday cupcake\"]\n\nimages = pipe(\n    prompt,\n    guidance_scale=guidance_scale,\n    num_inference_steps=64,\n    frame_size=256,\n).images\n```\n\n이제 [`~utils.export_to_gif`] 함수를 사용해 이미지 프레임 리스트를 3D 오브젝트의 gif로 변환합니다.\n\n```py\nfrom diffusers.utils import export_to_gif\n\nexport_to_gif(images[0], \"firecracker_3d.gif\")\nexport_to_gif(images[1], \"cake_3d.gif\")\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/firecracker_out.gif\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">prompt = \"A firecracker\"</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/cake_out.gif\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">prompt = \"A birthday cupcake\"</figcaption>\n  </div>\n</div>\n\n## Image-to-3D\n\nTo generate a 3D object from another image, use the [`ShapEImg2ImgPipeline`]. You can use an existing image or generate an entirely new one. Let's use the [Kandinsky 2.1](../api/pipelines/kandinsky) model to generate a new image.\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\nprior_pipeline = DiffusionPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\npipeline = DiffusionPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n\nprompt = \"A cheeseburger, white background\"\n\nimage_embeds, negative_image_embeds = prior_pipeline(prompt, guidance_scale=1.0).to_tuple()\nimage = pipeline(\n    prompt,\n    image_embeds=image_embeds,\n    negative_image_embeds=negative_image_embeds,\n).images[0]\n\nimage.save(\"burger.png\")\n```\n\nPass the cheeseburger to the [`ShapEImg2ImgPipeline`] to generate a 3D representation of it.\n\n```py\nfrom PIL import Image\nfrom diffusers import ShapEImg2ImgPipeline\nfrom diffusers.utils import export_to_gif\n\npipe = ShapEImg2ImgPipeline.from_pretrained(\"openai/shap-e-img2img\", torch_dtype=torch.float16, variant=\"fp16\").to(\"cuda\")\n\nguidance_scale = 3.0\nimage = Image.open(\"burger.png\").resize((256, 256))\n\nimages = pipe(\n    image,\n    guidance_scale=guidance_scale,\n    num_inference_steps=64,\n    frame_size=256,\n).images\n\ngif_path = export_to_gif(images[0], \"burger_3d.gif\")\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_in.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">cheeseburger</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_out.gif\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">3D cheeseburger</figcaption>\n  </div>\n</div>\n\n## Generate mesh\n\nShap-E is a flexible model that can also generate textured mesh outputs to be rendered for downstream applications. In this example, you'll convert the output into a `glb` file because the 🤗 Datasets library supports mesh visualization of `glb` files which can be rendered by the [Dataset viewer](https://huggingface.co/docs/hub/datasets-viewer#dataset-preview).\n\nYou can generate mesh outputs for both the [`ShapEPipeline`] and [`ShapEImg2ImgPipeline`] by specifying the `output_type` parameter as `\"mesh\"`:\n\n```py\nimport torch\nfrom diffusers import ShapEPipeline\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\npipe = ShapEPipeline.from_pretrained(\"openai/shap-e\", torch_dtype=torch.float16, variant=\"fp16\")\npipe = pipe.to(device)\n\nguidance_scale = 15.0\nprompt = \"A birthday cupcake\"\n\nimages = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=64, frame_size=256, output_type=\"mesh\").images\n```\n\nUse the [`~utils.export_to_ply`] function to save the mesh output as a `ply` file:\n\n> [!TIP]\n> You can optionally save the mesh output as an `obj` file with the [`~utils.export_to_obj`] function. The ability to save the mesh output in a variety of formats makes it more flexible for downstream usage!\n\n```py\nfrom diffusers.utils import export_to_ply\n\nply_path = export_to_ply(images[0], \"3d_cake.ply\")\nprint(f\"Saved to folder: {ply_path}\")\n```\n\nThen you can convert the `ply` file to a `glb` file with the trimesh library:\n\n```py\nimport trimesh\n\nmesh = trimesh.load(\"3d_cake.ply\")\nmesh_export = mesh.export(\"3d_cake.glb\", file_type=\"glb\")\n```\n\nBy default, the mesh output is focused from the bottom viewpoint but you can change the default viewpoint by applying a rotation transform:\n\n```py\nimport trimesh\nimport numpy as np\n\nmesh = trimesh.load(\"3d_cake.ply\")\nrot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])\nmesh = mesh.apply_transform(rot)\nmesh_export = mesh.export(\"3d_cake.glb\", file_type=\"glb\")\n```\n\nUpload the mesh file to your dataset repository to visualize it with the Dataset viewer!\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/3D-cake.gif\"/>\n</div>\n"
  },
  {
    "path": "docs/source/en/using-diffusers/svd.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Video Diffusion\n\n[[open-in-colab]]\n\n[Stable Video Diffusion (SVD)](https://huggingface.co/papers/2311.15127) is a powerful image-to-video generation model that can generate 2-4 second high resolution (576x1024) videos conditioned on an input image.\n\nThis guide will show you how to use SVD to generate short videos from images.\n\nBefore you begin, make sure you have the following libraries installed:\n\n```py\n# Colab에서 필요한 라이브러리를 설치하기 위해 주석을 제외하세요\n!pip install -q -U diffusers transformers accelerate\n```\n\nThe are two variants of this model, [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid) and [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt). The SVD checkpoint is trained to generate 14 frames and the SVD-XT checkpoint is further finetuned to generate 25 frames.\n\nYou'll use the SVD-XT checkpoint for this guide.\n\n```python\nimport torch\n\nfrom diffusers import StableVideoDiffusionPipeline\nfrom diffusers.utils import load_image, export_to_video\n\npipe = StableVideoDiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-video-diffusion-img2vid-xt\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipe.enable_model_cpu_offload()\n\n# Load the conditioning image\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png\")\nimage = image.resize((1024, 576))\n\ngenerator = torch.manual_seed(42)\nframes = pipe(image, decode_chunk_size=8, generator=generator).frames[0]\n\nexport_to_video(frames, \"generated.mp4\", fps=7)\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">\"source image of a rocket\"</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/output_rocket.gif\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">\"generated video from source image\"</figcaption>\n  </div>\n</div>\n\n## torch.compile\n\nYou can gain a 20-25% speedup at the expense of slightly increased memory by [compiling](../optimization/fp16#torchcompile) the UNet.\n\n```diff\n- pipe.enable_model_cpu_offload()\n+ pipe.to(\"cuda\")\n+ pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\n## Reduce memory usage\n\nVideo generation is very memory intensive because you're essentially generating `num_frames` all at once, similar to text-to-image generation with a high batch size. To reduce the memory requirement, there are multiple options that trade-off inference speed for lower memory requirement:\n\n- enable model offloading: each component of the pipeline is offloaded to the CPU once it's not needed anymore.\n- enable feed-forward chunking: the feed-forward layer runs in a loop instead of running a single feed-forward with a huge batch size.\n- reduce `decode_chunk_size`: the VAE decodes frames in chunks instead of decoding them all together. Setting `decode_chunk_size=1` decodes one frame at a time and uses the least amount of memory (we recommend adjusting this value based on your GPU memory) but the video might have some flickering.\n\n```diff\n- pipe.enable_model_cpu_offload()\n- frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]\n+ pipe.enable_model_cpu_offload()\n+ pipe.unet.enable_forward_chunking()\n+ frames = pipe(image, decode_chunk_size=2, generator=generator, num_frames=25).frames[0]\n```\n\nUsing all these tricks together should lower the memory requirement to less than 8GB VRAM.\n\n## Micro-conditioning\n\nStable Diffusion Video also accepts micro-conditioning, in addition to the conditioning image, which allows more control over the generated video:\n\n- `fps`: the frames per second of the generated video.\n- `motion_bucket_id`: the motion bucket id to use for the generated video. This can be used to control the motion of the generated video. Increasing the motion bucket id increases the motion of the generated video.\n- `noise_aug_strength`: the amount of noise added to the conditioning image. The higher the values the less the video resembles the conditioning image. Increasing this value also increases the motion of the generated video.\n\nFor example, to generate a video with more motion, use the `motion_bucket_id` and `noise_aug_strength` micro-conditioning parameters:\n\n```python\nimport torch\n\nfrom diffusers import StableVideoDiffusionPipeline\nfrom diffusers.utils import load_image, export_to_video\n\npipe = StableVideoDiffusionPipeline.from_pretrained(\n  \"stabilityai/stable-video-diffusion-img2vid-xt\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipe.enable_model_cpu_offload()\n\n# Load the conditioning image\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png\")\nimage = image.resize((1024, 576))\n\ngenerator = torch.manual_seed(42)\nframes = pipe(image, decode_chunk_size=8, generator=generator, motion_bucket_id=180, noise_aug_strength=0.1).frames[0]\nexport_to_video(frames, \"generated.mp4\", fps=7)\n```\n\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/output_rocket_with_conditions.gif)\n"
  },
  {
    "path": "docs/source/en/using-diffusers/t2i_adapter.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# T2I-Adapter\n\n[T2I-Adapter](https://huggingface.co/papers/2302.08453) is an adapter that enables controllable generation like [ControlNet](./controlnet). A T2I-Adapter works by learning a *mapping* between a control signal (for example, a depth map) and a pretrained model's internal knowledge. The adapter is plugged in to the base model to provide extra guidance based on the control signal during generation.\n\nLoad a T2I-Adapter conditioned on a specific control, such as canny edge, and pass it to the pipeline in [`~DiffusionPipeline.from_pretrained`].\n\n```py\nimport torch\nfrom diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, AutoencoderKL\n\nt2i_adapter = T2IAdapter.from_pretrained(\n    \"TencentARC/t2i-adapter-canny-sdxl-1.0\",\n    torch_dtype=torch.float16,\n)\n```\n\nGenerate a canny image with [opencv-python](https://github.com/opencv/opencv-python).\n\n```py\nimport cv2\nimport numpy as np\nfrom PIL import Image\nfrom diffusers.utils import load_image\n\noriginal_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png\"\n)\n\nimage = np.array(original_image)\n\nlow_threshold = 100\nhigh_threshold = 200\n\nimage = cv2.Canny(image, low_threshold, high_threshold)\nimage = image[:, :, None]\nimage = np.concatenate([image, image, image], axis=2)\ncanny_image = Image.fromarray(image)\n```\n\nPass the canny image to the pipeline to generate an image.\n\n```py\nvae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16)\npipeline = StableDiffusionXLAdapterPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    adapter=t2i_adapter,\n    vae=vae,\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n\nprompt = \"\"\"\nA photorealistic overhead image of a cat reclining sideways in a flamingo pool floatie holding a margarita. \nThe cat is floating leisurely in the pool and completely relaxed and happy.\n\"\"\"\n\npipeline(\n    prompt, \n    image=canny_image,\n    num_inference_steps=100, \n    guidance_scale=10,\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png\" width=\"300\" alt=\"Generated image (prompt only)\"/>\n    <figcaption style=\"text-align: center;\">original image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png\" width=\"300\" alt=\"Control image (Canny edges)\"/>\n    <figcaption style=\"text-align: center;\">canny image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-canny-cat-generated.png\" width=\"300\" alt=\"Generated image (ControlNet + prompt)\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>\n\n## MultiAdapter\n\nYou can compose multiple controls, such as canny image and a depth map, with the [`MultiAdapter`] class.\n\nThe example below composes a canny image and depth map.\n\nLoad the control images and T2I-Adapters as a list.\n\n```py\nimport torch\nfrom diffusers.utils import load_image\nfrom diffusers import StableDiffusionXLAdapterPipeline, AutoencoderKL, MultiAdapter, T2IAdapter\n\ncanny_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png\"\n)\ndepth_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_depth_image.png\"\n)\ncontrols = [canny_image, depth_image]\nprompt = [\"\"\"\na relaxed rabbit sitting on a striped towel next to a pool with a tropical drink nearby, \nbright sunny day, vacation scene, 35mm photograph, film, professional, 4k, highly detailed\n\"\"\"]\n\nadapters = MultiAdapter(\n    [\n        T2IAdapter.from_pretrained(\"TencentARC/t2i-adapter-canny-sdxl-1.0\", torch_dtype=torch.float16),\n        T2IAdapter.from_pretrained(\"TencentARC/t2i-adapter-depth-midas-sdxl-1.0\", torch_dtype=torch.float16),\n    ]\n)\n```\n\nPass the adapters, prompt, and control images to [`StableDiffusionXLAdapterPipeline`]. Use the `adapter_conditioning_scale` parameter to determine how much weight to assign to each control.\n\n```py\nvae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16)\npipeline = StableDiffusionXLAdapterPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    vae=vae,\n    adapter=adapters,\n).to(\"cuda\")\n\npipeline(\n    prompt,\n    image=controls,\n    height=1024,\n    width=1024,\n    adapter_conditioning_scale=[0.7, 0.7]\n).images[0]\n```\n\n<div style=\"display: flex; gap: 10px; justify-content: space-around; align-items: flex-end;\">\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/canny-cat.png\" width=\"300\" alt=\"Generated image (prompt only)\"/>\n    <figcaption style=\"text-align: center;\">canny image</figcaption>\n  </figure>\n  <figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_depth_image.png\" width=\"300\" alt=\"Control image (Canny edges)\"/>\n    <figcaption style=\"text-align: center;\">depth map</figcaption>\n  </figure>\n  <figure> \n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-multi-rabbit.png\" width=\"300\" alt=\"Generated image (ControlNet + prompt)\"/>\n    <figcaption style=\"text-align: center;\">generated image</figcaption>\n  </figure>\n</div>"
  },
  {
    "path": "docs/source/en/using-diffusers/text-img2vid.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Video generation\n\nVideo generation models extend image generation (can be considered a 1-frame video) to also process data related to space and time. Making sure all this data - text, space, time - remain consistent and aligned from frame-to-frame is a big challenge in generating long and high-resolution videos.\n\nModern video models tackle this challenge with the diffusion transformer (DiT) architecture. This reduces computational costs and allows more efficient scaling to larger and higher-quality image and video data.\n\nCheck out what some of these video models are capable of below.\n\n<hfoptions id=\"popular models\">\n<hfoption id=\"Wan2.1\">\n\n```py\n# pip install ftfy\nimport torch\nimport numpy as np\nfrom diffusers import AutoModel, WanPipeline\nfrom diffusers.hooks.group_offloading import apply_group_offloading\nfrom diffusers.utils import export_to_video, load_image\nfrom transformers import UMT5EncoderModel\n\ntext_encoder = UMT5EncoderModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"text_encoder\", torch_dtype=torch.bfloat16)\nvae = AutoModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"vae\", torch_dtype=torch.float32)\ntransformer = AutoModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n\n# group-offloading\nonload_device = torch.device(\"cuda\")\noffload_device = torch.device(\"cpu\")\napply_group_offloading(text_encoder,\n    onload_device=onload_device,\n    offload_device=offload_device,\n    offload_type=\"block_level\",\n    num_blocks_per_group=4\n)\ntransformer.enable_group_offload(\n    onload_device=onload_device,\n    offload_device=offload_device,\n    offload_type=\"leaf_level\",\n    use_stream=True\n)\n\npipeline = WanPipeline.from_pretrained(\n    \"Wan-AI/Wan2.1-T2V-14B-Diffusers\",\n    vae=vae,\n    transformer=transformer,\n    text_encoder=text_encoder,\n    torch_dtype=torch.bfloat16\n)\npipeline.to(\"cuda\")\n\nprompt = \"\"\"\nThe camera rushes from far to near in a low-angle shot, \nrevealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in \nfor a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. \nBirch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic \nshadows and warm highlights. Medium composition, front view, low angle, with depth of field.\n\"\"\"\nnegative_prompt = \"\"\"\nBright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, \nlow quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, \nmisshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\n\"\"\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=81,\n    guidance_scale=5.0,\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=16)\n```\n\n</hfoption>\n<hfoption id=\"HunyuanVideo\">\n\n```py\nimport torch\nfrom diffusers importAutoModel, HunyuanVideoPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\nfrom diffusers.utils import export_to_video\n\n# quantize weights to int4 with bitsandbytes\npipeline_quant_config = PipelineQuantizationConfig(\n  quant_backend=\"bitsandbytes_4bit\",\n  quant_kwargs={\n    \"load_in_4bit\": True,\n    \"bnb_4bit_quant_type\": \"nf4\",\n    \"bnb_4bit_compute_dtype\": torch.bfloat16\n    },\n  components_to_quantize=\"transformer\"\n)\n\npipeline = HunyuanVideoPipeline.from_pretrained(\n    \"hunyuanvideo-community/HunyuanVideo\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n)\n\n# model-offloading and tiling\npipeline.enable_model_cpu_offload()\npipeline.vae.enable_tiling()\n\nprompt = \"A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys.\"\nvideo = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]\nexport_to_video(video, \"output.mp4\", fps=15)\n```\n\n</hfoption>\n<hfoption id=\"LTX-Video\">\n\n```py\nimport torch\nfrom diffusers import LTXPipeline, AutoModel\nfrom diffusers.hooks import apply_group_offloading\nfrom diffusers.utils import export_to_video\n\n# fp8 layerwise weight-casting\ntransformer = AutoModel.from_pretrained(\n    \"Lightricks/LTX-Video\",\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16\n)\ntransformer.enable_layerwise_casting(\n    storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16\n)\n\npipeline = LTXPipeline.from_pretrained(\"Lightricks/LTX-Video\", transformer=transformer, torch_dtype=torch.bfloat16)\n\n# group-offloading\nonload_device = torch.device(\"cuda\")\noffload_device = torch.device(\"cpu\")\npipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=\"leaf_level\", use_stream=True)\napply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type=\"block_level\", num_blocks_per_group=2)\napply_group_offloading(pipeline.vae, onload_device=onload_device, offload_type=\"leaf_level\")\n\nprompt = \"\"\"\nA woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage\n\"\"\"\nnegative_prompt = \"worst quality, inconsistent motion, blurry, jittery, distorted\"\n\nvideo = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    width=768,\n    height=512,\n    num_frames=161,\n    decode_timestep=0.03,\n    decode_noise_scale=0.025,\n    num_inference_steps=50,\n).frames[0]\nexport_to_video(video, \"output.mp4\", fps=24)\n```\n\n</hfoption>\n</hfoptions>\n\nThis guide will cover video generation basics such as which parameters to configure and how to reduce their memory usage.\n\n> [!TIP]\n> If you're interested in learning more about how to use a specific model, please refer to their pipeline API model card.\n\n## Pipeline parameters\n\nThere are several parameters to configure in the pipeline that'll affect video generation quality or speed. Experimenting with different parameter values is important for discovering the appropriate quality and speed tradeoff.\n\n### num_frames\n\nA frame is a still image that is played in a sequence of other frames to create motion or a video. Control the number of frames generated per second with `num_frames`. Increasing `num_frames` increases perceived motion smoothness and visual coherence, making it especially important for videos with dynamic content. A higher `num_frames` value also increases video duration.\n\nSome video models require more specific `num_frames` values for inference. For example, [`HunyuanVideoPipeline`] recommends calculating the `num_frames` with `(4 * num_frames) +1`. Always check a pipelines API model card to see if there is a recommended value.\n\n```py\nimport torch\nfrom diffusers import LTXPipeline\nfrom diffusers.utils import export_to_video\n\npipeline = LTXPipeline.from_pretrained(\n    \"Lightricks/LTX-Video\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\nprompt = \"\"\"\nA woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman \nwith brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The \ncamera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and \nnatural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be \nreal-life footage\n\"\"\"\n\nvideo = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    width=768,\n    height=512,\n    num_frames=161,\n    decode_timestep=0.03,\n    decode_noise_scale=0.025,\n    num_inference_steps=50,\n).frames[0]\nexport_to_video(video, \"output.mp4\", fps=24)\n```\n\n### guidance_scale\n\nGuidance scale or \"cfg\" controls how closely the generated frames adhere to the input conditioning (text, image or both). Increasing `guidance_scale` generates frames that resemble the input conditions more closely and includes finer details, but risk introducing artifacts and reducing output diversity. Lower `guidance_scale` values encourages looser prompt adherence and increased output variety, but details may not be as great. If it's too low, it may ignore your prompt entirely and generate random noise.\n\n```py\nimport torch\nfrom diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel\nfrom diffusers.utils import export_to_video\n\npipeline = CogVideoXPipeline.from_pretrained(\n  \"THUDM/CogVideoX-2b\",\n  torch_dtype=torch.float16\n).to(\"cuda\")\n\nprompt = \"\"\"\nA detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over\na plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, \nwith tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an \noceanic expanse. Surrounding the ship are various other toys and children's items, hinting at \na playful environment. The scene captures the innocence and imagination of childhood, \nwith the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.\n\"\"\"\n\nvideo = pipeline(\n  prompt=prompt,\n  guidance_scale=6,\n  num_inference_steps=50\n).frames[0]\nexport_to_video(video, \"output.mp4\", fps=8)\n```\n\n### negative_prompt\n\nA negative prompt is useful for excluding things you don't want to see in the generated video. It is commonly used to refine the quality and alignment of the generated video by pushing the model away from undesirable elements like \"blurry, distorted, ugly\". This can create cleaner and more focused videos.\n\n```py\n# pip install ftfy\nimport torch\nfrom diffusers import WanPipeline\nfrom diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler\nfrom diffusers.utils import export_to_video\n\nvae = AutoencoderKLWan.from_pretrained(\n  \"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"vae\", torch_dtype=torch.float32\n)\npipeline = WanPipeline.from_pretrained(\n  \"Wan-AI/Wan2.1-T2V-14B-Diffusers\", vae=vae, torch_dtype=torch.bfloat16\n)\npipeline.scheduler = UniPCMultistepScheduler.from_config(\n  pipeline.scheduler.config, flow_shift=5.0\n)\npipeline.to(\"cuda\")\n\npipeline.load_lora_weights(\"benjamin-paine/steamboat-willie-14b\", adapter_name=\"steamboat-willie\")\npipeline.set_adapters(\"steamboat-willie\")\n\npipeline.enable_model_cpu_offload()\n\n# use \"steamboat willie style\" to trigger the LoRA\nprompt = \"\"\"\nsteamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, \nrevealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in \nfor a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. \nBirch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts \ndynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field.\n\"\"\"\n\noutput = pipeline(\n  prompt=prompt,\n  num_frames=81,\n  guidance_scale=5.0,\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=16)\n```\n\n## Reduce memory usage\n\nRecent video models like [`HunyuanVideoPipeline`] and [`WanPipeline`], which have 10B+ parameters, require a lot of memory and it often exceeds the memory available on consumer hardware. Diffusers offers several techniques for reducing the memory requirements of these large models.\n\n> [!TIP]\n> Refer to the [Reduce memory usage](../optimization/memory) guide for more details about other memory saving techniques.\n\nOne of these techniques is [group-offloading](../optimization/memory#group-offloading), which offloads groups of internal model layers (such as `torch.nn.Sequential`) to the CPU when it isn't being used. These layers are only loaded when they're needed for computation to avoid storing **all** the model components on the GPU. For a 14B parameter model like [`WanPipeline`], group-offloading can lower the required memory to ~13GB of VRAM.\n\n```py\n# pip install ftfy\nimport torch\nimport numpy as np\nfrom diffusers import AutoModel, WanPipeline\nfrom diffusers.hooks.group_offloading import apply_group_offloading\nfrom diffusers.utils import export_to_video, load_image\nfrom transformers import UMT5EncoderModel\n\ntext_encoder = UMT5EncoderModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"text_encoder\", torch_dtype=torch.bfloat16)\nvae = AutoModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"vae\", torch_dtype=torch.float32)\ntransformer = AutoModel.from_pretrained(\"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n\n# group-offloading\nonload_device = torch.device(\"cuda\")\noffload_device = torch.device(\"cpu\")\napply_group_offloading(text_encoder,\n    onload_device=onload_device,\n    offload_device=offload_device,\n    offload_type=\"block_level\",\n    num_blocks_per_group=4\n)\ntransformer.enable_group_offload(\n    onload_device=onload_device,\n    offload_device=offload_device,\n    offload_type=\"leaf_level\",\n    use_stream=True\n)\n\npipeline = WanPipeline.from_pretrained(\n    \"Wan-AI/Wan2.1-T2V-14B-Diffusers\",\n    vae=vae,\n    transformer=transformer,\n    text_encoder=text_encoder,\n    torch_dtype=torch.bfloat16\n)\npipeline.to(\"cuda\")\n\nprompt = \"\"\"\nThe camera rushes from far to near in a low-angle shot, \nrevealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in \nfor a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. \nBirch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic \nshadows and warm highlights. Medium composition, front view, low angle, with depth of field.\n\"\"\"\nnegative_prompt = \"\"\"\nBright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, \nlow quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, \nmisshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\n\"\"\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=81,\n    guidance_scale=5.0,\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=16)\n```\n\nAnother option for reducing memory is to consider quantizing a model, which stores the model weights in a lower precision data type. However, quantization may impact video quality depending on the specific video model. Refer to the quantization [Overivew](../quantization/overview) to learn more about the different supported quantization backends.\n\nThe example below uses [bitsandbytes](../quantization/bitsandbytes) to quantize a model.\n\n```py\n# pip install ftfy\n\nimport torch\nfrom diffusers import WanPipeline\nfrom diffusers import AutoModel, WanPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\nfrom diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler\nfrom transformers import UMT5EncoderModel\nfrom diffusers.utils import export_to_video\n\n# quantize transformer and text encoder weights with bitsandbytes\npipeline_quant_config = PipelineQuantizationConfig(\n  quant_backend=\"bitsandbytes_4bit\",\n  quant_kwargs={\"load_in_4bit\": True},\n  components_to_quantize=[\"transformer\", \"text_encoder\"]\n)\n\nvae = AutoModel.from_pretrained(\n  \"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"vae\", torch_dtype=torch.float32\n)\npipeline = WanPipeline.from_pretrained(\n  \"Wan-AI/Wan2.1-T2V-14B-Diffusers\", vae=vae, quantization_config=pipeline_quant_config, torch_dtype=torch.bfloat16\n)\npipeline.scheduler = UniPCMultistepScheduler.from_config(\n  pipeline.scheduler.config, flow_shift=5.0\n)\npipeline.to(\"cuda\")\n\npipeline.load_lora_weights(\"benjamin-paine/steamboat-willie-14b\", adapter_name=\"steamboat-willie\")\npipeline.set_adapters(\"steamboat-willie\")\n\npipeline.enable_model_cpu_offload()\n\n# use \"steamboat willie style\" to trigger the LoRA\nprompt = \"\"\"\nsteamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, \nrevealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in \nfor a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. \nBirch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts \ndynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field.\n\"\"\"\n\noutput = pipeline(\n  prompt=prompt,\n  num_frames=81,\n  guidance_scale=5.0,\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=16)\n```\n\n## Inference speed\n\n[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial_.html) can speedup inference by using optimized kernels. Compilation takes longer the first time, but once compiled, it is much faster. It is best to compile the pipeline once, and then use the pipeline multiple times without changing anything. A change, such as in the image size, triggers recompilation.\n\nThe example below compiles the transformer in the pipeline and uses the `\"max-autotune\"` mode to maximize performance.\n\n```py\nimport torch\nfrom diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel\nfrom diffusers.utils import export_to_video\n\npipeline = CogVideoXPipeline.from_pretrained(\n  \"THUDM/CogVideoX-2b\",\n  torch_dtype=torch.float16\n).to(\"cuda\")\n\n# torch.compile\npipeline.transformer.to(memory_format=torch.channels_last)\npipeline.transformer = torch.compile(\n    pipeline.transformer, mode=\"max-autotune\", fullgraph=True\n)\n\nprompt = \"\"\"\nA detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. \nThe ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. \nSurrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, \nwith the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.\n\"\"\"\n\nvideo = pipeline(\n  prompt=prompt,\n  guidance_scale=6,\n  num_inference_steps=50\n).frames[0]\nexport_to_video(video, \"output.mp4\", fps=8)\n```"
  },
  {
    "path": "docs/source/en/using-diffusers/textual_inversion_inference.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Textual Inversion\n\n[Textual Inversion](https://huggingface.co/papers/2208.01618) is a method for generating personalized images of a concept. It works by fine-tuning a models word embeddings on 3-5 images of the concept (for example, pixel art) that is associated with a unique token (`<sks>`). This allows you to use the `<sks>` token in your prompt to trigger the model to generate pixel art images.\n\nTextual Inversion weights are very lightweight and typically only a few KBs because they're only word embeddings. However, this also means the word embeddings need to be loaded after loading a model with [`~DiffusionPipeline.from_pretrained`].\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\n```\n\nLoad the word embeddings with [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] and include the unique token in the prompt to activate its generation.\n\n```py\npipeline.load_textual_inversion(\"sd-concepts-library/gta5-artwork\")\nprompt = \"A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration, <gta5-artwork> style\"\npipeline(prompt).images[0]\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_txt_embed.png\" />\n</div>\n\nTextual Inversion can also be trained to learn *negative embeddings* to steer generation away from unwanted characteristics such as \"blurry\" or \"ugly\". It is useful for improving image quality.\n\nEasyNegative is a widely used negative embedding that contains multiple learned negative concepts. Load the negative embeddings and specify the file name and token associated with the negative embeddings. Pass the token to `negative_prompt` in your pipeline to activate it.\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.load_textual_inversion(\n    \"EvilEngine/easynegative\",\n    weight_name=\"easynegative.safetensors\",\n    token=\"easynegative\"\n)\nprompt = \"A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration\"\nnegative_prompt = \"easynegative\"\npipeline(prompt, negative_prompt).images[0]\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png\" />\n</div>"
  },
  {
    "path": "docs/source/en/using-diffusers/unconditional_image_generation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Unconditional image generation\n\n[[open-in-colab]]\n\nUnconditional image generation generates images that look like a random sample from the training data the model was trained on because the denoising process is not guided by any additional context like text or image.\n\nTo get started, use the [`DiffusionPipeline`] to load the [anton-l/ddpm-butterflies-128](https://huggingface.co/anton-l/ddpm-butterflies-128) checkpoint to generate images of butterflies. The [`DiffusionPipeline`] downloads and caches all the model components required to generate an image.\n\n```py\nfrom diffusers import DiffusionPipeline\n\ngenerator = DiffusionPipeline.from_pretrained(\"anton-l/ddpm-butterflies-128\").to(\"cuda\")\nimage = generator().images[0]\nimage\n```\n\n> [!TIP]\n> Want to generate images of something else? Take a look at the training [guide](../training/unconditional_training) to learn how to train a model to generate your own images.\n\nThe output image is a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object that can be saved:\n\n```py\nimage.save(\"generated_image.png\")\n```\n\nYou can also try experimenting with the `num_inference_steps` parameter, which controls the number of denoising steps. More denoising steps typically produce higher quality images, but it'll take longer to generate. Feel free to play around with this parameter to see how it affects the image quality.\n\n```py\nimage = generator(num_inference_steps=100).images[0]\nimage\n```\n\nTry out the Space below to generate an image of a butterfly!\n\n<iframe\n\tsrc=\"https://stevhliu-unconditional-image-generation.hf.space\"\n\tframeborder=\"0\"\n\twidth=\"850\"\n\theight=\"500\"\n></iframe>\n"
  },
  {
    "path": "docs/source/en/using-diffusers/weighted_prompts.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# Prompting\n\nPrompts describes what a model should generate. Good prompts are detailed, specific, and structured and they generate better images and videos.\n\nThis guide shows you how to write effective prompts and introduces techniques that make them stronger.\n\n## Writing good prompts\n\nEvery effective prompt needs three core elements.\n\n1. <span class=\"underline decoration-sky-500 decoration-2 underline-offset-4\">Subject</span> - what you want to generate. Start your prompt here.\n2. <span class=\"underline decoration-pink-500 decoration-2 underline-offset-4\">Style</span> - the medium or aesthetic. How should it look?\n3. <span class=\"underline decoration-green-500 decoration-2 underline-offset-4\">Context</span> - details about actions, setting, and mood.\n\nUse these elements as a structured narrative, not a keyword list. Modern models understand language better than keyword matching. Start simple, then add details.\n\nContext is especially important for creating better prompts. Try adding lighting, artistic details, and mood.\n\n<div class=\"flex gap-4\">\n  <div class=\"flex-1 text-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ok-prompt.png\" class=\"w-full h-auto object-cover rounded-lg\">\n    <figcaption class=\"mt-2 text-sm text-gray-500\">A <span class=\"underline decoration-sky-500 decoration-2 underline-offset-1\">cute cat</span> <span class=\"underline decoration-pink-500 decoration-2 underline-offset-1\">lounges on a leaf in a pool during a peaceful summer afternoon</span>, in <span class=\"underline decoration-green-500 decoration-2 underline-offset-1\">lofi art style, illustration</span>.</figcaption>\n  </div>\n  <div class=\"flex-1 text-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/better-prompt.png\" class=\"w-full h-auto object-cover rounded-lg\"/>\n    <figcaption class=\"mt-2 text-sm text-gray-500\">A cute cat lounges on a floating leaf in a sparkling pool during a peaceful summer afternoon. Clear reflections ripple across the water, with sunlight casting soft, smooth highlights. The illustration is detailed and polished, with elegant lines and harmonious colors, evoking a relaxing, serene, and whimsical lofi mood, anime-inspired and visually comforting.</figcaption>\n  </div>\n</div>\n\nBe specific and add context. Use photography terms like lens type, focal length, camera angles, and depth of field.\n\n> [!TIP]\n> Try a [prompt enhancer](https://huggingface.co/models?sort=downloads&search=prompt+enhancer) to help improve your prompt structure.\n\n## Prompt weighting\n\nPrompt weighting makes some words stronger and others weaker. It scales attention scores so you control how much influence each concept has.\n\nDiffusers handles this through `prompt_embeds` and `pooled_prompt_embeds` arguments which take scaled text embedding vectors. Use the [sd_embed](https://github.com/xhinker/sd_embed) library to generate these embeddings. It also supports longer prompts.\n\n> [!NOTE]\n> The sd_embed library only supports Stable Diffusion, Stable Diffusion XL, Stable Diffusion 3, Stable Cascade, and Flux. Prompt weighting doesn't necessarily help for newer models like Flux which already has very good prompt adherence.\n\n```py\n!uv pip install git+https://github.com/xhinker/sd_embed.git@main\n```\n\nFormat weighted text with numerical multipliers or parentheses. More parentheses mean stronger weighting.\n\n| format | multiplier |\n|---|---|\n| `(cat)` | increase by 1.1x |\n| `((cat))` | increase by 1.21x |\n| `(cat:1.5)` | increase by 1.5x |\n| `(cat:0.5)` | decrease by 4x |\n\nCreate a weighted prompt and pass it to [get_weighted_text_embeddings_sdxl](https://github.com/xhinker/sd_embed/blob/4a47f71150a22942fa606fb741a1c971d95ba56f/src/sd_embed/embedding_funcs.py#L405) to generate embeddings.\n\n> [!TIP]\n> You could also pass negative prompts to `negative_prompt_embeds` and `negative_pooled_prompt_embeds`.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"Lykon/dreamshaper-xl-1-0\", torch_dtype=torch.bfloat16, device_map=\"cuda\"\n)\n\nprompt = \"\"\"\nA (cute cat:1.4) lounges on a (floating leaf:1.2) in a (sparkling pool:1.1) during a peaceful summer afternoon.\nGentle ripples reflect pastel skies, while (sunlight:1.1) casts soft highlights. The illustration is smooth and polished\nwith elegant, sketchy lines and subtle gradients, evoking a ((whimsical, nostalgic, dreamy lofi atmosphere:2.0)), \n(anime-inspired:1.6), calming, comforting, and visually serene.\n\"\"\"\n\nprompt_embeds, _, pooled_prompt_embeds, *_ = get_weighted_text_embeddings_sdxl(pipeline, prompt=prompt)\n```\n\nPass the embeddings to `prompt_embeds` and `pooled_prompt_embeds` to generate your image.\n\n```py\nimage = pipeline(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds).images[0]\n```\n\n<div class=\"flex justify-center\">\n  <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/prompt-embed-sdxl.png\"/>\n</div>\n\nPrompt weighting works with [Textual inversion](./textual_inversion_inference) and [DreamBooth](./dreambooth) adapters too."
  },
  {
    "path": "docs/source/en/using-diffusers/write_own_pipeline.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Understanding pipelines, models and schedulers\n\n[[open-in-colab]]\n\n🧨 Diffusers is designed to be a user-friendly and flexible toolbox for building diffusion systems tailored to your use-case. At the core of the toolbox are models and schedulers. While the [`DiffusionPipeline`] bundles these components together for convenience, you can also unbundle the pipeline and use the models and schedulers separately to create new diffusion systems.\n\nIn this tutorial, you'll learn how to use models and schedulers to assemble a diffusion system for inference, starting with a basic pipeline and then progressing to the Stable Diffusion pipeline.\n\n## Deconstruct a basic pipeline\n\nA pipeline is a quick and easy way to run a model for inference, requiring no more than four lines of code to generate an image:\n\n```py\n>>> from diffusers import DDPMPipeline\n\n>>> ddpm = DDPMPipeline.from_pretrained(\"google/ddpm-cat-256\", use_safetensors=True).to(\"cuda\")\n>>> image = ddpm(num_inference_steps=25).images[0]\n>>> image\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ddpm-cat.png\" alt=\"Image of cat created from DDPMPipeline\"/>\n</div>\n\nThat was super easy, but how did the pipeline do that? Let's breakdown the pipeline and take a look at what's happening under the hood.\n\nIn the example above, the pipeline contains a [`UNet2DModel`] model and a [`DDPMScheduler`]. The pipeline denoises an image by taking random noise the size of the desired output and passing it through the model several times. At each timestep, the model predicts the *noise residual* and the scheduler uses it to predict a less noisy image. The pipeline repeats this process until it reaches the end of the specified number of inference steps.\n\nTo recreate the pipeline with the model and scheduler separately, let's write our own denoising process.\n\n1. Load the model and scheduler:\n\n```py\n>>> from diffusers import DDPMScheduler, UNet2DModel\n\n>>> scheduler = DDPMScheduler.from_pretrained(\"google/ddpm-cat-256\")\n>>> model = UNet2DModel.from_pretrained(\"google/ddpm-cat-256\", use_safetensors=True).to(\"cuda\")\n```\n\n2. Set the number of timesteps to run the denoising process for:\n\n```py\n>>> scheduler.set_timesteps(50)\n```\n\n3. Setting the scheduler timesteps creates a tensor with evenly spaced elements in it, 50 in this example. Each element corresponds to a timestep at which the model denoises an image. When you create the denoising loop later, you'll iterate over this tensor to denoise an image:\n\n```py\n>>> scheduler.timesteps\ntensor([980, 960, 940, 920, 900, 880, 860, 840, 820, 800, 780, 760, 740, 720,\n    700, 680, 660, 640, 620, 600, 580, 560, 540, 520, 500, 480, 460, 440,\n    420, 400, 380, 360, 340, 320, 300, 280, 260, 240, 220, 200, 180, 160,\n    140, 120, 100,  80,  60,  40,  20,   0])\n```\n\n4. Create some random noise with the same shape as the desired output:\n\n```py\n>>> import torch\n\n>>> sample_size = model.config.sample_size\n>>> noise = torch.randn((1, 3, sample_size, sample_size), device=\"cuda\")\n```\n\n5. Now write a loop to iterate over the timesteps. At each timestep, the model does a [`UNet2DModel.forward`] pass and returns the noisy residual. The scheduler's [`~DDPMScheduler.step`] method takes the noisy residual, timestep, and input and it predicts the image at the previous timestep. This output becomes the next input to the model in the denoising loop, and it'll repeat until it reaches the end of the `timesteps` array.\n\n```py\n>>> input = noise\n\n>>> for t in scheduler.timesteps:\n...     with torch.no_grad():\n...         noisy_residual = model(input, t).sample\n...     previous_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample\n...     input = previous_noisy_sample\n```\n\nThis is the entire denoising process, and you can use this same pattern to write any diffusion system.\n\n6. The last step is to convert the denoised output into an image:\n\n```py\n>>> from PIL import Image\n>>> import numpy as np\n\n>>> image = (input / 2 + 0.5).clamp(0, 1).squeeze()\n>>> image = (image.permute(1, 2, 0) * 255).round().to(torch.uint8).cpu().numpy()\n>>> image = Image.fromarray(image)\n>>> image\n```\n\nIn the next section, you'll put your skills to the test and breakdown the more complex Stable Diffusion pipeline. The steps are more or less the same. You'll initialize the necessary components, and set the number of timesteps to create a `timestep` array. The `timestep` array is used in the denoising loop, and for each element in this array, the model predicts a less noisy image. The denoising loop iterates over the `timestep`'s, and at each timestep, it outputs a noisy residual and the scheduler uses it to predict a less noisy image at the previous timestep. This process is repeated until you reach the end of the `timestep` array.\n\nLet's try it out!\n\n## Deconstruct the Stable Diffusion pipeline\n\nStable Diffusion is a text-to-image *latent diffusion* model. It is called a latent diffusion model because it works with a lower-dimensional representation of the image instead of the actual pixel space, which makes it more memory efficient. The encoder compresses the image into a smaller representation, and a decoder converts the compressed representation back into an image. For text-to-image models, you'll need a tokenizer and an encoder to generate text embeddings. From the previous example, you already know you need a UNet model and a scheduler.\n\nAs you can see, this is already more complex than the DDPM pipeline which only contains a UNet model. The Stable Diffusion model has three separate pretrained models.\n\n> [!TIP]\n> 💡 Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog for more details about how the VAE, UNet, and text encoder models work.\n\nNow that you know what you need for the Stable Diffusion pipeline, load all these components with the [`~ModelMixin.from_pretrained`] method. You can find them in the pretrained [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint, and each component is stored in a separate subfolder:\n\n```py\n>>> from PIL import Image\n>>> import torch\n>>> from transformers import CLIPTextModel, CLIPTokenizer\n>>> from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler\n\n>>> vae = AutoencoderKL.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"vae\", use_safetensors=True)\n>>> tokenizer = CLIPTokenizer.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"tokenizer\")\n>>> text_encoder = CLIPTextModel.from_pretrained(\n...     \"CompVis/stable-diffusion-v1-4\", subfolder=\"text_encoder\", use_safetensors=True\n... )\n>>> unet = UNet2DConditionModel.from_pretrained(\n...     \"CompVis/stable-diffusion-v1-4\", subfolder=\"unet\", use_safetensors=True\n... )\n```\n\nInstead of the default [`PNDMScheduler`], exchange it for the [`UniPCMultistepScheduler`] to see how easy it is to plug a different scheduler in:\n\n```py\n>>> from diffusers import UniPCMultistepScheduler\n\n>>> scheduler = UniPCMultistepScheduler.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"scheduler\")\n```\n\nTo speed up inference, move the models to a GPU since, unlike the scheduler, they have trainable weights:\n\n```py\n>>> torch_device = \"cuda\"\n>>> vae.to(torch_device)\n>>> text_encoder.to(torch_device)\n>>> unet.to(torch_device)\n```\n\n### Create text embeddings\n\nThe next step is to tokenize the text to generate embeddings. The text is used to condition the UNet model and steer the diffusion process towards something that resembles the input prompt.\n\n> [!TIP]\n> 💡 The `guidance_scale` parameter determines how much weight should be given to the prompt when generating an image.\n\nFeel free to choose any prompt you like if you want to generate something else!\n\n```py\n>>> prompt = [\"a photograph of an astronaut riding a horse\"]\n>>> height = 512  # default height of Stable Diffusion\n>>> width = 512  # default width of Stable Diffusion\n>>> num_inference_steps = 25  # Number of denoising steps\n>>> guidance_scale = 7.5  # Scale for classifier-free guidance\n>>> generator = torch.manual_seed(0)  # Seed generator to create the initial latent noise\n>>> batch_size = len(prompt)\n```\n\nTokenize the text and generate the embeddings from the prompt:\n\n```py\n>>> text_input = tokenizer(\n...     prompt, padding=\"max_length\", max_length=tokenizer.model_max_length, truncation=True, return_tensors=\"pt\"\n... )\n\n>>> with torch.no_grad():\n...     text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]\n```\n\nYou'll also need to generate the *unconditional text embeddings* which are the embeddings for the padding token. These need to have the same shape (`batch_size` and `seq_length`) as the conditional `text_embeddings`:\n\n```py\n>>> max_length = text_input.input_ids.shape[-1]\n>>> uncond_input = tokenizer([\"\"] * batch_size, padding=\"max_length\", max_length=max_length, return_tensors=\"pt\")\n>>> uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]\n```\n\nLet's concatenate the conditional and unconditional embeddings into a batch to avoid doing two forward passes:\n\n```py\n>>> text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n```\n\n### Create random noise\n\nNext, generate some initial random noise as a starting point for the diffusion process. This is the latent representation of the image, and it'll be gradually denoised. At this point, the `latent` image is smaller than the final image size but that's okay though because the model will transform it into the final 512x512 image dimensions later.\n\n> [!TIP]\n> 💡 The height and width are divided by 8 because the `vae` model has 3 down-sampling layers. You can check by running the following:\n>\n> ```py\n> 2 ** (len(vae.config.block_out_channels) - 1) == 8\n> ```\n\n```py\n>>> latents = torch.randn(\n...     (batch_size, unet.config.in_channels, height // 8, width // 8),\n...     generator=generator,\n...     device=torch_device,\n... )\n```\n\n### Denoise the image\n\nStart by scaling the input with the initial noise distribution, *sigma*, the noise scale value, which is required for improved schedulers like [`UniPCMultistepScheduler`]:\n\n```py\n>>> latents = latents * scheduler.init_noise_sigma\n```\n\nThe last step is to create the denoising loop that'll progressively transform the pure noise in `latents` to an image described by your prompt. Remember, the denoising loop needs to do three things:\n\n1. Set the scheduler's timesteps to use during denoising.\n2. Iterate over the timesteps.\n3. At each timestep, call the UNet model to predict the noise residual and pass it to the scheduler to compute the previous noisy sample.\n\n```py\n>>> from tqdm.auto import tqdm\n\n>>> scheduler.set_timesteps(num_inference_steps)\n\n>>> for t in tqdm(scheduler.timesteps):\n...     # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.\n...     latent_model_input = torch.cat([latents] * 2)\n\n...     latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)\n\n...     # predict the noise residual\n...     with torch.no_grad():\n...         noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n\n...     # perform guidance\n...     noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n...     noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n...     # compute the previous noisy sample x_t -> x_t-1\n...     latents = scheduler.step(noise_pred, t, latents).prev_sample\n```\n\n### Decode the image\n\nThe final step is to use the `vae` to decode the latent representation into an image and get the decoded output with `sample`:\n\n```py\n# scale and decode the image latents with vae\nlatents = 1 / 0.18215 * latents\nwith torch.no_grad():\n    image = vae.decode(latents).sample\n```\n\nLastly, convert the image to a `PIL.Image` to see your generated image!\n\n```py\n>>> image = (image / 2 + 0.5).clamp(0, 1).squeeze()\n>>> image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()\n>>> image = Image.fromarray(image)\n>>> image\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/blog/assets/98_stable_diffusion/stable_diffusion_k_lms.png\"/>\n</div>\n\n## Next steps\n\nFrom basic to complex pipelines, you've seen that all you really need to write your own diffusion system is a denoising loop. The loop should set the scheduler's timesteps, iterate over them, and alternate between calling the UNet model to predict the noise residual and passing it to the scheduler to compute the previous noisy sample.\n\nThis is really what 🧨 Diffusers is designed for: to make it intuitive and easy to write your own diffusion system using models and schedulers.\n\nFor your next steps, feel free to:\n\n* Learn how to [build and contribute a pipeline](../conceptual/contribution) to 🧨 Diffusers. We can't wait and see what you'll come up with!\n* Explore [existing pipelines](../api/pipelines/overview) in the library, and see if you can deconstruct and build a pipeline from scratch using the models and schedulers separately.\n"
  },
  {
    "path": "docs/source/ja/_toctree.yml",
    "content": "- sections:\n  - local: index\n    title: 🧨 Diffusers\n  - local: quicktour\n    title: クイックツアー\n  - local: stable_diffusion\n    title: 有効で効率の良い拡散モデル\n  - local: installation\n    title: インストール\n  title: はじめに\n- sections:\n  - local: tutorials/tutorial_overview\n    title: 概要\n  - local: tutorials/autopipeline\n    title: AutoPipeline\n  title: チュートリアル"
  },
  {
    "path": "docs/source/ja/index.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://raw.githubusercontent.com/huggingface/diffusers/77aadfee6a891ab9fcfb780f87c693f7a5beeb8e/docs/source/imgs/diffusers_library.jpg\" width=\"400\"/>\n    <br>\n</p>\n\n# Diffusers\n\n🤗 Diffusers は、画像や音声、さらには分子の3D構造を生成するための、最先端の事前学習済みDiffusion Model(拡散モデル)を提供するライブラリです。シンプルな生成ソリューションをお探しの場合でも、独自の拡散モデルをトレーニングしたい場合でも、🤗 Diffusers はその両方をサポートするモジュール式のツールボックスです。私たちのライブラリは、[性能より使いやすさ](conceptual/philosophy#usability-over-performance)、[簡単よりシンプル](conceptual/philosophy#simple-over-easy)、[抽象化よりカスタマイズ性](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction)に重点を置いて設計されています。\n\nこのライブラリには3つの主要コンポーネントがあります:\n\n- 数行のコードで推論可能な最先端の[拡散パイプライン](api/pipelines/overview)。Diffusersには多くのパイプラインがあります。利用可能なパイプラインを網羅したリストと、それらが解決するタスクについては、パイプラインの[概要](https://huggingface.co/docs/diffusers/api/pipelines/overview)の表をご覧ください。\n- 生成速度と品質のトレードオフのバランスを取る交換可能な[ノイズスケジューラ](api/schedulers/overview)\n- ビルディングブロックとして使用することができ、スケジューラと組み合わせることで、エンドツーエンドの拡散モデルを構築可能な事前学習済み[モデル](api/models)\n\n<div class=\"mt-10\">\n  <div class=\"w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5\">\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./tutorials/tutorial_overview\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-blue-400 to-blue-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">チュートリアル</div>\n      <p class=\"text-gray-700\">出力の生成、独自の拡散システムの構築、拡散モデルのトレーニングを開始するために必要な基本的なスキルを学ぶことができます。初めて 🤗Diffusersを使用する場合は、ここから始めることをおすすめします！</p>\n    </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./using-diffusers/loading_overview\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-indigo-400 to-indigo-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">ガイド</div>\n      <p class=\"text-gray-700\">パイプライン、モデル、スケジューラの読み込みに役立つ実践的なガイドです。また、特定のタスクにパイプラインを使用する方法、出力の生成方法を制御する方法、生成速度を最適化する方法、さまざまなトレーニング手法についても学ぶことができます。</p>\n    </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./conceptual/philosophy\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-pink-400 to-pink-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Conceptual guides</div>\n      <p class=\"text-gray-700\">ライブラリがなぜこのように設計されたのかを理解し、ライブラリを利用する際の倫理的ガイドラインや安全対策について詳しく学べます。</p>\n   </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./api/models/overview\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-purple-400 to-purple-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">リファレンス</div>\n      <p class=\"text-gray-700\">🤗 Diffusersのクラスとメソッドがどのように機能するかについての技術的な説明です。</p>\n    </a>\n  </div>\n</div>"
  },
  {
    "path": "docs/source/ja/installation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# インストール\n\nお使いのディープラーニングライブラリに合わせてDiffusersをインストールできます。\n\n🤗 DiffusersはPython 3.8+、PyTorch 1.7.0+、Flaxでテストされています。使用するディープラーニングライブラリの以下のインストール手順に従ってください：\n\n- [PyTorch](https://pytorch.org/get-started/locally/)のインストール手順。\n- [Flax](https://flax.readthedocs.io/en/latest/)のインストール手順。\n\n## pip でインストール\n\nDiffusersは[仮想環境](https://docs.python.org/3/library/venv.html)の中でインストールすることが推奨されています。\nPython の仮想環境についてよく知らない場合は、こちらの [ガイド](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) を参照してください。\n仮想環境は異なるプロジェクトの管理を容易にし、依存関係間の互換性の問題を回避します。\n\nではさっそく、プロジェクトディレクトリに仮想環境を作ってみます：\n\n```bash\npython -m venv .env\n```\n\n仮想環境をアクティブにします：\n\n```bash\nsource .env/bin/activate\n```\n\n🤗 Diffusers もまた 🤗 Transformers ライブラリに依存しており、以下のコマンドで両方をインストールできます：\n\n<frameworkcontent>\n<pt>\n```bash\npip install diffusers[\"torch\"] transformers\n```\n</pt>\n<jax>\n```bash\npip install diffusers[\"flax\"] transformers\n```\n</jax>\n</frameworkcontent>\n\n## ソースからのインストール\n\nソースから🤗 Diffusersをインストールする前に、`torch`と🤗 Accelerateがインストールされていることを確認してください。\n\n`torch`のインストールについては、`torch` [インストール](https://pytorch.org/get-started/locally/#start-locally)ガイドを参照してください。\n\n🤗 Accelerateをインストールするには：\n\n```bash\npip install accelerate\n```\n\n以下のコマンドでソースから🤗 Diffusersをインストールできます：\n\n```bash\npip install git+https://github.com/huggingface/diffusers\n```\n\nこのコマンドは最新の `stable` バージョンではなく、最先端の `main` バージョンをインストールします。\n`main`バージョンは最新の開発に対応するのに便利です。\n例えば、前回の公式リリース以降にバグが修正されたが、新しいリリースがまだリリースされていない場合などには都合がいいです。\nしかし、これは `main` バージョンが常に安定しているとは限らないです。\n私たちは `main` バージョンを運用し続けるよう努力しており、ほとんどの問題は通常数時間から1日以内に解決されます。\nもし問題が発生した場合は、[Issue](https://github.com/huggingface/diffusers/issues/new/choose) を開いてください！\n\n## 編集可能なインストール\n\n以下の場合、編集可能なインストールが必要です：\n\n* ソースコードの `main` バージョンを使用する。\n* 🤗 Diffusers に貢献し、コードの変更をテストする必要がある場合。\n\nリポジトリをクローンし、次のコマンドで 🤗 Diffusers をインストールしてください：\n\n```bash\ngit clone https://github.com/huggingface/diffusers.git\ncd diffusers\n```\n\n<frameworkcontent>\n<pt>\n```bash\npip install -e \".[torch]\"\n```\n</pt>\n<jax>\n```bash\npip install -e \".[flax]\"\n```\n</jax>\n</frameworkcontent>\n\nこれらのコマンドは、リポジトリをクローンしたフォルダと Python のライブラリパスをリンクします。\nPython は通常のライブラリパスに加えて、クローンしたフォルダの中を探すようになります。\n例えば、Python パッケージが通常 `~/anaconda3/envs/main/lib/python3.10/site-packages/` にインストールされている場合、Python はクローンした `~/diffusers/` フォルダも同様に参照します。\n\n> [!WARNING]\n> ライブラリを使い続けたい場合は、`diffusers`フォルダを残しておく必要があります。\n\nこれで、以下のコマンドで簡単にクローンを最新版の🤗 Diffusersにアップデートできます：\n\n```bash\ncd ~/diffusers/\ngit pull\n```\n\nPython環境は次の実行時に `main` バージョンの🤗 Diffusersを見つけます。\n\n## テレメトリー・ロギングに関するお知らせ\n\nこのライブラリは `from_pretrained()` リクエスト中にデータを収集します。\nこのデータには Diffusers と PyTorch/Flax のバージョン、要求されたモデルやパイプラインクラスが含まれます。\nまた、Hubでホストされている場合は、事前に学習されたチェックポイントへのパスが含まれます。\nこの使用データは問題のデバッグや新機能の優先順位付けに役立ちます。\nテレメトリーはHuggingFace Hubからモデルやパイプラインをロードするときのみ送信されます。ローカルでの使用中は収集されません。\n\n我々は、すべての人が追加情報を共有したくないことを理解し、あなたのプライバシーを尊重します。\nそのため、ターミナルから `DISABLE_TELEMETRY` 環境変数を設定することで、データ収集を無効にすることができます：\n\nLinux/MacOSの場合\n```bash\nexport DISABLE_TELEMETRY=YES\n```\n\nWindows の場合\n```bash\nset DISABLE_TELEMETRY=YES\n```\n"
  },
  {
    "path": "docs/source/ja/quicktour.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# 簡単な案内\n\n拡散モデル(Diffusion Model)は、ランダムな正規分布から段階的にノイズ除去するように学習され、画像や音声などの目的のものを生成できます。これは生成AIに多大な関心を呼び起こしました。インターネット上で拡散によって生成された画像の例を見たことがあるでしょう。🧨 Diffusersは、誰もが拡散モデルに広くアクセスできるようにすることを目的としたライブラリです。\n\nこの案内では、開発者または日常的なユーザーに関わらず、🧨 Diffusers を紹介し、素早く目的のものを生成できるようにします！このライブラリには3つの主要コンポーネントがあります:\n\n* [`DiffusionPipeline`]は事前に学習された拡散モデルからサンプルを迅速に生成するために設計された高レベルのエンドツーエンドクラス。\n*  拡散システムを作成するためのビルディングブロックとして使用できる、人気のある事前学習された[モデル](./api/models)アーキテクチャとモジュール。\n*  多くの異なる[スケジューラ](./api/schedulers/overview) - ノイズがどのようにトレーニングのために加えられるか、そして生成中にどのようにノイズ除去された画像を生成するかを制御するアルゴリズム。\n\nこの案内では、[`DiffusionPipeline`]を生成に使用する方法を紹介し、モデルとスケジューラを組み合わせて[`DiffusionPipeline`]の内部で起こっていることを再現する方法を説明します。\n\n> [!TIP]\n> この案内は🧨 Diffusers [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)を簡略化したもので、すぐに使い始めることができます。Diffusers 🧨のゴール、設計哲学、コアAPIの詳細についてもっと知りたい方は、ノートブックをご覧ください！\n\n始める前に必要なライブラリーがすべてインストールされていることを確認してください：\n\n```py\n# uncomment to install the necessary libraries in Colab\n#!pip install --upgrade diffusers accelerate transformers\n```\n\n- [🤗 Accelerate](https://huggingface.co/docs/accelerate/index)生成とトレーニングのためのモデルのロードを高速化します\n- [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview)ような最も一般的な拡散モデルを実行するには、[🤗 Transformers](https://huggingface.co/docs/transformers/index)が必要です。\n\n## 拡散パイプライン\n\n[`DiffusionPipeline`]は事前学習された拡散システムを生成に使用する最も簡単な方法です。これはモデルとスケジューラを含むエンドツーエンドのシステムです。[`DiffusionPipeline`]は多くの作業／タスクにすぐに使用することができます。また、サポートされているタスクの完全なリストについては[🧨Diffusersの概要](./api/pipelines/overview#diffusers-summary)の表を参照してください。\n\n| **タスク**                     | **説明**                                                                                              | **パイプライン**\n|------------------------------|--------------------------------------------------------------------------------------------------------------|-----------------|\n| Unconditional Image Generation          | 正規分布から画像生成 | [unconditional_image_generation](./using-diffusers/unconditional_image_generation) |\n| Text-Guided Image Generation | 文章から画像生成 | [conditional_image_generation](./using-diffusers/conditional_image_generation) |\n| Text-Guided Image-to-Image Translation     | 画像と文章から新たな画像生成 | [img2img](./using-diffusers/img2img) |\n| Text-Guided Image-Inpainting          | 画像、マスク、および文章が指定された場合に、画像のマスクされた部分を文章をもとに修復 | [inpaint](./using-diffusers/inpaint) |\n| Text-Guided Depth-to-Image Translation | 文章と深度推定によって構造を保持しながら画像生成 | [depth2img](./using-diffusers/depth2img) |\n\nまず、[`DiffusionPipeline`]のインスタンスを作成し、ダウンロードしたいパイプラインのチェックポイントを指定します。\nこの[`DiffusionPipeline`]はHugging Face Hubに保存されている任意の[チェックポイント](https://huggingface.co/models?library=diffusers&sort=downloads)を使用することができます。\nこの案内では、[`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)チェックポイントでテキストから画像へ生成します。\n\n> [!WARNING]\n> [Stable Diffusion]モデルについては、モデルを実行する前にまず[ライセンス](https://huggingface.co/spaces/CompVis/stable-diffusion-license)を注意深くお読みください。🧨  Diffusers は、攻撃的または有害なコンテンツを防ぐために [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) を実装していますが、モデルの改良された画像生成機能により、潜在的に有害なコンテンツが生成される可能性があります。\n\nモデルを[`~DiffusionPipeline.from_pretrained`]メソッドでロードします：\n\n```python\n>>> from diffusers import DiffusionPipeline\n\n>>> pipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", use_safetensors=True)\n```\n[`DiffusionPipeline`]は全てのモデリング、トークン化、スケジューリングコンポーネントをダウンロードしてキャッシュします。Stable Diffusionパイプラインは[`UNet2DConditionModel`]と[`PNDMScheduler`]などで構成されています：\n\n```py\n>>> pipeline\nStableDiffusionPipeline {\n  \"_class_name\": \"StableDiffusionPipeline\",\n  \"_diffusers_version\": \"0.13.1\",\n  ...,\n  \"scheduler\": [\n    \"diffusers\",\n    \"PNDMScheduler\"\n  ],\n  ...,\n  \"unet\": [\n    \"diffusers\",\n    \"UNet2DConditionModel\"\n  ],\n  \"vae\": [\n    \"diffusers\",\n    \"AutoencoderKL\"\n  ]\n}\n```\n\nこのモデルはおよそ14億個のパラメータで構成されているため、GPU上でパイプラインを実行することを強く推奨します。\nPyTorchと同じように、ジェネレータオブジェクトをGPUに移すことができます：\n\n```python\n>>> pipeline.to(\"cuda\")\n```\n\nこれで、文章を `pipeline` に渡して画像を生成し、ノイズ除去された画像にアクセスできるようになりました。デフォルトでは、画像出力は[`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class)オブジェクトでラップされます。\n\n```python\n>>> image = pipeline(\"An image of a squirrel in Picasso style\").images[0]\n>>> image\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png\"/>\n</div>\n\n`save`関数で画像を保存できます:\n\n```python\n>>> image.save(\"image_of_squirrel_painting.png\")\n```\n\n### ローカルパイプライン\n\nローカルでパイプラインを使用することもできます。唯一の違いは、最初にウェイトをダウンロードする必要があることです：\n\n```bash\n!git lfs install\n!git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5\n```\n\n保存したウェイトをパイプラインにロードします：\n\n```python\n>>> pipeline = DiffusionPipeline.from_pretrained(\"./stable-diffusion-v1-5\", use_safetensors=True)\n```\n\nこれで、上のセクションと同じようにパイプラインを動かすことができます。\n\n### スケジューラの交換\n\nスケジューラーによって、ノイズ除去のスピードや品質のトレードオフが異なります。どれが自分に最適かを知る最善の方法は、実際に試してみることです！Diffusers 🧨の主な機能の1つは、スケジューラを簡単に切り替えることができることです。例えば、デフォルトの[`PNDMScheduler`]を[`EulerDiscreteScheduler`]に置き換えるには、[`~diffusers.ConfigMixin.from_config`]メソッドでロードできます：\n\n```py\n>>> from diffusers import EulerDiscreteScheduler\n\n>>> pipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", use_safetensors=True)\n>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)\n```\n\n新しいスケジューラを使って画像を生成し、その違いに気づくかどうか試してみてください！\n\n次のセクションでは、[`DiffusionPipeline`]を構成するコンポーネント（モデルとスケジューラ）を詳しく見て、これらのコンポーネントを使って猫の画像を生成する方法を学びます。\n\n## モデル\n\nほとんどのモデルはノイズの多いサンプルを取り、各タイムステップで*残りのノイズ*を予測します（他のモデルは前のサンプルを直接予測するか、速度または[`v-prediction`](https://github.com/huggingface/diffusers/blob/5e5ce13e2f89ac45a0066cb3f369462a3cf1d9ef/src/diffusers/schedulers/scheduling_ddim.py#L110)を予測するように学習します）。モデルを混ぜて他の拡散システムを作ることもできます。\n\nモデルは[`~ModelMixin.from_pretrained`]メソッドで開始されます。このメソッドはモデルをローカルにキャッシュするので、次にモデルをロードするときに高速になります。この案内では、[`UNet2DModel`]をロードします。これは基本的な画像生成モデルであり、猫画像で学習されたチェックポイントを使います：\n\n```py\n>>> from diffusers import UNet2DModel\n\n>>> repo_id = \"google/ddpm-cat-256\"\n>>> model = UNet2DModel.from_pretrained(repo_id, use_safetensors=True)\n```\n\nモデルのパラメータにアクセスするには、`model.config` を呼び出せます：\n\n```py\n>>> model.config\n```\n\nモデル構成は🧊凍結🧊されたディクショナリであり、モデル作成後にこれらのパラメー タを変更することはできません。これは意図的なもので、最初にモデル・アーキテクチャを定義するために使用されるパラメータが同じままであることを保証します。他のパラメータは生成中に調整することができます。\n\n最も重要なパラメータは以下の通りです：\n\n* sample_size`: 入力サンプルの高さと幅。\n* `in_channels`: 入力サンプルの入力チャンネル数。\n* down_block_types` と `up_block_types`: UNet アーキテクチャを作成するために使用されるダウンサンプリングブロックとアップサンプリングブロックのタイプ。\n* block_out_channels`: ダウンサンプリングブロックの出力チャンネル数。逆順でアップサンプリングブロックの入力チャンネル数にも使用されます。\n* layer_per_block`: 各 UNet ブロックに含まれる ResNet ブロックの数。\n\nこのモデルを生成に使用するには、ランダムな画像の形の正規分布を作成します。このモデルは複数のランダムな正規分布を受け取ることができるため`batch`軸を入れます。入力チャンネル数に対応する`channel`軸も必要です。画像の高さと幅に対応する`sample_size`軸を持つ必要があります：\n\n```py\n>>> import torch\n\n>>> torch.manual_seed(0)\n\n>>> noisy_sample = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)\n>>> noisy_sample.shape\ntorch.Size([1, 3, 256, 256])\n```\n\n画像生成には、ノイズの多い画像と `timestep` をモデルに渡します。`timestep`は入力画像がどの程度ノイズが多いかを示します。これは、モデルが拡散プロセスにおける自分の位置を決定するのに役立ちます。モデルの出力を得るには `sample` メソッドを使用します：\n\n```py\n>>> with torch.no_grad():\n...     noisy_residual = model(sample=noisy_sample, timestep=2).sample\n```\n\nしかし、実際の例を生成するには、ノイズ除去プロセスをガイドするスケジューラが必要です。次のセクションでは、モデルをスケジューラと組み合わせる方法を学びます。\n\n## スケジューラ\n\nスケジューラは、モデルの出力（この場合は `noisy_residual` ）が与えられたときに、ノイズの多いサンプルからノイズの少ないサンプルへの移行を管理します。\n\n\n> [!TIP]\n> 🧨 Diffusersは拡散システムを構築するためのツールボックスです。[`DiffusionPipeline`]は事前に構築された拡散システムを使い始めるのに便利な方法ですが、独自のモデルとスケジューラコンポーネントを個別に選択してカスタム拡散システムを構築することもできます。\n\nこの案内では、[`DDPMScheduler`]を[`~diffusers.ConfigMixin.from_config`]メソッドでインスタンス化します：\n\n```py\n>>> from diffusers import DDPMScheduler\n\n>>> scheduler = DDPMScheduler.from_config(repo_id)\n>>> scheduler\nDDPMScheduler {\n  \"_class_name\": \"DDPMScheduler\",\n  \"_diffusers_version\": \"0.13.1\",\n  \"beta_end\": 0.02,\n  \"beta_schedule\": \"linear\",\n  \"beta_start\": 0.0001,\n  \"clip_sample\": true,\n  \"clip_sample_range\": 1.0,\n  \"num_train_timesteps\": 1000,\n  \"prediction_type\": \"epsilon\",\n  \"trained_betas\": null,\n  \"variance_type\": \"fixed_small\"\n}\n```\n\n> [!TIP]\n> 💡 スケジューラがどのようにコンフィギュレーションからインスタンス化されるかに注目してください。モデルとは異なり、スケジューラは学習可能な重みを持たず、パラメーターを持ちません！\n\n最も重要なパラメータは以下の通りです：\n\n* num_train_timesteps`: ノイズ除去処理の長さ、言い換えれば、ランダムな正規分布をデータサンプルに処理するのに必要なタイムステップ数です。\n* `beta_schedule`: 生成とトレーニングに使用するノイズスケジュールのタイプ。\n* `beta_start` と `beta_end`: ノイズスケジュールの開始値と終了値。\n\n少しノイズの少ない画像を予測するには、スケジューラの [`~diffusers.DDPMScheduler.step`] メソッドに以下を渡します: モデルの出力、`timestep`、現在の `sample`。\n\n```py\n>>> less_noisy_sample = scheduler.step(model_output=noisy_residual, timestep=2, sample=noisy_sample).prev_sample\n>>> less_noisy_sample.shape\n```\n\n`less_noisy_sample`は次の`timestep`に渡すことができ、そこでさらにノイズが少なくなります！\n\nでは、すべてをまとめて、ノイズ除去プロセス全体を視覚化してみましょう。\n\nまず、ノイズ除去された画像を後処理して `PIL.Image` として表示する関数を作成します：\n\n```py\n>>> import PIL.Image\n>>> import numpy as np\n\n\n>>> def display_sample(sample, i):\n...     image_processed = sample.cpu().permute(0, 2, 3, 1)\n...     image_processed = (image_processed + 1.0) * 127.5\n...     image_processed = image_processed.numpy().astype(np.uint8)\n\n...     image_pil = PIL.Image.fromarray(image_processed[0])\n...     display(f\"Image at step {i}\")\n...     display(image_pil)\n```\n\nノイズ除去処理を高速化するために入力とモデルをGPUに移します：\n\n```py\n>>> model.to(\"cuda\")\n>>> noisy_sample = noisy_sample.to(\"cuda\")\n```\n\nここで、ノイズが少なくなったサンプルの残りのノイズを予測するノイズ除去ループを作成し、スケジューラを使ってさらにノイズの少ないサンプルを計算します：\n\n```py\n>>> import tqdm\n\n>>> sample = noisy_sample\n\n>>> for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):\n...     # 1. predict noise residual\n...     with torch.no_grad():\n...         residual = model(sample, t).sample\n\n...     # 2. compute less noisy image and set x_t -> x_t-1\n...     sample = scheduler.step(residual, t, sample).prev_sample\n\n...     # 3. optionally look at image\n...     if (i + 1) % 50 == 0:\n...         display_sample(sample, i + 1)\n```\n\n何もないところから猫が生成されるのを、座って見てください！😻\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/diffusion-quicktour.png\"/>\n</div>\n\n## 次のステップ\n\nこのクイックツアーで、🧨ディフューザーを使ったクールな画像をいくつか作成できたと思います！次のステップとして\n\n* モデルをトレーニングまたは微調整については、[training](./tutorials/basic_training)チュートリアルを参照してください。\n* 様々な使用例については、公式およびコミュニティの[training or finetuning scripts](https://github.com/huggingface/diffusers/tree/main/examples#-diffusers-examples)の例を参照してください。\n* スケジューラのロード、アクセス、変更、比較については[Using different Schedulers](./using-diffusers/schedulers)ガイドを参照してください。\n* プロンプトエンジニアリング、スピードとメモリの最適化、より高品質な画像を生成するためのヒントやトリックについては、[Stable Diffusion](./stable_diffusion)ガイドを参照してください。\n* 🧨 Diffusers の高速化については、最適化された [PyTorch on a GPU](./optimization/fp16)のガイド、[Stable Diffusion on Apple Silicon (M1/M2)](./optimization/mps)と[ONNX Runtime](./optimization/onnx)を参照してください。\n"
  },
  {
    "path": "docs/source/ja/stable_diffusion.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 効果的で効率的な拡散モデル\n\n[[open-in-colab]]\n\n[`DiffusionPipeline`]を使って特定のスタイルで画像を生成したり、希望する画像を生成したりするのは難しいことです。多くの場合、[`DiffusionPipeline`]を何度か実行してからでないと満足のいく画像は得られません。しかし、何もないところから何かを生成するにはたくさんの計算が必要です。生成を何度も何度も実行する場合、特にたくさんの計算量が必要になります。\n\nそのため、パイプラインから*計算*（速度）と*メモリ*（GPU RAM）の効率を最大限に引き出し、生成サイクル間の時間を短縮することで、より高速な反復処理を行えるようにすることが重要です。\n\nこのチュートリアルでは、[`DiffusionPipeline`]を用いて、より速く、より良い計算を行う方法を説明します。\n\nまず、[`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)モデルをロードします：\n\n```python\nfrom diffusers import DiffusionPipeline\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)\n```\n\nここで使用するプロンプトの例は年老いた戦士の長の肖像画ですが、ご自由に変更してください：\n\n```python\nprompt = \"portrait photo of a old warrior chief\"\n```\n\n## Speed\n\n> [!TIP]\n> 💡 GPUを利用できない場合は、[Colab](https://colab.research.google.com/)のようなGPUプロバイダーから無料で利用できます！\n\n画像生成を高速化する最も簡単な方法の1つは、PyTorchモジュールと同じようにGPU上にパイプラインを配置することです：\n\n```python\npipeline = pipeline.to(\"cuda\")\n```\n\n同じイメージを使って改良できるようにするには、[`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html)を使い、[reproducibility](./using-diffusers/reusing_seeds)の種を設定します：\n\n```python\nimport torch\n\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\n```\n\nこれで画像を生成できます：\n\n```python\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_1.png\">\n</div>\n\nこの処理にはT4 GPUで~30秒かかりました（割り当てられているGPUがT4より優れている場合はもっと速いかもしれません）。デフォルトでは、[`DiffusionPipeline`]は完全な`float32`精度で生成を50ステップ実行します。float16`のような低い精度に変更するか、推論ステップ数を減らすことで高速化することができます。\n\nまずは `float16` でモデルをロードして画像を生成してみましょう：\n\n```python\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)\npipeline = pipeline.to(\"cuda\")\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_2.png\">\n</div>\n\n今回、画像生成にかかった時間はわずか11秒で、以前より3倍近く速くなりました！\n\n> [!TIP]\n> 💡 パイプラインは常に `float16` で実行することを強くお勧めします。\n\n生成ステップ数を減らすという方法もあります。より効率的なスケジューラを選択することで、出力品質を犠牲にすることなくステップ数を減らすことができます。`compatibles`メソッドを呼び出すことで、[`DiffusionPipeline`]の現在のモデルと互換性のあるスケジューラを見つけることができます：\n\n```python\npipeline.scheduler.compatibles\n[\n    diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,\n    diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,\n    diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,\n    diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,\n    diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,\n    diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,\n    diffusers.schedulers.scheduling_ddpm.DDPMScheduler,\n    diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,\n    diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,\n    diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,\n    diffusers.schedulers.scheduling_pndm.PNDMScheduler,\n    diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,\n    diffusers.schedulers.scheduling_ddim.DDIMScheduler,\n]\n```\n\nStable Diffusionモデルはデフォルトで[`PNDMScheduler`]を使用します。このスケジューラは通常~50の推論ステップを必要としますが、[`DPMSolverMultistepScheduler`]のような高性能なスケジューラでは~20または25の推論ステップで済みます。[`ConfigMixin.from_config`]メソッドを使用すると、新しいスケジューラをロードすることができます：\n\n```python\nfrom diffusers import DPMSolverMultistepScheduler\n\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n```\n\nここで `num_inference_steps` を20に設定します：\n\n```python\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\nimage = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_3.png\">\n</div>\n\n推論時間をわずか4秒に短縮することに成功した！⚡️\n\n## メモリー\n\nパイプラインのパフォーマンスを向上させるもう1つの鍵は、消費メモリを少なくすることです。一度に生成できる画像の数を確認する最も簡単な方法は、`OutOfMemoryError`（OOM）が発生するまで、さまざまなバッチサイズを試してみることです。\n\n文章と `Generators` のリストから画像のバッチを生成する関数を作成します。各 `Generator` にシードを割り当てて、良い結果が得られた場合に再利用できるようにします。\n\n```python\ndef get_inputs(batch_size=1):\n    generator = [torch.Generator(\"cuda\").manual_seed(i) for i in range(batch_size)]\n    prompts = batch_size * [prompt]\n    num_inference_steps = 20\n\n    return {\"prompt\": prompts, \"generator\": generator, \"num_inference_steps\": num_inference_steps}\n```\n\n`batch_size=4`で開始し、どれだけメモリを消費したかを確認します：\n\n```python\nfrom diffusers.utils import make_image_grid\n\nimages = pipeline(**get_inputs(batch_size=4)).images\nmake_image_grid(images, 2, 2)\n```\n\n大容量のRAMを搭載したGPUでない限り、上記のコードはおそらく`OOM`エラーを返したはずです！メモリの大半はクロスアテンションレイヤーが占めています。この処理をバッチで実行する代わりに、逐次実行することでメモリを大幅に節約できます。必要なのは、[`~DiffusionPipeline.enable_attention_slicing`]関数を使用することだけです：\n\n```python\npipeline.enable_attention_slicing()\n```\n\n今度は`batch_size`を8にしてみてください！\n\n```python\nimages = pipeline(**get_inputs(batch_size=8)).images\nmake_image_grid(images, rows=2, cols=4)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_5.png\">\n</div>\n\n以前は4枚の画像のバッチを生成することさえできませんでしたが、今では8枚の画像のバッチを1枚あたり～3.5秒で生成できます！これはおそらく、品質を犠牲にすることなくT4 GPUでできる最速の処理速度です。\n\n## 品質\n\n前の2つのセクションでは、`fp16` を使ってパイプラインの速度を最適化する方法、よりパフォーマン スなスケジューラーを使って生成ステップ数を減らす方法、アテンションスライスを有効 にしてメモリ消費量を減らす方法について学びました。今度は、生成される画像の品質を向上させる方法に焦点を当てます。\n\n### より良いチェックポイント\n\n最も単純なステップは、より良いチェックポイントを使うことです。Stable Diffusionモデルは良い出発点であり、公式発表以来、いくつかの改良版もリリースされています。しかし、新しいバージョンを使ったからといって、自動的に良い結果が得られるわけではありません。最良の結果を得るためには、自分でさまざまなチェックポイントを試してみたり、ちょっとした研究（[ネガティブプロンプト](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)の使用など）をしたりする必要があります。\n\nこの分野が成長するにつれて、特定のスタイルを生み出すために微調整された、より質の高いチェックポイントが増えています。[Hub](https://huggingface.co/models?library=diffusers&sort=downloads)や[Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery)を探索して、興味のあるものを見つけてみてください！\n\n### より良いパイプラインコンポーネント\n\n現在のパイプラインコンポーネントを新しいバージョンに置き換えてみることもできます。Stability AIが提供する最新の[autodecoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae)をパイプラインにロードし、画像を生成してみましょう：\n\n```python\nfrom diffusers import AutoencoderKL\n\nvae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16).to(\"cuda\")\npipeline.vae = vae\nimages = pipeline(**get_inputs(batch_size=8)).images\nmake_image_grid(images, rows=2, cols=4)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_6.png\">\n</div>\n\n### より良いプロンプト・エンジニアリング\n\n画像を生成するために使用する文章は、*プロンプトエンジニアリング*と呼ばれる分野を作られるほど、非常に重要です。プロンプト・エンジニアリングで考慮すべき点は以下の通りです：\n\n- 生成したい画像やその類似画像は、インターネット上にどのように保存されているか？\n- 私が望むスタイルにモデルを誘導するために、どのような追加詳細を与えるべきか？\n\nこのことを念頭に置いて、プロンプトに色やより質の高いディテールを含めるように改良してみましょう：\n\n```python\nprompt += \", tribal panther make up, blue on red, side profile, looking away, serious eyes\"\nprompt += \" 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\"\n```\n\n新しいプロンプトで画像のバッチを生成しましょう：\n\n```python\nimages = pipeline(**get_inputs(batch_size=8)).images\nmake_image_grid(images, rows=2, cols=4)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_7.png\">\n</div>\n\nかなりいいです！種が`1`の`Generator`に対応する2番目の画像に、被写体の年齢に関するテキストを追加して、もう少し手を加えてみましょう：\n\n```python\nprompts = [\n    \"portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\",\n    \"portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\",\n    \"portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\",\n    \"portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\",\n]\n\ngenerator = [torch.Generator(\"cuda\").manual_seed(1) for _ in range(len(prompts))]\nimages = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images\nmake_image_grid(images, 2, 2)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_8.png\">\n</div>\n\n## 次のステップ\n\nこのチュートリアルでは、[`DiffusionPipeline`]を最適化して計算効率とメモリ効率を向上させ、生成される出力の品質を向上させる方法を学びました。パイプラインをさらに高速化することに興味があれば、以下のリソースを参照してください：\n\n- [PyTorch 2.0](./optimization/torch2.0)と[`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html)がどのように生成速度を5-300%高速化できるかを学んでください。A100 GPUの場合、画像生成は最大50%速くなります！\n- PyTorch 2が使えない場合は、[xFormers](./optimization/xformers)をインストールすることをお勧めします。このライブラリのメモリ効率の良いアテンションメカニズムは PyTorch 1.13.1 と相性が良く、高速化とメモリ消費量の削減を同時に実現します。\n- モデルのオフロードなど、その他の最適化テクニックは [this guide](./optimization/fp16) でカバーされています。\n"
  },
  {
    "path": "docs/source/ja/tutorials/autopipeline.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# AutoPipeline\n\nDiffusersは様々なタスクをこなすことができ、テキストから画像、画像から画像、画像の修復など、複数のタスクに対して同じように事前学習された重みを再利用することができます。しかし、ライブラリや拡散モデルに慣れていない場合、どのタスクにどのパイプラインを使えばいいのかがわかりにくいかもしれません。例えば、 [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) チェックポイントをテキストから画像に変換するために使用している場合、それぞれ[`StableDiffusionImg2ImgPipeline`]クラスと[`StableDiffusionInpaintPipeline`]クラスでチェックポイントをロードすることで、画像から画像や画像の修復にも使えることを知らない可能性もあります。\n\n`AutoPipeline` クラスは、🤗 Diffusers の様々なパイプラインをよりシンプルするために設計されています。この汎用的でタスク重視のパイプラインによってタスクそのものに集中することができます。`AutoPipeline` は、使用するべき正しいパイプラインクラスを自動的に検出するため、特定のパイプラインクラス名を知らなくても、タスクのチェックポイントを簡単にロードできます。\n\n> [!TIP]\n> どのタスクがサポートされているかは、[AutoPipeline](../api/pipelines/auto_pipeline) のリファレンスをご覧ください。現在、text-to-image、image-to-image、inpaintingをサポートしています。\n\nこのチュートリアルでは、`AutoPipeline` を使用して、事前に学習された重みが与えられたときに、特定のタスクを読み込むためのパイプラインクラスを自動的に推測する方法を示します。\n\n## タスクに合わせてAutoPipeline を選択する\nまずはチェックポイントを選ぶことから始めましょう。例えば、 [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) チェックポイントでテキストから画像への変換したいなら、[`AutoPipelineForText2Image`]を使います:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, use_safetensors=True\n).to(\"cuda\")\nprompt = \"peasant and dragon combat, wood cutting style, viking era, bevel with rune\"\n\nimage = pipeline(prompt, num_inference_steps=25).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png\" alt=\"generated image of peasant fighting dragon in wood cutting style\"/>\n</div>\n\n[`AutoPipelineForText2Image`] を具体的に見ていきましょう:\n\n1. [`model_index.json`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/model_index.json) ファイルから `\"stable-diffusion\"` クラスを自動的に検出します。\n2. `\"stable-diffusion\"` のクラス名に基づいて、テキストから画像へ変換する [`StableDiffusionPipeline`] を読み込みます。\n\n同様に、画像から画像へ変換する場合、[`AutoPipelineForImage2Image`] は `model_index.json` ファイルから `\"stable-diffusion\"` チェックポイントを検出し、対応する [`StableDiffusionImg2ImgPipeline`] を読み込みます。また、入力画像にノイズの量やバリエーションの追加を決めるための強さなど、パイプラインクラスに固有の追加引数を渡すこともできます:\n\n```py\nfrom diffusers import AutoPipelineForImage2Image\nimport torch\nimport requests\nfrom PIL import Image\nfrom io import BytesIO\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n).to(\"cuda\")\nprompt = \"a portrait of a dog wearing a pearl earring\"\n\nurl = \"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0f/1665_Girl_with_a_Pearl_Earring.jpg/800px-1665_Girl_with_a_Pearl_Earring.jpg\"\n\nresponse = requests.get(url)\nimage = Image.open(BytesIO(response.content)).convert(\"RGB\")\nimage.thumbnail((768, 768))\n\nimage = pipeline(prompt, image, num_inference_steps=200, strength=0.75, guidance_scale=10.5).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png\" alt=\"generated image of a vermeer portrait of a dog wearing a pearl earring\"/>\n</div>\n\nまた、画像の修復を行いたい場合は、 [`AutoPipelineForInpainting`] が、同様にベースとなる[`StableDiffusionInpaintPipeline`]クラスを読み込みます：\n\n```py\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image\nimport torch\n\npipeline = AutoPipelineForInpainting.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, use_safetensors=True\n).to(\"cuda\")\n\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\ninit_image = load_image(img_url).convert(\"RGB\")\nmask_image = load_image(mask_url).convert(\"RGB\")\n\nprompt = \"A majestic tiger sitting on a bench\"\nimage = pipeline(prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png\" alt=\"generated image of a tiger sitting on a bench\"/>\n</div>\n\nサポートされていないチェックポイントを読み込もうとすると、エラーになります:\n\n```py\nfrom diffusers import AutoPipelineForImage2Image\nimport torch\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"openai/shap-e-img2img\", torch_dtype=torch.float16, use_safetensors=True\n)\n\"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None\"\n```\n\n## 複数のパイプラインを使用する\n\nいくつかのワークフローや多くのパイプラインを読み込む場合、不要なメモリを使ってしまう再読み込みをするよりも、チェックポイントから同じコンポーネントを再利用する方がメモリ効率が良いです。たとえば、テキストから画像への変換にチェックポイントを使い、画像から画像への変換にまたチェックポイントを使いたい場合、[from_pipe()](https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/auto_pipeline#diffusers.AutoPipelineForImage2Image.from_pipe) メソッドを使用します。このメソッドは、以前読み込まれたパイプラインのコンポーネントを使うことで追加のメモリを消費することなく、新しいパイプラインを作成します。\n\n[from_pipe()](https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/auto_pipeline#diffusers.AutoPipelineForImage2Image.from_pipe) メソッドは、元のパイプラインクラスを検出し、実行したいタスクに対応する新しいパイプラインクラスにマッピングします。例えば、テキストから画像への`\"stable-diffusion\"` クラスのパイプラインを読み込む場合：\n\n```py\nfrom diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image\nimport torch\n\npipeline_text2img = AutoPipelineForText2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, use_safetensors=True\n)\nprint(type(pipeline_text2img))\n\"<class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'>\"\n```\n\nそして、[from_pipe()] (https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/auto_pipeline#diffusers.AutoPipelineForImage2Image.from_pipe)は、もとの`\"stable-diffusion\"` パイプラインのクラスである [`StableDiffusionImg2ImgPipeline`] にマップします:\n\n```py\npipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img)\nprint(type(pipeline_img2img))\n\"<class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline'>\"\n```\n元のパイプラインにオプションとして引数（セーフティチェッカーの無効化など）を渡した場合、この引数も新しいパイプラインに渡されます:\n\n```py\nfrom diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image\nimport torch\n\npipeline_text2img = AutoPipelineForText2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n    requires_safety_checker=False,\n).to(\"cuda\")\n\npipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img)\nprint(pipeline_img2img.config.requires_safety_checker)\n\"False\"\n```\n\n新しいパイプラインの動作を変更したい場合は、元のパイプラインの引数や設定を上書きすることができます。例えば、セーフティチェッカーをオンに戻し、`strength` 引数を追加します:\n\n```py\npipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img, requires_safety_checker=True, strength=0.3)\nprint(pipeline_img2img.config.requires_safety_checker)\n\"True\"\n```\n"
  },
  {
    "path": "docs/source/ja/tutorials/tutorial_overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Overview\n\nようこそ 🧨Diffusersへ！拡散モデル(diffusion models)や生成AIの初心者で、さらに学びたいのであれば、このチュートリアルが最適です。この初心者向けのチュートリアルは、拡散モデルについて丁寧に解説し、ライブラリの基礎（核となるコンポーネントと 🧨Diffusersの使用方法）を理解することを目的としています。\n\nまず、推論のためのパイプラインを使って、素早く生成する方法を学んでいきます。次に、独自の拡散システムを構築するためのモジュラーツールボックスとしてライブラリをどのように使えば良いかを理解するために、そのパイプラインを分解してみましょう。次のレッスンでは、あなたの欲しいものを生成できるように拡散モデルをトレーニングする方法を学びましょう。\n\nこのチュートリアルがすべて完了したら、ライブラリを自分で調べ、自分のプロジェクトやアプリケーションにどのように使えるかを知るために必要なスキルを身につけることができます。\n\nそして、 [Discord](https://discord.com/invite/JfAtkvEtRb) や [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) でDiffusersコミュニティに参加してユーザーや開発者と繋がって協力していきましょう。\n\nさあ、「拡散」をはじめていきましょう！🧨"
  },
  {
    "path": "docs/source/ko/_toctree.yml",
    "content": "- sections:\n  - local: index\n    title: 🧨 Diffusers\n  - local: quicktour\n    title: \"훑어보기\"\n  - local: stable_diffusion\n    title: Stable Diffusion\n  - local: installation\n    title: 설치\n  title: 시작하기\n- sections:\n  - local: tutorials/tutorial_overview\n    title: 개요\n  - local: using-diffusers/write_own_pipeline\n    title: 모델과 스케줄러 이해하기\n  - local: in_translation # tutorials/autopipeline\n    title: (번역중) AutoPipeline\n  - local: tutorials/basic_training\n    title: Diffusion 모델 학습하기\n  - local: in_translation # tutorials/using_peft_for_inference\n    title: (번역중) 추론을 위한 LoRAs 불러오기\n  - local: in_translation # tutorials/fast_diffusion\n    title: (번역중) Text-to-image diffusion 모델 추론 가속화하기\n  - local: in_translation  # tutorials/inference_with_big_models\n    title: (번역중) 큰 모델로 작업하기\n  title: 튜토리얼\n- sections:\n  - local: using-diffusers/loading\n    title: 파이프라인 불러오기\n  - local: using-diffusers/custom_pipeline_overview\n    title: 커뮤니티 파이프라인과 컴포넌트 불러오기\n  - local: using-diffusers/schedulers\n    title: 스케줄러와 모델 불러오기\n  - local: using-diffusers/other-formats\n    title: 모델 파일과 레이아웃\n  - local: using-diffusers/loading_adapters\n    title: 어댑터 불러오기\n  - local: using-diffusers/push_to_hub\n    title: 파일들을 Hub로 푸시하기\n  title: 파이프라인과 어댑터 불러오기\n- sections:\n  - local: using-diffusers/unconditional_image_generation\n    title: Unconditional 이미지 생성\n  - local: using-diffusers/conditional_image_generation\n    title: Text-to-image\n  - local: using-diffusers/img2img\n    title: Image-to-image\n  - local: using-diffusers/inpaint\n    title: 인페인팅\n  - local: in_translation # using-diffusers/text-img2vid\n    title: (번역중) Text 또는 image-to-video\n  - local: using-diffusers/depth2img\n    title: Depth-to-image\n  title: 생성 태스크\n- sections:\n  - local: in_translation # using-diffusers/overview_techniques\n    title: (번역중) 개요\n  - local: training/distributed_inference\n    title: 여러 GPU를 사용한 분산 추론\n  - local: in_translation # using-diffusers/merge_loras\n    title: (번역중) LoRA 병합\n  - local: in_translation # using-diffusers/scheduler_features\n    title: (번역중) 스케줄러 기능\n  - local: in_translation # using-diffusers/callback\n    title: (번역중) 파이프라인 콜백\n  - local: in_translation # using-diffusers/reusing_seeds\n    title: (번역중) 재현 가능한 파이프라인\n  - local: in_translation # using-diffusers/image_quality\n    title: (번역중) 이미지 퀄리티 조절하기\n  - local: using-diffusers/weighted_prompts\n    title: 프롬프트 기술\n  title: 추론 테크닉\n- sections:\n  - local: in_translation # advanced_inference/outpaint\n    title: (번역중) Outpainting\n  title: 추론 심화\n- sections:\n  - local: in_translation # using-diffusers/sdxl\n    title: (번역중) Stable Diffusion XL\n  - local: using-diffusers/sdxl_turbo\n    title: SDXL Turbo\n  - local: using-diffusers/kandinsky\n    title: Kandinsky\n  - local: in_translation # using-diffusers/ip_adapter\n    title: (번역중) IP-Adapter\n  - local: in_translation # using-diffusers/pag\n    title: (번역중) PAG\n  - local: in_translation # using-diffusers/controlnet\n    title: (번역중) ControlNet\n  - local: in_translation # using-diffusers/t2i_adapter\n    title: (번역중) T2I-Adapter\n  - local: in_translation # using-diffusers/inference_with_lcm\n    title: (번역중) Latent Consistency Model\n  - local: using-diffusers/textual_inversion_inference\n    title: Textual inversion\n  - local: using-diffusers/shap-e\n    title: Shap-E\n  - local: using-diffusers/diffedit\n    title: DiffEdit\n  - local: in_translation # using-diffusers/inference_with_tcd_lora\n    title: (번역중) Trajectory Consistency Distillation-LoRA\n  - local: using-diffusers/svd\n    title: Stable Video Diffusion\n  - local: in_translation # using-diffusers/marigold_usage\n    title: (번역중) Marigold 컴퓨터 비전\n  title: 특정 파이프라인 예시\n- sections:\n  - local: training/overview\n    title: 개요\n  - local: training/create_dataset\n    title: 학습을 위한 데이터셋 생성하기\n  - local: training/adapt_a_model\n    title: 새로운 태스크에 모델 적용하기\n  - isExpanded: false\n    sections:\n    - local: training/unconditional_training\n      title: Unconditional 이미지 생성\n    - local: training/text2image\n      title: Text-to-image\n    - local: in_translation # training/sdxl\n      title: (번역중) Stable Diffusion XL\n    - local: in_translation # training/kandinsky\n      title: (번역중) Kandinsky 2.2\n    - local: in_translation # training/wuerstchen\n      title: (번역중) Wuerstchen\n    - local: training/controlnet\n      title: ControlNet\n    - local: in_translation # training/t2i_adapters\n      title: (번역중) T2I-Adapters\n    - local: training/instructpix2pix\n      title: InstructPix2Pix\n    title: 모델\n  - isExpanded: false\n    sections:\n    - local: training/text_inversion\n      title: Textual Inversion\n    - local: training/dreambooth\n      title: DreamBooth\n    - local: training/lora\n      title: LoRA\n    - local: training/custom_diffusion\n      title: Custom Diffusion\n    - local: in_translation # training/lcm_distill\n      title: (번역중) Latent Consistency Distillation\n    - local: in_translation # training/ddpo\n      title: (번역중) DDPO 강화학습 훈련\n    title: 메서드\n  title: 학습\n- sections:\n  - local: optimization/fp16\n    title: 추론 스피드업\n  - local: in_translation # optimization/memory\n    title: (번역중) 메모리 사용량 줄이기\n  - local: optimization/torch2.0\n    title: PyTorch 2.0\n  - local: optimization/xformers\n    title: xFormers\n  - local: optimization/tome\n    title: Token merging\n  - local: in_translation # optimization/deepcache\n    title: (번역중) DeepCache\n  - local: in_translation # optimization/tgate\n    title: (번역중) TGATE\n  - sections:\n    - local: using-diffusers/stable_diffusion_jax_how_to\n      title: JAX/Flax\n    - local: optimization/onnx\n      title: ONNX\n    - local: optimization/open_vino\n      title: OpenVINO\n    - local: optimization/coreml\n      title: Core ML\n    title: 최적화된 모델 형식\n  - sections:\n    - local: optimization/mps\n      title: Metal Performance Shaders (MPS)\n    - local: optimization/habana\n      title: Intel Gaudi\n    title: 최적화된 하드웨어\n  title: 추론 가속화와 메모리 줄이기\n- sections:\n  - local: conceptual/philosophy\n    title: 철학\n  - local: using-diffusers/controlling_generation\n    title: 제어된 생성\n  - local: conceptual/contribution\n    title: 어떻게 기여하나요?\n  - local: conceptual/ethical_guidelines\n    title: Diffusers의 윤리적 가이드라인\n  - local: conceptual/evaluation\n    title: Diffusion Models 평가하기\n  title: 개념 가이드\n- sections:\n  - sections:\n    - sections:\n      - local: api/pipelines/stable_diffusion/stable_diffusion_xl\n        title: Stable Diffusion XL\n      title: Stable Diffusion\n    title: Pipelines\n  title: API"
  },
  {
    "path": "docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable diffusion XL\n\nStable Diffusion XL은 Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, Robin Rombach에 의해 [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://huggingface.co/papers/2307.01952)에서 제안되었습니다.\n\n논문 초록은 다음을 따릅니다:\n\n*text-to-image의 latent diffusion 모델인 SDXL을 소개합니다. 이전 버전의 Stable Diffusion과 비교하면, SDXL은 세 배 더큰 규모의 UNet 백본을 포함합니다: 모델 파라미터의 증가는 많은 attention 블럭을 사용하고 더 큰 cross-attention context를 SDXL의 두 번째 텍스트 인코더에 사용하기 때문입니다. 다중 종횡비에 다수의 새로운 conditioning 방법을 구성했습니다. 또한 후에 수정하는 image-to-image 기술을 사용함으로써 SDXL에 의해 생성된 시각적 품질을 향상하기 위해 정제된 모델을 소개합니다. SDXL은 이전 버전의 Stable Diffusion보다 성능이 향상되었고, 이러한 black-box 최신 이미지 생성자와 경쟁력있는 결과를 달성했습니다.*\n\n## 팁\n\n- Stable Diffusion XL은 특히 786과 1024사이의 이미지에 잘 작동합니다.\n- Stable Diffusion XL은 아래와 같이 학습된 각 텍스트 인코더에 대해 서로 다른 프롬프트를 전달할 수 있습니다. 동일한 프롬프트의 다른 부분을 텍스트 인코더에 전달할 수도 있습니다.\n- Stable Diffusion XL 결과 이미지는 아래에 보여지듯이 정제기(refiner)를 사용함으로써 향상될 수 있습니다.\n\n### 이용가능한 체크포인트:\n\n- *Text-to-Image (1024x1024 해상도)*: [`StableDiffusionXLPipeline`]을 사용한 [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n- *Image-to-Image / 정제기(refiner) (1024x1024 해상도)*: [`StableDiffusionXLImg2ImgPipeline`]를 사용한 [stabilityai/stable-diffusion-xl-refiner-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)\n\n## 사용 예시\n\nSDXL을 사용하기 전에 `transformers`, `accelerate`, `safetensors` 와 `invisible_watermark`를 설치하세요.\n다음과 같이 라이브러리를 설치할 수 있습니다:\n\n```sh\npip install transformers\npip install accelerate\npip install safetensors\npip install invisible-watermark>=0.2.0\n```\n\n### 워터마커\n\nStable Diffusion XL로 이미지를 생성할 때 워터마크가 보이지 않도록 추가하는 것을 권장하는데, 이는 다운스트림(downstream) 어플리케이션에서 기계에 합성되었는지를 식별하는데 도움을 줄 수 있습니다. 그렇게 하려면 [invisible_watermark 라이브러리](https://pypi.org/project/invisible-watermark/)를 통해 설치해주세요:\n\n\n```sh\npip install invisible-watermark>=0.2.0\n```\n\n`invisible-watermark` 라이브러리가 설치되면 워터마커가 **기본적으로** 사용될 것입니다.\n\n생성 또는 안전하게 이미지를 배포하기 위해 다른 규정이 있다면, 다음과 같이 워터마커를 비활성화할 수 있습니다:\n\n```py\npipe = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)\n```\n\n### Text-to-Image\n\n*text-to-image*를 위해 다음과 같이 SDXL을 사용할 수 있습니다:\n\n```py\nfrom diffusers import StableDiffusionXLPipeline\nimport torch\n\npipe = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipe.to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\nimage = pipe(prompt=prompt).images[0]\n```\n\n### Image-to-image\n\n*image-to-image*를 위해 다음과 같이 SDXL을 사용할 수 있습니다:\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLImg2ImgPipeline\nfrom diffusers.utils import load_image\n\npipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-refiner-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipe = pipe.to(\"cuda\")\nurl = \"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png\"\n\ninit_image = load_image(url).convert(\"RGB\")\nprompt = \"a photo of an astronaut riding a horse on mars\"\nimage = pipe(prompt, image=init_image).images[0]\n```\n\n### 인페인팅\n\n*inpainting*를 위해 다음과 같이 SDXL을 사용할 수 있습니다:\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLInpaintPipeline\nfrom diffusers.utils import load_image\n\npipe = StableDiffusionXLInpaintPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipe.to(\"cuda\")\n\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\ninit_image = load_image(img_url).convert(\"RGB\")\nmask_image = load_image(mask_url).convert(\"RGB\")\n\nprompt = \"A majestic tiger sitting on a bench\"\nimage = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80).images[0]\n```\n\n### 이미지 결과물을 정제하기\n\n[base 모델 체크포인트](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)에서, StableDiffusion-XL 또한 고주파 품질을 향상시키는 이미지를 생성하기 위해 낮은 노이즈 단계 이미지를 제거하는데 특화된 [refiner 체크포인트](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)를 포함하고 있습니다. 이 refiner 체크포인트는 이미지 품질을 향상시키기 위해 base 체크포인트를 실행한 후 \"두 번째 단계\" 파이프라인에 사용될 수 있습니다.\n\nrefiner를 사용할 때, 쉽게 사용할 수 있습니다\n- 1.) base 모델과 refiner을 사용하는데, 이는 *Denoisers의 앙상블*을 위한 첫 번째 제안된 [eDiff-I](https://research.nvidia.com/labs/dir/eDiff-I/)를 사용하거나\n- 2.) base 모델을 거친 후 [SDEdit](https://huggingface.co/papers/2108.01073) 방법으로 단순하게 refiner를 실행시킬 수 있습니다.\n\n**참고**: SD-XL base와 refiner를 앙상블로 사용하는 아이디어는 커뮤니티 기여자들이 처음으로 제안했으며, 이는 다음과 같은 `diffusers`를 구현하는 데도 도움을 주셨습니다.\n- [SytanSD](https://github.com/SytanSD)\n- [bghira](https://github.com/bghira)\n- [Birch-san](https://github.com/Birch-san)\n- [AmericanPresidentJimmyCarter](https://github.com/AmericanPresidentJimmyCarter)\n\n#### 1.) Denoisers의 앙상블\n\nbase와 refiner 모델을 denoiser의 앙상블로 사용할 때, base 모델은 고주파 diffusion 단계를 위한 전문가의 역할을 해야하고, refiner는 낮은 노이즈 diffusion 단계를 위한 전문가의 역할을 해야 합니다.\n\n2.)에 비해 1.)의 장점은 전체적으로 denoising 단계가 덜 필요하므로 속도가 훨씬 더 빨라집니다. 단점은 base 모델의 결과를 검사할 수 없다는 것입니다. 즉, 여전히 노이즈가 심하게 제거됩니다.\n\nbase 모델과 refiner를 denoiser의 앙상블로 사용하기 위해 각각 고노이즈(high-nosise) (*즉* base 모델)와 저노이즈 (*즉* refiner 모델)의 노이즈를 제거하는 단계를 거쳐야하는 타임스텝의 기간을 정의해야 합니다.\nbase 모델의 [`denoising_end`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.denoising_end)와 refiner 모델의 [`denoising_start`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline.__call__.denoising_start)를 사용해 간격을 정합니다.\n\n`denoising_end`와 `denoising_start` 모두 0과 1사이의 실수 값으로 전달되어야 합니다.\n전달되면 노이즈 제거의 끝과 시작은 모델 스케줄에 의해 정의된 이산적(discrete) 시간 간격의 비율로 정의됩니다.\n노이즈 제거 단계의 수는 모델이 학습된 불연속적인 시간 간격과 선언된 fractional cutoff에 의해 결정되므로 '강도' 또한 선언된 경우 이 값이 '강도'를 재정의합니다.\n\n예시를 들어보겠습니다.\n우선, 두 개의 파이프라인을 가져옵니다. 텍스트 인코더와 variational autoencoder는 동일하므로 refiner를 위해 다시 불러오지 않아도 됩니다.\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\nbase = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipe.to(\"cuda\")\n\nrefiner = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-refiner-1.0\",\n    text_encoder_2=base.text_encoder_2,\n    vae=base.vae,\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n    variant=\"fp16\",\n)\nrefiner.to(\"cuda\")\n```\n\n이제 추론 단계의 수와 고노이즈에서 노이즈를 제거하는 단계(*즉* base 모델)를 거쳐 실행되는 지점을 정의합니다.\n\n```py\nn_steps = 40\nhigh_noise_frac = 0.8\n```\n\nStable Diffusion XL base 모델은 타임스텝 0-999에 학습되며 Stable Diffusion XL refiner는 포괄적인 낮은 노이즈 타임스텝인 0-199에 base 모델로 부터 파인튜닝되어, 첫 800 타임스텝 (높은 노이즈)에 base 모델을 사용하고 마지막 200 타입스텝 (낮은 노이즈)에서 refiner가 사용됩니다. 따라서, `high_noise_frac`는 0.8로 설정하고, 모든 200-999 스텝(노이즈 제거 타임스텝의 첫 80%)은 base 모델에 의해 수행되며 0-199 스텝(노이즈 제거 타임스텝의 마지막 20%)은 refiner 모델에 의해 수행됩니다.\n\n기억하세요, 노이즈 제거 절차는 **높은 값**(높은 노이즈) 타임스텝에서 시작되고, **낮은 값** (낮은 노이즈) 타임스텝에서 끝납니다.\n\n이제 두 파이프라인을 실행해봅시다. `denoising_end`과 `denoising_start`를 같은 값으로 설정하고 `num_inference_steps`는 상수로 유지합니다. 또한 base 모델의 출력은 잠재 공간에 있어야 한다는 점을 기억하세요:\n\n```py\nprompt = \"A majestic lion jumping from a big stone at night\"\n\nimage = base(\n    prompt=prompt,\n    num_inference_steps=n_steps,\n    denoising_end=high_noise_frac,\n    output_type=\"latent\",\n).images\nimage = refiner(\n    prompt=prompt,\n    num_inference_steps=n_steps,\n    denoising_start=high_noise_frac,\n    image=image,\n).images[0]\n```\n\n이미지를 살펴보겠습니다.\n\n| 원래의 이미지 | Denoiser들의 앙상블 |\n|---|---|\n| ![lion_base](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lion_base.png) | ![lion_ref](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lion_refined.png)\n\n동일한 40 단계에서 base 모델을 실행한다면, 이미지의 디테일(예: 사자의 눈과 코)이 떨어졌을 것입니다:\n\n> [!TIP]\n> 앙상블 방식은 사용 가능한 모든 스케줄러에서 잘 작동합니다!\n\n#### 2.) 노이즈가 완전히 제거된 기본 이미지에서 이미지 출력을 정제하기\n\n일반적인 [`StableDiffusionImg2ImgPipeline`] 방식에서, 기본 모델에서 생성된 완전히 노이즈가 제거된 이미지는 [refiner checkpoint](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)를 사용해 더 향상시킬 수 있습니다.\n\n이를 위해, 보통의 \"base\" text-to-image 파이프라인을 수행 후에 image-to-image 파이프라인으로써 refiner를 실행시킬 수 있습니다. base 모델의 출력을 잠재 공간에 남겨둘 수 있습니다.\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipe = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipe.to(\"cuda\")\n\nrefiner = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-refiner-1.0\",\n    text_encoder_2=pipe.text_encoder_2,\n    vae=pipe.vae,\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n    variant=\"fp16\",\n)\nrefiner.to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n\nimage = pipe(prompt=prompt, output_type=\"latent\" if use_refiner else \"pil\").images[0]\nimage = refiner(prompt=prompt, image=image[None, :]).images[0]\n```\n\n| 원래의 이미지 | 정제된 이미지 |\n|---|---|\n| ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/init_image.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_image.png) |\n\n> [!TIP]\n> refiner는 또한 인페인팅 설정에 잘 사용될 수 있습니다. 아래에 보여지듯이 [`StableDiffusionXLInpaintPipeline`] 클래스를 사용해서 만들어보세요.\n\nDenoiser 앙상블 설정에서 인페인팅에 refiner를 사용하려면 다음을 수행하면 됩니다:\n\n```py\nfrom diffusers import StableDiffusionXLInpaintPipeline\nfrom diffusers.utils import load_image\n\npipe = StableDiffusionXLInpaintPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipe.to(\"cuda\")\n\nrefiner = StableDiffusionXLInpaintPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-refiner-1.0\",\n    text_encoder_2=pipe.text_encoder_2,\n    vae=pipe.vae,\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n    variant=\"fp16\",\n)\nrefiner.to(\"cuda\")\n\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\ninit_image = load_image(img_url).convert(\"RGB\")\nmask_image = load_image(mask_url).convert(\"RGB\")\n\nprompt = \"A majestic tiger sitting on a bench\"\nnum_inference_steps = 75\nhigh_noise_frac = 0.7\n\nimage = pipe(\n    prompt=prompt,\n    image=init_image,\n    mask_image=mask_image,\n    num_inference_steps=num_inference_steps,\n    denoising_start=high_noise_frac,\n    output_type=\"latent\",\n).images\nimage = refiner(\n    prompt=prompt,\n    image=image,\n    mask_image=mask_image,\n    num_inference_steps=num_inference_steps,\n    denoising_start=high_noise_frac,\n).images[0]\n```\n\n일반적인 SDE 설정에서 인페인팅에 refiner를 사용하기 위해, `denoising_end`와 `denoising_start`를 제거하고 refiner의 추론 단계의 수를 적게 선택하세요.\n\n### 단독 체크포인트 파일 / 원래의 파일 형식으로 불러오기\n\n[`~diffusers.loaders.FromSingleFileMixin.from_single_file`]를 사용함으로써 원래의 파일 형식을 `diffusers` 형식으로 불러올 수 있습니다:\n\n```py\nfrom diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline\nimport torch\n\npipe = StableDiffusionXLPipeline.from_single_file(\n    \"./sd_xl_base_1.0.safetensors\", torch_dtype=torch.float16\n)\npipe.to(\"cuda\")\n\nrefiner = StableDiffusionXLImg2ImgPipeline.from_single_file(\n    \"./sd_xl_refiner_1.0.safetensors\", torch_dtype=torch.float16\n)\nrefiner.to(\"cuda\")\n```\n\n### 모델 offloading을 통해 메모리 최적화하기\n\nout-of-memory 에러가 난다면, [`StableDiffusionXLPipeline.enable_model_cpu_offload`]을 사용하는 것을 권장합니다.\n\n```diff\n- pipe.to(\"cuda\")\n+ pipe.enable_model_cpu_offload()\n```\n\n그리고\n\n```diff\n- refiner.to(\"cuda\")\n+ refiner.enable_model_cpu_offload()\n```\n\n### `torch.compile`로 추론 속도를 올리기\n\n`torch.compile`를 사용함으로써 추론 속도를 올릴 수 있습니다. 이는 **ca.** 20% 속도 향상이 됩니다.\n\n```diff\n+ pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n+ refiner.unet = torch.compile(refiner.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\n### `torch < 2.0`일 때 실행하기\n\n**참고** Stable Diffusion XL을 `torch`가 2.0 버전 미만에서 실행시키고 싶을 때, xformers 어텐션을 사용해주세요:\n\n```sh\npip install xformers\n```\n\n```diff\n+pipe.enable_xformers_memory_efficient_attention()\n+refiner.enable_xformers_memory_efficient_attention()\n```\n\n## StableDiffusionXLPipeline\n\n[[autodoc]] StableDiffusionXLPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLImg2ImgPipeline\n\n[[autodoc]] StableDiffusionXLImg2ImgPipeline\n\t- all\n\t- __call__\n\n## StableDiffusionXLInpaintPipeline\n\n[[autodoc]] StableDiffusionXLInpaintPipeline\n\t- all\n\t- __call__\n\n### 각 텍스트 인코더에 다른 프롬프트를 전달하기\n\nStable Diffusion XL는 두 개의 텍스트 인코더에 학습되었습니다. 기본 동작은 각 프롬프트에 동일한 프롬프트를 전달하는 것입니다. 그러나 [일부 사용자](https://github.com/huggingface/diffusers/issues/4004#issuecomment-1627764201)가 품질을 향상시킬 수 있다고 지적한 것처럼 텍스트 인코더마다 다른 프롬프트를 전달할 수 있습니다. 그렇게 하려면, `prompt_2`와 `negative_prompt_2`를 `prompt`와 `negative_prompt`에 전달해야 합니다. 그렇게 함으로써, 원래의 프롬프트들(`prompt`)과 부정 프롬프트들(`negative_prompt`)를 `텍스트 인코더`에 전달할 것입니다.(공식 SDXL 0.9/1.0의 [OpenAI CLIP-ViT/L-14](https://huggingface.co/openai/clip-vit-large-patch14)에서 볼 수 있습니다.) 그리고 `prompt_2`와 `negative_prompt_2`는 `text_encoder_2`에 전달됩니다.(공식 SDXL 0.9/1.0의 [OpenCLIP-ViT/bigG-14](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)에서 볼 수 있습니다.)\n\n```py\nfrom diffusers import StableDiffusionXLPipeline\nimport torch\n\npipe = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-0.9\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\npipe.to(\"cuda\")\n\n# OAI CLIP-ViT/L-14에 prompt가 전달됩니다\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n# OpenCLIP-ViT/bigG-14에 prompt_2가 전달됩니다\nprompt_2 = \"monet painting\"\nimage = pipe(prompt=prompt, prompt_2=prompt_2).images[0]\n```"
  },
  {
    "path": "docs/source/ko/conceptual/contribution.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Diffusers에 기여하는 방법 🧨 [[how-to-contribute-to-diffusers-]]\n\n오픈 소스 커뮤니티에서의 기여를 환영합니다! 누구나 참여할 수 있으며, 코드뿐만 아니라 질문에 답변하거나 문서를 개선하는 등 모든 유형의 참여가 가치 있고 감사히 여겨집니다. 질문에 답변하고 다른 사람들을 도와주며 소통하고 문서를 개선하는 것은 모두 커뮤니티에게 큰 도움이 됩니다. 따라서 관심이 있다면 두려워하지 말고 참여해보세요!\n\n누구나 우리의 공개 Discord 채널에서 👋 인사하며 시작할 수 있도록 장려합니다. 우리는 diffusion 모델의 최신 동향을 논의하고 질문을 하며 개인 프로젝트를 자랑하고 기여에 대해 서로 도와주거나 그냥 어울리기 위해 모이는 곳입니다☕. <a href=\"https://Discord.gg/G7tWnz98XR\"><img alt=\"Join us on Discord\" src=\"https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white\"></a>\n\n어떤 방식으로든 기여하려는 경우, 우리는 개방적이고 환영하며 친근한 커뮤니티의 일부가 되기 위해 노력하고 있습니다. 우리의 [행동 강령](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md)을 읽고 상호 작용 중에 이를 존중하도록 주의해주시기 바랍니다. 또한 프로젝트를 안내하는 [윤리 지침](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines)에 익숙해지고 동일한 투명성과 책임성의 원칙을 준수해주시기를 부탁드립니다.\n\n우리는 커뮤니티로부터의 피드백을 매우 중요하게 생각하므로, 라이브러리를 개선하는 데 도움이 될 가치 있는 피드백이 있다고 생각되면 망설이지 말고 의견을 제시해주세요 - 모든 메시지, 댓글, 이슈, Pull Request(PR)는 읽히고 고려됩니다.\n\n## 개요 [[overview]]\n\n이슈에 있는 질문에 답변하는 것에서부터 코어 라이브러리에 새로운 diffusion 모델을 추가하는 것까지 다양한 방법으로 기여를 할 수 있습니다.\n\n이어지는 부분에서 우리는 다양한 방법의 기여에 대한 개요를 난이도에 따라 오름차순으로 정리하였습니다. 모든 기여는 커뮤니티에게 가치가 있습니다.\n\n1. [Diffusers 토론 포럼](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers)이나 [Discord](https://discord.gg/G7tWnz98XR)에서 질문에 대답하거나 질문을 할 수 있습니다.\n2. [GitHub Issues 탭](https://github.com/huggingface/diffusers/issues/new/choose)에서 새로운 이슈를 열 수 있습니다.\n3. [GitHub Issues 탭](https://github.com/huggingface/diffusers/issues)에서 이슈에 대답할 수 있습니다.\n4. \"Good first issue\" 라벨이 지정된 간단한 이슈를 수정할 수 있습니다. [여기](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)를 참조하세요.\n5. [문서](https://github.com/huggingface/diffusers/tree/main/docs/source)에 기여할 수 있습니다.\n6. [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples)에 기여할 수 있습니다.\n7. [예제](https://github.com/huggingface/diffusers/tree/main/examples)에 기여할 수 있습니다.\n8. \"Good second issue\" 라벨이 지정된 어려운 이슈를 수정할 수 있습니다. [여기](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22)를 참조하세요.\n9. 새로운 파이프라인, 모델 또는 스케줄러를 추가할 수 있습니다. [\"새로운 파이프라인/모델\"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) 및 [\"새로운 스케줄러\"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) 이슈를 참조하세요. 이 기여에 대해서는 [디자인 철학](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md)을 확인해주세요.\n\n앞서 말한 대로, **모든 기여는 커뮤니티에게 가치가 있습니다**. 이어지는 부분에서 각 기여에 대해 조금 더 자세히 설명하겠습니다.\n\n4부터 9까지의 모든 기여에는 Pull Request을 열어야 합니다. [Pull Request 열기](#how-to-open-a-pr)에서 자세히 설명되어 있습니다.\n\n### 1. Diffusers 토론 포럼이나 Diffusers Discord에서 질문하고 답변하기 [[1-asking-and-answering-questions-on-the-diffusers-discussion-forum-or-on-the-diffusers-discord]]\n\nDiffusers 라이브러리와 관련된 모든 질문이나 의견은 [토론 포럼](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)이나 [Discord](https://discord.gg/G7tWnz98XR)에서 할 수 있습니다. 이러한 질문과 의견에는 다음과 같은 내용이 포함됩니다(하지만 이에 국한되지는 않습니다):\n- 지식을 공유하기 위해서 훈련 또는 추론 실험에 대한 결과 보고\n- 개인 프로젝트 소개\n- 비공식 훈련 예제에 대한 질문\n- 프로젝트 제안\n- 일반적인 피드백\n- 논문 요약\n- Diffusers 라이브러리를 기반으로 하는 개인 프로젝트에 대한 도움 요청\n- 일반적인 질문\n- Diffusion 모델에 대한 윤리적 질문\n- ...\n\n포럼이나 Discord에서 질문을 하면 커뮤니티가 지식을 공개적으로 공유하도록 장려되며, 향후 동일한 질문을 가진 초보자에게도 도움이 될 수 있습니다. 따라서 궁금한 질문은 언제든지 하시기 바랍니다.\n또한, 이러한 질문에 답변하는 것은 커뮤니티에게 매우 큰 도움이 됩니다. 왜냐하면 이렇게 하면 모두가 학습할 수 있는 공개적인 지식을 문서화하기 때문입니다.\n\n**주의**하십시오. 질문이나 답변에 투자하는 노력이 많을수록 공개적으로 문서화된 지식의 품질이 높아집니다. 마찬가지로, 잘 정의되고 잘 답변된 질문은 모두에게 접근 가능한 고품질 지식 데이터베이스를 만들어줍니다. 반면에 잘못된 질문이나 답변은 공개 지식 데이터베이스의 전반적인 품질을 낮출 수 있습니다.\n간단히 말해서, 고품질의 질문이나 답변은 *명확하고 간결하며 관련성이 있으며 이해하기 쉽고 접근 가능하며 잘 형식화되어 있어야* 합니다. 자세한 내용은 [좋은 이슈 작성 방법](#how-to-write-a-good-issue) 섹션을 참조하십시오.\n\n**채널에 대한 참고사항**:\n[*포럼*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)은 구글과 같은 검색 엔진에서 더 잘 색인화됩니다. 게시물은 인기에 따라 순위가 매겨지며, 시간순으로 정렬되지 않습니다. 따라서 이전에 게시한 질문과 답변을 쉽게 찾을 수 있습니다.\n또한, 포럼에 게시된 질문과 답변은 쉽게 링크할 수 있습니다.\n반면 *Discord*는 채팅 형식으로 되어 있어 빠른 대화를 유도합니다.\n질문에 대한 답변을 빠르게 받을 수는 있겠지만, 시간이 지나면 질문이 더 이상 보이지 않습니다. 또한, Discord에서 이전에 게시된 정보를 찾는 것은 훨씬 어렵습니다. 따라서 포럼을 사용하여 고품질의 질문과 답변을 하여 커뮤니티를 위한 오래 지속되는 지식을 만들기를 권장합니다. Discord에서의 토론이 매우 흥미로운 답변과 결론을 이끌어내는 경우, 해당 정보를 포럼에 게시하여 향후 독자들에게 더 쉽게 액세스할 수 있도록 권장합니다.\n\n### 2. GitHub 이슈 탭에서 새로운 이슈 열기 [[2-opening-new-issues-on-the-github-issues-tab]]\n\n🧨 Diffusers 라이브러리는 사용자들이 마주치는 문제를 알려주는 덕분에 견고하고 신뢰할 수 있습니다. 따라서 이슈를 보고해주셔서 감사합니다.\n\n기억해주세요, GitHub 이슈는 Diffusers 라이브러리와 직접적으로 관련된 기술적인 질문, 버그 리포트, 기능 요청 또는 라이브러리 디자인에 대한 피드백에 사용됩니다.\n\n간단히 말해서, Diffusers 라이브러리의 **코드와 관련되지 않은** 모든 것(문서 포함)은 GitHub가 아닌 [포럼](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)이나 [Discord](https://discord.gg/G7tWnz98XR)에서 질문해야 합니다.\n\n**새로운 이슈를 열 때 다음 가이드라인을 고려해주세요**:\n- 이미 같은 이슈가 있는지 검색했는지 확인해주세요(GitHub의 이슈 탭에서 검색 기능을 사용하세요).\n- 다른(관련된) 이슈에 새로운 이슈를 보고하지 말아주세요. 다른 이슈와 관련이 높다면, 새로운 이슈를 열고 관련 이슈에 링크를 걸어주세요.\n- 이슈를 영어로 작성해주세요. 영어에 익숙하지 않다면, [DeepL](https://www.deepl.com/translator)과 같은 뛰어난 무료 온라인 번역 서비스를 사용하여 모국어에서 영어로 번역해주세요.\n- 이슈가 최신 Diffusers 버전으로 업데이트하면 해결될 수 있는지 확인해주세요. 이슈를 게시하기 전에 `python -c \"import diffusers; print(diffusers.__version__)\"` 명령을 실행하여 현재 사용 중인 Diffusers 버전이 최신 버전과 일치하거나 더 높은지 확인해주세요.\n- 새로운 이슈를 열 때 투자하는 노력이 많을수록 답변의 품질이 높아지고 Diffusers 이슈 전체의 품질도 향상됩니다.\n\n#### 2.1 재현 가능한 최소한의 버그 리포트 [[21-reproducible-minimal-bug-reports]]\n\n\n버그 리포트는 항상 재현 가능한 코드 조각을 포함하고 가능한 한 최소한이어야 하며 간결해야 합니다.\n자세히 말하면:\n- 버그를 가능한 한 좁혀야 합니다. **전체 코드 파일을 그냥 던지지 마세요**.\n- 코드의 서식을 지정해야 합니다.\n- Diffusers가 의존하는 외부 라이브러리를 제외한 다른 외부 라이브러리는 포함하지 마십시오.\n- **항상** 사용자 환경에 대한 모든 필요한 정보를 제공하세요. 이를 위해 쉘에서 `diffusers-cli env`를 실행하고 표시된 정보를 이슈에 복사하여 붙여넣을 수 있습니다.\n- 이슈를 설명해야 합니다. 독자가 문제가 무엇인지, 왜 문제가 되는지 모른다면 이슈를 해결할 수 없습니다. \n- **항상** 독자가 가능한 한 적은 노력으로 문제를 재현할 수 있어야 합니다. 코드 조각이 라이브러리가 없거나 정의되지 않은 변수 때문에 실행되지 않는 경우 독자가 도움을 줄 수 없습니다. 재현 가능한 코드 조각이 가능한 한 최소화되고 간단한 Python 셸에 복사하여 붙여넣을 수 있도록 해야 합니다.\n- 문제를 재현하기 위해 모델과/또는 데이터셋이 필요한 경우 독자가 해당 모델이나 데이터셋에 접근할 수 있도록 해야 합니다. 모델이나 데이터셋을 [Hub](https://huggingface.co)에 업로드하여 쉽게 다운로드할 수 있도록 할 수 있습니다. 문제 재현을 가능한 한 쉽게하기 위해 모델과 데이터셋을 가능한 한 작게 유지하려고 노력하세요.\n\n자세한 내용은 [좋은 이슈 작성 방법](#how-to-write-a-good-issue) 섹션을 참조하세요.\n\n버그 리포트를 열려면 [여기](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml)를 클릭하세요.\n\n\n#### 2.2. 기능 요청 [[22-feature-requests]]\n\n세계적인 기능 요청은 다음 사항을 다룹니다:\n\n1. 먼저 동기부여:\n* 라이브러리와 관련된 문제/불만이 있나요? 그렇다면 왜 그런지 설명해주세요. 문제를 보여주는 코드 조각을 제공하는 것이 가장 좋습니다.\n* 프로젝트에 필요한 기능인가요? 우리는 그에 대해 듣고 싶습니다!\n* 커뮤니티에 도움이 될 수 있는 것을 작업했고 그것에 대해 생각하고 있는가요? 멋지네요! 어떤 문제를 해결했는지 알려주세요.\n2. 기능을 *상세히 설명하는* 문단을 작성해주세요;\n3. 향후 사용을 보여주는 **코드 조각**을 제공해주세요;\n4. 논문과 관련된 내용인 경우 링크를 첨부해주세요;\n5. 도움이 될 수 있다고 생각되는 추가 정보(그림, 스크린샷 등)를 첨부해주세요.\n\n기능 요청은 [여기](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=)에서 열 수 있습니다.\n\n#### 2.3 피드백 [[23-feedback]]\n\n라이브러리 디자인과 그것이 왜 좋은지 또는 나쁜지에 대한 이유에 대한 피드백은 핵심 메인테이너가 사용자 친화적인 라이브러리를 만드는 데 엄청난 도움이 됩니다. 현재 디자인 철학을 이해하려면 [여기](https://huggingface.co/docs/diffusers/conceptual/philosophy)를 참조해 주세요. 특정 디자인 선택이 현재 디자인 철학과 맞지 않는다고 생각되면, 그 이유와 어떻게 변경되어야 하는지 설명해 주세요. 반대로 특정 디자인 선택이 디자인 철학을 너무 따르기 때문에 사용 사례를 제한한다고 생각되면, 그 이유와 어떻게 변경되어야 하는지 설명해 주세요. 특정 디자인 선택이 매우 유용하다고 생각되면, 향후 디자인 결정에 큰 도움이 되므로 이에 대한 의견을 남겨 주세요.\n\n피드백에 관한 이슈는 [여기](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=)에서 열 수 있습니다.\n\n#### 2.4 기술적인 질문 [[24-technical-questions]]\n\n기술적인 질문은 주로 라이브러리의 특정 코드가 왜 특정 방식으로 작성되었는지 또는 코드의 특정 부분이 무엇을 하는지에 대한 질문입니다. 질문하신 코드 부분에 대한 링크를 제공하고 해당 코드 부분이 이해하기 어려운 이유에 대한 자세한 설명을 해주시기 바랍니다.\n\n기술적인 질문에 관한 이슈를 [여기](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml)에서 열 수 있습니다.\n\n#### 2.5 새로운 모델, 스케줄러 또는 파이프라인 추가 제안 [[25-proposal-to-add-a-new-model-scheduler-or-pipeline]]\n\n만약 diffusion 모델 커뮤니티에서 Diffusers 라이브러리에 추가하고 싶은 새로운 모델, 파이프라인 또는 스케줄러가 있다면, 다음 정보를 제공해주세요:\n\n* Diffusion 파이프라인, 모델 또는 스케줄러에 대한 간단한 설명과 논문 또는 공개된 버전의 링크\n* 해당 모델의 오픈 소스 구현에 대한 링크\n* 모델 가중치가 있는 경우, 가중치의 링크\n\n직접 모델에 기여하고 싶다면, 가장 잘 안내해드릴 수 있습니다. 또한, 가능하다면 구성 요소(모델, 스케줄러, 파이프라인 등)의 원저자를 GitHub 핸들로 태그하는 것을 잊지 마세요.\n\n모델/파이프라인/스케줄러에 대한 요청을 [여기](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml)에서 열 수 있습니다.\n\n### 3. GitHub 이슈 탭에서 문제에 대한 답변하기 [[3-answering-issues-on-the-github-issues-tab]]\n\nGitHub에서 이슈에 대한 답변을 하기 위해서는 Diffusers에 대한 기술적인 지식이 필요할 수 있지만, 정확한 답변이 아니더라도 모두가 시도해기를 권장합니다. 이슈에 대한 고품질 답변을 제공하기 위한 몇 가지 팁:\n- 가능한 한 간결하고 최소한으로 유지합니다.\n- 주제에 집중합니다. 이슈에 대한 답변은 해당 이슈에 관련된 내용에만 집중해야 합니다.\n- 자신의 주장을 증명하거나 장려하는 코드, 논문 또는 기타 출처는 링크를 제공하세요.\n- 코드로 답변합니다. 간단한 코드 조각이 이슈에 대한 답변이거나 이슈를 해결하는 방법을 보여준다면, 완전히 재현 가능한 코드 조각을 제공해주세요.\n\n또한, 많은 이슈들은 단순히 주제와 무관하거나 다른 이슈의 중복이거나 관련이 없는 경우가 많습니다. 이러한 이슈들에 대한 답변을 제공하고, 이슈 작성자에게 더 정확한 정보를 제공하거나, 중복된 이슈에 대한 링크를 제공하거나, [포럼](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) 이나 [Discord](https://discord.gg/G7tWnz98XR)로 리디렉션하는 것은 메인테이너에게 큰 도움이 됩니다.\n\n이슈가 올바른 버그 보고서이고 소스 코드에서 수정이 필요하다고 확인한 경우, 다음 섹션을 살펴보세요.\n\n다음 모든 기여에 대해서는 PR을 열여야 합니다. [Pull Request 열기](#how-to-open-a-pr) 섹션에서 자세히 설명되어 있습니다.\n\n### 4. \"Good first issue\" 고치기 [[4-fixing-a-good-first-issue]]\n\n*Good first issues*는 [Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) 라벨로 표시됩니다. 일반적으로, 이슈는 이미 잠재적인 해결책이 어떻게 보이는지 설명하고 있어서 수정하기 쉽습니다.\n만약 이슈가 아직 닫히지 않았고 이 문제를 해결해보고 싶다면, \"이 이슈를 해결해보고 싶습니다.\"라는 메시지를 남기면 됩니다. 일반적으로 세 가지 시나리오가 있습니다:\n- a.) 이슈 설명에 이미 수정 사항을 제안하는 경우, 해결책이 이해되고 합리적으로 보인다면, PR 또는 드래프트 PR을 열어서 수정할 수 있습니다.\n- b.) 이슈 설명에 수정 사항이 제안되어 있지 않은 경우, 제안한 수정 사항이 가능할지 물어볼 수 있고, Diffusers 팀의 누군가가 곧 답변해줄 것입니다. 만약 어떻게 수정할지 좋은 아이디어가 있다면, 직접 PR을 열어도 됩니다.\n- c.) 이미 이 문제를 해결하기 위해 열린 PR이 있지만, 이슈가 아직 닫히지 않았습니다. PR이 더 이상 진행되지 않았다면, 새로운 PR을 열고 이전 PR에 링크를 걸면 됩니다. PR은 종종 원래 기여자가 갑자기 시간을 내지 못해 더 이상 진행하지 못하는 경우에 더 이상 진행되지 않게 됩니다. 이는 오픈 소스에서 자주 발생하는 일이며 매우 정상적인 상황입니다. 이 경우, 커뮤니티는 새로 시도하고 기존 PR의 지식을 활용해주면 매우 기쁠 것입니다. 이미 PR이 있고 활성화되어 있다면, 제안을 해주거나 PR을 검토하거나 PR에 기여할 수 있는지 물어보는 등 작성자를 도와줄 수 있습니다.\n\n\n### 5. 문서에 기여하기 [[5-contribute-to-the-documentation]]\n\n좋은 라이브러리는 항상 좋은 문서를 갖고 있습니다! 공식 문서는 라이브러리를 처음 사용하는 사용자들에게 첫 번째 접점 중 하나이며, 따라서 문서에 기여하는 것은 매우 가치 있는 기여입니다.\n\n라이브러리에 기여하는 방법은 다양합니다:\n\n- 맞춤법이나 문법 오류를 수정합니다.\n- 공식 문서가 이상하게 표시되거나 링크가 깨진 경우, 올바르게 수정하는 데 시간을 내주시면 매우 기쁠 것입니다.\n- 문서의 입력 또는 출력 텐서의 모양이나 차원을 수정합니다.\n- 이해하기 어렵거나 잘못된 문서를 명확하게 합니다.\n- 오래된 코드 예제를 업데이트합니다.\n- 문서를 다른 언어로 번역합니다.\n\n[공식 Diffusers 문서 페이지](https://huggingface.co/docs/diffusers/index)에 표시된 모든 내용은 공식 문서의 일부이며, 해당 [문서 소스](https://github.com/huggingface/diffusers/tree/main/docs/source)에서 수정할 수 있습니다.\n\n문서에 대한 변경 사항을 로컬에서 확인하는 방법은 [이 페이지](https://github.com/huggingface/diffusers/tree/main/docs)를 참조해주세요.\n\n\n### 6. 커뮤니티 파이프라인에 기여하기 [[6-contribute-a-community-pipeline]]\n\n> [!TIP]\n> 커뮤니티 파이프라인에 대해 자세히 알아보려면 [커뮤니티 파이프라인](../using-diffusers/custom_pipeline_overview#community-pipelines) 가이드를 읽어보세요. 커뮤니티 파이프라인이 왜 필요한지 궁금하다면 GitHub 이슈 [#841](https://github.com/huggingface/diffusers/issues/841)를 확인해보세요 (기본적으로, 우리는 diffusion 모델이 추론에 사용될 수 있는 모든 방법을 유지할 수 없지만 커뮤니티가 이를 구축하는 것을 방해하고 싶지 않습니다).\n\n커뮤니티 파이프라인에 기여하는 것은 창의성과 작업을 커뮤니티와 공유하는 좋은 방법입니다. [`DiffusionPipeline`]을 기반으로 빌드하여 `custom_pipeline` 매개변수를 설정함으로써 누구나 로드하고 사용할 수 있도록 할 수 있습니다. 이 섹션에서는 UNet이 단일 순방향 패스만 수행하고 스케줄러를 한 번 호출하는 간단한 파이프라인 (단계별 파이프라인)을 만드는 방법을 안내합니다.\n\n1. 커뮤니티 파이프라인을 위한 one_step_unet.py 파일을 생성하세요. 이 파일은 사용자에 의해 설치되는 패키지를 포함할 수 있지만, [`DiffusionPipeline`]에서 모델 가중치와 스케줄러 구성을 로드하기 위해 하나의 파이프라인 클래스만 있어야 합니다. `__init__` 함수에 UNet과 스케줄러를 추가하세요.\n\n    또한 [`~DiffusionPipeline.save_pretrained`]를 사용하여 파이프라인과 그 구성 요소를 저장할 수 있도록 `register_modules` 함수를 추가해야 합니다.\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\nclass UnetSchedulerOneForwardPipeline(DiffusionPipeline):\n    def __init__(self, unet, scheduler):\n        super().__init__()\n\n        self.register_modules(unet=unet, scheduler=scheduler)\n```\n\n1. forward 패스에서 (`__call__`로 정의하는 것을 추천합니다), 원하는 어떤 기능이든 추가할 수 있습니다. \"one-step\" 파이프라인의 경우, 무작위 이미지를 생성하고 `timestep=1`로 설정하여 UNet과 스케줄러를 한 번 호출합니다.\n\n```py\n  from diffusers import DiffusionPipeline\n  import torch\n\n  class UnetSchedulerOneForwardPipeline(DiffusionPipeline):\n      def __init__(self, unet, scheduler):\n          super().__init__()\n\n          self.register_modules(unet=unet, scheduler=scheduler)\n\n      def __call__(self):\n          image = torch.randn(\n              (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),\n          )\n          timestep = 1\n\n          model_output = self.unet(image, timestep).sample\n          scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample\n\n          return scheduler_output\n```\n\n이제 UNet과 스케줄러를 전달하여 파이프라인을 실행하거나, 파이프라인 구조가 동일한 경우 사전 학습된 가중치를 로드할 수 있습니다.\n\n```py\nfrom diffusers import DDPMScheduler, UNet2DModel\n\nscheduler = DDPMScheduler()\nunet = UNet2DModel()\n\npipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler)\noutput = pipeline()\n# load pretrained weights\npipeline = UnetSchedulerOneForwardPipeline.from_pretrained(\"google/ddpm-cifar10-32\", use_safetensors=True)\noutput = pipeline()\n```\n\n파이프라인을 GitHub 커뮤니티 파이프라인 또는 Hub 커뮤니티 파이프라인으로 공유할 수 있습니다.\n\n<hfoptions id=\"pipeline type\">\n<hfoption id=\"GitHub pipeline\">\n\nGitHub 파이프라인을 공유하려면 Diffusers [저장소](https://github.com/huggingface/diffusers)에서 Pull Request를 열고 one_step_unet.py 파일을 [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) 하위 폴더에 추가하세요.\n\n</hfoption>\n<hfoption id=\"Hub pipeline\">\n\nHub 파이프라인을 공유하려면, 허브에 모델 저장소를 생성하고 one_step_unet.py 파일을 업로드하세요.\n\n</hfoption>\n</hfoptions>\n\n### 7. 훈련 예제에 기여하기 [[7-contribute-to-training-examples]]\n\nDiffusers 예제는 [examples](https://github.com/huggingface/diffusers/tree/main/examples) 폴더에 있는 훈련 스크립트의 모음입니다.\n\n두 가지 유형의 훈련 예제를 지원합니다:\n\n- 공식 훈련 예제\n- 연구용 훈련 예제\n\n연구용 훈련 예제는 [examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects)에 위치하며, 공식 훈련 예제는 `research_projects` 및 `community` 폴더를 제외한 [examples](https://github.com/huggingface/diffusers/tree/main/examples)의 모든 폴더를 포함합니다.\n공식 훈련 예제는 Diffusers의 핵심 메인테이너가 유지 관리하며, 연구용 훈련 예제는 커뮤니티가 유지 관리합니다.\n이는 공식 파이프라인 vs 커뮤니티 파이프라인에 대한 [6. 커뮤니티 파이프라인 기여하기](#6-contribute-a-community-pipeline)에서 제시한 이유와 동일합니다: 핵심 메인테이너가 diffusion 모델의 모든 가능한 훈련 방법을 유지 관리하는 것은 현실적으로 불가능합니다.\nDiffusers 핵심 메인테이너와 커뮤니티가 특정 훈련 패러다임을 너무 실험적이거나 충분히 대중적이지 않다고 판단한다면, 해당 훈련 코드는 `research_projects` 폴더에 넣고 작성자에 의해 관리되어야 합니다.\n\n공식 훈련 및 연구 예제는 하나 이상의 훈련 스크립트, requirements.txt 파일 및 README.md 파일을 포함하는 디렉토리로 구성됩니다. 사용자가 훈련 예제를 사용하려면 리포지토리를 복제해야 합니다:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\n```\n\n그리고 훈련에 필요한 모든 추가적인 의존성도 설치해야 합니다:\n\n```bash\npip install -r /examples/<your-example-folder>/requirements.txt\n```\n\n따라서 예제를 추가할 때, `requirements.txt` 파일은 훈련 예제에 필요한 모든 pip 종속성을 정의해야 합니다. 이렇게 설치된 모든 종속성을 사용하여 사용자가 예제의 훈련 스크립트를 실행할 수 있어야 합니다. 예를 들어, [DreamBooth `requirements.txt` 파일](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt)을 참조하세요.\n\nDiffusers 라이브러리의 훈련 예제는 다음 철학을 따라야 합니다:\n- 예제를 실행하는 데 필요한 모든 코드는 하나의 Python 파일에 있어야 합니다.\n- 사용자는 명령 줄에서 `python <your-example>.py --args`와 같이 예제를 실행할 수 있어야 합니다.\n- 예제는 간단하게 유지되어야 하며, Diffusers를 사용한 훈련 방법을 보여주는 **예시**로 사용되어야 합니다. 예제 스크립트의 목적은 최첨단 diffusion 모델을 만드는 것이 아니라, 너무 많은 사용자 정의 로직을 추가하지 않고 이미 알려진 훈련 방법을 재현하는 것입니다. 이 점의 부산물로서, 예제는 좋은 교육 자료로써의 역할을 하기 위해 노력합니다.\n\n예제에 기여하기 위해서는, 이미 존재하는 예제인 [dreambooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)와 같은 예제를 참고하여 어떻게 보여야 하는지에 대한 아이디어를 얻는 것이 매우 권장됩니다.\nDiffusers와 긴밀하게 통합되어 있기 때문에, 기여자들이 [Accelerate 라이브러리](https://github.com/huggingface/accelerate)를 사용하는 것을 강력히 권장합니다.\n예제 스크립트가 작동하는 경우, 반드시 예제를 정확하게 사용하는 방법을 설명하는 포괄적인 `README.md`를 추가해야 합니다. 이 README에는 다음이 포함되어야 합니다:\n- [여기](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-locally-with-pytorch)에 표시된 예제 스크립트를 실행하는 방법에 대한 예제 명령어.\n- [여기](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5)에 표시된 훈련 결과 (로그, 모델 등)에 대한 링크로 사용자가 기대할 수 있는 내용을 보여줍니다.\n- 비공식/연구용 훈련 예제를 추가하는 경우, **반드시** git 핸들을 포함하여 이 훈련 예제를 유지 관리할 것임을 명시하는 문장을 추가해야 합니다. [여기](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations)에 표시된 것과 같습니다.\n\n만약 공식 훈련 예제에 기여하는 경우, [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py)에 테스트를 추가하는 것도 확인해주세요. 비공식 훈련 예제에는 이 작업이 필요하지 않습니다.\n\n### 8. \"Good second issue\" 고치기 [[8-fixing-a-good-second-issue]]\n\n\"Good second issue\"는 [Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22) 라벨로 표시됩니다. Good second issue는 [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)보다 해결하기가 더 복잡합니다.\n이슈 설명은 일반적으로 이슈를 해결하는 방법에 대해 덜 구체적이며, 관심 있는 기여자는 라이브러리에 대한 꽤 깊은 이해가 필요합니다.\nGood second issue를 해결하고자 하는 경우, 해당 이슈를 해결하기 위해 PR을 열고 PR을 이슈에 링크하세요. 이미 해당 이슈에 대한 PR이 열려있지만 병합되지 않은 경우, 왜 병합되지 않았는지 이해하기 위해 살펴보고 개선된 PR을 열어보세요.\nGood second issue는 일반적으로 Good first issue 이슈보다 병합하기가 더 어려우므로, 핵심 메인테이너에게 도움을 요청하는 것이 좋습니다. PR이 거의 완료된 경우, 핵심 메인테이너는 PR에 참여하여 커밋하고 병합을 진행할 수 있습니다.\n\n### 9. 파이프라인, 모델, 스케줄러 추가하기 [[9-adding-pipelines-models-schedulers]]\n\n파이프라인, 모델, 스케줄러는 Diffusers 라이브러리에서 가장 중요한 부분입니다.\n이들은 최첨단 diffusion 기술에 쉽게 접근하도록 하며, 따라서 커뮤니티가 강력한 생성형 AI 애플리케이션을 만들 수 있도록 합니다.\n\n새로운 모델, 파이프라인 또는 스케줄러를 추가함으로써, 사용자 인터페이스에 새로운 강력한 사용 사례를 활성화할 수 있으며, 이는 전체 생성형 AI 생태계에 매우 중요한 가치를 제공할 수 있습니다.\n\nDiffusers에는 세 가지 구성 요소에 대한 여러 개발 요청이 있습니다. 특정 구성 요소를 아직 정확히 어떤 것을 추가하고 싶은지 모르는 경우, 다음 링크를 참조하세요:\n- [모델 또는 파이프라인](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)\n- [스케줄러](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)\n\n\n세 가지 구성 요소를 추가하기 전에, [철학 가이드](philosophy)를 읽어보는 것을 강력히 권장합니다. 세 가지 구성 요소 중 어느 것을 추가하든, 디자인 철학과 관련된 API 일관성을 유지하기 위해 우리의 디자인 철학과 크게 다른 구성 요소는 병합할 수 없습니다. 디자인 선택에 근본적으로 동의하지 않는 경우, [피드백 이슈](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=)를 열어 해당 디자인 패턴/선택이 라이브러리 전체에서 변경되어야 하는지, 디자인 철학을 업데이트해야 하는지에 대해 논의할 수 있습니다. 라이브러리 전체의 일관성은 우리에게 매우 중요합니다.\n\nPR에 원본 코드베이스/논문 링크를 추가하고, 가능하면 PR에서 원래 작성자에게 직접 알림을 보내어 진행 상황을 따라갈 수 있도록 해주세요.\n\nPR에서 막힌 경우나 도움이 필요한 경우, 첫 번째 리뷰나 도움을 요청하는 메시지를 남기는 것을 주저하지 마세요.\n\n#### Copied from mechanism [[copied-from-mechanism]]\n\n`# Copied from mechanism` 은 파이프라인, 모델 또는 스케줄러 코드를 추가할 때 이해해야 할 독특하고 중요한 기능입니다. 이것은 Diffusers 코드베이스 전반에서 볼 수 있으며, 이를 사용하는 이유는 코드베이스를 이해하고 유지 관리하기 쉽게 만들기 위해서입니다. `# Copied from mechanism` 으로 표시된 코드는 복사한 코드와 정확히 동일하도록 강제됩니다. 이렇게 하면 `make fix-copies`를 실행할 때마다 여러 파일에 걸쳐 변경 사항을 쉽게 업데이트하고 전파할 수 있습니다.\n\n예를 들어, 아래 코드 예제에서 [`~diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`]은 원래 코드이며, `AltDiffusionPipelineOutput`은 `# Copied from mechanism`을 사용하여 복사합니다. 유일한 차이점은 클래스 접두사를 `Stable`에서 `Alt`로 변경한 것입니다.\n\n```py\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt\nclass AltDiffusionPipelineOutput(BaseOutput):\n    \"\"\"\n    Output class for Alt Diffusion pipelines.\n\n    Args:\n        images (`List[PIL.Image.Image]` or `np.ndarray`)\n            List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,\n            num_channels)`.\n        nsfw_content_detected (`List[bool]`)\n            List indicating whether the corresponding generated image contains \"not-safe-for-work\" (nsfw) content or\n            `None` if safety checking could not be performed.\n    \"\"\"\n```\n\n더 자세히 알고 싶다면 [~Don't~ Repeat Yourself*](https://huggingface.co/blog/transformers-design-philosophy#4-machine-learning-models-are-static) 블로그 포스트의 이 섹션을 읽어보세요.\n\n## 좋은 이슈 작성 방법 [[how-to-write-a-good-issue]]\n\n**이슈를 잘 작성할수록 빠르게 해결될 가능성이 높아집니다.**\n\n1. 이슈에 적절한 템플릿을 사용했는지 확인하세요. [새 이슈를 열 때](https://github.com/huggingface/diffusers/issues/new/choose) 올바른 템플릿을 선택해야 합니다. *버그 보고서*, *기능 요청*, *API 디자인에 대한 피드백*, *새로운 모델/파이프라인/스케줄러 추가*, *포럼*, 또는 빈 이슈 중에서 선택하세요. 이슈를 열 때 올바른 템플릿을 선택하는 것이 중요합니다.\n2. **명확성**: 이슈에 적합한 제목을 지정하세요. 이슈 설명을 가능한 간단하게 작성하세요. 이슈를 이해하고 해결하는 데 걸리는 시간을 줄이기 위해 가능한 한 명확하게 작성하세요. 하나의 이슈에 대해 여러 문제를 포함하지 않도록 주의하세요. 여러 문제를 발견한 경우, 각각의 이슈를 개별적으로 열어주세요. 버그인 경우, 어떤 버그인지 가능한 한 정확하게 설명해야 합니다. \"diffusers에서 오류\"와 같이 간단히 작성하지 마세요.\n3. **재현 가능성**: 재현 가능한 코드 조각이 없으면 해결할 수 없습니다. 버그를 발견한 경우, 유지 관리자는 그 버그를 재현할 수 있어야 합니다. 이슈에 재현 가능한 코드 조각을 포함해야 합니다. 코드 조각은 Python 인터프리터에 복사하여 붙여넣을 수 있는 형태여야 합니다. 코드 조각이 작동해야 합니다. 즉, 누락된 import나 이미지에 대한 링크가 없어야 합니다. 이슈에는 오류 메시지와 정확히 동일한 오류 메시지를 재현하기 위해 수정하지 않고 복사하여 붙여넣을 수 있는 코드 조각이 포함되어야 합니다. 이슈에 사용자의 로컬 모델 가중치나 로컬 데이터를 사용하는 경우, 독자가 액세스할 수 없는 경우 이슈를 해결할 수 없습니다. 데이터나 모델을 공유할 수 없는 경우, 더미 모델이나 더미 데이터를 만들어 사용해보세요.\n4. **간결성**: 가능한 한 간결하게 유지하여 독자가 문제를 빠르게 이해할 수 있도록 도와주세요. 문제와 관련이 없는 코드나 정보는 모두 제거해주세요. 버그를 발견한 경우, 문제를 설명하는 가장 간단한 코드 예제를 만들어보세요. 버그를 발견한 후에는 작업 흐름 전체를 문제에 던지는 것이 아니라, 에러가 발생하는 훈련 코드의 어느 부분이 문제인지 먼저 이해하고 몇 줄로 재현해보세요. 전체 데이터셋 대신 더미 데이터를 사용해보세요.\n5. 링크 추가하기. 특정한 이름, 메서드, 또는 모델을 참조하는 경우, 독자가 더 잘 이해할 수 있도록 링크를 제공해주세요. 특정 PR이나 이슈를 참조하는 경우, 해당 이슈에 링크를 걸어주세요. 독자가 무엇을 말하는지 알고 있다고 가정하지 마세요. 이슈에 링크를 추가할수록 좋습니다.\n6. 포맷팅. 코드를 파이썬 코드 구문으로, 에러 메시지를 일반 코드 구문으로 형식화하여 이슈를 깔끔하게 작성하세요. 자세한 내용은 [GitHub 공식 포맷팅 문서](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax)를 참조하세요.\n7. 여러분의 이슈를 단순히 해결해야 할 티켓으로 생각하지 말고, 잘 작성된 백과사전 항목으로 생각해보세요. 추가된 모든 이슈는 공개적으로 이용 가능한 지식에 대한 기여입니다. 잘 작성된 이슈를 추가함으로써 메인테이너가 여러분의 이슈를 더 쉽게 해결할 수 있게 할 뿐만 아니라, 전체 커뮤니티가 라이브러리의 특정 측면을 더 잘 이해할 수 있도록 도움을 주게 됩니다.\n\n## 좋은 PR 작성 방법 [[how-to-write-a-good-pr]]\n\n1. 카멜레온이 되세요. 기존의 디자인 패턴과 구문을 이해하고, 여러분이 추가하는 코드가 기존 코드베이스와 자연스럽게 어우러지도록 해야 합니다. 기존 디자인 패턴이나 사용자 인터페이스와 크게 다른 Pull Request들은 병합되지 않습니다.\n2. 레이저처럼 집중하세요. Pull Request는 하나의 문제, 오직 하나의 문제만 해결해야 합니다. \"이왕 추가하는 김에 다른 문제도 고치자\"는 함정에 빠지지 않도록 주의하세요. 여러 개의 관련 없는 문제를 해결하는 한 번에 해결하는 Pull Request들은 검토하기가 훨씬 더 어렵습니다.\n3. 도움이 되는 경우, 추가한 내용이 어떻게 사용되는지 예제 코드 조각을 추가해보세요.\n4. Pull Request의 제목은 기여 내용을 요약해야 합니다.\n5. Pull Request가 이슈를 해결하는 경우, Pull Request의 설명에 이슈 번호를 언급하여 연결되도록 해주세요 (이슈를 참조하는 사람들이 작업 중임을 알 수 있도록).\n6. 진행 중인 작업을 나타내려면 제목에 `[WIP]`를 접두사로 붙여주세요. 이는 중복 작업을 피하고, 병합 준비가 된 PR과 구분할 수 있도록 도움이 됩니다.\n7. [좋은 이슈를 작성하는 방법](#how-to-write-a-good-issue)에 설명된 대로 텍스트를 구성하고 형식을 지정해보세요.\n8. 기존 테스트가 통과하는지 확인하세요\n9. 높은 커버리지를 가진 테스트를 추가하세요. 품질 테스트가 없으면 병합할 수 없습니다.\n- 새로운 `@slow` 테스트를 추가하는 경우, 다음 명령을 사용하여 통과하는지 확인하세요.\n`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.\nCircleCI는 느린 테스트를 실행하지 않지만, GitHub Actions는 매일 실행합니다!\n10. 모든 공개 메서드는 마크다운과 잘 작동하는 정보성 docstring을 가져야 합니다. 예시로 [`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py)를 참조하세요.\n11. 리포지토리가 빠르게 성장하고 있기 때문에, 리포지토리에 큰 부담을 주는 파일이 추가되지 않도록 주의해야 합니다. 이미지, 비디오 및 기타 텍스트가 아닌 파일을 포함합니다. 이러한 파일을 배치하기 위해 hf.co 호스팅 `dataset`인 [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) 또는 [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images)를 활용하는 것이 우선입니다.\n외부 기여인 경우, 이미지를 PR에 추가하고 Hugging Face 구성원에게 이미지를 이 데이터셋으로 이동하도록 요청하세요.\n\n## PR을 열기 위한 방법 [[how-to-open-a-pr]]\n\n코드를 작성하기 전에, 이미 누군가가 같은 작업을 하고 있는지 확인하기 위해 기존의 PR이나 이슈를 검색하는 것이 좋습니다. 확실하지 않은 경우, 피드백을 받기 위해 이슈를 열어보는 것이 항상 좋은 아이디어입니다.\n\n🧨 Diffusers에 기여하기 위해서는 기본적인 `git` 사용법을 알아야 합니다. `git`은 가장 쉬운 도구는 아니지만, 가장 훌륭한 매뉴얼을 가지고 있습니다. 셸에서 `git --help`을 입력하고 즐기세요. 책을 선호하는 경우, [Pro Git](https://git-scm.com/book/en/v2)은 매우 좋은 참고 자료입니다.\n\n다음 단계를 따라 기여를 시작하세요 ([지원되는 Python 버전](https://github.com/huggingface/diffusers/blob/main/setup.py#L244)):\n\n1. 저장소 페이지에서 'Fork' 버튼을 클릭하여 [저장소](https://github.com/huggingface/diffusers)를 포크합니다. 이렇게 하면 코드의 사본이 GitHub 사용자 계정에 생성됩니다.\n\n2. 포크한 저장소를 로컬 디스크에 클론하고, 기본 저장소를 원격으로 추가하세요:\n\n ```bash\n $ git clone git@github.com:<your GitHub handle>/diffusers.git\n $ cd diffusers\n $ git remote add upstream https://github.com/huggingface/diffusers.git\n ```\n\n3. 개발 변경 사항을 보관할 새로운 브랜치를 생성하세요:\n\n ```bash\n $ git checkout -b a-descriptive-name-for-my-changes\n ```\n\n`main` 브랜치 위에서 **절대** 작업하지 마세요.\n\n4. 가상 환경에서 다음 명령을 실행하여 개발 환경을 설정하세요:\n\n ```bash\n $ pip install -e \".[dev]\"\n ```\n\n만약 저장소를 이미 클론한 경우, 가장 최신 변경 사항을 가져오기 위해 `git pull`을 실행해야 할 수도 있습니다.\n\n5. 기능을 브랜치에서 개발하세요.\n\n기능을 작업하는 동안 테스트 스위트가 통과되는지 확인해야 합니다. 다음과 같이 변경 사항에 영향을 받는 테스트를 실행해야 합니다:\n\n ```bash\n $ pytest tests/<TEST_TO_RUN>.py\n ```\n\n테스트를 실행하기 전에 테스트를 위해 필요한 의존성들을 설치하였는지 확인하세요. 다음의 커맨드를 통해서 확인할 수 있습니다:\n\n ```bash\n $ pip install -e \".[test]\"\n ```\n\n다음 명령어로 전체 테스트 묶음 실행할 수도 있지만, Diffusers가 많이 성장하였기 때문에 결과를 적당한 시간 내에 생성하기 위해서는 강력한 컴퓨터가 필요합니다. 다음은 해당 명령어입니다:\n\n ```bash\n $ make test\n ```\n\n🧨 Diffusers는 소스 코드를 일관되게 포맷팅하기 위해 `black`과 `isort`를 사용합니다. 변경 사항을 적용한 후에는 다음과 같이 자동 스타일 수정 및 코드 검증을 적용할 수 있습니다:\n\n\n ```bash\n $ make style\n ```\n\n🧨 Diffusers `ruff`와 몇개의 커스텀 스크립트를 이용하여 코딩 실수를 확인합니다. 품질 제어는 CI에서 작동하지만, 동일한 검사를 다음을 통해서도 할 수 있습니다:\n\n ```bash\n $ make quality\n ```\n\n변경사항에 대해 만족한다면 `git add`를 사용하여 변경된 파일을 추가하고 `git commit`을 사용하여 변경사항에 대해 로컬상으로 저장한다:\n\n ```bash\n $ git add modified_file.py\n $ git commit -m \"A descriptive message about your changes.\"\n ```\n\n코드를 정기적으로 원본 저장소와 동기화하는 것은 좋은 아이디어입니다. 이렇게 하면 변경 사항을 빠르게 반영할 수 있습니다:\n\n ```bash\n $ git pull upstream main\n ```\n\n변경 사항을 계정에 푸시하려면 다음을 사용하세요:\n\n ```bash\n $ git push -u origin a-descriptive-name-for-my-changes\n ```\n\n6. 만족하셨다면, GitHub에서 포크한 웹페이지로 이동하여 'Pull request'를 클릭하여 변경사항을 프로젝트 메인테이너에게 검토를 요청합니다.\n\n7. 메인테이너가 변경 사항을 요청하는 것은 괜찮습니다. 핵심 기여자들에게도 일어나는 일입니다! 따라서 변경 사항을 Pull request에서 볼 수 있도록 로컬 브랜치에서 작업하고 변경 사항을 포크에 푸시하면 자동으로 Pull request에 나타납니다.\n\n### 테스트 [[tests]]\n\n라이브러리 동작과 여러 예제를 테스트하기 위해 포괄적인 테스트 묶음이 포함되어 있습니다. 라이브러리 테스트는 [tests 폴더](https://github.com/huggingface/diffusers/tree/main/tests)에서 찾을 수 있습니다.\n\n`pytest`와 `pytest-xdist`를 선호하는 이유는 더 빠르기 때문입니다. 루트 디렉토리에서 라이브러리를 위해 `pytest`로 테스트를 실행하는 방법은 다음과 같습니다:\n\n```bash\n$ python -m pytest -n auto --dist=loadfile -s -v ./tests/\n```\n\n사실, `make test`는 이렇게 구현되어 있습니다!\n\n작업 중인 기능만 테스트하기 위해 더 작은 테스트 세트를 지정할 수 있습니다.\n\n기본적으로 느린 테스트는 건너뜁니다. `RUN_SLOW` 환경 변수를 `yes`로 설정하여 실행할 수 있습니다. 이는 많은 기가바이트의 모델을 다운로드합니다. 충분한 디스크 공간과 좋은 인터넷 연결 또는 많은 인내심이 필요합니다!\n\n```bash\n$ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/\n```\n\n`unittest`는 완전히 지원됩니다. 다음은 `unittest`를 사용하여 테스트를 실행하는 방법입니다:\n\n```bash\n$ python -m unittest discover -s tests -t . -v\n$ python -m unittest discover -s examples -t examples -v\n```\n\n### upstream(HuggingFace) main과 forked main 동기화하기 [[syncing-forked-main-with-upstream-huggingface-main]]\n\nupstream 저장소에 불필요한 참조 노트를 추가하고 관련 개발자에게 알림을 보내는 것을 피하기 위해,\nforked 저장소의 main 브랜치를 동기화할 때 다음 단계를 따르세요:\n1. 가능한 경우, forked 저장소에서 브랜치와 PR을 사용하여 upstream과 동기화하는 것을 피하세요. 대신 forked main으로 직접 병합하세요.\n2. PR이 절대적으로 필요한 경우, 브랜치를 체크아웃한 후 다음 단계를 사용하세요:\n```bash\n$ git checkout -b your-branch-for-syncing\n$ git pull --squash --no-commit upstream main\n$ git commit -m '<your message without GitHub references>'\n$ git push --set-upstream origin your-branch-for-syncing\n```\n\n### 스타일 가이드 [[style-guide]]\n\nDocumentation string에 대해서는, 🧨 Diffusers는 [Google 스타일](https://google.github.io/styleguide/pyguide.html)을 따릅니다.\n"
  },
  {
    "path": "docs/source/ko/conceptual/ethical_guidelines.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 🧨 Diffusers의 윤리 지침 [[-diffusers-ethical-guidelines]]\n\n## 서문 [[preamble]]\n\n[Diffusers](https://huggingface.co/docs/diffusers/index)는 사전 훈련된 diffusion 모델을 제공하며, 추론과 훈련을 위한 모듈형 툴박스로 활용됩니다.\n\n이 기술의 실제 적용 사례와 사회에 미칠 수 있는 잠재적 부정적 영향을 고려할 때, Diffusers 라이브러리의 개발, 사용자 기여, 사용에 윤리 지침을 제공하는 것이 중요하다고 생각합니다.\n\n이 기술 사용과 관련된 위험은 여전히 검토 중이지만, 예를 들면: 예술가의 저작권 문제, 딥페이크 악용, 부적절한 맥락에서의 성적 콘텐츠 생성, 비동의 사칭, 소수자 집단 억압을 영속화하는 유해한 사회적 편견 등이 있습니다.\n우리는 이러한 위험을 지속적으로 추적하고, 커뮤니티의 반응과 소중한 피드백에 따라 아래 지침을 조정할 것입니다.\n\n## 범위 [[scope]]\n\nDiffusers 커뮤니티는 프로젝트 개발에 다음 윤리 지침을 적용하며, 특히 윤리적 문제와 관련된 민감한 주제에 대해 커뮤니티의 기여를 조정하는 데 도움을 줄 것입니다.\n\n## 윤리 지침 [[ethical-guidelines]]\n\n다음 윤리 지침은 일반적으로 적용되지만, 윤리적으로 민감한 문제와 관련된 기술적 선택을 할 때 우선적으로 적용됩니다. 또한, 해당 기술의 최신 동향과 관련된 새로운 위험이 발생함에 따라 이러한 윤리 원칙을 지속적으로 조정할 것을 약속합니다.\n\n- **투명성**: 우리는 PR 관리, 사용자에게 선택의 이유 설명, 기술적 의사결정 과정에서 투명성을 유지할 것을 약속합니다.\n\n- **일관성**: 프로젝트 관리에서 모든 사용자에게 동일한 수준의 관심을 보장하고, 기술적으로 안정적이고 일관된 상태를 유지할 것을 약속합니다.\n\n- **간결성**: Diffusers 라이브러리를 쉽게 사용하고 활용할 수 있도록, 프로젝트의 목표를 간결하고 일관성 있게 유지할 것을 약속합니다.\n\n- **접근성**: Diffusers 프로젝트는 기술적 전문지식이 없어도 기여할 수 있도록 진입장벽을 낮춥니다. 이를 통해 연구 결과물이 커뮤니티에 더 잘 접근될 수 있습니다.\n\n- **재현성**: 우리는 Diffusers 라이브러리를 통해 제공되는 업스트림 코드, 모델, 데이터셋의 재현성에 대해 투명하게 공개하는 것을 목표로 합니다.\n\n- **책임**: 커뮤니티와 팀워크를 통해, 이 기술의 잠재적 위험을 예측하고 완화하는 데 공동 책임을 집니다.\n\n## 구현 사례: 안전 기능과 메커니즘 [[examples-of-implementations-safety-features-and-mechanisms]]\n\n팀은 diffusion 기술과 관련된 잠재적 윤리 및 사회적 위험에 대응하기 위해 기술적·비기술적 도구를 제공하고자 노력하고 있습니다. 또한, 커뮤니티의 참여는 이러한 기능 구현과 인식 제고에 매우 중요합니다.\n\n- [**커뮤니티 탭**](https://huggingface.co/docs/hub/repositories-pull-requests-discussions): 커뮤니티가 프로젝트에 대해 토론하고 더 나은 협업을 할 수 있도록 지원합니다.\n\n- **편향 탐색 및 평가**: Hugging Face 팀은 Stable Diffusion 모델의 편향성을 대화형으로 보여주는 [space](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)를 제공합니다. 우리는 이러한 편향 탐색과 평가를 지원하고 장려합니다.\n\n- **배포에서의 안전 유도**\n\n  - [**안전한 Stable Diffusion**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_safe): 필터링되지 않은 웹 크롤링 데이터셋으로 훈련된 Stable Diffusion과 같은 모델이 부적절하게 변질되는 문제를 완화합니다. 관련 논문: [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105).\n\n  - [**안전 검사기**](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py): 생성된 이미지가 임베딩 공간에서 하드코딩된 유해 개념 클래스와 일치할 확률을 확인하고 비교합니다. 유해 개념은 역공학을 방지하기 위해 의도적으로 숨겨져 있습니다.\n\n- **Hub에서의 단계적 배포**: 특히 민감한 상황에서는 일부 리포지토리에 대한 접근을 제한할 수 있습니다. 단계적 배포는 리포지토리 작성자가 사용에 대해 더 많은 통제권을 갖도록 하는 중간 단계입니다.\n\n- **라이선싱**: [OpenRAILs](https://huggingface.co/blog/open_rail)와 같은 새로운 유형의 라이선스를 통해 자유로운 접근을 보장하면서도 보다 책임 있는 사용을 위한 일련의 제한을 둘 수 있습니다.\n"
  },
  {
    "path": "docs/source/ko/conceptual/evaluation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Diffusion 모델 평가하기[[evaluating-diffusion-models]]\n\n<a target=\"_blank\" href=\"https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/evaluation.ipynb\">\n    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n</a>\n\n[Stable Diffusion](https://huggingface.co/docs/diffusers/stable_diffusion)와 같은 생성 모델의 평가는 주관적인 성격을 가지고 있습니다. 그러나 실무자와 연구자로서 우리는 종종 다양한 가능성 중에서 신중한 선택을 해야 합니다. 그래서 다양한 생성 모델 (GAN, Diffusion 등)을 사용할 때 어떻게 선택해야 할까요?\n\n정성적인 평가는 모델의 이미지 품질에 대한 주관적인 평가이므로 오류가 발생할 수 있고 결정에 잘못된 영향을 미칠 수 있습니다. 반면, 정량적인 평가는 이미지 품질과 직접적인 상관관계를 갖지 않을 수 있습니다. 따라서 일반적으로 정성적 평가와 정량적 평가를 모두 고려하는 것이 더 강력한 신호를 제공하여 모델 선택에 도움이 됩니다.\n\n이 문서에서는 Diffusion 모델을 평가하기 위한 정성적 및 정량적 방법에 대해 상세히 설명합니다. 정량적 방법에 대해서는 특히 `diffusers`와 함께 구현하는 방법에 초점을 맞추었습니다.\n\n이 문서에서 보여진 방법들은 기반 생성 모델을 고정시키고 다양한 [노이즈 스케줄러](https://huggingface.co/docs/diffusers/main/en/api/schedulers/overview)를 평가하는 데에도 사용할 수 있습니다.\n\n## 시나리오[[scenarios]]\n다음과 같은 파이프라인을 사용하여 Diffusion 모델을 다룹니다:\n\n- 텍스트로 안내된 이미지 생성 (예: [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/text2img)).\n- 입력 이미지에 추가로 조건을 건 텍스트로 안내된 이미지 생성 (예: [`StableDiffusionImg2ImgPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/img2img) 및 [`StableDiffusionInstructPix2PixPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix)).\n- 클래스 조건화된 이미지 생성 모델 (예: [`DiTPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/dit)).\n\n## 정성적 평가[[qualitative-evaluation]]\n\n정성적 평가는 일반적으로 생성된 이미지의 인간 평가를 포함합니다. 품질은 구성성, 이미지-텍스트 일치, 공간 관계 등과 같은 측면에서 측정됩니다. 일반적인 프롬프트는 주관적인 지표에 대한 일정한 기준을 제공합니다.\nDrawBench와 PartiPrompts는 정성적인 벤치마킹에 사용되는 프롬프트 데이터셋입니다. DrawBench와 PartiPrompts는 각각 [Imagen](https://imagen.research.google/)과 [Parti](https://parti.research.google/)에서 소개되었습니다.\n\n[Parti 공식 웹사이트](https://parti.research.google/)에서 다음과 같이 설명하고 있습니다:\n\n> PartiPrompts (P2)는 이 작업의 일부로 공개되는 영어로 된 1600개 이상의 다양한 프롬프트 세트입니다. P2는 다양한 범주와 도전 측면에서 모델의 능력을 측정하는 데 사용할 수 있습니다.\n\n![parti-prompts](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts.png)\n\nPartiPrompts는 다음과 같은 열을 가지고 있습니다:\n\n- 프롬프트 (Prompt)\n- 프롬프트의 카테고리 (예: \"Abstract\", \"World Knowledge\" 등)\n- 난이도를 반영한 챌린지 (예: \"Basic\", \"Complex\", \"Writing & Symbols\" 등)\n\n이러한 벤치마크는 서로 다른 이미지 생성 모델을 인간 평가로 비교할 수 있도록 합니다.\n\n이를 위해 🧨 Diffusers 팀은 **Open Parti Prompts**를 구축했습니다. 이는 Parti Prompts를 기반으로 한 커뮤니티 기반의 질적 벤치마크로, 최첨단 오픈 소스 확산 모델을 비교하는 데 사용됩니다:\n- [Open Parti Prompts 게임](https://huggingface.co/spaces/OpenGenAI/open-parti-prompts): 10개의 parti prompt에 대해 4개의 생성된 이미지가 제시되며, 사용자는 프롬프트에 가장 적합한 이미지를 선택합니다.\n- [Open Parti Prompts 리더보드](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard): 현재 최고의 오픈 소스 diffusion 모델들을 서로 비교하는 리더보드입니다.\n\n이미지를 수동으로 비교하려면, `diffusers`를 사용하여 몇가지 PartiPrompts를 어떻게 활용할 수 있는지 알아봅시다.\n\n다음은 몇 가지 다른 도전에서 샘플링한 프롬프트를 보여줍니다: Basic, Complex, Linguistic Structures, Imagination, Writing & Symbols. 여기서는 PartiPrompts를 [데이터셋](https://huggingface.co/datasets/nateraw/parti-prompts)으로 사용합니다.\n\n```python\nfrom datasets import load_dataset\n\n# prompts = load_dataset(\"nateraw/parti-prompts\", split=\"train\")\n# prompts = prompts.shuffle()\n# sample_prompts = [prompts[i][\"Prompt\"] for i in range(5)]\n\n# Fixing these sample prompts in the interest of reproducibility.\nsample_prompts = [\n    \"a corgi\",\n    \"a hot air balloon with a yin-yang symbol, with the moon visible in the daytime sky\",\n    \"a car with no windows\",\n    \"a cube made of porcupine\",\n    'The saying \"BE EXCELLENT TO EACH OTHER\" written on a red brick wall with a graffiti image of a green alien wearing a tuxedo. A yellow fire hydrant is on a sidewalk in the foreground.',\n]\n```\n이제 이런 프롬프트를 사용하여 Stable Diffusion ([v1-4 checkpoint](https://huggingface.co/CompVis/stable-diffusion-v1-4))를 사용한 이미지 생성을 할 수 있습니다 :\n\n```python\nimport torch\n\nseed = 0\ngenerator = torch.manual_seed(seed)\n\nimages = sd_pipeline(sample_prompts, num_images_per_prompt=1, generator=generator).images\n```\n\n![parti-prompts-14](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-14.png)\n\n\n`num_images_per_prompt`를 설정하여 동일한 프롬프트에 대해 다른 이미지를 비교할 수도 있습니다. 다른 체크포인트([v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5))로 동일한 파이프라인을 실행하면 다음과 같은 결과가 나옵니다:\n\n![parti-prompts-15](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-15.png)\n\n\n다양한 모델을 사용하여 모든 프롬프트에서 생성된 여러 이미지들이 생성되면 (평가 과정에서) 이러한 결과물들은 사람 평가자들에게 점수를 매기기 위해 제시됩니다. DrawBench와 PartiPrompts 벤치마크에 대한 자세한 내용은 각각의 논문을 참조하십시오.\n\n> [!TIP]\n> 모델이 훈련 중일 때 추론 샘플을 살펴보는 것은 훈련 진행 상황을 측정하는 데 유용합니다. [훈련 스크립트](https://github.com/huggingface/diffusers/tree/main/examples/)에서는 TensorBoard와 Weights & Biases에 대한 추가 지원과 함께 이 유틸리티를 지원합니다.\n\n## 정량적 평가[[quantitative-evaluation]]\n\n이 섹션에서는 세 가지 다른 확산 파이프라인을 평가하는 방법을 안내합니다:\n\n- CLIP 점수\n- CLIP 방향성 유사도\n- FID\n\n### 텍스트 안내 이미지 생성[[text-guided-image-generation]]\n\n[CLIP 점수](https://huggingface.co/papers/2104.08718)는 이미지-캡션 쌍의 호환성을 측정합니다. 높은 CLIP 점수는 높은 호환성🔼을 나타냅니다. CLIP 점수는 이미지와 캡션 사이의 의미적 유사성으로 생각할 수도 있습니다. CLIP 점수는 인간 판단과 높은 상관관계를 가지고 있습니다.\n\n[`StableDiffusionPipeline`]을 일단 로드해봅시다:\n\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\nmodel_ckpt = \"CompVis/stable-diffusion-v1-4\"\nsd_pipeline = StableDiffusionPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16).to(\"cuda\")\n```\n\n여러 개의 프롬프트를 사용하여 이미지를 생성합니다:\n\n```python\nprompts = [\n    \"a photo of an astronaut riding a horse on mars\",\n    \"A high tech solarpunk utopia in the Amazon rainforest\",\n    \"A pikachu fine dining with a view to the Eiffel Tower\",\n    \"A mecha robot in a favela in expressionist style\",\n    \"an insect robot preparing a delicious meal\",\n    \"A small cabin on top of a snowy mountain in the style of Disney, artstation\",\n]\n\nimages = sd_pipeline(prompts, num_images_per_prompt=1, output_type=\"np\").images\n\nprint(images.shape)\n# (6, 512, 512, 3)\n```\n\n그러고 나서 CLIP 점수를 계산합니다.\n\n```python\nfrom torchmetrics.functional.multimodal import clip_score\nfrom functools import partial\n\nclip_score_fn = partial(clip_score, model_name_or_path=\"openai/clip-vit-base-patch16\")\n\ndef calculate_clip_score(images, prompts):\n    images_int = (images * 255).astype(\"uint8\")\n    clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach()\n    return round(float(clip_score), 4)\n\nsd_clip_score = calculate_clip_score(images, prompts)\nprint(f\"CLIP score: {sd_clip_score}\")\n# CLIP score: 35.7038\n```\n\n위의 예제에서는 각 프롬프트 당 하나의 이미지를 생성했습니다. 만약 프롬프트 당 여러 이미지를 생성한다면, 프롬프트 당 생성된 이미지의 평균 점수를 사용해야 합니다.\n\n이제 [`StableDiffusionPipeline`]과 호환되는 두 개의 체크포인트를 비교하려면, 파이프라인을 호출할 때 generator를 전달해야 합니다. 먼저, 고정된 시드로 [v1-4 Stable Diffusion 체크포인트](https://huggingface.co/CompVis/stable-diffusion-v1-4)를 사용하여 이미지를 생성합니다:\n\n```python\nseed = 0\ngenerator = torch.manual_seed(seed)\n\nimages = sd_pipeline(prompts, num_images_per_prompt=1, generator=generator, output_type=\"np\").images\n```\n\n그런 다음 [v1-5 checkpoint](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)를 로드하여 이미지를 생성합니다:\n\n```python\nmodel_ckpt_1_5 = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nsd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=weight_dtype).to(device)\n\nimages_1_5 = sd_pipeline_1_5(prompts, num_images_per_prompt=1, generator=generator, output_type=\"np\").images\n```\n\n그리고 마지막으로 CLIP 점수를 비교합니다:\n\n```python\nsd_clip_score_1_4 = calculate_clip_score(images, prompts)\nprint(f\"CLIP Score with v-1-4: {sd_clip_score_1_4}\")\n# CLIP Score with v-1-4: 34.9102\n\nsd_clip_score_1_5 = calculate_clip_score(images_1_5, prompts)\nprint(f\"CLIP Score with v-1-5: {sd_clip_score_1_5}\")\n# CLIP Score with v-1-5: 36.2137\n```\n\n[v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 체크포인트가 이전 버전보다 더 나은 성능을 보이는 것 같습니다. 그러나 CLIP 점수를 계산하기 위해 사용한 프롬프트의 수가 상당히 적습니다. 보다 실용적인 평가를 위해서는 이 수를 훨씬 높게 설정하고, 프롬프트를 다양하게 사용해야 합니다.\n\n> [!WARNING]\n> 이 점수에는 몇 가지 제한 사항이 있습니다. 훈련 데이터셋의 캡션은 웹에서 크롤링되어 이미지와 관련된 `alt` 및 유사한 태그에서 추출되었습니다. 이들은 인간이 이미지를 설명하는 데 사용할 수 있는 것과 일치하지 않을 수 있습니다. 따라서 여기서는 몇 가지 프롬프트를 \"엔지니어링\"해야 했습니다.\n\n### 이미지 조건화된 텍스트-이미지 생성[[image-conditioned-text-to-image-generation]]\n\n이 경우, 생성 파이프라인을 입력 이미지와 텍스트 프롬프트로 조건화합니다. [`StableDiffusionInstructPix2PixPipeline`]을 예로 들어보겠습니다. 이는 편집 지시문을 입력 프롬프트로 사용하고 편집할 입력 이미지를 사용합니다.\n\n다음은 하나의 예시입니다:\n\n![edit-instruction](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png)\n\n모델을 평가하는 한 가지 전략은 두 이미지 캡션 간의 변경과([CLIP-Guided Domain Adaptation of Image Generators](https://huggingface.co/papers/2108.00946)에서 보여줍니다) 함께 두 이미지 사이의 변경의 일관성을 측정하는 것입니다 ([CLIP](https://huggingface.co/docs/transformers/model_doc/clip) 공간에서). 이를 \"**CLIP 방향성 유사성**\"이라고 합니다.\n\n- 캡션 1은 편집할 이미지 (이미지 1)에 해당합니다.\n- 캡션 2는 편집된 이미지 (이미지 2)에 해당합니다. 편집 지시를 반영해야 합니다.\n\n다음은 그림으로 된 개요입니다:\n\n![edit-consistency](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-consistency.png)\n\n우리는 이 측정 항목을 구현하기 위해 미니 데이터 세트를 준비했습니다. 먼저 데이터 세트를 로드해 보겠습니다.\n\n```python\nfrom datasets import load_dataset\n\ndataset = load_dataset(\"sayakpaul/instructpix2pix-demo\", split=\"train\")\ndataset.features\n```\n\n```bash\n{'input': Value(dtype='string', id=None),\n 'edit': Value(dtype='string', id=None),\n 'output': Value(dtype='string', id=None),\n 'image': Image(decode=True, id=None)}\n```\n\n여기에는 다음과 같은 항목이 있습니다:\n\n- `input`은 `image`에 해당하는 캡션입니다.\n- `edit`은 편집 지시사항을 나타냅니다.\n- `output`은 `edit` 지시사항을 반영한 수정된 캡션입니다.\n\n샘플을 살펴보겠습니다.\n\n```python\nidx = 0\nprint(f\"Original caption: {dataset[idx]['input']}\")\nprint(f\"Edit instruction: {dataset[idx]['edit']}\")\nprint(f\"Modified caption: {dataset[idx]['output']}\")\n```\n\n```bash\nOriginal caption: 2. FAROE ISLANDS: An archipelago of 18 mountainous isles in the North Atlantic Ocean between Norway and Iceland, the Faroe Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'\nEdit instruction: make the isles all white marble\nModified caption: 2. WHITE MARBLE ISLANDS: An archipelago of 18 mountainous white marble isles in the North Atlantic Ocean between Norway and Iceland, the White Marble Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'\n```\n\n다음은 이미지입니다:\n\n```python\ndataset[idx][\"image\"]\n```\n\n![edit-dataset](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-dataset.png)\n\n먼저 편집 지시사항을 사용하여 데이터 세트의 이미지를 편집하고 방향 유사도를 계산합니다.\n\n[`StableDiffusionInstructPix2PixPipeline`]를 먼저 로드합니다:\n\n```python\nfrom diffusers import StableDiffusionInstructPix2PixPipeline\n\ninstruct_pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(\n    \"timbrooks/instruct-pix2pix\", torch_dtype=torch.float16\n).to(device)\n```\n\n이제 편집을 수행합니다:\n\n```python\nimport numpy as np\n\n\ndef edit_image(input_image, instruction):\n    image = instruct_pix2pix_pipeline(\n        instruction,\n        image=input_image,\n        output_type=\"np\",\n        generator=generator,\n    ).images[0]\n    return image\n\ninput_images = []\noriginal_captions = []\nmodified_captions = []\nedited_images = []\n\nfor idx in range(len(dataset)):\n    input_image = dataset[idx][\"image\"]\n    edit_instruction = dataset[idx][\"edit\"]\n    edited_image = edit_image(input_image, edit_instruction)\n\n    input_images.append(np.array(input_image))\n    original_captions.append(dataset[idx][\"input\"])\n    modified_captions.append(dataset[idx][\"output\"])\n    edited_images.append(edited_image)\n```\n방향 유사도를 계산하기 위해서는 먼저 CLIP의 이미지와 텍스트 인코더를 로드합니다:\n\n```python\nfrom transformers import (\n    CLIPTokenizer,\n    CLIPTextModelWithProjection,\n    CLIPVisionModelWithProjection,\n    CLIPImageProcessor,\n)\n\nclip_id = \"openai/clip-vit-large-patch14\"\ntokenizer = CLIPTokenizer.from_pretrained(clip_id)\ntext_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to(device)\nimage_processor = CLIPImageProcessor.from_pretrained(clip_id)\nimage_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(device)\n```\n\n주목할 점은 특정한 CLIP 체크포인트인 `openai/clip-vit-large-patch14`를 사용하고 있다는 것입니다. 이는 Stable Diffusion 사전 훈련이 이 CLIP 변형체와 함께 수행되었기 때문입니다. 자세한 내용은 [문서](https://huggingface.co/docs/transformers/model_doc/clip)를 참조하세요.\n\n다음으로, 방향성 유사도를 계산하기 위해 PyTorch의 `nn.Module`을 준비합니다:\n\n```python\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass DirectionalSimilarity(nn.Module):\n    def __init__(self, tokenizer, text_encoder, image_processor, image_encoder):\n        super().__init__()\n        self.tokenizer = tokenizer\n        self.text_encoder = text_encoder\n        self.image_processor = image_processor\n        self.image_encoder = image_encoder\n\n    def preprocess_image(self, image):\n        image = self.image_processor(image, return_tensors=\"pt\")[\"pixel_values\"]\n        return {\"pixel_values\": image.to(device)}\n\n    def tokenize_text(self, text):\n        inputs = self.tokenizer(\n            text,\n            max_length=self.tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        return {\"input_ids\": inputs.input_ids.to(device)}\n\n    def encode_image(self, image):\n        preprocessed_image = self.preprocess_image(image)\n        image_features = self.image_encoder(**preprocessed_image).image_embeds\n        image_features = image_features / image_features.norm(dim=1, keepdim=True)\n        return image_features\n\n    def encode_text(self, text):\n        tokenized_text = self.tokenize_text(text)\n        text_features = self.text_encoder(**tokenized_text).text_embeds\n        text_features = text_features / text_features.norm(dim=1, keepdim=True)\n        return text_features\n\n    def compute_directional_similarity(self, img_feat_one, img_feat_two, text_feat_one, text_feat_two):\n        sim_direction = F.cosine_similarity(img_feat_two - img_feat_one, text_feat_two - text_feat_one)\n        return sim_direction\n\n    def forward(self, image_one, image_two, caption_one, caption_two):\n        img_feat_one = self.encode_image(image_one)\n        img_feat_two = self.encode_image(image_two)\n        text_feat_one = self.encode_text(caption_one)\n        text_feat_two = self.encode_text(caption_two)\n        directional_similarity = self.compute_directional_similarity(\n            img_feat_one, img_feat_two, text_feat_one, text_feat_two\n        )\n        return directional_similarity\n```\n\n이제 `DirectionalSimilarity`를 사용해 보겠습니다.\n\n```python\ndir_similarity = DirectionalSimilarity(tokenizer, text_encoder, image_processor, image_encoder)\nscores = []\n\nfor i in range(len(input_images)):\n    original_image = input_images[i]\n    original_caption = original_captions[i]\n    edited_image = edited_images[i]\n    modified_caption = modified_captions[i]\n\n    similarity_score = dir_similarity(original_image, edited_image, original_caption, modified_caption)\n    scores.append(float(similarity_score.detach().cpu()))\n\nprint(f\"CLIP directional similarity: {np.mean(scores)}\")\n# CLIP directional similarity: 0.0797976553440094\n```\n\nCLIP 점수와 마찬가지로, CLIP 방향 유사성이 높을수록 좋습니다.\n\n`StableDiffusionInstructPix2PixPipeline`은 `image_guidance_scale`과 `guidance_scale`이라는 두 가지 인자를 노출시킵니다. 이 두 인자를 조정하여 최종 편집된 이미지의 품질을 제어할 수 있습니다. 이 두 인자의 영향을 실험해보고 방향 유사성에 미치는 영향을 확인해보기를 권장합니다.\n\n이러한 메트릭의 개념을 확장하여 원본 이미지와 편집된 버전의 유사성을 측정할 수 있습니다. 이를 위해 `F.cosine_similarity(img_feat_two, img_feat_one)`을 사용할 수 있습니다. 이러한 종류의 편집에서는 이미지의 주요 의미가 최대한 보존되어야 합니다. 즉, 높은 유사성 점수를 얻어야 합니다.\n\n[`StableDiffusionPix2PixZeroPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix_zero#diffusers.StableDiffusionPix2PixZeroPipeline)와 같은 유사한 파이프라인에도 이러한 메트릭을 사용할 수 있습니다.\n\n> [!TIP]\n> CLIP 점수와 CLIP 방향 유사성 모두 CLIP 모델에 의존하기 때문에 평가가 편향될 수 있습니다\n\n***IS, FID (나중에 설명할 예정), 또는 KID와 같은 메트릭을 확장하는 것은 어려울 수 있습니다***. 평가 중인 모델이 대규모 이미지 캡셔닝 데이터셋 (예: [LAION-5B 데이터셋](https://laion.ai/blog/laion-5b/))에서 사전 훈련되었을 때 이는 문제가 될 수 있습니다. 왜냐하면 이러한 메트릭의 기반에는 중간 이미지 특징을 추출하기 위해 ImageNet-1k 데이터셋에서 사전 훈련된 InceptionNet이 사용되기 때문입니다. Stable Diffusion의 사전 훈련 데이터셋은 InceptionNet의 사전 훈련 데이터셋과 겹치는 부분이 제한적일 수 있으므로 따라서 여기에는 좋은 후보가 아닙니다.\n\n***위의 메트릭을 사용하면 클래스 조건이 있는 모델을 평가할 수 있습니다. 예를 들어, [DiT](https://huggingface.co/docs/diffusers/main/en/api/pipelines/dit). 이는 ImageNet-1k 클래스에 조건을 걸고 사전 훈련되었습니다.***\n\n### 클래스 조건화 이미지 생성[[class-conditioned-image-generation]]\n\n클래스 조건화 생성 모델은 일반적으로 [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k)와 같은 클래스 레이블이 지정된 데이터셋에서 사전 훈련됩니다. 이러한 모델을 평가하는 인기있는 지표에는 Fréchet Inception Distance (FID), Kernel Inception Distance (KID) 및 Inception Score (IS)가 있습니다. 이 문서에서는 FID ([Heusel et al.](https://huggingface.co/papers/1706.08500))에 초점을 맞추고 있습니다. [`DiTPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/dit)을 사용하여 FID를 계산하는 방법을 보여줍니다. 이는 내부적으로 [DiT 모델](https://huggingface.co/papers/2212.09748)을 사용합니다.\n\nFID는 두 개의 이미지 데이터셋이 얼마나 유사한지를 측정하는 것을 목표로 합니다. [이 자료](https://mmgeneration.readthedocs.io/en/latest/quick_run.html#fid)에 따르면:\n\n> Fréchet Inception Distance는 두 개의 이미지 데이터셋 간의 유사성을 측정하는 지표입니다. 시각적 품질에 대한 인간 판단과 잘 상관되는 것으로 나타났으며, 주로 생성적 적대 신경망의 샘플 품질을 평가하는 데 사용됩니다. FID는 Inception 네트워크의 특징 표현에 맞게 적합한 두 개의 가우시안 사이의 Fréchet 거리를 계산하여 구합니다.\n\n이 두 개의 데이터셋은 실제 이미지 데이터셋과 가짜 이미지 데이터셋(우리의 경우 생성된 이미지)입니다. FID는 일반적으로 두 개의 큰 데이터셋으로 계산됩니다. 그러나 이 문서에서는 두 개의 미니 데이터셋으로 작업할 것입니다.\n\n먼저 ImageNet-1k 훈련 세트에서 몇 개의 이미지를 다운로드해 봅시다:\n\n```python\nfrom zipfile import ZipFile\nimport requests\n\n\ndef download(url, local_filepath):\n    r = requests.get(url)\n    with open(local_filepath, \"wb\") as f:\n        f.write(r.content)\n    return local_filepath\n\ndummy_dataset_url = \"https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/sample-imagenet-images.zip\"\nlocal_filepath = download(dummy_dataset_url, dummy_dataset_url.split(\"/\")[-1])\n\nwith ZipFile(local_filepath, \"r\") as zipper:\n    zipper.extractall(\".\")\n```\n\n```python\nfrom PIL import Image\nimport os\n\ndataset_path = \"sample-imagenet-images\"\nimage_paths = sorted([os.path.join(dataset_path, x) for x in os.listdir(dataset_path)])\n\nreal_images = [np.array(Image.open(path).convert(\"RGB\")) for path in image_paths]\n```\n\n다음은 ImageNet-1k classes의 이미지 10개입니다 : \"cassette_player\", \"chain_saw\" (x2), \"church\", \"gas_pump\" (x3), \"parachute\" (x2), 그리고 \"tench\".\n\n<p align=\"center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/real-images.png\" alt=\"real-images\"><br>\n    <em>Real images.</em>\n</p>\n\n이제 이미지가 로드되었으므로 이미지에 가벼운 전처리를 적용하여 FID 계산에 사용해 보겠습니다.\n\n```python\nfrom torchvision.transforms import functional as F\n\n\ndef preprocess_image(image):\n    image = torch.tensor(image).unsqueeze(0)\n    image = image.permute(0, 3, 1, 2) / 255.0\n    return F.center_crop(image, (256, 256))\n\nreal_images = torch.cat([preprocess_image(image) for image in real_images])\nprint(real_images.shape)\n# torch.Size([10, 3, 256, 256])\n```\n\n이제 위에서 언급한 클래스에 따라 조건화 된 이미지를 생성하기 위해 [`DiTPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/dit)를 로드합니다.\n\n```python\nfrom diffusers import DiTPipeline, DPMSolverMultistepScheduler\n\ndit_pipeline = DiTPipeline.from_pretrained(\"facebook/DiT-XL-2-256\", torch_dtype=torch.float16)\ndit_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(dit_pipeline.scheduler.config)\ndit_pipeline = dit_pipeline.to(\"cuda\")\n\nwords = [\n    \"cassette player\",\n    \"chainsaw\",\n    \"chainsaw\",\n    \"church\",\n    \"gas pump\",\n    \"gas pump\",\n    \"gas pump\",\n    \"parachute\",\n    \"parachute\",\n    \"tench\",\n]\n\nclass_ids = dit_pipeline.get_label_ids(words)\noutput = dit_pipeline(class_labels=class_ids, generator=generator, output_type=\"np\")\n\nfake_images = output.images\nfake_images = torch.tensor(fake_images)\nfake_images = fake_images.permute(0, 3, 1, 2)\nprint(fake_images.shape)\n# torch.Size([10, 3, 256, 256])\n```\n\n이제 [`torchmetrics`](https://torchmetrics.readthedocs.io/)를 사용하여 FID를 계산할 수 있습니다.\n\n```python\nfrom torchmetrics.image.fid import FrechetInceptionDistance\n\nfid = FrechetInceptionDistance(normalize=True)\nfid.update(real_images, real=True)\nfid.update(fake_images, real=False)\n\nprint(f\"FID: {float(fid.compute())}\")\n# FID: 177.7147216796875\n```\n\nFID는 낮을수록 좋습니다. 여러 가지 요소가 FID에 영향을 줄 수 있습니다:\n\n- 이미지의 수 (실제 이미지와 가짜 이미지 모두)\n- diffusion 과정에서 발생하는 무작위성\n- diffusion 과정에서의 추론 단계 수\n- diffusion 과정에서 사용되는 스케줄러\n\n마지막 두 가지 요소에 대해서는, 다른 시드와 추론 단계에서 평가를 실행하고 평균 결과를 보고하는 것은 좋은 실천 방법입니다\n\n> [!WARNING]\n> FID 결과는 많은 요소에 의존하기 때문에 취약할 수 있습니다:\n>\n> * 계산 중 사용되는 특정 Inception 모델.\n> * 계산의 구현 정확도.\n> * 이미지 형식 (PNG 또는 JPG에서 시작하는 경우가 다릅니다).\n>\n> 이러한 사항을 염두에 두면, FID는 유사한 실행을 비교할 때 가장 유용하지만, 저자가 FID 측정 코드를 주의 깊게 공개하지 않는 한 논문 결과를 재현하기는 어렵습니다.\n>\n> 이러한 사항은 KID 및 IS와 같은 다른 관련 메트릭에도 적용됩니다.\n\n마지막 단계로, `fake_images`를 시각적으로 검사해 봅시다.\n\n<p align=\"center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/fake-images.png\" alt=\"fake-images\"><br>\n    <em>Fake images.</em>\n</p>"
  },
  {
    "path": "docs/source/ko/conceptual/philosophy.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 철학 [[philosophy]]\n\n🧨 Diffusers는 다양한 모달리티에서 **최신의** 사전 훈련된 diffusion 모델을 제공합니다.\n그 목적은 추론과 훈련을 위한 **모듈식 툴박스**로 사용되는 것입니다.\n\n저희는 시간이 지나도 변치 않는 라이브러리를 구축하는 것을 목표로 하기에 API 설계를 매우 중요하게 생각합니다.\n\n간단히 말해서, Diffusers는 PyTorch를 자연스럽게 확장할 수 있도록 만들어졌습니다. 따라서 대부분의 설계 선택은 [PyTorch의 설계 원칙](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy)에 기반합니다. 이제 가장 중요한 것들을 살펴보겠습니다:\n\n## 성능보다는 사용성을 [[usability-over-performance]]\n\n- Diffusers는 다양한 성능 향상 기능이 내장되어 있지만 (자세한 내용은 [메모리와 속도](https://huggingface.co/docs/diffusers/optimization/fp16) 참조), 모델은 항상 가장 높은 정밀도와 최소한의 최적화로 로드됩니다. 따라서 사용자가 별도로 정의하지 않는 한 기본적으로 diffusion 파이프라인은 항상 float32 정밀도로 CPU에 인스턴스화됩니다. 이는 다양한 플랫폼과 가속기에서의 사용성을 보장하며, 라이브러리를 실행하기 위해 복잡한 설치가 필요하지 않다는 것을 의미합니다.\n- Diffusers는 **가벼운** 패키지를 지향하기 때문에 필수 종속성은 거의 없지만 성능을 향상시킬 수 있는 많은 선택적 종속성이 있습니다 (`accelerate`, `safetensors`, `onnx` 등). 저희는 라이브러리를 가능한 한 가볍게 유지하여 다른 패키지에 대한 종속성 걱정이 없도록 노력하고 있습니다.\n- Diffusers는 간결하고 이해하기 쉬운 코드를 선호합니다. 이는 람다 함수나 고급 PyTorch 연산자와 같은 압축된 코드 구문을 자주 사용하지 않는 것을 의미합니다.\n\n## 쉬움보다는 간단함을 [[simple-over-easy]]\n\nPyTorch에서는 **명시적인 것이 암시적인 것보다 낫다**와 **단순한 것이 복잡한 것보다 낫다**라고 말합니다. 이 설계 철학은 라이브러리의 여러 부분에 반영되어 있습니다:\n- [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to)와 같은 메소드를 사용하여 사용자가 장치 관리를 할 수 있도록 PyTorch의 API를 따릅니다.\n- 잘못된 입력을 조용히 수정하는 대신 간결한 오류 메시지를 발생시키는 것이 우선입니다. Diffusers는 라이브러리를 가능한 한 쉽게 사용할 수 있도록 하는 것보다 사용자를 가르치는 것을 목표로 합니다.\n- 복잡한 모델과 스케줄러 로직이 내부에서 마법처럼 처리하는 대신 노출됩니다. 스케줄러/샘플러는 서로에게 최소한의 종속성을 가지고 분리되어 있습니다. 이로써 사용자는 언롤된 노이즈 제거 루프를 작성해야 합니다. 그러나 이 분리는 디버깅을 더 쉽게하고 노이즈 제거 과정을 조정하거나 diffusers 모델이나 스케줄러를 교체하는 데 사용자에게 더 많은 제어권을 제공합니다.\n- diffusers 파이프라인의 따로 훈련된 구성 요소인 text encoder, unet 및 variational autoencoder는 각각 자체 모델 클래스를 갖습니다. 이로써 사용자는 서로 다른 모델의 구성 요소 간의 상호 작용을 처리해야 하며, 직렬화 형식은 모델 구성 요소를 다른 파일로 분리합니다. 그러나 이는 디버깅과 커스터마이징을 더 쉽게합니다. DreamBooth나 Textual Inversion 훈련은 Diffusers의 'diffusion 파이프라인의 단일 구성 요소들을 분리할 수 있는 능력' 덕분에 매우 간단합니다.\n\n## 추상화보다는 수정 가능하고 기여하기 쉬움을 [[tweakable-contributor-friendly-over-abstraction]]\n\n라이브러리의 대부분에 대해 Diffusers는 [Transformers 라이브러리](https://github.com/huggingface/transformers)의 중요한 설계 원칙을 채택합니다, 바로 성급한 추상화보다는 copy-pasted 코드를 선호한다는 것입니다. 이 설계 원칙은 [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself)와 같은 인기 있는 설계 원칙과는 대조적으로 매우 의견이 분분한데요.\n간단히 말해서, Transformers가 모델링 파일에 대해 수행하는 것처럼, Diffusers는 매우 낮은 수준의 추상화와 매우 독립적인 코드를 유지하는 것을 선호합니다. 함수, 긴 코드 블록, 심지어 클래스도 여러 파일에 복사할 수 있으며, 이는 처음에는 라이브러리를 유지할 수 없게 만드는 나쁜, 서투른 설계 선택으로 보일 수 있습니다. 하지만 이러한 설계는 매우 성공적이며, 커뮤니티 기반의 오픈 소스 기계 학습 라이브러리에 매우 적합합니다. 그 이유는 다음과 같습니다:\n- 기계 학습은 패러다임, 모델 아키텍처 및 알고리즘이 빠르게 변화하는 매우 빠르게 움직이는 분야이기 때문에 오랜 기간 지속되는 코드 추상화를 정의하기가 매우 어렵습니다.\n- 기계 학습 전문가들은 아이디어와 연구를 위해 기존 코드를 빠르게 조정할 수 있어야 하므로, 많은 추상화보다는 독립적인 코드를 선호합니다.\n- 오픈 소스 라이브러리는 커뮤니티 기여에 의존하므로, 기여하기 쉬운 라이브러리를 구축해야 합니다. 코드가 추상화되면 의존성이 많아지고 읽기 어렵고 기여하기 어려워집니다. 기여자들은 중요한 기능을 망가뜨릴까 두려워하여 매우 추상화된 라이브러리에 기여하지 않게 됩니다. 라이브러리에 기여하는 것이 다른 기본 코드를 망가뜨릴 수 없다면, 잠재적인 새로운 기여자에게 더욱 환영받을 수 있을 뿐만 아니라 여러 부분에 대해 병렬적으로 검토하고 기여하기가 더 쉬워집니다.\n\nHugging Face에서는 이 설계를 **단일 파일 정책**이라고 부르며, 특정 클래스의 대부분의 코드가 단일하고 독립적인 파일에 작성되어야 한다는 의미입니다. 철학에 대해 자세히 알아보려면 [이 블로그 글](https://huggingface.co/blog/transformers-design-philosophy)을 참조할 수 있습니다.\n\nDiffusers에서는 이러한 철학을 파이프라인과 스케줄러에 모두 따르지만, diffusion 모델에 대해서는 일부만 따릅니다. 일부만 따르는 이유는 Diffusion 파이프라인인 [DDPM](https://huggingface.co/docs/diffusers/api/pipelines/ddpm), [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview#stable-diffusion-pipelines), [unCLIP (DALL·E 2)](https://huggingface.co/docs/diffusers/api/pipelines/unclip) 및 [Imagen](https://imagen.research.google/) 등 대부분의 diffusion 파이프라인은 동일한 diffusion 모델인 [UNet](https://huggingface.co/docs/diffusers/api/models/unet2d-cond)에 의존하기 때문입니다.\n\n좋아요, 이제 🧨 Diffusers가 설계된 방식을 대략적으로 이해했을 것입니다 🤗.\n우리는 이러한 설계 원칙을 일관되게 라이브러리 전체에 적용하려고 노력하고 있습니다. 그럼에도 불구하고 철학에 대한 일부 예외 사항이나 불행한 설계 선택이 있을 수 있습니다. 디자인에 대한 피드백이 있다면 [GitHub에서 직접](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) 알려주시면 감사하겠습니다.\n\n## 디자인 철학 자세히 알아보기 [[design-philosophy-in-details]]\n\n이제 디자인 철학의 세부 사항을 좀 더 자세히 살펴보겠습니다. Diffusers는 주로 세 가지 주요 클래스로 구성됩니다: [파이프라인](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines), [모델](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models), 그리고 [스케줄러](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers). 각 클래스에 대한 더 자세한 설계 결정 사항을 살펴보겠습니다.\n\n### 파이프라인 [[pipelines]]\n\n파이프라인은 사용하기 쉽도록 설계되었으며 (따라서 [*쉬움보다는 간단함을*](#쉬움보다는-간단함을)을 100% 따르지는 않음), feature-complete하지 않으며, 추론을 위한 [모델](#모델)과 [스케줄러](#스케줄러)를 사용하는 방법의 예시로 간주될 수 있습니다.\n\n다음과 같은 설계 원칙을 따릅니다:\n- 파이프라인은 단일 파일 정책을 따릅니다. 모든 파이프라인은 src/diffusers/pipelines의 개별 디렉토리에 있습니다. 하나의 파이프라인 폴더는 하나의 diffusion 논문/프로젝트/릴리스에 해당합니다. 여러 파이프라인 파일은 하나의 파이프라인 폴더에 모을 수 있습니다. 예를 들어 [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion)에서 그렇게 하고 있습니다. 파이프라인이 유사한 기능을 공유하는 경우, [# Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251)을 사용할 수 있습니다.\n- 파이프라인은 모두 [`DiffusionPipeline`]을 상속합니다.\n- 각 파이프라인은 서로 다른 모델 및 스케줄러 구성 요소로 구성되어 있으며, 이는 [`model_index.json` 파일](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/model_index.json)에 문서화되어 있으며, 파이프라인의 속성 이름과 동일한 이름으로 액세스할 수 있으며, [`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components) 함수를 통해 파이프라인 간에 공유할 수 있습니다.\n- 각 파이프라인은 [`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained) 함수를 통해 로드할 수 있어야 합니다.\n- 파이프라인은 추론에**만** 사용되어야 합니다.\n- 파이프라인은 매우 가독성이 좋고, 이해하기 쉽고, 쉽게 조정할 수 있도록 설계되어야 합니다.\n- 파이프라인은 서로 상호작용하고, 상위 수준 API에 쉽게 통합할 수 있도록 설계되어야 합니다.\n- 파이프라인은 사용자 인터페이스가 feature-complete하지 않게 하는 것을 목표로 합니다. future-complete한 사용자 인터페이스를 원한다면 [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), [lama-cleaner](https://github.com/Sanster/lama-cleaner)를 참조해야 합니다.\n- 모든 파이프라인은 오로지 `__call__` 메소드를 통해 실행할 수 있어야 합니다. `__call__` 인자의 이름은 모든 파이프라인에서 공유되어야 합니다.\n- 파이프라인은 해결하고자 하는 작업의 이름으로 지정되어야 합니다.\n- 대부분의 경우에 새로운 diffusion 파이프라인은 새로운 파이프라인 폴더/파일에 구현되어야 합니다.\n\n### 모델 [[models]]\n\n모델은 [PyTorch의 Module 클래스](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)의 자연스러운 확장이 되도록, 구성 가능한 툴박스로 설계되었습니다. 그리고 모델은 **단일 파일 정책**을 일부만 따릅니다.\n\n다음과 같은 설계 원칙을 따릅니다:\n- 모델은 **모델 아키텍처 유형**에 해당합니다. 예를 들어 [`UNet2DConditionModel`] 클래스는 2D 이미지 입력을 기대하고 일부 context에 의존하는 모든 UNet 변형들에 사용됩니다.\n- 모든 모델은 [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)에서 찾을 수 있으며, 각 모델 아키텍처는 해당 파일에 정의되어야 합니다. 예를 들어 [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py) 등이 있습니다.\n- 모델은 **단일 파일 정책**을 따르지 않으며, [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py) 등과 같은 작은 모델 구성 요소를 사용해야 합니다. **참고**: 이는 Transformers의 모델링 파일과는 대조적으로 모델이 실제로 단일 파일 정책을 따르지 않음을 보여줍니다.\n- 모델은 PyTorch의 `Module` 클래스와 마찬가지로 복잡성을 노출하고 명확한 오류 메시지를 제공해야 합니다.\n- 모든 모델은 `ModelMixin`과 `ConfigMixin`을 상속합니다.\n- 모델은 주요 코드 변경이 필요하지 않고, 역호환성을 유지하며, 메모리 또는 컴퓨팅과 관련한 중요한 이득을 제공할 때 성능을 위해 최적화할 수 있습니다.\n- 모델은 기본적으로 가장 높은 정밀도와 가장 낮은 성능 설정을 가져야 합니다.\n- Diffusers에 이미 있는 모델 아키텍처로 분류할 수 있는 새로운 모델 체크포인트를 통합할 때는 기존 모델 아키텍처를 새로운 체크포인트와 호환되도록 수정해야 합니다. 새로운 파일을 만들어야 하는 경우는 모델 아키텍처가 근본적으로 다른 경우에만 해당합니다.\n- 모델은 미래의 변경 사항을 쉽게 확장할 수 있도록 설계되어야 합니다. 이는 공개 함수 인수들과 구성 인수들을 제한하고,미래의 변경 사항을 \"예상\"하는 것을 통해 달성할 수 있습니다. 예를 들어, 불리언 `is_..._type` 인수보다는 새로운 미래 유형에 쉽게 확장할 수 있는 문자열 \"...type\" 인수를 추가하는 것이 일반적으로 더 좋습니다. 새로운 모델 체크포인트가 작동하도록 하기 위해 기존 아키텍처에 최소한의 변경만을 가해야 합니다.\n- 모델 디자인은 코드의 가독성과 간결성을 유지하는 것과 많은 모델 체크포인트를 지원하는 것 사이의 어려운 균형 조절입니다. 모델링 코드의 대부분은 새로운 모델 체크포인트를 위해 클래스를 수정하는 것이 좋지만, [UNet 블록](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py) 및 [Attention 프로세서](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)와 같이 코드를 장기적으로 간결하고 읽기 쉽게 유지하기 위해 새로운 클래스를 추가하는 예외도 있습니다.\n\n### 스케줄러 [[schedulers]]\n\n스케줄러는 추론을 위한 노이즈 제거 과정을 안내하고 훈련을 위한 노이즈 스케줄을 정의하는 역할을 합니다. 스케줄러는 개별 클래스로 설계되어 있으며, 로드 가능한 구성 파일과 **단일 파일 정책**을 엄격히 따릅니다.\n\n다음과 같은 설계 원칙을 따릅니다:\n- 모든 스케줄러는 [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)에서 찾을 수 있습니다.\n- 스케줄러는 큰 유틸리티 파일에서 가져오지 **않아야** 하며, 자체 포함성을 유지해야 합니다.\n- 하나의 스케줄러 Python 파일은 하나의 스케줄러 알고리즘(논문에서 정의된 것과 같은)에 해당합니다.\n- 스케줄러가 유사한 기능을 공유하는 경우, `# Copied from` 메커니즘을 사용할 수 있습니다.\n- 모든 스케줄러는 `SchedulerMixin`과 `ConfigMixin`을 상속합니다.\n- [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) 메소드를 사용하여 스케줄러를 쉽게 교체할 수 있습니다. 자세한 내용은 [여기](../using-diffusers/schedulers.md)에서 설명합니다.\n- 모든 스케줄러는 `set_num_inference_steps`와 `step` 함수를 가져야 합니다. `set_num_inference_steps(...)`는 각 노이즈 제거 과정(즉, `step(...)`이 호출되기 전) 이전에 호출되어야 합니다.\n- 각 스케줄러는 모델이 호출될 타임스텝의 배열인 `timesteps` 속성을 통해 루프를 돌 수 있는 타임스텝을 노출합니다.\n- `step(...)` 함수는 예측된 모델 출력과 \"현재\" 샘플(x_t)을 입력으로 받고, \"이전\" 약간 더 노이즈가 제거된 샘플(x_t-1)을 반환합니다.\n- 노이즈 제거 스케줄러의 복잡성을 고려하여, `step` 함수는 모든 복잡성을 노출하지 않으며, \"블랙 박스\"일 수 있습니다.\n- 거의 모든 경우에 새로운 스케줄러는 새로운 스케줄링 파일에 구현되어야 합니다.\n"
  },
  {
    "path": "docs/source/ko/in_translation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 번역중\n\n열심히 번역을 진행중입니다. 조금만 기다려주세요.\n감사합니다!"
  },
  {
    "path": "docs/source/ko/index.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://raw.githubusercontent.com/huggingface/diffusers/77aadfee6a891ab9fcfb780f87c693f7a5beeb8e/docs/source/imgs/diffusers_library.jpg\" width=\"400\"/>\n    <br>\n</p>\n\n\n# Diffusers\n\n🤗 Diffusers는 이미지, 오디오, 심지어 분자의 3D 구조를 생성하기 위한 최첨단 사전 훈련된 diffusion 모델을 위한 라이브러리입니다. 간단한 추론 솔루션을 찾고 있든, 자체 diffusion 모델을 훈련하고 싶든, 🤗 Diffusers는 두 가지 모두를 지원하는 모듈식 툴박스입니다. 저희 라이브러리는 [성능보다 사용성](conceptual/philosophy#usability-over-performance), [간편함보다 단순함](conceptual/philosophy#simple-over-easy), 그리고 [추상화보다 사용자 지정 가능성](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction)에 중점을 두고 설계되었습니다.\n\n이 라이브러리에는 세 가지 주요 구성 요소가 있습니다:\n\n- 몇 줄의 코드만으로 추론할 수 있는 최첨단 [diffusion 파이프라인](api/pipelines/overview).\n- 생성 속도와 품질 간의 균형을 맞추기 위해 상호교환적으로 사용할 수 있는 [노이즈 스케줄러](api/schedulers/overview).\n- 빌딩 블록으로 사용할 수 있고 스케줄러와 결합하여 자체적인 end-to-end diffusion 시스템을 만들 수 있는 사전 학습된 [모델](api/models).\n\n<div class=\"mt-10\">\n  <div class=\"w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5\">\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./tutorials/tutorial_overview\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-blue-400 to-blue-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Tutorials</div>\n      <p class=\"text-gray-700\">결과물을 생성하고, 나만의 diffusion 시스템을 구축하고, 확산 모델을 훈련하는 데 필요한 기본 기술을 배워보세요. 🤗 Diffusers를 처음 사용하는 경우 여기에서 시작하는 것이 좋습니다!</p>\n    </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./using-diffusers/loading_overview\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-indigo-400 to-indigo-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">How-to guides</div>\n      <p class=\"text-gray-700\">파이프라인, 모델, 스케줄러를 로드하는 데 도움이 되는 실용적인 가이드입니다. 또한 특정 작업에 파이프라인을 사용하고, 출력 생성 방식을 제어하고, 추론 속도에 맞게 최적화하고, 다양한 학습 기법을 사용하는 방법도 배울 수 있습니다.</p>\n    </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./conceptual/philosophy\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-pink-400 to-pink-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Conceptual guides</div>\n      <p class=\"text-gray-700\">라이브러리가 왜 이런 방식으로 설계되었는지 이해하고, 라이브러리 이용에 대한 윤리적 가이드라인과 안전 구현에 대해 자세히 알아보세요.</p>\n   </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./api/models\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-purple-400 to-purple-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Reference</div>\n      <p class=\"text-gray-700\">🤗 Diffusers 클래스 및 메서드의 작동 방식에 대한 기술 설명.</p>\n    </a>\n  </div>\n</div>"
  },
  {
    "path": "docs/source/ko/installation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 설치\n\n사용하시는 라이브러리에 맞는 🤗 Diffusers를 설치하세요.\n\n🤗 Diffusers는 Python 3.8+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.\n\n- [PyTorch 설치 안내](https://pytorch.org/get-started/locally/)\n- [Flax 설치 안내](https://flax.readthedocs.io/en/latest/)\n\n## pip를 이용한 설치\n\n[가상 환경](https://docs.python.org/3/library/venv.html)에 🤗 Diffusers를 설치해야 합니다.\nPython 가상 환경에 익숙하지 않은 경우 [가상환경 pip 설치 가이드](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)를 살펴보세요.\n가상 환경을 사용하면 서로 다른 프로젝트를 더 쉽게 관리하고, 종속성간의 호환성 문제를 피할 수 있습니다.\n\n프로젝트 디렉토리에 가상 환경을 생성하는 것으로 시작하세요:\n\n```bash\npython -m venv .env\n```\n\n그리고 가상 환경을 활성화합니다:\n\n```bash\nsource .env/bin/activate\n```\n\n이제 다음의 명령어로 🤗 Diffusers를 설치할 준비가 되었습니다:\n\n**PyTorch의 경우**\n\n```bash\npip install diffusers[\"torch\"]\n```\n\n**Flax의 경우**\n\n```bash\npip install diffusers[\"flax\"]\n```\n\n## 소스로부터 설치\n\n소스에서 `diffusers`를 설치하기 전에, `torch` 및 `accelerate`이 설치되어 있는지 확인하세요.\n\n`torch` 설치에 대해서는 [torch docs](https://pytorch.org/get-started/locally/#start-locally)를 참고하세요.\n\n다음과 같이 `accelerate`을 설치하세요.\n\n```bash\npip install accelerate\n```\n\n다음 명령어를 사용하여 소스에서 🤗 Diffusers를 설치하세요:\n\n```bash\npip install git+https://github.com/huggingface/diffusers\n```\n\n이 명령어는 최신 `stable` 버전이 아닌 최첨단 `main` 버전을 설치합니다.\n`main` 버전은 최신 개발 정보를 최신 상태로 유지하는 데 유용합니다.\n예를 들어 마지막 공식 릴리즈 이후 버그가 수정되었지만, 새 릴리즈가 아직 출시되지 않은 경우입니다.\n그러나 이는 `main` 버전이 항상 안정적이지 않을 수 있음을 의미합니다.\n우리는 `main` 버전이 지속적으로 작동하도록 노력하고 있으며, 대부분의 문제는 보통 몇 시간 또는 하루 안에 해결됩니다.\n문제가 발생하면 더 빨리 해결할 수 있도록 [Issue](https://github.com/huggingface/transformers/issues)를 열어주세요!\n\n\n## 편집가능한 설치\n\n다음을 수행하려면 편집가능한 설치가 필요합니다:\n\n* 소스 코드의 `main` 버전을 사용\n* 🤗 Diffusers에 기여 (코드의 변경 사항을 테스트하기 위해 필요)\n\n저장소를 복제하고 다음 명령어를 사용하여 🤗 Diffusers를 설치합니다:\n\n```bash\ngit clone https://github.com/huggingface/diffusers.git\ncd diffusers\n```\n\n**PyTorch의 경우**\n\n```sh\npip install -e \".[torch]\"\n```\n\n**Flax의 경우**\n\n```sh\npip install -e \".[flax]\"\n```\n\n이러한 명령어들은 저장소를 복제한 폴더와 Python 라이브러리 경로를 연결합니다.\nPython은 이제 일반 라이브러리 경로에 더하여 복제한 폴더 내부를 살펴봅니다.\n예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.10/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.\n\n> [!WARNING]\n> 라이브러리를 계속 사용하려면 `diffusers` 폴더를 유지해야 합니다.\n\n이제 다음 명령어를 사용하여 최신 버전의 🤗 Diffusers로 쉽게 업데이트할 수 있습니다:\n\n```bash\ncd ~/diffusers/\ngit pull\n```\n\n이렇게 하면, 다음에 실행할 때 Python 환경이 🤗 Diffusers의 `main` 버전을 찾게 됩니다.\n\n## 텔레메트리 로깅에 대한 알림\n\n우리 라이브러리는 `from_pretrained()` 요청 중에 텔레메트리 정보를 원격으로 수집합니다.\n이 데이터에는 Diffusers 및 PyTorch/Flax의 버전, 요청된 모델 또는 파이프라인 클래스, 그리고 허브에서 호스팅되는 경우 사전학습된 체크포인트에 대한 경로를 포함합니다.\n이 사용 데이터는 문제를 디버깅하고 새로운 기능의 우선순위를 지정하는데 도움이 됩니다.\n텔레메트리는 HuggingFace 허브에서 모델과 파이프라인을 불러올 때만 전송되며, 로컬 사용 중에는 수집되지 않습니다.\n\n우리는 추가 정보를 공유하지 않기를 원하는 사람이 있다는 것을 이해하고 개인 정보를 존중하므로, 터미널에서 `DISABLE_TELEMETRY` 환경 변수를 설정하여 텔레메트리 수집을 비활성화할 수 있습니다.\n\nLinux/MacOS에서:\n```bash\nexport DISABLE_TELEMETRY=YES\n```\n\nWindows에서:\n```bash\nset DISABLE_TELEMETRY=YES\n```"
  },
  {
    "path": "docs/source/ko/optimization/coreml.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Core ML로 Stable Diffusion을 실행하는 방법\n\n[Core ML](https://developer.apple.com/documentation/coreml)은 Apple 프레임워크에서 지원하는 모델 형식 및 머신 러닝 라이브러리입니다. macOS 또는 iOS/iPadOS 앱 내에서 Stable Diffusion 모델을 실행하는 데 관심이 있는 경우, 이 가이드에서는 기존 PyTorch 체크포인트를 Core ML 형식으로 변환하고 이를 Python 또는 Swift로 추론에 사용하는 방법을 설명합니다.\n\nCore ML 모델은 Apple 기기에서 사용할 수 있는 모든 컴퓨팅 엔진들, 즉 CPU, GPU, Apple Neural Engine(또는 Apple Silicon Mac 및 최신 iPhone/iPad에서 사용할 수 있는 텐서 최적화 가속기인 ANE)을 활용할 수 있습니다. 모델과 실행 중인 기기에 따라 Core ML은 컴퓨팅 엔진도 혼합하여 사용할 수 있으므로, 예를 들어 모델의 일부가 CPU에서 실행되는 반면 다른 부분은 GPU에서 실행될 수 있습니다.\n\n> [!TIP]\n> PyTorch에 내장된 `mps` 가속기를 사용하여 Apple Silicon Macs에서 `diffusers` Python 코드베이스를 실행할 수도 있습니다. 이 방법은 [mps 가이드]에 자세히 설명되어 있지만 네이티브 앱과 호환되지 않습니다.\n\n## Stable Diffusion Core ML 체크포인트\n\nStable Diffusion 가중치(또는 체크포인트)는 PyTorch 형식으로 저장되기 때문에 네이티브 앱에서 사용하기 위해서는 Core ML 형식으로 변환해야 합니다.\n\n다행히도 Apple 엔지니어들이 `diffusers`를 기반으로 한 [변환 툴](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml)을 개발하여 PyTorch 체크포인트를 Core ML로 변환할 수 있습니다.\n\n모델을 변환하기 전에 잠시 시간을 내어 Hugging Face Hub를 살펴보세요. 관심 있는 모델이 이미 Core ML 형식으로 제공되고 있을 가능성이 높습니다:\n\n- [Apple](https://huggingface.co/apple) organization에는 Stable Diffusion 버전 1.4, 1.5, 2.0 base 및 2.1 base가 포함되어 있습니다.\n- [coreml](https://huggingface.co/coreml) organization에는 커스텀 DreamBooth가 적용되거나, 파인튜닝된 모델이 포함되어 있습니다.\n- 이 [필터](https://huggingface.co/models?pipeline_tag=text-to-image&library=coreml&p=2&sort=likes)를 사용하여 사용 가능한 모든 Core ML 체크포인트들을 반환합니다.\n\n원하는 모델을 찾을 수 없는 경우 Apple의 [모델을 Core ML로 변환하기](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml) 지침을 따르는 것이 좋습니다.\n\n## 사용할 Core ML 변형(Variant) 선택하기\n\nStable Diffusion 모델은 다양한 목적에 따라 다른 Core ML 변형으로 변환할 수 있습니다:\n\n- 사용되는 어텐션 블록 유형. 어텐션 연산은 이미지 표현의 여러 영역 간의 관계에 '주의를 기울이고' 이미지와 텍스트 표현이 어떻게 연관되어 있는지 이해하는 데 사용됩니다. 어텐션 연산은 컴퓨팅 및 메모리 집약적이므로 다양한 장치의 하드웨어 특성을 고려한 다양한 구현이 존재합니다. Core ML Stable Diffusion 모델의 경우 두 가지 주의 변형이 있습니다:\n    * `split_einsum` ([Apple에서 도입](https://machinelearning.apple.com/research/neural-engine-transformers)은 최신 iPhone, iPad 및 M 시리즈 컴퓨터에서 사용할 수 있는 ANE 장치에 최적화되어 있습니다.\n    * \"원본\" 어텐션(`diffusers`에 사용되는 기본 구현)는 CPU/GPU와만 호환되며 ANE와는 호환되지 않습니다. \"원본\" 어텐션을 사용하여 CPU + GPU에서 모델을 실행하는 것이 ANE보다 *더* 빠를 수 있습니다. 자세한 내용은 [이 성능 벤치마크](https://huggingface.co/blog/fast-mac-diffusers#performance-benchmarks)와 커뮤니티에서 제공하는 일부 [추가 측정](https://github.com/huggingface/swift-coreml-diffusers/issues/31)을 참조하십시오.\n\n- 지원되는 추론 프레임워크\n    * `packages`는 Python 추론에 적합합니다. 네이티브 앱에 통합하기 전에 변환된 Core ML 모델을 테스트하거나, Core ML 성능을 알고 싶지만 네이티브 앱을 지원할 필요는 없는 경우에 사용할 수 있습니다. 예를 들어, 웹 UI가 있는 애플리케이션은 Python Core ML 백엔드를 완벽하게 사용할 수 있습니다.\n    * Swift 코드에는 `컴파일된` 모델이 필요합니다. Hub의 `컴파일된` 모델은 iOS 및 iPadOS 기기와의 호환성을 위해 큰 UNet 모델 가중치를 여러 파일로 분할합니다. 이는 [`--chunk-unet` 변환 옵션](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml)에 해당합니다. 네이티브 앱을 지원하려면 `컴파일된` 변형을 선택해야 합니다.\n\n공식 Core ML Stable Diffusion [모델](https://huggingface.co/apple/coreml-stable-diffusion-v1-4/tree/main)에는 이러한 변형이 포함되어 있지만 커뮤니티 버전은 다를 수 있습니다:\n\n```\ncoreml-stable-diffusion-v1-4\n├── README.md\n├── original\n│   ├── compiled\n│   └── packages\n└── split_einsum\n    ├── compiled\n    └── packages\n```\n\n아래와 같이 필요한 변형을 다운로드하여 사용할 수 있습니다.\n\n## Python에서 Core ML 추론\n\nPython에서 Core ML 추론을 실행하려면 다음 라이브러리를 설치하세요:\n\n```bash\npip install huggingface_hub\npip install git+https://github.com/apple/ml-stable-diffusion\n```\n\n### 모델 체크포인트 다운로드하기\n\n`컴파일된` 버전은 Swift와만 호환되므로 Python에서 추론을 실행하려면 `packages` 폴더에 저장된 버전 중 하나를 사용하세요. `원본` 또는 `split_einsum` 어텐션 중 어느 것을 사용할지 선택할 수 있습니다.\n\n다음은 Hub에서 'models'라는 디렉토리로 'original' 어텐션 변형을 다운로드하는 방법입니다:\n\n```Python\nfrom huggingface_hub import snapshot_download\nfrom pathlib import Path\n\nrepo_id = \"apple/coreml-stable-diffusion-v1-4\"\nvariant = \"original/packages\"\n\nmodel_path = Path(\"./models\") / (repo_id.split(\"/\")[-1] + \"_\" + variant.replace(\"/\", \"_\"))\nsnapshot_download(repo_id, allow_patterns=f\"{variant}/*\", local_dir=model_path, local_dir_use_symlinks=False)\nprint(f\"Model downloaded at {model_path}\")\n```\n\n\n### 추론[[python-inference]]\n\n모델의 snapshot을 다운로드한 후에는 Apple의 Python 스크립트를 사용하여 테스트할 수 있습니다.\n\n```shell\npython -m python_coreml_stable_diffusion.pipeline --prompt \"a photo of an astronaut riding a horse on mars\" -i models/coreml-stable-diffusion-v1-4_original_packages -o </path/to/output/image> --compute-unit CPU_AND_GPU --seed 93\n```\n\n`<output-mlpackages-directory>`는 위 단계에서 다운로드한 체크포인트를 가리켜야 하며, `--compute-unit`은 추론을 허용할 하드웨어를 나타냅니다. 이는 다음 옵션 중 하나이어야 합니다: `ALL`, `CPU_AND_GPU`, `CPU_ONLY`, `CPU_AND_NE`. 선택적 출력 경로와 재현성을 위한 시드를 제공할 수도 있습니다.\n\n추론 스크립트에서는 Stable Diffusion 모델의 원래 버전인 `CompVis/stable-diffusion-v1-4`를 사용한다고 가정합니다. 다른 모델을 사용하는 경우 추론 명령줄에서 `--model-version` 옵션을 사용하여 해당 허브 ID를 *지정*해야 합니다. 이는 이미 지원되는 모델과 사용자가 직접 학습하거나 파인튜닝한 사용자 지정 모델에 적용됩니다.\n\n예를 들어, [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)를 사용하려는 경우입니다:\n\n```shell\npython -m python_coreml_stable_diffusion.pipeline --prompt \"a photo of an astronaut riding a horse on mars\" --compute-unit ALL -o output --seed 93 -i models/coreml-stable-diffusion-v1-5_original_packages --model-version stable-diffusion-v1-5/stable-diffusion-v1-5\n```\n\n\n## Swift에서 Core ML 추론하기\n\nSwift에서 추론을 실행하는 것은 모델이 이미 `mlmodelc` 형식으로 컴파일되어 있기 때문에 Python보다 약간 빠릅니다. 이는 앱이 시작될 때 모델이 불러와지는 것이 눈에 띄지만, 이후 여러 번 실행하면 눈에 띄지 않을 것입니다.\n\n### 다운로드\n\nMac에서 Swift에서 추론을 실행하려면 `컴파일된` 체크포인트 버전 중 하나가 필요합니다. 이전 예제와 유사하지만 `컴파일된` 변형 중 하나를 사용하여 Python 코드를 로컬로 다운로드하는 것이 좋습니다:\n\n```Python\nfrom huggingface_hub import snapshot_download\nfrom pathlib import Path\n\nrepo_id = \"apple/coreml-stable-diffusion-v1-4\"\nvariant = \"original/compiled\"\n\nmodel_path = Path(\"./models\") / (repo_id.split(\"/\")[-1] + \"_\" + variant.replace(\"/\", \"_\"))\nsnapshot_download(repo_id, allow_patterns=f\"{variant}/*\", local_dir=model_path, local_dir_use_symlinks=False)\nprint(f\"Model downloaded at {model_path}\")\n```\n\n### 추론[[swift-inference]]\n\n추론을 실행하기 위해서, Apple의 리포지토리를 복제하세요:\n\n```bash\ngit clone https://github.com/apple/ml-stable-diffusion\ncd ml-stable-diffusion\n```\n\n그 다음 Apple의 명령어 도구인 [Swift 패키지 관리자](https://www.swift.org/package-manager/#)를 사용합니다:\n\n```bash\nswift run StableDiffusionSample --resource-path models/coreml-stable-diffusion-v1-4_original_compiled --compute-units all \"a photo of an astronaut riding a horse on mars\"\n```\n\n`--resource-path`에 이전 단계에서 다운로드한 체크포인트 중 하나를 지정해야 하므로 확장자가 `.mlmodelc`인 컴파일된 Core ML 번들이 포함되어 있는지 확인하시기 바랍니다. `--compute-units`는 다음 값 중 하나이어야 합니다: `all`, `cpuOnly`, `cpuAndGPU`, `cpuAndNeuralEngine`.\n\n자세한 내용은 [Apple의 리포지토리 안의 지침](https://github.com/apple/ml-stable-diffusion)을 참고하시기 바랍니다.\n\n\n## 지원되는 Diffusers 기능\n\nCore ML 모델과 추론 코드는 🧨 Diffusers의 많은 기능, 옵션 및 유연성을 지원하지 않습니다. 다음은 유의해야 할 몇 가지 제한 사항입니다:\n\n- Core ML 모델은 추론에만 적합합니다. 학습이나 파인튜닝에는 사용할 수 없습니다.\n- Swift에 포팅된 스케줄러는 Stable Diffusion에서 사용하는 기본 스케줄러와 `diffusers` 구현에서 Swift로 포팅한 `DPMSolverMultistepScheduler` 두 개뿐입니다. 이들 중 약 절반의 스텝으로 동일한 품질을 생성하는 `DPMSolverMultistepScheduler`를 사용하는 것이 좋습니다.\n- 추론 코드에서 네거티브 프롬프트, classifier-free guidance scale 및 image-to-image 작업을 사용할 수 있습니다. depth guidance, ControlNet, latent upscalers와 같은 고급 기능은 아직 사용할 수 없습니다.\n\nApple의 [변환 및 추론 리포지토리](https://github.com/apple/ml-stable-diffusion)와 자체 [swift-coreml-diffusers](https://github.com/huggingface/swift-coreml-diffusers) 리포지토리는 다른 개발자들이 구축할 수 있는 기술적인 데모입니다.\n\n누락된 기능이 있다고 생각되면 언제든지 기능을 요청하거나, 더 좋은 방법은 기여 PR을 열어주세요. :)\n\n\n## 네이티브 Diffusers Swift 앱\n\n자체 Apple 하드웨어에서 Stable Diffusion을 실행하는 쉬운 방법 중 하나는 `diffusers`와 Apple의 변환 및 추론 리포지토리를 기반으로 하는 [자체 오픈 소스 Swift 리포지토리](https://github.com/huggingface/swift-coreml-diffusers)를 사용하는 것입니다. 코드를 공부하고 [Xcode](https://developer.apple.com/xcode/)로 컴파일하여 필요에 맞게 조정할 수 있습니다. 편의를 위해 앱스토어에 [독립형 Mac 앱](https://apps.apple.com/app/diffusers/id1666309574)도 있으므로 코드나 IDE를 다루지 않고도 사용할 수 있습니다. 개발자로서 Core ML이 Stable Diffusion 앱을 구축하는 데 가장 적합한 솔루션이라고 판단했다면, 이 가이드의 나머지 부분을 사용하여 프로젝트를 시작할 수 있습니다. 여러분이 무엇을 빌드할지 기대됩니다. :)"
  },
  {
    "path": "docs/source/ko/optimization/fp16.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 메모리와 속도\n\n메모리 또는 속도에 대해 🤗 Diffusers *추론*을 최적화하기 위한 몇 가지 기술과 아이디어를 제시합니다.\n일반적으로, memory-efficient attention을 위해 [xFormers](https://github.com/facebookresearch/xformers) 사용을 추천하기 때문에, 추천하는 [설치 방법](xformers)을 보고 설치해 보세요.\n\n다음 설정이 성능과 메모리에 미치는 영향에 대해 설명합니다.\n\n|                  | 지연시간  | 속도 향상 |\n| ---------------- | ------- | ------- |\n| 별도 설정 없음      | 9.50s   | x1      |\n| cuDNN auto-tuner | 9.37s   | x1.01   |\n| fp16             | 3.61s   | x2.63   |\n| Channels Last 메모리 형식     | 3.30s   | x2.88   |\n| traced UNet      | 3.21s   | x2.96   |\n| memory-efficient attention | 2.63s  | x3.61   |\n\n<em>\n   NVIDIA TITAN RTX에서 50 DDIM 스텝의 \"a photo of an astronaut riding a horse on mars\" 프롬프트로 512x512 크기의 단일 이미지를 생성하였습니다.\n</em>\n\n## cuDNN auto-tuner 활성화하기\n\n[NVIDIA cuDNN](https://developer.nvidia.com/cudnn)은 컨볼루션을 계산하는 많은 알고리즘을 지원합니다. Autotuner는 짧은 벤치마크를 실행하고 주어진 입력 크기에 대해 주어진 하드웨어에서 최고의 성능을 가진 커널을 선택합니다.\n\n**컨볼루션 네트워크**를 활용하고 있기 때문에 (다른 유형들은 현재 지원되지 않음), 다음 설정을 통해 추론 전에 cuDNN autotuner를 활성화할 수 있습니다:\n\n```python\nimport torch\n\ntorch.backends.cudnn.benchmark = True\n```\n\n### fp32 대신 tf32 사용하기  (Ampere 및 이후 CUDA 장치들에서)\n\nAmpere 및 이후 CUDA 장치에서 행렬곱 및 컨볼루션은 TensorFloat32(TF32) 모드를 사용하여 더 빠르지만 약간 덜 정확할 수 있습니다.\n기본적으로 PyTorch는 컨볼루션에 대해 TF32 모드를 활성화하지만 행렬 곱셈은 활성화하지 않습니다.\n네트워크에 완전한 float32 정밀도가 필요한 경우가 아니면 행렬 곱셈에 대해서도 이 설정을 활성화하는 것이 좋습니다.\n이는 일반적으로 무시할 수 있는 수치의 정확도 손실이 있지만, 계산 속도를 크게 높일 수 있습니다.\n그것에 대해 [여기](https://huggingface.co/docs/transformers/v4.18.0/en/performance#tf32)서 더 읽을 수 있습니다.\n추론하기 전에 다음을 추가하기만 하면 됩니다:\n\n```python\nimport torch\n\ntorch.backends.cuda.matmul.allow_tf32 = True\n```\n\n## 반정밀도 가중치\n\n더 많은 GPU 메모리를 절약하고 더 빠른 속도를 얻기 위해 모델 가중치를 반정밀도(half precision)로 직접 불러오고 실행할 수 있습니다.\n여기에는 `fp16`이라는 브랜치에 저장된 float16 버전의 가중치를 불러오고, 그 때 `float16` 유형을 사용하도록 PyTorch에 지시하는 작업이 포함됩니다.\n\n```Python\npipe = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n\n    torch_dtype=torch.float16,\n)\npipe = pipe.to(\"cuda\")\n\nprompt = \"a photo of an astronaut riding a horse on mars\"\nimage = pipe(prompt).images[0]\n```\n\n> [!WARNING]\n> 어떤 파이프라인에서도 [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) 를 사용하는 것은 검은색 이미지를 생성할 수 있고, 순수한 float16 정밀도를 사용하는 것보다 항상 느리기 때문에 사용하지 않는 것이 좋습니다.\n\n## 추가 메모리 절약을 위한 슬라이스 어텐션\n\n추가 메모리 절약을 위해, 한 번에 모두 계산하는 대신 단계적으로 계산을 수행하는 슬라이스 버전의 어텐션(attention)을 사용할 수 있습니다.\n\n> [!TIP]\n> Attention slicing은 모델이 하나 이상의 어텐션 헤드를 사용하는 한, 배치 크기가 1인 경우에도 유용합니다.\n>   하나 이상의 어텐션 헤드가 있는 경우 *QK^T* 어텐션 매트릭스는 상당한 양의 메모리를 절약할 수 있는 각 헤드에 대해 순차적으로 계산될 수 있습니다.\n\n각 헤드에 대해 순차적으로 어텐션 계산을 수행하려면, 다음과 같이 추론 전에 파이프라인에서 [`~StableDiffusionPipeline.enable_attention_slicing`]를 호출하면 됩니다:\n\n```Python\nimport torch\nfrom diffusers import StableDiffusionPipeline\n\npipe = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n\n    torch_dtype=torch.float16,\n)\npipe = pipe.to(\"cuda\")\n\nprompt = \"a photo of an astronaut riding a horse on mars\"\npipe.enable_attention_slicing()\nimage = pipe(prompt).images[0]\n```\n\n추론 시간이 약 10% 느려지는 약간의 성능 저하가 있지만 이 방법을 사용하면 3.2GB 정도의 작은 VRAM으로도 Stable Diffusion을 사용할 수 있습니다!\n\n\n## 더 큰 배치를 위한 sliced VAE 디코드\n\n제한된 VRAM에서 대규모 이미지 배치를 디코딩하거나 32개 이상의 이미지가 포함된 배치를 활성화하기 위해, 배치의 latent 이미지를 한 번에 하나씩 디코딩하는 슬라이스 VAE 디코드를 사용할 수 있습니다.\n\n이를 [`~StableDiffusionPipeline.enable_attention_slicing`] 또는 [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`]과 결합하여 메모리 사용을 추가로 최소화할 수 있습니다.\n\nVAE 디코드를 한 번에 하나씩 수행하려면 추론 전에 파이프라인에서 [`~StableDiffusionPipeline.enable_vae_slicing`]을 호출합니다. 예를 들어:\n\n```Python\nimport torch\nfrom diffusers import StableDiffusionPipeline\n\npipe = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n\n    torch_dtype=torch.float16,\n)\npipe = pipe.to(\"cuda\")\n\nprompt = \"a photo of an astronaut riding a horse on mars\"\npipe.enable_vae_slicing()\nimages = pipe([prompt] * 32).images\n```\n\n다중 이미지 배치에서 VAE 디코드가 약간의 성능 향상이 이루어집니다. 단일 이미지 배치에서는 성능 영향은 없습니다.\n\n\n<a name=\"sequential_offloading\"></a>\n## 메모리 절약을 위해 가속 기능을 사용하여 CPU로 오프로딩\n\n추가 메모리 절약을 위해 가중치를 CPU로 오프로드하고 순방향 전달을 수행할 때만 GPU로 로드할 수 있습니다.\n\nCPU 오프로딩을 수행하려면 [`~StableDiffusionPipeline.enable_sequential_cpu_offload`]를 호출하기만 하면 됩니다:\n\n```Python\nimport torch\nfrom diffusers import StableDiffusionPipeline\n\npipe = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n\n    torch_dtype=torch.float16,\n)\n\nprompt = \"a photo of an astronaut riding a horse on mars\"\npipe.enable_sequential_cpu_offload()\nimage = pipe(prompt).images[0]\n```\n\n그러면 메모리 소비를 3GB 미만으로 줄일 수 있습니다.\n\n참고로 이 방법은 전체 모델이 아닌 서브모듈 수준에서 작동합니다. 이는 메모리 소비를 최소화하는 가장 좋은 방법이지만 프로세스의 반복적 특성으로 인해 추론 속도가 훨씬 느립니다. 파이프라인의 UNet 구성 요소는 여러 번 실행됩니다('num_inference_steps' 만큼). 매번 UNet의 서로 다른 서브모듈이 순차적으로 온로드된 다음 필요에 따라 오프로드되므로 메모리 이동 횟수가 많습니다.\n\n> [!TIP]\n> 또 다른 최적화 방법인 <a href=\"#model_offloading\">모델 오프로딩</a>을 사용하는 것을 고려하십시오. 이는 훨씬 빠르지만 메모리 절약이 크지는 않습니다.\n\n또한 ttention slicing과 연결해서 최소 메모리(< 2GB)로도 동작할 수 있습니다.\n\n\n```Python\nimport torch\nfrom diffusers import StableDiffusionPipeline\n\npipe = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n\n    torch_dtype=torch.float16,\n)\n\nprompt = \"a photo of an astronaut riding a horse on mars\"\npipe.enable_sequential_cpu_offload()\npipe.enable_attention_slicing(1)\n\nimage = pipe(prompt).images[0]\n```\n\n**참고**: 'enable_sequential_cpu_offload()'를 사용할 때, 미리 파이프라인을 CUDA로 이동하지 **않는** 것이 중요합니다.그렇지 않으면 메모리 소비의 이득이 최소화됩니다. 더 많은 정보를 위해 [이 이슈](https://github.com/huggingface/diffusers/issues/1934)를 보세요.\n\n<a name=\"model_offloading\"></a>\n## 빠른 추론과 메모리 메모리 절약을 위한 모델 오프로딩\n\n[순차적 CPU 오프로딩](#sequential_offloading)은 이전 섹션에서 설명한 것처럼 많은 메모리를 보존하지만 필요에 따라 서브모듈을 GPU로 이동하고 새 모듈이 실행될 때 즉시 CPU로 반환되기 때문에 추론 속도가 느려집니다.\n\n전체 모델 오프로딩은 각 모델의 구성 요소인 _modules_을 처리하는 대신, 전체 모델을 GPU로 이동하는 대안입니다. 이로 인해 추론 시간에 미치는 영향은 미미하지만(파이프라인을 'cuda'로 이동하는 것과 비교하여) 여전히 약간의 메모리를 절약할 수 있습니다.\n\n이 시나리오에서는 파이프라인의 주요 구성 요소 중 하나만(일반적으로 텍스트 인코더, unet 및 vae) GPU에 있고, 나머지는 CPU에서 대기할 것입니다.\n여러 반복을 위해 실행되는 UNet과 같은 구성 요소는 더 이상 필요하지 않을 때까지 GPU에 남아 있습니다.\n\n이 기능은 아래와 같이 파이프라인에서 `enable_model_cpu_offload()`를 호출하여 활성화할 수 있습니다.\n\n```Python\nimport torch\nfrom diffusers import StableDiffusionPipeline\n\npipe = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16,\n)\n\nprompt = \"a photo of an astronaut riding a horse on mars\"\npipe.enable_model_cpu_offload()\nimage = pipe(prompt).images[0]\n```\n\n이는 추가적인 메모리 절약을 위한 attention slicing과도 호환됩니다.\n\n```Python\nimport torch\nfrom diffusers import StableDiffusionPipeline\n\npipe = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16,\n)\n\nprompt = \"a photo of an astronaut riding a horse on mars\"\npipe.enable_model_cpu_offload()\npipe.enable_attention_slicing(1)\n\nimage = pipe(prompt).images[0]\n```\n\n> [!TIP]\n> 이 기능을 사용하려면 'accelerate' 버전 0.17.0 이상이 필요합니다.\n\n## Channels Last 메모리 형식 사용하기\n\nChannels Last 메모리 형식은 차원 순서를 보존하는 메모리에서 NCHW 텐서 배열을 대체하는 방법입니다.\nChannels Last 텐서는 채널이 가장 조밀한 차원이 되는 방식으로 정렬됩니다(일명 픽셀당 이미지를 저장).\n현재 모든 연산자 Channels Last 형식을 지원하는 것은 아니라 성능이 저하될 수 있으므로, 사용해보고 모델에 잘 작동하는지 확인하는 것이 좋습니다.\n\n\n예를 들어 파이프라인의 UNet 모델이 channels Last 형식을 사용하도록 설정하려면 다음을 사용할 수 있습니다:\n\n```python\nprint(pipe.unet.conv_out.state_dict()[\"weight\"].stride())  # (2880, 9, 3, 1)\npipe.unet.to(memory_format=torch.channels_last)  # in-place 연산\n# 2번째 차원에서 스트라이드 1을 가지는 (2880, 1, 960, 320)로, 연산이 작동함을 증명합니다.\nprint(pipe.unet.conv_out.state_dict()[\"weight\"].stride())\n```\n\n## 추적(tracing)\n\n추적은 모델을 통해 예제 입력 텐서를 통해 실행되는데, 해당 입력이 모델의 레이어를 통과할 때 호출되는 작업을 캡처하여 실행 파일 또는 'ScriptFunction'이 반환되도록 하고, 이는 just-in-time 컴파일로 최적화됩니다.\n\nUNet 모델을 추적하기 위해 다음을 사용할 수 있습니다:\n\n```python\nimport time\nimport torch\nfrom diffusers import StableDiffusionPipeline\nimport functools\n\n# torch 기울기 비활성화\ntorch.set_grad_enabled(False)\n\n# 변수 설정\nn_experiments = 2\nunet_runs_per_experiment = 50\n\n\n# 입력 불러오기\ndef generate_inputs():\n    sample = torch.randn((2, 4, 64, 64), device=\"cuda\", dtype=torch.float16)\n    timestep = torch.rand(1, device=\"cuda\", dtype=torch.float16) * 999\n    encoder_hidden_states = torch.randn((2, 77, 768), device=\"cuda\", dtype=torch.float16)\n    return sample, timestep, encoder_hidden_states\n\n\npipe = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\nunet = pipe.unet\nunet.eval()\nunet.to(memory_format=torch.channels_last)  # Channels Last 메모리 형식 사용\nunet.forward = functools.partial(unet.forward, return_dict=False)  # return_dict=False을 기본값으로 설정\n\n# 워밍업\nfor _ in range(3):\n    with torch.inference_mode():\n        inputs = generate_inputs()\n        orig_output = unet(*inputs)\n\n# 추적\nprint(\"tracing..\")\nunet_traced = torch.jit.trace(unet, inputs)\nunet_traced.eval()\nprint(\"done tracing\")\n\n\n# 워밍업 및 그래프 최적화\nfor _ in range(5):\n    with torch.inference_mode():\n        inputs = generate_inputs()\n        orig_output = unet_traced(*inputs)\n\n\n# 벤치마킹\nwith torch.inference_mode():\n    for _ in range(n_experiments):\n        torch.cuda.synchronize()\n        start_time = time.time()\n        for _ in range(unet_runs_per_experiment):\n            orig_output = unet_traced(*inputs)\n        torch.cuda.synchronize()\n        print(f\"unet traced inference took {time.time() - start_time:.2f} seconds\")\n    for _ in range(n_experiments):\n        torch.cuda.synchronize()\n        start_time = time.time()\n        for _ in range(unet_runs_per_experiment):\n            orig_output = unet(*inputs)\n        torch.cuda.synchronize()\n        print(f\"unet inference took {time.time() - start_time:.2f} seconds\")\n\n# 모델 저장\nunet_traced.save(\"unet_traced.pt\")\n```\n\n그 다음, 파이프라인의 `unet` 특성을 다음과 같이 추적된 모델로 바꿀 수 있습니다.\n\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\nfrom dataclasses import dataclass\n\n\n@dataclass\nclass UNet2DConditionOutput:\n    sample: torch.Tensor\n\n\npipe = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n\n# jitted unet 사용\nunet_traced = torch.jit.load(\"unet_traced.pt\")\n\n\n# pipe.unet 삭제\nclass TracedUNet(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.in_channels = pipe.unet.config.in_channels\n        self.device = pipe.unet.device\n\n    def forward(self, latent_model_input, t, encoder_hidden_states):\n        sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]\n        return UNet2DConditionOutput(sample=sample)\n\n\npipe.unet = TracedUNet()\n\nwith torch.inference_mode():\n    image = pipe([prompt] * 1, num_inference_steps=50).images[0]\n```\n\n\n## Memory-efficient attention\n\n어텐션 블록의 대역폭을 최적화하는 최근 작업으로 GPU 메모리 사용량이 크게 향상되고 향상되었습니다.\n@tridao의 가장 최근의 플래시 어텐션: [code](https://github.com/HazyResearch/flash-attention), [paper](https://huggingface.co/papers/2205.14135).\n\n배치 크기 1(프롬프트 1개)의 512x512 크기로 추론을 실행할 때 몇 가지 Nvidia GPU에서 얻은 속도 향상은 다음과 같습니다:\n\n| GPU              \t| 기준 어텐션 FP16 \t       | 메모리 효율적인 어텐션 FP16 \t|\n|------------------\t|---------------------\t|---------------------------------\t|\n| NVIDIA Tesla T4  \t| 3.5it/s             \t| 5.5it/s                         \t|\n| NVIDIA 3060 RTX  \t| 4.6it/s             \t| 7.8it/s                         \t|\n| NVIDIA A10G      \t| 8.88it/s            \t| 15.6it/s                        \t|\n| NVIDIA RTX A6000 \t| 11.7it/s            \t| 21.09it/s                       \t|\n| NVIDIA TITAN RTX  | 12.51it/s         \t| 18.22it/s                       \t|\n| A100-SXM4-40GB    \t| 18.6it/s            \t| 29.it/s                        \t|\n| A100-SXM-80GB    \t| 18.7it/s            \t| 29.5it/s                        \t|\n\n이를 활용하려면 다음을 만족해야 합니다:\n - PyTorch > 1.12\n - Cuda 사용 가능\n - [xformers 라이브러리를 설치함](xformers)\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\npipe = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n\npipe.enable_xformers_memory_efficient_attention()\n\nwith torch.inference_mode():\n    sample = pipe(\"a small cat\")\n\n# 선택: 이를 비활성화 하기 위해 다음을 사용할 수 있습니다.\n# pipe.disable_xformers_memory_efficient_attention()\n```\n"
  },
  {
    "path": "docs/source/ko/optimization/habana.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Intel Gaudi에서 Stable Diffusion을 사용하는 방법\n\n🤗 Diffusers는 🤗 [Optimum Habana](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion)를 통해서 Habana Gaudi와 호환됩니다.\n\n## 요구 사항\n\n- Optimum Habana 1.4 또는 이후, [여기](https://huggingface.co/docs/optimum/habana/installation)에 설치하는 방법이 있습니다.\n- SynapseAI 1.8.\n\n\n## 추론 파이프라인\n\nGaudi에서 Stable Diffusion 1 및 2로 이미지를 생성하려면 두 인스턴스를 인스턴스화해야 합니다:\n- [`GaudiStableDiffusionPipeline`](https://huggingface.co/docs/optimum/habana/package_reference/stable_diffusion_pipeline)이 포함된 파이프라인. 이 파이프라인은 *텍스트-이미지 생성*을 지원합니다.\n- [`GaudiDDIMScheduler`](https://huggingface.co/docs/optimum/habana/package_reference/stable_diffusion_pipeline#optimum.habana.diffusers.GaudiDDIMScheduler)이 포함된 스케줄러. 이 스케줄러는 Habana Gaudi에 최적화되어 있습니다.\n\n파이프라인을 초기화할 때, HPU에 배포하기 위해 `use_habana=True`를 지정해야 합니다.\n또한 가능한 가장 빠른 생성을 위해 `use_hpu_graphs=True`로 **HPU 그래프**를 활성화해야 합니다.\n마지막으로, [Hugging Face Hub](https://huggingface.co/Habana)에서 다운로드할 수 있는 [Gaudi configuration](https://huggingface.co/docs/optimum/habana/package_reference/gaudi_config)을 지정해야 합니다.\n\n```python\nfrom optimum.habana import GaudiConfig\nfrom optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionPipeline\n\nmodel_name = \"stabilityai/stable-diffusion-2-base\"\nscheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder=\"scheduler\")\npipeline = GaudiStableDiffusionPipeline.from_pretrained(\n    model_name,\n    scheduler=scheduler,\n    use_habana=True,\n    use_hpu_graphs=True,\n    gaudi_config=\"Habana/stable-diffusion\",\n)\n```\n\n파이프라인을 호출하여 하나 이상의 프롬프트에서 배치별로 이미지를 생성할 수 있습니다.\n\n```python\noutputs = pipeline(\n    prompt=[\n        \"High quality photo of an astronaut riding a horse in space\",\n        \"Face of a yellow cat, high resolution, sitting on a park bench\",\n    ],\n    num_images_per_prompt=10,\n    batch_size=4,\n)\n```\n\n더 많은 정보를 얻기 위해, Optimum Habana의 [문서](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion)와 공식 GitHub 저장소에 제공된 [예시](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)를 확인하세요.\n\n\n## 벤치마크\n\n다음은 [Habana/stable-diffusion](https://huggingface.co/Habana/stable-diffusion) Gaudi 구성(혼합 정밀도 bf16/fp32)을 사용하는 Habana first-generation Gaudi 및 Gaudi2의 지연 시간입니다:\n\n|                        | Latency (배치 크기 = 1) | Throughput (배치 크기 = 8) |\n| ---------------------- |:------------------------:|:---------------------------:|\n| first-generation Gaudi | 4.29s                    | 0.283 images/s              |\n| Gaudi2                 | 1.54s                    | 0.904 images/s              |\n"
  },
  {
    "path": "docs/source/ko/optimization/mps.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Apple Silicon (M1/M2)에서 Stable Diffusion을 사용하는 방법\n\nDiffusers는 Stable Diffusion 추론을 위해 PyTorch `mps`를 사용해 Apple 실리콘과 호환됩니다. 다음은 Stable Diffusion이 있는 M1 또는 M2 컴퓨터를 사용하기 위해 따라야 하는 단계입니다.\n\n## 요구 사항\n\n- Apple silicon (M1/M2) 하드웨어의 Mac 컴퓨터.\n- macOS 12.6 또는 이후 (13.0 또는 이후 추천).\n- Python arm64 버전\n- PyTorch 2.0(추천) 또는 1.13(`mps`를 지원하는 최소 버전). Yhttps://pytorch.org/get-started/locally/의 지침에 따라 `pip` 또는 `conda`로 설치할 수 있습니다.\n\n\n## 추론 파이프라인\n\n아래 코도는 익숙한 `to()` 인터페이스를 사용하여 `mps` 백엔드로 Stable Diffusion 파이프라인을 M1 또는 M2 장치로 이동하는 방법을 보여줍니다.\n\n\n> [!WARNING]\n> **PyTorch 1.13을 사용 중일 때 ** 추가 일회성 전달을 사용하여 파이프라인을 \"프라이밍\"하는 것을 추천합니다. 이것은 발견한 이상한 문제에 대한 임시 해결 방법입니다. 첫 번째 추론 전달은 후속 전달와 약간 다른 결과를 생성합니다. 이 전달은 한 번만 수행하면 되며 추론 단계를 한 번만 사용하고 결과를 폐기해도 됩니다.\n\n이전 팁에서 설명한 것들을 포함한 여러 문제를 해결하므로 PyTorch 2 이상을 사용하는 것이 좋습니다.\n\n\n```python\n# `hf auth login`에 로그인되어 있음을 확인\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\")\npipe = pipe.to(\"mps\")\n\n# 컴퓨터가 64GB 이하의 RAM 램일 때 추천\npipe.enable_attention_slicing()\n\nprompt = \"a photo of an astronaut riding a horse on mars\"\n\n# 처음 \"워밍업\" 전달 (위 설명을 보세요)\n_ = pipe(prompt, num_inference_steps=1)\n\n# 결과는 워밍업 전달 후의 CPU 장치의 결과와 일치합니다.\nimage = pipe(prompt).images[0]\n```\n\n## 성능 추천\n\nM1/M2 성능은 메모리 압력에 매우 민감합니다. 시스템은 필요한 경우 자동으로 스왑되지만 스왑할 때 성능이 크게 저하됩니다.\n\n\n특히 컴퓨터의 시스템 RAM이 64GB 미만이거나 512 × 512픽셀보다 큰 비표준 해상도에서 이미지를 생성하는 경우, 추론 중에 메모리 압력을 줄이고 스와핑을 방지하기 위해 *어텐션 슬라이싱*을 사용하는 것이 좋습니다. 어텐션 슬라이싱은 비용이 많이 드는 어텐션 작업을 한 번에 모두 수행하는 대신 여러 단계로 수행합니다. 일반적으로 범용 메모리가 없는 컴퓨터에서 ~20%의 성능 영향을 미치지만 64GB 이상이 아닌 경우 대부분의 Apple Silicon 컴퓨터에서 *더 나은 성능*이 관찰되었습니다.\n\n```python\npipeline.enable_attention_slicing()\n```\n\n## Known Issues\n\n- 여러 프롬프트를 배치로 생성하는 것은 [충돌이 발생하거나 안정적으로 작동하지 않습니다](https://github.com/huggingface/diffusers/issues/363). 우리는 이것이 [PyTorch의 `mps` 백엔드](https://github.com/pytorch/pytorch/issues/84039)와 관련이 있다고 생각합니다. 이 문제는 해결되고 있지만 지금은 배치 대신 반복 방법을 사용하는 것이 좋습니다."
  },
  {
    "path": "docs/source/ko/optimization/onnx.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n\n# 추론을 위해 ONNX 런타임을 사용하는 방법\n\n🤗 Diffusers는 ONNX Runtime과 호환되는 Stable Diffusion 파이프라인을 제공합니다. 이를 통해 ONNX(CPU 포함)를 지원하고 PyTorch의 가속 버전을 사용할 수 없는 모든 하드웨어에서 Stable Diffusion을 실행할 수 있습니다.\n\n## 설치\n\n다음 명령어로 ONNX Runtime를 지원하는 🤗 Optimum를 설치합니다:\n\n```sh\npip install optimum[\"onnxruntime\"]\n```\n\n## Stable Diffusion 추론\n\n아래 코드는 ONNX 런타임을 사용하는 방법을 보여줍니다. `StableDiffusionPipeline` 대신 `OnnxStableDiffusionPipeline`을 사용해야 합니다.\nPyTorch 모델을 불러오고 즉시 ONNX 형식으로 변환하려는 경우 `export=True`로 설정합니다.\n\n```python\nfrom optimum.onnxruntime import ORTStableDiffusionPipeline\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipe = ORTStableDiffusionPipeline.from_pretrained(model_id, export=True)\nprompt = \"a photo of an astronaut riding a horse on mars\"\nimages = pipe(prompt).images[0]\npipe.save_pretrained(\"./onnx-stable-diffusion-v1-5\")\n```\n\n파이프라인을 ONNX 형식으로 오프라인으로 내보내고 나중에 추론에 사용하려는 경우,\n[`optimum-cli export`](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) 명령어를 사용할 수 있습니다:\n\n```bash\noptimum-cli export onnx --model stable-diffusion-v1-5/stable-diffusion-v1-5 sd_v15_onnx/\n```\n\n그 다음 추론을 수행합니다:\n\n```python\nfrom optimum.onnxruntime import ORTStableDiffusionPipeline\n\nmodel_id = \"sd_v15_onnx\"\npipe = ORTStableDiffusionPipeline.from_pretrained(model_id)\nprompt = \"a photo of an astronaut riding a horse on mars\"\nimages = pipe(prompt).images[0]\n```\n\nNotice that we didn't have to specify `export=True` above.\n\n[Optimum 문서](https://huggingface.co/docs/optimum/)에서 더 많은 예시를 찾을 수 있습니다.\n\n## 알려진 이슈들\n\n- 여러 프롬프트를 배치로 생성하면 너무 많은 메모리가 사용되는 것 같습니다. 이를 조사하는 동안, 배치 대신 반복 방법이 필요할 수도 있습니다.\n"
  },
  {
    "path": "docs/source/ko/optimization/open_vino.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 추론을 위한 OpenVINO 사용 방법\n\n🤗 [Optimum](https://github.com/huggingface/optimum-intel)은 OpenVINO와 호환되는 Stable Diffusion 파이프라인을 제공합니다.\n이제 다양한 Intel 프로세서에서 OpenVINO Runtime으로 쉽게 추론을 수행할 수 있습니다. ([여기](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html)서 지원되는 전 기기 목록을 확인하세요).\n\n## 설치\n\n다음 명령어로 🤗 Optimum을 설치합니다:\n\n```sh\npip install optimum[\"openvino\"]\n```\n\n## Stable Diffusion 추론\n\nOpenVINO 모델을 불러오고 OpenVINO 런타임으로 추론을 실행하려면 `StableDiffusionPipeline`을 `OVStableDiffusionPipeline`으로 교체해야 합니다. PyTorch 모델을 불러오고 즉시 OpenVINO 형식으로 변환하려는 경우 `export=True`로 설정합니다.\n\n```python\nfrom optimum.intel.openvino import OVStableDiffusionPipeline\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipe = OVStableDiffusionPipeline.from_pretrained(model_id, export=True)\nprompt = \"a photo of an astronaut riding a horse on mars\"\nimages = pipe(prompt).images[0]\n```\n\n[Optimum 문서](https://huggingface.co/docs/optimum/intel/inference#export-and-inference-of-stable-diffusion-models)에서 (정적 reshaping과 모델 컴파일 등의) 더 많은 예시들을 찾을 수 있습니다.\n"
  },
  {
    "path": "docs/source/ko/optimization/tome.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Token Merging (토큰 병합)\n\nToken Merging (introduced in [Token Merging: Your ViT But Faster](https://huggingface.co/papers/2210.09461))은 트랜스포머 기반 네트워크의 forward pass에서 중복 토큰이나 패치를 점진적으로 병합하는 방식으로 작동합니다. 이를 통해 기반 네트워크의 추론 지연 시간을 단축할 수 있습니다.\n\nToken Merging(ToMe)이 출시된 후, 저자들은 [Fast Stable Diffusion을 위한 토큰 병합](https://huggingface.co/papers/2303.17604)을 발표하여 Stable Diffusion과 더 잘 호환되는 ToMe 버전을 소개했습니다. ToMe를 사용하면 [`DiffusionPipeline`]의 추론 지연 시간을 부드럽게 단축할 수 있습니다. 이 문서에서는 ToMe를 [`StableDiffusionPipeline`]에 적용하는 방법, 예상되는 속도 향상, [`StableDiffusionPipeline`]에서 ToMe를 사용할 때의 질적 측면에 대해 설명합니다.\n\n## ToMe 사용하기\n\nToMe의 저자들은 [`tomesd`](https://github.com/dbolya/tomesd)라는 편리한 Python 라이브러리를 공개했는데, 이 라이브러리를 이용하면 [`DiffusionPipeline`]에 ToMe를 다음과 같이 적용할 수 있습니다:\n\n```diff\nfrom diffusers import StableDiffusionPipeline\nimport tomesd\n\npipeline = StableDiffusionPipeline.from_pretrained(\n      \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16\n).to(\"cuda\")\n+ tomesd.apply_patch(pipeline, ratio=0.5)\n\nimage = pipeline(\"a photo of an astronaut riding a horse on mars\").images[0]\n```\n\n이것이 다입니다!\n\n`tomesd.apply_patch()`는 파이프라인 추론 속도와 생성된 토큰의 품질 사이의 균형을 맞출 수 있도록 [여러 개의 인자](https://github.com/dbolya/tomesd#usage)를 노출합니다. 이러한 인수 중 가장 중요한 것은 `ratio(비율)`입니다. `ratio`은 forward pass 중에 병합될 토큰의 수를 제어합니다. `tomesd`에 대한 자세한 내용은 해당 리포지토리(https://github.com/dbolya/tomesd) 및 [논문](https://huggingface.co/papers/2303.17604)을 참고하시기 바랍니다.\n\n## `StableDiffusionPipeline`으로 `tomesd` 벤치마킹하기\n\nWe benchmarked the impact of using `tomesd` on [`StableDiffusionPipeline`] along with [xformers](https://huggingface.co/docs/diffusers/optimization/xformers) across different image resolutions. We used A100 and V100 as our test GPU devices with the following development environment (with Python 3.8.5):\n다양한 이미지 해상도에서 [xformers](https://huggingface.co/docs/diffusers/optimization/xformers)를 적용한 상태에서, [`StableDiffusionPipeline`]에 `tomesd`를 사용했을 때의 영향을 벤치마킹했습니다. 테스트 GPU 장치로 A100과 V100을 사용했으며 개발 환경은 다음과 같습니다(Python 3.8.5 사용):\n\n```bash\n- `diffusers` version: 0.15.1\n- Python version: 3.8.16\n- PyTorch version (GPU?): 1.13.1+cu116 (True)\n- Huggingface_hub version: 0.13.2\n- Transformers version: 4.27.2\n- Accelerate version: 0.18.0\n- xFormers version: 0.0.16\n- tomesd version: 0.1.2\n```\n\n벤치마킹에는 다음 스크립트를 사용했습니다: [https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335](https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335). 결과는 다음과 같습니다:\n\n### A100\n\n| 해상도 | 배치 크기 | Vanilla | ToMe | ToMe + xFormers | ToMe 속도 향상 (%) | ToMe + xFormers 속도 향상 (%) |\n| --- | --- | --- | --- | --- | --- | --- |\n| 512 | 10 | 6.88 | 5.26 | 4.69 | 23.54651163 | 31.83139535 |\n|  |  |  |  |  |  |  |\n| 768 | 10 | OOM | 14.71 | 11 |  |  |\n|  | 8 | OOM | 11.56 | 8.84 |  |  |\n|  | 4 | OOM | 5.98 | 4.66 |  |  |\n|  | 2 | 4.99 | 3.24 | 3.1 | 35.07014028 | 37.8757515 |\n|  | 1 | 3.29 | 2.24 | 2.03 | 31.91489362 | 38.29787234 |\n|  |  |  |  |  |  |  |\n| 1024 | 10 | OOM | OOM | OOM |  |  |\n|  | 8 | OOM | OOM | OOM |  |  |\n|  | 4 | OOM | 12.51 | 9.09 |  |  |\n|  | 2 | OOM | 6.52 | 4.96 |  |  |\n|  | 1 | 6.4 | 3.61 | 2.81 | 43.59375 | 56.09375 |\n\n***결과는 초 단위입니다. 속도 향상은 `Vanilla`과 비교해 계산됩니다.***\n\n### V100\n\n| 해상도 | 배치 크기 | Vanilla | ToMe | ToMe + xFormers | ToMe 속도 향상 (%) | ToMe + xFormers 속도 향상 (%) |\n| --- | --- | --- | --- | --- | --- | --- |\n| 512 | 10 | OOM | 10.03 | 9.29 |  |  |\n|  | 8 | OOM | 8.05 | 7.47 |  |  |\n|  | 4 | 5.7 | 4.3 | 3.98 | 24.56140351 | 30.1754386 |\n|  | 2 | 3.14 | 2.43 | 2.27 | 22.61146497 | 27.70700637 |\n|  | 1 | 1.88 | 1.57 | 1.57 | 16.4893617 | 16.4893617 |\n|  |  |  |  |  |  |  |\n| 768 | 10 | OOM | OOM | 23.67 |  |  |\n|  | 8 | OOM | OOM | 18.81 |  |  |\n|  | 4 | OOM | 11.81 | 9.7 |  |  |\n|  | 2 | OOM | 6.27 | 5.2 |  |  |\n|  | 1 | 5.43 | 3.38 | 2.82 | 37.75322284 | 48.06629834 |\n|  |  |  |  |  |  |  |\n| 1024 | 10 | OOM | OOM | OOM |  |  |\n|  | 8 | OOM | OOM | OOM |  |  |\n|  | 4 | OOM | OOM | 19.35 |  |  |\n|  | 2 | OOM | 13 | 10.78 |  |  |\n|  | 1 | OOM | 6.66 | 5.54 |  |  |\n\n위의 표에서 볼 수 있듯이, 이미지 해상도가 높을수록 `tomesd`를 사용한 속도 향상이 더욱 두드러집니다. 또한 `tomesd`를 사용하면 1024x1024와 같은 더 높은 해상도에서 파이프라인을 실행할 수 있다는 점도 흥미롭습니다.\n\n[`torch.compile()`](https://huggingface.co/docs/diffusers/optimization/torch2.0)을 사용하면 추론 속도를 더욱 높일 수 있습니다.\n\n## 품질\n\nAs reported in [the paper](https://huggingface.co/papers/2303.17604), ToMe can preserve the quality of the generated images to a great extent while speeding up inference. By increasing the `ratio`, it is possible to further speed up inference, but that might come at the cost of a deterioration in the image quality.\n\nTo test the quality of the generated samples using our setup, we sampled a few prompts from the “Parti Prompts” (introduced in [Parti](https://parti.research.google/)) and performed inference with the [`StableDiffusionPipeline`] in the following settings:\n\n[논문](https://huggingface.co/papers/2303.17604)에 보고된 바와 같이, ToMe는 생성된 이미지의 품질을 상당 부분 보존하면서 추론 속도를 높일 수 있습니다. `ratio`을 높이면 추론 속도를 더 높일 수 있지만, 이미지 품질이 저하될 수 있습니다.\n\n해당 설정을 사용하여 생성된 샘플의 품질을 테스트하기 위해, \"Parti 프롬프트\"([Parti](https://parti.research.google/)에서 소개)에서 몇 가지 프롬프트를 샘플링하고 다음 설정에서 [`StableDiffusionPipeline`]을 사용하여 추론을 수행했습니다:\n\n- Vanilla [`StableDiffusionPipeline`]\n- [`StableDiffusionPipeline`] + ToMe\n- [`StableDiffusionPipeline`] + ToMe + xformers\n\n생성된 샘플의 품질이 크게 저하되는 것을 발견하지 못했습니다. 다음은 샘플입니다:\n\n![tome-samples](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/tome/tome_samples.png)\n\n생성된 샘플은 [여기](https://wandb.ai/sayakpaul/tomesd-results/runs/23j4bj3i?workspace=)에서 확인할 수 있습니다. 이 실험을 수행하기 위해 [이 스크립트](https://gist.github.com/sayakpaul/8cac98d7f22399085a060992f411ecbd)를 사용했습니다."
  },
  {
    "path": "docs/source/ko/optimization/torch2.0.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Diffusers에서의 PyTorch 2.0 가속화 지원\n\n`0.13.0` 버전부터 Diffusers는 [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/)에서의 최신 최적화를 지원합니다. 이는 다음을 포함됩니다.\n1. momory-efficient attention을 사용한 가속화된 트랜스포머 지원 - `xformers`같은 추가적인 dependencies 필요 없음\n2. 추가 성능 향상을 위한 개별 모델에 대한 컴파일 기능 [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) 지원\n\n\n## 설치\n가속화된 어텐션 구현과 및 `torch.compile()`을 사용하기 위해, pip에서 최신 버전의 PyTorch 2.0을 설치되어 있고 diffusers 0.13.0. 버전 이상인지 확인하세요. 아래 설명된 바와 같이, PyTorch 2.0이 활성화되어 있을 때 diffusers는 최적화된 어텐션 프로세서([`AttnProcessor2_0`](https://github.com/huggingface/diffusers/blob/1a5797c6d4491a879ea5285c4efc377664e0332d/src/diffusers/models/attention_processor.py#L798))를 사용합니다.\n\n```bash\npip install --upgrade torch diffusers\n```\n\n## 가속화된 트랜스포머와 `torch.compile` 사용하기.\n\n\n1. **가속화된 트랜스포머 구현**\n\n   PyTorch 2.0에는 [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) 함수를 통해 최적화된 memory-efficient attention의 구현이 포함되어 있습니다. 이는 입력 및 GPU 유형에 따라 여러 최적화를 자동으로 활성화합니다. 이는 [xFormers](https://github.com/facebookresearch/xformers)의 `memory_efficient_attention`과 유사하지만 기본적으로 PyTorch에 내장되어 있습니다.\n\n   이러한 최적화는 PyTorch 2.0이 설치되어 있고 `torch.nn.functional.scaled_dot_product_attention`을 사용할 수 있는 경우 Diffusers에서 기본적으로 활성화됩니다. 이를 사용하려면 `torch 2.0`을 설치하고 파이프라인을 사용하기만 하면 됩니다. 예를 들어:\n\n    ```Python\n    import torch\n    from diffusers import DiffusionPipeline\n\n    pipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16)\n    pipe = pipe.to(\"cuda\")\n\n    prompt = \"a photo of an astronaut riding a horse on mars\"\n    image = pipe(prompt).images[0]\n    ```\n\n    이를 명시적으로 활성화하려면(필수는 아님) 아래와 같이 수행할 수 있습니다.\n\n    ```diff\n    import torch\n    from diffusers import DiffusionPipeline\n    + from diffusers.models.attention_processor import AttnProcessor2_0\n\n    pipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16).to(\"cuda\")\n    + pipe.unet.set_attn_processor(AttnProcessor2_0())\n\n    prompt = \"a photo of an astronaut riding a horse on mars\"\n    image = pipe(prompt).images[0]\n    ```\n\n    이 실행 과정은 `xFormers`만큼 빠르고 메모리적으로 효율적이어야 합니다. 자세한 내용은 [벤치마크](#benchmark)에서 확인하세요.\n\n    파이프라인을 보다 deterministic으로 만들거나 파인 튜닝된 모델을 [Core ML](https://huggingface.co/docs/diffusers/v0.16.0/en/optimization/coreml#how-to-run-stable-diffusion-with-core-ml)과 같은 다른 형식으로 변환해야 하는 경우 바닐라 어텐션 프로세서 ([`AttnProcessor`](https://github.com/huggingface/diffusers/blob/1a5797c6d4491a879ea5285c4efc377664e0332d/src/diffusers/models/attention_processor.py#L402))로 되돌릴 수 있습니다. 일반 어텐션 프로세서를 사용하려면 [`~diffusers.UNet2DConditionModel.set_default_attn_processor`] 함수를 사용할 수 있습니다:\n\n    ```Python\n    import torch\n    from diffusers import DiffusionPipeline\n    from diffusers.models.attention_processor import AttnProcessor\n\n    pipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16).to(\"cuda\")\n    pipe.unet.set_default_attn_processor()\n\n    prompt = \"a photo of an astronaut riding a horse on mars\"\n    image = pipe(prompt).images[0]\n    ```\n\n2. **torch.compile**\n\n    추가적인 속도 향상을 위해 새로운 `torch.compile` 기능을 사용할 수 있습니다. 파이프라인의 UNet은 일반적으로 계산 비용이 가장 크기 때문에 나머지 하위 모델(텍스트 인코더와 VAE)은 그대로 두고 `unet`을 `torch.compile`로 래핑합니다. 자세한 내용과 다른 옵션은 [torch 컴파일 문서](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)를 참조하세요.\n\n    ```python\n    pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n    images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images\n    ```\n\n    GPU 유형에 따라 `compile()`은 가속화된 트랜스포머 최적화를 통해 **5% - 300%**의 _추가 성능 향상_을 얻을 수 있습니다. 그러나 컴파일은 Ampere(A100, 3090), Ada(4090) 및 Hopper(H100)와 같은 최신 GPU 아키텍처에서 더 많은 성능 향상을 가져올 수 있음을 참고하세요.\n\n    컴파일은 완료하는 데 약간의 시간이 걸리므로, 파이프라인을 한 번 준비한 다음 동일한 유형의 추론 작업을 여러 번 수행해야 하는 상황에 가장 적합합니다. 다른 이미지 크기에서 컴파일된 파이프라인을 호출하면 시간적 비용이 많이 들 수 있는 컴파일 작업이 다시 트리거됩니다.\n\n\n## 벤치마크\n\nPyTorch 2.0의 효율적인 어텐션 구현과 `torch.compile`을 사용하여 가장 많이 사용되는 5개의 파이프라인에 대해 다양한 GPU와 배치 크기에 걸쳐 포괄적인 벤치마크를 수행했습니다. 여기서는 [`torch.compile()`이 최적으로 활용되도록 하는](https://github.com/huggingface/diffusers/pull/3313) `diffusers 0.17.0.dev0`을 사용했습니다.\n\n### 벤치마킹 코드\n\n#### Stable Diffusion text-to-image\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\npath = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n\nrun_compile = True  # Set True / False\n\npipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)\npipe = pipe.to(\"cuda\")\npipe.unet.to(memory_format=torch.channels_last)\n\nif run_compile:\n    print(\"Run torch compile\")\n    pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n\nprompt = \"ghibli style, a fantasy landscape with castles\"\n\nfor _ in range(3):\n    images = pipe(prompt=prompt).images\n```\n\n#### Stable Diffusion image-to-image\n\n```python\nfrom diffusers import StableDiffusionImg2ImgPipeline\nimport requests\nimport torch\nfrom PIL import Image\nfrom io import BytesIO\n\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\n\nresponse = requests.get(url)\ninit_image = Image.open(BytesIO(response.content)).convert(\"RGB\")\ninit_image = init_image.resize((512, 512))\n\npath = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n\nrun_compile = True  # Set True / False\n\npipe = StableDiffusionImg2ImgPipeline.from_pretrained(path, torch_dtype=torch.float16)\npipe = pipe.to(\"cuda\")\npipe.unet.to(memory_format=torch.channels_last)\n\nif run_compile:\n    print(\"Run torch compile\")\n    pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n\nprompt = \"ghibli style, a fantasy landscape with castles\"\n\nfor _ in range(3):\n    image = pipe(prompt=prompt, image=init_image).images[0]\n```\n\n#### Stable Diffusion - inpainting\n\n```python\nfrom diffusers import StableDiffusionInpaintPipeline\nimport requests\nimport torch\nfrom PIL import Image\nfrom io import BytesIO\n\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\n\ndef download_image(url):\n    response = requests.get(url)\n    return Image.open(BytesIO(response.content)).convert(\"RGB\")\n\n\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\ninit_image = download_image(img_url).resize((512, 512))\nmask_image = download_image(mask_url).resize((512, 512))\n\npath = \"stable-diffusion-v1-5/stable-diffusion-inpainting\"\n\nrun_compile = True  # Set True / False\n\npipe = StableDiffusionInpaintPipeline.from_pretrained(path, torch_dtype=torch.float16)\npipe = pipe.to(\"cuda\")\npipe.unet.to(memory_format=torch.channels_last)\n\nif run_compile:\n    print(\"Run torch compile\")\n    pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n\nprompt = \"ghibli style, a fantasy landscape with castles\"\n\nfor _ in range(3):\n    image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]\n```\n\n#### ControlNet\n\n```python\nfrom diffusers import StableDiffusionControlNetPipeline, ControlNetModel\nimport requests\nimport torch\nfrom PIL import Image\nfrom io import BytesIO\n\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\n\nresponse = requests.get(url)\ninit_image = Image.open(BytesIO(response.content)).convert(\"RGB\")\ninit_image = init_image.resize((512, 512))\n\npath = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n\nrun_compile = True  # Set True / False\ncontrolnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\npipe = StableDiffusionControlNetPipeline.from_pretrained(\n    path, controlnet=controlnet, torch_dtype=torch.float16\n)\n\npipe = pipe.to(\"cuda\")\npipe.unet.to(memory_format=torch.channels_last)\npipe.controlnet.to(memory_format=torch.channels_last)\n\nif run_compile:\n    print(\"Run torch compile\")\n    pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n    pipe.controlnet = torch.compile(pipe.controlnet, mode=\"reduce-overhead\", fullgraph=True)\n\nprompt = \"ghibli style, a fantasy landscape with castles\"\n\nfor _ in range(3):\n    image = pipe(prompt=prompt, image=init_image).images[0]\n```\n\n#### IF text-to-image + upscaling\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\nrun_compile = True  # Set True / False\n\npipe = DiffusionPipeline.from_pretrained(\"DeepFloyd/IF-I-M-v1.0\", variant=\"fp16\", text_encoder=None, torch_dtype=torch.float16)\npipe.to(\"cuda\")\npipe_2 = DiffusionPipeline.from_pretrained(\"DeepFloyd/IF-II-M-v1.0\", variant=\"fp16\", text_encoder=None, torch_dtype=torch.float16)\npipe_2.to(\"cuda\")\npipe_3 = DiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-x4-upscaler\", torch_dtype=torch.float16)\npipe_3.to(\"cuda\")\n\n\npipe.unet.to(memory_format=torch.channels_last)\npipe_2.unet.to(memory_format=torch.channels_last)\npipe_3.unet.to(memory_format=torch.channels_last)\n\nif run_compile:\n    pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n    pipe_2.unet = torch.compile(pipe_2.unet, mode=\"reduce-overhead\", fullgraph=True)\n    pipe_3.unet = torch.compile(pipe_3.unet, mode=\"reduce-overhead\", fullgraph=True)\n\nprompt = \"the blue hulk\"\n\nprompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16)\nneg_prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16)\n\nfor _ in range(3):\n    image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type=\"pt\").images\n    image_2 = pipe_2(image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type=\"pt\").images\n    image_3 = pipe_3(prompt=prompt, image=image, noise_level=100).images\n```\n\nPyTorch 2.0 및 `torch.compile()`로 얻을 수 있는 가능한 속도 향상에 대해, [Stable Diffusion text-to-image pipeline](StableDiffusionPipeline)에 대한 상대적인 속도 향상을 보여주는 차트를 5개의 서로 다른 GPU 제품군(배치 크기 4)에 대해 나타냅니다:\n\n![t2i_speedup](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/pt2_benchmarks/t2i_speedup.png)\n\nTo give you an even better idea of how this speed-up holds for the other pipelines presented above, consider the following\nplot that shows the benchmarking numbers from an A100 across three different batch sizes\n(with PyTorch 2.0 nightly and `torch.compile()`):\n이 속도 향상이 위에 제시된 다른 파이프라인에 대해서도 어떻게 유지되는지 더 잘 이해하기 위해, 세 가지의 다른 배치 크기에 걸쳐 A100의 벤치마킹(PyTorch 2.0 nightly 및 `torch.compile() 사용) 수치를 보여주는 차트를 보입니다:\n\n![a100_numbers](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/pt2_benchmarks/a100_numbers.png)\n\n_(위 차트의 벤치마크 메트릭은 **초당 iteration 수(iterations/second)**입니다)_\n\n그러나 투명성을 위해 모든 벤치마킹 수치를 공개합니다!\n\n다음 표들에서는, **_초당 처리되는 iteration_** 수 측면에서의 결과를 보여줍니다.\n\n### A100 (batch size: 1)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 21.66 | 23.13 | 44.03 | 49.74 |\n| SD - img2img | 21.81 | 22.40 | 43.92 | 46.32 |\n| SD - inpaint | 22.24 | 23.23 | 43.76 | 49.25 |\n| SD - controlnet | 15.02 | 15.82 | 32.13 | 36.08 |\n| IF | 20.21 / <br>13.84 / <br>24.00 | 20.12 / <br>13.70 / <br>24.03 | ❌ | 97.34 / <br>27.23 / <br>111.66 |\n\n### A100 (batch size: 4)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 11.6 | 13.12 | 14.62 | 17.27 |\n| SD - img2img | 11.47 | 13.06 | 14.66 | 17.25 |\n| SD - inpaint | 11.67 | 13.31 | 14.88 | 17.48 |\n| SD - controlnet | 8.28 | 9.38 | 10.51 | 12.41 |\n| IF | 25.02 | 18.04 | ❌ | 48.47 |\n\n### A100 (batch size: 16)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 3.04 | 3.6 | 3.83 | 4.68 |\n| SD - img2img | 2.98 | 3.58 | 3.83 | 4.67 |\n| SD - inpaint | 3.04 | 3.66 | 3.9 | 4.76 |\n| SD - controlnet | 2.15 | 2.58 | 2.74 | 3.35 |\n| IF | 8.78 | 9.82 | ❌ | 16.77 |\n\n### V100 (batch size: 1)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 18.99 | 19.14 | 20.95 | 22.17 |\n| SD - img2img | 18.56 | 19.18 | 20.95 | 22.11 |\n| SD - inpaint | 19.14 | 19.06 | 21.08 | 22.20 |\n| SD - controlnet | 13.48 | 13.93 | 15.18 | 15.88 |\n| IF |  20.01 / <br>9.08 / <br>23.34 | 19.79 / <br>8.98 / <br>24.10 | ❌ | 55.75 / <br>11.57 / <br>57.67 |\n\n### V100 (batch size: 4)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 5.96 | 5.89 | 6.83 | 6.86 |\n| SD - img2img | 5.90 | 5.91 | 6.81 | 6.82 |\n| SD - inpaint | 5.99 | 6.03 | 6.93 | 6.95 |\n| SD - controlnet | 4.26 | 4.29 | 4.92 | 4.93 |\n| IF | 15.41 | 14.76 | ❌ | 22.95 |\n\n### V100 (batch size: 16)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 1.66 | 1.66 | 1.92 | 1.90 |\n| SD - img2img | 1.65 | 1.65 | 1.91 | 1.89 |\n| SD - inpaint | 1.69 | 1.69 | 1.95 | 1.93 |\n| SD - controlnet | 1.19 | 1.19 | OOM after warmup | 1.36 |\n| IF | 5.43 | 5.29 | ❌ | 7.06 |\n\n### T4 (batch size: 1)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 6.9 | 6.95 | 7.3 | 7.56 |\n| SD - img2img | 6.84 | 6.99 | 7.04 | 7.55 |\n| SD - inpaint | 6.91 | 6.7 | 7.01 | 7.37 |\n| SD - controlnet | 4.89 | 4.86 | 5.35 | 5.48 |\n| IF | 17.42 / <br>2.47 / <br>18.52 | 16.96 / <br>2.45 / <br>18.69 | ❌ | 24.63 / <br>2.47 / <br>23.39 |\n\n### T4 (batch size: 4)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 1.79 | 1.79 | 2.03 | 1.99 |\n| SD - img2img | 1.77 | 1.77 | 2.05 | 2.04 |\n| SD - inpaint | 1.81 | 1.82 | 2.09 | 2.09 |\n| SD - controlnet | 1.34 | 1.27 | 1.47 | 1.46 |\n| IF | 5.79 |  5.61 | ❌ | 7.39 |\n\n### T4 (batch size: 16)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 2.34s | 2.30s | OOM after 2nd iteration | 1.99s |\n| SD - img2img | 2.35s | 2.31s | OOM after warmup | 2.00s |\n| SD - inpaint | 2.30s | 2.26s | OOM after 2nd iteration | 1.95s |\n| SD - controlnet | OOM after 2nd iteration | OOM after 2nd iteration | OOM after warmup | OOM after warmup |\n| IF * | 1.44 | 1.44 | ❌ | 1.94 |\n\n### RTX 3090 (batch size: 1)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 22.56 | 22.84 | 23.84 | 25.69 |\n| SD - img2img | 22.25 | 22.61 | 24.1 | 25.83 |\n| SD - inpaint | 22.22 | 22.54 | 24.26 | 26.02 |\n| SD - controlnet | 16.03 | 16.33 | 17.38 | 18.56 |\n| IF | 27.08 / <br>9.07 / <br>31.23 | 26.75 / <br>8.92 / <br>31.47 | ❌ | 68.08 / <br>11.16 / <br>65.29 |\n\n### RTX 3090 (batch size: 4)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 6.46 | 6.35 | 7.29 | 7.3 |\n| SD - img2img | 6.33 | 6.27 | 7.31 | 7.26 |\n| SD - inpaint | 6.47 | 6.4 | 7.44 | 7.39 |\n| SD - controlnet | 4.59 | 4.54 | 5.27 | 5.26 |\n| IF | 16.81 | 16.62 | ❌ | 21.57 |\n\n### RTX 3090 (batch size: 16)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 1.7 | 1.69 | 1.93 | 1.91 |\n| SD - img2img | 1.68 | 1.67 | 1.93 | 1.9 |\n| SD - inpaint | 1.72 | 1.71 | 1.97 | 1.94 |\n| SD - controlnet | 1.23 | 1.22 | 1.4 | 1.38 |\n| IF | 5.01 | 5.00 | ❌ | 6.33 |\n\n### RTX 4090 (batch size: 1)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 40.5 | 41.89 | 44.65 | 49.81 |\n| SD - img2img | 40.39 | 41.95 | 44.46 | 49.8 |\n| SD - inpaint | 40.51 | 41.88 | 44.58 | 49.72 |\n| SD - controlnet | 29.27 | 30.29 | 32.26 | 36.03 |\n| IF | 69.71 / <br>18.78 / <br>85.49 | 69.13 / <br>18.80 / <br>85.56 | ❌ | 124.60 / <br>26.37 / <br>138.79 |\n\n### RTX 4090 (batch size: 4)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 12.62 | 12.84 | 15.32 | 15.59 |\n| SD - img2img | 12.61 | 12,.79 | 15.35 | 15.66 |\n| SD - inpaint | 12.65 | 12.81 | 15.3 | 15.58 |\n| SD - controlnet | 9.1 | 9.25 | 11.03 | 11.22 |\n| IF | 31.88 | 31.14 | ❌ | 43.92 |\n\n### RTX 4090 (batch size: 16)\n\n| **Pipeline** | **torch 2.0 - <br>no compile** | **torch nightly - <br>no compile** | **torch 2.0 - <br>compile** | **torch nightly - <br>compile** |\n|:---:|:---:|:---:|:---:|:---:|\n| SD - txt2img | 3.17 | 3.2 | 3.84 | 3.85 |\n| SD - img2img | 3.16 | 3.2 | 3.84 | 3.85 |\n| SD - inpaint | 3.17 | 3.2 | 3.85 | 3.85 |\n| SD - controlnet | 2.23 | 2.3 | 2.7 | 2.75 |\n| IF | 9.26 | 9.2 | ❌ | 13.31 |\n\n## 참고\n\n* Follow [this PR](https://github.com/huggingface/diffusers/pull/3313) for more details on the environment used for conducting the benchmarks.\n* For the IF pipeline and batch sizes > 1, we only used a batch size of >1 in the first IF pipeline for text-to-image generation and NOT for upscaling. So, that means the two upscaling pipelines received a batch size of 1.\n\n*Thanks to [Horace He](https://github.com/Chillee) from the PyTorch team for their support in improving our support of `torch.compile()` in Diffusers.*\n\n* 벤치마크 수행에 사용된 환경에 대한 자세한 내용은 [이 PR](https://github.com/huggingface/diffusers/pull/3313)을 참조하세요.\n* IF 파이프라인와 배치 크기 > 1의 경우 첫 번째 IF 파이프라인에서 text-to-image 생성을 위한 배치 크기 > 1만 사용했으며 업스케일링에는 사용하지 않았습니다. 즉, 두 개의 업스케일링 파이프라인이 배치 크기 1임을 의미합니다.\n\n*Diffusers에서 `torch.compile()` 지원을 개선하는 데 도움을 준 PyTorch 팀의 [Horace He](https://github.com/Chillee)에게 감사드립니다.*"
  },
  {
    "path": "docs/source/ko/optimization/xformers.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# xFormers 설치하기\n\n추론과 학습 모두에 [xFormers](https://github.com/facebookresearch/xformers)를 사용하는 것이 좋습니다.\n자체 테스트로 어텐션 블록에서 수행된 최적화가 더 빠른 속도와 적은 메모리 소비를 확인했습니다.\n\n2023년 1월에 출시된 xFormers 버전 '0.0.16'부터 사전 빌드된 pip wheel을 사용하여 쉽게 설치할 수 있습니다:\n\n```bash\npip install xformers\n```\n\n> [!TIP]\n> xFormers PIP 패키지에는 최신 버전의 PyTorch(xFormers 0.0.16에 1.13.1)가 필요합니다. 이전 버전의 PyTorch를 사용해야 하는 경우 [프로젝트 지침](https://github.com/facebookresearch/xformers#installing-xformers)의 소스를 사용해 xFormers를 설치하는 것이 좋습니다.\n\nxFormers를 설치하면, [여기](fp16#memory-efficient-attention)서 설명한 것처럼 'enable_xformers_memory_efficient_attention()'을 사용하여 추론 속도를 높이고 메모리 소비를 줄일 수 있습니다.\n\n> [!WARNING]\n> [이 이슈](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212)에 따르면 xFormers `v0.0.16`에서 GPU를 사용한 학습(파인 튜닝 또는 Dreambooth)을 할 수 없습니다. 해당 문제가 발견되면. 해당 코멘트를 참고해 development 버전을 설치하세요.\n"
  },
  {
    "path": "docs/source/ko/quicktour.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n[[open-in-colab]]\n\n# 훑어보기\n\nDiffusion 모델은 이미지나 오디오와 같은 관심 샘플들을 생성하기 위해 랜덤 가우시안 노이즈를 단계별로 제거하도록 학습됩니다. 이로 인해 생성 AI에 대한 관심이 매우 높아졌으며, 인터넷에서 diffusion 생성 이미지의 예를 본 적이 있을 것입니다. 🧨 Diffusers는 누구나 diffusion 모델들을 널리 이용할 수 있도록 하기 위한 라이브러리입니다.\n\n개발자든 일반 사용자든 이 훑어보기를 통해 🧨 Diffusers를 소개하고 빠르게 생성할 수 있도록 도와드립니다! 알아야 할 라이브러리의 주요 구성 요소는 크게 세 가지입니다:\n\n* [`DiffusionPipeline`]은 추론을 위해 사전 학습된 diffusion 모델에서 샘플을 빠르게 생성하도록 설계된 높은 수준의 엔드투엔드 클래스입니다.\n* Diffusion 시스템 생성을 위한 빌딩 블록으로 사용할 수 있는 널리 사용되는 사전 학습된 [model](./api/models) 아키텍처 및 모듈.\n* 다양한 [schedulers](./api/schedulers/overview) - 학습을 위해 노이즈를 추가하는 방법과 추론 중에 노이즈 제거된 이미지를 생성하는 방법을 제어하는 알고리즘입니다.\n\n훑어보기에서는 추론을 위해 [`DiffusionPipeline`]을 사용하는 방법을 보여준 다음, 모델과 스케줄러를 결합하여 [`DiffusionPipeline`] 내부에서 일어나는 일을 복제하는 방법을 안내합니다.\n\n> [!TIP]\n> 훑어보기는 간결한 버전의 🧨 Diffusers 소개로서 [노트북](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) 빠르게 시작할 수 있도록 도와드립니다. 디퓨저의 목표, 디자인 철학, 핵심 API에 대한 추가 세부 정보를 자세히 알아보려면 노트북을 확인하세요!\n\n시작하기 전에 필요한 라이브러리가 모두 설치되어 있는지 확인하세요:\n\n```py\n# 주석 풀어서 Colab에 필요한 라이브러리 설치하기.\n#!pip install --upgrade diffusers accelerate transformers\n```\n\n- [🤗 Accelerate](https://huggingface.co/docs/accelerate/index)는 추론 및 학습을 위한 모델 로딩 속도를 높여줍니다.\n- [🤗 Transformers](https://huggingface.co/docs/transformers/index)는 [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview)과 같이 가장 많이 사용되는 diffusion 모델을 실행하는 데 필요합니다.\n\n## DiffusionPipeline\n\n[`DiffusionPipeline`] 은 추론을 위해 사전 학습된 diffusion 시스템을 사용하는 가장 쉬운 방법입니다. 모델과 스케줄러를 포함하는 엔드 투 엔드 시스템입니다. 다양한 작업에 [`DiffusionPipeline`]을 바로 사용할 수 있습니다. 아래 표에서 지원되는 몇 가지 작업을 살펴보고, 지원되는 작업의 전체 목록은 [🧨 Diffusers Summary](./api/pipelines/overview#diffusers-summary) 표에서 확인할 수 있습니다.\n\n| **Task**                     | **Description**                                                                                              | **Pipeline**\n|------------------------------|--------------------------------------------------------------------------------------------------------------|-----------------|\n| Unconditional Image Generation          | generate an image from Gaussian noise | [unconditional_image_generation](./using-diffusers/unconditional_image_generation) |\n| Text-Guided Image Generation | generate an image given a text prompt | [conditional_image_generation](./using-diffusers/conditional_image_generation) |\n| Text-Guided Image-to-Image Translation     | adapt an image guided by a text prompt | [img2img](./using-diffusers/img2img) |\n| Text-Guided Image-Inpainting          | fill the masked part of an image given the image, the mask and a text prompt | [inpaint](./using-diffusers/inpaint) |\n| Text-Guided Depth-to-Image Translation | adapt parts of an image guided by a text prompt while preserving structure via depth estimation | [depth2img](./using-diffusers/depth2img) |\n\n먼저 [`DiffusionPipeline`]의 인스턴스를 생성하고 다운로드할 파이프라인 체크포인트를 지정합니다.\n허깅페이스 허브에 저장된 모든 [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads)에 대해 [`DiffusionPipeline`]을 사용할 수 있습니다.\n이 훑어보기에서는 text-to-image 생성을 위한 [`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 체크포인트를 로드합니다.\n\n> [!WARNING]\n> [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) 모델의 경우, 모델을 실행하기 전에 [라이선스](https://huggingface.co/spaces/CompVis/stable-diffusion-license)를 먼저 주의 깊게 읽어주세요. 🧨 Diffusers는 불쾌하거나 유해한 콘텐츠를 방지하기 위해 [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py)를 구현하고 있지만, 모델의 향상된 이미지 생성 기능으로 인해 여전히 잠재적으로 유해한 콘텐츠가 생성될 수 있습니다.\n\n[`~DiffusionPipeline.from_pretrained`] 방법으로 모델 로드하기:\n\n```python\n>>> from diffusers import DiffusionPipeline\n\n>>> pipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\")\n```\n\nThe [`DiffusionPipeline`]은 모든 모델링, 토큰화, 스케줄링 컴포넌트를 다운로드하고 캐시합니다. Stable Diffusion Pipeline은 무엇보다도 [`UNet2DConditionModel`]과 [`PNDMScheduler`]로 구성되어 있음을 알 수 있습니다:\n\n```py\n>>> pipeline\nStableDiffusionPipeline {\n  \"_class_name\": \"StableDiffusionPipeline\",\n  \"_diffusers_version\": \"0.13.1\",\n  ...,\n  \"scheduler\": [\n    \"diffusers\",\n    \"PNDMScheduler\"\n  ],\n  ...,\n  \"unet\": [\n    \"diffusers\",\n    \"UNet2DConditionModel\"\n  ],\n  \"vae\": [\n    \"diffusers\",\n    \"AutoencoderKL\"\n  ]\n}\n```\n\n이 모델은 약 14억 개의 파라미터로 구성되어 있으므로 GPU에서 파이프라인을 실행할 것을 강력히 권장합니다.\nPyTorch에서와 마찬가지로 제너레이터 객체를 GPU로 이동할 수 있습니다:\n\n```python\n>>> pipeline.to(\"cuda\")\n```\n\n이제 `파이프라인`에 텍스트 프롬프트를 전달하여 이미지를 생성한 다음 노이즈가 제거된 이미지에 액세스할 수 있습니다. 기본적으로 이미지 출력은 [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) 객체로 감싸집니다.\n\n```python\n>>> image = pipeline(\"An image of a squirrel in Picasso style\").images[0]\n>>> image\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png\"/>\n</div>\n\n`save`를 호출하여 이미지를 저장합니다:\n\n```python\n>>> image.save(\"image_of_squirrel_painting.png\")\n```\n\n### 로컬 파이프라인\n\n파이프라인을 로컬에서 사용할 수도 있습니다. 유일한 차이점은 가중치를 먼저 다운로드해야 한다는 점입니다:\n\n```bash\n!git lfs install\n!git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5\n```\n\n그런 다음 저장된 가중치를 파이프라인에 로드합니다:\n\n```python\n>>> pipeline = DiffusionPipeline.from_pretrained(\"./stable-diffusion-v1-5\")\n```\n\n이제 위 섹션에서와 같이 파이프라인을 실행할 수 있습니다.\n\n### 스케줄러 교체\n\n스케줄러마다 노이즈 제거 속도와 품질이 서로 다릅니다. 자신에게 가장 적합한 스케줄러를 찾는 가장 좋은 방법은 직접 사용해 보는 것입니다! 🧨 Diffusers의 주요 기능 중 하나는 스케줄러 간에 쉽게 전환이 가능하다는 것입니다. 예를 들어, 기본 스케줄러인 [`PNDMScheduler`]를 [`EulerDiscreteScheduler`]로 바꾸려면, [`~diffusers.ConfigMixin.from_config`] 메서드를 사용하여 로드하세요:\n\n```py\n>>> from diffusers import EulerDiscreteScheduler\n\n>>> pipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\")\n>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)\n```\n\n새 스케줄러로 이미지를 생성해보고 어떤 차이가 있는지 확인해 보세요!\n\n다음 섹션에서는 모델과 스케줄러라는 [`DiffusionPipeline`]을 구성하는 컴포넌트를 자세히 살펴보고 이러한 컴포넌트를 사용하여 고양이 이미지를 생성하는 방법을 배워보겠습니다.\n\n## 모델\n\n대부분의 모델은 노이즈가 있는 샘플을 가져와 각 시간 간격마다 노이즈가 적은 이미지와 입력 이미지 사이의 차이인 *노이즈 잔차*(다른 모델은 이전 샘플을 직접 예측하거나 속도 또는 [`v-prediction`](https://github.com/huggingface/diffusers/blob/5e5ce13e2f89ac45a0066cb3f369462a3cf1d9ef/src/diffusers/schedulers/scheduling_ddim.py#L110)을 예측하는 학습을 합니다)을 예측합니다. 모델을 믹스 앤 매치하여 다른 diffusion 시스템을 만들 수 있습니다.\n\n모델은 [`~ModelMixin.from_pretrained`] 메서드로 시작되며, 이 메서드는 모델 가중치를 로컬에 캐시하여 다음에 모델을 로드할 때 더 빠르게 로드할 수 있습니다. 훑어보기에서는 고양이 이미지에 대해 학습된 체크포인트가 있는 기본적인 unconditional 이미지 생성 모델인 [`UNet2DModel`]을 로드합니다:\n\n```py\n>>> from diffusers import UNet2DModel\n\n>>> repo_id = \"google/ddpm-cat-256\"\n>>> model = UNet2DModel.from_pretrained(repo_id)\n```\n\n모델 매개변수에 액세스하려면 `model.config`를 호출합니다:\n\n```py\n>>> model.config\n```\n\n모델 구성은 🧊 고정된 🧊 딕셔너리로, 모델이 생성된 후에는 해당 매개 변수들을 변경할 수 없습니다. 이는 의도적인 것으로, 처음에 모델 아키텍처를 정의하는 데 사용된 매개변수는 동일하게 유지하면서 다른 매개변수는 추론 중에 조정할 수 있도록 하기 위한 것입니다.\n\n가장 중요한 매개변수들은 다음과 같습니다:\n\n* `sample_size`: 입력 샘플의 높이 및 너비 치수입니다.\n* `in_channels`: 입력 샘플의 입력 채널 수입니다.\n* `down_block_types` 및 `up_block_types`: UNet 아키텍처를 생성하는 데 사용되는 다운 및 업샘플링 블록의 유형.\n* `block_out_channels`: 다운샘플링 블록의 출력 채널 수. 업샘플링 블록의 입력 채널 수에 역순으로 사용되기도 합니다.\n* `layers_per_block`: 각 UNet 블록에 존재하는 ResNet 블록의 수입니다.\n\n추론에 모델을 사용하려면 랜덤 가우시안 노이즈로 이미지 모양을 만듭니다. 모델이 여러 개의 무작위 노이즈를 수신할 수 있으므로 'batch' 축, 입력 채널 수에 해당하는 'channel' 축, 이미지의 높이와 너비를 나타내는 'sample_size' 축이 있어야 합니다:\n\n```py\n>>> import torch\n\n>>> torch.manual_seed(0)\n\n>>> noisy_sample = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)\n>>> noisy_sample.shape\ntorch.Size([1, 3, 256, 256])\n```\n\n추론을 위해 모델에 노이즈가 있는 이미지와 `timestep`을 전달합니다. 'timestep'은 입력 이미지의 노이즈 정도를 나타내며, 시작 부분에 더 많은 노이즈가 있고 끝 부분에 더 적은 노이즈가 있습니다. 이를 통해 모델이 diffusion 과정에서 시작 또는 끝에 더 가까운 위치를 결정할 수 있습니다. `sample` 메서드를 사용하여 모델 출력을 얻습니다:\n\n```py\n>>> with torch.no_grad():\n...     noisy_residual = model(sample=noisy_sample, timestep=2).sample\n```\n\n하지만 실제 예를 생성하려면 노이즈 제거 프로세스를 안내할 스케줄러가 필요합니다. 다음 섹션에서는 모델을 스케줄러와 결합하는 방법에 대해 알아봅니다.\n\n## 스케줄러\n\n스케줄러는 모델 출력이 주어졌을 때 노이즈가 많은 샘플에서 노이즈가 적은 샘플로 전환하는 것을 관리합니다 - 이 경우 'noisy_residual'.\n\n> [!TIP]\n> 🧨 Diffusers는 Diffusion 시스템을 구축하기 위한 툴박스입니다. [`DiffusionPipeline`]을 사용하면 미리 만들어진 Diffusion 시스템을 편리하게 시작할 수 있지만, 모델과 스케줄러 구성 요소를 개별적으로 선택하여 사용자 지정 Diffusion 시스템을 구축할 수도 있습니다.\n\n훑어보기의 경우, [`~diffusers.ConfigMixin.from_config`] 메서드를 사용하여 [`DDPMScheduler`]를 인스턴스화합니다:\n\n```py\n>>> from diffusers import DDPMScheduler\n\n>>> scheduler = DDPMScheduler.from_config(repo_id)\n>>> scheduler\nDDPMScheduler {\n  \"_class_name\": \"DDPMScheduler\",\n  \"_diffusers_version\": \"0.13.1\",\n  \"beta_end\": 0.02,\n  \"beta_schedule\": \"linear\",\n  \"beta_start\": 0.0001,\n  \"clip_sample\": true,\n  \"clip_sample_range\": 1.0,\n  \"num_train_timesteps\": 1000,\n  \"prediction_type\": \"epsilon\",\n  \"trained_betas\": null,\n  \"variance_type\": \"fixed_small\"\n}\n```\n\n> [!TIP]\n> 💡 스케줄러가 구성에서 어떻게 인스턴스화되는지 주목하세요. 모델과 달리 스케줄러에는 학습 가능한 가중치가 없으며 매개변수도 없습니다!\n\n가장 중요한 매개변수는 다음과 같습니다:\n\n* `num_train_timesteps`: 노이즈 제거 프로세스의 길이, 즉 랜덤 가우스 노이즈를 데이터 샘플로 처리하는 데 필요한 타임스텝 수입니다.\n* `beta_schedule`: 추론 및 학습에 사용할 노이즈 스케줄 유형입니다.\n* `beta_start` 및 `beta_end`: 노이즈 스케줄의 시작 및 종료 노이즈 값입니다.\n\n노이즈가 약간 적은 이미지를 예측하려면 스케줄러의 [`~diffusers.DDPMScheduler.step`] 메서드에 모델 출력, `timestep`, 현재 `sample`을 전달하세요.\n\n```py\n>>> less_noisy_sample = scheduler.step(model_output=noisy_residual, timestep=2, sample=noisy_sample).prev_sample\n>>> less_noisy_sample.shape\n```\n\n`less_noisy_sample`을 다음 `timestep`으로 넘기면 노이즈가 더 줄어듭니다! 이제 이 모든 것을 한데 모아 전체 노이즈 제거 과정을 시각화해 보겠습니다.\n\n먼저 노이즈 제거된 이미지를 후처리하여 `PIL.Image`로 표시하는 함수를 만듭니다:\n\n```py\n>>> import PIL.Image\n>>> import numpy as np\n\n\n>>> def display_sample(sample, i):\n...     image_processed = sample.cpu().permute(0, 2, 3, 1)\n...     image_processed = (image_processed + 1.0) * 127.5\n...     image_processed = image_processed.numpy().astype(np.uint8)\n\n...     image_pil = PIL.Image.fromarray(image_processed[0])\n...     display(f\"Image at step {i}\")\n...     display(image_pil)\n```\n\n노이즈 제거 프로세스의 속도를 높이려면 입력과 모델을 GPU로 옮기세요:\n\n```py\n>>> model.to(\"cuda\")\n>>> noisy_sample = noisy_sample.to(\"cuda\")\n```\n\n이제 노이즈가 적은 샘플의 잔차를 예측하고 스케줄러로 노이즈가 적은 샘플을 계산하는 노이즈 제거 루프를 생성합니다:\n\n```py\n>>> import tqdm\n\n>>> sample = noisy_sample\n\n>>> for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):\n...     # 1. predict noise residual\n...     with torch.no_grad():\n...         residual = model(sample, t).sample\n\n...     # 2. compute less noisy image and set x_t -> x_t-1\n...     sample = scheduler.step(residual, t, sample).prev_sample\n\n...     # 3. optionally look at image\n...     if (i + 1) % 50 == 0:\n...         display_sample(sample, i + 1)\n```\n\n가만히 앉아서 고양이가 소음으로만 생성되는 것을 지켜보세요!😻\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/diffusion-quicktour.png\"/>\n</div>\n\n## 다음 단계\n\n이번 훑어보기에서 🧨 Diffusers로 멋진 이미지를 만들어 보셨기를 바랍니다! 다음 단계로 넘어가세요:\n\n* [training](./tutorials/basic_training) 튜토리얼에서 모델을 학습하거나 파인튜닝하여 나만의 이미지를 생성할 수 있습니다.\n* 다양한 사용 사례는 공식 및 커뮤니티 [학습 또는 파인튜닝 스크립트](https://github.com/huggingface/diffusers/tree/main/examples#-diffusers-examples) 예시를 참조하세요.\n* 스케줄러 로드, 액세스, 변경 및 비교에 대한 자세한 내용은 [다른 스케줄러 사용](./using-diffusers/schedulers) 가이드에서 확인하세요.\n* [Stable Diffusion](./stable_diffusion) 가이드에서 프롬프트 엔지니어링, 속도 및 메모리 최적화, 고품질 이미지 생성을 위한 팁과 요령을 살펴보세요.\n* [GPU에서 파이토치 최적화](./optimization/fp16) 가이드와 [애플 실리콘(M1/M2)에서의 Stable Diffusion](./optimization/mps) 및 [ONNX 런타임](./optimization/onnx) 실행에 대한 추론 가이드를 통해 🧨 Diffuser 속도를 높이는 방법을 더 자세히 알아보세요."
  },
  {
    "path": "docs/source/ko/stable_diffusion.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 효과적이고 효율적인 Diffusion\n\n[[open-in-colab]]\n\n특정 스타일로 이미지를 생성하거나 원하는 내용을 포함하도록[`DiffusionPipeline`]을 설정하는 것은 까다로울 수 있습니다. 종종 만족스러운 이미지를 얻기까지 [`DiffusionPipeline`]을 여러 번 실행해야 하는 경우가 많습니다. 그러나 무에서 유를 창조하는 것은 특히 추론을 반복해서 실행하는 경우 계산 집약적인 프로세스입니다.\n\n그렇기 때문에 파이프라인에서 *계산*(속도) 및 *메모리*(GPU RAM) 효율성을 극대화하여 추론 주기 사이의 시간을 단축하여 더 빠르게 반복할 수 있도록 하는 것이 중요합니다.\n\n이 튜토리얼에서는 [`DiffusionPipeline`]을 사용하여 더 빠르고 효과적으로 생성하는 방법을 안내합니다.\n\n[`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 모델을 불러와서 시작합니다:\n\n```python\nfrom diffusers import DiffusionPipeline\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipeline = DiffusionPipeline.from_pretrained(model_id)\n```\n\n예제 프롬프트는 \"portrait of an old warrior chief\" 이지만, 자유롭게 자신만의 프롬프트를 사용해도 됩니다:\n\n```python\nprompt = \"portrait photo of a old warrior chief\"\n```\n\n## 속도\n\n> [!TIP]\n> 💡 GPU에 액세스할 수 없는 경우 다음과 같은 GPU 제공업체에서 무료로 사용할 수 있습니다!. [Colab](https://colab.research.google.com/)\n\n추론 속도를 높이는 가장 간단한 방법 중 하나는 Pytorch 모듈을 사용할 때와 같은 방식으로 GPU에 파이프라인을 배치하는 것입니다:\n\n```python\npipeline = pipeline.to(\"cuda\")\n```\n\n동일한 이미지를 사용하고 개선할 수 있는지 확인하려면 [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html)를 사용하고 [재현성](./using-diffusers/reusing_seeds)에 대한 시드를 설정하세요:\n\n```python\nimport torch\n\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\n```\n\n이제 이미지를 생성할 수 있습니다:\n\n```python\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_1.png\">\n</div>\n\n이 프로세스는 T4 GPU에서 약 30초가 소요되었습니다(할당된 GPU가 T4보다 나은 경우 더 빠를 수 있음). 기본적으로 [`DiffusionPipeline`]은 50개의 추론 단계에 대해 전체 `float32` 정밀도로 추론을 실행합니다. `float16`과 같은 더 낮은 정밀도로 전환하거나 추론 단계를 더 적게 실행하여 속도를 높일 수 있습니다.\n\n`float16`으로 모델을 로드하고 이미지를 생성해 보겠습니다:\n\n\n```python\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)\npipeline = pipeline.to(\"cuda\")\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_2.png\">\n</div>\n\n이번에는 이미지를 생성하는 데 약 11초밖에 걸리지 않아 이전보다 3배 가까이 빨라졌습니다!\n\n> [!TIP]\n> 💡 파이프라인은 항상 `float16`에서 실행할 것을 강력히 권장하며, 지금까지 출력 품질이 저하되는 경우는 거의 없었습니다.\n\n또 다른 옵션은 추론 단계의 수를 줄이는 것입니다. 보다 효율적인 스케줄러를 선택하면 출력 품질 저하 없이 단계 수를 줄이는 데 도움이 될 수 있습니다. 현재 모델과 호환되는 스케줄러는 `compatibles` 메서드를 호출하여 [`DiffusionPipeline`]에서 찾을 수 있습니다:\n\n```python\npipeline.scheduler.compatibles\n[\n    diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,\n    diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,\n    diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,\n    diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,\n    diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,\n    diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,\n    diffusers.schedulers.scheduling_ddpm.DDPMScheduler,\n    diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,\n    diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,\n    diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,\n    diffusers.schedulers.scheduling_pndm.PNDMScheduler,\n    diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,\n    diffusers.schedulers.scheduling_ddim.DDIMScheduler,\n]\n```\n\nStable Diffusion 모델은 일반적으로 약 50개의 추론 단계가 필요한 [`PNDMScheduler`]를 기본으로 사용하지만, [`DPMSolverMultistepScheduler`]와 같이 성능이 더 뛰어난 스케줄러는 약 20개 또는 25개의 추론 단계만 필요로 합니다. 새 스케줄러를 로드하려면 [`ConfigMixin.from_config`] 메서드를 사용합니다:\n\n```python\nfrom diffusers import DPMSolverMultistepScheduler\n\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n```\n\n`num_inference_steps`를 20으로 설정합니다:\n\n```python\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\nimage = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_3.png\">\n</div>\n\n추론시간을 4초로 단축할 수 있었습니다! ⚡️\n\n## 메모리\n\n파이프라인 성능 향상의 또 다른 핵심은 메모리 사용량을 줄이는 것인데, 초당 생성되는 이미지 수를 최대화하려고 하는 경우가 많기 때문에 간접적으로 더 빠른 속도를 의미합니다. 한 번에 생성할 수 있는 이미지 수를 확인하는 가장 쉬운 방법은 `OutOfMemoryError`(OOM)이 발생할 때까지 다양한 배치 크기를 시도해 보는 것입니다.\n\n프롬프트 목록과 `Generators`에서 이미지 배치를 생성하는 함수를 만듭니다. 좋은 결과를 생성하는 경우 재사용할 수 있도록 각 `Generator`에 시드를 할당해야 합니다.\n\n```python\ndef get_inputs(batch_size=1):\n    generator = [torch.Generator(\"cuda\").manual_seed(i) for i in range(batch_size)]\n    prompts = batch_size * [prompt]\n    num_inference_steps = 20\n\n    return {\"prompt\": prompts, \"generator\": generator, \"num_inference_steps\": num_inference_steps}\n```\n\n또한 각 이미지 배치를 보여주는 기능이 필요합니다:\n\n```python\nfrom PIL import Image\n\n\ndef image_grid(imgs, rows=2, cols=2):\n    w, h = imgs[0].size\n    grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i % cols * w, i // cols * h))\n    return grid\n```\n\n`batch_size=4`부터 시작해 얼마나 많은 메모리를 소비했는지 확인합니다:\n\n```python\nimages = pipeline(**get_inputs(batch_size=4)).images\nimage_grid(images)\n```\n\nRAM이 더 많은 GPU가 아니라면 위의 코드에서 `OOM` 오류가 반환되었을 것입니다! 대부분의 메모리는 cross-attention 레이어가 차지합니다. 이 작업을 배치로 실행하는 대신 순차적으로 실행하면 상당한 양의 메모리를 절약할 수 있습니다. 파이프라인을 구성하여 [`~DiffusionPipeline.enable_attention_slicing`] 함수를 사용하기만 하면 됩니다:\n\n\n```python\npipeline.enable_attention_slicing()\n```\n\n이제 `batch_size`를 8로 늘려보세요!\n\n```python\nimages = pipeline(**get_inputs(batch_size=8)).images\nimage_grid(images, rows=2, cols=4)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_5.png\">\n</div>\n\n이전에는 4개의 이미지를 배치로 생성할 수도 없었지만, 이제는 이미지당 약 3.5초 만에 8개의 이미지를 배치로 생성할 수 있습니다! 이는 아마도 품질 저하 없이 T4 GPU에서 가장 빠른 속도일 것입니다.\n\n## 품질\n\n지난 두 섹션에서는 `fp16`을 사용하여 파이프라인의 속도를 최적화하고, 더 성능이 좋은 스케줄러를 사용하여 추론 단계의 수를 줄이고, attention slicing을 활성화하여 메모리 소비를 줄이는 방법을 배웠습니다. 이제 생성된 이미지의 품질을 개선하는 방법에 대해 집중적으로 알아보겠습니다.\n\n\n### 더 나은 체크포인트\n\n가장 확실한 단계는 더 나은 체크포인트를 사용하는 것입니다. Stable Diffusion 모델은 좋은 출발점이며, 공식 출시 이후 몇 가지 개선된 버전도 출시되었습니다. 하지만 최신 버전을 사용한다고 해서 자동으로 더 나은 결과를 얻을 수 있는 것은 아닙니다. 여전히 다양한 체크포인트를 직접 실험해보고, [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/) 사용 등 약간의 조사를 통해 최상의 결과를 얻어야 합니다.\n\n이 분야가 성장함에 따라 특정 스타일을 연출할 수 있도록 세밀하게 조정된 고품질 체크포인트가 점점 더 많아지고 있습니다. [Hub](https://huggingface.co/models?library=diffusers&sort=downloads)와 [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery)를 둘러보고 관심 있는 것을 찾아보세요!\n\n\n### 더 나은 파이프라인 구성 요소\n\n현재 파이프라인 구성 요소를 최신 버전으로 교체해 볼 수도 있습니다. Stability AI의 최신 [autodecoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae)를 파이프라인에 로드하고 몇 가지 이미지를 생성해 보겠습니다:\n\n\n```python\nfrom diffusers import AutoencoderKL\n\nvae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16).to(\"cuda\")\npipeline.vae = vae\nimages = pipeline(**get_inputs(batch_size=8)).images\nimage_grid(images, rows=2, cols=4)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_6.png\">\n</div>\n\n### 더 나은 프롬프트 엔지니어링\n\n이미지를 생성하는 데 사용하는 텍스트 프롬프트는 *prompt engineering*이라고 할 정도로 매우 중요합니다. 프롬프트 엔지니어링 시 고려해야 할 몇 가지 사항은 다음과 같습니다:\n\n- 생성하려는 이미지 또는 유사한 이미지가 인터넷에 어떻게 저장되어 있는가?\n- 내가 원하는 스타일로 모델을 유도하기 위해 어떤 추가 세부 정보를 제공할 수 있는가?\n\n이를 염두에 두고 색상과 더 높은 품질의 디테일을 포함하도록 프롬프트를 개선해 봅시다:\n\n\n```python\nprompt += \", tribal panther make up, blue on red, side profile, looking away, serious eyes\"\nprompt += \" 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\"\n```\n\n새로운 프롬프트로 이미지 배치를 생성합니다:\n\n```python\nimages = pipeline(**get_inputs(batch_size=8)).images\nimage_grid(images, rows=2, cols=4)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_7.png\">\n</div>\n\n꽤 인상적입니다! `1`의 시드를 가진 `Generator`에 해당하는 두 번째 이미지에 피사체의 나이에 대한 텍스트를 추가하여 조금 더 조정해 보겠습니다:\n\n```python\nprompts = [\n    \"portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\",\n    \"portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\",\n    \"portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\",\n    \"portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\",\n]\n\ngenerator = [torch.Generator(\"cuda\").manual_seed(1) for _ in range(len(prompts))]\nimages = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images\nimage_grid(images)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_8.png\">\n</div>\n\n## 다음 단계\n\n이 튜토리얼에서는 계산 및 메모리 효율을 높이고 생성된 출력의 품질을 개선하기 위해 [`DiffusionPipeline`]을 최적화하는 방법을 배웠습니다. 파이프라인을 더 빠르게 만드는 데 관심이 있다면 다음 리소스를 살펴보세요:\n\n- [PyTorch 2.0](./optimization/torch2.0) 및 [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html)이 어떻게 추론 속도를 5~300% 향상시킬 수 있는지 알아보세요. A100 GPU에서는 추론 속도가 최대 50%까지 빨라질 수 있습니다!\n- PyTorch 2를 사용할 수 없는 경우, [xFormers](./optimization/xformers)를 설치하는 것이 좋습니다. 메모리 효율적인 어텐션 메커니즘은 PyTorch 1.13.1과 함께 사용하면 속도가 빨라지고 메모리 소비가 줄어듭니다.\n- 모델 오프로딩과 같은 다른 최적화 기법은 [이 가이드](./optimization/fp16)에서 다루고 있습니다."
  },
  {
    "path": "docs/source/ko/training/adapt_a_model.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 새로운 작업에 대한 모델을 적용하기\n\n많은 diffusion 시스템은 같은 구성 요소들을 공유하므로 한 작업에 대해 사전학습된 모델을 완전히 다른 작업에 적용할 수 있습니다.\n\n이 인페인팅을 위한 가이드는 사전학습된 [`UNet2DConditionModel`]의 아키텍처를 초기화하고 수정하여 사전학습된 text-to-image 모델을 어떻게 인페인팅에 적용하는지를 알려줄 것입니다.\n\n## UNet2DConditionModel 파라미터 구성\n\n[`UNet2DConditionModel`]은 [input sample](https://huggingface.co/docs/diffusers/v0.16.0/en/api/models#diffusers.UNet2DConditionModel.in_channels)에서 4개의 채널을 기본적으로 허용합니다. 예를 들어,  [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)와 같은 사전학습된 text-to-image 모델을 불러오고 `in_channels`의 수를 확인합니다:\n\n```py\nfrom diffusers import StableDiffusionPipeline\n\npipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\")\npipeline.unet.config[\"in_channels\"]\n4\n```\n\n인페인팅은 입력 샘플에 9개의 채널이 필요합니다. [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting)와 같은 사전학습된 인페인팅 모델에서 이 값을 확인할 수 있습니다:\n\n```py\nfrom diffusers import StableDiffusionPipeline\n\npipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-inpainting\")\npipeline.unet.config[\"in_channels\"]\n9\n```\n\n인페인팅에 대한 text-to-image 모델을 적용하기 위해, `in_channels` 수를 4에서 9로 수정해야 할 것입니다.\n\n사전학습된 text-to-image 모델의 가중치와 [`UNet2DConditionModel`]을 초기화하고 `in_channels`를 9로 수정해 주세요. `in_channels`의 수를 수정하면 크기가 달라지기 때문에 크기가 안 맞는 오류를 피하기 위해 `ignore_mismatched_sizes=True` 및 `low_cpu_mem_usage=False`를 설정해야 합니다.\n\n```py\nfrom diffusers import UNet2DConditionModel\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nunet = UNet2DConditionModel.from_pretrained(\n    model_id, subfolder=\"unet\", in_channels=9, low_cpu_mem_usage=False, ignore_mismatched_sizes=True\n)\n```\n\nText-to-image 모델로부터 다른 구성 요소의 사전학습된 가중치는 체크포인트로부터 초기화되지만 `unet`의 입력 채널 가중치 (`conv_in.weight`)는 랜덤하게 초기화됩니다. 그렇지 않으면 모델이 노이즈를 리턴하기 때문에 인페인팅의 모델을 파인튜닝 할 때 중요합니다.\n"
  },
  {
    "path": "docs/source/ko/training/controlnet.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNet\n\n[Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) (ControlNet)은 Lvmin Zhang과 Maneesh Agrawala에 의해 쓰여졌습니다.\n\n이 예시는 [원본 ControlNet 리포지토리에서 예시 학습하기](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md)에 기반합니다. ControlNet은 원들을 채우기 위해 [small synthetic dataset](https://huggingface.co/datasets/fusing/fill50k)을 사용해서 학습됩니다.\n\n## 의존성 설치하기\n\n아래의 스크립트를 실행하기 전에, 라이브러리의 학습 의존성을 설치해야 합니다.\n\n> [!WARNING]\n> 가장 최신 버전의 예시 스크립트를 성공적으로 실행하기 위해서는, 소스에서 설치하고 최신 버전의 설치를 유지하는 것을 강력하게 추천합니다. 우리는 예시 스크립트들을 자주 업데이트하고 예시에 맞춘 특정한 요구사항을 설치합니다.\n\n위 사항을 만족시키기 위해서, 새로운 가상환경에서 다음 일련의 스텝을 실행하세요:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\n그 다음에는 [예시 폴더](https://github.com/huggingface/diffusers/tree/main/examples/controlnet)으로 이동합니다.\n\n```bash\ncd examples/controlnet\n```\n\n이제 실행하세요:\n\n```bash\npip install -r requirements.txt\n```\n\n[🤗Accelerate](https://github.com/huggingface/accelerate/) 환경을 초기화 합니다:\n\n```bash\naccelerate config\n```\n\n혹은 여러분의 환경이 무엇인지 몰라도 기본적인 🤗Accelerate 구성으로 초기화할 수 있습니다:\n\n```bash\naccelerate config default\n```\n\n혹은 당신의 환경이 노트북 같은 상호작용하는 쉘을 지원하지 않는다면, 아래의 코드로 초기화 할 수 있습니다:\n\n```python\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n자체 데이터셋을 사용하기 위해서는 [학습을 위한 데이터셋 생성하기](create_dataset) 가이드를 확인하세요.\n\n## 학습\n\n이 학습에 사용될 다음 이미지들을 다운로드하세요:\n\n```sh\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png\n\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png\n```\n\n`MODEL_NAME` 환경 변수 (Hub 모델 리포지토리 아이디 혹은 모델 가중치가 있는 디렉토리로 가는 주소)를 명시하고 [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) 인자로 환경변수를 보냅니다.\n\n학습 스크립트는 당신의 리포지토리에 `diffusion_pytorch_model.bin` 파일을 생성하고 저장합니다.\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=4 \\\n --push_to_hub\n```\n\n이 기본적인 설정으로는 ~38GB VRAM이 필요합니다.\n\n기본적으로 학습 스크립트는 결과를 텐서보드에 기록합니다. 가중치(weight)와 편향(bias)을 사용하기 위해 `--report_to wandb` 를 전달합니다.\n\n더 작은 batch(배치) 크기로 gradient accumulation(기울기 누적)을 하면 학습 요구사항을 ~20 GB VRAM으로 줄일 수 있습니다.\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n  --push_to_hub\n```\n\n## 여러개 GPU로 학습하기\n\n`accelerate` 은 seamless multi-GPU 학습을 고려합니다. `accelerate`과 함께 분산된 학습을 실행하기 위해 [여기](https://huggingface.co/docs/accelerate/basic_tutorials/launch)\n의 설명을 확인하세요. 아래는 예시 명령어입니다:\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch --mixed_precision=\"fp16\" --multi_gpu train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=4 \\\n --mixed_precision=\"fp16\" \\\n --tracker_project_name=\"controlnet-demo\" \\\n --report_to=wandb \\\n  --push_to_hub\n```\n\n## 예시 결과\n\n#### 배치 사이즈 8로 300 스텝 이후:\n\n| |  |\n|-------------------|:-------------------------:|\n| | 푸른 배경과 빨간 원  |\n![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![푸른 배경과 빨간 원](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/red_circle_with_blue_background_300_steps.png) |\n| | 갈색 꽃 배경과 청록색 원 |\n![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png) | ![갈색 꽃 배경과 청록색 원](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/cyan_circle_with_brown_floral_background_300_steps.png) |\n\n#### 배치 사이즈 8로 6000 스텝 이후:\n\n| |  |\n|-------------------|:-------------------------:|\n| | 푸른 배경과 빨간 원  |\n![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![푸른 배경과 빨간 원](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/red_circle_with_blue_background_6000_steps.png) |\n| | 갈색 꽃 배경과 청록색 원 |\n![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png) | ![갈색 꽃 배경과 청록색 원](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/cyan_circle_with_brown_floral_background_6000_steps.png) |\n\n## 16GB GPU에서 학습하기\n\n16GB GPU에서 학습하기 위해 다음의 최적화를 진행하세요:\n\n- 기울기 체크포인트 저장하기\n- bitsandbyte의 [8-bit optimizer](https://github.com/TimDettmers/bitsandbytes#requirements--installation)가 설치되지 않았다면 링크에 연결된 설명서를 보세요.\n\n이제 학습 스크립트를 시작할 수 있습니다:\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n --gradient_checkpointing \\\n --use_8bit_adam \\\n  --push_to_hub\n```\n\n## 12GB GPU에서 학습하기\n\n12GB GPU에서 실행하기 위해 다음의 최적화를 진행하세요:\n\n- 기울기 체크포인트 저장하기\n- bitsandbyte의 8-bit [optimizer](https://github.com/TimDettmers/bitsandbytes#requirements--installation)(가 설치되지 않았다면 링크에 연결된 설명서를 보세요)\n- [xFormers](https://huggingface.co/docs/diffusers/training/optimization/xformers)(가 설치되지 않았다면 링크에 연결된 설명서를 보세요)\n- 기울기를 `None`으로 설정\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n --gradient_checkpointing \\\n --use_8bit_adam \\\n --enable_xformers_memory_efficient_attention \\\n --set_grads_to_none \\\n  --push_to_hub\n```\n\n`pip install xformers`으로 `xformers`을 확실히 설치하고 `enable_xformers_memory_efficient_attention`을 사용하세요.\n\n## 8GB GPU에서 학습하기\n\n우리는 ControlNet을 지원하기 위한 DeepSpeed를 철저하게 테스트하지 않았습니다. 환경설정이 메모리를 저장할 때,\n그 환경이 성공적으로 학습했는지를 확정하지 않았습니다. 성공한 학습 실행을 위해 설정을 변경해야 할 가능성이 높습니다.\n\n8GB GPU에서 실행하기 위해 다음의 최적화를 진행하세요:\n\n- 기울기 체크포인트 저장하기\n- bitsandbyte의 8-bit [optimizer](https://github.com/TimDettmers/bitsandbytes#requirements--installation)(가 설치되지 않았다면 링크에 연결된 설명서를 보세요)\n- [xFormers](https://huggingface.co/docs/diffusers/training/optimization/xformers)(가 설치되지 않았다면 링크에 연결된 설명서를 보세요)\n- 기울기를 `None`으로 설정\n- DeepSpeed stage 2 변수와 optimizer 없에기\n- fp16 혼합 정밀도(precision)\n\n[DeepSpeed](https://www.deepspeed.ai/)는 CPU 또는 NVME로 텐서를 VRAM에서 오프로드할 수 있습니다.\n이를 위해서 훨씬 더 많은 RAM(약 25 GB)가 필요합니다.\n\nDeepSpeed stage 2를 활성화하기 위해서 `accelerate config`로 환경을 구성해야합니다.\n\n구성(configuration) 파일은 이런 모습이어야 합니다:\n\n```yaml\ncompute_environment: LOCAL_MACHINE\ndeepspeed_config:\n  gradient_accumulation_steps: 4\n  offload_optimizer_device: cpu\n  offload_param_device: cpu\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\n```\n\n<팁>\n\n[문서](https://huggingface.co/docs/accelerate/usage_guides/deepspeed)를 더 많은 DeepSpeed 설정 옵션을 위해 보세요.\n\n<팁>\n\n기본 Adam optimizer를 DeepSpeed'의 Adam\n`deepspeed.ops.adam.DeepSpeedCPUAdam` 으로 바꾸면 상당한 속도 향상을 이룰수 있지만,\nPytorch와 같은 버전의 CUDA toolchain이 필요합니다. 8-비트 optimizer는 현재 DeepSpeed와\n호환되지 않는 것 같습니다.\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n --gradient_checkpointing \\\n --enable_xformers_memory_efficient_attention \\\n --set_grads_to_none \\\n --mixed_precision fp16 \\\n --push_to_hub\n```\n\n## 추론\n\n학습된 모델은 [`StableDiffusionControlNetPipeline`]과 함께 실행될 수 있습니다.\n`base_model_path`와 `controlnet_path` 에 값을 지정하세요 `--pretrained_model_name_or_path` 와\n`--output_dir` 는 학습 스크립트에 개별적으로 지정됩니다.\n\n```py\nfrom diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler\nfrom diffusers.utils import load_image\nimport torch\n\nbase_model_path = \"path to model\"\ncontrolnet_path = \"path to controlnet\"\n\ncontrolnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)\npipe = StableDiffusionControlNetPipeline.from_pretrained(\n    base_model_path, controlnet=controlnet, torch_dtype=torch.float16\n)\n\n# 더 빠른 스케줄러와 메모리 최적화로 diffusion 프로세스 속도 올리기\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n# xformers가 설치되지 않으면 아래 줄을 삭제하기\npipe.enable_xformers_memory_efficient_attention()\n\npipe.enable_model_cpu_offload()\n\ncontrol_image = load_image(\"./conditioning_image_1.png\")\nprompt = \"pale golden rod circle with old lace background\"\n\n# 이미지 생성하기\ngenerator = torch.manual_seed(0)\nimage = pipe(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]\n\nimage.save(\"./output.png\")\n```\n"
  },
  {
    "path": "docs/source/ko/training/create_dataset.md",
    "content": "# 학습을 위한 데이터셋 만들기\n\n[Hub](https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=downloads) 에는 모델 교육을 위한 많은 데이터셋이 있지만,\n관심이 있거나 사용하고 싶은 데이터셋을 찾을 수 없는 경우 🤗 [Datasets](https://huggingface.co/docs/datasets) 라이브러리를 사용하여 데이터셋을 만들 수 있습니다.\n데이터셋 구조는 모델을 학습하려는 작업에 따라 달라집니다.\n가장 기본적인 데이터셋 구조는 unconditional 이미지 생성과 같은 작업을 위한 이미지 디렉토리입니다.\n또 다른 데이터셋 구조는 이미지 디렉토리와 text-to-image 생성과 같은 작업에 해당하는 텍스트 캡션이 포함된 텍스트 파일일 수 있습니다.\n\n이 가이드에는 파인 튜닝할 데이터셋을 만드는 두 가지 방법을 소개합니다:\n\n- 이미지 폴더를 `--train_data_dir` 인수에 제공합니다.\n- 데이터셋을 Hub에 업로드하고 데이터셋 리포지토리 id를 `--dataset_name` 인수에 전달합니다.\n\n> [!TIP]\n> 💡 학습에 사용할 이미지 데이터셋을 만드는 방법에 대한 자세한 내용은 [이미지 데이터셋 만들기](https://huggingface.co/docs/datasets/image_dataset) 가이드를 참고하세요.\n\n## 폴더 형태로 데이터셋 구축하기\n\nUnconditional 생성을 위해 이미지 폴더로 자신의 데이터셋을 구축할 수 있습니다.\n학습 스크립트는 🤗 Datasets의 [ImageFolder](https://huggingface.co/docs/datasets/en/image_dataset#imagefolder) 빌더를 사용하여\n자동으로 폴더에서 데이터셋을 구축합니다. 디렉토리 구조는 다음과 같아야 합니다 :\n\n```bash\ndata_dir/xxx.png\ndata_dir/xxy.png\ndata_dir/[...]/xxz.png\n```\n\n데이터셋 디렉터리의 경로를 `--train_data_dir` 인수로 전달한 다음 학습을 시작할 수 있습니다:\n\n```bash\naccelerate launch train_unconditional.py \\\n    # argument로 폴더 지정하기 \\\n    --train_data_dir <path-to-train-directory> \\\n    <other-arguments>\n```\n\n## Hub에 데이터 올리기\n\n> [!TIP]\n> 💡 데이터셋을 만들고 Hub에 업로드하는 것에 대한 자세한 내용은 [🤗 Datasets을 사용한 이미지 검색](https://huggingface.co/blog/image-search-datasets) 게시물을 참고하세요.\n\nPIL 인코딩된 이미지가 포함된 `이미지` 열을 생성하는 [이미지 폴더](https://huggingface.co/docs/datasets/image_load#imagefolder) 기능을 사용하여 데이터셋 생성을 시작합니다.\n\n`data_dir` 또는 `data_files` 매개 변수를 사용하여 데이터셋의 위치를 지정할 수 있습니다.\n`data_files` 매개변수는 특정 파일을 `train` 이나 `test` 로 분리한 데이터셋에 매핑하는 것을 지원합니다:\n\n```python\nfrom datasets import load_dataset\n\n# 예시 1: 로컬 폴더\ndataset = load_dataset(\"imagefolder\", data_dir=\"path_to_your_folder\")\n\n# 예시 2: 로컬 파일 (지원 포맷 : tar, gzip, zip, xz, rar, zstd)\ndataset = load_dataset(\"imagefolder\", data_files=\"path_to_zip_file\")\n\n# 예시 3: 원격 파일 (지원 포맷 : tar, gzip, zip, xz, rar, zstd)\ndataset = load_dataset(\n    \"imagefolder\",\n    data_files=\"https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip\",\n)\n\n# 예시 4: 여러개로 분할\ndataset = load_dataset(\n    \"imagefolder\", data_files={\"train\": [\"path/to/file1\", \"path/to/file2\"], \"test\": [\"path/to/file3\", \"path/to/file4\"]}\n)\n```\n\n[push_to_hub(https://huggingface.co/docs/datasets/v2.13.1/en/package_reference/main_classes#datasets.Dataset.push_to_hub) 을 사용해서 Hub에 데이터셋을 업로드 합니다:\n\n```python\n# 터미널에서 hf auth login 커맨드를 이미 실행했다고 가정합니다\ndataset.push_to_hub(\"name_of_your_dataset\")\n\n# 개인 repo로 push 하고 싶다면, `private=True` 을 추가하세요:\ndataset.push_to_hub(\"name_of_your_dataset\", private=True)\n```\n\n이제 데이터셋 이름을 `--dataset_name` 인수에 전달하여 데이터셋을 학습에 사용할 수 있습니다:\n\n```bash\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image.py \\\n  --pretrained_model_name_or_path=\"stable-diffusion-v1-5/stable-diffusion-v1-5\" \\\n  --dataset_name=\"name_of_your_dataset\" \\\n  <other-arguments>\n```\n\n## 다음 단계\n\n데이터셋을 생성했으니 이제 학습 스크립트의 `train_data_dir` (데이터셋이 로컬이면) 혹은 `dataset_name` (Hub에 데이터셋을 올렸으면) 인수에 연결할 수 있습니다.\n\n다음 단계에서는 데이터셋을 사용하여 [unconditional 생성](https://huggingface.co/docs/diffusers/v0.18.2/en/training/unconditional_training) 또는 [텍스트-이미지 생성](https://huggingface.co/docs/diffusers/training/text2image)을 위한 모델을 학습시켜보세요!\n"
  },
  {
    "path": "docs/source/ko/training/custom_diffusion.md",
    "content": "<!--Copyright 2025 Custom Diffusion authors The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 커스텀 Diffusion 학습 예제\n\n[커스텀 Diffusion](https://huggingface.co/papers/2212.04488)은 피사체의 이미지 몇 장(4~5장)만 주어지면 Stable Diffusion처럼 text-to-image 모델을 커스터마이징하는 방법입니다.\n'train_custom_diffusion.py' 스크립트는 학습 과정을 구현하고 이를 Stable Diffusion에 맞게 조정하는 방법을 보여줍니다.\n\n이 교육 사례는 [Nupur Kumari](https://nupurkmr9.github.io/)가 제공하였습니다. (Custom Diffusion의 저자 중 한명).\n\n## 로컬에서 PyTorch로 실행하기\n\n### Dependencies 설치하기\n\n스크립트를 실행하기 전에 라이브러리의 학습 dependencies를 설치해야 합니다:\n\n**중요**\n\n예제 스크립트의 최신 버전을 성공적으로 실행하려면 **소스로부터 설치**하는 것을 매우 권장하며, 예제 스크립트를 자주 업데이트하는 만큼 일부 예제별 요구 사항을 설치하고 설치를 최신 상태로 유지하는 것이 좋습니다. 이를 위해 새 가상 환경에서 다음 단계를 실행하세요:\n\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\n[example folder](https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion)로 cd하여 이동하세요.\n\n```\ncd examples/custom_diffusion\n```\n\n이제 실행\n\n```bash\npip install -r requirements.txt\npip install clip-retrieval\n```\n\n그리고 [🤗Accelerate](https://github.com/huggingface/accelerate/) 환경을 초기화:\n\n```bash\naccelerate config\n```\n\n또는 사용자 환경에 대한 질문에 답하지 않고 기본 가속 구성을 사용하려면 다음과 같이 하세요.\n\n```bash\naccelerate config default\n```\n\n또는 사용 중인 환경이 대화형 셸을 지원하지 않는 경우(예: jupyter notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n### 고양이 예제 😺\n\n이제 데이터셋을 가져옵니다. [여기](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip)에서 데이터셋을 다운로드하고 압축을 풉니다. 직접 데이터셋을 사용하려면 [학습용 데이터셋 생성하기](create_dataset) 가이드를 참고하세요.\n\n또한 'clip-retrieval'을 사용하여 200개의 실제 이미지를 수집하고, regularization으로서 이를 학습 데이터셋의 타겟 이미지와 결합합니다. 이렇게 하면 주어진 타겟 이미지에 대한 과적합을 방지할 수 있습니다. 다음 플래그를 사용하면 `prior_loss_weight=1.`로 `prior_preservation`, `real_prior` regularization을 활성화할 수 있습니다.\n클래스_프롬프트`는 대상 이미지와 동일한 카테고리 이름이어야 합니다. 수집된 실제 이미지에는 `class_prompt`와 유사한 텍스트 캡션이 있습니다. 검색된 이미지는 `class_data_dir`에 저장됩니다. 생성된 이미지를 regularization으로 사용하기 위해 `real_prior`를 비활성화할 수 있습니다. 실제 이미지를 수집하려면 훈련 전에 이 명령을 먼저 사용하십시오.\n\n```bash\npip install clip-retrieval\npython retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200\n```\n\n**___참고: [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 모델을 사용하는 경우 '해상도'를 768로 변경하세요.___**\n\n스크립트는 모델 체크포인트와 `pytorch_custom_diffusion_weights.bin` 파일을 생성하여 저장소에 저장합니다.\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport OUTPUT_DIR=\"path-to-save-model\"\nexport INSTANCE_DIR=\"./data/cat\"\n\naccelerate launch train_custom_diffusion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --class_data_dir=./real_reg/samples_cat/ \\\n  --with_prior_preservation --real_prior --prior_loss_weight=1.0 \\\n  --class_prompt=\"cat\" --num_class_images=200 \\\n  --instance_prompt=\"photo of a <new1> cat\"  \\\n  --resolution=512  \\\n  --train_batch_size=2  \\\n  --learning_rate=1e-5  \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=250 \\\n  --scale_lr --hflip  \\\n  --modifier_token \"<new1>\" \\\n  --push_to_hub\n```\n\n**더 낮은 VRAM 요구 사항(GPU당 16GB)으로 더 빠르게 훈련하려면 `--enable_xformers_memory_efficient_attention`을 사용하세요. 설치 방법은 [가이드](https://github.com/facebookresearch/xformers)를 따르세요.**\n\n가중치 및 편향(`wandb`)을 사용하여 실험을 추적하고 중간 결과를 저장하려면(강력히 권장합니다) 다음 단계를 따르세요:\n\n* `wandb` 설치: `pip install wandb`.\n* 로그인 : `wandb login`.\n* 그런 다음 트레이닝을 시작하는 동안 `validation_prompt`를 지정하고 `report_to`를 `wandb`로 설정합니다. 다음과 같은 관련 인수를 구성할 수도 있습니다:\n    * `num_validation_images`\n    * `validation_steps`\n\n```bash\naccelerate launch train_custom_diffusion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --class_data_dir=./real_reg/samples_cat/ \\\n  --with_prior_preservation --real_prior --prior_loss_weight=1.0 \\\n  --class_prompt=\"cat\" --num_class_images=200 \\\n  --instance_prompt=\"photo of a <new1> cat\"  \\\n  --resolution=512  \\\n  --train_batch_size=2  \\\n  --learning_rate=1e-5  \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=250 \\\n  --scale_lr --hflip  \\\n  --modifier_token \"<new1>\" \\\n  --validation_prompt=\"<new1> cat sitting in a bucket\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub\n```\n\n다음은 [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau)의 예시이며, 여러 학습 세부 정보와 함께 중간 결과들을 확인할 수 있습니다.\n\n`--push_to_hub`를 지정하면 학습된 파라미터가 허깅 페이스 허브의 리포지토리에 푸시됩니다. 다음은 [예제 리포지토리](https://huggingface.co/sayakpaul/custom-diffusion-cat)입니다.\n\n### 멀티 컨셉에 대한 학습 🐱🪵\n\n[this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)와 유사하게 각 컨셉에 대한 정보가 포함된 [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) 파일을 제공합니다.\n\n실제 이미지를 수집하려면 json 파일의 각 컨셉에 대해 이 명령을 실행합니다.\n\n```bash\npip install clip-retrieval\npython retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200\n```\n\n그럼 우리는 학습시킬 준비가 되었습니다!\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_custom_diffusion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --output_dir=$OUTPUT_DIR \\\n  --concepts_list=./concept_list.json \\\n  --with_prior_preservation --real_prior --prior_loss_weight=1.0 \\\n  --resolution=512  \\\n  --train_batch_size=2  \\\n  --learning_rate=1e-5  \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --num_class_images=200 \\\n  --scale_lr --hflip  \\\n  --modifier_token \"<new1>+<new2>\" \\\n  --push_to_hub\n```\n\n다음은 [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg)의 예시이며, 다른 학습 세부 정보와 함께 중간 결과들을 확인할 수 있습니다.\n\n### 사람 얼굴에 대한 학습\n\n사람 얼굴에 대한 파인튜닝을 위해 다음과 같은 설정이 더 효과적이라는 것을 확인했습니다: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, `freeze_model=crossattn`을 최소 15~20개의 이미지로 설정합니다.\n\n실제 이미지를 수집하려면 훈련 전에 이 명령을 먼저 사용하십시오.\n\n```bash\npip install clip-retrieval\npython retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200\n```\n\n이제 학습을 시작하세요!\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport OUTPUT_DIR=\"path-to-save-model\"\nexport INSTANCE_DIR=\"path-to-images\"\n\naccelerate launch train_custom_diffusion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --class_data_dir=./real_reg/samples_person/ \\\n  --with_prior_preservation --real_prior --prior_loss_weight=1.0 \\\n  --class_prompt=\"person\" --num_class_images=200 \\\n  --instance_prompt=\"photo of a <new1> person\"  \\\n  --resolution=512  \\\n  --train_batch_size=2  \\\n  --learning_rate=5e-6  \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=1000 \\\n  --scale_lr --hflip --noaug \\\n  --freeze_model crossattn \\\n  --modifier_token \"<new1>\" \\\n  --enable_xformers_memory_efficient_attention \\\n  --push_to_hub\n```\n\n## 추론\n\n위 프롬프트를 사용하여 모델을 학습시킨 후에는 아래 프롬프트를 사용하여 추론을 실행할 수 있습니다. 프롬프트에 'modifier token'(예: 위 예제에서는 \\<new1\\>)을 반드시 포함해야 합니다.\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", torch_dtype=torch.float16).to(\"cuda\")\npipe.unet.load_attn_procs(\"path-to-save-model\", weight_name=\"pytorch_custom_diffusion_weights.bin\")\npipe.load_textual_inversion(\"path-to-save-model\", weight_name=\"<new1>.bin\")\n\nimage = pipe(\n    \"<new1> cat sitting in a bucket\",\n    num_inference_steps=100,\n    guidance_scale=6.0,\n    eta=1.0,\n).images[0]\nimage.save(\"cat.png\")\n```\n\n허브 리포지토리에서 이러한 매개변수를 직접 로드할 수 있습니다:\n\n```python\nimport torch\nfrom huggingface_hub.repocard import RepoCard\nfrom diffusers import DiffusionPipeline\n\nmodel_id = \"sayakpaul/custom-diffusion-cat\"\ncard = RepoCard.load(model_id)\nbase_model_id = card.data.to_dict()[\"base_model\"]\n\npipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(\"cuda\")\npipe.unet.load_attn_procs(model_id, weight_name=\"pytorch_custom_diffusion_weights.bin\")\npipe.load_textual_inversion(model_id, weight_name=\"<new1>.bin\")\n\nimage = pipe(\n    \"<new1> cat sitting in a bucket\",\n    num_inference_steps=100,\n    guidance_scale=6.0,\n    eta=1.0,\n).images[0]\nimage.save(\"cat.png\")\n```\n\n다음은 여러 컨셉으로 추론을 수행하는 예제입니다:\n\n```python\nimport torch\nfrom huggingface_hub.repocard import RepoCard\nfrom diffusers import DiffusionPipeline\n\nmodel_id = \"sayakpaul/custom-diffusion-cat-wooden-pot\"\ncard = RepoCard.load(model_id)\nbase_model_id = card.data.to_dict()[\"base_model\"]\n\npipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(\"cuda\")\npipe.unet.load_attn_procs(model_id, weight_name=\"pytorch_custom_diffusion_weights.bin\")\npipe.load_textual_inversion(model_id, weight_name=\"<new1>.bin\")\npipe.load_textual_inversion(model_id, weight_name=\"<new2>.bin\")\n\nimage = pipe(\n    \"the <new1> cat sculpture in the style of a <new2> wooden pot\",\n    num_inference_steps=100,\n    guidance_scale=6.0,\n    eta=1.0,\n).images[0]\nimage.save(\"multi-subject.png\")\n```\n\n여기서 '고양이'와 '나무 냄비'는 여러 컨셉을 말합니다.\n\n### 학습된 체크포인트에서 추론하기\n\n`--checkpointing_steps`  인수를 사용한 경우 학습 과정에서 저장된 전체 체크포인트 중 하나에서 추론을 수행할 수도 있습니다.\n\n## Grads를 None으로 설정\n\n더 많은 메모리를 절약하려면 스크립트에 `--set_grads_to_none` 인수를 전달하세요. 이렇게 하면 성적이 0이 아닌 없음으로 설정됩니다. 그러나 특정 동작이 변경되므로 문제가 발생하면 이 인수를 제거하세요.\n\n자세한 정보: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\n\n## 실험 결과\n\n실험에 대한 자세한 내용은 [당사 웹페이지](https://www.cs.cmu.edu/~custom-diffusion/)를 참조하세요."
  },
  {
    "path": "docs/source/ko/training/distributed_inference.md",
    "content": "# 여러 GPU를 사용한 분산 추론\n\n분산 설정에서는 여러 개의 프롬프트를 동시에 생성할 때 유용한 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) 또는 [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html)를 사용하여 여러 GPU에서 추론을 실행할 수 있습니다.\n\n이 가이드에서는 분산 추론을 위해 🤗 Accelerate와 PyTorch Distributed를 사용하는 방법을 보여드립니다.\n\n## 🤗 Accelerate\n\n🤗 [Accelerate](https://huggingface.co/docs/accelerate/index)는 분산 설정에서 추론을 쉽게 훈련하거나 실행할 수 있도록 설계된 라이브러리입니다. 분산 환경 설정 프로세스를 간소화하여 PyTorch 코드에 집중할 수 있도록 해줍니다.\n\n시작하려면 Python 파일을 생성하고 [`accelerate.PartialState`]를 초기화하여 분산 환경을 생성하면, 설정이 자동으로 감지되므로 `rank` 또는 `world_size`를 명시적으로 정의할 필요가 없습니다. ['DiffusionPipeline`]을 `distributed_state.device`로 이동하여 각 프로세스에 GPU를 할당합니다.\n\n이제 컨텍스트 관리자로 [`~accelerate.PartialState.split_between_processes`] 유틸리티를 사용하여 프로세스 수에 따라 프롬프트를 자동으로 분배합니다.\n\n\n```py\nfrom accelerate import PartialState\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16)\ndistributed_state = PartialState()\npipeline.to(distributed_state.device)\n\nwith distributed_state.split_between_processes([\"a dog\", \"a cat\"]) as prompt:\n    result = pipeline(prompt).images[0]\n    result.save(f\"result_{distributed_state.process_index}.png\")\n```\n\nUse the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script:\n\n```bash\naccelerate launch run_distributed.py --num_processes=2\n```\n\n> [!TIP]\n> 자세한 내용은 [🤗 Accelerate를 사용한 분산 추론](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) 가이드를 참조하세요.\n\n## Pytoerch 분산\n\nPyTorch는 데이터 병렬 처리를 가능하게 하는 [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)을 지원합니다.\n\n시작하려면 Python 파일을 생성하고 `torch.distributed` 및 `torch.multiprocessing`을 임포트하여 분산 프로세스 그룹을 설정하고 각 GPU에서 추론용 프로세스를 생성합니다. 그리고 [`DiffusionPipeline`]도 초기화해야 합니다:\n\n확산 파이프라인을 `rank`로 이동하고 `get_rank`를 사용하여 각 프로세스에 GPU를 할당하면 각 프로세스가 다른 프롬프트를 처리합니다:\n\n```py\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\nfrom diffusers import DiffusionPipeline\n\nsd = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16)\n```\n\n사용할 백엔드 유형, 현재 프로세스의 `rank`, `world_size` 또는 참여하는 프로세스 수로 분산 환경 생성을 처리하는 함수[`init_process_group`]를 만들어 추론을 실행해야 합니다.\n\n2개의 GPU에서 추론을 병렬로 실행하는 경우 `world_size`는 2입니다.\n\n```py\ndef run_inference(rank, world_size):\n    dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n\n    sd.to(rank)\n\n    if torch.distributed.get_rank() == 0:\n        prompt = \"a dog\"\n    elif torch.distributed.get_rank() == 1:\n        prompt = \"a cat\"\n\n    image = sd(prompt).images[0]\n    image.save(f\"./{'_'.join(prompt)}.png\")\n```\n\n분산 추론을 실행하려면 [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn)을 호출하여 `world_size`에 정의된 GPU 수에 대해 `run_inference` 함수를 실행합니다:\n\n```py\ndef main():\n    world_size = 2\n    mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)\n\n\nif __name__ == \"__main__\":\n    main()\n```\n\n추론 스크립트를 완료했으면 `--nproc_per_node` 인수를 사용하여 사용할 GPU 수를 지정하고 `torchrun`을 호출하여 스크립트를 실행합니다:\n\n```bash\ntorchrun run_distributed.py --nproc_per_node=2\n```"
  },
  {
    "path": "docs/source/ko/training/dreambooth.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DreamBooth\n\n[DreamBooth](https://huggingface.co/papers/2208.12242)는 한 주제에 대한 적은 이미지(3~5개)만으로도 stable diffusion과 같이 text-to-image 모델을 개인화할 수 있는 방법입니다. 이를 통해 모델은 다양한 장면, 포즈 및 장면(뷰)에서 피사체에 대해 맥락화(contextualized)된 이미지를 생성할 수 있습니다.\n\n![프로젝트 블로그에서의 DreamBooth 예시](https://dreambooth.github.io/DreamBooth_files/teaser_static.jpg)\n<small>에서의 Dreambooth 예시 <a href=\"https://dreambooth.github.io\">project's blog.</a></small>\n\n\n이 가이드는 다양한 GPU, Flax 사양에 대해 [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) 모델로 DreamBooth를 파인튜닝하는 방법을 보여줍니다. 더 깊이 파고들어 작동 방식을 확인하는 데 관심이 있는 경우, 이 가이드에 사용된 DreamBooth의 모든 학습 스크립트를 [여기](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)에서 찾을 수 있습니다.\n\n스크립트를 실행하기 전에 라이브러리의 학습에 필요한 dependencies를 설치해야 합니다. 또한 `main` GitHub 브랜치에서 🧨 Diffusers를 설치하는 것이 좋습니다.\n\n```bash\npip install git+https://github.com/huggingface/diffusers\npip install -U -r diffusers/examples/dreambooth/requirements.txt\n```\n\nxFormers는 학습에 필요한 요구 사항은 아니지만, 가능하면 [설치](../optimization/xformers)하는 것이 좋습니다. 학습 속도를 높이고 메모리 사용량을 줄일 수 있기 때문입니다.\n\n모든 dependencies을 설정한 후 다음을 사용하여 [🤗 Accelerate](https://github.com/huggingface/accelerate/) 환경을 다음과 같이 초기화합니다:\n\n```bash\naccelerate config\n```\n\n별도 설정 없이 기본 🤗 Accelerate 환경을 설치하려면 다음을 실행합니다:\n\n```bash\naccelerate config default\n```\n\n또는 현재 환경이 노트북과 같은 대화형 셸을 지원하지 않는 경우 다음을 사용할 수 있습니다:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n## 파인튜닝\n\n> [!WARNING]\n> DreamBooth 파인튜닝은 하이퍼파라미터에 매우 민감하고 과적합되기 쉽습니다. 적절한 하이퍼파라미터를 선택하는 데 도움이 되도록 다양한 권장 설정이 포함된 [심층 분석](https://huggingface.co/blog/dreambooth)을 살펴보는 것이 좋습니다.\n\n<frameworkcontent>\n<pt>\n[몇 장의 강아지 이미지들](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ)로 DreamBooth를 시도해봅시다.\n이를 다운로드해 디렉터리에 저장한 다음 `INSTANCE_DIR` 환경 변수를 해당 경로로 설정합니다:\n\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path_to_training_images\"\nexport OUTPUT_DIR=\"path_to_saved_model\"\n```\n\n그런 다음, 다음 명령을 사용하여 학습 스크립트를 실행할 수 있습니다 (전체 학습 스크립트는 [여기](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)에서 찾을 수 있습니다):\n\n```bash\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=400\n```\n</pt>\n<jax>\n\nTPU에 액세스할 수 있거나 더 빠르게 훈련하고 싶다면 [Flax 학습 스크립트](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_flax.py)를 사용해 볼 수 있습니다. Flax 학습 스크립트는 gradient checkpointing 또는 gradient accumulation을 지원하지 않으므로, 메모리가 30GB 이상인 GPU가 필요합니다.\n\n스크립트를 실행하기 전에 요구 사항이 설치되어 있는지 확인하십시오.\n\n```bash\npip install -U -r requirements.txt\n```\n\n그러면 다음 명령어로 학습 스크립트를 실행시킬 수 있습니다:\n\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\npython train_dreambooth_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --learning_rate=5e-6 \\\n  --max_train_steps=400\n```\n</jax>\n</frameworkcontent>\n\n### Prior-preserving(사전 보존) loss를 사용한 파인튜닝\n\n과적합과 language drift를 방지하기 위해 사전 보존이 사용됩니다(관심이 있는 경우 [논문](https://huggingface.co/papers/2208.12242)을 참조하세요).  사전 보존을 위해 동일한 클래스의 다른 이미지를 학습 프로세스의 일부로 사용합니다. 좋은 점은 Stable Diffusion 모델 자체를 사용하여 이러한 이미지를 생성할 수 있다는 것입니다! 학습 스크립트는 생성된 이미지를 우리가 지정한 로컬 경로에 저장합니다.\n\n저자들에 따르면 사전 보존을 위해 `num_epochs * num_samples`개의 이미지를 생성하는 것이 좋습니다. 200-300개에서 대부분 잘 작동합니다.\n\n<frameworkcontent>\n<pt>\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path_to_training_images\"\nexport CLASS_DIR=\"path_to_class_images\"\nexport OUTPUT_DIR=\"path_to_saved_model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n</pt>\n<jax>\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\npython train_dreambooth_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --learning_rate=5e-6 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n</jax>\n</frameworkcontent>\n\n## 텍스트 인코더와 and UNet로 파인튜닝하기\n\n해당 스크립트를 사용하면 `unet`과 함께 `text_encoder`를 파인튜닝할 수 있습니다. 실험에서(자세한 내용은 [🧨 Diffusers를 사용해 DreamBooth로 Stable Diffusion 학습하기](https://huggingface.co/blog/dreambooth) 게시물을 확인하세요), 특히 얼굴 이미지를 생성할 때 훨씬 더 나은 결과를 얻을 수 있습니다.\n\n> [!WARNING]\n> 텍스트 인코더를 학습시키려면 추가 메모리가 필요해 16GB GPU로는 동작하지 않습니다. 이 옵션을 사용하려면 최소 24GB VRAM이 필요합니다.\n\n`--train_text_encoder` 인수를 학습 스크립트에 전달하여 `text_encoder` 및 `unet`을 파인튜닝할 수 있습니다:\n\n<frameworkcontent>\n<pt>\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path_to_training_images\"\nexport CLASS_DIR=\"path_to_class_images\"\nexport OUTPUT_DIR=\"path_to_saved_model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --train_text_encoder \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --use_8bit_adam\n  --gradient_checkpointing \\\n  --learning_rate=2e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n</pt>\n<jax>\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\npython train_dreambooth_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --train_text_encoder \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --learning_rate=2e-6 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n</jax>\n</frameworkcontent>\n\n## LoRA로 파인튜닝하기\n\nDreamBooth에서 대규모 모델의 학습을 가속화하기 위한 파인튜닝 기술인 LoRA(Low-Rank Adaptation of Large Language Models)를 사용할 수 있습니다. 자세한 내용은 [LoRA 학습](training/lora#dreambooth) 가이드를 참조하세요.\n\n### 학습 중 체크포인트 저장하기\n\nDreambooth로 훈련하는 동안 과적합하기 쉬우므로, 때때로 학습 중에 정기적인 체크포인트를 저장하는 것이 유용합니다. 중간 체크포인트 중 하나가 최종 모델보다 더 잘 작동할 수 있습니다! 체크포인트 저장 기능을 활성화하려면 학습 스크립트에 다음 인수를 전달해야 합니다:\n\n```bash\n  --checkpointing_steps=500\n```\n\n이렇게 하면 `output_dir`의 하위 폴더에 전체 학습 상태가 저장됩니다. 하위 폴더 이름은 접두사 `checkpoint-`로 시작하고 지금까지 수행된 step 수입니다. 예시로 `checkpoint-1500`은 1500 학습 step 후에 저장된 체크포인트입니다.\n\n#### 저장된 체크포인트에서 훈련 재개하기\n\n저장된 체크포인트에서 훈련을 재개하려면, `--resume_from_checkpoint` 인수를 전달한 다음 사용할 체크포인트의 이름을 지정하면 됩니다. 특수 문자열 `\"latest\"`를 사용하여 저장된 마지막 체크포인트(즉, step 수가 가장 많은 체크포인트)에서 재개할 수도 있습니다. 예를 들어 다음은 1500 step 후에 저장된 체크포인트에서부터 학습을 재개합니다:\n\n```bash\n  --resume_from_checkpoint=\"checkpoint-1500\"\n```\n\n원하는 경우 일부 하이퍼파라미터를 조정할 수 있습니다.\n\n#### 저장된 체크포인트를 사용하여 추론 수행하기\n\n저장된 체크포인트는 훈련 재개에 적합한 형식으로 저장됩니다. 여기에는 모델 가중치뿐만 아니라 옵티마이저, 데이터 로더 및 학습률의 상태도 포함됩니다.\n\n**`\"accelerate>=0.16.0\"`**이 설치된 경우 다음 코드를 사용하여 중간 체크포인트에서 추론을 실행합니다.\n\n```python\nfrom diffusers import DiffusionPipeline, UNet2DConditionModel\nfrom transformers import CLIPTextModel\nimport torch\n\n# 학습에 사용된 것과 동일한 인수(model, revision)로 파이프라인을 불러옵니다.\nmodel_id = \"CompVis/stable-diffusion-v1-4\"\n\nunet = UNet2DConditionModel.from_pretrained(\"/sddata/dreambooth/daruma-v2-1/checkpoint-100/unet\")\n\n# `args.train_text_encoder`로 학습한 경우면 텍스트 인코더를 꼭 불러오세요\ntext_encoder = CLIPTextModel.from_pretrained(\"/sddata/dreambooth/daruma-v2-1/checkpoint-100/text_encoder\")\n\npipeline = DiffusionPipeline.from_pretrained(model_id, unet=unet, text_encoder=text_encoder, dtype=torch.float16)\npipeline.to(\"cuda\")\n\n# 추론을 수행하거나 저장하거나, 허브에 푸시합니다.\npipeline.save_pretrained(\"dreambooth-pipeline\")\n```\n\nIf you have **`\"accelerate<0.16.0\"`** installed, you need to convert it to an inference pipeline first:\n\n```python\nfrom accelerate import Accelerator\nfrom diffusers import DiffusionPipeline\n\n# 학습에 사용된 것과 동일한 인수(model, revision)로 파이프라인을 불러옵니다.\nmodel_id = \"CompVis/stable-diffusion-v1-4\"\npipeline = DiffusionPipeline.from_pretrained(model_id)\n\naccelerator = Accelerator()\n\n# 초기 학습에 `--train_text_encoder`가 사용된 경우 text_encoder를 사용합니다.\nunet, text_encoder = accelerator.prepare(pipeline.unet, pipeline.text_encoder)\n\n# 체크포인트 경로로부터 상태를 복원합니다. 여기서는 절대 경로를 사용해야 합니다.\naccelerator.load_state(\"/sddata/dreambooth/daruma-v2-1/checkpoint-100\")\n\n# unwrapped 모델로 파이프라인을 다시 빌드합니다.(.unet and .text_encoder로의 할당도 작동해야 합니다)\npipeline = DiffusionPipeline.from_pretrained(\n    model_id,\n    unet=accelerator.unwrap_model(unet),\n    text_encoder=accelerator.unwrap_model(text_encoder),\n)\n\n# 추론을 수행하거나 저장하거나, 허브에 푸시합니다.\npipeline.save_pretrained(\"dreambooth-pipeline\")\n```\n\n## 각 GPU 용량에서의 최적화\n\n하드웨어에 따라 16GB에서 8GB까지 GPU에서 DreamBooth를 최적화하는 몇 가지 방법이 있습니다!\n\n### xFormers\n\n[xFormers](https://github.com/facebookresearch/xformers)는 Transformers를 최적화하기 위한 toolbox이며, 🧨 Diffusers에서 사용되는[memory-efficient attention](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops)  메커니즘을 포함하고 있습니다. [xFormers를 설치](./optimization/xformers)한 다음 학습 스크립트에 다음 인수를 추가합니다:\n\n```bash\n  --enable_xformers_memory_efficient_attention\n```\n\nxFormers는 Flax에서 사용할 수 없습니다.\n\n### 그래디언트 없음으로 설정\n\n메모리 사용량을 줄일 수 있는 또 다른 방법은 [기울기 설정](https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html)을 0 대신 `None`으로 하는 것입니다. 그러나 이로 인해 특정 동작이 변경될 수 있으므로 문제가 발생하면 이 인수를 제거해 보십시오. 학습 스크립트에 다음 인수를 추가하여 그래디언트를 `None`으로 설정합니다.\n\n```bash\n  --set_grads_to_none\n```\n\n### 16GB GPU\n\nGradient checkpointing과 [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)의 8비트 옵티마이저의 도움으로, 16GB GPU에서 dreambooth를 훈련할 수 있습니다. bitsandbytes가 설치되어 있는지 확인하세요:\n\n```bash\npip install bitsandbytes\n```\n\n그 다음, 학습 스크립트에 `--use_8bit_adam` 옵션을 명시합니다:\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path_to_training_images\"\nexport CLASS_DIR=\"path_to_class_images\"\nexport OUTPUT_DIR=\"path_to_saved_model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=2 --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n\n### 12GB GPU\n\n12GB GPU에서 DreamBooth를 실행하려면 gradient checkpointing, 8비트 옵티마이저, xFormers를 활성화하고 그래디언트를 `None`으로 설정해야 합니다.\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --enable_xformers_memory_efficient_attention \\\n  --set_grads_to_none \\\n  --learning_rate=2e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n\n### 8GB GPU에서 학습하기\n\n8GB GPU에 대해서는 [DeepSpeed](https://www.deepspeed.ai/)를 사용해 일부 텐서를 VRAM에서 CPU 또는 NVME로 오프로드하여 더 적은 GPU 메모리로 학습할 수도 있습니다.\n\n🤗 Accelerate 환경을 구성하려면 다음 명령을 실행하세요:\n\n```bash\naccelerate config\n```\n\n환경 구성 중에 DeepSpeed를 사용할 것을 확인하세요.\n그러면 DeepSpeed stage 2, fp16 혼합 정밀도를 결합하고 모델 매개변수와 옵티마이저 상태를 모두 CPU로 오프로드하면 8GB VRAM 미만에서 학습할 수 있습니다.\n단점은 더 많은 시스템 RAM(약 25GB)이 필요하다는 것입니다. 추가 구성 옵션은 [DeepSpeed 문서](https://huggingface.co/docs/accelerate/usage_guides/deepspeed)를 참조하세요.\n\n또한 기본 Adam 옵티마이저를 DeepSpeed의 최적화된 Adam 버전으로 변경해야 합니다.\n이는 상당한 속도 향상을 위한 Adam인 [`deepspeed.ops.adam.DeepSpeedCPUAdam`](https://deepspeed.readthedocs.io/en/latest/optimizers.html#adam-cpu)입니다.\n`DeepSpeedCPUAdam`을 활성화하려면 시스템의 CUDA toolchain 버전이 PyTorch와 함께 설치된 것과 동일해야 합니다.\n\n8비트 옵티마이저는 현재 DeepSpeed와 호환되지 않는 것 같습니다.\n\n다음 명령으로 학습을 시작합니다:\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path_to_training_images\"\nexport CLASS_DIR=\"path_to_class_images\"\nexport OUTPUT_DIR=\"path_to_saved_model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --sample_batch_size=1 \\\n  --gradient_accumulation_steps=1 --gradient_checkpointing \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800 \\\n  --mixed_precision=fp16\n```\n\n## 추론\n\n모델을 학습한 후에는, 모델이 저장된 경로를 지정해 [`StableDiffusionPipeline`]로 추론을 수행할 수 있습니다. 프롬프트에 학습에 사용된 특수 `식별자`(이전 예시의 `sks`)가 포함되어 있는지 확인하세요.\n\n**`\"accelerate>=0.16.0\"`**이 설치되어 있는 경우 다음 코드를 사용하여 중간 체크포인트에서 추론을 실행할 수 있습니다:\n\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\nmodel_id = \"path_to_saved_model\"\npipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A photo of sks dog in a bucket\"\nimage = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]\n\nimage.save(\"dog-bucket.png\")\n```\n\n[저장된 학습 체크포인트](#inference-from-a-saved-checkpoint)에서도 추론을 실행할 수도 있습니다.\n"
  },
  {
    "path": "docs/source/ko/training/instructpix2pix.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# InstructPix2Pix\n\n[InstructPix2Pix](https://huggingface.co/papers/2211.09800)는 text-conditioned diffusion 모델이 한 이미지에 편집을 따를 수 있도록 파인튜닝하는 방법입니다. 이 방법을 사용하여 파인튜닝된 모델은 다음을 입력으로 사용합니다:\n\n<p align=\"center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png\" alt=\"instructpix2pix-inputs\" width=600/>\n</p>\n\n출력은 입력 이미지에 편집 지시가 반영된 \"수정된\" 이미지입니다:\n\n<p align=\"center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/output-gs%407-igs%401-steps%4050.png\" alt=\"instructpix2pix-output\" width=600/>\n</p>\n\n`train_instruct_pix2pix.py` 스크립트([여기](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py)에서 찾을 수 있습니다.)는 학습 절차를 설명하고 Stable Diffusion에 적용할 수 있는 방법을 보여줍니다.\n\n\n*** `train_instruct_pix2pix.py`는 [원래 구현](https://github.com/timothybrooks/instruct-pix2pix)에 충실하면서 InstructPix2Pix 학습 절차를 구현하고 있지만, [소규모 데이터셋](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples)에서만 테스트를 했습니다. 이는 최종 결과에 영향을 끼칠 수 있습니다. 더 나은 결과를 위해, 더 큰 데이터셋에서 더 길게 학습하는 것을 권장합니다. [여기](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered)에서 InstructPix2Pix 학습을 위해 큰 데이터셋을 찾을 수 있습니다.\n***\n\n## PyTorch로 로컬에서 실행하기\n\n### 종속성(dependencies) 설치하기\n\n이 스크립트를 실행하기 전에, 라이브러리의 학습 종속성을 설치하세요:\n\n**중요**\n\n최신 버전의 예제 스크립트를 성공적으로 실행하기 위해, **원본으로부터 설치**하는 것과 예제 스크립트를 자주 업데이트하고 예제별 요구사항을 설치하기 때문에 최신 상태로 유지하는 것을 권장합니다. 이를 위해, 새로운 가상 환경에서 다음 스텝을 실행하세요:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\ncd 명령어로 예제 폴더로 이동하세요.\n```bash\ncd examples/instruct_pix2pix\n```\n\n이제 실행하세요.\n```bash\npip install -r requirements.txt\n```\n\n그리고 [🤗Accelerate](https://github.com/huggingface/accelerate/) 환경에서 초기화하세요:\n\n```bash\naccelerate config\n```\n\n혹은 환경에 대한 질문 없이 기본적인 accelerate 구성을 사용하려면 다음을 실행하세요.\n\n```bash\naccelerate config default\n```\n\n혹은 사용 중인 환경이 notebook과 같은 대화형 쉘은 지원하지 않는 경우는 다음 절차를 따라주세요.\n\n```python\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n### 예시\n\n이전에 언급했듯이, 학습을 위해 [작은 데이터셋](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples)을 사용할 것입니다. 그 데이터셋은 InstructPix2Pix 논문에서 사용된 [원래의 데이터셋](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered)보다 작은 버전입니다. 자신의 데이터셋을 사용하기 위해, [학습을 위한 데이터셋 만들기](create_dataset) 가이드를 참고하세요.\n\n`MODEL_NAME` 환경 변수(허브 모델 레포지토리 또는 모델 가중치가 포함된 폴더 경로)를 지정하고 [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) 인수에 전달합니다. `DATASET_ID`에 데이터셋 이름을 지정해야 합니다:\n\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport DATASET_ID=\"fusing/instructpix2pix-1000-samples\"\n```\n\n지금, 학습을 실행할 수 있습니다. 스크립트는 레포지토리의 하위 폴더의 모든 구성요소(`feature_extractor`, `scheduler`, `text_encoder`, `unet` 등)를 저장합니다.\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" train_instruct_pix2pix.py \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --dataset_name=$DATASET_ID \\\n    --enable_xformers_memory_efficient_attention \\\n    --resolution=256 --random_flip \\\n    --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \\\n    --max_train_steps=15000 \\\n    --checkpointing_steps=5000 --checkpoints_total_limit=1 \\\n    --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \\\n    --conditioning_dropout_prob=0.05 \\\n    --mixed_precision=fp16 \\\n    --seed=42 \\\n    --push_to_hub\n```\n\n\n추가적으로, 가중치와 바이어스를 학습 과정에 모니터링하여 검증 추론을 수행하는 것을 지원합니다. `report_to=\"wandb\"`와 이 기능을 사용할 수 있습니다:\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" train_instruct_pix2pix.py \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --dataset_name=$DATASET_ID \\\n    --enable_xformers_memory_efficient_attention \\\n    --resolution=256 --random_flip \\\n    --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \\\n    --max_train_steps=15000 \\\n    --checkpointing_steps=5000 --checkpoints_total_limit=1 \\\n    --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \\\n    --conditioning_dropout_prob=0.05 \\\n    --mixed_precision=fp16 \\\n    --val_image_url=\"https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png\" \\\n    --validation_prompt=\"make the mountains snowy\" \\\n    --seed=42 \\\n    --report_to=wandb \\\n    --push_to_hub\n ```\n\n모델 디버깅에 유용한 이 평가 방법 권장합니다. 이를 사용하기 위해 `wandb`를 설치하는 것을 주목해주세요. `pip install wandb`로 실행해 `wandb`를 설치할 수 있습니다.\n\n[여기](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), 몇 가지 평가 방법과 학습 파라미터를 포함하는 예시를 볼 수 있습니다.\n\n ***참고: 원본 논문에서, 저자들은 256x256 이미지 해상도로 학습한 모델로 512x512와 같은 더 큰 해상도로 잘 일반화되는 것을 볼 수 있었습니다. 이는 학습에 사용한 큰 데이터셋을 사용했기 때문입니다.***\n\n ## 다수의 GPU로 학습하기\n\n`accelerate`는 원활한 다수의 GPU로 학습을 가능하게 합니다. `accelerate`로 분산 학습을 실행하는 [여기](https://huggingface.co/docs/accelerate/basic_tutorials/launch) 설명을 따라 해 주시기 바랍니다. 예시의 명령어 입니다:\n\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" --multi_gpu train_instruct_pix2pix.py \\\n --pretrained_model_name_or_path=stable-diffusion-v1-5/stable-diffusion-v1-5 \\\n --dataset_name=sayakpaul/instructpix2pix-1000-samples \\\n --use_ema \\\n --enable_xformers_memory_efficient_attention \\\n --resolution=512 --random_flip \\\n --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \\\n --max_train_steps=15000 \\\n --checkpointing_steps=5000 --checkpoints_total_limit=1 \\\n --learning_rate=5e-05 --lr_warmup_steps=0 \\\n --conditioning_dropout_prob=0.05 \\\n --mixed_precision=fp16 \\\n --seed=42 \\\n --push_to_hub\n```\n\n ## 추론하기\n\n일단 학습이 완료되면, 추론 할 수 있습니다:\n\n ```python\nimport PIL\nimport requests\nimport torch\nfrom diffusers import StableDiffusionInstructPix2PixPipeline\n\nmodel_id = \"your_model_id\"  # <- 이를 수정하세요.\npipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\n\nurl = \"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png\"\n\n\ndef download_image(url):\n    image = PIL.Image.open(requests.get(url, stream=True).raw)\n    image = PIL.ImageOps.exif_transpose(image)\n    image = image.convert(\"RGB\")\n    return image\n\n\nimage = download_image(url)\nprompt = \"wipe out the lake\"\nnum_inference_steps = 20\nimage_guidance_scale = 1.5\nguidance_scale = 10\n\nedited_image = pipe(\n    prompt,\n    image=image,\n    num_inference_steps=num_inference_steps,\n    image_guidance_scale=image_guidance_scale,\n    guidance_scale=guidance_scale,\n    generator=generator,\n).images[0]\nedited_image.save(\"edited_image.png\")\n```\n\n학습 스크립트를 사용해 얻은 예시의 모델 레포지토리는 여기 [sayakpaul/instruct-pix2pix](https://huggingface.co/sayakpaul/instruct-pix2pix)에서 확인할 수 있습니다.\n\n성능을 위한 속도와 품질을 제어하기 위해 세 가지 파라미터를 사용하는 것이 좋습니다:\n\n* `num_inference_steps`\n* `image_guidance_scale`\n* `guidance_scale`\n\n특히, `image_guidance_scale`와 `guidance_scale`는 생성된(\"수정된\") 이미지에서 큰 영향을 미칠 수 있습니다.([여기](https://twitter.com/RisingSayak/status/1628392199196151808?s=20)예시를 참고해주세요.)\n\n\n만약 InstructPix2Pix 학습 방법을 사용해 몇 가지 흥미로운 방법을 찾고 있다면, 이 블로그 게시물[Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd)을 확인해주세요."
  },
  {
    "path": "docs/source/ko/training/lora.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Low-Rank Adaptation of Large Language Models (LoRA)\n\n[[open-in-colab]]\n\n> [!WARNING]\n> 현재 LoRA는 [`UNet2DConditionalModel`]의 어텐션 레이어에서만 지원됩니다.\n\n[LoRA(Low-Rank Adaptation of Large Language Models)](https://huggingface.co/papers/2106.09685)는 메모리를 적게 사용하면서 대규모 모델의 학습을 가속화하는 학습 방법입니다. 이는 rank-decomposition weight 행렬 쌍(**업데이트 행렬**이라고 함)을 추가하고 새로 추가된 가중치**만** 학습합니다. 여기에는 몇 가지 장점이 있습니다.\n\n- 이전에 미리 학습된 가중치는 고정된 상태로 유지되므로 모델이 [치명적인 망각](https://www.pnas.org/doi/10.1073/pnas.1611835114) 경향이 없습니다.\n- Rank-decomposition 행렬은 원래 모델보다 파라메터 수가 훨씬 적으므로 학습된 LoRA 가중치를 쉽게 끼워넣을 수 있습니다.\n- LoRA 매트릭스는 일반적으로 원본 모델의 어텐션 레이어에 추가됩니다. 🧨 Diffusers는 [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`] 메서드를 제공하여 LoRA 가중치를 모델의 어텐션 레이어로 불러옵니다. `scale` 매개변수를 통해 모델이 새로운 학습 이미지에 맞게 조정되는 범위를 제어할 수 있습니다.\n- 메모리 효율성이 향상되어 Tesla T4, RTX 3080 또는 RTX 2080 Ti와 같은 소비자용 GPU에서 파인튜닝을 실행할 수 있습니다! T4와 같은 GPU는 무료이며 Kaggle 또는 Google Colab 노트북에서 쉽게 액세스할 수 있습니다.\n\n\n> [!TIP]\n> 💡 LoRA는 어텐션 레이어에만 한정되지는 않습니다. 저자는 언어 모델의 어텐션 레이어를 수정하는 것이 매우 효율적으로 죻은 성능을 얻기에 충분하다는 것을 발견했습니다. 이것이 LoRA 가중치를 모델의 어텐션 레이어에 추가하는 것이 일반적인 이유입니다. LoRA 작동 방식에 대한 자세한 내용은 [Using LoRA for effective Stable Diffusion fine-tuning](https://huggingface.co/blog/lora) 블로그를 확인하세요!\n\n[cloneofsimo](https://github.com/cloneofsimo)는 인기 있는 [lora](https://github.com/cloneofsimo/lora) GitHub 리포지토리에서 Stable Diffusion을 위한 LoRA 학습을 최초로 시도했습니다. 🧨 Diffusers는 [text-to-image 생성](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) 및 [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora)을 지원합니다. 이 가이드는 두 가지를 모두 수행하는 방법을 보여줍니다.\n\n모델을 저장하거나 커뮤니티와 공유하려면 Hugging Face 계정에 로그인하세요(아직 계정이 없는 경우 [생성](https://huggingface.co/join)하세요):\n\n```bash\nhf auth login\n```\n\n## Text-to-image\n\n수십억 개의 파라메터들이 있는 Stable Diffusion과 같은 모델을 파인튜닝하는 것은 느리고 어려울 수 있습니다. LoRA를 사용하면 diffusion 모델을 파인튜닝하는 것이 훨씬 쉽고 빠릅니다. 8비트 옵티마이저와 같은 트릭에 의존하지 않고도 11GB의 GPU RAM으로 하드웨어에서 실행할 수 있습니다.\n\n\n### 학습[[dreambooth-training]]\n\n[Naruto BLIP 캡션](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) 데이터셋으로 [`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)를 파인튜닝해 나만의 포켓몬을 생성해 보겠습니다.\n\n시작하려면 `MODEL_NAME` 및 `DATASET_NAME` 환경 변수가 설정되어 있는지 확인하십시오. `OUTPUT_DIR` 및 `HUB_MODEL_ID` 변수는 선택 사항이며 허브에서 모델을 저장할 위치를 지정합니다.\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"/sddata/finetune/lora/naruto\"\nexport HUB_MODEL_ID=\"naruto-lora\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n```\n\n학습을 시작하기 전에 알아야 할 몇 가지 플래그가 있습니다.\n\n* `--push_to_hub`를 명시하면 학습된 LoRA 임베딩을 허브에 저장합니다.\n* `--report_to=wandb`는 학습 결과를 가중치 및 편향 대시보드에 보고하고 기록합니다(예를 들어, 이 [보고서](https://wandb.ai/pcuenq/text2image-fine-tune/run/b4k1w0tn?workspace=user-pcuenq)를 참조하세요).\n* `--learning_rate=1e-04`, 일반적으로 LoRA에서 사용하는 것보다 더 높은 학습률을 사용할 수 있습니다.\n\n이제 학습을 시작할 준비가 되었습니다 (전체 학습 스크립트는 [여기](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)에서 찾을 수 있습니다).\n\n```bash\naccelerate launch train_dreambooth_lora.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --checkpointing_steps=100 \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=50 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n### 추론[[dreambooth-inference]]\n\n이제 [`StableDiffusionPipeline`]에서 기본 모델을 불러와 추론을 위해 모델을 사용할 수 있습니다:\n\n```py\n>>> import torch\n>>> from diffusers import StableDiffusionPipeline\n\n>>> model_base = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n\n>>> pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)\n```\n\n*기본 모델의 가중치 위에* 파인튜닝된 DreamBooth 모델에서 LoRA 가중치를 불러온 다음, 더 빠른 추론을 위해 파이프라인을 GPU로 이동합니다. LoRA 가중치를 프리징된 사전 훈련된 모델 가중치와 병합할 때, 선택적으로 'scale' 매개변수로 어느 정도의 가중치를 병합할 지 조절할 수 있습니다:\n\n> [!TIP]\n> 💡 `0`의 `scale` 값은 LoRA 가중치를 사용하지 않아 원래 모델의 가중치만 사용한 것과 같고, `1`의 `scale` 값은 파인튜닝된 LoRA 가중치만 사용함을 의미합니다. 0과 1 사이의 값들은 두 결과들 사이로 보간됩니다.\n\n```py\n>>> pipe.unet.load_attn_procs(model_path)\n>>> pipe.to(\"cuda\")\n# LoRA 파인튜닝된 모델의 가중치 절반과 기본 모델의 가중치 절반 사용\n\n>>> image = pipe(\n...     \"A picture of a sks dog in a bucket.\",\n...     num_inference_steps=25,\n...     guidance_scale=7.5,\n...     cross_attention_kwargs={\"scale\": 0.5},\n... ).images[0]\n# 완전히 파인튜닝된 LoRA 모델의 가중치 사용\n\n>>> image = pipe(\"A picture of a sks dog in a bucket.\", num_inference_steps=25, guidance_scale=7.5).images[0]\n>>> image.save(\"bucket-dog.png\")\n```"
  },
  {
    "path": "docs/source/ko/training/overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 🧨 Diffusers 학습 예시\n\n이번 챕터에서는 다양한 유즈케이스들에 대한 예제 코드들을 통해 어떻게하면 효과적으로 `diffusers` 라이브러리를 사용할 수 있을까에 대해 알아보도록 하겠습니다.\n\n**Note**: 혹시 오피셜한 예시코드를 찾고 있다면, [여기](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)를 참고해보세요!\n\n여기서 다룰 예시들은 다음을 지향합니다.\n\n- **손쉬운 디펜던시 설치** (Self-contained) : 여기서 사용될 예시 코드들의 디펜던시 패키지들은 전부 `pip install` 명령어를 통해 설치 가능한 패키지들입니다. 또한 친절하게 `requirements.txt` 파일에 해당 패키지들이 명시되어 있어, `pip install -r requirements.txt`로 간편하게 해당 디펜던시들을 설치할 수 있습니다. 예시: [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py), [requirements.txt](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/requirements.txt)\n- **손쉬운 수정** (Easy-to-tweak) : 저희는 가능하면 많은 유즈 케이스들을 제공하고자 합니다. 하지만 예시는 결국 그저 예시라는 점들 기억해주세요. 여기서 제공되는 예시코드들을 그저 단순히 복사-붙혀넣기하는 식으로는 여러분이 마주한 문제들을 손쉽게 해결할 순 없을 것입니다. 다시 말해 어느 정도는 여러분의 상황과 니즈에 맞춰 코드를 일정 부분 고쳐나가야 할 것입니다. 따라서 대부분의 학습 예시들은 데이터의 전처리 과정과 학습 과정에 대한 코드들을 함께 제공함으로써, 사용자가 니즈에 맞게 손쉬운 수정할 수 있도록 돕고 있습니다.\n- **입문자 친화적인** (Beginner-friendly) : 이번 챕터는 diffusion 모델과 `diffusers` 라이브러리에 대한 전반적인 이해를 돕기 위해 작성되었습니다. 따라서 diffusion 모델에 대한 최신 SOTA (state-of-the-art) 방법론들 가운데서도, 입문자에게는 많이 어려울 수 있다고 판단되면, 해당 방법론들은 여기서 다루지 않으려고 합니다.\n- **하나의 태스크만 포함할 것**(One-purpose-only): 여기서 다룰 예시들은 하나의 태스크만 포함하고 있어야 합니다. 물론 이미지 초해상화(super-resolution)와 이미지 보정(modification)과 같은 유사한 모델링 프로세스를 갖는 태스크들이 존재하겠지만, 하나의 예제에 하나의 태스크만을 담는 것이 더 이해하기 용이하다고 판단했기 때문입니다.\n\n\n\n저희는 diffusion 모델의 대표적인 태스크들을 다루는 공식 예제를 제공하고 있습니다. *공식* 예제는 현재 진행형으로 `diffusers` 관리자들(maintainers)에 의해 관리되고 있습니다. 또한 저희는 앞서 정의한 저희의 철학을 엄격하게 따르고자 노력하고 있습니다. 혹시 여러분께서 이러한 예시가 반드시 필요하다고 생각되신다면, 언제든지 [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) 혹은 직접 [Pull Request](https://github.com/huggingface/diffusers/compare)를 주시기 바랍니다. 저희는 언제나 환영입니다!\n\n학습 예시들은 다양한 태스크들에 대해 diffusion 모델을 사전학습(pretrain)하거나 파인튜닝(fine-tuning)하는 법을 보여줍니다. 현재 다음과 같은 예제들을 지원하고 있습니다.\n\n- [Unconditional Training](./unconditional_training)\n- [Text-to-Image Training](./text2image)\n- [Text Inversion](./text_inversion)\n- [Dreambooth](./dreambooth)\n\nmemory-efficient attention 연산을 수행하기 위해, 가능하면 [xFormers](../optimization/xformers)를 설치해주시기 바랍니다. 이를 통해 학습 속도를 늘리고 메모리에 대한 부담을 줄일 수 있습니다.\n\n| Task | 🤗 Accelerate | 🤗 Datasets | Colab\n|---|---|:---:|:---:|\n| [**Unconditional Image Generation**](./unconditional_training) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)\n| [**Text-to-Image fine-tuning**](./text2image) | ✅ | ✅ |\n| [**Textual Inversion**](./text_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)\n| [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)\n| [**Training with LoRA**](./lora) | ✅ | - | - |\n| [**ControlNet**](./controlnet) | ✅ | ✅ | - |\n| [**InstructPix2Pix**](./instructpix2pix) | ✅ | ✅ | - |\n| [**Custom Diffusion**](./custom_diffusion) | ✅ | ✅ | - |\n\n\n## 커뮤니티\n\n공식 예제 외에도 **커뮤니티 예제** 역시 제공하고 있습니다. 해당 예제들은 우리의 커뮤니티에 의해 관리됩니다. 커뮤니티 예쩨는 학습 예시나 추론 파이프라인으로 구성될 수 있습니다. 이러한 커뮤니티 예시들의 경우,  앞서 정의했던 철학들을 좀 더 관대하게 적용하고 있습니다. 또한 이러한 커뮤니티 예시들의 경우, 모든 이슈들에 대한 유지보수를 보장할 수는 없습니다.\n\n유용하긴 하지만, 아직은 대중적이지 못하거나 저희의 철학에 부합하지 않는 예제들은 [community examples](https://github.com/huggingface/diffusers/tree/main/examples/community) 폴더에 담기게 됩니다.\n\n**Note**: 커뮤니티 예제는 `diffusers`에 기여(contribution)를 희망하는 분들에게 [아주 좋은 기여 수단](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)이 될 수 있습니다.\n\n## 주목할 사항들\n\n최신 버전의 예시 코드들의 성공적인 구동을 보장하기 위해서는, 반드시 **소스코드를 통해 `diffusers`를 설치해야 하며,** 해당 예시 코드들이 요구하는 디펜던시들 역시 설치해야 합니다. 이를 위해 새로운 가상 환경을 구축하고 다음의 명령어를 실행해야 합니다.\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\n그 다음 `cd` 명령어를 통해 해당 예제 디렉토리에 접근해서 다음 명령어를 실행하면 됩니다.\n\n```bash\npip install -r requirements.txt\n```"
  },
  {
    "path": "docs/source/ko/training/text2image.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n\n# Text-to-image\n\n> [!WARNING]\n> text-to-image 파인튜닝 스크립트는 experimental 상태입니다. 과적합하기 쉽고 치명적인 망각과 같은 문제에 부딪히기 쉽습니다. 자체 데이터셋에서 최상의 결과를 얻으려면 다양한 하이퍼파라미터를 탐색하는 것이 좋습니다.\n\nStable Diffusion과 같은 text-to-image 모델은 텍스트 프롬프트에서 이미지를 생성합니다. 이 가이드는 PyTorch 및 Flax를 사용하여 자체 데이터셋에서 [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) 모델로 파인튜닝하는 방법을 보여줍니다. 이 가이드에 사용된 text-to-image 파인튜닝을 위한 모든 학습 스크립트에 관심이 있는 경우 이 [리포지토리](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image)에서 자세히 찾을 수 있습니다.\n\n스크립트를 실행하기 전에, 라이브러리의 학습 dependency들을 설치해야 합니다:\n\n```bash\npip install git+https://github.com/huggingface/diffusers.git\npip install -U -r requirements.txt\n```\n\n그리고 [🤗Accelerate](https://github.com/huggingface/accelerate/) 환경을 초기화합니다:\n\n```bash\naccelerate config\n```\n\n리포지토리를 이미 복제한 경우, 이 단계를 수행할 필요가 없습니다. 대신, 로컬 체크아웃 경로를 학습 스크립트에 명시할 수 있으며 거기에서 로드됩니다.\n\n### 하드웨어 요구 사항\n\n`gradient_checkpointing` 및 `mixed_precision`을 사용하면 단일 24GB GPU에서 모델을 파인튜닝할 수 있습니다. 더 높은 `batch_size`와 더 빠른 훈련을 위해서는 GPU 메모리가 30GB 이상인 GPU를 사용하는 것이 좋습니다. TPU 또는 GPU에서 파인튜닝을 위해 JAX나 Flax를 사용할 수도 있습니다. 자세한 내용은 [아래](#flax-jax-finetuning)를 참조하세요.\n\nxFormers로 memory efficient attention을 활성화하여 메모리 사용량 훨씬 더 줄일 수 있습니다. [xFormers가 설치](./optimization/xformers)되어 있는지 확인하고 `--enable_xformers_memory_efficient_attention`를 학습 스크립트에 명시합니다.\n\nxFormers는 Flax에 사용할 수 없습니다.\n\n## Hub에 모델 업로드하기\n\n학습 스크립트에 다음 인수를 추가하여 모델을 허브에 저장합니다:\n\n```bash\n  --push_to_hub\n```\n\n\n## 체크포인트 저장 및 불러오기\n\n학습 중 발생할 수 있는 일에 대비하여 정기적으로 체크포인트를 저장해 두는 것이 좋습니다. 체크포인트를 저장하려면 학습 스크립트에 다음 인수를 명시합니다.\n\n```bash\n  --checkpointing_steps=500\n```\n\n500스텝마다 전체 학습 state가 'output_dir'의 하위 폴더에 저장됩니다. 체크포인트는 'checkpoint-'에 지금까지 학습된 step 수입니다. 예를 들어 'checkpoint-1500'은 1500 학습 step 후에 저장된 체크포인트입니다.\n\n학습을 재개하기 위해 체크포인트를 불러오려면 '--resume_from_checkpoint' 인수를 학습 스크립트에 명시하고 재개할 체크포인트를 지정하십시오. 예를 들어 다음 인수는 1500개의 학습 step 후에 저장된 체크포인트에서부터 훈련을 재개합니다.\n\n```bash\n  --resume_from_checkpoint=\"checkpoint-1500\"\n```\n\n## 파인튜닝\n\n<frameworkcontent>\n<pt>\n다음과 같이 [Naruto BLIP 캡션](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) 데이터셋에서 파인튜닝 실행을 위해 [PyTorch 학습 스크립트](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)를 실행합니다:\n\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport dataset_name=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch train_text_to_image.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$dataset_name \\\n  --use_ema \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --mixed_precision=\"fp16\" \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --output_dir=\"sd-naruto-model\"\n```\n\n자체 데이터셋으로 파인튜닝하려면 🤗 [Datasets](https://huggingface.co/docs/datasets/index)에서 요구하는 형식에 따라 데이터셋을 준비하세요. [데이터셋을 허브에 업로드](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)하거나 [파일들이 있는 로컬 폴더를 준비](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)할 수 있습니다.\n\n사용자 커스텀 loading logic을 사용하려면 스크립트를 수정하십시오. 도움이 되도록 코드의 적절한 위치에 포인터를 남겼습니다. 🤗 아래 예제 스크립트는 `TRAIN_DIR`의 로컬 데이터셋으로를 파인튜닝하는 방법과 `OUTPUT_DIR`에서 모델을 저장할 위치를 보여줍니다:\n\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport TRAIN_DIR=\"path_to_your_dataset\"\nexport OUTPUT_DIR=\"path_to_save_model\"\n\naccelerate launch train_text_to_image.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$TRAIN_DIR \\\n  --use_ema \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --mixed_precision=\"fp16\" \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --output_dir=${OUTPUT_DIR}\n```\n\n</pt>\n<jax>\n[@duongna211](https://github.com/duongna21)의 기여로, Flax를 사용해 TPU 및 GPU에서 Stable Diffusion 모델을 더 빠르게 학습할 수 있습니다. 이는 TPU 하드웨어에서 매우 효율적이지만 GPU에서도 훌륭하게 작동합니다. Flax 학습 스크립트는 gradient checkpointing나 gradient accumulation과 같은 기능을 아직 지원하지 않으므로 메모리가 30GB 이상인 GPU 또는 TPU v3가 필요합니다.\n\n스크립트를 실행하기 전에 요구 사항이 설치되어 있는지 확인하십시오:\n\n```bash\npip install -U -r requirements_flax.txt\n```\n\n그러면 다음과 같이 [Flax 학습 스크립트](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py)를 실행할 수 있습니다.\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport dataset_name=\"lambdalabs/naruto-blip-captions\"\n\npython train_text_to_image_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$dataset_name \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --output_dir=\"sd-naruto-model\"\n```\n\n자체 데이터셋으로 파인튜닝하려면 🤗 [Datasets](https://huggingface.co/docs/datasets/index)에서 요구하는 형식에 따라 데이터셋을 준비하세요. [데이터셋을 허브에 업로드](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)하거나 [파일들이 있는 로컬 폴더를 준비](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)할 수 있습니다.\n\n사용자 커스텀 loading logic을 사용하려면 스크립트를 수정하십시오. 도움이 되도록 코드의 적절한 위치에 포인터를 남겼습니다. 🤗 아래 예제 스크립트는 `TRAIN_DIR`의 로컬 데이터셋으로를 파인튜닝하는 방법을 보여줍니다:\n\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport TRAIN_DIR=\"path_to_your_dataset\"\n\npython train_text_to_image_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$TRAIN_DIR \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --mixed_precision=\"fp16\" \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --output_dir=\"sd-naruto-model\"\n```\n</jax>\n</frameworkcontent>\n\n## LoRA\n\nText-to-image 모델 파인튜닝을 위해, 대규모 모델 학습을 가속화하기 위한 파인튜닝 기술인 LoRA(Low-Rank Adaptation of Large Language Models)를 사용할 수 있습니다. 자세한 내용은 [LoRA 학습](lora#text-to-image) 가이드를 참조하세요.\n\n## 추론\n\n허브의 모델 경로 또는 모델 이름을 [`StableDiffusionPipeline`]에 전달하여 추론을 위해 파인 튜닝된 모델을 불러올 수 있습니다:\n\n<frameworkcontent>\n<pt>\n```python\nfrom diffusers import StableDiffusionPipeline\n\nmodel_path = \"path_to_saved_model\"\npipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)\npipe.to(\"cuda\")\n\nimage = pipe(prompt=\"yoda\").images[0]\nimage.save(\"yoda-naruto.png\")\n```\n</pt>\n<jax>\n```python\nimport jax\nimport numpy as np\nfrom flax.jax_utils import replicate\nfrom flax.training.common_utils import shard\nfrom diffusers import FlaxStableDiffusionPipeline\n\nmodel_path = \"path_to_saved_model\"\npipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)\n\nprompt = \"yoda naruto\"\nprng_seed = jax.random.PRNGKey(0)\nnum_inference_steps = 50\n\nnum_samples = jax.device_count()\nprompt = num_samples * [prompt]\nprompt_ids = pipeline.prepare_inputs(prompt)\n\n# shard inputs and rng\nparams = replicate(params)\nprng_seed = jax.random.split(prng_seed, jax.device_count())\nprompt_ids = shard(prompt_ids)\n\nimages = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images\nimages = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))\nimage.save(\"yoda-naruto.png\")\n```\n</jax>\n</frameworkcontent>"
  },
  {
    "path": "docs/source/ko/training/text_inversion.md",
    "content": " <!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n\n\n# Textual-Inversion\n\n[[open-in-colab]]\n\n[textual-inversion](https://huggingface.co/papers/2208.01618)은 소수의 예시 이미지에서 새로운 콘셉트를 포착하는 기법입니다. 이 기술은 원래 [Latent Diffusion](https://github.com/CompVis/latent-diffusion)에서 시연되었지만, 이후 [Stable Diffusion](https://huggingface.co/docs/diffusers/main/en/conceptual/stable_diffusion)과 같은 유사한 다른 모델에도 적용되었습니다. 학습된 콘셉트는 text-to-image 파이프라인에서 생성된 이미지를 더 잘 제어하는 데 사용할 수 있습니다. 이 모델은 텍스트 인코더의 임베딩 공간에서 새로운 '단어'를 학습하여 개인화된 이미지 생성을 위한 텍스트 프롬프트 내에서 사용됩니다.\n\n![Textual Inversion example](https://textual-inversion.github.io/static/images/editing/colorful_teapot.JPG)\n<small>By using just 3-5 images you can teach new concepts to a model such as Stable Diffusion for personalized image generation <a href=\"https://github.com/rinongal/textual_inversion\">(image source)</a>.</small>\n\n이 가이드에서는 textual-inversion으로 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 모델을 학습하는 방법을 설명합니다. 이 가이드에서 사용된 모든 textual-inversion 학습 스크립트는 [여기](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion)에서 확인할 수 있습니다. 내부적으로 어떻게 작동하는지 자세히 살펴보고 싶으시다면 해당 링크를 참조해주시기 바랍니다.\n\n> [!TIP]\n> [Stable Diffusion Textual Inversion Concepts Library](https://huggingface.co/sd-concepts-library)에는 커뮤니티에서 제작한 학습된 textual-inversion 모델들이 있습니다. 시간이 지남에 따라 더 많은 콘셉트들이 추가되어 유용한 리소스로 성장할 것입니다!\n\n시작하기 전에 학습을 위한 의존성 라이브러리들을 설치해야 합니다:\n\n```bash\npip install diffusers accelerate transformers\n```\n\n의존성 라이브러리들의 설치가 완료되면, [🤗Accelerate](https://github.com/huggingface/accelerate/) 환경을 초기화시킵니다.\n\n```bash\naccelerate config\n```\n\n별도의 설정없이, 기본 🤗Accelerate 환경을 설정하려면 다음과 같이 하세요:\n\n```bash\naccelerate config default\n```\n\n또는 사용 중인 환경이 노트북과 같은 대화형 셸을 지원하지 않는다면, 다음과 같이 사용할 수 있습니다:\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n마지막으로, Memory-Efficient Attention을 통해 메모리 사용량을 줄이기 위해 [xFormers](https://huggingface.co/docs/diffusers/main/en/training/optimization/xformers)를 설치합니다. xFormers를 설치한 후, 학습 스크립트에 `--enable_xformers_memory_efficient_attention` 인자를 추가합니다. xFormers는 Flax에서 지원되지 않습니다.\n\n## 허브에 모델 업로드하기\n\n모델을 허브에 저장하려면, 학습 스크립트에 다음 인자를 추가해야 합니다.\n\n```bash\n--push_to_hub\n```\n\n## 체크포인트 저장 및 불러오기\n\n학습중에 모델의 체크포인트를 정기적으로 저장하는 것이 좋습니다. 이렇게 하면 어떤 이유로든 학습이 중단된 경우 저장된 체크포인트에서 학습을 다시 시작할 수 있습니다. 학습 스크립트에 다음 인자를 전달하면 500단계마다 전체 학습 상태가 `output_dir`의 하위 폴더에 체크포인트로서 저장됩니다.\n\n```bash\n--checkpointing_steps=500\n```\n\n저장된 체크포인트에서 학습을 재개하려면, 학습 스크립트와 재개할 특정 체크포인트에 다음 인자를 전달하세요.\n\n```bash\n--resume_from_checkpoint=\"checkpoint-1500\"\n```\n\n## 파인 튜닝\n\n학습용 데이터셋으로 [고양이 장난감 데이터셋](https://huggingface.co/datasets/diffusers/cat_toy_example)을 다운로드하여 디렉토리에 저장하세요. 여러분만의 고유한 데이터셋을 사용하고자 한다면, [학습용 데이터셋 만들기](https://huggingface.co/docs/diffusers/training/create_dataset) 가이드를 살펴보시기 바랍니다.\n\n```py\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./cat\"\nsnapshot_download(\n    \"diffusers/cat_toy_example\", local_dir=local_dir, repo_type=\"dataset\", ignore_patterns=\".gitattributes\"\n)\n```\n\n모델의 리포지토리 ID(또는 모델 가중치가 포함된 디렉터리 경로)를 `MODEL_NAME` 환경 변수에 할당하고, 해당 값을 [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) 인자에 전달합니다. 그리고 이미지가 포함된 디렉터리 경로를 `DATA_DIR` 환경 변수에 할당합니다.\n\n이제 [학습 스크립트](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py)를 실행할 수 있습니다. 스크립트는 다음 파일을 생성하고 리포지토리에 저장합니다.\n\n- `learned_embeds.bin`\n- `token_identifier.txt`\n- `type_of_concept.txt`.\n\n> [!TIP]\n> 💡V100 GPU 1개를 기준으로 전체 학습에는 최대 1시간이 걸립니다. 학습이 완료되기를 기다리는 동안 궁금한 점이 있으면 아래 섹션에서 [textual-inversion이 어떻게 작동하는지](https://huggingface.co/docs/diffusers/training/text_inversion#how-it-works) 자유롭게 확인하세요 !\n\n<frameworkcontent>\n<pt>\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport DATA_DIR=\"./cat\"\n\naccelerate launch textual_inversion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<cat-toy>\" --initializer_token=\"toy\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=3000 \\\n  --learning_rate=5.0e-04 --scale_lr \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --output_dir=\"textual_inversion_cat\" \\\n  --push_to_hub\n```\n\n> [!TIP]\n> 💡학습 성능을 올리기 위해, 플레이스홀더 토큰(`<cat-toy>`)을 (단일한 임베딩 벡터가 아닌) 복수의 임베딩 벡터로 표현하는 것 역시 고려할 있습니다.  이러한 트릭이 모델이 보다 복잡한 이미지의 스타일(앞서 말한 콘셉트)을 더 잘 캡처하는 데 도움이 될 수 있습니다. 복수의 임베딩 벡터 학습을 활성화하려면 다음 옵션을 전달하십시오.\n>\n> ```bash\n> --num_vectors=5\n> ```\n</pt>\n<jax>\n\nTPU에 액세스할 수 있는 경우, [Flax 학습 스크립트](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py)를 사용하여 더 빠르게 모델을 학습시켜보세요. (물론 GPU에서도 작동합니다.) 동일한 설정에서 Flax 학습 스크립트는 PyTorch 학습 스크립트보다 최소 70% 더 빨라야 합니다! ⚡️\n\n시작하기 앞서 Flax에 대한 의존성 라이브러리들을 설치해야 합니다.\n\n```bash\npip install -U -r requirements_flax.txt\n```\n\n모델의 리포지토리 ID(또는 모델 가중치가 포함된 디렉터리 경로)를 `MODEL_NAME` 환경 변수에 할당하고, 해당 값을 [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) 인자에 전달합니다.\n\n그런 다음 [학습 스크립트](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py)를 시작할 수 있습니다.\n\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport DATA_DIR=\"./cat\"\n\npython textual_inversion_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<cat-toy>\" --initializer_token=\"toy\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --max_train_steps=3000 \\\n  --learning_rate=5.0e-04 --scale_lr \\\n  --output_dir=\"textual_inversion_cat\" \\\n  --push_to_hub\n```\n</jax>\n</frameworkcontent>\n\n### 중간 로깅\n\n모델의 학습 진행 상황을 추적하는 데 관심이 있는 경우, 학습 과정에서 생성된 이미지를 저장할 수 있습니다. 학습 스크립트에 다음 인수를 추가하여 중간 로깅을 활성화합니다.\n\n- `validation_prompt` : 샘플을 생성하는 데 사용되는 프롬프트(기본값은 `None`으로 설정되며, 이 때 중간 로깅은 비활성화됨)\n- `num_validation_images` : 생성할 샘플 이미지 수\n- `validation_steps` : `validation_prompt`로부터 샘플 이미지를 생성하기 전 스텝의 수\n\n```bash\n--validation_prompt=\"A <cat-toy> backpack\"\n--num_validation_images=4\n--validation_steps=100\n```\n\n## 추론\n\n모델을 학습한 후에는, 해당 모델을 [`StableDiffusionPipeline`]을 사용하여 추론에 사용할 수 있습니다.\n\ntextual-inversion 스크립트는 기본적으로 textual-inversion을 통해 얻어진 임베딩 벡터만을 저장합니다. 해당 임베딩 벡터들은 텍스트 인코더의 임베딩 행렬에 추가되어 있습습니다.\n\n<frameworkcontent>\n<pt>\n> [!TIP]\n> 💡 커뮤니티는 [sd-concepts-library](https://huggingface.co/sd-concepts-library) 라는 대규모의 textual-inversion 임베딩 벡터 라이브러리를 만들었습니다. textual-inversion 임베딩을 밑바닥부터 학습하는 대신, 해당 라이브러리에 본인이 찾는 textual-inversion 임베딩이 이미 추가되어 있지 않은지를 확인하는 것도 좋은 방법이 될 것 같습니다.\n\ntextual-inversion 임베딩 벡터을 불러오기 위해서는, 먼저 해당 임베딩 벡터를 학습할 때 사용한 모델을 불러와야 합니다. 여기서는  [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/docs/diffusers/training/stable-diffusion-v1-5/stable-diffusion-v1-5) 모델이 사용되었다고 가정하고 불러오겠습니다.\n\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\n```\n\n다음으로 `TextualInversionLoaderMixin.load_textual_inversion` 함수를 통해, textual-inversion 임베딩 벡터를 불러와야 합니다. 여기서 우리는 이전의 `<cat-toy>` 예제의 임베딩을 불러올 것입니다.\n\n```python\npipe.load_textual_inversion(\"sd-concepts-library/cat-toy\")\n```\n\n이제 플레이스홀더 토큰(`<cat-toy>`)이 잘 동작하는지를 확인하는 파이프라인을 실행할 수 있습니다.\n\n```python\nprompt = \"A <cat-toy> backpack\"\n\nimage = pipe(prompt, num_inference_steps=50).images[0]\nimage.save(\"cat-backpack.png\")\n```\n\n`TextualInversionLoaderMixin.load_textual_inversion`은 Diffusers 형식으로 저장된 텍스트 임베딩 벡터를 로드할 수 있을 뿐만 아니라, [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 형식으로 저장된 임베딩 벡터도 로드할 수 있습니다. 이렇게 하려면, 먼저 [civitAI](https://civitai.com/models/3036?modelVersionId=8387)에서 임베딩 벡터를 다운로드한 다음 로컬에서 불러와야 합니다.\n\n```python\npipe.load_textual_inversion(\"./charturnerv2.pt\")\n```\n</pt>\n<jax>\n\n현재 Flax에 대한 `load_textual_inversion` 함수는 없습니다. 따라서 학습 후 textual-inversion 임베딩 벡터가 모델의 일부로서 저장되었는지를 확인해야 합니다. 그런 다음은 다른 Flax 모델과 마찬가지로 실행할 수 있습니다.\n\n```python\nimport jax\nimport numpy as np\nfrom flax.jax_utils import replicate\nfrom flax.training.common_utils import shard\nfrom diffusers import FlaxStableDiffusionPipeline\n\nmodel_path = \"path-to-your-trained-model\"\npipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)\n\nprompt = \"A <cat-toy> backpack\"\nprng_seed = jax.random.PRNGKey(0)\nnum_inference_steps = 50\n\nnum_samples = jax.device_count()\nprompt = num_samples * [prompt]\nprompt_ids = pipeline.prepare_inputs(prompt)\n\n# shard inputs and rng\nparams = replicate(params)\nprng_seed = jax.random.split(prng_seed, jax.device_count())\nprompt_ids = shard(prompt_ids)\n\nimages = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images\nimages = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))\nimage.save(\"cat-backpack.png\")\n```\n</jax>\n</frameworkcontent>\n\n## 작동 방식\n\n![Diagram from the paper showing overview](https://textual-inversion.github.io/static/images/training/training.JPG)\n<small>Architecture overview from the Textual Inversion <a href=\"https://textual-inversion.github.io/\">blog post.</a></small>\n\n일반적으로 텍스트 프롬프트는 모델에 전달되기 전에 임베딩으로 토큰화됩니다. textual-inversion은 비슷한 작업을 수행하지만, 위 다이어그램의 특수 토큰 `S*`로부터 새로운 토큰 임베딩 `v*`를 학습합니다. 모델의 아웃풋은 디퓨전 모델을 조정하는 데 사용되며, 디퓨전 모델이 단 몇 개의 예제 이미지에서 신속하고 새로운 콘셉트를 이해하는 데 도움을 줍니다.\n\n이를 위해 textual-inversion은 제너레이터 모델과 학습용 이미지의 노이즈 버전을 사용합니다. 제너레이터는 노이즈가 적은 버전의 이미지를 예측하려고 시도하며 토큰 임베딩 `v*`은 제너레이터의 성능에 따라 최적화됩니다. 토큰 임베딩이 새로운 콘셉트를 성공적으로 포착하면 디퓨전 모델에 더 유용한 정보를 제공하고 노이즈가 적은 더 선명한 이미지를 생성하는 데 도움이 됩니다. 이러한 최적화 프로세스는 일반적으로 다양한 프롬프트와 이미지에 수천 번에 노출됨으로써 이루어집니다.\n\n"
  },
  {
    "path": "docs/source/ko/training/unconditional_training.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Unconditional 이미지 생성\n\nunconditional 이미지 생성은 text-to-image 또는 image-to-image 모델과 달리 텍스트나 이미지에 대한 조건이 없이 학습 데이터 분포와 유사한 이미지만을 생성합니다.\n\n<iframe\n\tsrc=\"https://stevhliu-ddpm-butterflies-128.hf.space\"\n\tframeborder=\"0\"\n\twidth=\"850\"\n\theight=\"550\"\n></iframe>\n\n\n이 가이드에서는 기존에 존재하던 데이터셋과 자신만의 커스텀 데이터셋에 대해 unconditional image generation 모델을 훈련하는 방법을 설명합니다. 훈련 세부 사항에 대해 더 자세히 알고 싶다면 unconditional image generation을 위한 모든 학습 스크립트를 [여기](https://github.com/huggingface/diffusers/tree/main/examples/unconditional_image_generation)에서 확인할 수 있습니다.\n\n스크립트를 실행하기 전, 먼저 의존성 라이브러리들을 설치해야 합니다.\n\n```bash\npip install diffusers[training] accelerate datasets\n```\n\n그 다음 🤗 [Accelerate](https://github.com/huggingface/accelerate/) 환경을 초기화합니다.\n\n```bash\naccelerate config\n```\n\n별도의 설정 없이 기본 설정으로 🤗 [Accelerate](https://github.com/huggingface/accelerate/) 환경을 초기화해봅시다.\n\n```bash\naccelerate config default\n```\n\n노트북과 같은 대화형 쉘을 지원하지 않는 환경의 경우, 다음과 같이 사용해볼 수도 있습니다.\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n## 모델을 허브에 업로드하기\n\n학습 스크립트에 다음 인자를 추가하여 허브에 모델을 업로드할 수 있습니다.\n\n```bash\n--push_to_hub\n```\n\n## 체크포인트 저장하고 불러오기\n\n훈련 중 문제가 발생할 경우를 대비하여 체크포인트를 정기적으로 저장하는 것이 좋습니다. 체크포인트를 저장하려면 학습 스크립트에 다음 인자를 전달합니다:\n\n```bash\n--checkpointing_steps=500\n```\n\n전체 훈련 상태는 500스텝마다 `output_dir`의 하위 폴더에 저장되며, 학습 스크립트에 `--resume_from_checkpoint` 인자를 전달함으로써 체크포인트를 불러오고 훈련을 재개할 수 있습니다.\n\n```bash\n--resume_from_checkpoint=\"checkpoint-1500\"\n```\n\n## 파인튜닝\n\n이제 학습 스크립트를 시작할 준비가 되었습니다! `--dataset_name` 인자에 파인튜닝할 데이터셋 이름을 지정한 다음, `--output_dir` 인자에 지정된 경로로 저장합니다. 본인만의 데이터셋를 사용하려면, [학습용 데이터셋 만들기](create_dataset) 가이드를 참조하세요.\n\n학습 스크립트는 `diffusion_pytorch_model.bin` 파일을 생성하고, 그것을 당신의 리포지토리에 저장합니다.\n\n> [!TIP]\n> 💡 전체 학습은 V100 GPU 4개를 사용할 경우, 2시간이 소요됩니다.\n\n예를 들어, [Oxford Flowers](https://huggingface.co/datasets/huggan/flowers-102-categories) 데이터셋을 사용해 파인튜닝할 경우:\n\n```bash\naccelerate launch train_unconditional.py \\\n  --dataset_name=\"huggan/flowers-102-categories\" \\\n  --resolution=64 \\\n  --output_dir=\"ddpm-ema-flowers-64\" \\\n  --train_batch_size=16 \\\n  --num_epochs=100 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=1e-4 \\\n  --lr_warmup_steps=500 \\\n  --mixed_precision=no \\\n  --push_to_hub\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://user-images.githubusercontent.com/26864830/180248660-a0b143d0-b89a-42c5-8656-2ebf6ece7e52.png\"/>\n</div>\n[Naruto](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) 데이터셋을 사용할 경우:\n\n```bash\naccelerate launch train_unconditional.py \\\n  --dataset_name=\"lambdalabs/naruto-blip-captions\" \\\n  --resolution=64 \\\n  --output_dir=\"ddpm-ema-naruto-64\" \\\n  --train_batch_size=16 \\\n  --num_epochs=100 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=1e-4 \\\n  --lr_warmup_steps=500 \\\n  --mixed_precision=no \\\n  --push_to_hub\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://user-images.githubusercontent.com/26864830/180248200-928953b4-db38-48db-b0c6-8b740fe6786f.png\"/>\n</div>\n\n### 여러개의 GPU로 훈련하기\n\n`accelerate`을 사용하면 원활한 다중 GPU 훈련이 가능합니다. `accelerate`을 사용하여 분산 훈련을 실행하려면 [여기](https://huggingface.co/docs/accelerate/basic_tutorials/launch) 지침을 따르세요. 다음은 명령어 예제입니다.\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" --multi_gpu train_unconditional.py \\\n  --dataset_name=\"lambdalabs/naruto-blip-captions\" \\\n  --resolution=64 --center_crop --random_flip \\\n  --output_dir=\"ddpm-ema-naruto-64\" \\\n  --train_batch_size=16 \\\n  --num_epochs=100 \\\n  --gradient_accumulation_steps=1 \\\n  --use_ema \\\n  --learning_rate=1e-4 \\\n  --lr_warmup_steps=500 \\\n  --mixed_precision=\"fp16\" \\\n  --logger=\"wandb\" \\\n  --push_to_hub\n```\n"
  },
  {
    "path": "docs/source/ko/tutorials/basic_training.md",
    "content": "﻿<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n\n# Diffusion 모델을 학습하기\n\nUnconditional 이미지 생성은 학습에 사용된 데이터셋과 유사한 이미지를 생성하는 diffusion 모델에서 인기 있는 어플리케이션입니다. 일반적으로, 가장 좋은 결과는 특정 데이터셋에 사전 훈련된 모델을 파인튜닝하는 것으로 얻을 수 있습니다. 이 [허브](https://huggingface.co/search/full-text?q=unconditional-image-generation&type=model)에서 이러한 많은 체크포인트를 찾을 수 있지만, 만약 마음에 드는 체크포인트를 찾지 못했다면, 언제든지 스스로 학습할 수 있습니다!\n\n이 튜토리얼은 나만의 🦋 나비 🦋를 생성하기 위해 [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) 데이터셋의 하위 집합에서 [`UNet2DModel`] 모델을 학습하는 방법을 가르쳐줄 것입니다.\n\n> [!TIP]\n> 💡 이 학습 튜토리얼은 [Training with 🧨 Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) 노트북 기반으로 합니다. Diffusion 모델의 작동 방식 및 자세한 내용은 노트북을 확인하세요!\n\n시작 전에, 🤗 Datasets을 불러오고 전처리하기 위해 데이터셋이 설치되어 있는지 다수 GPU에서 학습을 간소화하기 위해 🤗 Accelerate 가 설치되어 있는지 확인하세요. 그 후 학습 메트릭을 시각화하기 위해 [TensorBoard](https://www.tensorflow.org/tensorboard)를 또한 설치하세요. (또한 학습 추적을 위해 [Weights & Biases](https://docs.wandb.ai/)를 사용할 수 있습니다.)\n\n```bash\n!pip install diffusers[training]\n```\n\n커뮤니티에 모델을 공유할 것을 권장하며, 이를 위해서 Hugging Face 계정에 로그인을 해야 합니다. (계정이 없다면 [여기](https://hf.co/join)에서 만들 수 있습니다.) 노트북에서 로그인할 수 있으며 메시지가 표시되면 토큰을 입력할 수 있습니다.\n\n```py\n>>> from huggingface_hub import notebook_login\n\n>>> notebook_login()\n```\n\n또는 터미널로 로그인할 수 있습니다:\n\n```bash\nhf auth login\n```\n\n모델 체크포인트가 상당히 크기 때문에 [Git-LFS](https://git-lfs.com/)에서 대용량 파일의 버전 관리를 할 수 있습니다.\n\n```bash\n!sudo apt -qq install git-lfs\n!git config --global credential.helper store\n```\n\n\n## 학습 구성\n\n편의를 위해 학습 파라미터들을 포함한 `TrainingConfig` 클래스를 생성합니다 (자유롭게 조정 가능):\n\n```py\n>>> from dataclasses import dataclass\n\n\n>>> @dataclass\n... class TrainingConfig:\n...     image_size = 128  # 생성되는 이미지 해상도\n...     train_batch_size = 16\n...     eval_batch_size = 16  # 평가 동안에 샘플링할 이미지 수\n...     num_epochs = 50\n...     gradient_accumulation_steps = 1\n...     learning_rate = 1e-4\n...     lr_warmup_steps = 500\n...     save_image_epochs = 10\n...     save_model_epochs = 30\n...     mixed_precision = \"fp16\"  # `no`는 float32, 자동 혼합 정밀도를 위한 `fp16`\n...     output_dir = \"ddpm-butterflies-128\"  # 로컬 및 HF Hub에 저장되는 모델명\n\n...     push_to_hub = True  # 저장된 모델을 HF Hub에 업로드할지 여부\n...     hub_private_repo = None\n...     overwrite_output_dir = True  # 노트북을 다시 실행할 때 이전 모델에 덮어씌울지\n...     seed = 0\n\n\n>>> config = TrainingConfig()\n```\n\n\n## 데이터셋 불러오기\n\n🤗 Datasets 라이브러리와 [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) 데이터셋을 쉽게 불러올 수 있습니다.\n\n```py\n>>> from datasets import load_dataset\n\n>>> config.dataset_name = \"huggan/smithsonian_butterflies_subset\"\n>>> dataset = load_dataset(config.dataset_name, split=\"train\")\n```\n\n💡[HugGan Community Event](https://huggingface.co/huggan) 에서 추가의 데이터셋을 찾거나 로컬의 [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder)를 만듦으로써 나만의 데이터셋을 사용할 수 있습니다. HugGan Community Event 에 가져온 데이터셋의 경우 리포지토리의 id로 `config.dataset_name` 을 설정하고, 나만의 이미지를 사용하는 경우 `imagefolder` 를 설정합니다.\n\n🤗 Datasets은 [`~datasets.Image`] 기능을 사용해 자동으로 이미지 데이터를 디코딩하고 [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html)로 불러옵니다. 이를 시각화 해보면:\n\n```py\n>>> import matplotlib.pyplot as plt\n\n>>> fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n>>> for i, image in enumerate(dataset[:4][\"image\"]):\n...     axs[i].imshow(image)\n...     axs[i].set_axis_off()\n>>> fig.show()\n```\n\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_ds.png)\n\n이미지는 모두 다른 사이즈이기 때문에, 우선 전처리가 필요합니다:\n\n-   `Resize` 는 `config.image_size` 에 정의된 이미지 사이즈로 변경합니다.\n-   `RandomHorizontalFlip` 은 랜덤적으로 이미지를 미러링하여 데이터셋을 보강합니다.\n-   `Normalize` 는 모델이 예상하는 [-1, 1] 범위로 픽셀 값을 재조정 하는데 중요합니다.\n\n```py\n>>> from torchvision import transforms\n\n>>> preprocess = transforms.Compose(\n...     [\n...         transforms.Resize((config.image_size, config.image_size)),\n...         transforms.RandomHorizontalFlip(),\n...         transforms.ToTensor(),\n...         transforms.Normalize([0.5], [0.5]),\n...     ]\n... )\n```\n\n 학습 도중에 `preprocess` 함수를 적용하려면 🤗 Datasets의 [`~datasets.Dataset.set_transform`] 방법이 사용됩니다.\n\n```py\n>>> def transform(examples):\n...     images = [preprocess(image.convert(\"RGB\")) for image in examples[\"image\"]]\n...     return {\"images\": images}\n\n\n>>> dataset.set_transform(transform)\n```\n\n이미지의 크기가 조정되었는지 확인하기 위해 이미지를 다시 시각화해보세요. 이제 [DataLoader](https://pytorch.org/docs/stable/data#torch.utils.data.DataLoader)에 데이터셋을 포함해 학습할 준비가 되었습니다!\n\n```py\n>>> import torch\n\n>>> train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n```\n\n\n## UNet2DModel 생성하기\n\n🧨 Diffusers에 사전학습된 모델들은 모델 클래스에서 원하는 파라미터로 쉽게 생성할 수 있습니다. 예를 들어, [`UNet2DModel`]를 생성하려면:\n\n```py\n>>> from diffusers import UNet2DModel\n\n>>> model = UNet2DModel(\n...     sample_size=config.image_size,  # 타겟 이미지 해상도\n...     in_channels=3,  # 입력 채널 수, RGB 이미지에서 3\n...     out_channels=3,  # 출력 채널 수\n...     layers_per_block=2,  # UNet 블럭당 몇 개의 ResNet 레이어가 사용되는지\n...     block_out_channels=(128, 128, 256, 256, 512, 512),  # 각 UNet 블럭을 위한 출력 채널 수\n...     down_block_types=(\n...         \"DownBlock2D\",  # 일반적인 ResNet 다운샘플링 블럭\n...         \"DownBlock2D\",\n...         \"DownBlock2D\",\n...         \"DownBlock2D\",\n...         \"AttnDownBlock2D\",  # spatial self-attention이 포함된 일반적인 ResNet 다운샘플링 블럭\n...         \"DownBlock2D\",\n...     ),\n...     up_block_types=(\n...         \"UpBlock2D\",  # 일반적인 ResNet 업샘플링 블럭\n...         \"AttnUpBlock2D\",  # spatial self-attention이 포함된 일반적인 ResNet 업샘플링 블럭\n...         \"UpBlock2D\",\n...         \"UpBlock2D\",\n...         \"UpBlock2D\",\n...         \"UpBlock2D\",\n...     ),\n... )\n```\n\n샘플의 이미지 크기와 모델 출력 크기가 맞는지 빠르게 확인하기 위한 좋은 아이디어가 있습니다:\n\n```py\n>>> sample_image = dataset[0][\"images\"].unsqueeze(0)\n>>> print(\"Input shape:\", sample_image.shape)\nInput shape: torch.Size([1, 3, 128, 128])\n\n>>> print(\"Output shape:\", model(sample_image, timestep=0).sample.shape)\nOutput shape: torch.Size([1, 3, 128, 128])\n```\n\n훌륭해요! 다음, 이미지에 약간의 노이즈를 더하기 위해 스케줄러가 필요합니다.\n\n\n## 스케줄러 생성하기\n\n스케줄러는 모델을 학습 또는 추론에 사용하는지에 따라 다르게 작동합니다. 추론시에, 스케줄러는 노이즈로부터 이미지를 생성합니다. 학습시 스케줄러는 diffusion 과정에서의 특정 포인트로부터 모델의 출력 또는 샘플을 가져와 *노이즈 스케줄* 과 *업데이트 규칙*에 따라 이미지에 노이즈를 적용합니다.\n\n`DDPMScheduler`를 보면 이전으로부터 `sample_image`에 랜덤한 노이즈를 더하는 `add_noise` 메서드를 사용합니다:\n\n```py\n>>> import torch\n>>> from PIL import Image\n>>> from diffusers import DDPMScheduler\n\n>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)\n>>> noise = torch.randn(sample_image.shape)\n>>> timesteps = torch.LongTensor([50])\n>>> noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)\n\n>>> Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])\n```\n\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/noisy_butterfly.png)\n\n모델의 학습 목적은 이미지에 더해진 노이즈를 예측하는 것입니다. 이 단계에서 손실은 다음과 같이 계산될 수 있습니다:\n\n```py\n>>> import torch.nn.functional as F\n\n>>> noise_pred = model(noisy_image, timesteps).sample\n>>> loss = F.mse_loss(noise_pred, noise)\n```\n\n## 모델 학습하기\n\n지금까지, 모델 학습을 시작하기 위해 많은 부분을 갖추었으며 이제 남은 것은 모든 것을 조합하는 것입니다.\n\n우선 옵티마이저(optimizer)와 학습률 스케줄러(learning rate scheduler)가 필요할 것입니다:\n\n```py\n>>> from diffusers.optimization import get_cosine_schedule_with_warmup\n\n>>> optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)\n>>> lr_scheduler = get_cosine_schedule_with_warmup(\n...     optimizer=optimizer,\n...     num_warmup_steps=config.lr_warmup_steps,\n...     num_training_steps=(len(train_dataloader) * config.num_epochs),\n... )\n```\n\n그 후, 모델을 평가하는 방법이 필요합니다. 평가를 위해, `DDPMPipeline`을 사용해 배치의 이미지 샘플들을 생성하고 그리드 형태로 저장할 수 있습니다:\n\n```py\n>>> from diffusers import DDPMPipeline\n>>> import math\n>>> import os\n\n\n>>> def make_grid(images, rows, cols):\n...     w, h = images[0].size\n...     grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n...     for i, image in enumerate(images):\n...         grid.paste(image, box=(i % cols * w, i // cols * h))\n...     return grid\n\n\n>>> def evaluate(config, epoch, pipeline):\n...     # 랜덤한 노이즈로 부터 이미지를 추출합니다.(이는 역전파 diffusion 과정입니다.)\n...     # 기본 파이프라인 출력 형태는 `List[PIL.Image]` 입니다.\n...     images = pipeline(\n...         batch_size=config.eval_batch_size,\n...         generator=torch.manual_seed(config.seed),\n...     ).images\n\n...     # 이미지들을 그리드로 만들어줍니다.\n...     image_grid = make_grid(images, rows=4, cols=4)\n\n...     # 이미지들을 저장합니다.\n...     test_dir = os.path.join(config.output_dir, \"samples\")\n...     os.makedirs(test_dir, exist_ok=True)\n...     image_grid.save(f\"{test_dir}/{epoch:04d}.png\")\n```\n\nTensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽게 수행하기 위해 🤗 Accelerate를 학습 루프에 함께 앞서 말한 모든 구성 정보들을 묶어 진행할 수 있습니다. 허브에 모델을 업로드 하기 위해 리포지토리 이름 및 정보를 가져오기 위한 함수를 작성하고 허브에 업로드할 수 있습니다.\n\n💡아래의 학습 루프는 어렵고 길어 보일 수 있지만, 나중에 한 줄의 코드로 학습을 한다면 그만한 가치가 있을 것입니다! 만약 기다리지 못하고 이미지를 생성하고 싶다면, 아래 코드를 자유롭게 붙여넣고 작동시키면 됩니다. 🤗\n\n```py\n>>> from accelerate import Accelerator\n>>> from huggingface_hub import create_repo, upload_folder\n>>> from tqdm.auto import tqdm\n>>> from pathlib import Path\n>>> import os\n\n\n>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):\n...     # Initialize accelerator and tensorboard logging\n...     accelerator = Accelerator(\n...         mixed_precision=config.mixed_precision,\n...         gradient_accumulation_steps=config.gradient_accumulation_steps,\n...         log_with=\"tensorboard\",\n...         project_dir=os.path.join(config.output_dir, \"logs\"),\n...     )\n...     if accelerator.is_main_process:\n...         if config.output_dir is not None:\n...             os.makedirs(config.output_dir, exist_ok=True)\n...         if config.push_to_hub:\n...             repo_id = create_repo(\n...                 repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True\n...             ).repo_id\n...         accelerator.init_trackers(\"train_example\")\n\n...     # 모든 것이 준비되었습니다.\n...     # 기억해야 할 특정한 순서는 없으며 준비한 방법에 제공한 것과 동일한 순서로 객체의 압축을 풀면 됩니다.\n...     model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n...         model, optimizer, train_dataloader, lr_scheduler\n...     )\n\n...     global_step = 0\n\n...     # 이제 모델을 학습합니다.\n...     for epoch in range(config.num_epochs):\n...         progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)\n...         progress_bar.set_description(f\"Epoch {epoch}\")\n\n...         for step, batch in enumerate(train_dataloader):\n...             clean_images = batch[\"images\"]\n...             # 이미지에 더할 노이즈를 샘플링합니다.\n...             noise = torch.randn(clean_images.shape, device=clean_images.device)\n...             bs = clean_images.shape[0]\n\n...             # 각 이미지를 위한 랜덤한 타임스텝(timestep)을 샘플링합니다.\n...             timesteps = torch.randint(\n...                 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,\n...                 dtype=torch.int64\n...             )\n\n...             # 각 타임스텝의 노이즈 크기에 따라 깨끗한 이미지에 노이즈를 추가합니다.\n...             # (이는 foward diffusion 과정입니다.)\n...             noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)\n\n...             with accelerator.accumulate(model):\n...                 # 노이즈를 반복적으로 예측합니다.\n...                 noise_pred = model(noisy_images, timesteps, return_dict=False)[0]\n...                 loss = F.mse_loss(noise_pred, noise)\n...                 accelerator.backward(loss)\n\n...                 accelerator.clip_grad_norm_(model.parameters(), 1.0)\n...                 optimizer.step()\n...                 lr_scheduler.step()\n...                 optimizer.zero_grad()\n\n...             progress_bar.update(1)\n...             logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0], \"step\": global_step}\n...             progress_bar.set_postfix(**logs)\n...             accelerator.log(logs, step=global_step)\n...             global_step += 1\n\n...         # 각 에포크가 끝난 후 evaluate()와 몇 가지 데모 이미지를 선택적으로 샘플링하고 모델을 저장합니다.\n...         if accelerator.is_main_process:\n...             pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)\n\n...             if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:\n...                 evaluate(config, epoch, pipeline)\n\n...             if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:\n...                 if config.push_to_hub:\n...                     upload_folder(\n...                         repo_id=repo_id,\n...                         folder_path=config.output_dir,\n...                         commit_message=f\"Epoch {epoch}\",\n...                         ignore_patterns=[\"step_*\", \"epoch_*\"],\n...                     )\n...                 else:\n...                     pipeline.save_pretrained(config.output_dir)\n```\n\n휴, 코드가 꽤 많았네요! 하지만 🤗 Accelerate의 [`~accelerate.notebook_launcher`] 함수와 학습을 시작할 준비가 되었습니다. 함수에 학습 루프, 모든 학습 인수, 학습에 사용할 프로세스 수(사용 가능한 GPU의 수를 변경할 수 있음)를 전달합니다:\n\n```py\n>>> from accelerate import notebook_launcher\n\n>>> args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\n\n>>> notebook_launcher(train_loop, args, num_processes=1)\n```\n\n한번 학습이 완료되면, diffusion 모델로 생성된 최종 🦋이미지🦋를 확인해보길 바랍니다!\n\n```py\n>>> import glob\n\n>>> sample_images = sorted(glob.glob(f\"{config.output_dir}/samples/*.png\"))\n>>> Image.open(sample_images[-1])\n```\n\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_final.png)\n\n## 다음 단계\n\nUnconditional 이미지 생성은 학습될 수 있는 작업 중 하나의 예시입니다. 다른 작업과 학습 방법은 [🧨 Diffusers 학습 예시](../training/overview) 페이지에서 확인할 수 있습니다. 다음은 학습할 수 있는 몇 가지 예시입니다:\n\n-   [Textual Inversion](../training/text_inversion), 특정 시각적 개념을 학습시켜 생성된 이미지에 통합시키는 알고리즘입니다.\n-   [DreamBooth](../training/dreambooth), 주제에 대한 몇 가지 입력 이미지들이 주어지면 주제에 대한 개인화된 이미지를 생성하기 위한 기술입니다.\n-   [Guide](../training/text2image) 데이터셋에 Stable Diffusion 모델을 파인튜닝하는 방법입니다.\n-   [Guide](../training/lora)  LoRA를 사용해 매우 큰 모델을 빠르게 파인튜닝하기 위한 메모리 효율적인 기술입니다.\n"
  },
  {
    "path": "docs/source/ko/tutorials/tutorial_overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Overview\n\n🧨 Diffusers에 오신 걸 환영합니다! 여러분이 diffusion 모델과 생성 AI를 처음 접하고, 더 많은 걸 배우고 싶으셨다면 제대로 찾아오셨습니다. 이 튜토리얼은 diffusion model을 여러분에게 젠틀하게 소개하고, 라이브러리의 기본 사항(핵심 구성요소와 🧨 Diffusers 사용법)을 이해하는 데 도움이 되도록 설계되었습니다.\n\n여러분은 이 튜토리얼을 통해 빠르게 생성하기 위해선 추론 파이프라인을 어떻게 사용해야 하는지, 그리고 라이브러리를 modular toolbox처럼 이용해서 여러분만의 diffusion system을 구축할 수 있도록 파이프라인을 분해하는 법을 배울 수 있습니다. 다음 단원에서는 여러분이 원하는 것을 생성하기 위해 자신만의 diffusion model을 학습하는 방법을 배우게 됩니다.\n\n튜토리얼을 완료한다면 여러분은 라이브러리를 직접 탐색하고, 자신의 프로젝트와 애플리케이션에 적용할 스킬들을 습득할 수 있을 겁니다.\n\n[Discord](https://discord.com/invite/JfAtkvEtRb)나 [포럼](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) 커뮤니티에 자유롭게 참여해서 다른 사용자와 개발자들과 교류하고 협업해 보세요!\n\n자 지금부터 diffusing을 시작해 보겠습니다! 🧨"
  },
  {
    "path": "docs/source/ko/using-diffusers/conditional_image_generation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 조건부 이미지 생성\n\n[[open-in-colab]]\n\n조건부 이미지 생성을 사용하면 텍스트 프롬프트에서 이미지를 생성할 수 있습니다. 텍스트는 임베딩으로 변환되며, 임베딩은 노이즈에서 이미지를 생성하도록 모델을 조건화하는 데 사용됩니다.\n\n[`DiffusionPipeline`]은 추론을 위해 사전 훈련된 diffusion 시스템을 사용하는 가장 쉬운 방법입니다.\n\n먼저 [`DiffusionPipeline`]의 인스턴스를 생성하고 다운로드할 파이프라인 [체크포인트](https://huggingface.co/models?library=diffusers&sort=downloads)를 지정합니다.\n\n이 가이드에서는 [잠재 Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)과 함께 텍스트-이미지 생성에 [`DiffusionPipeline`]을 사용합니다:\n\n```python\n>>> from diffusers import DiffusionPipeline\n\n>>> generator = DiffusionPipeline.from_pretrained(\"CompVis/ldm-text2im-large-256\")\n```\n\n[`DiffusionPipeline`]은 모든 모델링, 토큰화, 스케줄링 구성 요소를 다운로드하고 캐시합니다.\n이 모델은 약 14억 개의 파라미터로 구성되어 있기 때문에 GPU에서 실행할 것을 강력히 권장합니다.\nPyTorch에서와 마찬가지로 생성기 객체를 GPU로 이동할 수 있습니다:\n\n```python\n>>> generator.to(\"cuda\")\n```\n\n이제 텍스트 프롬프트에서 `생성기`를 사용할 수 있습니다:\n\n```python\n>>> image = generator(\"An image of a squirrel in Picasso style\").images[0]\n```\n\n출력값은 기본적으로 [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) 객체로 래핑됩니다.\n\n호출하여 이미지를 저장할 수 있습니다:\n\n```python\n>>> image.save(\"image_of_squirrel_painting.png\")\n```\n\n아래 스페이스를 사용해보고 안내 배율 매개변수를 자유롭게 조정하여 이미지 품질에 어떤 영향을 미치는지 확인해 보세요!\n\n<iframe\n\tsrc=\"https://stabilityai-stable-diffusion.hf.space\"\n\tframeborder=\"0\"\n\twidth=\"850\"\n\theight=\"500\"\n></iframe>"
  },
  {
    "path": "docs/source/ko/using-diffusers/controlling_generation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 제어된 생성\n\nDiffusion 모델에 의해 생성된 출력을 제어하는 것은 커뮤니티에서 오랫동안 추구해 왔으며 현재 활발한 연구 주제입니다. 널리 사용되는 많은 diffusion 모델에서는 이미지와 텍스트 프롬프트 등 입력의 미묘한 변화로 인해 출력이 크게 달라질 수 있습니다. 이상적인 세계에서는 의미가 유지되고 변경되는 방식을 제어할 수 있기를 원합니다.\n\n의미 보존의 대부분의 예는 입력의 변화를 출력의 변화에 정확하게 매핑하는 것으로 축소됩니다. 즉, 프롬프트에서 피사체에 형용사를 추가하면 전체 이미지가 보존되고 변경된 피사체만 수정됩니다. 또는 특정 피사체의 이미지를 변형하면 피사체의 포즈가 유지됩니다.\n\n추가적으로 생성된 이미지의 품질에는 의미 보존 외에도 영향을 미치고자 하는 품질이 있습니다. 즉, 일반적으로 결과물의 품질이 좋거나 특정 스타일을 고수하거나 사실적이기를 원합니다.\n\ndiffusion 모델 생성을 제어하기 위해 `diffusers`가 지원하는 몇 가지 기술을 문서화합니다. 많은 부분이 최첨단 연구이며 미묘한 차이가 있을 수 있습니다. 명확한 설명이 필요하거나 제안 사항이 있으면 주저하지 마시고 [포럼](https://discuss.huggingface.co/) 또는 [GitHub 이슈](https://github.com/huggingface/diffusers/issues)에서 토론을 시작하세요.\n\n생성 제어 방법에 대한 개략적인 설명과 기술 개요를 제공합니다. 기술에 대한 자세한 설명은 파이프라인에서 링크된 원본 논문을 참조하는 것이 가장 좋습니다.\n\n사용 사례에 따라 적절한 기술을 선택해야 합니다. 많은 경우 이러한 기법을 결합할 수 있습니다. 예를 들어, 텍스트 반전과 SEGA를 결합하여 텍스트 반전을 사용하여 생성된 출력에 더 많은 의미적 지침을 제공할 수 있습니다.\n\n별도의 언급이 없는 한, 이러한 기법은 기존 모델과 함께 작동하며 자체 가중치가 필요하지 않은 기법입니다.\n\n1. [Instruct Pix2Pix](#instruct-pix2pix)\n2. [Pix2Pix Zero](#pix2pixzero)\n3. [Attend and Excite](#attend-and-excite)\n4. [Semantic Guidance](#semantic-guidance)\n5. [Self-attention Guidance](#self-attention-guidance)\n6. [Depth2Image](#depth2image)\n7. [MultiDiffusion Panorama](#multidiffusion-panorama)\n8. [DreamBooth](#dreambooth)\n9. [Textual Inversion](#textual-inversion)\n10. [ControlNet](#controlnet)\n11. [Prompt Weighting](#prompt-weighting)\n12. [Custom Diffusion](#custom-diffusion)\n13. [Model Editing](#model-editing)\n14. [DiffEdit](#diffedit)\n15. [T2I-Adapter](#t2i-adapter)\n\n편의를 위해, 추론만 하거나 파인튜닝/학습하는 방법에 대한 표를 제공합니다.\n\n|                     **Method**                      | **Inference only** | **Requires training /<br> fine-tuning** |                                          **Comments**                                           |\n| :-------------------------------------------------: | :----------------: | :-------------------------------------: | :---------------------------------------------------------------------------------------------: |\n|        [Instruct Pix2Pix](#instruct-pix2pix)        |         ✅         |                   ❌                    | Can additionally be<br>fine-tuned for better <br>performance on specific <br>edit instructions. |\n|            [Pix2Pix Zero](#pix2pixzero)             |         ✅         |                   ❌                    |                                                                                                 |\n|       [Attend and Excite](#attend-and-excite)       |         ✅         |                   ❌                    |                                                                                                 |\n|       [Semantic Guidance](#semantic-guidance)       |         ✅         |                   ❌                    |                                                                                                 |\n| [Self-attention Guidance](#self-attention-guidance) |         ✅         |                   ❌                    |                                                                                                 |\n|             [Depth2Image](#depth2image)             |         ✅         |                   ❌                    |                                                                                                 |\n| [MultiDiffusion Panorama](#multidiffusion-panorama) |         ✅         |                   ❌                    |                                                                                                 |\n|              [DreamBooth](#dreambooth)              |         ❌         |                   ✅                    |                                                                                                 |\n|       [Textual Inversion](#textual-inversion)       |         ❌         |                   ✅                    |                                                                                                 |\n|              [ControlNet](#controlnet)              |         ✅         |                   ❌                    |             A ControlNet can be <br>trained/fine-tuned on<br>a custom conditioning.             |\n|        [Prompt Weighting](#prompt-weighting)        |         ✅         |                   ❌                    |                                                                                                 |\n|        [Custom Diffusion](#custom-diffusion)        |         ❌         |                   ✅                    |                                                                                                 |\n|           [Model Editing](#model-editing)           |         ✅         |                   ❌                    |                                                                                                 |\n|                [DiffEdit](#diffedit)                |         ✅         |                   ❌                    |                                                                                                 |\n|             [T2I-Adapter](#t2i-adapter)             |         ✅         |                   ❌                    |                                                                                                 |\n\n## Pix2Pix Instruct\n\n[Paper](https://huggingface.co/papers/2211.09800)\n\n[Instruct Pix2Pix](../api/pipelines/stable_diffusion/pix2pix) 는 입력 이미지 편집을 지원하기 위해 stable diffusion에서 미세-조정되었습니다. 이미지와 편집을 설명하는 프롬프트를 입력으로 받아 편집된 이미지를 출력합니다.\nInstruct Pix2Pix는 [InstructGPT](https://openai.com/blog/instruction-following/)와 같은 프롬프트와 잘 작동하도록 명시적으로 훈련되었습니다.\n\n사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion/pix2pix)를 참조하세요.\n\n## Pix2Pix Zero\n\n[Paper](https://huggingface.co/papers/2302.03027)\n\n[Pix2Pix Zero](../api/pipelines/stable_diffusion/pix2pix_zero)를 사용하면 일반적인 이미지 의미를 유지하면서 한 개념이나 피사체가 다른 개념이나 피사체로 변환되도록 이미지를 수정할 수 있습니다.\n\n노이즈 제거 프로세스는 한 개념적 임베딩에서 다른 개념적 임베딩으로 안내됩니다. 중간 잠복(intermediate latents)은 디노이징(denoising?) 프로세스 중에 최적화되어 참조 주의 지도(reference attention maps)를 향해 나아갑니다. 참조 주의 지도(reference attention maps)는 입력 이미지의 노이즈 제거(?) 프로세스에서 나온 것으로 의미 보존을 장려하는 데 사용됩니다.\n\nPix2Pix Zero는 합성 이미지와 실제 이미지를 편집하는 데 모두 사용할 수 있습니다.\n\n- 합성 이미지를 편집하려면 먼저 캡션이 지정된 이미지를 생성합니다.\n  다음으로 편집할 컨셉과 새로운 타겟 컨셉에 대한 이미지 캡션을 생성합니다. 이를 위해 [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)와 같은 모델을 사용할 수 있습니다. 그런 다음 텍스트 인코더를 통해 소스 개념과 대상 개념 모두에 대한 \"평균\" 프롬프트 임베딩을 생성합니다. 마지막으로, 합성 이미지를 편집하기 위해 pix2pix-zero 알고리즘을 사용합니다.\n- 실제 이미지를 편집하려면 먼저 [BLIP](https://huggingface.co/docs/transformers/model_doc/blip)과 같은 모델을 사용하여 이미지 캡션을 생성합니다. 그런 다음 프롬프트와 이미지에 ddim 반전을 적용하여 \"역(inverse)\" latents을 생성합니다. 이전과 마찬가지로 소스 및 대상 개념 모두에 대한 \"평균(mean)\" 프롬프트 임베딩이 생성되고 마지막으로 \"역(inverse)\" latents와 결합된 pix2pix-zero 알고리즘이 이미지를 편집하는 데 사용됩니다.\n\n> [!TIP]\n> Pix2Pix Zero는 '제로 샷(zero-shot)' 이미지 편집이 가능한 최초의 모델입니다.\n> 즉, 이 모델은 다음과 같이 일반 소비자용 GPU에서 1분 이내에 이미지를 편집할 수 있습니다(../api/pipelines/stable_diffusion/pix2pix_zero#usage-example).\n\n위에서 언급했듯이 Pix2Pix Zero에는 특정 개념으로 세대를 유도하기 위해 (UNet, VAE 또는 텍스트 인코더가 아닌) latents을 최적화하는 기능이 포함되어 있습니다.즉, 전체 파이프라인에 표준 [StableDiffusionPipeline](../api/pipelines/stable_diffusion/text2img)보다 더 많은 메모리가 필요할 수 있습니다.\n\n사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion/pix2pix_zero)를 참조하세요.\n\n## Attend and Excite\n\n[Paper](https://huggingface.co/papers/2301.13826)\n\n[Attend and Excite](../api/pipelines/stable_diffusion/attend_and_excite)를 사용하면 프롬프트의 피사체가 최종 이미지에 충실하게 표현되도록 할 수 있습니다.\n\n이미지에 존재해야 하는 프롬프트의 피사체에 해당하는 일련의 토큰 인덱스가 입력으로 제공됩니다. 노이즈 제거 중에 각 토큰 인덱스는 이미지의 최소 한 패치 이상에 대해 최소 주의 임계값을 갖도록 보장됩니다. 모든 피사체 토큰에 대해 주의 임계값이 통과될 때까지 노이즈 제거 프로세스 중에 중간 잠복기가 반복적으로 최적화되어 가장 소홀히 취급되는 피사체 토큰의 주의력을 강화합니다.\n\nPix2Pix Zero와 마찬가지로 Attend and Excite 역시 파이프라인에 미니 최적화 루프(사전 학습된 가중치를 그대로 둔 채)가 포함되며, 일반적인 'StableDiffusionPipeline'보다 더 많은 메모리가 필요할 수 있습니다.\n\n사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion/attend_and_excite)를 참조하세요.\n\n## Semantic Guidance (SEGA)\n\n[Paper](https://huggingface.co/papers/2301.12247)\n\n의미유도(SEGA)를 사용하면 이미지에서 하나 이상의 컨셉을 적용하거나 제거할 수 있습니다. 컨셉의 강도도 조절할 수 있습니다. 즉, 스마일 컨셉을 사용하여 인물 사진의 스마일을 점진적으로 늘리거나 줄일 수 있습니다.\n\n분류기 무료 안내(classifier free guidance)가 빈 프롬프트 입력을 통해 안내를 제공하는 방식과 유사하게, SEGA는 개념 프롬프트에 대한 안내를 제공합니다. 이러한 개념 프롬프트는 여러 개를 동시에 적용할 수 있습니다. 각 개념 프롬프트는 안내가 긍정적으로 적용되는지 또는 부정적으로 적용되는지에 따라 해당 개념을 추가하거나 제거할 수 있습니다.\n\nPix2Pix Zero 또는 Attend and Excite와 달리 SEGA는 명시적인 그라데이션 기반 최적화를 수행하는 대신 확산 프로세스와 직접 상호 작용합니다.\n\n사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/semantic_stable_diffusion)를 참조하세요.\n\n## Self-attention Guidance (SAG)\n\n[Paper](https://huggingface.co/papers/2210.00939)\n\n[자기 주의 안내](../api/pipelines/stable_diffusion/self_attention_guidance)는 이미지의 전반적인 품질을 개선합니다.\n\nSAG는 고빈도 세부 정보를 기반으로 하지 않은 예측에서 완전히 조건화된 이미지에 이르기까지 가이드를 제공합니다. 고빈도 디테일은 UNet 자기 주의 맵에서 추출됩니다.\n\n사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion/self_attention_guidance)를 참조하세요.\n\n## Depth2Image\n\n[Project](https://huggingface.co/stabilityai/stable-diffusion-2-depth)\n\n[Depth2Image](../pipelines/stable_diffusion_2#depthtoimage)는 텍스트 안내 이미지 변화에 대한 시맨틱을 더 잘 보존하도록 안정적 확산에서 미세 조정되었습니다.\n\n원본 이미지의 단안(monocular) 깊이 추정치를 조건으로 합니다.\n\n사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion_2#depthtoimage)를 참조하세요.\n\n> [!TIP]\n> InstructPix2Pix와 Pix2Pix Zero와 같은 방법의 중요한 차이점은 전자의 경우\n> 는 사전 학습된 가중치를 미세 조정하는 반면, 후자는 그렇지 않다는 것입니다. 즉, 다음을 수행할 수 있습니다.\n> 사용 가능한 모든 안정적 확산 모델에 Pix2Pix Zero를 적용할 수 있습니다.\n\n## MultiDiffusion Panorama\n\n[Paper](https://huggingface.co/papers/2302.08113)\n\nMultiDiffusion은 사전 학습된 diffusion model을 통해 새로운 생성 프로세스를 정의합니다. 이 프로세스는 고품질의 다양한 이미지를 생성하는 데 쉽게 적용할 수 있는 여러 diffusion 생성 방법을 하나로 묶습니다. 결과는 원하는 종횡비(예: 파노라마) 및 타이트한 분할 마스크에서 바운딩 박스에 이르는 공간 안내 신호와 같은 사용자가 제공한 제어를 준수합니다.\n[MultiDiffusion 파노라마](../api/pipelines/stable_diffusion/panorama)를 사용하면 임의의 종횡비(예: 파노라마)로 고품질 이미지를 생성할 수 있습니다.\n\n파노라마 이미지를 생성하는 데 사용하는 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion/panorama)를 참조하세요.\n\n## 나만의 모델 파인튜닝\n\n사전 학습된 모델 외에도 Diffusers는 사용자가 제공한 데이터에 대해 모델을 파인튜닝할 수 있는 학습 스크립트가 있습니다.\n\n## DreamBooth\n\n[DreamBooth](../training/dreambooth)는 모델을 파인튜닝하여 새로운 주제에 대해 가르칩니다. 즉, 한 사람의 사진 몇 장을 사용하여 다양한 스타일로 그 사람의 이미지를 생성할 수 있습니다.\n\n사용 방법에 대한 자세한 내용은 [여기](../training/dreambooth)를 참조하세요.\n\n## Textual Inversion\n\n[Textual Inversion](../training/text_inversion)은 모델을 파인튜닝하여 새로운 개념에 대해 학습시킵니다. 즉, 특정 스타일의 아트웍 사진 몇 장을 사용하여 해당 스타일의 이미지를 생성할 수 있습니다.\n\n사용 방법에 대한 자세한 내용은 [여기](../training/text_inversion)를 참조하세요.\n\n## ControlNet\n\n[Paper](https://huggingface.co/papers/2302.05543)\n\n[ControlNet](../api/pipelines/stable_diffusion/controlnet)은 추가 조건을 추가하는 보조 네트워크입니다.\n가장자리 감지, 낙서, 깊이 맵, 의미적 세그먼트와 같은 다양한 조건에 대해 훈련된 8개의 표준 사전 훈련된 ControlNet이 있습니다,\n깊이 맵, 시맨틱 세그먼테이션과 같은 다양한 조건으로 훈련된 8개의 표준 제어망이 있습니다.\n\n사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion/controlnet)를 참조하세요.\n\n## Prompt Weighting\n\n프롬프트 가중치는 텍스트의 특정 부분에 더 많은 관심 가중치를 부여하는 간단한 기법입니다.\n입력에 가중치를 부여하는 간단한 기법입니다.\n\n자세한 설명과 예시는 [여기](../using-diffusers/weighted_prompts)를 참조하세요.\n\n## Custom Diffusion\n\n[Custom Diffusion](../training/custom_diffusion)은 사전 학습된 text-to-image 간 확산 모델의 교차 관심도 맵만 미세 조정합니다.\n또한 textual inversion을 추가로 수행할 수 있습니다. 설계상 다중 개념 훈련을 지원합니다.\nDreamBooth 및 Textual Inversion 마찬가지로, 사용자 지정 확산은 사전학습된 text-to-image diffusion 모델에 새로운 개념을 학습시켜 관심 있는 개념과 관련된 출력을 생성하는 데에도 사용됩니다.\n\n자세한 설명은 [공식 문서](../training/custom_diffusion)를 참조하세요.\n\n## Model Editing\n\n[Paper](https://huggingface.co/papers/2303.08084)\n\n[텍스트-이미지 모델 편집 파이프라인](../api/pipelines/model_editing)을 사용하면 사전학습된 text-to-image diffusion 모델이 입력 프롬프트에 있는 피사체에 대해 내릴 수 있는 잘못된 암시적 가정을 완화하는 데 도움이 됩니다.\n예를 들어, 안정적 확산에 \"A pack of roses\"에 대한 이미지를 생성하라는 메시지를 표시하면 생성된 이미지의 장미는 빨간색일 가능성이 높습니다. 이 파이프라인은 이러한 가정을 변경하는 데 도움이 됩니다.\n\n자세한 설명은 [공식 문서](../api/pipelines/model_editing)를 참조하세요.\n\n## DiffEdit\n\n[Paper](https://huggingface.co/papers/2210.11427)\n\n[DiffEdit](../api/pipelines/diffedit)를 사용하면 원본 입력 이미지를 최대한 보존하면서 입력 프롬프트와 함께 입력 이미지의 의미론적 편집이 가능합니다.\n\n\n자세한 설명은 [공식 문서](../api/pipelines/diffedit)를 참조하세요.\n\n## T2I-Adapter\n\n[Paper](https://huggingface.co/papers/2302.08453)\n\n[T2I-어댑터](../api/pipelines/stable_diffusion/adapter)는 추가적인 조건을 추가하는 auxiliary 네트워크입니다.\n가장자리 감지, 스케치, depth maps, semantic segmentations와 같은 다양한 조건에 대해 훈련된 8개의 표준 사전훈련된 adapter가 있습니다,\n\n[공식 문서](api/pipelines/stable_diffusion/adapter)에서 사용 방법에 대한 정보를 참조하세요."
  },
  {
    "path": "docs/source/ko/using-diffusers/custom_pipeline_overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 커스텀 파이프라인 불러오기\n\n[[open-in-colab]]\n\n커뮤니티 파이프라인은 논문에 명시된 원래의 구현체와 다른 형태로 구현된 모든 [`DiffusionPipeline`] 클래스를 의미합니다. (예를 들어, [`StableDiffusionControlNetPipeline`]는 [\"Text-to-Image Generation with ControlNet Conditioning\"](https://huggingface.co/papers/2302.05543) 해당) 이들은 추가 기능을 제공하거나 파이프라인의 원래 구현을 확장합니다.\n\n[Speech to Image](https://github.com/huggingface/diffusers/tree/main/examples/community#speech-to-image) 또는 [Composable Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#composable-stable-diffusion) 과 같은 멋진 커뮤니티 파이프라인이 많이 있으며 [여기에서](https://github.com/huggingface/diffusers/tree/main/examples/community) 모든 공식 커뮤니티 파이프라인을 찾을 수 있습니다.\n\n허브에서 커뮤니티 파이프라인을 로드하려면, 커뮤니티 파이프라인의 리포지토리 ID와 (파이프라인 가중치 및 구성 요소를 로드하려는) 모델의 리포지토리 ID를 인자로 전달해야 합니다. 예를 들어, 아래 예시에서는 `hf-internal-testing/diffusers-dummy-pipeline`에서 더미 파이프라인을 불러오고, `google/ddpm-cifar10-32`에서 파이프라인의 가중치와 컴포넌트들을 로드합니다.\n\n> [!WARNING]\n> 🔒 허깅 페이스 허브에서 커뮤니티 파이프라인을 불러오는 것은 곧 해당 코드가 안전하다고 신뢰하는 것입니다. 코드를 자동으로 불러오고 실행하기 앞서 반드시 온라인으로 해당 코드의 신뢰성을 검사하세요!\n\n```py\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"google/ddpm-cifar10-32\", custom_pipeline=\"hf-internal-testing/diffusers-dummy-pipeline\"\n)\n```\n\n공식 커뮤니티 파이프라인을 불러오는 것은 비슷하지만, 공식 리포지토리 ID에서 가중치를 불러오는 것과 더불어 해당 파이프라인 내의 컴포넌트를 직접 지정하는 것 역시 가능합니다. 아래 예제를 보면 커뮤니티 [CLIP Guided Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#clip-guided-stable-diffusion) 파이프라인을 로드할 때, 해당 파이프라인에서 사용할 `clip_model` 컴포넌트와 `feature_extractor` 컴포넌트를 직접 설정하는 것을 확인할 수 있습니다.\n\n```py\nfrom diffusers import DiffusionPipeline\nfrom transformers import CLIPImageProcessor, CLIPModel\n\nclip_model_id = \"laion/CLIP-ViT-B-32-laion2B-s34B-b79K\"\n\nfeature_extractor = CLIPImageProcessor.from_pretrained(clip_model_id)\nclip_model = CLIPModel.from_pretrained(clip_model_id)\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    custom_pipeline=\"clip_guided_stable_diffusion\",\n    clip_model=clip_model,\n    feature_extractor=feature_extractor,\n)\n```\n\n커뮤니티 파이프라인에 대한 자세한 내용은 [커뮤니티 파이프라인](https://github.com/huggingface/diffusers/blob/main/docs/source/en/using-diffusers/custom_pipeline_examples) 가이드를 살펴보세요. 커뮤니티 파이프라인 등록에 관심이 있는 경우 [커뮤니티 파이프라인에 기여하는 방법](https://github.com/huggingface/diffusers/blob/main/docs/source/en/using-diffusers/contribute_pipeline)에 대한 가이드를 확인하세요 !"
  },
  {
    "path": "docs/source/ko/using-diffusers/depth2img.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Text-guided depth-to-image 생성\n\n[[open-in-colab]]\n\n[`StableDiffusionDepth2ImgPipeline`]을 사용하면 텍스트 프롬프트와 초기 이미지를 전달하여 새 이미지의 생성을 조절할 수 있습니다. 또한 이미지 구조를 보존하기 위해 `depth_map`을 전달할 수도 있습니다. `depth_map`이 제공되지 않으면 파이프라인은 통합된 [depth-estimation model](https://github.com/isl-org/MiDaS)을 통해 자동으로 깊이를 예측합니다.\n\n\n먼저 [`StableDiffusionDepth2ImgPipeline`]의 인스턴스를 생성합니다:\n\n```python\nimport torch\nimport requests\nfrom PIL import Image\n\nfrom diffusers import StableDiffusionDepth2ImgPipeline\n\npipe = StableDiffusionDepth2ImgPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-2-depth\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n```\n\n이제 프롬프트를 파이프라인에 전달합니다. 특정 단어가 이미지 생성을 가이드 하는것을 방지하기 위해 `negative_prompt`를 전달할 수도 있습니다:\n\n```python\nurl = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\ninit_image = Image.open(requests.get(url, stream=True).raw)\nprompt = \"two tigers\"\nn_prompt = \"bad, deformed, ugly, bad anatomy\"\nimage = pipe(prompt=prompt, image=init_image, negative_prompt=n_prompt, strength=0.7).images[0]\nimage\n```\n\n| Input                                                                           | Output                                                                                                                                |\n|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|\n| <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/coco-cats.png\" width=\"500\"/> | <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/depth2img-tigers.png\" width=\"500\"/> |\n\n아래의 Spaces를 가지고 놀며 depth map이 있는 이미지와 없는 이미지의 차이가 있는지 확인해 보세요!\n\n<iframe\n\tsrc=\"https://radames-stable-diffusion-depth2img.hf.space\"\n\tframeborder=\"0\"\n\twidth=\"850\"\n\theight=\"500\"\n></iframe>\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/diffedit.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# DiffEdit\n\n[[open-in-colab]]\n\n이미지 편집을 하려면 일반적으로 편집할 영역의 마스크를 제공해야 합니다. DiffEdit는 텍스트 쿼리를 기반으로 마스크를 자동으로 생성하므로 이미지 편집 소프트웨어 없이도 마스크를 만들기가 전반적으로 더 쉬워집니다. DiffEdit 알고리즘은 세 단계로 작동합니다:\n\n1. Diffusion 모델이 일부 쿼리 텍스트와 참조 텍스트를 조건부로 이미지의 노이즈를 제거하여 이미지의 여러 영역에 대해 서로 다른 노이즈 추정치를 생성하고, 그 차이를 사용하여 쿼리 텍스트와 일치하도록 이미지의 어느 영역을 변경해야 하는지 식별하기 위한 마스크를 추론합니다.\n2. 입력 이미지가 DDIM을 사용하여 잠재 공간으로 인코딩됩니다.\n3. 마스크 외부의 픽셀이 입력 이미지와 동일하게 유지되도록 마스크를 가이드로 사용하여 텍스트 쿼리에 조건이 지정된 diffusion 모델로 latents를 디코딩합니다.\n\n이 가이드에서는 마스크를 수동으로 만들지 않고 DiffEdit를 사용하여 이미지를 편집하는 방법을 설명합니다.\n\n시작하기 전에 다음 라이브러리가 설치되어 있는지 확인하세요:\n\n```py\n# Colab에서 필요한 라이브러리를 설치하기 위해 주석을 제외하세요\n#!pip install -q diffusers transformers accelerate\n```\n\n[`StableDiffusionDiffEditPipeline`]에는 이미지 마스크와 부분적으로 반전된 latents 집합이 필요합니다. 이미지 마스크는 [`~StableDiffusionDiffEditPipeline.generate_mask`] 함수에서 생성되며, 두 개의 파라미터인 `source_prompt`와 `target_prompt`가 포함됩니다. 이 매개변수는 이미지에서 무엇을 편집할지 결정합니다. 예를 들어, *과일* 한 그릇을 *배* 한 그릇으로 변경하려면 다음과 같이 하세요:\n\n```py\nsource_prompt = \"a bowl of fruits\"\ntarget_prompt = \"a bowl of pears\"\n```\n\n부분적으로 반전된 latents는 [`~StableDiffusionDiffEditPipeline.invert`] 함수에서 생성되며, 일반적으로 이미지를 설명하는 `prompt` 또는 *캡션*을 포함하는 것이 inverse latent sampling 프로세스를 가이드하는 데 도움이 됩니다. 캡션은 종종 `source_prompt`가 될 수 있지만, 다른 텍스트 설명으로 자유롭게 실험해 보세요!\n\n파이프라인, 스케줄러, 역 스케줄러를 불러오고 메모리 사용량을 줄이기 위해 몇 가지 최적화를 활성화해 보겠습니다:\n\n```py\nimport torch\nfrom diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionDiffEditPipeline\n\npipeline = StableDiffusionDiffEditPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-2-1\",\n    torch_dtype=torch.float16,\n    safety_checker=None,\n    use_safetensors=True,\n)\npipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)\npipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)\npipeline.enable_model_cpu_offload()\npipeline.enable_vae_slicing()\n```\n\n수정하기 위한 이미지를 불러옵니다:\n\n```py\nfrom diffusers.utils import load_image, make_image_grid\n\nimg_url = \"https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png\"\nraw_image = load_image(img_url).resize((768, 768))\nraw_image\n```\n\n이미지 마스크를 생성하기 위해 [`~StableDiffusionDiffEditPipeline.generate_mask`] 함수를 사용합니다. 이미지에서 편집할 내용을 지정하기 위해 `source_prompt`와 `target_prompt`를 전달해야 합니다:\n\n```py\nfrom PIL import Image\n\nsource_prompt = \"a bowl of fruits\"\ntarget_prompt = \"a basket of pears\"\nmask_image = pipeline.generate_mask(\n    image=raw_image,\n    source_prompt=source_prompt,\n    target_prompt=target_prompt,\n)\nImage.fromarray((mask_image.squeeze()*255).astype(\"uint8\"), \"L\").resize((768, 768))\n```\n\n다음으로, 반전된 latents를 생성하고 이미지를 묘사하는 캡션에 전달합니다:\n\n```py\ninv_latents = pipeline.invert(prompt=source_prompt, image=raw_image).latents\n```\n\n마지막으로, 이미지 마스크와 반전된 latents를 파이프라인에 전달합니다. `target_prompt`는 이제 `prompt`가 되며, `source_prompt`는 `negative_prompt`로 사용됩니다.\n\n```py\noutput_image = pipeline(\n    prompt=target_prompt,\n    mask_image=mask_image,\n    image_latents=inv_latents,\n    negative_prompt=source_prompt,\n).images[0]\nmask_image = Image.fromarray((mask_image.squeeze()*255).astype(\"uint8\"), \"L\").resize((768, 768))\nmake_image_grid([raw_image, mask_image, output_image], rows=1, cols=3)\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">original image</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://github.com/Xiang-cd/DiffEdit-stable-diffusion/blob/main/assets/target.png?raw=true\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">edited image</figcaption>\n  </div>\n</div>\n\n## Source와 target 임베딩 생성하기\n\nSource와 target 임베딩은 수동으로 생성하는 대신 [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) 모델을 사용하여 자동으로 생성할 수 있습니다.\n\nFlan-T5 모델과 토크나이저를 🤗 Transformers 라이브러리에서 불러옵니다:\n\n```py\nimport torch\nfrom transformers import AutoTokenizer, T5ForConditionalGeneration\n\ntokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-large\")\nmodel = T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-large\", device_map=\"auto\", torch_dtype=torch.float16)\n```\n\n모델에 프롬프트할 source와 target 프롬프트를 생성하기 위해 초기 텍스트들을 제공합니다.\n\n```py\nsource_concept = \"bowl\"\ntarget_concept = \"basket\"\n\nsource_text = f\"Provide a caption for images containing a {source_concept}. \"\n\"The captions should be in English and should be no longer than 150 characters.\"\n\ntarget_text = f\"Provide a caption for images containing a {target_concept}. \"\n\"The captions should be in English and should be no longer than 150 characters.\"\n```\n\n다음으로, 프롬프트들을 생성하기 위해 유틸리티 함수를 생성합니다.\n\n```py\n@torch.no_grad()\ndef generate_prompts(input_prompt):\n    input_ids = tokenizer(input_prompt, return_tensors=\"pt\").input_ids.to(\"cuda\")\n\n    outputs = model.generate(\n        input_ids, temperature=0.8, num_return_sequences=16, do_sample=True, max_new_tokens=128, top_k=10\n    )\n    return tokenizer.batch_decode(outputs, skip_special_tokens=True)\n\nsource_prompts = generate_prompts(source_text)\ntarget_prompts = generate_prompts(target_text)\nprint(source_prompts)\nprint(target_prompts)\n```\n\n> [!TIP]\n> 다양한 품질의 텍스트를 생성하는 전략에 대해 자세히 알아보려면 [생성 전략](https://huggingface.co/docs/transformers/main/en/generation_strategies) 가이드를 참조하세요.\n\n텍스트 인코딩을 위해 [`StableDiffusionDiffEditPipeline`]에서 사용하는 텍스트 인코더 모델을 불러옵니다. 텍스트 인코더를 사용하여 텍스트 임베딩을 계산합니다:\n\n```py\nimport torch\nfrom diffusers import StableDiffusionDiffEditPipeline\n\npipeline = StableDiffusionDiffEditPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-2-1\", torch_dtype=torch.float16, use_safetensors=True\n)\npipeline.enable_model_cpu_offload()\npipeline.enable_vae_slicing()\n\n@torch.no_grad()\ndef embed_prompts(sentences, tokenizer, text_encoder, device=\"cuda\"):\n    embeddings = []\n    for sent in sentences:\n        text_inputs = tokenizer(\n            sent,\n            padding=\"max_length\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]\n        embeddings.append(prompt_embeds)\n    return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0)\n\nsource_embeds = embed_prompts(source_prompts, pipeline.tokenizer, pipeline.text_encoder)\ntarget_embeds = embed_prompts(target_prompts, pipeline.tokenizer, pipeline.text_encoder)\n```\n\n마지막으로, 임베딩을 [`~StableDiffusionDiffEditPipeline.generate_mask`] 및 [`~StableDiffusionDiffEditPipeline.invert`] 함수와 파이프라인에 전달하여 이미지를 생성합니다:\n\n```diff\n  from diffusers import DDIMInverseScheduler, DDIMScheduler\n  from diffusers.utils import load_image, make_image_grid\n  from PIL import Image\n\n  pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)\n  pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)\n\n  img_url = \"https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png\"\n  raw_image = load_image(img_url).resize((768, 768))\n\n  mask_image = pipeline.generate_mask(\n      image=raw_image,\n-     source_prompt=source_prompt,\n-     target_prompt=target_prompt,\n+     source_prompt_embeds=source_embeds,\n+     target_prompt_embeds=target_embeds,\n  )\n\n  inv_latents = pipeline.invert(\n-     prompt=source_prompt,\n+     prompt_embeds=source_embeds,\n      image=raw_image,\n  ).latents\n\n  output_image = pipeline(\n      mask_image=mask_image,\n      image_latents=inv_latents,\n-     prompt=target_prompt,\n-     negative_prompt=source_prompt,\n+     prompt_embeds=target_embeds,\n+     negative_prompt_embeds=source_embeds,\n  ).images[0]\n  mask_image = Image.fromarray((mask_image.squeeze()*255).astype(\"uint8\"), \"L\")\n  make_image_grid([raw_image, mask_image, output_image], rows=1, cols=3)\n```\n\n## 반전을 위한 캡션 생성하기\n\n`source_prompt`를 캡션으로 사용하여 부분적으로 반전된 latents를 생성할 수 있지만, [BLIP](https://huggingface.co/docs/transformers/model_doc/blip) 모델을 사용하여 캡션을 자동으로 생성할 수도 있습니다.\n\n🤗 Transformers 라이브러리에서 BLIP 모델과 프로세서를 불러옵니다:\n\n```py\nimport torch\nfrom transformers import BlipForConditionalGeneration, BlipProcessor\n\nprocessor = BlipProcessor.from_pretrained(\"Salesforce/blip-image-captioning-base\")\nmodel = BlipForConditionalGeneration.from_pretrained(\"Salesforce/blip-image-captioning-base\", torch_dtype=torch.float16, low_cpu_mem_usage=True)\n```\n\n입력 이미지에서 캡션을 생성하는 유틸리티 함수를 만듭니다:\n\n```py\n@torch.no_grad()\ndef generate_caption(images, caption_generator, caption_processor):\n    text = \"a photograph of\"\n\n    inputs = caption_processor(images, text, return_tensors=\"pt\").to(device=\"cuda\", dtype=caption_generator.dtype)\n    caption_generator.to(\"cuda\")\n    outputs = caption_generator.generate(**inputs, max_new_tokens=128)\n\n    # 캡션 generator 오프로드\n    caption_generator.to(\"cpu\")\n\n    caption = caption_processor.batch_decode(outputs, skip_special_tokens=True)[0]\n    return caption\n```\n\n입력 이미지를 불러오고 `generate_caption` 함수를 사용하여 해당 이미지에 대한 캡션을 생성합니다:\n\n```py\nfrom diffusers.utils import load_image\n\nimg_url = \"https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png\"\nraw_image = load_image(img_url).resize((768, 768))\ncaption = generate_caption(raw_image, model, processor)\n```\n\n<div class=\"flex justify-center\">\n    <figure>\n        <img class=\"rounded-xl\" src=\"https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png\"/>\n        <figcaption class=\"text-center\">generated caption: \"a photograph of a bowl of fruit on a table\"</figcaption>\n    </figure>\n</div>\n\n이제 캡션을 [`~StableDiffusionDiffEditPipeline.invert`] 함수에 놓아 부분적으로 반전된 latents를 생성할 수 있습니다!\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/img2img.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 텍스트 기반 image-to-image 생성\n\n[[open-in-colab]]\n\n[`StableDiffusionImg2ImgPipeline`]을 사용하면 텍스트 프롬프트와 시작 이미지를 전달하여 새 이미지 생성의 조건을 지정할 수 있습니다.\n\n시작하기 전에 필요한 라이브러리가 모두 설치되어 있는지 확인하세요:\n\n```bash\n!pip install diffusers transformers ftfy accelerate\n```\n\n[`nitrosocke/Ghibli-Diffusion`](https://huggingface.co/nitrosocke/Ghibli-Diffusion)과 같은 사전학습된 stable diffusion 모델로 [`StableDiffusionImg2ImgPipeline`]을 생성하여 시작하세요.\n\n\n```python\nimport torch\nimport requests\nfrom PIL import Image\nfrom io import BytesIO\nfrom diffusers import StableDiffusionImg2ImgPipeline\n\ndevice = \"cuda\"\npipe = StableDiffusionImg2ImgPipeline.from_pretrained(\"nitrosocke/Ghibli-Diffusion\", torch_dtype=torch.float16).to(\n    device\n)\n```\n\n초기 이미지를 다운로드하고 사전 처리하여 파이프라인에 전달할 수 있습니다:\n\n```python\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\n\nresponse = requests.get(url)\ninit_image = Image.open(BytesIO(response.content)).convert(\"RGB\")\ninit_image.thumbnail((768, 768))\ninit_image\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/image_2_image_using_diffusers_cell_8_output_0.jpeg\"/>\n</div>\n\n> [!TIP]\n> 💡 `strength`는 입력 이미지에 추가되는 노이즈의 양을 제어하는 0.0에서 1.0 사이의 값입니다. 1.0에 가까운 값은 다양한 변형을 허용하지만 입력 이미지와 의미적으로 일치하지 않는 이미지를 생성합니다.\n\n프롬프트를 정의하고(지브리 스타일(Ghibli-style)에 맞게 조정된 이 체크포인트의 경우 프롬프트 앞에 `ghibli style` 토큰을 붙여야 합니다) 파이프라인을 실행합니다:\n\n```python\nprompt = \"ghibli style, a fantasy landscape with castles\"\ngenerator = torch.Generator(device=device).manual_seed(1024)\nimage = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ghibli-castles.png\"/>\n</div>\n\n다른 스케줄러로 실험하여 출력에 어떤 영향을 미치는지 확인할 수도 있습니다:\n\n```python\nfrom diffusers import LMSDiscreteScheduler\n\nlms = LMSDiscreteScheduler.from_config(pipe.scheduler.config)\npipe.scheduler = lms\ngenerator = torch.Generator(device=device).manual_seed(1024)\nimage = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lms-ghibli.png\"/>\n</div>\n\n아래 공백을 확인하고 `strength` 값을 다르게 설정하여 이미지를 생성해 보세요. `strength`를 낮게 설정하면 원본 이미지와 더 유사한 이미지가 생성되는 것을 확인할 수 있습니다.\n\n자유롭게 스케줄러를 [`LMSDiscreteScheduler`]로 전환하여 출력에 어떤 영향을 미치는지 확인해 보세요.\n\n<iframe\n\tsrc=\"https://stevhliu-ghibli-img2img.hf.space\"\n\tframeborder=\"0\"\n\twidth=\"850\"\n\theight=\"500\"\n></iframe>"
  },
  {
    "path": "docs/source/ko/using-diffusers/inpaint.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Text-guided 이미지 인페인팅(inpainting)\n\n[[open-in-colab]]\n\n[`StableDiffusionInpaintPipeline`]은 마스크와 텍스트 프롬프트를 제공하여 이미지의 특정 부분을 편집할 수 있도록 합니다. 이 기능은 인페인팅 작업을 위해 특별히 훈련된 [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting)과 같은 Stable Diffusion 버전을 사용합니다.\n\n먼저 [`StableDiffusionInpaintPipeline`] 인스턴스를 불러옵니다:\n\n```python\nimport PIL\nimport requests\nimport torch\nfrom io import BytesIO\n\nfrom diffusers import StableDiffusionInpaintPipeline\n\npipeline = StableDiffusionInpaintPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\",\n    torch_dtype=torch.float16,\n)\npipeline = pipeline.to(\"cuda\")\n```\n\n나중에 교체할 강아지 이미지와 마스크를 다운로드하세요:\n\n```python\ndef download_image(url):\n    response = requests.get(url)\n    return PIL.Image.open(BytesIO(response.content)).convert(\"RGB\")\n\n\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\ninit_image = download_image(img_url).resize((512, 512))\nmask_image = download_image(mask_url).resize((512, 512))\n```\n\n이제 마스크를 다른 것으로 교체하라는 프롬프트를 만들 수 있습니다:\n\n```python\nprompt = \"Face of a yellow cat, high resolution, sitting on a park bench\"\nimage = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]\n```\n\n`image`          | `mask_image` | `prompt` | output |\n:-------------------------:|:-------------------------:|:-------------------------:|-------------------------:|\n<img src=\"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\" alt=\"drawing\" width=\"250\"/> | <img src=\"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\" alt=\"drawing\" width=\"250\"/> | ***Face of a yellow cat, high resolution, sitting on a park bench*** | <img src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/yellow_cat_sitting_on_a_park_bench.png\" alt=\"drawing\" width=\"250\"/> |\n\n> [!WARNING]\n> 이전의 실험적인 인페인팅 구현에서는 품질이 낮은 다른 프로세스를 사용했습니다. 이전 버전과의 호환성을 보장하기 위해 새 모델이 포함되지 않은 사전학습된 파이프라인을 불러오면 이전 인페인팅 방법이 계속 적용됩니다.\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/kandinsky.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Kandinsky\n\n[[open-in-colab]]\n\nKandinsky 모델은 일련의 다국어 text-to-image 생성 모델입니다. Kandinsky 2.0 모델은 두 개의 다국어 텍스트 인코더를 사용하고 그 결과를 연결해 UNet에 사용됩니다.\n\n[Kandinsky 2.1](../api/pipelines/kandinsky)은 텍스트와 이미지 임베딩 간의 매핑을 생성하는 image prior 모델([`CLIP`](https://huggingface.co/docs/transformers/model_doc/clip))을 포함하도록 아키텍처를 변경했습니다. 이 매핑은 더 나은 text-image alignment를 제공하며, 학습 중에 텍스트 임베딩과 함께 사용되어 더 높은 품질의 결과를 가져옵니다. 마지막으로, Kandinsky 2.1은 spatial conditional 정규화 레이어를 추가하여 사실감을 높여주는 [Modulating Quantized Vectors (MoVQ)](https://huggingface.co/papers/2209.09002) 디코더를 사용하여 latents를 이미지로 디코딩합니다.\n\n[Kandinsky 2.2](../api/pipelines/kandinsky_v22)는 image prior 모델의 이미지 인코더를 더 큰 CLIP-ViT-G 모델로 교체하여 품질을 개선함으로써 이전 모델을 개선했습니다. 또한 image prior 모델은 해상도와 종횡비가 다른 이미지로 재훈련되어 더 높은 해상도의 이미지와 다양한 이미지 크기를 생성합니다.\n\n[Kandinsky 3](../api/pipelines/kandinsky3)는 아키텍처를 단순화하고 prior 모델과 diffusion 모델을 포함하는 2단계 생성 프로세스에서 벗어나고 있습니다. 대신, Kandinsky 3는 [Flan-UL2](https://huggingface.co/google/flan-ul2)를 사용하여 텍스트를 인코딩하고, [BigGan-deep](https://hf.co/papers/1809.11096) 블록이 포함된 UNet을 사용하며, [Sber-MoVQGAN](https://github.com/ai-forever/MoVQGAN)을 사용하여 latents를 이미지로 디코딩합니다. 텍스트 이해와 생성된 이미지 품질은 주로 더 큰 텍스트 인코더와 UNet을 사용함으로써 달성됩니다.\n\n이 가이드에서는 text-to-image, image-to-image, 인페인팅, 보간 등을 위해 Kandinsky 모델을 사용하는 방법을 설명합니다.\n\n시작하기 전에 다음 라이브러리가 설치되어 있는지 확인하세요:\n\n```py\n# Colab에서 필요한 라이브러리를 설치하기 위해 주석을 제외하세요\n#!pip install -q diffusers transformers accelerate\n```\n\n> [!WARNING]\n> Kandinsky 2.1과 2.2의 사용법은 매우 유사합니다! 유일한 차이점은 Kandinsky 2.2는 latents를 디코딩할 때 `프롬프트`를 입력으로 받지 않는다는 것입니다. 대신, Kandinsky 2.2는 디코딩 중에는 `image_embeds`만 받아들입니다.\n>\n> <br>\n>\n> Kandinsky 3는 더 간결한 아키텍처를 가지고 있으며 prior 모델이 필요하지 않습니다. 즉, [Stable Diffusion XL](sdxl)과 같은 다른 diffusion 모델과 사용법이 동일합니다.\n\n## Text-to-image\n\n모든 작업에 Kandinsky 모델을 사용하려면 항상 프롬프트를 인코딩하고 이미지 임베딩을 생성하는 prior 파이프라인을 설정하는 것부터 시작해야 합니다. 이전 파이프라인은 negative 프롬프트 `\"\"`에 해당하는 `negative_image_embeds`도 생성합니다. 더 나은 결과를 얻으려면 이전 파이프라인에 실제 `negative_prompt`를 전달할 수 있지만, 이렇게 하면 prior 파이프라인의 유효 배치 크기가 2배로 증가합니다.\n\n<hfoptions id=\"text-to-image\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nfrom diffusers import KandinskyPriorPipeline, KandinskyPipeline\nimport torch\n\nprior_pipeline = KandinskyPriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1-prior\", torch_dtype=torch.float16).to(\"cuda\")\npipeline = KandinskyPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting\"\nnegative_prompt = \"low quality, bad quality\" # negative 프롬프트 포함은 선택적이지만, 보통 결과는 더 좋습니다\nimage_embeds, negative_image_embeds = prior_pipeline(prompt, negative_prompt, guidance_scale=1.0).to_tuple()\n```\n\n이제 모든 프롬프트와 임베딩을 [`KandinskyPipeline`]에 전달하여 이미지를 생성합니다:\n\n```py\nimage = pipeline(prompt, image_embeds=image_embeds, negative_prompt=negative_prompt, negative_image_embeds=negative_image_embeds, height=768, width=768).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/cheeseburger.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nfrom diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline\nimport torch\n\nprior_pipeline = KandinskyV22PriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-prior\", torch_dtype=torch.float16).to(\"cuda\")\npipeline = KandinskyV22Pipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting\"\nnegative_prompt = \"low quality, bad quality\" # negative 프롬프트 포함은 선택적이지만, 보통 결과는 더 좋습니다\nimage_embeds, negative_image_embeds = prior_pipeline(prompt, guidance_scale=1.0).to_tuple()\n```\n\n이미지 생성을 위해 `image_embeds`와 `negative_image_embeds`를 [`KandinskyV22Pipeline`]에 전달합니다:\n\n```py\nimage = pipeline(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, height=768, width=768).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-text-to-image.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"Kandinsky 3\">\n\nKandinsky 3는 prior 모델이 필요하지 않으므로 [`Kandinsky3Pipeline`]을 직접 불러오고 이미지 생성 프롬프트를 전달할 수 있습니다:\n\n```py\nfrom diffusers import Kandinsky3Pipeline\nimport torch\n\npipeline = Kandinsky3Pipeline.from_pretrained(\"kandinsky-community/kandinsky-3\", variant=\"fp16\", torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n\nprompt = \"A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting\"\nimage = pipeline(prompt).images[0]\nimage\n```\n\n</hfoption>\n</hfoptions>\n\n🤗 Diffusers는 또한 [`KandinskyCombinedPipeline`] 및 [`KandinskyV22CombinedPipeline`]이 포함된 end-to-end API를 제공하므로 prior 파이프라인과 text-to-image 변환 파이프라인을 별도로 불러올 필요가 없습니다. 결합된 파이프라인은 prior 모델과 디코더를 모두 자동으로 불러옵니다. 원하는 경우 `prior_guidance_scale` 및 `prior_num_inference_steps` 매개 변수를 사용하여 prior 파이프라인에 대해 다른 값을 설정할 수 있습니다.\n\n내부에서 결합된 파이프라인을 자동으로 호출하려면 [`AutoPipelineForText2Image`]를 사용합니다:\n\n<hfoptions id=\"text-to-image\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n\nprompt = \"A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting\"\nnegative_prompt = \"low quality, bad quality\"\n\nimage = pipeline(prompt=prompt, negative_prompt=negative_prompt, prior_guidance_scale=1.0, guidance_scale=4.0, height=768, width=768).images[0]\nimage\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n\nprompt = \"A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting\"\nnegative_prompt = \"low quality, bad quality\"\n\nimage = pipeline(prompt=prompt, negative_prompt=negative_prompt, prior_guidance_scale=1.0, guidance_scale=4.0, height=768, width=768).images[0]\nimage\n```\n\n</hfoption>\n</hfoptions>\n\n## Image-to-image\n\nImage-to-image 경우, 초기 이미지와 텍스트 프롬프트를 전달하여 파이프라인에 이미지를 conditioning합니다. Prior 파이프라인을 불러오는 것으로 시작합니다:\n\n<hfoptions id=\"image-to-image\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nimport torch\nfrom diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline\n\nprior_pipeline = KandinskyPriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\npipeline = KandinskyImg2ImgPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nimport torch\nfrom diffusers import KandinskyV22Img2ImgPipeline, KandinskyPriorPipeline\n\nprior_pipeline = KandinskyPriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\npipeline = KandinskyV22Img2ImgPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 3\">\n\nKandinsky 3는 prior 모델이 필요하지 않으므로 image-to-image 파이프라인을 직접 불러올 수 있습니다:\n\n```py\nfrom diffusers import Kandinsky3Img2ImgPipeline\nfrom diffusers.utils import load_image\nimport torch\n\npipeline = Kandinsky3Img2ImgPipeline.from_pretrained(\"kandinsky-community/kandinsky-3\", variant=\"fp16\", torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n```\n\n</hfoption>\n</hfoptions>\n\nConditioning할 이미지를 다운로드합니다:\n\n```py\nfrom diffusers.utils import load_image\n\n# 이미지 다운로드\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\noriginal_image = load_image(url)\noriginal_image = original_image.resize((768, 512))\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"/>\n</div>\n\nPrior 파이프라인으로 `image_embeds`와 `negative_image_embeds`를 생성합니다:\n\n```py\nprompt = \"A fantasy landscape, Cinematic lighting\"\nnegative_prompt = \"low quality, bad quality\"\n\nimage_embeds, negative_image_embeds = prior_pipeline(prompt, negative_prompt).to_tuple()\n```\n\n이제 원본 이미지와 모든 프롬프트 및 임베딩을 파이프라인으로 전달하여 이미지를 생성합니다:\n\n<hfoptions id=\"image-to-image\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nfrom diffusers.utils import make_image_grid\n\nimage = pipeline(prompt, negative_prompt=negative_prompt, image=original_image, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, height=768, width=768, strength=0.3).images[0]\nmake_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/img2img_fantasyland.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nfrom diffusers.utils import make_image_grid\n\nimage = pipeline(image=original_image, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, height=768, width=768, strength=0.3).images[0]\nmake_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-image-to-image.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"Kandinsky 3\">\n\n```py\nimage = pipeline(prompt, negative_prompt=negative_prompt, image=image, strength=0.75, num_inference_steps=25).images[0]\nimage\n```\n\n</hfoption>\n</hfoptions>\n\n또한 🤗 Diffusers에서는 [`KandinskyImg2ImgCombinedPipeline`] 및 [`KandinskyV22Img2ImgCombinedPipeline`]이 포함된 end-to-end API를 제공하므로 prior 파이프라인과 image-to-image 파이프라인을 별도로 불러올 필요가 없습니다. 결합된 파이프라인은 prior 모델과 디코더를 모두 자동으로 불러옵니다. 원하는 경우 `prior_guidance_scale` 및 `prior_num_inference_steps` 매개 변수를 사용하여 이전 파이프라인에 대해 다른 값을 설정할 수 있습니다.\n\n내부에서 결합된 파이프라인을 자동으로 호출하려면 [`AutoPipelineForImage2Image`]를 사용합니다:\n\n<hfoptions id=\"image-to-image\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import make_image_grid, load_image\nimport torch\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16, use_safetensors=True)\npipeline.enable_model_cpu_offload()\n\nprompt = \"A fantasy landscape, Cinematic lighting\"\nnegative_prompt = \"low quality, bad quality\"\n\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\noriginal_image = load_image(url)\n\noriginal_image.thumbnail((768, 768))\n\nimage = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=original_image, strength=0.3).images[0]\nmake_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import make_image_grid, load_image\nimport torch\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n\nprompt = \"A fantasy landscape, Cinematic lighting\"\nnegative_prompt = \"low quality, bad quality\"\n\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\noriginal_image = load_image(url)\n\noriginal_image.thumbnail((768, 768))\n\nimage = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=original_image, strength=0.3).images[0]\nmake_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)\n```\n\n</hfoption>\n</hfoptions>\n\n## Inpainting\n\n> [!WARNING]\n> ⚠️ Kandinsky 모델은 이제 검은색 픽셀 대신 ⬜️ **흰색 픽셀**을 사용하여 마스크 영역을 표현합니다. 프로덕션에서 [`KandinskyInpaintPipeline`]을 사용하는 경우 흰색 픽셀을 사용하도록 마스크를 변경해야 합니다:\n>\n> ```py\n> # PIL 입력에 대해\n> import PIL.ImageOps\n> mask = PIL.ImageOps.invert(mask)\n>\n> # PyTorch와 NumPy 입력에 대해\n> mask = 1 - mask\n> ```\n\n인페인팅에서는 원본 이미지, 원본 이미지에서 대체할 영역의 마스크, 인페인팅할 내용에 대한 텍스트 프롬프트가 필요합니다. Prior 파이프라인을 불러옵니다:\n\n<hfoptions id=\"inpaint\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nfrom diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline\nfrom diffusers.utils import load_image, make_image_grid\nimport torch\nimport numpy as np\nfrom PIL import Image\n\nprior_pipeline = KandinskyPriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\npipeline = KandinskyInpaintPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1-inpaint\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nfrom diffusers import KandinskyV22InpaintPipeline, KandinskyV22PriorPipeline\nfrom diffusers.utils import load_image, make_image_grid\nimport torch\nimport numpy as np\nfrom PIL import Image\n\nprior_pipeline = KandinskyV22PriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\npipeline = KandinskyV22InpaintPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder-inpaint\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n```\n\n</hfoption>\n</hfoptions>\n\n초기 이미지를 불러오고 마스크를 생성합니다:\n\n```py\ninit_image = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png\")\nmask = np.zeros((768, 768), dtype=np.float32)\n# mask area above cat's head\nmask[:250, 250:-250] = 1\n```\n\nPrior 파이프라인으로 임베딩을 생성합니다:\n\n```py\nprompt = \"a hat\"\nprior_output = prior_pipeline(prompt)\n```\n\n이제 이미지 생성을 위해 초기 이미지, 마스크, 프롬프트와 임베딩을 파이프라인에 전달합니다:\n\n<hfoptions id=\"inpaint\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\noutput_image = pipeline(prompt, image=init_image, mask_image=mask, **prior_output, height=768, width=768, num_inference_steps=150).images[0]\nmask = Image.fromarray((mask*255).astype('uint8'), 'L')\nmake_image_grid([init_image, mask, output_image], rows=1, cols=3)\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/inpaint_cat_hat.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\noutput_image = pipeline(image=init_image, mask_image=mask, **prior_output, height=768, width=768, num_inference_steps=150).images[0]\nmask = Image.fromarray((mask*255).astype('uint8'), 'L')\nmake_image_grid([init_image, mask, output_image], rows=1, cols=3)\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinskyv22-inpaint.png\"/>\n</div>\n\n</hfoption>\n</hfoptions>\n\n[`KandinskyInpaintCombinedPipeline`] 및 [`KandinskyV22InpaintCombinedPipeline`]을 사용하여 내부에서 prior 및 디코더 파이프라인을 함께 호출할 수 있습니다. 이를 위해 [`AutoPipelineForInpainting`]을 사용합니다:\n\n<hfoptions id=\"inpaint\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipe = AutoPipelineForInpainting.from_pretrained(\"kandinsky-community/kandinsky-2-1-inpaint\", torch_dtype=torch.float16)\npipe.enable_model_cpu_offload()\n\ninit_image = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png\")\nmask = np.zeros((768, 768), dtype=np.float32)\n# 고양이 머리 위 마스크 지역\nmask[:250, 250:-250] = 1\nprompt = \"a hat\"\n\noutput_image = pipe(prompt=prompt, image=init_image, mask_image=mask).images[0]\nmask = Image.fromarray((mask*255).astype('uint8'), 'L')\nmake_image_grid([init_image, mask, output_image], rows=1, cols=3)\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom diffusers import AutoPipelineForInpainting\nfrom diffusers.utils import load_image, make_image_grid\n\npipe = AutoPipelineForInpainting.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder-inpaint\", torch_dtype=torch.float16)\npipe.enable_model_cpu_offload()\n\ninit_image = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png\")\nmask = np.zeros((768, 768), dtype=np.float32)\n# 고양이 머리 위 마스크 영역\nmask[:250, 250:-250] = 1\nprompt = \"a hat\"\n\noutput_image = pipe(prompt=prompt, image=original_image, mask_image=mask).images[0]\nmask = Image.fromarray((mask*255).astype('uint8'), 'L')\nmake_image_grid([init_image, mask, output_image], rows=1, cols=3)\n```\n\n</hfoption>\n</hfoptions>\n\n## Interpolation (보간)\n\nInterpolation(보간)을 사용하면 이미지와 텍스트 임베딩 사이의 latent space를 탐색할 수 있어 prior 모델의 중간 결과물을 볼 수 있는 멋진 방법입니다. Prior 파이프라인과 보간하려는 두 개의 이미지를 불러옵니다:\n\n<hfoptions id=\"interpolate\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\nfrom diffusers import KandinskyPriorPipeline, KandinskyPipeline\nfrom diffusers.utils import load_image, make_image_grid\nimport torch\n\nprior_pipeline = KandinskyPriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\nimg_1 = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png\")\nimg_2 = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/starry_night.jpeg\")\nmake_image_grid([img_1.resize((512,512)), img_2.resize((512,512))], rows=1, cols=2)\n```\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\nfrom diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline\nfrom diffusers.utils import load_image, make_image_grid\nimport torch\n\nprior_pipeline = KandinskyV22PriorPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\nimg_1 = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png\")\nimg_2 = load_image(\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/starry_night.jpeg\")\nmake_image_grid([img_1.resize((512,512)), img_2.resize((512,512))], rows=1, cols=2)\n```\n\n</hfoption>\n</hfoptions>\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">a cat</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/starry_night.jpeg\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">Van Gogh's Starry Night painting</figcaption>\n  </div>\n</div>\n\n보간할 텍스트 또는 이미지를 지정하고 각 텍스트 또는 이미지에 대한 가중치를 설정합니다. 가중치를 실험하여 보간에 어떤 영향을 미치는지 확인하세요!\n\n```py\nimages_texts = [\"a cat\", img_1, img_2]\nweights = [0.3, 0.3, 0.4]\n```\n\n`interpolate` 함수를 호출하여 임베딩을 생성한 다음, 파이프라인으로 전달하여 이미지를 생성합니다:\n\n<hfoptions id=\"interpolate\">\n<hfoption id=\"Kandinsky 2.1\">\n\n```py\n# 프롬프트는 빈칸으로 남겨도 됩니다\nprompt = \"\"\nprior_out = prior_pipeline.interpolate(images_texts, weights)\n\npipeline = KandinskyPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n\nimage = pipeline(prompt, **prior_out, height=768, width=768).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/starry_cat.png\"/>\n</div>\n\n</hfoption>\n<hfoption id=\"Kandinsky 2.2\">\n\n```py\n# 프롬프트는 빈칸으로 남겨도 됩니다\nprompt = \"\"\nprior_out = prior_pipeline.interpolate(images_texts, weights)\n\npipeline = KandinskyV22Pipeline.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n\nimage = pipeline(prompt, **prior_out, height=768, width=768).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinskyv22-interpolate.png\"/>\n</div>\n\n</hfoption>\n</hfoptions>\n\n## ControlNet\n\n> [!WARNING]\n> ⚠️ ControlNet은 Kandinsky 2.2에서만 지원됩니다!\n\nControlNet을 사용하면 depth map이나 edge detection와 같은 추가 입력을 통해 사전학습된 large diffusion 모델을 conditioning할 수 있습니다. 예를 들어, 모델이 depth map의 구조를 이해하고 보존할 수 있도록 깊이 맵으로 Kandinsky 2.2를 conditioning할 수 있습니다.\n\n이미지를 불러오고 depth map을 추출해 보겠습니다:\n\n```py\nfrom diffusers.utils import load_image\n\nimg = load_image(\n    \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png\"\n).resize((768, 768))\nimg\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png\"/>\n</div>\n\n그런 다음 🤗 Transformers의 `depth-estimation` [`~transformers.Pipeline`]을 사용하여 이미지를 처리해 depth map을 구할 수 있습니다:\n\n```py\nimport torch\nimport numpy as np\n\nfrom transformers import pipeline\n\ndef make_hint(image, depth_estimator):\n    image = depth_estimator(image)[\"depth\"]\n    image = np.array(image)\n    image = image[:, :, None]\n    image = np.concatenate([image, image, image], axis=2)\n    detected_map = torch.from_numpy(image).float() / 255.0\n    hint = detected_map.permute(2, 0, 1)\n    return hint\n\ndepth_estimator = pipeline(\"depth-estimation\")\nhint = make_hint(img, depth_estimator).unsqueeze(0).half().to(\"cuda\")\n```\n\n### Text-to-image [[controlnet-text-to-image]]\n\nPrior 파이프라인과 [`KandinskyV22ControlnetPipeline`]를 불러옵니다:\n\n```py\nfrom diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline\n\nprior_pipeline = KandinskyV22PriorPipeline.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-prior\", torch_dtype=torch.float16, use_safetensors=True\n).to(\"cuda\")\n\npipeline = KandinskyV22ControlnetPipeline.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-controlnet-depth\", torch_dtype=torch.float16\n).to(\"cuda\")\n```\n\n프롬프트와 negative 프롬프트로 이미지 임베딩을 생성합니다:\n\n```py\nprompt = \"A robot, 4k photo\"\nnegative_prior_prompt = \"lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature\"\n\ngenerator = torch.Generator(device=\"cuda\").manual_seed(43)\n\nimage_emb, zero_image_emb = prior_pipeline(\n    prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator\n).to_tuple()\n```\n\n마지막으로 이미지 임베딩과 depth 이미지를 [`KandinskyV22ControlnetPipeline`]에 전달하여 이미지를 생성합니다:\n\n```py\nimage = pipeline(image_embeds=image_emb, negative_image_embeds=zero_image_emb, hint=hint, num_inference_steps=50, generator=generator, height=768, width=768).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/robot_cat_text2img.png\"/>\n</div>\n\n### Image-to-image [[controlnet-image-to-image]]\n\nControlNet을 사용한 image-to-image의 경우, 다음을 사용할 필요가 있습니다:\n\n- [`KandinskyV22PriorEmb2EmbPipeline`]로 텍스트 프롬프트와 이미지에서 이미지 임베딩을 생성합니다.\n- [`KandinskyV22ControlnetImg2ImgPipeline`]로 초기 이미지와 이미지 임베딩에서 이미지를 생성합니다.\n\n🤗 Transformers에서 `depth-estimation` [`~transformers.Pipeline`]을 사용하여 고양이의 초기 이미지의 depth map을 처리해 추출합니다:\n\n```py\nimport torch\nimport numpy as np\n\nfrom diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline\nfrom diffusers.utils import load_image\nfrom transformers import pipeline\n\nimg = load_image(\n    \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/cat.png\"\n).resize((768, 768))\n\ndef make_hint(image, depth_estimator):\n    image = depth_estimator(image)[\"depth\"]\n    image = np.array(image)\n    image = image[:, :, None]\n    image = np.concatenate([image, image, image], axis=2)\n    detected_map = torch.from_numpy(image).float() / 255.0\n    hint = detected_map.permute(2, 0, 1)\n    return hint\n\ndepth_estimator = pipeline(\"depth-estimation\")\nhint = make_hint(img, depth_estimator).unsqueeze(0).half().to(\"cuda\")\n```\n\nPrior 파이프라인과 [`KandinskyV22ControlnetImg2ImgPipeline`]을 불러옵니다:\n\n```py\nprior_pipeline = KandinskyV22PriorEmb2EmbPipeline.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-prior\", torch_dtype=torch.float16, use_safetensors=True\n).to(\"cuda\")\n\npipeline = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained(\n    \"kandinsky-community/kandinsky-2-2-controlnet-depth\", torch_dtype=torch.float16\n).to(\"cuda\")\n```\n\n텍스트 프롬프트와 초기 이미지를 이전 파이프라인에 전달하여 이미지 임베딩을 생성합니다:\n\n```py\nprompt = \"A robot, 4k photo\"\nnegative_prior_prompt = \"lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature\"\n\ngenerator = torch.Generator(device=\"cuda\").manual_seed(43)\n\nimg_emb = prior_pipeline(prompt=prompt, image=img, strength=0.85, generator=generator)\nnegative_emb = prior_pipeline(prompt=negative_prior_prompt, image=img, strength=1, generator=generator)\n```\n\n이제 [`KandinskyV22ControlnetImg2ImgPipeline`]을 실행하여 초기 이미지와 이미지 임베딩으로부터 이미지를 생성할 수 있습니다:\n\n```py\nimage = pipeline(image=img, strength=0.5, image_embeds=img_emb.image_embeds, negative_image_embeds=negative_emb.image_embeds, hint=hint, num_inference_steps=50, generator=generator, height=768, width=768).images[0]\nmake_image_grid([img.resize((512, 512)), image.resize((512, 512))], rows=1, cols=2)\n```\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinskyv22/robot_cat.png\"/>\n</div>\n\n## 최적화\n\nKandinsky는 mapping을 생성하기 위한 prior 파이프라인과 latents를 이미지로 디코딩하기 위한 두 번째 파이프라인이 필요하다는 점에서 독특합니다. 대부분의 계산이 두 번째 파이프라인에서 이루어지므로 최적화의 노력은 두 번째 파이프라인에 집중되어야 합니다. 다음은 추론 중 Kandinsky키를 개선하기 위한 몇 가지 팁입니다.\n\n1. PyTorch < 2.0을 사용할 경우 [xFormers](../optimization/xformers)을 활성화합니다.\n\n```diff\n  from diffusers import DiffusionPipeline\n  import torch\n\n  pipe = DiffusionPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16)\n+ pipe.enable_xformers_memory_efficient_attention()\n```\n\n2. PyTorch >= 2.0을 사용할 경우 `torch.compile`을 활성화하여 scaled dot-product attention (SDPA)를 자동으로 사용하도록 합니다:\n\n```diff\n  pipe.unet.to(memory_format=torch.channels_last)\n+ pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\n이는 attention processor를 명시적으로 [`~models.attention_processor.AttnAddedKVProcessor2_0`]을 사용하도록 설정하는 것과 동일합니다:\n\n```py\nfrom diffusers.models.attention_processor import AttnAddedKVProcessor2_0\n\npipe.unet.set_attn_processor(AttnAddedKVProcessor2_0())\n```\n\n3. 메모리 부족 오류를 방지하기 위해 [`~KandinskyPriorPipeline.enable_model_cpu_offload`]를 사용하여 모델을 CPU로 오프로드합니다:\n\n```diff\n  from diffusers import DiffusionPipeline\n  import torch\n\n  pipe = DiffusionPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16)\n+ pipe.enable_model_cpu_offload()\n```\n\n4. 기본적으로 text-to-image 파이프라인은 [`DDIMScheduler`]를 사용하지만, [`DDPMScheduler`]와 같은 다른 스케줄러로 대체하여 추론 속도와 이미지 품질 간의 균형에 어떤 영향을 미치는지 확인할 수 있습니다:\n\n```py\nfrom diffusers import DDPMScheduler\nfrom diffusers import DiffusionPipeline\n\nscheduler = DDPMScheduler.from_pretrained(\"kandinsky-community/kandinsky-2-1\", subfolder=\"ddpm_scheduler\")\npipe = DiffusionPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", scheduler=scheduler, torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n```\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/loading.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n\n\n# 파이프라인, 모델, 스케줄러 불러오기\n\n기본적으로 diffusion 모델은 다양한 컴포넌트들(모델, 토크나이저, 스케줄러) 간의 복잡한 상호작용을 기반으로 동작합니다. 디퓨저스(Diffusers)는 이러한 diffusion 모델을 보다 쉽고 간편한 API로 제공하는 것을 목표로 설계되었습니다. [`DiffusionPipeline`]은 diffusion 모델이 갖는 복잡성을 하나의 파이프라인 API로 통합하고, 동시에 이를 구성하는 각각의 컴포넌트들을 태스크에 맞춰 유연하게 커스터마이징할 수 있도록 지원하고 있습니다.\n\ndiffusion 모델의 훈련과 추론에 필요한 모든 것은 [`DiffusionPipeline.from_pretrained`] 메서드를 통해 접근할 수 있습니다. (이 말의 의미는 다음 단락에서 보다 자세하게 다뤄보도록 하겠습니다.)\n\n이 문서에서는 설명할 내용은 다음과 같습니다.\n\n* 허브를 통해 혹은 로컬로 파이프라인을 불러오는 법\n\n* 파이프라인에 다른 컴포넌트들을 적용하는 법\n* 오리지널 체크포인트가 아닌 variant를 불러오는 법  (variant란 기본으로 설정된 `fp32`가 아닌 다른  부동 소수점 타입(예: `fp16`)을 사용하거나 Non-EMA 가중치를 사용하는 체크포인트들을 의미합니다.)\n* 모델과 스케줄러를 불러오는 법\n\n\n\n## Diffusion 파이프라인\n\n> [!TIP]\n> 💡 [`DiffusionPipeline`] 클래스가 동작하는 방식에 보다 자세한 내용이 궁금하다면,  [DiffusionPipeline explained](#diffusionpipeline에-대해-알아보기) 섹션을 확인해보세요.\n\n[`DiffusionPipeline`] 클래스는 diffusion 모델을 [허브](https://huggingface.co/models?library=diffusers)로부터 불러오는 가장 심플하면서 보편적인 방식입니다. [`DiffusionPipeline.from_pretrained`] 메서드는 적합한 파이프라인 클래스를 자동으로 탐지하고, 필요한 구성요소(configuration)와 가중치(weight) 파일들을 다운로드하고 캐싱한 다음, 해당 파이프라인 인스턴스를 반환합니다.\n\n```python\nfrom diffusers import DiffusionPipeline\n\nrepo_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipe = DiffusionPipeline.from_pretrained(repo_id)\n```\n\n물론 [`DiffusionPipeline`] 클래스를 사용하지 않고, 명시적으로 직접 해당 파이프라인 클래스를 불러오는 것도 가능합니다. 아래 예시 코드는 위 예시와 동일한 인스턴스를 반환합니다.\n\n```python\nfrom diffusers import StableDiffusionPipeline\n\nrepo_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipe = StableDiffusionPipeline.from_pretrained(repo_id)\n```\n\n[CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4)이나 [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 같은 체크포인트들의 경우, 하나 이상의 다양한 태스크에 활용될 수 있습니다. (예를 들어 위의 두 체크포인트의 경우, text-to-image와 image-to-image에 모두 활용될 수 있습니다.)  만약 이러한 체크포인트들을 기본 설정 태스크가 아닌 다른 태스크에 활용하고자 한다면, 해당 태스크에 대응되는 파이프라인(task-specific pipeline)을 사용해야 합니다.\n\n```python\nfrom diffusers import StableDiffusionImg2ImgPipeline\n\nrepo_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipe = StableDiffusionImg2ImgPipeline.from_pretrained(repo_id)\n```\n\n\n\n### 로컬 파이프라인\n\n파이프라인을 로컬로 불러오고자 한다면, `git-lfs`를 사용하여 직접 체크포인트를 로컬 디스크에 다운로드 받아야 합니다. 아래의 명령어를 실행하면 `./stable-diffusion-v1-5`란 이름으로 폴더가 로컬디스크에 생성됩니다.\n\n```bash\ngit lfs install\ngit clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5\n```\n\n그런 다음 해당 로컬 경로를 [`~DiffusionPipeline.from_pretrained`] 메서드에 전달합니다.\n\n```python\nfrom diffusers import DiffusionPipeline\n\nrepo_id = \"./stable-diffusion-v1-5\"\nstable_diffusion = DiffusionPipeline.from_pretrained(repo_id)\n```\n\n위의 예시코드처럼 만약 `repo_id`가 로컬 패스(local path)라면, [`~DiffusionPipeline.from_pretrained`] 메서드는 이를 자동으로 감지하여 허브에서 파일을 다운로드하지 않습니다. 만약 로컬 디스크에 저장된 파이프라인 체크포인트가 최신 버전이 아닐 경우에도, 최신 버전을 다운로드하지 않고 기존 로컬 디스크에 저장된 체크포인트를 사용한다는 것을 의미합니다.\n\n\n\n### 파이프라인 내부의 컴포넌트 교체하기\n\n파이프라인 내부의 컴포넌트들은 호환 가능한 다른 컴포넌트로 교체될 수 있습니다. 이와 같은 컴포넌트 교체가 중요한 이유는 다음과 같습니다.\n\n- 어떤 스케줄러를 사용할 것인가는 생성속도와 생성품질 간의 트레이드오프를 정의하는 중요한 요소입니다.\n- diffusion 모델 내부의 컴포넌트들은 일반적으로 각각 독립적으로 훈련되기 때문에, 더 좋은 성능을 보여주는 컴포넌트가 있다면 그걸로 교체하는 식으로 성능을 향상시킬 수 있습니다.\n- 파인 튜닝 단계에서는 일반적으로 UNet 혹은 텍스트 인코더와 같은 일부 컴포넌트들만 훈련하게 됩니다.\n\n어떤 스케줄러들이 호환가능한지는 `compatibles` 속성을 통해 확인할 수 있습니다.\n\n```python\nfrom diffusers import DiffusionPipeline\n\nrepo_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nstable_diffusion = DiffusionPipeline.from_pretrained(repo_id)\nstable_diffusion.scheduler.compatibles\n```\n\n이번에는 [`SchedulerMixin.from_pretrained`] 메서드를 사용해서, 기존 기본 스케줄러였던 [`PNDMScheduler`]를 보다 우수한 성능의 [`EulerDiscreteScheduler`]로 바꿔봅시다. 스케줄러를 로드할 때는 `subfolder` 인자를 통해, 해당 파이프라인의 리포지토리에서 [스케줄러에 관한 하위폴더](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/scheduler)를  명시해주어야 합니다.\n\n그 다음 새롭게 생성한 [`EulerDiscreteScheduler`] 인스턴스를 [`DiffusionPipeline`]의 `scheduler` 인자에 전달합니다.\n\n```python\nfrom diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler\n\nrepo_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n\nscheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder=\"scheduler\")\n\nstable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler)\n```\n\n### 세이프티 체커\n\n스테이블 diffusion과 같은 diffusion 모델들은 유해한 이미지를 생성할 수도 있습니다. 이를 예방하기 위해 디퓨저스는 생성된 이미지의 유해성을 판단하는 [세이프티 체커(safety checker)](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) 기능을 지원하고 있습니다. 만약 세이프티 체커의 사용을 원하지 않는다면, `safety_checker` 인자에 `None`을 전달해주시면 됩니다.\n\n```python\nfrom diffusers import DiffusionPipeline\n\nrepo_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nstable_diffusion = DiffusionPipeline.from_pretrained(repo_id, safety_checker=None)\n```\n\n### 컴포넌트 재사용\n\n복수의 파이프라인에 동일한 모델이 반복적으로 사용한다면, 굳이 해당 모델의 동일한 가중치를 중복으로 RAM에 불러올 필요는 없을 것입니다.  [`~DiffusionPipeline.components`] 속성을 통해 파이프라인 내부의 컴포넌트들을 참조할 수 있는데, 이번 단락에서는 이를 통해 동일한 모델 가중치를 RAM에 중복으로 불러오는 것을 방지하는 법에 대해 알아보겠습니다.\n\n```python\nfrom diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nstable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id)\n\ncomponents = stable_diffusion_txt2img.components\n```\n\n그 다음 위 예시 코드에서 선언한 `components` 변수를 다른 파이프라인에 전달함으로써, 모델의 가중치를 중복으로 RAM에 로딩하지 않고, 동일한 컴포넌트를 재사용할 수 있습니다.\n\n```python\nstable_diffusion_img2img = StableDiffusionImg2ImgPipeline(**components)\n```\n\n물론 각각의 컴포넌트들을 따로 따로 파이프라인에 전달할 수도 있습니다.  예를 들어 `stable_diffusion_txt2img` 파이프라인 안의 컴포넌트들 가운데서 세이프티 체커(`safety_checker`)와 피쳐 익스트랙터(`feature_extractor`)를 제외한 컴포넌트들만 `stable_diffusion_img2img` 파이프라인에서 재사용하는 방식 역시 가능합니다.\n\n```python\nfrom diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nstable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id)\nstable_diffusion_img2img = StableDiffusionImg2ImgPipeline(\n    vae=stable_diffusion_txt2img.vae,\n    text_encoder=stable_diffusion_txt2img.text_encoder,\n    tokenizer=stable_diffusion_txt2img.tokenizer,\n    unet=stable_diffusion_txt2img.unet,\n    scheduler=stable_diffusion_txt2img.scheduler,\n    safety_checker=None,\n    feature_extractor=None,\n    requires_safety_checker=False,\n)\n```\n\n## Checkpoint variants\n\nVariant란 일반적으로 다음과 같은 체크포인트들을 의미합니다.\n\n-  `torch.float16`과 같이 정밀도는 더 낮지만, 용량 역시 더 작은 부동소수점 타입의 가중치를 사용하는 체크포인트. *(다만 이와 같은 variant의 경우, 추가적인 훈련과 CPU환경에서의 구동이 불가능합니다.)*\n- Non-EMA 가중치를 사용하는 체크포인트. *(Non-EMA 가중치의 경우, 파인 튜닝 단계에서 사용하는 것이 권장되는데, 추론 단계에선 사용하지 않는 것이 권장됩니다.)*\n\n> [!TIP]\n> 💡 모델 구조는 동일하지만 서로 다른 학습 환경에서 서로 다른 데이터셋으로 학습된 체크포인트들이 있을 경우, 해당 체크포인트들은 variant 단계가 아닌 리포지토리 단계에서 분리되어 관리되어야 합니다. (즉, 해당 체크포인트들은 서로 다른 리포지토리에서 따로 관리되어야 합니다. 예시: [`stable-diffusion-v1-4`], [`stable-diffusion-v1-5`]).\n\n| **checkpoint type** | **weight name**                     | **argument for loading weights** |\n| ------------------- | ----------------------------------- | -------------------------------- |\n| original            | diffusion_pytorch_model.bin         |                                  |\n| floating point      | diffusion_pytorch_model.fp16.bin    | `variant`, `torch_dtype`         |\n| non-EMA             | diffusion_pytorch_model.non_ema.bin | `variant`                        |\n\nvariant를 로드할 때 2개의 중요한 argument가 있습니다.\n\n* `torch_dtype`은 불러올 체크포인트의 부동소수점을 정의합니다. 예를 들어 `torch_dtype=torch.float16`을 명시함으로써 가중치의 부동소수점 타입을 `fl16`으로 변환할 수 있습니다. (만약 따로 설정하지 않을 경우, 기본값으로 `fp32` 타입의 가중치가 로딩됩니다.) 또한 `variant` 인자를 명시하지 않은 채로 체크포인트를 불러온 다음, 해당 체크포인트를 `torch_dtype=torch.float16` 인자를 통해 `fp16` 타입으로 변환하는 것 역시 가능합니다. 이 경우 기본으로 설정된 `fp32` 가중치가 먼저 다운로드되고, 해당 가중치들을 불러온 다음 `fp16` 타입으로 변환하게 됩니다.\n* `variant` 인자는 리포지토리에서 어떤 variant를 불러올 것인가를 정의합니다. 가령  [`diffusers/stable-diffusion-variants`](https://huggingface.co/diffusers/stable-diffusion-variants/tree/main/unet) 리포지토리로부터 `non_ema` 체크포인트를 불러오고자 한다면, `variant=\"non_ema\"` 인자를 전달해야 합니다.\n\n```python\nfrom diffusers import DiffusionPipeline\n\n# load fp16 variant\nstable_diffusion = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", variant=\"fp16\", torch_dtype=torch.float16\n)\n# load non_ema variant\nstable_diffusion = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", variant=\"non_ema\")\n```\n\n다른 부동소수점 타입의 가중치 혹은 non-EMA 가중치를 사용하는 체크포인트를 저장하기 위해서는, [`DiffusionPipeline.save_pretrained`] 메서드를 사용해야 하며, 이 때 `variant` 인자를 명시해줘야 합니다. 원래의 체크포인트와 동일한 폴더에 variant를 저장해야 하며, 이렇게 하면 동일한 폴더에서 오리지널 체크포인트과 variant를 모두 불러올 수 있습니다.\n\n```python\nfrom diffusers import DiffusionPipeline\n\n# save as fp16 variant\nstable_diffusion.save_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", variant=\"fp16\")\n# save as non-ema variant\nstable_diffusion.save_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", variant=\"non_ema\")\n```\n\n만약 variant를 기존 폴더에 저장하지 않을 경우, `variant` 인자를 반드시 명시해야 합니다. 그렇게 하지 않을 경우 원래의 오리지널 체크포인트를 찾을 수 없게 되기 때문에 에러가 발생합니다.\n\n```python\n# 👎 this won't work\nstable_diffusion = DiffusionPipeline.from_pretrained(\"./stable-diffusion-v1-5\", torch_dtype=torch.float16)\n# 👍 this works\nstable_diffusion = DiffusionPipeline.from_pretrained(\n    \"./stable-diffusion-v1-5\", variant=\"fp16\", torch_dtype=torch.float16\n)\n```\n\n### 모델 불러오기\n\n모델들은 [`ModelMixin.from_pretrained`] 메서드를 통해 불러올 수 있습니다. 해당 메서드는 최신 버전의 모델 가중치 파일과 설정 파일(configurations)을 다운로드하고 캐싱합니다. 만약 이러한 파일들이 최신 버전으로 로컬 캐시에 저장되어 있다면, [`ModelMixin.from_pretrained`]는 굳이 해당 파일들을 다시 다운로드하지 않으며, 그저 캐시에 있는 최신 파일들을 재사용합니다.\n\n모델은 `subfolder` 인자에 명시된 하위 폴더로부터 로드됩니다. 예를 들어 `stable-diffusion-v1-5/stable-diffusion-v1-5`의 UNet 모델의 가중치는 [`unet`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet) 폴더에 저장되어 있습니다.\n\n```python\nfrom diffusers import UNet2DConditionModel\n\nrepo_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nmodel = UNet2DConditionModel.from_pretrained(repo_id, subfolder=\"unet\")\n```\n\n혹은 [해당 모델의 리포지토리](https://huggingface.co/google/ddpm-cifar10-32/tree/main)로부터 다이렉트로 가져오는 것 역시 가능합니다.\n\n```python\nfrom diffusers import UNet2DModel\n\nrepo_id = \"google/ddpm-cifar10-32\"\nmodel = UNet2DModel.from_pretrained(repo_id)\n```\n\n또한 앞서 봤던 `variant` 인자를 명시함으로써, Non-EMA나 `fp16`의 가중치를 가져오는 것 역시 가능합니다.\n\n```python\nfrom diffusers import UNet2DConditionModel\n\nmodel = UNet2DConditionModel.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", subfolder=\"unet\", variant=\"non-ema\")\nmodel.save_pretrained(\"./local-unet\", variant=\"non-ema\")\n```\n\n### 스케줄러\n\n스케줄러들은 [`SchedulerMixin.from_pretrained`] 메서드를 통해 불러올 수 있습니다. 모델과 달리 스케줄러는 별도의 가중치를 갖지 않으며, 따라서 당연히 별도의 학습과정을 요구하지 않습니다. 이러한 스케줄러들은 (해당 스케줄러 하위폴더의) configration 파일을 통해 정의됩니다.\n\n여러개의 스케줄러를 불러온다고 해서 많은 메모리를 소모하는 것은 아니며, 다양한 스케줄러들에 동일한 스케줄러 configration을  적용하는 것 역시 가능합니다. 다음 예시 코드에서 불러오는 스케줄러들은 모두 [`StableDiffusionPipeline`]과 호환되는데, 이는 곧 해당 스케줄러들에 동일한 스케줄러 configration 파일을 적용할 수 있음을 의미합니다.\n\n```python\nfrom diffusers import StableDiffusionPipeline\nfrom diffusers import (\n    DDPMScheduler,\n    DDIMScheduler,\n    PNDMScheduler,\n    LMSDiscreteScheduler,\n    EulerDiscreteScheduler,\n    EulerAncestralDiscreteScheduler,\n    DPMSolverMultistepScheduler,\n)\n\nrepo_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n\nddpm = DDPMScheduler.from_pretrained(repo_id, subfolder=\"scheduler\")\nddim = DDIMScheduler.from_pretrained(repo_id, subfolder=\"scheduler\")\npndm = PNDMScheduler.from_pretrained(repo_id, subfolder=\"scheduler\")\nlms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder=\"scheduler\")\neuler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder=\"scheduler\")\neuler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder=\"scheduler\")\ndpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder=\"scheduler\")\n\n# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler_anc`, `euler`\npipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)\n```\n\n### DiffusionPipeline에 대해 알아보기\n\n클래스 메서드로서  [`DiffusionPipeline.from_pretrained`]은 2가지를 담당합니다.\n\n- 첫째로, `from_pretrained` 메서드는 최신 버전의 파이프라인을 다운로드하고, 캐시에 저장합니다. 이미 로컬 캐시에 최신 버전의 파이프라인이 저장되어 있다면, [`DiffusionPipeline.from_pretrained`]은 해당 파일들을 다시 다운로드하지 않고, 로컬 캐시에 저장되어 있는 파이프라인을 불러옵니다.\n-  `model_index.json` 파일을 통해 체크포인트에 대응되는 적합한 파이프라인 클래스로 불러옵니다.\n\n파이프라인의 폴더 구조는 해당 파이프라인 클래스의 구조와 직접적으로 일치합니다. 예를 들어 [`StableDiffusionPipeline`] 클래스는 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 리포지토리와 대응되는 구조를 갖습니다.\n\n```python\nfrom diffusers import DiffusionPipeline\n\nrepo_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipeline = DiffusionPipeline.from_pretrained(repo_id)\nprint(pipeline)\n```\n\n위의 코드 출력 결과를 확인해보면, `pipeline`은 [`StableDiffusionPipeline`]의 인스턴스이며, 다음과 같이 총 7개의 컴포넌트로 구성된다는 것을 알 수 있습니다.\n\n- `\"feature_extractor\"`: [`~transformers.CLIPImageProcessor`]의 인스턴스\n- `\"safety_checker\"`: 유해한 컨텐츠를 스크리닝하기 위한 [컴포넌트](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32)\n- `\"scheduler\"`: [`PNDMScheduler`]의 인스턴스\n- `\"text_encoder\"`: [`~transformers.CLIPTextModel`]의 인스턴스\n- `\"tokenizer\"`: a [`~transformers.CLIPTokenizer`]의 인스턴스\n- `\"unet\"`: [`UNet2DConditionModel`]의 인스턴스\n- `\"vae\"` [`AutoencoderKL`]의 인스턴스\n\n```json\nStableDiffusionPipeline {\n  \"feature_extractor\": [\n    \"transformers\",\n    \"CLIPImageProcessor\"\n  ],\n  \"safety_checker\": [\n    \"stable_diffusion\",\n    \"StableDiffusionSafetyChecker\"\n  ],\n  \"scheduler\": [\n    \"diffusers\",\n    \"PNDMScheduler\"\n  ],\n  \"text_encoder\": [\n    \"transformers\",\n    \"CLIPTextModel\"\n  ],\n  \"tokenizer\": [\n    \"transformers\",\n    \"CLIPTokenizer\"\n  ],\n  \"unet\": [\n    \"diffusers\",\n    \"UNet2DConditionModel\"\n  ],\n  \"vae\": [\n    \"diffusers\",\n    \"AutoencoderKL\"\n  ]\n}\n```\n\n파이프라인 인스턴스의 컴포넌트들을  [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)의 폴더 구조와 비교해볼 경우, 각각의 컴포넌트마다 별도의 폴더가 있음을 확인할 수 있습니다.\n\n```\n.\n├── feature_extractor\n│   └── preprocessor_config.json\n├── model_index.json\n├── safety_checker\n│   ├── config.json\n│   └── pytorch_model.bin\n├── scheduler\n│   └── scheduler_config.json\n├── text_encoder\n│   ├── config.json\n│   └── pytorch_model.bin\n├── tokenizer\n│   ├── merges.txt\n│   ├── special_tokens_map.json\n│   ├── tokenizer_config.json\n│   └── vocab.json\n├── unet\n│   ├── config.json\n│   ├── diffusion_pytorch_model.bin\n└── vae\n    ├── config.json\n    ├── diffusion_pytorch_model.bin\n```\n\n또한 각각의 컴포넌트들을 파이프라인 인스턴스의 속성으로써 참조할 수 있습니다.\n\n```py\npipeline.tokenizer\n```\n\n```python\nCLIPTokenizer(\n    name_or_path=\"/root/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819/tokenizer\",\n    vocab_size=49408,\n    model_max_length=77,\n    is_fast=False,\n    padding_side=\"right\",\n    truncation_side=\"right\",\n    special_tokens={\n        \"bos_token\": AddedToken(\"<|startoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True),\n        \"eos_token\": AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True),\n        \"unk_token\": AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=True),\n        \"pad_token\": \"<|endoftext|>\",\n    },\n)\n```\n\n모든 파이프라인은 `model_index.json` 파일을 통해 [`DiffusionPipeline`]에 다음과 같은 정보를 전달합니다.\n\n- `_class_name` 는 어떤 파이프라인 클래스를 사용해야 하는지에 대해 알려줍니다.\n- `_diffusers_version`는 어떤 버전의 디퓨저스로 파이프라인 안의 모델들이 만들어졌는지를 알려줍니다.\n- 그 다음은 각각의 컴포넌트들이 어떤 라이브러리의 어떤 클래스로 만들어졌는지에 대해 알려줍니다. (아래 예시에서 `\"feature_extractor\" : [\"transformers\", \"CLIPImageProcessor\"]`의 경우, `feature_extractor` 컴포넌트는 `transformers` 라이브러리의 `CLIPImageProcessor` 클래스를 통해 만들어졌다는 것을 의미합니다.)\n\n```json\n{\n  \"_class_name\": \"StableDiffusionPipeline\",\n  \"_diffusers_version\": \"0.6.0\",\n  \"feature_extractor\": [\n    \"transformers\",\n    \"CLIPImageProcessor\"\n  ],\n  \"safety_checker\": [\n    \"stable_diffusion\",\n    \"StableDiffusionSafetyChecker\"\n  ],\n  \"scheduler\": [\n    \"diffusers\",\n    \"PNDMScheduler\"\n  ],\n  \"text_encoder\": [\n    \"transformers\",\n    \"CLIPTextModel\"\n  ],\n  \"tokenizer\": [\n    \"transformers\",\n    \"CLIPTokenizer\"\n  ],\n  \"unet\": [\n    \"diffusers\",\n    \"UNet2DConditionModel\"\n  ],\n  \"vae\": [\n    \"diffusers\",\n    \"AutoencoderKL\"\n  ]\n}\n```\n\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/loading_adapters.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 어댑터 불러오기\n\n[[open-in-colab]]\n\n특정 물체의 이미지 또는 특정 스타일의 이미지를 생성하도록 diffusion 모델을 개인화하기 위한 몇 가지 [학습](../training/overview) 기법이 있습니다. 이러한 학습 방법은 각각 다른 유형의 어댑터를 생성합니다. 일부 어댑터는 완전히 새로운 모델을 생성하는 반면, 다른 어댑터는 임베딩 또는 가중치의 작은 부분만 수정합니다. 이는 각 어댑터의 로딩 프로세스도 다르다는 것을 의미합니다.\n\n이 가이드에서는 DreamBooth, textual inversion 및 LoRA 가중치를 불러오는 방법을 설명합니다.\n\n> [!TIP]\n> 사용할 체크포인트와 임베딩은 [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer), [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer), [Diffusers Models Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery)에서 찾아보시기 바랍니다.\n\n## DreamBooth\n\n[DreamBooth](https://dreambooth.github.io/)는 물체의 여러 이미지에 대한 *diffusion 모델 전체*를 미세 조정하여 새로운 스타일과 설정으로 해당 물체의 이미지를 생성합니다. 이 방법은 모델이 물체 이미지와 연관시키는 방법을 학습하는 프롬프트에 특수 단어를 사용하는 방식으로 작동합니다. 모든 학습 방법 중에서 드림부스는 전체 체크포인트 모델이기 때문에 파일 크기가 가장 큽니다(보통 몇 GB).\n\nHergé가 그린 단 10개의 이미지로 학습된 [herge_style](https://huggingface.co/sd-dreambooth-library/herge-style) 체크포인트를 불러와 해당 스타일의 이미지를 생성해 보겠습니다. 이 모델이 작동하려면 체크포인트를 트리거하는 프롬프트에 특수 단어 `herge_style`을 포함시켜야 합니다:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"sd-dreambooth-library/herge-style\", torch_dtype=torch.float16).to(\"cuda\")\nprompt = \"A cute herge_style brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration\"\nimage = pipeline(prompt).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_dreambooth.png\" />\n</div>\n\n## Textual inversion\n\n[Textual inversion](https://textual-inversion.github.io/)은 DreamBooth와 매우 유사하며 몇 개의 이미지만으로 특정 개념(스타일, 개체)을 생성하는 diffusion 모델을 개인화할 수도 있습니다. 이 방법은 프롬프트에 특정 단어를 입력하면 해당 이미지를 나타내는 새로운 임베딩을 학습하고 찾아내는 방식으로 작동합니다. 결과적으로 diffusion 모델 가중치는 동일하게 유지되고 훈련 프로세스는 비교적 작은(수 KB) 파일을 생성합니다.\n\nTextual inversion은 임베딩을 생성하기 때문에 DreamBooth처럼 단독으로 사용할 수 없으며 또 다른 모델이 필요합니다.\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16).to(\"cuda\")\n```\n\n이제 [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] 메서드를 사용하여 textual inversion 임베딩을 불러와 이미지를 생성할 수 있습니다. [sd-concepts-library/gta5-artwork](https://huggingface.co/sd-concepts-library/gta5-artwork) 임베딩을 불러와 보겠습니다. 이를 트리거하려면 프롬프트에 특수 단어 `<gta5-artwork>`를 포함시켜야 합니다:\n\n```py\npipeline.load_textual_inversion(\"sd-concepts-library/gta5-artwork\")\nprompt = \"A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration, <gta5-artwork> style\"\nimage = pipeline(prompt).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_txt_embed.png\" />\n</div>\n\nTextual inversion은 또한 바람직하지 않은 사물에 대해 *네거티브 임베딩*을 생성하여 모델이 흐릿한 이미지나 손의 추가 손가락과 같은 바람직하지 않은 사물이 포함된 이미지를 생성하지 못하도록 학습할 수도 있습니다. 이는 프롬프트를 빠르게 개선하는 것이 쉬운 방법이 될 수 있습니다. 이는 이전과 같이 임베딩을 [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`]으로 불러오지만 이번에는 두 개의 매개변수가 더 필요합니다:\n\n- `weight_name`: 파일이 특정 이름의 🤗 Diffusers 형식으로 저장된 경우이거나 파일이 A1111 형식으로 저장된 경우, 불러올 가중치 파일을 지정합니다.\n- `token`: 임베딩을 트리거하기 위해 프롬프트에서 사용할 특수 단어를 지정합니다.\n\n[sayakpaul/EasyNegative-test](https://huggingface.co/sayakpaul/EasyNegative-test) 임베딩을 불러와 보겠습니다:\n\n```py\npipeline.load_textual_inversion(\n    \"sayakpaul/EasyNegative-test\", weight_name=\"EasyNegative.safetensors\", token=\"EasyNegative\"\n)\n```\n\n이제 `token`을 사용해 네거티브 임베딩이 있는 이미지를 생성할 수 있습니다:\n\n```py\nprompt = \"A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration, EasyNegative\"\nnegative_prompt = \"EasyNegative\"\n\nimage = pipeline(prompt, negative_prompt=negative_prompt, num_inference_steps=50).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png\" />\n</div>\n\n## LoRA\n\n[Low-Rank Adaptation (LoRA)](https://huggingface.co/papers/2106.09685)은 속도가 빠르고 파일 크기가 (수백 MB로) 작기 때문에 널리 사용되는 학습 기법입니다. 이 가이드의 다른 방법과 마찬가지로, LoRA는 몇 장의 이미지만으로 새로운 스타일을 학습하도록 모델을 학습시킬 수 있습니다. 이는 diffusion 모델에 새로운 가중치를 삽입한 다음 전체 모델 대신 새로운 가중치만 학습시키는 방식으로 작동합니다. 따라서 LoRA를 더 빠르게 학습시키고 더 쉽게 저장할 수 있습니다.\n\n> [!TIP]\n> LoRA는 다른 학습 방법과 함께 사용할 수 있는 매우 일반적인 학습 기법입니다. 예를 들어, DreamBooth와 LoRA로 모델을 학습하는 것이 일반적입니다. 또한 새롭고 고유한 이미지를 생성하기 위해 여러 개의 LoRA를 불러오고 병합하는 것이 점점 더 일반화되고 있습니다. 병합은 이 불러오기 가이드의 범위를 벗어나므로 자세한 내용은 심층적인 [LoRA 병합](merge_loras) 가이드에서 확인할 수 있습니다.\n\nLoRA는 다른 모델과 함께 사용해야 합니다:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16).to(\"cuda\")\n```\n\n그리고 [`~loaders.LoraLoaderMixin.load_lora_weights`] 메서드를 사용하여 [ostris/super-cereal-sdxl-lora](https://huggingface.co/ostris/super-cereal-sdxl-lora) 가중치를 불러오고 리포지토리에서 가중치 파일명을 지정합니다:\n\n```py\npipeline.load_lora_weights(\"ostris/super-cereal-sdxl-lora\", weight_name=\"cereal_box_sdxl_v1.safetensors\")\nprompt = \"bears, pizza bites\"\nimage = pipeline(prompt).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_lora.png\" />\n</div>\n\n[`~loaders.LoraLoaderMixin.load_lora_weights`] 메서드는 LoRA 가중치를 UNet과 텍스트 인코더에 모두 불러옵니다. 이 메서드는 해당 케이스에서 LoRA를 불러오는 데 선호되는 방식입니다:\n\n- LoRA 가중치에 UNet 및 텍스트 인코더에 대한 별도의 식별자가 없는 경우\n- LoRA 가중치에 UNet과 텍스트 인코더에 대한 별도의 식별자가 있는 경우\n\n하지만 LoRA 가중치만 UNet에 로드해야 하는 경우에는 [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] 메서드를 사용할 수 있습니다. [jbilcke-hf/sdxl-cinematic-1](https://huggingface.co/jbilcke-hf/sdxl-cinematic-1) LoRA를 불러와 보겠습니다:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16).to(\"cuda\")\npipeline.unet.load_attn_procs(\"jbilcke-hf/sdxl-cinematic-1\", weight_name=\"pytorch_lora_weights.safetensors\")\n\n# 프롬프트에서 cnmt를 사용하여 LoRA를 트리거합니다.\nprompt = \"A cute cnmt eating a slice of pizza, stunning color scheme, masterpiece, illustration\"\nimage = pipeline(prompt).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png\" />\n</div>\n\nLoRA 가중치를 언로드하려면 [`~loaders.LoraLoaderMixin.unload_lora_weights`] 메서드를 사용하여 LoRA 가중치를 삭제하고 모델을 원래 가중치로 복원합니다:\n\n```py\npipeline.unload_lora_weights()\n```\n\n### LoRA 가중치 스케일 조정하기\n\n[`~loaders.LoraLoaderMixin.load_lora_weights`] 및 [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] 모두 `cross_attention_kwargs={\"scale\": 0.5}` 파라미터를 전달하여 얼마나 LoRA 가중치를 사용할지 조정할 수 있습니다. 값이 `0`이면 기본 모델 가중치만 사용하는 것과 같고, 값이 `1`이면 완전히 미세 조정된 LoRA를 사용하는 것과 같습니다.\n\n레이어당 사용되는 LoRA 가중치의 양을 보다 세밀하게 제어하려면 [`~loaders.LoraLoaderMixin.set_adapters`]를 사용하여 각 레이어의 가중치를 얼마만큼 조정할지 지정하는 딕셔너리를 전달할 수 있습니다.\n```python\npipe = ... # 파이프라인 생성\npipe.load_lora_weights(..., adapter_name=\"my_adapter\")\nscales = {\n    \"text_encoder\": 0.5,\n    \"text_encoder_2\": 0.5,  # 파이프에 두 번째 텍스트 인코더가 있는 경우에만 사용 가능\n    \"unet\": {\n        \"down\": 0.9,  # down 부분의 모든 트랜스포머는 스케일 0.9를 사용\n        # \"mid\"  # 이 예제에서는 \"mid\"가 지정되지 않았으므로 중간 부분의 모든 트랜스포머는 기본 스케일 1.0을 사용\n        \"up\": {\n            \"block_0\": 0.6,  # # up의 0번째 블록에 있는 3개의 트랜스포머는 모두 스케일 0.6을 사용\n            \"block_1\": [0.4, 0.8, 1.0],  # up의 첫 번째 블록에 있는 3개의 트랜스포머는 각각 스케일 0.4, 0.8, 1.0을 사용\n        }\n    }\n}\npipe.set_adapters(\"my_adapter\", scales)\n```\n\n이는 여러 어댑터에서도 작동합니다. 방법은 [이 가이드](https://huggingface.co/docs/diffusers/tutorials/using_peft_for_inference#customize-adapters-strength)를 참조하세요.\n\n> [!WARNING]\n> 현재 [`~loaders.LoraLoaderMixin.set_adapters`]는 어텐션 가중치의 스케일링만 지원합니다. LoRA에 다른 부분(예: resnets or down-/upsamplers)이 있는 경우 1.0의 스케일을 유지합니다.\n\n### Kohya와 TheLastBen\n\n커뮤니티에서 인기 있는 다른 LoRA trainer로는 [Kohya](https://github.com/kohya-ss/sd-scripts/)와 [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion)의 trainer가 있습니다. 이 trainer들은 🤗 Diffusers가 훈련한 것과는 다른 LoRA 체크포인트를 생성하지만, 같은 방식으로 불러올 수 있습니다.\n\n<hfoptions id=\"other-trainers\">\n<hfoption id=\"Kohya\">\n\nKohya LoRA를 불러오기 위해, 예시로 [Civitai](https://civitai.com/)에서 [Blueprintify SD XL 1.0](https://civitai.com/models/150986/blueprintify-sd-xl-10) 체크포인트를 다운로드합니다:\n\n```sh\n!wget https://civitai.com/api/download/models/168776 -O blueprintify-sd-xl-10.safetensors\n```\n\nLoRA 체크포인트를 [`~loaders.LoraLoaderMixin.load_lora_weights`] 메서드로 불러오고 `weight_name` 파라미터에 파일명을 지정합니다:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16).to(\"cuda\")\npipeline.load_lora_weights(\"path/to/weights\", weight_name=\"blueprintify-sd-xl-10.safetensors\")\n```\n\n이미지를 생성합니다:\n\n```py\n# LoRA를 트리거하기 위해 bl3uprint를 프롬프트에 사용\nprompt = \"bl3uprint, a highly detailed blueprint of the eiffel tower, explaining how to build all parts, many txt, blueprint grid backdrop\"\nimage = pipeline(prompt).images[0]\nimage\n```\n\n> [!WARNING]\n> Kohya LoRA를 🤗 Diffusers와 함께 사용할 때 몇 가지 제한 사항이 있습니다:\n>\n> - [여기](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736)에 설명된 여러 가지 이유로 인해 이미지가 ComfyUI와 같은 UI에서 생성된 이미지와 다르게 보일 수 있습니다.\n> - [LyCORIS 체크포인트](https://github.com/KohakuBlueleaf/LyCORIS)가 완전히 지원되지 않습니다. [`~loaders.LoraLoaderMixin.load_lora_weights`] 메서드는 LoRA 및 LoCon 모듈로 LyCORIS 체크포인트를 불러올 수 있지만, Hada 및 LoKR은 지원되지 않습니다.\n\n</hfoption>\n<hfoption id=\"TheLastBen\">\n\nTheLastBen에서 체크포인트를 불러오는 방법은 매우 유사합니다. 예를 들어, [TheLastBen/William_Eggleston_Style_SDXL](https://huggingface.co/TheLastBen/William_Eggleston_Style_SDXL) 체크포인트를 불러오려면:\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16).to(\"cuda\")\npipeline.load_lora_weights(\"TheLastBen/William_Eggleston_Style_SDXL\", weight_name=\"wegg.safetensors\")\n\n# LoRA를 트리거하기 위해 william eggleston를 프롬프트에 사용\nprompt = \"a house by william eggleston, sunrays, beautiful, sunlight, sunrays, beautiful\"\nimage = pipeline(prompt=prompt).images[0]\nimage\n```\n\n</hfoption>\n</hfoptions>\n\n## IP-Adapter\n\n[IP-Adapter](https://ip-adapter.github.io/)는 모든 diffusion 모델에 이미지 프롬프트를 사용할 수 있는 경량 어댑터입니다. 이 어댑터는 이미지와 텍스트 feature의 cross-attention 레이어를 분리하여 작동합니다. 다른 모든 모델 컴포넌트튼 freeze되고 UNet의 embedded 이미지 features만 학습됩니다. 따라서 IP-Adapter 파일은 일반적으로 최대 100MB에 불과합니다.\n\n다양한 작업과 구체적인 사용 사례에 IP-Adapter를 사용하는 방법에 대한 자세한 내용은 [IP-Adapter](../using-diffusers/ip_adapter) 가이드에서 확인할 수 있습니다.\n\n> [!TIP]\n> Diffusers는 현재 가장 많이 사용되는 일부 파이프라인에 대해서만 IP-Adapter를 지원합니다. 멋진 사용 사례가 있는 지원되지 않는 파이프라인에 IP-Adapter를 통합하고 싶다면 언제든지 기능 요청을 여세요!\n> 공식 IP-Adapter 체크포인트는 [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter)에서 확인할 수 있습니다.\n\n시작하려면 Stable Diffusion 체크포인트를 불러오세요.\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\nfrom diffusers.utils import load_image\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16).to(\"cuda\")\n```\n\n그런 다음 IP-Adapter 가중치를 불러와 [`~loaders.IPAdapterMixin.load_ip_adapter`] 메서드를 사용하여 파이프라인에 추가합니다.\n\n```py\npipeline.load_ip_adapter(\"h94/IP-Adapter\", subfolder=\"models\", weight_name=\"ip-adapter_sd15.bin\")\n```\n\n불러온 뒤, 이미지 및 텍스트 프롬프트가 있는 파이프라인을 사용하여 이미지 생성 프로세스를 가이드할 수 있습니다.\n\n```py\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png\")\ngenerator = torch.Generator(device=\"cpu\").manual_seed(33)\nimages = pipeline(\n    prompt='best quality, high quality, wearing sunglasses',\n    ip_adapter_image=image,\n    negative_prompt=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n    num_inference_steps=50,\n    generator=generator,\n).images[0]\nimages\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip-bear.png\" />\n</div>\n\n### IP-Adapter Plus\n\nIP-Adapter는 이미지 인코더를 사용하여 이미지 feature를 생성합니다. IP-Adapter 리포지토리에 `image_encoder` 하위 폴더가 있는 경우, 이미지 인코더가 자동으로 불러와 파이프라인에 등록됩니다. 그렇지 않은 경우, [`~transformers.CLIPVisionModelWithProjection`] 모델을 사용하여 이미지 인코더를 명시적으로 불러와 파이프라인에 전달해야 합니다.\n\n이는 ViT-H 이미지 인코더를 사용하는 *IP-Adapter Plus* 체크포인트에 해당하는 케이스입니다.\n\n```py\nfrom transformers import CLIPVisionModelWithProjection\n\nimage_encoder = CLIPVisionModelWithProjection.from_pretrained(\n    \"h94/IP-Adapter\",\n    subfolder=\"models/image_encoder\",\n    torch_dtype=torch.float16\n)\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    image_encoder=image_encoder,\n    torch_dtype=torch.float16\n).to(\"cuda\")\n\npipeline.load_ip_adapter(\"h94/IP-Adapter\", subfolder=\"sdxl_models\", weight_name=\"ip-adapter-plus_sdxl_vit-h.safetensors\")\n```\n\n### IP-Adapter Face ID 모델\n\nIP-Adapter FaceID 모델은 CLIP 이미지 임베딩 대신 `insightface`에서 생성한 이미지 임베딩을 사용하는 실험적인 IP Adapter입니다. 이러한 모델 중 일부는 LoRA를 사용하여 ID 일관성을 개선하기도 합니다.\n이러한 모델을 사용하려면 `insightface`와 해당 요구 사항을 모두 설치해야 합니다.\n\n> [!WARNING]\n> InsightFace 사전학습된 모델은 비상업적 연구 목적으로만 사용할 수 있으므로, IP-Adapter-FaceID 모델은 연구 목적으로만 릴리즈되었으며 상업적 용도로는 사용할 수 없습니다.\n\n```py\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16\n).to(\"cuda\")\n\npipeline.load_ip_adapter(\"h94/IP-Adapter-FaceID\", subfolder=None, weight_name=\"ip-adapter-faceid_sdxl.bin\", image_encoder_folder=None)\n```\n\n두 가지 IP 어댑터 FaceID Plus 모델 중 하나를 사용하려는 경우, 이 모델들은 더 나은 사실감을 얻기 위해 `insightface`와 CLIP 이미지 임베딩을 모두 사용하므로, CLIP 이미지 인코더도 불러와야 합니다.\n\n```py\nfrom transformers import CLIPVisionModelWithProjection\n\nimage_encoder = CLIPVisionModelWithProjection.from_pretrained(\n    \"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\",\n    torch_dtype=torch.float16,\n)\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    image_encoder=image_encoder,\n    torch_dtype=torch.float16\n).to(\"cuda\")\n\npipeline.load_ip_adapter(\"h94/IP-Adapter-FaceID\", subfolder=None, weight_name=\"ip-adapter-faceid-plus_sd15.bin\")\n```\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/other-formats.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 다양한 Stable Diffusion 포맷 불러오기\n\nStable Diffusion 모델들은 학습 및 저장된 프레임워크와 다운로드 위치에 따라 다양한 형식으로 제공됩니다. 이러한 형식을 🤗 Diffusers에서 사용할 수 있도록 변환하면 추론을 위한 [다양한 스케줄러 사용](schedulers), 사용자 지정 파이프라인 구축, 추론 속도 최적화를 위한 다양한 기법과 방법 등 라이브러리에서 지원하는 모든 기능을 사용할 수 있습니다.\n\n> [!TIP]\n> 우리는 `.safetensors` 형식을 추천합니다. 왜냐하면 기존의 pickled 파일은 취약하고 머신에서 코드를 실행할 때 악용될 수 있는 것에 비해 훨씬 더 안전합니다. (safetensors 불러오기 가이드에서 자세히 알아보세요.)\n\n이 가이드에서는 다른 Stable Diffusion 형식을 🤗 Diffusers와 호환되도록 변환하는 방법을 설명합니다.\n\n## PyTorch .ckpt\n\n체크포인트 또는 `.ckpt` 형식은 일반적으로 모델을 저장하는 데 사용됩니다. `.ckpt` 파일은 전체 모델을 포함하며 일반적으로 크기가 몇 GB입니다. `.ckpt` 파일을 [~StableDiffusionPipeline.from_ckpt] 메서드를 사용하여 직접 불러와서 사용할 수도 있지만, 일반적으로 두 가지 형식을 모두 사용할 수 있도록 `.ckpt` 파일을 🤗 Diffusers로 변환하는 것이 더 좋습니다.\n\n`.ckpt` 파일을 변환하는 두 가지 옵션이 있습니다. Space를 사용하여 체크포인트를 변환하거나 스크립트를 사용하여 `.ckpt` 파일을 변환합니다.\n\n### Space로 변환하기\n\n`.ckpt` 파일을 변환하는 가장 쉽고 편리한 방법은 SD에서 Diffusers로 스페이스를 사용하는 것입니다. Space의 지침에 따라 .ckpt 파일을 변환 할 수 있습니다.\n\n이 접근 방식은 기본 모델에서는 잘 작동하지만 더 많은 사용자 정의 모델에서는 어려움을 겪을 수 있습니다. 빈 pull request나 오류를 반환하면 Space가 실패한 것입니다.\n이 경우 스크립트를 사용하여 `.ckpt` 파일을 변환해 볼 수 있습니다.\n\n### 스크립트로 변환하기\n\n🤗 Diffusers는 `.ckpt`  파일 변환을 위한 변환 스크립트를 제공합니다. 이 접근 방식은 위의 Space보다 더 안정적입니다.\n\n시작하기 전에 스크립트를 실행할 🤗 Diffusers의 로컬 클론(clone)이 있는지 확인하고 Hugging Face 계정에 로그인하여 pull request를 열고 변환된 모델을 허브에 푸시할 수 있도록 하세요.\n\n```bash\nhf auth login\n```\n\n스크립트를 사용하려면:\n\n1. 변환하려는 `.ckpt`  파일이 포함된 리포지토리를 Git으로 클론(clone)합니다.\n\n이 예제에서는 TemporalNet .ckpt 파일을 변환해 보겠습니다:\n\n```bash\ngit lfs install\ngit clone https://huggingface.co/CiaraRowles/TemporalNet\n```\n\n2. 체크포인트를 변환할 리포지토리에서 pull request를 엽니다:\n\n```bash\ncd TemporalNet && git fetch origin refs/pr/13:pr/13\ngit checkout pr/13\n```\n\n3. 변환 스크립트에서 구성할 입력 인수는 여러 가지가 있지만 가장 중요한 인수는 다음과 같습니다:\n\n- `checkpoint_path`: 변환할 `.ckpt` 파일의 경로를 입력합니다.\n- `original_config_file`: 원래 아키텍처의 구성을 정의하는 YAML 파일입니다. 이 파일을 찾을 수 없는 경우 `.ckpt` 파일을 찾은 GitHub 리포지토리에서 YAML 파일을 검색해 보세요.\n- `dump_path`: 변환된 모델의 경로\n\n예를 들어, TemporalNet 모델은 Stable Diffusion v1.5 및 ControlNet 모델이기 때문에 ControlNet 리포지토리에서 cldm_v15.yaml 파일을 가져올 수 있습니다.\n\n4. 이제 스크립트를 실행하여 .ckpt 파일을 변환할 수 있습니다:\n\n```bash\npython ../diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path temporalnetv3.ckpt --original_config_file cldm_v15.yaml --dump_path ./ --controlnet\n```\n\n5. 변환이 완료되면 변환된 모델을 업로드하고 결과물을 pull request [pull request](https://huggingface.co/CiaraRowles/TemporalNet/discussions/13)를 테스트하세요!\n\n```bash\ngit push origin pr/13:refs/pr/13\n```\n\n## **Keras .pb or .h5**\n\n🧪 이 기능은 실험적인 기능입니다. 현재로서는 Stable Diffusion v1 체크포인트만 변환 KerasCV Space에서 지원됩니다.\n\n[KerasCV](https://keras.io/keras_cv/)는 [Stable Diffusion](https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/stable_diffusion)  v1 및 v2에 대한 학습을 지원합니다. 그러나 추론 및 배포를 위한 Stable Diffusion 모델 실험을 제한적으로 지원하는 반면, 🤗 Diffusers는 다양한 [noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers), [flash attention](https://huggingface.co/docs/diffusers/optimization/xformers), and [other optimization techniques](https://huggingface.co/docs/diffusers/optimization/fp16) 등 이러한 목적을 위한 보다 완벽한 기능을 갖추고 있습니다.\n\n[Convert KerasCV](https://huggingface.co/spaces/sayakpaul/convert-kerascv-sd-diffusers) Space 변환은 `.pb` 또는 `.h5`을 PyTorch로 변환한 다음, 추론할 수 있도록 [`StableDiffusionPipeline`] 으로 감싸서 준비합니다. 변환된 체크포인트는 Hugging Face Hub의 리포지토리에 저장됩니다.\n\n예제로, textual-inversion으로 학습된 `[sayakpaul/textual-inversion-kerasio](https://huggingface.co/sayakpaul/textual-inversion-kerasio/tree/main)` 체크포인트를 변환해 보겠습니다. 이것은 특수 토큰  `<my-funny-cat>`을 사용하여 고양이로 이미지를 개인화합니다.\n\nKerasCV Space 변환에서는 다음을 입력할 수 있습니다:\n\n- Hugging Face 토큰.\n- UNet 과 텍스트 인코더(text encoder) 가중치를 다운로드하는 경로입니다. 모델을 어떻게 학습할지 방식에 따라, UNet과 텍스트 인코더의 경로를 모두 제공할 필요는 없습니다. 예를 들어, textual-inversion에는 텍스트 인코더의 임베딩만 필요하고 텍스트-이미지(text-to-image) 모델 변환에는 UNet 가중치만 필요합니다.\n- Placeholder 토큰은 textual-inversion 모델에만 적용됩니다.\n- `output_repo_prefix`는 변환된 모델이 저장되는 리포지토리의 이름입니다.\n\n**Submit** (제출) 버튼을 클릭하면 KerasCV 체크포인트가 자동으로 변환됩니다! 체크포인트가 성공적으로 변환되면, 변환된 체크포인트가 포함된 새 리포지토리로 연결되는 링크가 표시됩니다. 새 리포지토리로 연결되는 링크를 따라가면 변환된 모델을 사용해 볼 수 있는 추론 위젯이 포함된 모델 카드가 생성된 KerasCV Space 변환을 확인할 수 있습니다.\n\n코드를 사용하여 추론을 실행하려면 모델 카드의 오른쪽 상단 모서리에 있는 **Use in Diffusers**  버튼을 클릭하여 예시 코드를 복사하여 붙여넣습니다:\n\n```py\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\"sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline\")\n```\n\n그러면 다음과 같은 이미지를 생성할 수 있습니다:\n\n```py\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\"sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline\")\npipeline.to(\"cuda\")\n\nplaceholder_token = \"<my-funny-cat-token>\"\nprompt = f\"two {placeholder_token} getting married, photorealistic, high quality\"\nimage = pipeline(prompt, num_inference_steps=50).images[0]\n```\n\n## **A1111 LoRA files**\n\n[Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) (A1111)은 Stable Diffusion을 위해 널리 사용되는 웹 UI로, [Civitai](https://civitai.com/) 와 같은 모델 공유 플랫폼을 지원합니다. 특히 LoRA 기법으로 학습된 모델은 학습 속도가 빠르고 완전히 파인튜닝된 모델보다 파일 크기가 훨씬 작기 때문에 인기가 높습니다.\n\n🤗 Diffusers는 [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]:를 사용하여 A1111 LoRA 체크포인트 불러오기를 지원합니다:\n\n```py\nfrom diffusers import DiffusionPipeline, UniPCMultistepScheduler\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"andite/anything-v4.0\", torch_dtype=torch.float16, safety_checker=None\n).to(\"cuda\")\npipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)\n```\n\nCivitai에서 LoRA 체크포인트를 다운로드하세요; 이 예제에서는  [Howls Moving Castle,Interior/Scenery LoRA (Ghibli Stlye)](https://civitai.com/models/14605?modelVersionId=19998) 체크포인트를 사용했지만, 어떤 LoRA 체크포인트든 자유롭게 사용해 보세요!\n\n```bash\n!wget https://civitai.com/api/download/models/19998 -O howls_moving_castle.safetensors\n```\n\n메서드를 사용하여 파이프라인에 LoRA 체크포인트를 불러옵니다:\n\n```py\npipeline.load_lora_weights(\".\", weight_name=\"howls_moving_castle.safetensors\")\n```\n\n이제 파이프라인을 사용하여 이미지를 생성할 수 있습니다:\n\n```py\nprompt = \"masterpiece, illustration, ultra-detailed, cityscape, san francisco, golden gate bridge, california, bay area, in the snow, beautiful detailed starry sky\"\nnegative_prompt = \"lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture\"\n\nimages = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    width=512,\n    height=512,\n    num_inference_steps=25,\n    num_images_per_prompt=4,\n    generator=torch.manual_seed(0),\n).images\n```\n\n마지막으로, 디스플레이에 이미지를 표시하는 헬퍼 함수를 만듭니다:\n\n```py\nfrom PIL import Image\n\n\ndef image_grid(imgs, rows=2, cols=2):\n    w, h = imgs[0].size\n    grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i % cols * w, i // cols * h))\n    return grid\n\n\nimage_grid(images)\n```\n\n<div class=\"flex justify-center\">\n  <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/a1111-lora-sf.png\" />\n</div>\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/push_to_hub.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 파일들을 Hub로 푸시하기\n\n[[open-in-colab]]\n\n🤗 Diffusers는 모델, 스케줄러 또는 파이프라인을 Hub에 업로드할 수 있는 [`~diffusers.utils.PushToHubMixin`]을 제공합니다. 이는 Hub에 당신의 파일을 저장하는 쉬운 방법이며, 다른 사람들과 작업을 공유할 수도 있습니다. 실제적으로 [`~diffusers.utils.PushToHubMixin`]가 동작하는 방식은 다음과 같습니다:\n\n1. Hub에 리포지토리를 생성합니다.\n2. 나중에 다시 불러올 수 있도록 모델, 스케줄러 또는 파이프라인 파일을 저장합니다.\n3. 이러한 파일이 포함된 폴더를 Hub에 업로드합니다.\n\n이 가이드는 [`~diffusers.utils.PushToHubMixin`]을 사용하여 Hub에 파일을 업로드하는 방법을 보여줍니다.\n\n먼저 액세스 [토큰](https://huggingface.co/settings/tokens)으로 Hub 계정에 로그인해야 합니다:\n\n```py\nfrom huggingface_hub import notebook_login\n\nnotebook_login()\n```\n\n## 모델\n\n모델을 허브에 푸시하려면 [`~diffusers.utils.PushToHubMixin.push_to_hub`]를 호출하고 Hub에 저장할 모델의 리포지토리 id를 지정합니다:\n\n```py\nfrom diffusers import ControlNetModel\n\ncontrolnet = ControlNetModel(\n    block_out_channels=(32, 64),\n    layers_per_block=2,\n    in_channels=4,\n    down_block_types=(\"DownBlock2D\", \"CrossAttnDownBlock2D\"),\n    cross_attention_dim=32,\n    conditioning_embedding_out_channels=(16, 32),\n)\ncontrolnet.push_to_hub(\"my-controlnet-model\")\n```\n\n모델의 경우 Hub에 푸시할 가중치의 [*변형*](loading#checkpoint-variants)을 지정할 수도 있습니다. 예를 들어, `fp16` 가중치를 푸시하려면 다음과 같이 하세요:\n\n```py\ncontrolnet.push_to_hub(\"my-controlnet-model\", variant=\"fp16\")\n```\n\n[`~diffusers.utils.PushToHubMixin.push_to_hub`] 함수는 모델의 `config.json` 파일을 저장하고 가중치는 `safetensors` 형식으로 자동으로 저장됩니다.\n\n이제 Hub의 리포지토리에서 모델을 다시 불러올 수 있습니다:\n\n```py\nmodel = ControlNetModel.from_pretrained(\"your-namespace/my-controlnet-model\")\n```\n\n## 스케줄러\n\n스케줄러를 허브에 푸시하려면 [`~diffusers.utils.PushToHubMixin.push_to_hub`]를 호출하고 Hub에 저장할 스케줄러의 리포지토리 id를 지정합니다:\n\n```py\nfrom diffusers import DDIMScheduler\n\nscheduler = DDIMScheduler(\n    beta_start=0.00085,\n    beta_end=0.012,\n    beta_schedule=\"scaled_linear\",\n    clip_sample=False,\n    set_alpha_to_one=False,\n)\nscheduler.push_to_hub(\"my-controlnet-scheduler\")\n```\n\n[`~diffusers.utils.PushToHubMixin.push_to_hub`] 함수는 스케줄러의 `scheduler_config.json` 파일을 지정된 리포지토리에 저장합니다.\n\n이제 허브의 리포지토리에서 스케줄러를 다시 불러올 수 있습니다:\n\n```py\nscheduler = DDIMScheduler.from_pretrained(\"your-namepsace/my-controlnet-scheduler\")\n```\n\n## 파이프라인\n\n모든 컴포넌트가 포함된 전체 파이프라인을 Hub로 푸시할 수도 있습니다. 예를 들어, 원하는 파라미터로 [`StableDiffusionPipeline`]의 컴포넌트들을 초기화합니다:\n\n```py\nfrom diffusers import (\n    UNet2DConditionModel,\n    AutoencoderKL,\n    DDIMScheduler,\n    StableDiffusionPipeline,\n)\nfrom transformers import CLIPTextModel, CLIPTextConfig, CLIPTokenizer\n\nunet = UNet2DConditionModel(\n    block_out_channels=(32, 64),\n    layers_per_block=2,\n    sample_size=32,\n    in_channels=4,\n    out_channels=4,\n    down_block_types=(\"DownBlock2D\", \"CrossAttnDownBlock2D\"),\n    up_block_types=(\"CrossAttnUpBlock2D\", \"UpBlock2D\"),\n    cross_attention_dim=32,\n)\n\nscheduler = DDIMScheduler(\n    beta_start=0.00085,\n    beta_end=0.012,\n    beta_schedule=\"scaled_linear\",\n    clip_sample=False,\n    set_alpha_to_one=False,\n)\n\nvae = AutoencoderKL(\n    block_out_channels=[32, 64],\n    in_channels=3,\n    out_channels=3,\n    down_block_types=[\"DownEncoderBlock2D\", \"DownEncoderBlock2D\"],\n    up_block_types=[\"UpDecoderBlock2D\", \"UpDecoderBlock2D\"],\n    latent_channels=4,\n)\n\ntext_encoder_config = CLIPTextConfig(\n    bos_token_id=0,\n    eos_token_id=2,\n    hidden_size=32,\n    intermediate_size=37,\n    layer_norm_eps=1e-05,\n    num_attention_heads=4,\n    num_hidden_layers=5,\n    pad_token_id=1,\n    vocab_size=1000,\n)\ntext_encoder = CLIPTextModel(text_encoder_config)\ntokenizer = CLIPTokenizer.from_pretrained(\"hf-internal-testing/tiny-random-clip\")\n```\n\n모든 컴포넌트들을 [`StableDiffusionPipeline`]에 전달하고 [`~diffusers.utils.PushToHubMixin.push_to_hub`]를 호출하여 파이프라인을 Hub로 푸시합니다:\n\n```py\ncomponents = {\n    \"unet\": unet,\n    \"scheduler\": scheduler,\n    \"vae\": vae,\n    \"text_encoder\": text_encoder,\n    \"tokenizer\": tokenizer,\n    \"safety_checker\": None,\n    \"feature_extractor\": None,\n}\n\npipeline = StableDiffusionPipeline(**components)\npipeline.push_to_hub(\"my-pipeline\")\n```\n\n[`~diffusers.utils.PushToHubMixin.push_to_hub`] 함수는 각 컴포넌트를 리포지토리의 하위 폴더에 저장합니다. 이제 Hub의 리포지토리에서 파이프라인을 다시 불러올 수 있습니다:\n\n```py\npipeline = StableDiffusionPipeline.from_pretrained(\"your-namespace/my-pipeline\")\n```\n\n## 비공개\n\n모델, 스케줄러 또는 파이프라인 파일들을 비공개로 두려면 [`~diffusers.utils.PushToHubMixin.push_to_hub`] 함수에서 `private=True`를 설정하세요:\n\n```py\ncontrolnet.push_to_hub(\"my-controlnet-model-private\", private=True)\n```\n\n비공개 리포지토리는 본인만 볼 수 있으며 다른 사용자는 리포지토리를 복제할 수 없고 리포지토리가 검색 결과에 표시되지 않습니다. 사용자가 비공개 리포지토리의 URL을 가지고 있더라도 `404 - Sorry, we can't find the page you are looking for`라는 메시지가 표시됩니다. 비공개 리포지토리에서 모델을 로드하려면 [로그인](https://huggingface.co/docs/huggingface_hub/quick-start#login) 상태여야 합니다."
  },
  {
    "path": "docs/source/ko/using-diffusers/schedulers.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 스케줄러\n\ndiffusion 파이프라인은 diffusion 모델, 스케줄러 등의 컴포넌트들로 구성됩니다. 그리고 파이프라인 안의 일부 컴포넌트를 다른 컴포넌트로 교체하는 식의 커스터마이징 역시 가능합니다.  이와 같은 컴포넌트 커스터마이징의 가장 대표적인 예시가 바로 [스케줄러](../api/schedulers/overview.md)를 교체하는 것입니다.\n\n\n\n스케쥴러는 다음과 같이 diffusion 시스템의 전반적인 디노이징 프로세스를 정의합니다.\n\n- 디노이징 스텝을 얼마나 가져가야 할까?\n- 확률적으로(stochastic) 혹은 확정적으로(deterministic)?\n- 디노이징 된 샘플을 찾아내기 위해 어떤 알고리즘을 사용해야 할까?\n\n이러한 프로세스는 다소 난해하고, 디노이징 속도와 디노이징 퀄리티 사이의 트레이드 오프를 정의해야 하는 문제가 될 수 있습니다. 주어진 파이프라인에 어떤 스케줄러가 가장 적합한지를 정량적으로 판단하는 것은 매우 어려운 일입니다. 이로 인해 일단 해당 스케줄러를 직접 사용하여, 생성되는 이미지를 직접 눈으로 보며, 정성적으로 성능을 판단해보는 것이 추천되곤 합니다.\n\n\n\n\n\n## 파이프라인 불러오기\n\n먼저 스테이블 diffusion 파이프라인을 불러오도록 해보겠습니다. 물론 스테이블 diffusion을 사용하기 위해서는, 허깅페이스 허브에 등록된 사용자여야 하며, 관련 [라이센스](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)에 동의해야 한다는 점을 잊지 말아주세요.\n\n*역자 주: 다만, 현재 신규로 생성한 허깅페이스 계정에 대해서는 라이센스 동의를 요구하지 않는 것으로 보입니다!*\n\n```python\nfrom huggingface_hub import login\nfrom diffusers import DiffusionPipeline\nimport torch\n\n# first we need to login with our access token\nlogin()\n\n# Now we can download the pipeline\npipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16)\n```\n\n다음으로, GPU로 이동합니다.\n\n```python\npipeline.to(\"cuda\")\n```\n\n\n\n\n\n## 스케줄러 액세스\n\n스케줄러는 언제나 파이프라인의 컴포넌트로서 존재하며, 일반적으로 파이프라인 인스턴스 내에 `scheduler`라는 이름의 속성(property)으로 정의되어 있습니다.\n\n```python\npipeline.scheduler\n```\n\n**Output**:\n\n```\nPNDMScheduler {\n  \"_class_name\": \"PNDMScheduler\",\n  \"_diffusers_version\": \"0.8.0.dev0\",\n  \"beta_end\": 0.012,\n  \"beta_schedule\": \"scaled_linear\",\n  \"beta_start\": 0.00085,\n  \"clip_sample\": false,\n  \"num_train_timesteps\": 1000,\n  \"set_alpha_to_one\": false,\n  \"skip_prk_steps\": true,\n  \"steps_offset\": 1,\n  \"trained_betas\": null\n}\n```\n\n출력 결과를 통해, 우리는 해당 스케줄러가 [`PNDMScheduler`]의 인스턴스라는 것을 알 수 있습니다. 이제 [`PNDMScheduler`]와 다른 스케줄러들의 성능을 비교해보도록 하겠습니다. 먼저 테스트에 사용할 프롬프트를 다음과 같이 정의해보도록 하겠습니다.\n\n```python\nprompt = \"A photograph of an astronaut riding a horse on Mars, high resolution, high definition.\"\n```\n\n다음으로 유사한 이미지 생성을 보장하기 위해서, 다음과 같이 랜덤시드를 고정해주도록 하겠습니다.\n\n```python\ngenerator = torch.Generator(device=\"cuda\").manual_seed(8)\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_pndm.png\" width=\"400\"/>\n    <br>\n</p>\n\n\n\n\n## 스케줄러 교체하기\n\n다음으로 파이프라인의 스케줄러를 다른 스케줄러로 교체하는 방법에 대해 알아보겠습니다. 모든 스케줄러는 [`SchedulerMixin.compatibles`]라는 속성(property)을 갖고 있습니다. 해당 속성은 **호환 가능한** 스케줄러들에 대한 정보를 담고 있습니다.\n\n```python\npipeline.scheduler.compatibles\n```\n\n**Output**:\n\n```\n[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,\n diffusers.schedulers.scheduling_ddim.DDIMScheduler,\n diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,\n diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,\n diffusers.schedulers.scheduling_pndm.PNDMScheduler,\n diffusers.schedulers.scheduling_ddpm.DDPMScheduler,\n diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler]\n```\n\n호환되는 스케줄러들을 살펴보면 아래와 같습니다.\n\n- [`LMSDiscreteScheduler`],\n- [`DDIMScheduler`],\n- [`DPMSolverMultistepScheduler`],\n- [`EulerDiscreteScheduler`],\n- [`PNDMScheduler`],\n- [`DDPMScheduler`],\n- [`EulerAncestralDiscreteScheduler`].\n\n앞서 정의했던 프롬프트를 사용해서 각각의 스케줄러들을 비교해보도록 하겠습니다.\n\n먼저 파이프라인 안의 스케줄러를 바꾸기 위해 [`ConfigMixin.config`] 속성과 [`ConfigMixin.from_config`] 메서드를 활용해보려고 합니다.\n\n\n\n```python\npipeline.scheduler.config\n```\n\n**Output**:\n\n```\nFrozenDict([('num_train_timesteps', 1000),\n            ('beta_start', 0.00085),\n            ('beta_end', 0.012),\n            ('beta_schedule', 'scaled_linear'),\n            ('trained_betas', None),\n            ('skip_prk_steps', True),\n            ('set_alpha_to_one', False),\n            ('steps_offset', 1),\n            ('_class_name', 'PNDMScheduler'),\n            ('_diffusers_version', '0.8.0.dev0'),\n            ('clip_sample', False)])\n```\n\n기존 스케줄러의 config를 호환 가능한 다른 스케줄러에 이식하는 것 역시 가능합니다.\n\n다음 예시는 기존 스케줄러(`pipeline.scheduler`)를 다른 종류의 스케줄러(`DDIMScheduler`)로 바꾸는 코드입니다. 기존 스케줄러가 갖고 있던 config를 `.from_config` 메서드의 인자로 전달하는 것을 확인할 수 있습니다.\n\n```python\nfrom diffusers import DDIMScheduler\n\npipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)\n```\n\n\n\n이제 파이프라인을 실행해서 두 스케줄러 사이의 생성된 이미지의 퀄리티를 비교해봅시다.\n\n```python\ngenerator = torch.Generator(device=\"cuda\").manual_seed(8)\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_ddim.png\" width=\"400\"/>\n    <br>\n</p>\n\n\n\n\n## 스케줄러들 비교해보기\n\n지금까지는 [`PNDMScheduler`]와 [`DDIMScheduler`] 스케줄러를 실행해보았습니다. 아직 비교해볼 스케줄러들이 더 많이 남아있으니 계속 비교해보도록 하겠습니다.\n\n\n\n[`LMSDiscreteScheduler`]을 일반적으로 더 좋은 결과를 보여줍니다.\n\n```python\nfrom diffusers import LMSDiscreteScheduler\n\npipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)\n\ngenerator = torch.Generator(device=\"cuda\").manual_seed(8)\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_lms.png\" width=\"400\"/>\n    <br>\n</p>\n\n\n[`EulerDiscreteScheduler`]와 [`EulerAncestralDiscreteScheduler`] 고작 30번의 inference step만으로도 높은 퀄리티의 이미지를 생성하는 것을 알 수 있습니다.\n\n```python\nfrom diffusers import EulerDiscreteScheduler\n\npipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)\n\ngenerator = torch.Generator(device=\"cuda\").manual_seed(8)\nimage = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]\nimage\n```\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_discrete.png\" width=\"400\"/>\n    <br>\n</p>\n\n\n```python\nfrom diffusers import EulerAncestralDiscreteScheduler\n\npipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)\n\ngenerator = torch.Generator(device=\"cuda\").manual_seed(8)\nimage = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]\nimage\n```\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_ancestral.png\" width=\"400\"/>\n    <br>\n</p>\n\n\n지금 이 문서를 작성하는 현시점 기준에선, [`DPMSolverMultistepScheduler`]가 시간 대비 가장 좋은 품질의 이미지를 생성하는 것 같습니다. 20번 정도의 스텝만으로도 실행될 수 있습니다.\n\n\n\n```python\nfrom diffusers import DPMSolverMultistepScheduler\n\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n\ngenerator = torch.Generator(device=\"cuda\").manual_seed(8)\nimage = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]\nimage\n```\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_dpm.png\" width=\"400\"/>\n    <br>\n</p>\n\n\n보시다시피 생성된 이미지들은 매우 비슷하고, 비슷한 퀄리티를 보이는 것 같습니다. 실제로 어떤 스케줄러를 선택할 것인가는 종종 특정 이용 사례에 기반해서 결정되곤 합니다. 결국 여러 종류의 스케줄러를 직접 실행시켜보고 눈으로 직접 비교해서 판단하는 게 좋은 선택일 것 같습니다.\n\n\n\n## Flax에서 스케줄러 교체하기\n\nJAX/Flax 사용자인 경우 기본 파이프라인 스케줄러를 변경할 수도 있습니다. 다음은 Flax Stable Diffusion 파이프라인과 초고속 [DDPM-Solver++ 스케줄러를](../api/schedulers/multistep_dpm_solver) 사용하여 추론을 실행하는 방법에 대한 예시입니다 .\n\n```Python\nimport jax\nimport numpy as np\nfrom flax.jax_utils import replicate\nfrom flax.training.common_utils import shard\n\nfrom diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nscheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(\n    model_id,\n    subfolder=\"scheduler\"\n)\npipeline, params = FlaxStableDiffusionPipeline.from_pretrained(\n    model_id,\n    scheduler=scheduler,\n    variant=\"bf16\",\n    dtype=jax.numpy.bfloat16,\n)\nparams[\"scheduler\"] = scheduler_state\n\n# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)\nprompt = \"a photo of an astronaut riding a horse on mars\"\nnum_samples = jax.device_count()\nprompt_ids = pipeline.prepare_inputs([prompt] * num_samples)\n\nprng_seed = jax.random.PRNGKey(0)\nnum_inference_steps = 25\n\n# shard inputs and rng\nparams = replicate(params)\nprng_seed = jax.random.split(prng_seed, jax.device_count())\nprompt_ids = shard(prompt_ids)\n\nimages = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images\nimages = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))\n```\n\n> [!WARNING]\n> 다음 Flax 스케줄러는 *아직* Flax Stable Diffusion 파이프라인과 호환되지 않습니다.\n>\n> - `FlaxLMSDiscreteScheduler`\n> - `FlaxDDPMScheduler`\n\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/sdxl_turbo.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Diffusion XL Turbo\n\n[[open-in-colab]]\n\nSDXL Turbo는 adversarial time-distilled(적대적 시간 전이) [Stable Diffusion XL](https://huggingface.co/papers/2307.01952)(SDXL) 모델로, 단 한 번의 스텝만으로 추론을 실행할 수 있습니다.\n\n이 가이드에서는 text-to-image와 image-to-image를 위한 SDXL-Turbo를 사용하는 방법을 설명합니다.\n\n시작하기 전에 다음 라이브러리가 설치되어 있는지 확인하세요:\n\n```py\n# Colab에서 필요한 라이브러리를 설치하기 위해 주석을 제외하세요\n#!pip install -q diffusers transformers accelerate\n```\n\n## 모델 체크포인트 불러오기\n\n모델 가중치는 Hub의 별도 하위 폴더 또는 로컬에 저장할 수 있으며, 이 경우 [`~StableDiffusionXLPipeline.from_pretrained`] 메서드를 사용해야 합니다:\n\n```py\nfrom diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"stabilityai/sdxl-turbo\", torch_dtype=torch.float16, variant=\"fp16\")\npipeline = pipeline.to(\"cuda\")\n```\n\n또한 [`~StableDiffusionXLPipeline.from_single_file`] 메서드를 사용하여 허브 또는 로컬에서 단일 파일 형식(`.ckpt` 또는 `.safetensors`)으로 저장된 모델 체크포인트를 불러올 수도 있습니다:\n\n```py\nfrom diffusers import StableDiffusionXLPipeline\nimport torch\n\npipeline = StableDiffusionXLPipeline.from_single_file(\n    \"https://huggingface.co/stabilityai/sdxl-turbo/blob/main/sd_xl_turbo_1.0_fp16.safetensors\", torch_dtype=torch.float16)\npipeline = pipeline.to(\"cuda\")\n```\n\n## Text-to-image\n\nText-to-image의 경우 텍스트 프롬프트를 전달합니다. 기본적으로 SDXL Turbo는 512x512 이미지를 생성하며, 이 해상도에서 최상의 결과를 제공합니다. `height` 및 `width` 매개 변수를 768x768 또는 1024x1024로 설정할 수 있지만 이 경우 품질 저하를 예상할 수 있습니다.\n\n모델이 `guidance_scale` 없이 학습되었으므로 이를 0.0으로 설정해 비활성화해야 합니다. 단일 추론 스텝만으로도 고품질 이미지를 생성할 수 있습니다.\n스텝 수를 2, 3 또는 4로 늘리면 이미지 품질이 향상됩니다.\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline_text2image = AutoPipelineForText2Image.from_pretrained(\"stabilityai/sdxl-turbo\", torch_dtype=torch.float16, variant=\"fp16\")\npipeline_text2image = pipeline_text2image.to(\"cuda\")\n\nprompt = \"A cinematic shot of a baby racoon wearing an intricate italian priest robe.\"\n\nimage = pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=1).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/sdxl-turbo-text2img.png\" alt=\"generated image of a racoon in a robe\"/>\n</div>\n\n## Image-to-image\n\nImage-to-image 생성의 경우 `num_inference_steps * strength`가 1보다 크거나 같은지 확인하세요.\nImage-to-image 파이프라인은 아래 예제에서 `0.5 * 2.0 = 1` 스텝과 같이 `int(num_inference_steps * strength)` 스텝으로 실행됩니다.\n\n```py\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import load_image, make_image_grid\n\n# 체크포인트를 불러올 때 추가 메모리 소모를 피하려면 from_pipe를 사용하세요.\npipeline = AutoPipelineForImage2Image.from_pipe(pipeline_text2image).to(\"cuda\")\n\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png\")\ninit_image = init_image.resize((512, 512))\n\nprompt = \"cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k\"\n\nimage = pipeline(prompt, image=init_image, strength=0.5, guidance_scale=0.0, num_inference_steps=2).images[0]\nmake_image_grid([init_image, image], rows=1, cols=2)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/sdxl-turbo-img2img.png\" alt=\"Image-to-image generation sample using SDXL Turbo\"/>\n</div>\n\n## SDXL Turbo 속도 훨씬 더 빠르게 하기\n\n- PyTorch 버전 2 이상을 사용하는 경우 UNet을 컴파일합니다. 첫 번째 추론 실행은 매우 느리지만 이후 실행은 훨씬 빨라집니다.\n\n```py\npipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\n- 기본 VAE를 사용하는 경우, 각 생성 전후에 비용이 많이 드는 `dtype` 변환을 피하기 위해 `float32`로 유지하세요. 이 작업은 첫 생성 이전에 한 번만 수행하면 됩니다:\n\n```py\npipe.upcast_vae()\n```\n\n또는, 커뮤니티 회원인 [`@madebyollin`](https://huggingface.co/madebyollin)이 만든 [16비트 VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)를 사용할 수도 있으며, 이는 `float32`로 업캐스트할 필요가 없습니다."
  },
  {
    "path": "docs/source/ko/using-diffusers/shap-e.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Shap-E\n\n[[open-in-colab]]\n\nShap-E는 비디오 게임 개발, 인테리어 디자인, 건축에 사용할 수 있는 3D 에셋을 생성하기 위한 conditional 모델입니다. 대규모 3D 에셋 데이터셋을 학습되었고, 각 오브젝트의 더 많은 뷰를 렌더링하고 4K point cloud 대신 16K를 생성하도록 후처리합니다. Shap-E 모델은 두 단계로 학습됩니다:\n\n1. 인코더가 3D 에셋의 포인트 클라우드와 렌더링된 뷰를 받아들이고 에셋을 나타내는 implicit functions의 파라미터를 출력합니다.\n2. 인코더가 생성한 latents를 바탕으로 diffusion 모델을 훈련하여 neural radiance fields(NeRF) 또는 textured 3D 메시를 생성하여 다운스트림 애플리케이션에서 3D 에셋을 더 쉽게 렌더링하고 사용할 수 있도록 합니다.\n\n이 가이드에서는 Shap-E를 사용하여 나만의 3D 에셋을 생성하는 방법을 보입니다!\n\n시작하기 전에 다음 라이브러리가 설치되어 있는지 확인하세요:\n\n```py\n# Colab에서 필요한 라이브러리를 설치하기 위해 주석을 제외하세요\n#!pip install -q diffusers transformers accelerate trimesh\n```\n\n## Text-to-3D\n\n3D 객체의 gif를 생성하려면 텍스트 프롬프트를 [`ShapEPipeline`]에 전달합니다. 파이프라인은 3D 객체를 생성하는 데 사용되는 이미지 프레임 리스트를 생성합니다.\n\n```py\nimport torch\nfrom diffusers import ShapEPipeline\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\npipe = ShapEPipeline.from_pretrained(\"openai/shap-e\", torch_dtype=torch.float16, variant=\"fp16\")\npipe = pipe.to(device)\n\nguidance_scale = 15.0\nprompt = [\"A firecracker\", \"A birthday cupcake\"]\n\nimages = pipe(\n    prompt,\n    guidance_scale=guidance_scale,\n    num_inference_steps=64,\n    frame_size=256,\n).images\n```\n\n이제 [`~utils.export_to_gif`] 함수를 사용하여 이미지 프레임 리스트를 3D 객체의 gif로 변환합니다.\n\n```py\nfrom diffusers.utils import export_to_gif\n\nexport_to_gif(images[0], \"firecracker_3d.gif\")\nexport_to_gif(images[1], \"cake_3d.gif\")\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/firecracker_out.gif\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">prompt = \"A firecracker\"</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/cake_out.gif\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">prompt = \"A birthday cupcake\"</figcaption>\n  </div>\n</div>\n\n## Image-to-3D\n\n다른 이미지로부터 3D 개체를 생성하려면 [`ShapEImg2ImgPipeline`]을 사용합니다. 기존 이미지를 사용하거나 완전히 새로운 이미지를 생성할 수 있습니다. [Kandinsky 2.1](../api/pipelines/kandinsky) 모델을 사용하여 새 이미지를 생성해 보겠습니다.\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\nprior_pipeline = DiffusionPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1-prior\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\npipeline = DiffusionPipeline.from_pretrained(\"kandinsky-community/kandinsky-2-1\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n\nprompt = \"A cheeseburger, white background\"\n\nimage_embeds, negative_image_embeds = prior_pipeline(prompt, guidance_scale=1.0).to_tuple()\nimage = pipeline(\n    prompt,\n    image_embeds=image_embeds,\n    negative_image_embeds=negative_image_embeds,\n).images[0]\n\nimage.save(\"burger.png\")\n```\n\n치즈버거를 [`ShapEImg2ImgPipeline`]에 전달하여 3D representation을 생성합니다.\n\n```py\nfrom PIL import Image\nfrom diffusers import ShapEImg2ImgPipeline\nfrom diffusers.utils import export_to_gif\n\npipe = ShapEImg2ImgPipeline.from_pretrained(\"openai/shap-e-img2img\", torch_dtype=torch.float16, variant=\"fp16\").to(\"cuda\")\n\nguidance_scale = 3.0\nimage = Image.open(\"burger.png\").resize((256, 256))\n\nimages = pipe(\n    image,\n    guidance_scale=guidance_scale,\n    num_inference_steps=64,\n    frame_size=256,\n).images\n\ngif_path = export_to_gif(images[0], \"burger_3d.gif\")\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_in.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">cheeseburger</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_out.gif\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">3D cheeseburger</figcaption>\n  </div>\n</div>\n\n## 메시 생성하기\n\nShap-E는 다운스트림 애플리케이션에 렌더링할 textured 메시 출력을 생성할 수도 있는 유연한 모델입니다. 이 예제에서는 🤗 Datasets 라이브러리에서 [Dataset viewer](https://huggingface.co/docs/hub/datasets-viewer#dataset-preview)를 사용해 메시 시각화를 지원하는 `glb` 파일로 변환합니다.\n\n`output_type` 매개변수를 `\"mesh\"`로 지정함으로써 [`ShapEPipeline`]과 [`ShapEImg2ImgPipeline`] 모두에 대한 메시 출력을 생성할 수 있습니다:\n\n```py\nimport torch\nfrom diffusers import ShapEPipeline\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\npipe = ShapEPipeline.from_pretrained(\"openai/shap-e\", torch_dtype=torch.float16, variant=\"fp16\")\npipe = pipe.to(device)\n\nguidance_scale = 15.0\nprompt = \"A birthday cupcake\"\n\nimages = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=64, frame_size=256, output_type=\"mesh\").images\n```\n\n메시 출력을 `ply` 파일로 저장하려면 [`~utils.export_to_ply`] 함수를 사용합니다:\n\n> [!TIP]\n> 선택적으로 [`~utils.export_to_obj`] 함수를 사용하여 메시 출력을 `obj` 파일로 저장할 수 있습니다. 다양한 형식으로 메시 출력을 저장할 수 있어 다운스트림에서 더욱 유연하게 사용할 수 있습니다!\n\n```py\nfrom diffusers.utils import export_to_ply\n\nply_path = export_to_ply(images[0], \"3d_cake.ply\")\nprint(f\"Saved to folder: {ply_path}\")\n```\n\n그 다음 trimesh 라이브러리를 사용하여 `ply` 파일을 `glb` 파일로 변환할 수 있습니다:\n\n```py\nimport trimesh\n\nmesh = trimesh.load(\"3d_cake.ply\")\nmesh_export = mesh.export(\"3d_cake.glb\", file_type=\"glb\")\n```\n\n기본적으로 메시 출력은 아래쪽 시점에 초점이 맞춰져 있지만 회전 변환을 적용하여 기본 시점을 변경할 수 있습니다:\n\n```py\nimport trimesh\nimport numpy as np\n\nmesh = trimesh.load(\"3d_cake.ply\")\nrot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])\nmesh = mesh.apply_transform(rot)\nmesh_export = mesh.export(\"3d_cake.glb\", file_type=\"glb\")\n```\n\n메시 파일을 데이터셋 레포지토리에 업로드해 Dataset viewer로 시각화하세요!\n\n<div class=\"flex justify-center\">\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/3D-cake.gif\"/>\n</div>\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/stable_diffusion_jax_how_to.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# JAX / Flax에서의 🧨 Stable Diffusion!\n\n[[open-in-colab]]\n\n🤗 Hugging Face [Diffusers] (https://github.com/huggingface/diffusers) 는 버전 0.5.1부터 Flax를 지원합니다! 이를 통해 Colab, Kaggle, Google Cloud Platform에서 사용할 수 있는 것처럼 Google TPU에서 초고속 추론이 가능합니다.\n\n이 노트북은 JAX / Flax를 사용해 추론을 실행하는 방법을 보여줍니다. Stable Diffusion의 작동 방식에 대한 자세한 내용을 원하거나 GPU에서 실행하려면 이 [노트북] ](https://huggingface.co/docs/diffusers/stable_diffusion)을 참조하세요.\n\n먼저, TPU 백엔드를 사용하고 있는지 확인합니다. Colab에서 이 노트북을 실행하는 경우, 메뉴에서 런타임을 선택한 다음 \"런타임 유형 변경\" 옵션을 선택한 다음 하드웨어 가속기 설정에서 TPU를 선택합니다.\n\nJAX는 TPU 전용은 아니지만 각 TPU 서버에는 8개의 TPU 가속기가 병렬로 작동하기 때문에 해당 하드웨어에서 더 빛을 발한다는 점은 알아두세요.\n\n\n## Setup\n\n먼저 diffusers가 설치되어 있는지 확인합니다.\n\n```bash\n!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy\n!pip install diffusers\n```\n\n```python\nimport jax.tools.colab_tpu\n\njax.tools.colab_tpu.setup_tpu()\nimport jax\n```\n\n```python\nnum_devices = jax.device_count()\ndevice_type = jax.devices()[0].device_kind\n\nprint(f\"Found {num_devices} JAX devices of type {device_type}.\")\nassert (\n    \"TPU\" in device_type\n), \"Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator\"\n```\n\n```python out\nFound 8 JAX devices of type Cloud TPU.\n```\n\n그런 다음 모든 dependencies를 가져옵니다.\n\n```python\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\n\nfrom pathlib import Path\nfrom jax import pmap\nfrom flax.jax_utils import replicate\nfrom flax.training.common_utils import shard\nfrom PIL import Image\n\nfrom huggingface_hub import notebook_login\nfrom diffusers import FlaxStableDiffusionPipeline\n```\n\n## 모델 불러오기\n\nTPU 장치는 효율적인 half-float 유형인 bfloat16을 지원합니다. 테스트에는 이 유형을 사용하지만 대신 float32를 사용하여 전체 정밀도(full precision)를 사용할 수도 있습니다.\n\n```python\ndtype = jnp.bfloat16\n```\n\nFlax는 함수형 프레임워크이므로 모델은 무상태(stateless)형이며 매개변수는 모델 외부에 저장됩니다. 사전학습된 Flax 파이프라인을 불러오면 파이프라인 자체와 모델 가중치(또는 매개변수)가 모두 반환됩니다. 저희는 bf16 버전의 가중치를 사용하고 있으므로 유형 경고가 표시되지만 무시해도 됩니다.\n\n```python\npipeline, params = FlaxStableDiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    variant=\"bf16\",\n    dtype=dtype,\n)\n```\n\n## 추론\n\nTPU에는 일반적으로 8개의 디바이스가 병렬로 작동하므로 보유한 디바이스 수만큼 프롬프트를 복제합니다. 그런 다음 각각 하나의 이미지 생성을 담당하는 8개의 디바이스에서 한 번에 추론을 수행합니다. 따라서 하나의 칩이 하나의 이미지를 생성하는 데 걸리는 시간과 동일한 시간에 8개의 이미지를 얻을 수 있습니다.\n\n프롬프트를 복제하고 나면 파이프라인의 `prepare_inputs` 함수를 호출하여 토큰화된 텍스트 ID를 얻습니다. 토큰화된 텍스트의 길이는 기본 CLIP 텍스트 모델의 구성에 따라 77토큰으로 설정됩니다.\n\n```python\nprompt = \"A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic\"\nprompt = [prompt] * jax.device_count()\nprompt_ids = pipeline.prepare_inputs(prompt)\nprompt_ids.shape\n```\n\n```python out\n(8, 77)\n```\n\n### 복사(Replication) 및 정렬화\n\n모델 매개변수와 입력값은 우리가 보유한 8개의 병렬 장치에 복사(Replication)되어야 합니다. 매개변수 딕셔너리는 `flax.jax_utils.replicate`(딕셔너리를 순회하며 가중치의 모양을 변경하여 8번 반복하는 함수)를 사용하여 복사됩니다. 배열은 `shard`를 사용하여 복제됩니다.\n\n```python\np_params = replicate(params)\n```\n\n```python\nprompt_ids = shard(prompt_ids)\nprompt_ids.shape\n```\n\n```python out\n(8, 1, 77)\n```\n\n이 shape은 8개의 디바이스 각각이 shape `(1, 77)`의 jnp 배열을 입력값으로 받는다는 의미입니다. 즉 1은 디바이스당 batch(배치) 크기입니다. 메모리가 충분한 TPU에서는 한 번에 여러 이미지(칩당)를 생성하려는 경우 1보다 클 수 있습니다.\n\n이미지를 생성할 준비가 거의 완료되었습니다! 이제 생성 함수에 전달할 난수 생성기만 만들면 됩니다. 이것은 난수를 다루는 모든 함수에 난수 생성기가 있어야 한다는, 난수에 대해 매우 진지하고 독단적인 Flax의 표준 절차입니다. 이렇게 하면 여러 분산된 기기에서 훈련할 때에도 재현성이 보장됩니다.\n\n아래 헬퍼 함수는 시드를 사용하여 난수 생성기를 초기화합니다. 동일한 시드를 사용하는 한 정확히 동일한 결과를 얻을 수 있습니다. 나중에 노트북에서 결과를 탐색할 때엔 다른 시드를 자유롭게 사용하세요.\n\n```python\ndef create_key(seed=0):\n    return jax.random.PRNGKey(seed)\n```\n\nrng를 얻은 다음 8번 '분할'하여 각 디바이스가 다른 제너레이터를 수신하도록 합니다. 따라서 각 디바이스마다 다른 이미지가 생성되며 전체 프로세스를 재현할 수 있습니다.\n\n```python\nrng = create_key(0)\nrng = jax.random.split(rng, jax.device_count())\n```\n\nJAX 코드는 매우 빠르게 실행되는 효율적인 표현으로 컴파일할 수 있습니다. 하지만 후속 호출에서 모든 입력이 동일한 모양을 갖도록 해야 하며, 그렇지 않으면 JAX가 코드를 다시 컴파일해야 하므로 최적화된 속도를 활용할 수 없습니다.\n\n`jit = True`를 인수로 전달하면 Flax 파이프라인이 코드를 컴파일할 수 있습니다. 또한 모델이 사용 가능한 8개의 디바이스에서 병렬로 실행되도록 보장합니다.\n\n다음 셀을 처음 실행하면 컴파일하는 데 시간이 오래 걸리지만 이후 호출(입력이 다른 경우에도)은 훨씬 빨라집니다. 예를 들어, 테스트했을 때 TPU v2-8에서 컴파일하는 데 1분 이상 걸리지만 이후 추론 실행에는 약 7초가 걸립니다.\n\n```\n%%time\nimages = pipeline(prompt_ids, p_params, rng, jit=True)[0]\n```\n\n```python out\nCPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s\nWall time: 1min 29s\n```\n\n반환된 배열의 shape은 `(8, 1, 512, 512, 3)`입니다. 이를 재구성하여 두 번째 차원을 제거하고 512 × 512 × 3의 이미지 8개를 얻은 다음 PIL로 변환합니다.\n\n```python\nimages = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])\nimages = pipeline.numpy_to_pil(images)\n```\n\n### 시각화\n\n이미지를 그리드에 표시하는 도우미 함수를 만들어 보겠습니다.\n\n```python\ndef image_grid(imgs, rows, cols):\n    w, h = imgs[0].size\n    grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i % cols * w, i // cols * h))\n    return grid\n```\n\n```python\nimage_grid(images, 2, 4)\n```\n\n![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg)\n\n\n## 다른 프롬프트 사용\n\n모든 디바이스에서 동일한 프롬프트를 복제할 필요는 없습니다. 프롬프트 2개를 각각 4번씩 생성하거나 한 번에 8개의 서로 다른 프롬프트를 생성하는 등 원하는 것은 무엇이든 할 수 있습니다. 한번 해보세요!\n\n먼저 입력 준비 코드를 편리한 함수로 리팩터링하겠습니다:\n\n```python\nprompts = [\n    \"Labrador in the style of Hokusai\",\n    \"Painting of a squirrel skating in New York\",\n    \"HAL-9000 in the style of Van Gogh\",\n    \"Times Square under water, with fish and a dolphin swimming around\",\n    \"Ancient Roman fresco showing a man working on his laptop\",\n    \"Close-up photograph of young black woman against urban background, high quality, bokeh\",\n    \"Armchair in the shape of an avocado\",\n    \"Clown astronaut in space, with Earth in the background\",\n]\n```\n\n```python\nprompt_ids = pipeline.prepare_inputs(prompts)\nprompt_ids = shard(prompt_ids)\n\nimages = pipeline(prompt_ids, p_params, rng, jit=True).images\nimages = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])\nimages = pipeline.numpy_to_pil(images)\n\nimage_grid(images, 2, 4)\n```\n\n![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg)\n\n\n## 병렬화(parallelization)는 어떻게 작동하는가?\n\n앞서 `diffusers` Flax 파이프라인이 모델을 자동으로 컴파일하고 사용 가능한 모든 기기에서 병렬로 실행한다고 말씀드렸습니다. 이제 그 프로세스를 간략하게 살펴보고 작동 방식을 보여드리겠습니다.\n\nJAX 병렬화는 여러 가지 방법으로 수행할 수 있습니다. 가장 쉬운 방법은 jax.pmap 함수를 사용하여 단일 프로그램, 다중 데이터(SPMD) 병렬화를 달성하는 것입니다. 즉, 동일한 코드의 복사본을 각각 다른 데이터 입력에 대해 여러 개 실행하는 것입니다. 더 정교한 접근 방식도 가능하므로 관심이 있으시다면 [JAX 문서](https://jax.readthedocs.io/en/latest/index.html)와 [`pjit` 페이지](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html?highlight=pjit)에서 이 주제를 살펴보시기 바랍니다!\n\n`jax.pmap`은 두 가지 기능을 수행합니다:\n\n- `jax.jit()`를 호출한 것처럼 코드를 컴파일(또는 `jit`)합니다. 이 작업은 `pmap`을 호출할 때가 아니라 pmapped 함수가 처음 호출될 때 수행됩니다.\n- 컴파일된 코드가 사용 가능한 모든 기기에서 병렬로 실행되도록 합니다.\n\n작동 방식을 보여드리기 위해 이미지 생성을 실행하는 비공개 메서드인 파이프라인의 `_generate` 메서드를 `pmap`합니다. 이 메서드는 향후 `Diffusers` 릴리스에서 이름이 변경되거나 제거될 수 있다는 점에 유의하세요.\n\n```python\np_generate = pmap(pipeline._generate)\n```\n\n`pmap`을 사용한 후 준비된 함수 `p_generate`는 개념적으로 다음을 수행합니다:\n* 각 장치에서 기본 함수 `pipeline._generate`의 복사본을 호출합니다.\n* 각 장치에 입력 인수의 다른 부분을 보냅니다. 이것이 바로 샤딩이 사용되는 이유입니다. 이 경우 `prompt_ids`의 shape은 `(8, 1, 77, 768)`입니다. 이 배열은 8개로 분할되고 `_generate`의 각 복사본은 `(1, 77, 768)`의 shape을 가진 입력을 받게 됩니다.\n\n병렬로 호출된다는 사실을 완전히 무시하고 `_generate`를 코딩할 수 있습니다. batch(배치) 크기(이 예제에서는 `1`)와 코드에 적합한 차원만 신경 쓰면 되며, 병렬로 작동하기 위해 아무것도 변경할 필요가 없습니다.\n\n파이프라인 호출을 사용할 때와 마찬가지로, 다음 셀을 처음 실행할 때는 시간이 걸리지만 그 이후에는 훨씬 빨라집니다.\n\n```\n%%time\nimages = p_generate(prompt_ids, p_params, rng)\nimages = images.block_until_ready()\nimages.shape\n```\n\n```python out\nCPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s\nWall time: 1min 15s\n```\n\n```python\nimages.shape\n```\n\n```python out\n(8, 1, 512, 512, 3)\n```\n\nJAX는 비동기 디스패치를 사용하고 가능한 한 빨리 제어권을 Python 루프에 반환하기 때문에 추론 시간을 정확하게 측정하기 위해 `block_until_ready()`를 사용합니다. 아직 구체화되지 않은 계산 결과를 사용하려는 경우 자동으로 차단이 수행되므로 코드에서 이 함수를 사용할 필요가 없습니다."
  },
  {
    "path": "docs/source/ko/using-diffusers/svd.md",
    "content": "<!--Copyright 2023 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Stable Video Diffusion\n\n[[open-in-colab]]\n\n[Stable Video Diffusion (SVD)](https://huggingface.co/papers/2311.15127)은 입력 이미지에 맞춰 2~4초 분량의 고해상도(576x1024) 비디오를 생성할 수 있는 강력한 image-to-video 생성 모델입니다.\n\n이 가이드에서는 SVD를 사용하여 이미지에서 짧은 동영상을 생성하는 방법을 설명합니다.\n\n시작하기 전에 다음 라이브러리가 설치되어 있는지 확인하세요:\n\n```py\n!pip install -q -U diffusers transformers accelerate\n```\n\n이 모델에는 [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid)와 [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) 두 가지 종류가 있습니다. SVD 체크포인트는 14개의 프레임을 생성하도록 학습되었고, SVD-XT 체크포인트는 25개의 프레임을 생성하도록 파인튜닝되었습니다.\n\n이 가이드에서는 SVD-XT 체크포인트를 사용합니다.\n\n```python\nimport torch\n\nfrom diffusers import StableVideoDiffusionPipeline\nfrom diffusers.utils import load_image, export_to_video\n\npipe = StableVideoDiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-video-diffusion-img2vid-xt\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipe.enable_model_cpu_offload()\n\n# Conditioning 이미지 불러오기\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png\")\nimage = image.resize((1024, 576))\n\ngenerator = torch.manual_seed(42)\nframes = pipe(image, decode_chunk_size=8, generator=generator).frames[0]\n\nexport_to_video(frames, \"generated.mp4\", fps=7)\n```\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">\"source image of a rocket\"</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/output_rocket.gif\"/>\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">\"generated video from source image\"</figcaption>\n  </div>\n</div>\n\n## torch.compile\n\nUNet을 [컴파일](../optimization/torch2.0#torchcompile)하면 메모리 사용량이 살짝 증가하지만, 20~25%의 속도 향상을 얻을 수 있습니다.\n\n```diff\n- pipe.enable_model_cpu_offload()\n+ pipe.to(\"cuda\")\n+ pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\n## 메모리 사용량 줄이기\n\n비디오 생성은 기본적으로 배치 크기가 큰 text-to-image 생성과 유사하게 'num_frames'를 한 번에 생성하기 때문에 메모리 사용량이 매우 높습니다. 메모리 사용량을 줄이기 위해 추론 속도와 메모리 사용량을 절충하는 여러 가지 옵션이 있습니다:\n\n- 모델 오프로링 활성화: 파이프라인의 각 구성 요소가 더 이상 필요하지 않을 때 CPU로 오프로드됩니다.\n- Feed-forward chunking 활성화: feed-forward 레이어가 배치 크기가 큰 단일 feed-forward를 실행하는 대신 루프로 반복해서 실행됩니다.\n- `decode_chunk_size` 감소: VAE가 프레임들을 한꺼번에 디코딩하는 대신 chunk 단위로 디코딩합니다. `decode_chunk_size=1`을 설정하면 한 번에 한 프레임씩 디코딩하고 최소한의 메모리만 사용하지만(GPU 메모리에 따라 이 값을 조정하는 것이 좋습니다), 동영상에 약간의 깜박임이 발생할 수 있습니다.\n\n```diff\n- pipe.enable_model_cpu_offload()\n- frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]\n+ pipe.enable_model_cpu_offload()\n+ pipe.unet.enable_forward_chunking()\n+ frames = pipe(image, decode_chunk_size=2, generator=generator, num_frames=25).frames[0]\n```\n\n이러한 모든 방법들을 사용하면 메모리 사용량이 8GAM VRAM보다 적을 것입니다.\n\n## Micro-conditioning\n\nStable Diffusion Video는 또한 이미지 conditoning 외에도 micro-conditioning을 허용하므로 생성된 비디오를 더 잘 제어할 수 있습니다:\n\n- `fps`: 생성된 비디오의 초당 프레임 수입니다.\n- `motion_bucket_id`: 생성된 동영상에 사용할 모션 버킷 아이디입니다. 생성된 동영상의 모션을 제어하는 데 사용할 수 있습니다. 모션 버킷 아이디를 늘리면 생성되는 동영상의 모션이 증가합니다.\n- `noise_aug_strength`: Conditioning 이미지에 추가되는 노이즈의 양입니다. 값이 클수록 비디오가 conditioning 이미지와 덜 유사해집니다. 이 값을 높이면 생성된 비디오의 움직임도 증가합니다.\n\n예를 들어, 모션이 더 많은 동영상을 생성하려면 `motion_bucket_id` 및 `noise_aug_strength` micro-conditioning 파라미터를 사용합니다:\n\n```python\nimport torch\n\nfrom diffusers import StableVideoDiffusionPipeline\nfrom diffusers.utils import load_image, export_to_video\n\npipe = StableVideoDiffusionPipeline.from_pretrained(\n  \"stabilityai/stable-video-diffusion-img2vid-xt\", torch_dtype=torch.float16, variant=\"fp16\"\n)\npipe.enable_model_cpu_offload()\n\n# Conditioning 이미지 불러오기\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png\")\nimage = image.resize((1024, 576))\n\ngenerator = torch.manual_seed(42)\nframes = pipe(image, decode_chunk_size=8, generator=generator, motion_bucket_id=180, noise_aug_strength=0.1).frames[0]\nexport_to_video(frames, \"generated.mp4\", fps=7)\n```\n\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/output_rocket_with_conditions.gif)\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/textual_inversion_inference.md",
    "content": "# Textual inversion\n\n[[open-in-colab]]\n\n[`StableDiffusionPipeline`]은  textual-inversion을 지원하는데, 이는 몇 개의 샘플 이미지만으로 stable diffusion과 같은 모델이 새로운 컨셉을 학습할 수 있도록 하는 기법입니다. 이를 통해 생성된 이미지를 더 잘 제어하고 특정 컨셉에 맞게 모델을 조정할 수 있습니다. 커뮤니티에서 만들어진 컨셉들의 컬렉션은 [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer)를 통해 빠르게 사용해볼 수 있습니다.\n\n이 가이드에서는 Stable Diffusion Conceptualizer에서 사전학습한 컨셉을 사용하여 textual-inversion으로 추론을 실행하는 방법을 보여드립니다. textual-inversion으로 모델에 새로운 컨셉을 학습시키는 데 관심이 있으시다면,  [Textual Inversion](./training/text_inversion)  훈련 가이드를 참조하세요.\n\nHugging Face 계정으로 로그인하세요:\n\n```py\nfrom huggingface_hub import notebook_login\n\nnotebook_login()\n```\n\n필요한 라이브러리를 불러오고 생성된 이미지를 시각화하기 위한 도우미 함수 `image_grid`를 만듭니다:\n\n```py\nimport os\nimport torch\n\nimport PIL\nfrom PIL import Image\n\nfrom diffusers import StableDiffusionPipeline\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\n\ndef image_grid(imgs, rows, cols):\n    assert len(imgs) == rows * cols\n\n    w, h = imgs[0].size\n    grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n    grid_w, grid_h = grid.size\n\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i % cols * w, i // cols * h))\n    return grid\n```\n\nStable Diffusion과 [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer)에서 사전학습된 컨셉을 선택합니다:\n\n```py\npretrained_model_name_or_path = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nrepo_id_embeds = \"sd-concepts-library/cat-toy\"\n```\n\n이제 파이프라인을 로드하고 사전학습된 컨셉을 파이프라인에 전달할 수 있습니다:\n\n```py\npipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16).to(\"cuda\")\n\npipeline.load_textual_inversion(repo_id_embeds)\n```\n\n특별한 placeholder token '`<cat-toy>`'를 사용하여 사전학습된 컨셉으로 프롬프트를 만들고, 생성할 샘플의 수와 이미지 행의 수를 선택합니다:\n\n```py\nprompt = \"a grafitti in a favela wall with a <cat-toy> on it\"\n\nnum_samples = 2\nnum_rows = 2\n```\n\n그런 다음 파이프라인을 실행하고, 생성된 이미지들을 저장합니다. 그리고 처음에 만들었던 도우미 함수 `image_grid`를 사용하여 생성 결과들을 시각화합니다. 이 때 `num_inference_steps`와 `guidance_scale`과 같은 매개 변수들을 조정하여, 이것들이 이미지 품질에 어떠한 영향을 미치는지를 자유롭게 확인해보시기 바랍니다.\n\n```py\nall_images = []\nfor _ in range(num_rows):\n    images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images\n    all_images.extend(images)\n\ngrid = image_grid(all_images, num_samples, num_rows)\ngrid\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/textual_inversion_inference.png\">\n</div>\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/unconditional_image_generation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Unconditional 이미지 생성\n\n[[open-in-colab]]\n\nUnconditional 이미지 생성은 비교적 간단한 작업입니다. 모델이 텍스트나 이미지와 같은 추가 조건 없이 이미 학습된 학습 데이터와 유사한 이미지만 생성합니다.\n\n['DiffusionPipeline']은 추론을 위해 미리 학습된 diffusion 시스템을 사용하는 가장 쉬운 방법입니다.\n\n먼저 ['DiffusionPipeline']의 인스턴스를 생성하고 다운로드할 파이프라인의 [체크포인트](https://huggingface.co/models?library=diffusers&sort=downloads)를 지정합니다. 허브의 🧨 diffusion 체크포인트 중 하나를 사용할 수 있습니다(사용할 체크포인트는 나비 이미지를 생성합니다).\n\n> [!TIP]\n> 💡 나만의 unconditional 이미지 생성 모델을 학습시키고 싶으신가요? 학습 가이드를 살펴보고 나만의 이미지를 생성하는 방법을 알아보세요.\n\n\n이 가이드에서는 unconditional 이미지 생성에 ['DiffusionPipeline']과 [DDPM](https://huggingface.co/papers/2006.11239)을 사용합니다:\n\n```python\n >>> from diffusers import DiffusionPipeline\n\n >>> generator = DiffusionPipeline.from_pretrained(\"anton-l/ddpm-butterflies-128\")\n```\n\n[diffusion 파이프라인]은 모든 모델링, 토큰화, 스케줄링 구성 요소를 다운로드하고 캐시합니다. 이 모델은 약 14억 개의 파라미터로 구성되어 있기 때문에 GPU에서 실행할 것을 강력히 권장합니다. PyTorch에서와 마찬가지로 제너레이터 객체를 GPU로 옮길 수 있습니다:\n\n```python\n >>> generator.to(\"cuda\")\n```\n\n이제 제너레이터를 사용하여 이미지를 생성할 수 있습니다:\n\n```python\n >>> image = generator().images[0]\n```\n\n출력은 기본적으로 [PIL.Image](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) 객체로 감싸집니다.\n\n다음을 호출하여 이미지를 저장할 수 있습니다:\n\n```python\n >>> image.save(\"generated_image.png\")\n```\n\n아래 스페이스(데모 링크)를 이용해 보고, 추론 단계의 매개변수를 자유롭게 조절하여 이미지 품질에 어떤 영향을 미치는지 확인해 보세요!\n\n<iframe src=\"https://stevhliu-ddpm-butterflies-128.hf.space\" frameborder=\"0\" width=\"850\" height=\"500\"></iframe>\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/weighted_prompts.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 프롬프트에 가중치 부여하기\n\n[[open-in-colab]]\n\n텍스트 가이드 기반의 diffusion 모델은 주어진 텍스트 프롬프트를 기반으로 이미지를 생성합니다.\n텍스트 프롬프트에는 모델이 생성해야 하는 여러 개념이 포함될 수 있으며 프롬프트의 특정 부분에 가중치를 부여하는 것이 바람직한 경우가 많습니다.\n\nDiffusion 모델은 문맥화된 텍스트 임베딩으로 diffusion 모델의 cross attention 레이어를 조절함으로써 작동합니다.\n([더 많은 정보를 위한 Stable Diffusion Guide](https://huggingface.co/docs/optimum-neuron/main/en/package_reference/modeling#stable-diffusion)를 참고하세요).\n따라서 프롬프트의 특정 부분을 강조하는(또는 강조하지 않는) 간단한 방법은 프롬프트의 관련 부분에 해당하는 텍스트 임베딩 벡터의 크기를 늘리거나 줄이는 것입니다.\n이것은 \"프롬프트 가중치 부여\" 라고 하며, 커뮤니티에서 가장 요구하는 기능입니다.([이곳](https://github.com/huggingface/diffusers/issues/2431)의 issue를 보세요 ).\n\n## Diffusers에서 프롬프트 가중치 부여하는 방법\n\n우리는 `diffusers`의 역할이 다른 프로젝트를 가능하게 하는 필수적인 기능을 제공하는 toolbex라고 생각합니다.\n[InvokeAI](https://github.com/invoke-ai/InvokeAI) 나 [diffuzers](https://github.com/abhishekkrthakur/diffuzers) 같은 강력한 UI를 구축할 수 있습니다.\n프롬프트를 조작하는 방법을 지원하기 위해, `diffusers` 는\n[StableDiffusionPipeline](https://huggingface.co/docs/diffusers/v0.18.2/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline)와 같은\n많은 파이프라인에 [prompt_embeds](https://huggingface.co/docs/diffusers/v0.14.0/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds)\n인수를 노출시켜, \"prompt-weighted\"/축척된 텍스트 임베딩을 파이프라인에 바로 전달할 수 있게 합니다.\n\n[Compel 라이브러리](https://github.com/damian0815/compel)는 프롬프트의 일부를 강조하거나 강조하지 않을 수 있는 쉬운 방법을 제공합니다.\n임베딩을 직접 준비하는 것 대신 이 방법을 사용하는 것을 강력히 추천합니다.\n\n간단한 예제를 살펴보겠습니다.\n다음과 같이 `\"공을 갖고 노는 붉은색 고양이\"` 이미지를 생성하고 싶습니다:\n\n```py\nfrom diffusers import StableDiffusionPipeline, UniPCMultistepScheduler\n\npipe = StableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\")\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n\nprompt = \"a red cat playing with a ball\"\n\ngenerator = torch.Generator(device=\"cpu\").manual_seed(33)\n\nimage = pipe(prompt, generator=generator, num_inference_steps=20).images[0]\nimage\n```\n\n생성된 이미지:\n\n![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_0.png)\n\n사진에서 알 수 있듯이, \"공\"은 이미지에 없습니다. 이 부분을 강조해 볼까요!\n\n먼저 `compel` 라이브러리를 설치해야합니다:\n\n```sh\npip install compel\n```\n\n그런 다음에는 `Compel` 오브젝트를 생성합니다:\n\n```py\nfrom compel import Compel\n\ncompel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)\n```\n\n이제 `\"++\"` 를 사용해서 \"공\" 을 강조해 봅시다:\n\n```py\nprompt = \"a red cat playing with a ball++\"\n```\n\n그리고 이 프롬프트를 파이프라인에 바로 전달하지 않고, `compel_proc` 를 사용하여 처리해야합니다:\n\n```py\nprompt_embeds = compel_proc(prompt)\n```\n\n파이프라인에 `prompt_embeds` 를 바로 전달할 수 있습니다:\n\n```py\ngenerator = torch.Generator(device=\"cpu\").manual_seed(33)\n\nimages = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]\nimage\n```\n\n이제 \"공\"이 있는 그림을 출력할 수 있습니다!\n\n![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_1.png)\n\n마찬가지로 `--` 접미사를 단어에 사용하여 문장의 일부를 강조하지 않을 수 있습니다. 한번 시도해 보세요!\n\n즐겨찾는 파이프라인에 `prompt_embeds` 입력이 없는 경우 issue를 새로 만들어주세요.\nDiffusers 팀은 최대한 대응하려고 노력합니다.\n\nCompel 1.1.6 는 textual inversions을 사용하여 단순화하는 유티릴티 클래스를 추가합니다.\n`DiffusersTextualInversionManager`를 인스턴스화 한 후 이를 Compel init에 전달합니다:\n\n```\ntextual_inversion_manager = DiffusersTextualInversionManager(pipe)\ncompel = Compel(\n    tokenizer=pipe.tokenizer,\n    text_encoder=pipe.text_encoder,\n    textual_inversion_manager=textual_inversion_manager)\n```\n\n더 많은 정보를 얻고 싶다면 [compel](https://github.com/damian0815/compel) 라이브러리 문서를 참고하세요.\n"
  },
  {
    "path": "docs/source/ko/using-diffusers/write_own_pipeline.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 파이프라인, 모델 및 스케줄러 이해하기\n\n[[open-in-colab]]\n\n🧨 Diffusers는 사용자 친화적이며 유연한 도구 상자로, 사용사례에 맞게 diffusion 시스템을 구축 할 수 있도록 설계되었습니다. 이 도구 상자의 핵심은 모델과 스케줄러입니다. [`DiffusionPipeline`]은 편의를 위해 이러한 구성 요소를 번들로 제공하지만, 파이프라인을 분리하고 모델과 스케줄러를 개별적으로 사용해 새로운 diffusion 시스템을 만들 수도 있습니다.\n\n이 튜토리얼에서는 기본 파이프라인부터 시작해 Stable Diffusion 파이프라인까지 진행하며 모델과 스케줄러를 사용해 추론을 위한 diffusion 시스템을 조립하는 방법을 배웁니다.\n\n## 기본 파이프라인 해체하기\n\n파이프라인은 추론을 위해 모델을 실행하는 빠르고 쉬운 방법으로, 이미지를 생성하는 데 코드가 4줄 이상 필요하지 않습니다:\n\n```py\n>>> from diffusers import DDPMPipeline\n\n>>> ddpm = DDPMPipeline.from_pretrained(\"google/ddpm-cat-256\").to(\"cuda\")\n>>> image = ddpm(num_inference_steps=25).images[0]\n>>> image\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ddpm-cat.png\" alt=\"Image of cat created from DDPMPipeline\"/>\n</div>\n\n정말 쉽습니다. 그런데 파이프라인은 어떻게 이렇게 할 수 있었을까요? 파이프라인을 세분화하여 내부에서 어떤 일이 일어나고 있는지 살펴보겠습니다.\n\n위 예시에서 파이프라인에는 [`UNet2DModel`] 모델과 [`DDPMScheduler`]가 포함되어 있습니다. 파이프라인은 원하는 출력 크기의 랜덤 노이즈를 받아 모델을 여러번 통과시켜 이미지의 노이즈를 제거합니다. 각 timestep에서 모델은 *noise residual*을 예측하고 스케줄러는 이를 사용하여 노이즈가 적은 이미지를 예측합니다. 파이프라인은 지정된 추론 스텝수에 도달할 때까지 이 과정을 반복합니다.\n\n모델과 스케줄러를 별도로 사용하여 파이프라인을 다시 생성하기 위해 자체적인 노이즈 제거 프로세스를 작성해 보겠습니다.\n\n1. 모델과 스케줄러를 불러옵니다:\n\n    ```py\n    >>> from diffusers import DDPMScheduler, UNet2DModel\n\n    >>> scheduler = DDPMScheduler.from_pretrained(\"google/ddpm-cat-256\")\n    >>> model = UNet2DModel.from_pretrained(\"google/ddpm-cat-256\").to(\"cuda\")\n    ```\n\n2. 노이즈 제거 프로세스를 실행할 timestep 수를 설정합니다:\n\n    ```py\n    >>> scheduler.set_timesteps(50)\n    ```\n\n3. 스케줄러의 timestep을 설정하면 균등한 간격의 구성 요소를 가진 텐서가 생성됩니다.(이 예시에서는 50개) 각 요소는 모델이 이미지의 노이즈를 제거하는 시간 간격에 해당합니다. 나중에 노이즈 제거 루프를 만들 때 이 텐서를 반복하여 이미지의 노이즈를 제거합니다:\n\n    ```py\n    >>> scheduler.timesteps\n    tensor([980, 960, 940, 920, 900, 880, 860, 840, 820, 800, 780, 760, 740, 720,\n        700, 680, 660, 640, 620, 600, 580, 560, 540, 520, 500, 480, 460, 440,\n        420, 400, 380, 360, 340, 320, 300, 280, 260, 240, 220, 200, 180, 160,\n        140, 120, 100,  80,  60,  40,  20,   0])\n    ```\n\n4. 원하는 출력과 같은 모양을 가진 랜덤 노이즈를 생성합니다:\n\n    ```py\n    >>> import torch\n\n    >>> sample_size = model.config.sample_size\n    >>> noise = torch.randn((1, 3, sample_size, sample_size), device=\"cuda\")\n    ```\n\n5. 이제 timestep을 반복하는 루프를 작성합니다. 각 timestep에서 모델은 [`UNet2DModel.forward`]를 통해 noisy residual을 반환합니다. 스케줄러의 [`~DDPMScheduler.step`] 메서드는 noisy residual, timestep, 그리고 입력을 받아 이전 timestep에서 이미지를 예측합니다. 이 출력은 노이즈 제거 루프의 모델에 대한 다음 입력이 되며, `timesteps` 배열의 끝에 도달할 때까지 반복됩니다.\n\n    ```py\n    >>> input = noise\n\n    >>> for t in scheduler.timesteps:\n    ...     with torch.no_grad():\n    ...         noisy_residual = model(input, t).sample\n    ...     previous_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample\n    ...     input = previous_noisy_sample\n    ```\n\n    이것이 전체 노이즈 제거 프로세스이며, 동일한 패턴을 사용해 모든 diffusion 시스템을 작성할 수 있습니다.\n\n6. 마지막 단계는 노이즈가 제거된 출력을 이미지로 변환하는 것입니다:\n\n    ```py\n    >>> from PIL import Image\n    >>> import numpy as np\n\n    >>> image = (input / 2 + 0.5).clamp(0, 1)\n    >>> image = image.cpu().permute(0, 2, 3, 1).numpy()[0]\n    >>> image = Image.fromarray((image * 255).round().astype(\"uint8\"))\n    >>> image\n    ```\n\n다음 섹션에서는 여러분의 기술을 시험해보고 좀 더 복잡한 Stable Diffusion 파이프라인을 분석해 보겠습니다. 방법은 거의 동일합니다. 필요한 구성요소들을 초기화하고 timestep수를 설정하여 `timestep` 배열을 생성합니다. 노이즈 제거 루프에서 `timestep` 배열이 사용되며, 이 배열의 각 요소에 대해 모델은 노이즈가 적은 이미지를 예측합니다. 노이즈 제거 루프는 `timestep`을 반복하고 각 timestep에서 noise residual을 출력하고 스케줄러는 이를 사용하여 이전 timestep에서 노이즈가 덜한 이미지를 예측합니다. 이 프로세스는 `timestep` 배열의 끝에 도달할 때까지 반복됩니다.\n\n한번 사용해 봅시다!\n\n## Stable Diffusion 파이프라인 해체하기\n\nStable Diffusion 은 text-to-image *latent diffusion* 모델입니다. latent diffusion 모델이라고 불리는 이유는 실제 픽셀 공간 대신 이미지의 저차원의 표현으로 작업하기 때문이고, 메모리 효율이 더 높습니다. 인코더는 이미지를 더 작은 표현으로 압축하고, 디코더는 압축된 표현을 다시 이미지로 변환합니다. text-to-image 모델의 경우 텍스트 임베딩을 생성하기 위해 tokenizer와 인코더가 필요합니다. 이전 예제에서 이미 UNet 모델과 스케줄러가 필요하다는 것은 알고 계셨을 것입니다.\n\n보시다시피, 이것은 UNet 모델만 포함된 DDPM 파이프라인보다 더 복잡합니다. Stable Diffusion 모델에는 세 개의 개별 사전학습된 모델이 있습니다.\n\n> [!TIP]\n> 💡 VAE, UNet 및 텍스트 인코더 모델의 작동방식에 대한 자세한 내용은 [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) 블로그를 참조하세요.\n\n이제 Stable Diffusion 파이프라인에 필요한 구성요소들이 무엇인지 알았으니, [`~ModelMixin.from_pretrained`] 메서드를 사용해 모든 구성요소를 불러옵니다. 사전학습된 체크포인트 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)에서 찾을 수 있으며, 각 구성요소들은 별도의 하위 폴더에 저장되어 있습니다:\n\n```py\n>>> from PIL import Image\n>>> import torch\n>>> from transformers import CLIPTextModel, CLIPTokenizer\n>>> from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler\n\n>>> vae = AutoencoderKL.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"vae\")\n>>> tokenizer = CLIPTokenizer.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"tokenizer\")\n>>> text_encoder = CLIPTextModel.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"text_encoder\")\n>>> unet = UNet2DConditionModel.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"unet\")\n```\n\n기본 [`PNDMScheduler`] 대신, [`UniPCMultistepScheduler`]로 교체하여 다른 스케줄러를 얼마나 쉽게 연결할 수 있는지 확인합니다:\n\n```py\n>>> from diffusers import UniPCMultistepScheduler\n\n>>> scheduler = UniPCMultistepScheduler.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"scheduler\")\n```\n\n추론 속도를 높이려면 스케줄러와 달리 학습 가능한 가중치가 있으므로 모델을 GPU로 옮기세요:\n\n```py\n>>> torch_device = \"cuda\"\n>>> vae.to(torch_device)\n>>> text_encoder.to(torch_device)\n>>> unet.to(torch_device)\n```\n\n### 텍스트 임베딩 생성하기\n\n다음 단계는 임베딩을 생성하기 위해 텍스트를 토큰화하는 것입니다. 이 텍스트는 UNet 모델에서 condition으로 사용되고 입력 프롬프트와 유사한 방향으로 diffusion 프로세스를 조정하는 데 사용됩니다.\n\n> [!TIP]\n> 💡 `guidance_scale` 매개변수는 이미지를 생성할 때 프롬프트에 얼마나 많은 가중치를 부여할지 결정합니다.\n\n다른 프롬프트를 생성하고 싶다면 원하는 프롬프트를 자유롭게 선택하세요!\n\n```py\n>>> prompt = [\"a photograph of an astronaut riding a horse\"]\n>>> height = 512  # Stable Diffusion의 기본 높이\n>>> width = 512  # Stable Diffusion의 기본 너비\n>>> num_inference_steps = 25  # 노이즈 제거 스텝 수\n>>> guidance_scale = 7.5  # classifier-free guidance를 위한 scale\n>>> generator = torch.manual_seed(0)  # 초기 잠재 노이즈를 생성하는 seed generator\n>>> batch_size = len(prompt)\n```\n\n텍스트를 토큰화하고 프롬프트에서 임베딩을 생성합니다:\n\n```py\n>>> text_input = tokenizer(\n...     prompt, padding=\"max_length\", max_length=tokenizer.model_max_length, truncation=True, return_tensors=\"pt\"\n... )\n\n>>> with torch.no_grad():\n...     text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]\n```\n\n또한 패딩 토큰의 임베딩인 *unconditional 텍스트 임베딩*을 생성해야 합니다. 이 임베딩은 조건부 `text_embeddings`과 동일한 shape(`batch_size` 그리고 `seq_length`)을 가져야 합니다:\n\n```py\n>>> max_length = text_input.input_ids.shape[-1]\n>>> uncond_input = tokenizer([\"\"] * batch_size, padding=\"max_length\", max_length=max_length, return_tensors=\"pt\")\n>>> uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]\n```\n\n두번의 forward pass를 피하기 위해 conditional 임베딩과 unconditional 임베딩을 배치(batch)로 연결하겠습니다:\n\n```py\n>>> text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n```\n\n### 랜덤 노이즈 생성\n\n그다음 diffusion 프로세스의 시작점으로 초기 랜덤 노이즈를 생성합니다. 이것이 이미지의 잠재적 표현이며 점차적으로 노이즈가 제거됩니다. 이 시점에서 `latent` 이미지는 최종 이미지 크기보다 작지만 나중에 모델이 이를 512x512 이미지 크기로 변환하므로 괜찮습니다.\n\n> [!TIP]\n> 💡 `vae` 모델에는 3개의 다운 샘플링 레이어가 있기 때문에 높이와 너비가 8로 나뉩니다. 다음을 실행하여 확인할 수 있습니다:\n>\n> ```py\n> 2 ** (len(vae.config.block_out_channels) - 1) == 8\n> ```\n\n```py\n>>> latents = torch.randn(\n...     (batch_size, unet.config.in_channels, height // 8, width // 8),\n...     generator=generator,\n...     device=torch_device,\n... )\n```\n\n### 이미지 노이즈 제거\n\n먼저 [`UniPCMultistepScheduler`]와 같은 향상된 스케줄러에 필요한 노이즈 스케일 값인 초기 노이즈 분포 *sigma* 로 입력을 스케일링 하는 것부터 시작합니다:\n\n```py\n>>> latents = latents * scheduler.init_noise_sigma\n```\n\n마지막 단계는 `latent`의 순수한 노이즈를 점진적으로 프롬프트에 설명된 이미지로 변환하는 노이즈 제거 루프를 생성하는 것입니다. 노이즈 제거 루프는 세 가지 작업을 수행해야 한다는 점을 기억하세요:\n\n1. 노이즈 제거 중에 사용할 스케줄러의 timesteps를 설정합니다.\n2. timestep을 따라 반복합니다.\n3. 각 timestep에서 UNet 모델을 호출하여 noise residual을 예측하고 스케줄러에 전달하여 이전 노이즈 샘플을 계산합니다.\n\n```py\n>>> from tqdm.auto import tqdm\n\n>>> scheduler.set_timesteps(num_inference_steps)\n\n>>> for t in tqdm(scheduler.timesteps):\n...     # classifier-free guidance를 수행하는 경우 두번의 forward pass를 수행하지 않도록 latent를 확장.\n...     latent_model_input = torch.cat([latents] * 2)\n\n...     latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)\n\n...     # noise residual 예측\n...     with torch.no_grad():\n...         noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n\n...     # guidance 수행\n...     noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n...     noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n...     # 이전 노이즈 샘플을 계산 x_t -> x_t-1\n...     latents = scheduler.step(noise_pred, t, latents).prev_sample\n```\n\n### 이미지 디코딩\n\n마지막 단계는 `vae`를 이용하여 잠재 표현을 이미지로 디코딩하고 `sample`과 함께 디코딩된 출력을 얻는 것입니다:\n\n```py\n# latent를 스케일링하고 vae로 이미지 디코딩\nlatents = 1 / 0.18215 * latents\nwith torch.no_grad():\n    image = vae.decode(latents).sample\n```\n\n마지막으로 이미지를 `PIL.Image`로 변환하면 생성된 이미지를 확인할 수 있습니다!\n\n```py\n>>> image = (image / 2 + 0.5).clamp(0, 1)\n>>> image = image.detach().cpu().permute(0, 2, 3, 1).numpy()\n>>> images = (image * 255).round().astype(\"uint8\")\n>>> pil_images = [Image.fromarray(image) for image in images]\n>>> pil_images[0]\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/blog/assets/98_stable_diffusion/stable_diffusion_k_lms.png\"/>\n</div>\n\n## 다음 단계\n\n기본 파이프라인부터 복잡한 파이프라인까지, 자신만의 diffusion 시스템을 작성하는 데 필요한 것은 노이즈 제거 루프뿐이라는 것을 알 수 있었습니다. 이 루프는 스케줄러의 timesteps를 설정하고, 이를 반복하며, UNet 모델을 호출하여 noise residual을 예측하고 스케줄러에 전달하여 이전 노이즈 샘플을 계산하는 과정을 번갈아 가며 수행해야 합니다.\n\n이것이 바로 🧨 Diffusers가 설계된 목적입니다: 모델과 스케줄러를 사용해 자신만의 diffusion 시스템을 직관적이고 쉽게 작성할 수 있도록 하기 위해서입니다.\n\n다음 단계를 자유롭게 진행하세요:\n\n* 🧨 Diffusers에 [파이프라인 구축 및 기여](using-diffusers/#contribute_pipeline)하는 방법을 알아보세요. 여러분이 어떤 아이디어를 내놓을지 기대됩니다!\n* 라이브러리에서 [기본 파이프라인](./api/pipelines/overview)을 살펴보고, 모델과 스케줄러를 별도로 사용하여 파이프라인을 처음부터 해체하고 빌드할 수 있는지 확인해 보세요.\n"
  },
  {
    "path": "docs/source/pt/_toctree.yml",
    "content": "- sections:\n  - local: index\n    title: Diffusers\n  - local: installation\n    title: Instalação\n  - local: quicktour\n    title: Tour rápido\n  - local: stable_diffusion\n    title: Desempenho básico\n  title: Primeiros passos\n"
  },
  {
    "path": "docs/source/pt/index.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://raw.githubusercontent.com/huggingface/diffusers/77aadfee6a891ab9fcfb780f87c693f7a5beeb8e/docs/source/imgs/diffusers_library.jpg\" width=\"400\"/>\n    <br>\n</p>\n\n# Diffusers\n\n🤗 Diffusers é uma biblioteca de modelos de difusão de última geração para geração de imagens, áudio e até mesmo estruturas 3D de moléculas. Se você está procurando uma solução de geração simples ou quer treinar seu próprio modelo de difusão, 🤗 Diffusers é uma caixa de ferramentas modular que suporta ambos. Nossa biblioteca é desenhada com foco em [usabilidade em vez de desempenho](conceptual/philosophy#usability-over-performance), [simples em vez de fácil](conceptual/philosophy#simple-over-easy) e [customizável em vez de abstrações](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).\n\nA Biblioteca tem três componentes principais:\n\n- Pipelines de última geração para a geração em poucas linhas de código. Há muitos pipelines no 🤗 Diffusers, veja a tabela no pipeline [Visão geral](api/pipelines/overview) para uma lista completa de pipelines disponíveis e as tarefas que eles resolvem.\n- Intercambiáveis [agendadores de ruído](api/schedulers/overview) para balancear as compensações entre velocidade e qualidade de geração.\n- [Modelos](api/models) pré-treinados que podem ser usados como se fossem blocos de construção, e combinados com agendadores, para criar seu próprio sistema de difusão de ponta a ponta.\n\n<div class=\"mt-10\">\n  <div class=\"w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5\">\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./tutorials/tutorial_overview\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-blue-400 to-blue-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Tutoriais</div>\n      <p class=\"text-gray-700\">Aprenda as competências fundamentais que precisa para iniciar a gerar saídas, construa seu próprio sistema de difusão, e treine um modelo de difusão. Nós recomendamos começar por aqui se você está utilizando o 🤗 Diffusers pela primeira vez!</p>\n    </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./using-diffusers/loading_overview\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-indigo-400 to-indigo-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Guias de utilização</div>\n      <p class=\"text-gray-700\">Guias práticos para ajudar você carregar pipelines, modelos, e agendadores. Você também aprenderá como usar os pipelines para tarefas específicas, controlar como as saídas são geradas, otimizar a velocidade de geração, e outras técnicas diferentes de treinamento.</p>\n    </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./conceptual/philosophy\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-pink-400 to-pink-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Guias conceituais</div>\n      <p class=\"text-gray-700\">Compreenda porque a biblioteca foi desenhada da forma que ela é, e aprenda mais sobre as diretrizes éticas e implementações de segurança para o uso da biblioteca.</p>\n   </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./api/models/overview\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-purple-400 to-purple-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Referência</div>\n      <p class=\"text-gray-700\">Descrições técnicas de como funcionam as classes e métodos do 🤗 Diffusers</p>\n    </a>\n  </div>\n</div>\n"
  },
  {
    "path": "docs/source/pt/installation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Instalação\n\n🤗 Diffusers é testado no Python 3.8+, PyTorch 1.7.0+, e Flax. Siga as instruções de instalação abaixo para a biblioteca de deep learning que você está utilizando:\n\n- [PyTorch](https://pytorch.org/get-started/locally/) instruções de instalação\n- [Flax](https://flax.readthedocs.io/en/latest/) instruções de instalação\n\n## Instalação com pip\n\nRecomenda-se instalar 🤗 Diffusers em um [ambiente virtual](https://docs.python.org/3/library/venv.html).\nSe você não está familiarizado com ambiente virtuals, veja o [guia](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).\nUm ambiente virtual facilita gerenciar diferentes projetos e evitar problemas de compatibilidade entre dependências.\n\nComece criando um ambiente virtual no diretório do projeto:\n\n```bash\npython -m venv .env\n```\n\nAtive o ambiente virtual:\n\n```bash\nsource .env/bin/activate\n```\n\nRecomenda-se a instalação do 🤗 Transformers porque 🤗 Diffusers depende de seus modelos:\n\n<frameworkcontent>\n<pt>\n```bash\npip install diffusers[\"torch\"] transformers\n```\n</pt>\n<jax>\n```bash\npip install diffusers[\"flax\"] transformers\n```\n</jax>\n</frameworkcontent>\n\n## Instalação a partir do código fonte\n\nAntes da instalação do 🤗 Diffusers a partir do código fonte, certifique-se de ter o PyTorch e o 🤗 Accelerate instalados.\n\nPara instalar o 🤗 Accelerate:\n\n```bash\npip install accelerate\n```\n\nentão instale o 🤗 Diffusers do código fonte:\n\n```bash\npip install git+https://github.com/huggingface/diffusers\n```\n\nEsse comando instala a última versão em desenvolvimento `main` em vez da última versão estável `stable`.\nA versão `main` é útil para se manter atualizado com os últimos desenvolvimentos.\nPor exemplo, se um bug foi corrigido desde o último lançamento estável, mas um novo lançamento ainda não foi lançado.\nNo entanto, isso significa que a versão `main` pode não ser sempre estável.\nNós nos esforçamos para manter a versão `main` operacional, e a maioria dos problemas geralmente são resolvidos em algumas horas ou um dia.\nSe você encontrar um problema, por favor abra uma [Issue](https://github.com/huggingface/diffusers/issues/new/choose), assim conseguimos arrumar o quanto antes!\n\n## Instalação editável\n\nVocê precisará de uma instalação editável se você:\n\n- Usar a versão `main` do código fonte.\n- Contribuir para o 🤗 Diffusers e precisa testar mudanças no código.\n\nClone o repositório e instale o 🤗 Diffusers com os seguintes comandos:\n\n```bash\ngit clone https://github.com/huggingface/diffusers.git\ncd diffusers\n```\n\n<frameworkcontent>\n<pt>\n```bash\npip install -e \".[torch]\"\n```\n</pt>\n<jax>\n```bash\npip install -e \".[flax]\"\n```\n</jax>\n</frameworkcontent>\n\nEsses comandos irão vincular a pasta que você clonou o repositório e os caminhos das suas bibliotecas Python.\nPython então irá procurar dentro da pasta que você clonou além dos caminhos normais das bibliotecas.\nPor exemplo, se o pacote python for tipicamente instalado no `~/anaconda3/envs/main/lib/python3.10/site-packages/`, o Python também irá procurar na pasta `~/diffusers/` que você clonou.\n\n> [!WARNING]\n> Você deve manter a pasta `diffusers` se quiser continuar usando a biblioteca.\n\nAgora você pode facilmente atualizar seu clone para a última versão do 🤗 Diffusers com o seguinte comando:\n\n```bash\ncd ~/diffusers/\ngit pull\n```\n\nSeu ambiente Python vai encontrar a versão `main` do 🤗 Diffusers na próxima execução.\n\n## Cache\n\nOs pesos e os arquivos dos modelos são baixados do Hub para o cache que geralmente é o seu diretório home. Você pode mudar a localização do cache especificando as variáveis de ambiente `HF_HOME` ou `HUGGINFACE_HUB_CACHE` ou configurando o parâmetro `cache_dir` em métodos como [`~DiffusionPipeline.from_pretrained`].\n\nAquivos em cache permitem que você rode 🤗 Diffusers offline. Para prevenir que o 🤗 Diffusers se conecte à internet, defina a variável de ambiente `HF_HUB_OFFLINE` para `True` e o 🤗 Diffusers irá apenas carregar arquivos previamente baixados em cache.\n\n```shell\nexport HF_HUB_OFFLINE=True\n```\n\nPara mais detalhes de como gerenciar e limpar o cache, olhe o guia de [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache).\n\n## Telemetria\n\nNossa biblioteca coleta informações de telemetria durante as requisições [`~DiffusionPipeline.from_pretrained`].\nO dado coletado inclui a versão do 🤗 Diffusers e PyTorch/Flax, o modelo ou classe de pipeline requisitado,\ne o caminho para um checkpoint pré-treinado se ele estiver hospedado no Hugging Face Hub.\nEsse dado de uso nos ajuda a debugar problemas e priorizar novas funcionalidades.\nTelemetria é enviada apenas quando é carregado modelos e pipelines do Hub,\ne não é coletado se você estiver carregando arquivos locais.\n\nNos entendemos que nem todo mundo quer compartilhar informações adicionais, e nós respeitamos sua privacidade.\nVocê pode desabilitar a coleta de telemetria definindo a variável de ambiente `DISABLE_TELEMETRY` do seu terminal:\n\nNo Linux/MacOS:\n\n```bash\nexport DISABLE_TELEMETRY=YES\n```\n\nNo Windows:\n\n```bash\nset DISABLE_TELEMETRY=YES\n```\n"
  },
  {
    "path": "docs/source/pt/quicktour.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# Tour rápido\n\nModelos de difusão são treinados para remover o ruído Gaussiano aleatório passo a passo para gerar uma amostra de interesse, como uma imagem ou áudio. Isso despertou um tremendo interesse em IA generativa, e você provavelmente já viu exemplos de imagens geradas por difusão na internet. 🧨 Diffusers é uma biblioteca que visa tornar os modelos de difusão amplamente acessíveis a todos.\n\nSeja você um desenvolvedor ou um usuário, esse tour rápido irá introduzir você ao 🧨 Diffusers e ajudar você a começar a gerar rapidamente! Há três componentes principais da biblioteca para conhecer:\n\n- O [`DiffusionPipeline`] é uma classe de alto nível de ponta a ponta desenhada para gerar rapidamente amostras de modelos de difusão pré-treinados para inferência.\n- [Modelos](./api/models) pré-treinados populares e módulos que podem ser usados como blocos de construção para criar sistemas de difusão.\n- Vários [Agendadores](./api/schedulers/overview) diferentes - algoritmos que controlam como o ruído é adicionado para treinamento, e como gerar imagens sem o ruído durante a inferência.\n\nEsse tour rápido mostrará como usar o [`DiffusionPipeline`] para inferência, e então mostrará como combinar um modelo e um agendador para replicar o que está acontecendo dentro do [`DiffusionPipeline`].\n\n> [!TIP]\n> Esse tour rápido é uma versão simplificada da introdução 🧨 Diffusers [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) para ajudar você a começar rápido. Se você quer aprender mais sobre o objetivo do 🧨 Diffusers, filosofia de design, e detalhes adicionais sobre a API principal, veja o notebook!\n\nAntes de começar, certifique-se de ter todas as bibliotecas necessárias instaladas:\n\n```py\n# uncomment to install the necessary libraries in Colab\n#!pip install --upgrade diffusers accelerate transformers\n```\n\n- [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) acelera o carregamento do modelo para geração e treinamento.\n- [🤗 Transformers](https://huggingface.co/docs/transformers/index) é necessário para executar os modelos mais populares de difusão, como o [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview).\n\n## DiffusionPipeline\n\nO [`DiffusionPipeline`] é a forma mais fácil de usar um sistema de difusão pré-treinado para geração. É um sistema de ponta a ponta contendo o modelo e o agendador. Você pode usar o [`DiffusionPipeline`] pronto para muitas tarefas. Dê uma olhada na tabela abaixo para algumas tarefas suportadas, e para uma lista completa de tarefas suportadas, veja a tabela [Resumo do 🧨 Diffusers](./api/pipelines/overview#diffusers-summary).\n\n| **Tarefa**                             | **Descrição**                                                                                                             | **Pipeline**                                                                       |\n| -------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- |\n| Unconditional Image Generation         | gera uma imagem a partir do ruído Gaussiano                                                                               | [unconditional_image_generation](./using-diffusers/unconditional_image_generation) |\n| Text-Guided Image Generation           | gera uma imagem a partir de um prompt de texto                                                                            | [conditional_image_generation](./using-diffusers/conditional_image_generation)     |\n| Text-Guided Image-to-Image Translation | adapta uma imagem guiada por um prompt de texto                                                                           | [img2img](./using-diffusers/img2img)                                               |\n| Text-Guided Image-Inpainting           | preenche a parte da máscara da imagem, dado a imagem, a máscara e o prompt de texto                                       | [inpaint](./using-diffusers/inpaint)                                               |\n| Text-Guided Depth-to-Image Translation | adapta as partes de uma imagem guiada por um prompt de texto enquanto preserva a estrutura por estimativa de profundidade | [depth2img](./using-diffusers/depth2img)                                           |\n\nComece criando uma instância do [`DiffusionPipeline`] e especifique qual checkpoint do pipeline você gostaria de baixar.\nVocê pode usar o [`DiffusionPipeline`] para qualquer [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads) armazenado no Hugging Face Hub.\nNesse quicktour, você carregará o checkpoint [`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) para geração de texto para imagem.\n\n> [!WARNING]\n> Para os modelos de [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion), por favor leia cuidadosamente a [licença](https://huggingface.co/spaces/CompVis/stable-diffusion-license) primeiro antes de rodar o modelo. 🧨 Diffusers implementa uma verificação de segurança: [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) para prevenir conteúdo ofensivo ou nocivo, mas as capacidades de geração de imagem aprimorada do modelo podem ainda produzir conteúdo potencialmente nocivo.\n\nPara carregar o modelo com o método [`~DiffusionPipeline.from_pretrained`]:\n\n```python\n>>> from diffusers import DiffusionPipeline\n\n>>> pipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", use_safetensors=True)\n```\n\nO [`DiffusionPipeline`] baixa e armazena em cache todos os componentes de modelagem, tokenização, e agendamento. Você verá que o pipeline do Stable Diffusion é composto pelo [`UNet2DConditionModel`] e [`PNDMScheduler`] entre outras coisas:\n\n```py\n>>> pipeline\nStableDiffusionPipeline {\n  \"_class_name\": \"StableDiffusionPipeline\",\n  \"_diffusers_version\": \"0.13.1\",\n  ...,\n  \"scheduler\": [\n    \"diffusers\",\n    \"PNDMScheduler\"\n  ],\n  ...,\n  \"unet\": [\n    \"diffusers\",\n    \"UNet2DConditionModel\"\n  ],\n  \"vae\": [\n    \"diffusers\",\n    \"AutoencoderKL\"\n  ]\n}\n```\n\nNós fortemente recomendamos rodar o pipeline em uma placa de vídeo, pois o modelo consiste em aproximadamente 1.4 bilhões de parâmetros.\nVocê pode mover o objeto gerador para uma placa de vídeo, assim como você faria no PyTorch:\n\n```python\n>>> pipeline.to(\"cuda\")\n```\n\nAgora você pode passar o prompt de texto para o `pipeline` para gerar uma imagem, e então acessar a imagem sem ruído. Por padrão, a saída da imagem é embrulhada em um objeto [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class).\n\n```python\n>>> image = pipeline(\"An image of a squirrel in Picasso style\").images[0]\n>>> image\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png\"/>\n</div>\n\nSalve a imagem chamando o `save`:\n\n```python\n>>> image.save(\"image_of_squirrel_painting.png\")\n```\n\n### Pipeline local\n\nVocê também pode utilizar o pipeline localmente. A única diferença é que você precisa baixar os pesos primeiro:\n\n```bash\n!git lfs install\n!git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5\n```\n\nAssim carregue os pesos salvos no pipeline:\n\n```python\n>>> pipeline = DiffusionPipeline.from_pretrained(\"./stable-diffusion-v1-5\", use_safetensors=True)\n```\n\nAgora você pode rodar o pipeline como você faria na seção acima.\n\n### Troca dos agendadores\n\nAgendadores diferentes tem diferentes velocidades de retirar o ruído e compensações de qualidade. A melhor forma de descobrir qual funciona melhor para você é testar eles! Uma das principais características do 🧨 Diffusers é permitir que você troque facilmente entre agendadores. Por exemplo, para substituir o [`PNDMScheduler`] padrão com o [`EulerDiscreteScheduler`], carregue ele com o método [`~diffusers.ConfigMixin.from_config`]:\n\n```py\n>>> from diffusers import EulerDiscreteScheduler\n\n>>> pipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", use_safetensors=True)\n>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)\n```\n\nTente gerar uma imagem com o novo agendador e veja se você nota alguma diferença!\n\nNa próxima seção, você irá dar uma olhada mais de perto nos componentes - o modelo e o agendador - que compõe o [`DiffusionPipeline`] e aprender como usar esses componentes para gerar uma imagem de um gato.\n\n## Modelos\n\nA maioria dos modelos recebe uma amostra de ruído, e em cada _timestep_ ele prevê o _noise residual_ (outros modelos aprendem a prever a amostra anterior diretamente ou a velocidade ou [`v-prediction`](https://github.com/huggingface/diffusers/blob/5e5ce13e2f89ac45a0066cb3f369462a3cf1d9ef/src/diffusers/schedulers/scheduling_ddim.py#L110)), a diferença entre uma imagem menos com ruído e a imagem de entrada. Você pode misturar e combinar modelos para criar outros sistemas de difusão.\n\nModelos são inicializados com o método [`~ModelMixin.from_pretrained`] que também armazena em cache localmente os pesos do modelo para que seja mais rápido na próxima vez que você carregar o modelo. Para o tour rápido, você irá carregar o [`UNet2DModel`], um modelo básico de geração de imagem incondicional com um checkpoint treinado em imagens de gato:\n\n```py\n>>> from diffusers import UNet2DModel\n\n>>> repo_id = \"google/ddpm-cat-256\"\n>>> model = UNet2DModel.from_pretrained(repo_id, use_safetensors=True)\n```\n\nPara acessar os parâmetros do modelo, chame `model.config`:\n\n```py\n>>> model.config\n```\n\nA configuração do modelo é um dicionário 🧊 congelado 🧊, o que significa que esses parâmetros não podem ser mudados depois que o modelo é criado. Isso é intencional e garante que os parâmetros usados para definir a arquitetura do modelo no início permaneçam os mesmos, enquanto outros parâmetros ainda podem ser ajustados durante a geração.\n\nUm dos parâmetros mais importantes são:\n\n- `sample_size`: a dimensão da altura e largura da amostra de entrada.\n- `in_channels`: o número de canais de entrada da amostra de entrada.\n- `down_block_types` e `up_block_types`: o tipo de blocos de downsampling e upsampling usados para criar a arquitetura UNet.\n- `block_out_channels`: o número de canais de saída dos blocos de downsampling; também utilizado como uma order reversa do número de canais de entrada dos blocos de upsampling.\n- `layers_per_block`: o número de blocks ResNet presentes em cada block UNet.\n\nPara usar o modelo para geração, crie a forma da imagem com ruído Gaussiano aleatório. Deve ter um eixo `batch` porque o modelo pode receber múltiplos ruídos aleatórios, um eixo `channel` correspondente ao número de canais de entrada, e um eixo `sample_size` para a altura e largura da imagem:\n\n```py\n>>> import torch\n\n>>> torch.manual_seed(0)\n\n>>> noisy_sample = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)\n>>> noisy_sample.shape\ntorch.Size([1, 3, 256, 256])\n```\n\nPara geração, passe a imagem com ruído para o modelo e um `timestep`. O `timestep` indica o quão ruidosa a imagem de entrada é, com mais ruído no início e menos no final. Isso ajuda o modelo a determinar sua posição no processo de difusão, se está mais perto do início ou do final. Use o método `sample` para obter a saída do modelo:\n\n```py\n>>> with torch.no_grad():\n...     noisy_residual = model(sample=noisy_sample, timestep=2).sample\n```\n\nPara geração de exemplos reais, você precisará de um agendador para guiar o processo de retirada do ruído. Na próxima seção, você irá aprender como acoplar um modelo com um agendador.\n\n## Agendadores\n\nAgendadores gerenciam a retirada do ruído de uma amostra ruidosa para uma amostra menos ruidosa dado a saída do modelo - nesse caso, é o `noisy_residual`.\n\n> [!TIP]\n> 🧨 Diffusers é uma caixa de ferramentas para construir sistemas de difusão. Enquanto o [`DiffusionPipeline`] é uma forma conveniente de começar com um sistema de difusão pré-construído, você também pode escolher seus próprios modelos e agendadores separadamente para construir um sistema de difusão personalizado.\n\nPara o tour rápido, você irá instanciar o [`DDPMScheduler`] com o método [`~diffusers.ConfigMixin.from_config`]:\n\n```py\n>>> from diffusers import DDPMScheduler\n\n>>> scheduler = DDPMScheduler.from_config(repo_id)\n>>> scheduler\nDDPMScheduler {\n  \"_class_name\": \"DDPMScheduler\",\n  \"_diffusers_version\": \"0.13.1\",\n  \"beta_end\": 0.02,\n  \"beta_schedule\": \"linear\",\n  \"beta_start\": 0.0001,\n  \"clip_sample\": true,\n  \"clip_sample_range\": 1.0,\n  \"num_train_timesteps\": 1000,\n  \"prediction_type\": \"epsilon\",\n  \"trained_betas\": null,\n  \"variance_type\": \"fixed_small\"\n}\n```\n\n> [!TIP]\n> 💡 Perceba como o agendador é instanciado de uma configuração. Diferentemente de um modelo, um agendador não tem pesos treináveis e é livre de parâmetros!\n\nUm dos parâmetros mais importante são:\n\n- `num_train_timesteps`: o tamanho do processo de retirar ruído ou em outras palavras, o número de _timesteps_ necessários para o processo de ruídos Gausianos aleatórios dentro de uma amostra de dados.\n- `beta_schedule`: o tipo de agendados de ruído para o uso de geração e treinamento.\n- `beta_start` e `beta_end`: para começar e terminar os valores de ruído para o agendador de ruído.\n\nPara predizer uma imagem com um pouco menos de ruído, passe o seguinte para o método do agendador [`~diffusers.DDPMScheduler.step`]: saída do modelo, `timestep`, e a atual `amostra`.\n\n```py\n>>> less_noisy_sample = scheduler.step(model_output=noisy_residual, timestep=2, sample=noisy_sample).prev_sample\n>>> less_noisy_sample.shape\n```\n\nO `less_noisy_sample` pode ser passado para o próximo `timestep` onde ele ficará ainda com menos ruído! Vamos juntar tudo agora e visualizar o processo inteiro de retirada de ruído.\n\nComece, criando a função que faça o pós-processamento e mostre a imagem sem ruído como uma `PIL.Image`:\n\n```py\n>>> import PIL.Image\n>>> import numpy as np\n\n\n>>> def display_sample(sample, i):\n...     image_processed = sample.cpu().permute(0, 2, 3, 1)\n...     image_processed = (image_processed + 1.0) * 127.5\n...     image_processed = image_processed.numpy().astype(np.uint8)\n\n...     image_pil = PIL.Image.fromarray(image_processed[0])\n...     display(f\"Image at step {i}\")\n...     display(image_pil)\n```\n\nPara acelerar o processo de retirada de ruído, mova a entrada e o modelo para uma GPU:\n\n```py\n>>> model.to(\"cuda\")\n>>> noisy_sample = noisy_sample.to(\"cuda\")\n```\n\nAgora, crie um loop de retirada de ruído que prediz o residual da amostra menos ruidosa, e computa a amostra menos ruidosa com o agendador:\n\n```py\n>>> import tqdm\n\n>>> sample = noisy_sample\n\n>>> for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):\n...     # 1. predict noise residual\n...     with torch.no_grad():\n...         residual = model(sample, t).sample\n\n...     # 2. compute less noisy image and set x_t -> x_t-1\n...     sample = scheduler.step(residual, t, sample).prev_sample\n\n...     # 3. optionally look at image\n...     if (i + 1) % 50 == 0:\n...         display_sample(sample, i + 1)\n```\n\nSente-se e assista o gato ser gerado do nada além de ruído! 😻\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/diffusion-quicktour.png\"/>\n</div>\n\n## Próximos passos\n\nEsperamos que você tenha gerado algumas imagens legais com o 🧨 Diffusers neste tour rápido! Para suas próximas etapas, você pode\n\n- Treine ou faça a configuração fina de um modelo para gerar suas próprias imagens no tutorial de [treinamento](./tutorials/basic_training).\n- Veja exemplos oficiais e da comunidade de [scripts de treinamento ou configuração fina](https://github.com/huggingface/diffusers/tree/main/examples#-diffusers-examples) para os mais variados casos de uso.\n- Aprenda sobre como carregar, acessar, mudar e comparar agendadores no guia [Usando diferentes agendadores](./using-diffusers/schedulers).\n- Explore engenharia de prompt, otimizações de velocidade e memória, e dicas e truques para gerar imagens de maior qualidade com o guia [Stable Diffusion](./stable_diffusion).\n- Se aprofunde em acelerar 🧨 Diffusers com guias sobre [PyTorch otimizado em uma GPU](./optimization/fp16), e guias de inferência para rodar [Stable Diffusion em Apple Silicon (M1/M2)](./optimization/mps) e [ONNX Runtime](./optimization/onnx).\n"
  },
  {
    "path": "docs/source/pt/stable_diffusion.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# Desempenho básico\n\nDifusão é um processo aleatório que demanda muito processamento. Você pode precisar executar o [`DiffusionPipeline`] várias vezes antes de obter o resultado desejado. Por isso é importante equilibrar cuidadosamente a velocidade de geração e o uso de memória para iterar mais rápido.\n\nEste guia recomenda algumas dicas básicas de desempenho para usar o [`DiffusionPipeline`]. Consulte a seção de documentação sobre Otimização de Inferência, como [Acelerar inferência](./optimization/fp16) ou [Reduzir uso de memória](./optimization/memory) para guias de desempenho mais detalhados.\n\n## Uso de memória\n\nReduzir a quantidade de memória usada indiretamente acelera a geração e pode ajudar um modelo a caber no dispositivo.\n\nO método [`~DiffusionPipeline.enable_model_cpu_offload`] move um modelo para a CPU quando não está em uso para economizar memória da GPU.\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\",\n  torch_dtype=torch.bfloat16,\n  device_map=\"cuda\"\n)\npipeline.enable_model_cpu_offload()\n\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\npipeline(prompt).images[0]\nprint(f\"Memória máxima reservada: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n```\n\n## Velocidade de inferência\n\nO processo de remoção de ruído é o mais exigente computacionalmente durante a difusão. Métodos que otimizam este processo aceleram a velocidade de inferência. Experimente os seguintes métodos para acelerar.\n\n- Adicione `device_map=\"cuda\"` para colocar o pipeline em uma GPU. Colocar um modelo em um acelerador, como uma GPU, aumenta a velocidade porque realiza computações em paralelo.\n- Defina `torch_dtype=torch.bfloat16` para executar o pipeline em meia-precisão. Reduzir a precisão do tipo de dado aumenta a velocidade porque leva menos tempo para realizar computações em precisão mais baixa.\n\n```py\nimport torch\nimport time\nfrom diffusers import DiffusionPipeline, DPMSolverMultistepScheduler\n\npipeline = DiffusionPipeline.from_pretrained(\n  \"stabilityai/stable-diffusion-xl-base-1.0\",\n  torch_dtype=torch.bfloat16,\n  device_map=\"cuda\"\n)\n```\n\n- Use um agendador mais rápido, como [`DPMSolverMultistepScheduler`], que requer apenas ~20-25 passos.\n- Defina `num_inference_steps` para um valor menor. Reduzir o número de passos de inferência reduz o número total de computações. No entanto, isso pode resultar em menor qualidade de geração.\n\n```py\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n\nprompt = \"\"\"\ncinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\nhighly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\n\nstart_time = time.perf_counter()\nimage = pipeline(prompt).images[0]\nend_time = time.perf_counter()\n\nprint(f\"Geração de imagem levou {end_time - start_time:.3f} segundos\")\n```\n\n## Qualidade de geração\n\nMuitos modelos de difusão modernos entregam imagens de alta qualidade imediatamente. No entanto, você ainda pode melhorar a qualidade de geração experimentando o seguinte.\n\n- Experimente um prompt mais detalhado e descritivo. Inclua detalhes como o meio da imagem, assunto, estilo e estética. Um prompt negativo também pode ajudar, guiando um modelo para longe de características indesejáveis usando palavras como baixa qualidade ou desfocado.\n\n    ```py\n    import torch\n    from diffusers import DiffusionPipeline\n\n    pipeline = DiffusionPipeline.from_pretrained(\n        \"stabilityai/stable-diffusion-xl-base-1.0\",\n        torch_dtype=torch.bfloat16,\n        device_map=\"cuda\"\n    )\n\n    prompt = \"\"\"\n    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\n    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n    \"\"\"\n    negative_prompt = \"low quality, blurry, ugly, poor details\"\n    pipeline(prompt, negative_prompt=negative_prompt).images[0]\n    ```\n\n    Para mais detalhes sobre como criar prompts melhores, consulte a documentação sobre [Técnicas de prompt](./using-diffusers/weighted_prompts).\n\n- Experimente um agendador diferente, como [`HeunDiscreteScheduler`] ou [`LMSDiscreteScheduler`], que sacrifica velocidade de geração por qualidade.\n\n    ```py\n    import torch\n    from diffusers import DiffusionPipeline, HeunDiscreteScheduler\n\n    pipeline = DiffusionPipeline.from_pretrained(\n        \"stabilityai/stable-diffusion-xl-base-1.0\",\n        torch_dtype=torch.bfloat16,\n        device_map=\"cuda\"\n    )\n    pipeline.scheduler = HeunDiscreteScheduler.from_config(pipeline.scheduler.config)\n\n    prompt = \"\"\"\n    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\n    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n    \"\"\"\n    negative_prompt = \"low quality, blurry, ugly, poor details\"\n    pipeline(prompt, negative_prompt=negative_prompt).images[0]\n    ```\n\n## Próximos passos\n\nDiffusers oferece otimizações mais avançadas e poderosas, como [group-offloading](./optimization/memory#group-offloading) e [compilação regional](./optimization/fp16#regional-compilation). Para saber mais sobre como maximizar o desempenho, consulte a seção sobre Otimização de Inferência.\n"
  },
  {
    "path": "docs/source/zh/_toctree.yml",
    "content": "- title: 开始Diffusers\n  sections:\n  - local: index\n    title: Diffusers\n  - local: installation\n    title: 安装\n  - local: quicktour\n    title: 快速入门\n  - local: stable_diffusion\n    title: 有效和高效的扩散\n\n- title: DiffusionPipeline\n  isExpanded: false\n  sections:\n  - local: using-diffusers/schedulers\n    title: Load schedulers and models\n  - local: using-diffusers/guiders\n    title: Guiders\n\n- title: Inference\n  isExpanded: false\n  sections:\n  - local: training/distributed_inference\n    title: Distributed inference\n\n- title: Inference optimization\n  isExpanded: false\n  sections:\n  - local: optimization/fp16\n    title: Accelerate inference\n  - local: optimization/cache\n    title: Caching\n  - local: optimization/memory\n    title: Reduce memory usage\n  - local: optimization/speed-memory-optims\n    title: Compile and offloading quantized models\n  - title: Community optimizations\n    sections:\n    - local: optimization/pruna\n      title: Pruna\n    - local: optimization/xformers\n      title: xFormers\n    - local: optimization/tome\n      title: Token merging\n    - local: optimization/deepcache\n      title: DeepCache\n    - local: optimization/tgate\n      title: TGATE\n    - local: optimization/xdit\n      title: xDiT\n    - local: optimization/para_attn\n      title: ParaAttention\n\n- title: Hybrid Inference\n  isExpanded: false\n  sections:\n  - local: hybrid_inference/overview\n    title: Overview\n  - local: hybrid_inference/vae_encode\n    title: VAE Encode\n  - local: hybrid_inference/api_reference\n    title: API Reference\n\n- title: Modular Diffusers\n  isExpanded: false\n  sections:\n  - local: modular_diffusers/overview\n    title: Overview\n  - local: modular_diffusers/quickstart\n    title: Quickstart\n  - local: modular_diffusers/modular_diffusers_states\n    title: States\n  - local: modular_diffusers/pipeline_block\n    title: ModularPipelineBlocks\n  - local: modular_diffusers/sequential_pipeline_blocks\n    title: SequentialPipelineBlocks\n  - local: modular_diffusers/loop_sequential_pipeline_blocks\n    title: LoopSequentialPipelineBlocks\n  - local: modular_diffusers/auto_pipeline_blocks\n    title: AutoPipelineBlocks\n  - local: modular_diffusers/modular_pipeline\n    title: ModularPipeline\n  - local: modular_diffusers/components_manager\n    title: ComponentsManager\n\n- title: Training\n  isExpanded: false\n  sections:\n  - local: training/overview\n    title: Overview\n  - local: training/adapt_a_model\n    title: Adapt a model to a new task\n  - title: Models\n    sections:\n    - local: training/text2image\n      title: Text-to-image\n    - local: training/kandinsky\n      title: Kandinsky 2.2\n    - local: training/wuerstchen\n      title: Wuerstchen\n    - local: training/controlnet\n      title: ControlNet\n    - local: training/instructpix2pix\n      title: InstructPix2Pix\n  - title: Methods\n    sections:\n    - local: training/text_inversion\n      title: Textual Inversion\n    - local: training/dreambooth\n      title: DreamBooth\n    - local: training/lora\n      title: LoRA\n\n- title: Model accelerators and hardware\n  isExpanded: false\n  sections:\n  - local: optimization/onnx\n    title: ONNX\n  - local: optimization/open_vino\n    title: OpenVINO\n  - local: optimization/coreml\n    title: Core ML\n  - local: optimization/mps\n    title: Metal Performance Shaders (MPS)\n  - local: optimization/habana\n    title: Intel Gaudi\n  - local: optimization/neuron\n    title: AWS Neuron\n\n- title: Specific pipeline examples\n  isExpanded: false\n  sections:\n  - local: using-diffusers/consisid\n    title: ConsisID\n  - local: using-diffusers/helios\n    title: Helios\n\n- title: Resources\n  isExpanded: false\n  sections:\n  - title: Task recipes\n    sections:\n    - local: community_projects\n      title: Projects built with Diffusers\n    - local: conceptual/philosophy\n      title: Philosophy\n    - local: conceptual/contribution\n      title: How to contribute?\n    - local: conceptual/ethical_guidelines\n      title: Diffusers' Ethical Guidelines\n    - local: conceptual/evaluation\n      title: Evaluating Diffusion Models\n"
  },
  {
    "path": "docs/source/zh/community_projects.md",
    "content": "<!--版权 2025 The HuggingFace Team。保留所有权利。\n\n根据Apache许可证，版本2.0（\"许可证\"）授权；除非符合许可证，否则不得使用此文件。您可以在\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n获取许可证的副本。\n\n除非适用法律要求或书面同意，根据许可证分发的软件是按\"原样\"分发的，没有任何形式的明示或暗示的担保或条件。有关许可证的特定语言，请参阅许可证。\n-->\n\n# 社区项目\n\n欢迎来到社区项目。这个空间致力于展示我们充满活力的社区使用`diffusers`库创建的令人难以置信的工作和创新应用。\n\n本节旨在：\n\n- 突出使用`diffusers`构建的多样化和鼓舞人心的项目\n- 促进我们社区内的知识共享\n- 提供如何利用`diffusers`的实际例子\n\n探索愉快，感谢您成为Diffusers社区的一部分！\n\n<table>\n    <tr>\n        <th>项目名称</th>\n        <th>描述</th>\n    </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/PKU-YuanGroup/Helios\"> helios </a></td>\n    <td>Helios：比1.3B更低开销、更快且更强的14B的实时长视频生成模型</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/PKU-YuanGroup/ConsisID\"> consisid </a></td>\n    <td>ConsisID：零样本身份保持的文本到视频生成模型</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/carson-katri/dream-textures\"> dream-textures </a></td>\n    <td>Stable Diffusion内置到Blender</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/megvii-research/HiDiffusion\"> HiDiffusion </a></td>\n    <td>仅通过添加一行代码即可提高扩散模型的分辨率和速度</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/lllyasviel/IC-Light\"> IC-Light </a></td>\n    <td>IC-Light是一个用于操作图像照明的项目</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/InstantID/InstantID\"> InstantID </a></td>\n    <td>InstantID：零样本身份保留生成在几秒钟内</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/Sanster/IOPaint\"> IOPaint </a></td>\n    <td>由SOTA AI模型驱动的图像修复工具。从您的图片中移除任何不需要的物体、缺陷、人物，或擦除并替换（由stable_diffusion驱动）图片上的任何内容。</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/bmaltais/kohya_ss\"> Kohya </a></td>\n    <td>Kohya的Stable Diffusion训练器的Gradio GUI</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/magic-research/magic-animate\"> MagicAnimate </a></td>\n    <td>MagicAnimate：使用扩散模型进行时间一致的人体图像动画</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/levihsu/OOTDiffusion\"> OOTDiffusion </a></td>\n    <td>基于潜在扩散的虚拟试穿控制</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/vladmandic/automatic\"> SD.Next </a></td>\n    <td>SD.Next: Stable Diffusion 和其他基于Diffusion的生成图像模型的高级实现</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/ashawkey/stable-dreamfusion\"> stable-dreamfusion </a></td>\n    <td>使用 NeRF + Diffusion 进行文本到3D & 图像到3D & 网格导出</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/HVision-NKU/StoryDiffusion\"> StoryDiffusion </a></td>\n    <td>StoryDiffusion 可以通过生成一致的图像和视频来创造一个神奇的故事。</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/cumulo-autumn/StreamDiffusion\"> StreamDiffusion </a></td>\n    <td>实时交互生成的管道级解决方案</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/Netwrck/stable-diffusion-server\"> Stable Diffusion Server </a></td>\n    <td>配置用于使用一个 stable diffusion 模型进行修复/生成/img2img 的服务器</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/suzukimain/auto_diffusers\"> Model Search </a></td>\n    <td>在 Civitai 和 Hugging Face 上搜索模型</td>\n  </tr>\n  <tr style=\"border-top: 2px solid black\">\n    <td><a href=\"https://github.com/beinsezii/skrample\"> Skrample </a></td>\n    <td>完全模块化的调度器功能，具有一流的 diffusers 集成。</td>\n  </tr>\n</table>\n"
  },
  {
    "path": "docs/source/zh/conceptual/contribution.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. 保留所有权利。\n\n根据Apache许可证2.0版（\"许可证\"）授权；除非符合许可证要求，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件均按\"原样\"分发，不附带任何明示或暗示的担保或条件。有关许可证下特定语言规定的权限和限制，请参阅许可证。\n-->\n\n# 如何为Diffusers 🧨做贡献\n\n我们❤️来自开源社区的贡献！欢迎所有人参与，所有类型的贡献——不仅仅是代码——都受到重视和赞赏。回答问题、帮助他人、主动交流以及改进文档对社区都极具价值，所以如果您愿意参与，请不要犹豫！\n\n我们鼓励每个人先在公开Discord频道里打招呼👋。在那里我们讨论扩散模型的最新趋势、提出问题、展示个人项目、互相协助贡献，或者只是闲聊☕。<a href=\"https://Discord.gg/G7tWnz98XR\"><img alt=\"加入Discord社区\" src=\"https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white\"></a>\n\n无论您选择以何种方式贡献，我们都致力于成为一个开放、友好、善良的社区。请阅读我们的[行为准则](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md)，并在互动时注意遵守。我们也建议您了解指导本项目的[伦理准则](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines)，并请您遵循同样的透明度和责任原则。\n\n我们高度重视社区的反馈，所以如果您认为自己有能帮助改进库的有价值反馈，请不要犹豫说出来——每条消息、评论、issue和拉取请求（PR）都会被阅读和考虑。\n\n## 概述\n\n您可以通过多种方式做出贡献，从在issue和讨论区回答问题，到向核心库添加新的diffusion模型。\n\n下面我们按难度升序列出不同的贡献方式，所有方式对社区都很有价值：\n\n* 1. 在[Diffusers讨论论坛](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers)或[Discord](https://discord.gg/G7tWnz98XR)上提问和回答问题\n* 2. 在[GitHub Issues标签页](https://github.com/huggingface/diffusers/issues/new/choose)提交新issue，或在[GitHub Discussions标签页](https://github.com/huggingface/diffusers/discussions/new/choose)发起新讨论\n* 3. 在[GitHub Issues标签页](https://github.com/huggingface/diffusers/issues)解答issue，或在[GitHub Discussions标签页](https://github.com/huggingface/diffusers/discussions)参与讨论\n* 4. 解决标记为\"Good first issue\"的简单问题，详见[此处](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n* 5. 参与[文档](https://github.com/huggingface/diffusers/tree/main/docs/source)建设\n* 6. 贡献[社区Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples)\n* 7. 完善[示例代码](https://github.com/huggingface/diffusers/tree/main/examples)\n* 8. 解决标记为\"Good second issue\"的中等难度问题，详见[此处](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22)\n* 9. 添加新pipeline/模型/调度器，参见[\"New Pipeline/Model\"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)和[\"New scheduler\"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)类issue。此类贡献请先阅读[设计哲学](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md)\n\n重申：**所有贡献对社区都具有重要价值。**下文将详细说明各类贡献方式。\n\n对于4-9类贡献，您需要提交PR（拉取请求），具体操作详见[如何提交PR](#how-to-open-a-pr)章节。\n\n### 1. 在Diffusers讨论区或Discord提问与解答\n\n任何与Diffusers库相关的问题或讨论都可以发布在[官方论坛](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/)或[Discord频道](https://discord.gg/G7tWnz98XR)，包括但不限于：\n- 分享训练/推理实验报告\n- 展示个人项目\n- 咨询非官方训练示例\n- 项目提案\n- 通用反馈\n- 论文解读\n- 基于Diffusers库的个人项目求助\n- 一般性问题\n- 关于diffusion模型的伦理讨论\n- ...\n\n论坛/Discord上的每个问题都能促使社区公开分享知识，很可能帮助未来遇到相同问题的初学者。请务必提出您的疑问。\n同样地，通过回答问题您也在为社区创造公共知识文档，这种贡献极具价值。\n\n**请注意**：提问/回答时投入的精力越多，产生的公共知识质量就越高。精心构建的问题与专业解答能形成高质量知识库，而表述不清的问题则可能降低讨论价值。\n\n低质量的问题或回答会降低公共知识库的整体质量。  \n简而言之，高质量的问题或回答应具备*精确性*、*简洁性*、*相关性*、*易于理解*、*可访问性*和*格式规范/表述清晰*等特质。更多详情请参阅[如何提交优质议题](#how-to-write-a-good-issue)章节。\n\n**关于渠道的说明**：  \n[*论坛*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)的内容能被谷歌等搜索引擎更好地收录，且帖子按热度而非时间排序，便于查找历史问答。此外，论坛内容更容易被直接链接引用。  \n而*Discord*采用即时聊天模式，适合快速交流。虽然在Discord上可能更快获得解答，但信息会随时间淹没，且难以回溯历史讨论。因此我们强烈建议在论坛发布优质问答，以构建可持续的社区知识库。若Discord讨论产生有价值结论，建议将成果整理发布至论坛以惠及更多读者。\n\n### 2. 在GitHub议题页提交新议题\n\n🧨 Diffusers库的稳健性离不开用户的问题反馈，感谢您的报错。\n\n请注意：GitHub议题仅限处理与Diffusers库代码直接相关的技术问题、错误报告、功能请求或库设计反馈。  \n简言之，**与Diffusers库代码（含文档）无关**的内容应发布至[论坛](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)或[Discord](https://discord.gg/G7tWnz98XR)。\n\n**提交新议题时请遵循以下准则**：\n- 确认是否已有类似议题（使用GitHub议题页的搜索栏）\n- 请勿在现有议题下追加新问题。若存在高度关联议题，应新建议题并添加相关链接\n- 确保使用英文提交。非英语用户可通过[DeepL](https://www.deepl.com/translator)等免费工具翻译\n- 检查升级至最新Diffusers版本是否能解决问题。提交前请确认`python -c \"import diffusers; print(diffusers.__version__)\"`显示的版本号不低于最新版本\n- 记请记住，你在提交新issue时投入的精力越多，得到的回答质量就越高，Diffusers项目的整体issue质量也会越好。\n\n新issue通常包含以下内容：\n\n#### 2.1 可复现的最小化错误报告\n\n错误报告应始终包含可复现的代码片段，并尽可能简洁明了。具体而言：\n- 尽量缩小问题范围，**不要直接粘贴整个代码文件**\n- 规范代码格式\n- 除Diffusers依赖库外，不要包含其他外部库\n- **务必**提供环境信息：可在终端运行`diffusers-cli env`命令，然后将显示的信息复制到issue中\n- 详细说明问题。如果读者不清楚问题所在及其影响，就无法解决问题\n- **确保**读者能以最小成本复现问题。如果代码片段因缺少库或未定义变量而无法运行，读者将无法提供帮助。请确保提供的可复现代码尽可能精简，可直接复制到Python shell运行\n- 如需特定模型/数据集复现问题，请确保读者能获取这些资源。可将模型/数据集上传至[Hub](https://huggingface.co)便于下载。尽量保持模型和数据集体积最小化，降低复现难度\n\n更多信息请参阅[如何撰写优质issue](#how-to-write-a-good-issue)章节。\n\n提交错误报告请点击[此处](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml)。\n\n#### 2.2 功能请求\n\n优质的功能请求应包含以下要素：\n\n1. 首先说明动机：\n* 是否与库的使用痛点相关？若是，请解释原因，最好提供演示问题的代码片段\n* 是否因项目需求产生？我们很乐意了解详情！\n* 是否是你已实现且认为对社区有价值的功能？请说明它为你解决了什么问题\n2. 用**完整段落**描述功能特性\n3. 提供**代码片段**演示预期用法\n4. 如涉及论文，请附上链接\n5. 可补充任何有助于理解的辅助材料（示意图、截图等）\n\n提交功能请求请点击[此处](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=)。\n\n#### 2.3 设计反馈\n\n关于库设计的反馈（无论正面还是负面）能极大帮助核心维护者打造更友好的库。要了解当前设计理念，请参阅[此文档](https://huggingface.co/docs/diffusers/conceptual/philosophy)如果您认为某个设计选择与当前理念不符，请说明原因及改进建议。如果某个设计选择因过度遵循理念而限制了使用场景，也请解释原因并提出调整方案。  \n若某个设计对您特别实用，请同样留下备注——这对未来的设计决策极具参考价值。\n\n您可通过[此链接](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=)提交设计反馈。\n\n#### 2.4 技术问题\n\n技术问题主要涉及库代码的实现逻辑或特定功能模块的作用。提问时请务必：  \n- 附上相关代码链接  \n- 详细说明难以理解的具体原因  \n\n技术问题提交入口：[点击此处](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml)\n\n#### 2.5 新模型/调度器/pipeline提案\n\n若diffusion模型社区发布了您希望集成到Diffusers库的新模型、pipeline或调度器，请提供以下信息：  \n* 简要说明并附论文或发布链接  \n* 开源实现链接（如有）  \n* 模型权重下载链接（如已公开）  \n\n若您愿意参与开发，请告知我们以便指导。另请尝试通过GitHub账号标记原始组件作者。  \n\n提案提交地址：[新建请求](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml)\n\n### 3. 解答GitHub问题\n\n回答GitHub问题可能需要Diffusers的技术知识，但我们鼓励所有人尝试参与——即使您对答案不完全正确。高质量回答的建议：  \n- 保持简洁精炼  \n- 严格聚焦问题本身  \n- 提供代码/论文等佐证材料  \n- 优先用代码说话：若代码片段能解决问题，请提供完整可复现代码  \n\n许多问题可能存在离题、重复或无关情况。您可以通过以下方式协助维护者：  \n- 引导提问者精确描述问题  \n- 标记重复issue并附原链接  \n- 推荐用户至[论坛](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)或[Discord](https://discord.gg/G7tWnz98XR)  \n\n在确认提交的Bug报告正确且需要修改源代码后，请继续阅读以下章节内容。\n\n以下所有贡献都需要提交PR（拉取请求）。具体操作步骤详见[如何提交PR](#how-to-open-a-pr)章节。\n\n### 4. 修复\"Good first issue\"类问题\n\n标有[Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)标签的问题通常已说明解决方案建议，便于修复。若该问题尚未关闭且您想尝试解决，只需留言\"我想尝试解决这个问题\"。通常有三种情况：\n- a.) 问题描述已提出解决方案。若您认可该方案，可直接提交PR或草稿PR进行修复\n- b.) 问题描述未提出解决方案。您可询问修复建议，Diffusers团队会尽快回复。若有成熟解决方案，也可直接提交PR\n- c.) 已有PR但问题未关闭。若原PR停滞，可新开PR并关联原PR（开源社区常见现象）。若PR仍活跃，您可通过建议、审查或协作等方式帮助原作者\n\n### 5. 文档贡献\n\n优秀库**必然**拥有优秀文档！官方文档是新用户的首要接触点，因此文档贡献具有**极高价值**。贡献形式包括：\n- 修正拼写/语法错误\n- 修复文档字符串格式错误（如显示异常或链接失效）\n- 修正文档字符串中张量的形状/维度描述\n- 优化晦涩或错误的说明\n- 更新过时代码示例\n- 文档翻译\n\n[官方文档页面](https://huggingface.co/docs/diffusers/index)所有内容均属可修改范围，对应[文档源文件](https://github.com/huggingface/diffusers/tree/main/docs/source)可进行编辑。修改前请查阅[验证说明](https://github.com/huggingface/diffusers/tree/main/docs)。\n\n### 6. 贡献社区流程\n\n> [!TIP]\n> 阅读[社区流程](../using-diffusers/custom_pipeline_overview#community-pipelines)指南了解GitHub与Hugging Face Hub社区流程的区别。若想了解我们设立社区流程的原因，请查看GitHub Issue [#841](https://github.com/huggingface/diffusers/issues/841)（简而言之，我们无法维护diffusion模型所有可能的推理使用方式，但也不希望限制社区构建这些流程）。\n\n贡献社区流程是向社区分享创意与成果的绝佳方式。您可以在[`DiffusionPipeline`]基础上构建流程，任何人都能通过设置`custom_pipeline`参数加载使用。本节将指导您创建一个简单的\"单步\"流程——UNet仅执行单次前向传播并调用调度器一次。\n\n1. 为社区流程创建one_step_unet.py文件。只要用户已安装相关包，该文件可包含任意所需包。确保仅有一个继承自[`DiffusionPipeline`]的流程类，用于从Hub加载模型权重和调度器配置。在`__init__`函数中添加UNet和调度器。\n\n    同时添加`register_modules`函数，确保您的流程及其组件可通过[`~DiffusionPipeline.save_pretrained`]保存。\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\nclass UnetSchedulerOneForwardPipeline(DiffusionPipeline):\n    def __init__(self, unet, scheduler):\n        super().__init__()\n\n        self.register_modules(unet=unet, scheduler=scheduler)\n```\n\n2. 在前向传播中（建议定义为`__call__`），可添加任意功能。对于\"单步\"流程，创建随机图像并通过设置`timestep=1`调用UNet和调度器一次。\n\n```py\n  from diffusers import DiffusionPipeline\n  import torch\n\n  class UnetSchedulerOneForwardPipeline(DiffusionPipeline):\n      def __init__(self, unet, scheduler):\n          super().__init__()\n\n          self.register_modules(unet=unet, scheduler=scheduler)\n\n      def __call__(self):\n          image = torch.randn(\n              (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),\n          )\n          timestep = 1\n\n          model_output = self.unet(image, timestep).sample\n          scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample\n\n          return scheduler_output\n```\n\n现在您可以通过传入UNet和调度器来运行流程，若流程结构相同也可加载预训练权重。\n\n```python\nfrom diffusers import DDPMScheduler, UNet2DModel\n\nscheduler = DDPMScheduler()\nunet = UNet2DModel()\n\npipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler)\noutput = pipeline()\n# 加载预训练权重\npipeline = UnetSchedulerOneForwardPipeline.from_pretrained(\"google/ddpm-cifar10-32\", use_safetensors=True)\noutput = pipeline()\n```\n\n您可以选择将pipeline作为GitHub社区pipeline或Hub社区pipeline进行分享。\n\n<hfoptions id=\"pipeline类型\">\n<hfoption id=\"GitHub pipeline\">\n\n通过向Diffusers[代码库](https://github.com/huggingface/diffusers)提交拉取请求来分享GitHub pipeline，将one_step_unet.py文件添加到[examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community)子文件夹中。\n\n</hfoption>\n<hfoption id=\"Hub pipeline\">\n\n通过在Hub上创建模型仓库并上传one_step_unet.py文件来分享Hub pipeline。\n\n</hfoption>\n</hfoptions>\n\n### 7. 贡献训练示例\n\nDiffusers训练示例是位于[examples](https://github.com/huggingface/diffusers/tree/main/examples)目录下的训练脚本集合。\n\n我们支持两种类型的训练示例：\n\n- 官方训练示例\n- 研究型训练示例\n\n研究型训练示例位于[examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects)，而官方训练示例包含[examples](https://github.com/huggingface/diffusers/tree/main/examples)目录下除`research_projects`和`community`外的所有文件夹。\n官方训练示例由Diffusers核心维护者维护，研究型训练示例则由社区维护。\n这与[6. 贡献社区pipeline](#6-contribute-a-community-pipeline)中关于官方pipeline与社区pipeline的原因相同：核心维护者不可能维护diffusion模型的所有可能训练方法。\n如果Diffusers核心维护者和社区认为某种训练范式过于实验性或不够普及，相应训练代码应放入`research_projects`文件夹并由作者维护。\n\n官方训练和研究型示例都包含一个目录，其中含有一个或多个训练脚本、`requirements.txt`文件和`README.md`文件。用户使用时需要先克隆代码库：\n\n```bash\ngit clone https://github.com/huggingface/diffusers\n```\n\n并安装训练所需的所有额外依赖：\n\n```bash\ncd diffusers\npip install -r examples/<your-example-folder>/requirements.txt\n```\n\n因此添加示例时，`requirements.txt`文件应定义训练示例所需的所有pip依赖项，安装完成后用户即可运行示例训练脚本。可参考[DreamBooth的requirements.txt文件](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt)。\n- 运行示例所需的所有代码应集中在单个Python文件中  \n- 用户应能通过命令行`python <your-example>.py --args`直接运行示例  \n- **示例**应保持简洁，主要展示如何使用Diffusers进行训练。示例脚本的目的**不是**创建最先进的diffusion模型，而是复现已知训练方案，避免添加过多自定义逻辑。因此，这些示例也力求成为优质的教学材料。\n\n提交示例时，强烈建议参考现有示例（如[dreambooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)）来了解规范格式。  \n我们强烈建议贡献者使用[Accelerate库](https://github.com/huggingface/accelerate)，因其与Diffusers深度集成。  \n当示例脚本完成后，请确保添加详细的`README.md`说明使用方法，包括：  \n- 运行示例的具体命令（示例参见[此处](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-locally-with-pytorch)）  \n- 训练结果链接（日志/模型等），展示用户可预期的效果（示例参见[此处](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5)）  \n- 若添加非官方/研究性训练示例，**必须注明**维护者信息（含Git账号），格式参照[此处](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations)  \n\n贡献官方训练示例时，还需在对应目录添加测试文件（如[examples/dreambooth/test_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/test_dreambooth.py)），非官方示例无需此步骤。\n\n### 8. 处理\"Good second issue\"类问题\n\n标有[Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22)标签的问题通常比[Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)更复杂。  \n这类问题的描述通常不会提供详细解决指引，需要贡献者对库有较深理解。  \n若您想解决此类问题，可直接提交PR并关联对应issue。若已有未合并的PR，请分析原因后提交改进版。需注意，Good second issue类PR的合并难度通常高于good first issues。在需要帮助的时候请不要犹豫，大胆的向核心维护者询问。\n\n### 9. 添加管道、模型和调度器\n\n管道（pipelines）、模型（models）和调度器（schedulers）是Diffusers库中最重要的组成部分。它们提供了对最先进diffusion技术的便捷访问，使得社区能够构建强大的生成式AI应用。\n\n通过添加新的模型、管道或调度器，您可能为依赖Diffusers的任何用户界面开启全新的强大用例，这对整个生成式AI生态系统具有巨大价值。\n\nDiffusers针对这三类组件都有一些开放的功能请求——如果您还不确定要添加哪个具体组件，可以浏览以下链接：\n- [模型或管道](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)\n- [调度器](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)\n\n在添加任何组件之前，强烈建议您阅读[设计哲学指南](philosophy)，以更好地理解这三类组件的设计理念。请注意，如果添加的模型、调度器或管道与我们的设计理念存在严重分歧，我们将无法合并，因为这会导致API不一致。如果您从根本上不同意某个设计选择，请改为提交[反馈问题](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=)，以便讨论是否应该更改库中的特定设计模式/选择，以及是否更新我们的设计哲学。保持库内的一致性对我们非常重要。\n\n请确保在PR中添加原始代码库/论文的链接，并最好直接在PR中@原始作者，以便他们可以跟踪进展并在有疑问时提供帮助。\n\n如果您在PR过程中遇到不确定或卡住的情况，请随时留言请求初步审查或帮助。\n\n#### 复制机制（Copied from）\n\n在添加任何管道、模型或调度器代码时，理解`# Copied from`机制是独特且重要的。您会在整个Diffusers代码库中看到这种机制，我们使用它的原因是为了保持代码库易于理解和维护。用`# Copied from`机制标记代码会强制标记的代码与复制来源的代码完全相同。这使得每当您运行`make fix-copies`时，可以轻松更新并将更改传播到多个文件。\n\n例如，在下面的代码示例中，[`~diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`]是原始代码，而`AltDiffusionPipelineOutput`使用`# Copied from`机制来复制它。唯一的区别是将类前缀从`Stable`改为`Alt`。\n\n```py\n# 从 diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput 复制并将 Stable 替换为 Alt\nclass AltDiffusionPipelineOutput(BaseOutput):\n    \"\"\"\n    Output class for Alt Diffusion pipelines.\n\n    Args:\n        images (`List[PIL.Image.Image]` or `np.ndarray`)\n            List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,\n            num_channels)`.\n        nsfw_content_detected (`List[bool]`)\n            List indicating whether the corresponding generated image contains \"not-safe-for-work\" (nsfw) content or\n            `None` if safety checking could not be performed.\n    \"\"\"\n```\n\n要了解更多信息，请阅读[~不要~重复自己*](https://huggingface.co/blog/transformers-design-philosophy#4-machine-learning-models-are-static)博客文章的相应部分。\n\n## 如何撰写优质问题\n\n**问题描述越清晰，被快速解决的可能性就越高。**\n\n1. 确保使用了正确的issue模板。您可以选择*错误报告*、*功能请求*、*API设计反馈*、*新模型/流水线/调度器添加*、*论坛*或空白issue。在[新建issue](https://github.com/huggingface/diffusers/issues/new/choose)时务必选择正确的模板。\n2. **精确描述**：为issue起一个恰当的标题。尽量用最简练的语言描述问题。提交issue时越精确，理解问题和潜在解决方案所需的时间就越少。确保一个issue只针对一个问题，不要将多个问题放在同一个issue中。如果发现多个问题，请分别创建多个issue。如果是错误报告，请尽可能精确描述错误类型——不应只写\"diffusers出错\"。\n3. **可复现性**：无法复现的代码片段 == 无法解决问题。如果遇到错误，维护人员必须能够**复现**它。确保包含一个可以复制粘贴到Python解释器中复现问题的代码片段。确保您的代码片段是可运行的，即没有缺少导入或图像链接等问题。issue应包含错误信息和可直接复制粘贴以复现相同错误的代码片段。如果issue涉及本地模型权重或无法被读者访问的本地数据，则问题无法解决。如果无法共享数据或模型，请尝试创建虚拟模型或虚拟数据。\n4. **最小化原则**：通过尽可能简洁的描述帮助读者快速理解问题。删除所有与问题无关的代码/信息。如果发现错误，请创建最简单的代码示例来演示问题，不要一发现错误就把整个工作流程都转储到issue中。例如，如果在训练模型时某个阶段出现错误或训练过程中遇到问题时，应首先尝试理解训练代码的哪部分导致了错误，并用少量代码尝试复现。建议使用模拟数据替代完整数据集进行测试。\n5. 添加引用链接。当提及特定命名、方法或模型时，请务必提供引用链接以便读者理解。若涉及具体PR或issue，请确保添加对应链接。不要假设读者了解你所指内容。issue中引用链接越丰富越好。\n6. 规范格式。请确保规范格式化issue内容：Python代码使用代码语法块，错误信息使用标准代码语法。详见[GitHub官方格式文档](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax)。\n7. 请将issue视为百科全书的精美词条，而非待解决的工单。每个规范撰写的issue不仅是向维护者有效传递问题的方式，更是帮助社区深入理解库特性的公共知识贡献。\n\n## 优质PR编写规范\n\n1. 保持风格统一。理解现有设计模式和语法规范，确保新增代码与代码库现有结构无缝衔接。显著偏离现有设计模式或用户界面的PR将不予合并。\n2. 聚焦单一问题。每个PR应当只解决一个明确问题，避免\"顺手修复其他问题\"的陷阱。包含多个无关修改的PR会极大增加审查难度。\n3. 如适用，建议添加代码片段演示新增功能的使用方法。\n4. PR标题应准确概括其核心贡献。\n5. 若PR针对某个issue，请在描述中注明issue编号以建立关联（也让关注该issue的用户知晓有人正在处理）；\n6. 进行中的PR请在标题添加`[WIP]`前缀。这既能避免重复劳动，也可与待合并PR明确区分；\n7. 文本表述与格式要求请参照[优质issue编写规范](#how-to-write-a-good-issue)；\n8. 确保现有测试用例全部通过；\n9. 必须添加高覆盖率测试。未经充分测试的代码不予合并。\n- 若新增`@slow`测试，请使用`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`确保通过。\nCircleCI不执行慢速测试，但GitHub Actions会每日夜间运行！\n10. 所有公开方法必须包含格式规范、兼容markdown的说明文档。可参考[`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) \n11. 由于代码库快速增长，必须确保不会添加明显增加仓库体积的文件（如图片、视频等非文本文件）。建议优先使用托管在hf.co的`dataset`（例如[`hf-internal-testing`](https://huggingface.co/hf-internal-testing)或[huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images)）存放这类文件。若为外部贡献，可将图片添加到PR中并请Hugging Face成员将其迁移至该数据集。\n\n## 提交PR流程\n\n编写代码前，强烈建议先搜索现有PR或issue，确认没有重复工作。如有疑问，建议先创建issue获取反馈。\n\n贡献至🧨 Diffusers需要基本的`git`技能。虽然`git`学习曲线较高，但其拥有最完善的手册。在终端输入`git --help`即可查阅，或参考书籍[Pro Git](https://git-scm.com/book/en/v2)。\n\n请按以下步骤操作（[支持的Python版本](https://github.com/huggingface/diffusers/blob/83bc6c94eaeb6f7704a2a428931cf2d9ad973ae9/setup.py#L270)）：\n\n1. 在[仓库页面](https://github.com/huggingface/diffusers)点击\"Fork\"按钮创建代码副本至您的GitHub账户\n\n2. 克隆fork到本地，并添加主仓库为远程源：\n ```bash\n $ git clone git@github.com:<您的GitHub账号>/diffusers.git\n $ cd diffusers\n $ git remote add upstream https://github.com/huggingface/diffusers.git\n ```\n\n3. 创建新分支进行开发：\n ```bash\n $ git checkout -b 您的开发分支名称\n ```\n**禁止**直接在`main`分支上修改\n\n4. 在虚拟环境中运行以下命令配置开发环境：\n ```bash\n $ pip install -e \".[dev]\"\n ```\n若已克隆仓库，可能需要先执行`git pull`获取最新代码\n\n5. 在您的分支上开发功能\n\n开发过程中应确保测试通过。可运行受影响测试：\n ```bash\n $ pytest tests/<待测文件>.py\n ```\n执行测试前请安装测试依赖：\n ```bash\n $ pip install -e \".[test]\"\n ```\n也可运行完整测试套件（需高性能机器）：\n ```bash\n $ make test\n ```\n\n🧨 Diffusers使用`black`和`isort`工具保持代码风格统一。修改后请执行自动化格式校正与代码验证，以下内容无法通过以下命令一次性自动化完成：\n\n```bash\n$ make style\n```\n\n🧨 Diffusers 还使用 `ruff` 和一些自定义脚本来检查代码错误。虽然质量控制流程会在 CI 中运行，但您也可以通过以下命令手动执行相同的检查：\n\n```bash\n$ make quality\n```\n\n当您对修改满意后，使用 `git add` 添加更改的文件，并通过 `git commit` 在本地记录这些更改：\n\n```bash\n$ git add modified_file.py\n$ git commit -m \"关于您所做更改的描述性信息。\"\n```\n\n定期将您的代码副本与原始仓库同步是一个好习惯。这样可以快速适应上游变更：\n\n```bash\n$ git pull upstream main\n```\n\n使用以下命令将更改推送到您的账户：\n\n```bash\n$ git push -u origin 此处替换为您的描述性分支名称\n```\n\n6. 确认无误后，请访问您 GitHub 账户中的派生仓库页面。点击「Pull request」将您的更改提交给项目维护者审核。\n\n7. 如果维护者要求修改，这很正常——核心贡献者也会遇到这种情况！为了让所有人能在 Pull request 中看到变更，请在本地分支继续工作并将修改推送到您的派生仓库，这些变更会自动出现在 Pull request 中。\n\n### 测试\n\n我们提供了全面的测试套件来验证库行为和多个示例。库测试位于 [tests 文件夹](https://github.com/huggingface/diffusers/tree/main/tests)。\n\n我们推荐使用 `pytest` 和 `pytest-xdist`，因为它们速度更快。在仓库根目录下运行以下命令执行库测试：\n\n```bash\n$ python -m pytest -n auto --dist=loadfile -s -v ./tests/\n```\n\n实际上，这就是 `make test` 的实现方式！\n\n您可以指定更小的测试范围来仅验证您正在开发的功能。\n\n默认情况下会跳过耗时测试。设置 `RUN_SLOW` 环境变量为 `yes` 可运行这些测试。注意：这将下载数十 GB 的模型文件——请确保您有足够的磁盘空间、良好的网络连接或充足的耐心！\n\n```bash\n$ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/\n```\n\n我们也完全支持 `unittest`，运行方式如下：\n\n```bash\n$ python -m unittest discover -s tests -t . -v\n$ python -m unittest discover -s examples -t examples -v\n```\n\n### 将派生仓库的 main 分支与上游（HuggingFace）main 分支同步\n\n为避免向上游仓库发送引用通知（这会给相关 PR 添加注释并向开发者发送不必要的通知），在同步派生仓库的 main 分支时，请遵循以下步骤：\n1. 尽可能避免通过派生仓库的分支和 PR 来同步上游，而是直接合并到派生仓库的 main 分支\n2. 如果必须使用 PR，请在检出分支后执行以下操作：\n```bash\n$ git checkout -b 您的同步分支名称\n$ git pull --squash --no-commit upstream main\n$ git commit -m '提交信息（不要包含 GitHub 引用）'\n$ git push --set-upstream origin 您的分支名称\n```\n\n### 风格指南\n\n对于文档字符串，🧨 Diffusers 遵循 [Google 风格指南](https://google.github.io/styleguide/pyguide.html)。\n"
  },
  {
    "path": "docs/source/zh/conceptual/ethical_guidelines.md",
    "content": "<!--版权归2025年HuggingFace团队所有。保留所有权利。\n\n根据Apache许可证2.0版（\"许可证\"）授权；除非符合许可证要求，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，本软件按\"原样\"分发，不附带任何明示或暗示的担保或条件。详见许可证中规定的特定语言权限和限制。\n-->\n\n# 🧨 Diffusers伦理准则\n\n## 前言\n\n[Diffusers](https://huggingface.co/docs/diffusers/index)不仅提供预训练的diffusion模型，还是一个模块化工具箱，支持推理和训练功能。\n\n鉴于该技术在实际场景中的应用及其可能对社会产生的负面影响，我们认为有必要制定项目伦理准则，以指导Diffusers库的开发、用户贡献和使用规范。\n\n该技术涉及的风险仍在持续评估中，主要包括但不限于：艺术家版权问题、深度伪造滥用、不当情境下的色情内容生成、非自愿的人物模仿、以及加剧边缘群体压迫的有害社会偏见。我们将持续追踪风险，并根据社区反馈动态调整本准则。\n\n## 适用范围\n\nDiffusers社区将在项目开发中贯彻以下伦理准则，并协调社区贡献的整合方式，特别是在涉及伦理敏感议题的技术决策时。\n\n## 伦理准则\n\n以下准则具有普遍适用性，但我们主要在处理涉及伦理敏感问题的技术决策时实施。同时，我们承诺将根据技术发展带来的新兴风险持续调整这些原则：\n\n- **透明度**：我们承诺以透明方式管理PR（拉取请求），向用户解释决策依据，并公开技术选择过程。\n\n- **一致性**：我们承诺为用户提供统一标准的项目管理，保持技术稳定性和连贯性。\n\n- **简洁性**：为了让Diffusers库更易使用和开发，我们承诺保持项目目标精简且逻辑自洽。\n\n- **可及性**：本项目致力于降低贡献门槛，即使非技术人员也能参与运营，从而使研究资源更广泛地服务于社区。\n\n- **可复现性**：对于通过Diffusers库发布的上游代码、模型和数据集，我们将明确说明其可复现性。\n\n- **责任性**：作为社区和团队，我们共同承担用户责任，通过风险预判和缓解措施来应对技术潜在危害。\n\n## 实施案例：安全功能与机制\n\n团队持续开发技术和非技术工具，以应对diffusion技术相关的伦理与社会风险。社区反馈对于功能实施和风险意识提升具有不可替代的价值：\n\n- [**社区讨论区**](https://huggingface.co/docs/hub/repositories-pull-requests-discussions)：促进社区成员就项目开展协作讨论。\n\n- **偏见探索与评估**：Hugging Face团队提供[交互空间](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)展示Stable Diffusion中的偏见。我们支持并鼓励此类偏见探索与评估工作。\n\n- **部署安全强化**：\n  \n  - [**Safe Stable Diffusion**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_safe)：解决Stable Diffusion等基于未过滤网络爬取数据训练的模型容易产生不当内容的问题。相关论文：[Safe Latent Diffusion：缓解diffusion模型中的不当退化](https://huggingface.co/papers/2211.05105)。\n\n  - [**安全检测器**](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py)：通过比对图像生成后嵌入空间中硬编码有害概念集的类别概率进行检测。有害概念列表经特殊处理以防逆向工程。\n\n- **分阶段模型发布**：对于高度敏感的仓库，采用分级访问控制。这种阶段性发布机制让作者能更好地管控使用场景。\n\n- **许可证制度**：采用新型[OpenRAILs](https://huggingface.co/blog/open_rail)许可协议，在保障开放访问的同时设置使用限制以确保更负责任的应用。\n"
  },
  {
    "path": "docs/source/zh/conceptual/evaluation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\n根据 Apache License 2.0 版本（\"许可证\"）授权，除非符合许可证要求，否则不得使用本文件。\n您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，本软件按\"原样\"分发，不附带任何明示或暗示的担保或条件。详见许可证中规定的特定语言权限和限制。\n-->\n\n# Diffusion模型评估指南\n\n<a target=\"_blank\" href=\"https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/evaluation.ipynb\">\n    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"在 Colab 中打开\"/>\n</a>\n\n> [!TIP]\n> 鉴于当前已出现针对图像生成Diffusion模型的成熟评估框架（如[HEIM](https://crfm.stanford.edu/helm/heim/latest/)、[T2I-Compbench](https://huggingface.co/papers/2307.06350)、[GenEval](https://huggingface.co/papers/2310.11513)），本文档部分内容已过时。\n\n像 [Stable Diffusion](https://huggingface.co/docs/diffusers/stable_diffusion) 这类生成模型的评估本质上是主观的。但作为开发者和研究者，我们经常需要在众多可能性中做出审慎选择。那么当面对不同生成模型（如 GANs、Diffusion 等）时，该如何决策？\n\n定性评估容易产生偏差，可能导致错误结论；而定量指标又未必能准确反映图像质量。因此，通常需要结合定性与定量评估来获得更可靠的模型选择依据。\n\n本文档将系统介绍扩散模型的定性与定量评估方法（非穷尽列举）。对于定量方法，我们将重点演示如何结合 `diffusers` 库实现这些评估。\n\n文档所示方法同样适用于评估不同[噪声调度器](https://huggingface.co/docs/diffusers/main/en/api/schedulers/overview)在固定生成模型下的表现差异。\n\n## 评估场景\n\n我们涵盖以下Diffusion模型管线的评估：\n\n- 文本引导图像生成（如 [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/text2img)）\n- 基于文本和输入图像的引导生成（如 [`StableDiffusionImg2ImgPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/img2img) 和 [`StableDiffusionInstructPix2PixPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix)）\n- 类别条件图像生成模型(如 [`DiTPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipe))\n\n## 定性评估\n\n定性评估通常涉及对生成图像的人工评判。评估维度包括构图质量、图文对齐度和空间关系等方面。标准化的提示词能为这些主观指标提供统一基准。DrawBench和PartiPrompts是常用的定性评估提示词数据集，分别由[Imagen](https://imagen.research.google/)和[Parti](https://parti.research.google/)团队提出。\n\n根据[Parti官方网站](https://parti.research.google/)说明：\n\n> PartiPrompts (P2)是我们发布的包含1600多个英文提示词的丰富集合，可用于测量模型在不同类别和挑战维度上的能力。\n\n![parti-prompts](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts.png)\n\nPartiPrompts包含以下字段：\n- Prompt（提示词）\n- Category（类别，如\"抽象\"、\"世界知识\"等）\n- Challenge（难度等级，如\"基础\"、\"复杂\"、\"文字与符号\"等）\n\n这些基准测试支持对不同图像生成模型进行并排人工对比评估。为此，🧨 Diffusers团队构建了**Open Parti Prompts**——一个基于Parti Prompts的社区驱动型定性评估基准，用于比较顶尖开源diffusion模型：\n- [Open Parti Prompts游戏](https://huggingface.co/spaces/OpenGenAI/open-parti-prompts)：展示10个parti提示词对应的4张生成图像，用户选择最符合提示的图片\n- [Open Parti Prompts排行榜](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard)：对比当前最优开源diffusion模型的性能榜单\n\n为进行手动图像对比，我们演示如何使用`diffusers`处理部分PartiPrompts提示词。\n\n以下是从不同挑战维度（基础、复杂、语言结构、想象力、文字与符号）采样的提示词示例（使用[PartiPrompts作为数据集](https://huggingface.co/datasets/nateraw/parti-prompts)）：\n\n```python\nfrom datasets import load_dataset\n\n# prompts = load_dataset(\"nateraw/parti-prompts\", split=\"train\")\n# prompts = prompts.shuffle()\n# sample_prompts = [prompts[i][\"Prompt\"] for i in range(5)]\n\n# Fixing these sample prompts in the interest of reproducibility.\nsample_prompts = [\n    \"a corgi\",\n    \"a hot air balloon with a yin-yang symbol, with the moon visible in the daytime sky\",\n    \"a car with no windows\",\n    \"a cube made of porcupine\",\n    'The saying \"BE EXCELLENT TO EACH OTHER\" written on a red brick wall with a graffiti image of a green alien wearing a tuxedo. A yellow fire hydrant is on a sidewalk in the foreground.',\n]\n```\n\n现在我们可以使用Stable Diffusion（[v1-4 checkpoint](https://huggingface.co/CompVis/stable-diffusion-v1-4)）生成这些提示词对应的图像：\n\n```python\nimport torch\n\nseed = 0\ngenerator = torch.manual_seed(seed)\n\nimages = sd_pipeline(sample_prompts, num_images_per_prompt=1, generator=generator).images\n```\n\n![parti-prompts-14](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-14.png)\n\n我们也可以通过设置`num_images_per_prompt`参数来比较同一提示词生成的不同图像。使用不同检查点([v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5))运行相同流程后，结果如下：\n\n![parti-prompts-15](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-15.png)\n\n当使用多个待评估模型为所有提示词生成若干图像后，这些结果将提交给人类评估员进行打分。有关DrawBench和PartiPrompts基准测试的更多细节，请参阅各自的论文。\n\n> [!TIP]\n> 在模型训练过程中查看推理样本有助于评估训练进度。我们的[训练脚本](https://github.com/huggingface/diffusers/tree/main/examples/)支持此功能，并额外提供TensorBoard和Weights & Biases日志记录功能。\n\n## 定量评估\n\n本节将指导您如何评估三种不同的扩散流程，使用以下指标：\n- CLIP分数\n- CLIP方向相似度\n- FID（弗雷歇起始距离）\n\n### 文本引导图像生成\n\n[CLIP分数](https://huggingface.co/papers/2104.08718)用于衡量图像-标题对的匹配程度。CLIP分数越高表明匹配度越高🔼。该分数是对\"匹配度\"这一定性概念的量化测量，也可以理解为图像与标题之间的语义相似度。研究发现CLIP分数与人类判断具有高度相关性。\n\n首先加载[`StableDiffusionPipeline`]：\n\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\nmodel_ckpt = \"CompVis/stable-diffusion-v1-4\"\nsd_pipeline = StableDiffusionPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16).to(\"cuda\")\n```\n\n使用多个提示词生成图像：\n\n```python\nprompts = [\n    \"a photo of an astronaut riding a horse on mars\",\n    \"A high tech solarpunk utopia in the Amazon rainforest\",\n    \"A pikachu fine dining with a view to the Eiffel Tower\",\n    \"A mecha robot in a favela in expressionist style\",\n    \"an insect robot preparing a delicious meal\",\n    \"A small cabin on top of a snowy mountain in the style of Disney, artstation\",\n]\n\nimages = sd_pipeline(prompts, num_images_per_prompt=1, output_type=\"np\").images\n\nprint(images.shape)\n# (6, 512, 512, 3)\n```\n\n然后计算CLIP分数：\n\n```python\nfrom torchmetrics.functional.multimodal import clip_score\nfrom functools import partial\n\nclip_score_fn = partial(clip_score, model_name_or_path=\"openai/clip-vit-base-patch16\")\n\ndef calculate_clip_score(images, prompts):\n    images_int = (images * 255).astype(\"uint8\")\n    clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach()\n    return round(float(clip_score), 4)\n\nsd_clip_score = calculate_clip_score(images, prompts)\nprint(f\"CLIP分数: {sd_clip_score}\")\n# CLIP分数: 35.7038\n```\n\n上述示例中，我们为每个提示生成一张图像。如果为每个提示生成多张图像，则需要计算每个提示生成图像的平均分数。\n\n当需要比较两个兼容[`StableDiffusionPipeline`]的检查点时，应在调用管道时传入生成器。首先使用[v1-4 Stable Diffusion检查点](https://huggingface.co/CompVis/stable-diffusion-v1-4)以固定种子生成图像：\n\n```python\nseed = 0\ngenerator = torch.manual_seed(seed)\n\nimages = sd_pipeline(prompts, num_images_per_prompt=1, generator=generator, output_type=\"np\").images\n```\n\n然后加载[v1-5检查点](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)生成图像：\n\n```python\nmodel_ckpt_1_5 = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nsd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=torch.float16).to(\"cuda\")\n\nimages_1_5 = sd_pipeline_1_5(prompts, num_images_per_prompt=1, generator=generator, output_type=\"np\").images\n```\n\n最后比较两者的CLIP分数：\n\n```python\nsd_clip_score_1_4 = calculate_clip_score(images, prompts)\nprint(f\"v-1-4版本的CLIP分数: {sd_clip_score_1_4}\")\n# v-1-4版本的CLIP分数: 34.9102\n\nsd_clip_score_1_5 = calculate_clip_score(images_1_5, prompts)\nprint(f\"v-1-5版本的CLIP分数: {sd_clip_score_1_5}\")\n# v-1-5版本的CLIP分数: 36.2137\n```\n\n结果表明[v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)检查点性能优于前代。但需注意，我们用于计算CLIP分数的提示词数量较少。实际评估时应使用更多样化且数量更大的提示词集。\n\n> [!WARNING]\n> 该分数存在固有局限性：训练数据中的标题是从网络爬取，并提取自图片关联的`alt`等标签。这些描述未必符合人类描述图像的方式，因此我们需要人工\"设计\"部分提示词。\n\n### 图像条件式文本生成图像\n\n这种情况下，生成管道同时接受输入图像和文本提示作为条件。以[`StableDiffusionInstructPix2PixPipeline`]为例，该管道接收编辑指令作为输入提示，并接受待编辑的输入图像。\n\n示例图示：\n\n![编辑指令](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png)\n\n评估此类模型的策略之一是测量两幅图像间变化的连贯性（通过[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)定义）中两个图像之间的变化与两个图像描述之间的变化的一致性（如论文[《CLIP-Guided Domain Adaptation of Image Generators》](https://huggingface.co/papers/2108.00946)所示）。这被称为“**CLIP方向相似度**”。  \n\n- **描述1**对应输入图像（图像1），即待编辑的图像。  \n- **描述2**对应编辑后的图像（图像2），应反映编辑指令。  \n\n以下是示意图：  \n\n![edit-consistency](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-consistency.png)  \n\n我们准备了一个小型数据集来实现该指标。首先加载数据集：  \n\n```python\nfrom datasets import load_dataset\n\ndataset = load_dataset(\"sayakpaul/instructpix2pix-demo\", split=\"train\")\ndataset.features\n```  \n\n```bash\n{'input': Value(dtype='string', id=None),\n 'edit': Value(dtype='string', id=None),\n 'output': Value(dtype='string', id=None),\n 'image': Image(decode=True, id=None)}\n```  \n\n数据字段说明：  \n\n- `input`：与`image`对应的原始描述。  \n- `edit`：编辑指令。  \n- `output`：反映`edit`指令的修改后描述。  \n\n查看一个样本：  \n\n```python\nidx = 0\nprint(f\"Original caption: {dataset[idx]['input']}\")\nprint(f\"Edit instruction: {dataset[idx]['edit']}\")\nprint(f\"Modified caption: {dataset[idx]['output']}\")\n```  \n\n```bash\nOriginal caption: 2. FAROE ISLANDS: An archipelago of 18 mountainous isles in the North Atlantic Ocean between Norway and Iceland, the Faroe Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'\nEdit instruction: make the isles all white marble\nModified caption: 2. WHITE MARBLE ISLANDS: An archipelago of 18 mountainous white marble isles in the North Atlantic Ocean between Norway and Iceland, the White Marble Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'\n```  \n\n对应的图像：  \n\n```python\ndataset[idx][\"image\"]\n```  \n\n![edit-dataset](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-dataset.png)  \n\n我们将根据编辑指令修改数据集中的图像，并计算方向相似度。  \n\n首先加载[`StableDiffusionInstructPix2PixPipeline`]：  \n\n```python\nfrom diffusers import StableDiffusionInstructPix2PixPipeline\n\ninstruct_pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(\n    \"timbrooks/instruct-pix2pix\", torch_dtype=torch.float16\n).to(\"cuda\")\n```  \n\n执行编辑操作：  \n\n```python\nimport numpy as np\n\n\ndef edit_image(input_image, instruction):\n    image = instruct_pix2pix_pipeline(\n        instruction,\n        image=input_image,\n        output_type=\"np\",\n        generator=generator,\n    ).images[0]\n    return image\n\ninput_images = []\noriginal_captions = []\nmodified_captions = []\nedited_images = []\n\nfor idx in range(len(dataset)):\n    input_image = dataset[idx][\"image\"]\n    edit_instruction = dataset[idx][\"edit\"]\n    edited_image = edit_image(input_image, edit_instruction)\n\n    input_images.append(np.array(input_image))\n    original_captions.append(dataset[idx][\"input\"])\n    modified_captions.append(dataset[idx][\"output\"])\n    edited_images.append(edited_image)\n```\n\n为测量方向相似度，我们首先加载CLIP的图像和文本编码器：\n\n```python\nfrom transformers import (\n    CLIPTokenizer,\n    CLIPTextModelWithProjection,\n    CLIPVisionModelWithProjection,\n    CLIPImageProcessor,\n)\n\nclip_id = \"openai/clip-vit-large-patch14\"\ntokenizer = CLIPTokenizer.from_pretrained(clip_id)\ntext_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to(\"cuda\")\nimage_processor = CLIPImageProcessor.from_pretrained(clip_id)\nimage_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(\"cuda\")\n```\n\n注意我们使用的是特定CLIP检查点——`openai/clip-vit-large-patch14`，因为Stable Diffusion预训练正是基于此CLIP变体。详见[文档](https://huggingface.co/docs/transformers/model_doc/clip)。\n\n接着准备计算方向相似度的PyTorch `nn.Module`：\n\n```python\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass DirectionalSimilarity(nn.Module):\n    def __init__(self, tokenizer, text_encoder, image_processor, image_encoder):\n        super().__init__()\n        self.tokenizer = tokenizer\n        self.text_encoder = text_encoder\n        self.image_processor = image_processor\n        self.image_encoder = image_encoder\n\n    def preprocess_image(self, image):\n        image = self.image_processor(image, return_tensors=\"pt\")[\"pixel_values\"]\n        return {\"pixel_values\": image.to(\"cuda\")}\n\n    def tokenize_text(self, text):\n        inputs = self.tokenizer(\n            text,\n            max_length=self.tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        return {\"input_ids\": inputs.input_ids.to(\"cuda\")}\n\n    def encode_image(self, image):\n        preprocessed_image = self.preprocess_image(image)\n        image_features = self.image_encoder(**preprocessed_image).image_embeds\n        image_features = image_features / image_features.norm(dim=1, keepdim=True)\n        return image_features\n\n    def encode_text(self, text):\n        tokenized_text = self.tokenize_text(text)\n        text_features = self.text_encoder(**tokenized_text).text_embeds\n        text_features = text_features / text_features.norm(dim=1, keepdim=True)\n        return text_features\n\n    def compute_directional_similarity(self, img_feat_one, img_feat_two, text_feat_one, text_feat_two):\n        sim_direction = F.cosine_similarity(img_feat_two - img_feat_one, text_feat_two - text_feat_one)\n        return sim_direction\n\n    def forward(self, image_one, image_two, caption_one, caption_two):\n        img_feat_one = self.encode_image(image_one)\n        img_feat_two = self.encode_image(image_two)\n        text_feat_one = self.encode_text(caption_one)\n        text_feat_two = self.encode_text(caption_two)\n        directional_similarity = self.compute_directional_similarity(\n            img_feat_one, img_feat_two, text_feat_one, text_feat_two\n        )\n        return directional_similarity\n```\n\n现在让我们使用`DirectionalSimilarity`模块：\n\n```python\ndir_similarity = DirectionalSimilarity(tokenizer, text_encoder, image_processor, image_encoder)\nscores = []\n\nfor i in range(len(input_images)):\n    original_image = input_images[i]\n    original_caption = original_captions[i]\n    edited_image = edited_images[i]\n    modified_caption = modified_captions[i]\n\n    similarity_score = dir_similarity(original_image, edited_image, original_caption, modified_caption)\n    scores.append(float(similarity_score.detach().cpu()))\n\nprint(f\"CLIP方向相似度: {np.mean(scores)}\")\n# CLIP方向相似度: 0.0797976553440094\n```\n\n与CLIP分数类似，CLIP方向相似度数值越高越好。\n\n需要注意的是，`StableDiffusionInstructPix2PixPipeline`提供了两个控制参数`image_guidance_scale`和`guidance_scale`来调节最终编辑图像的质量。建议您尝试调整这两个参数，观察它们对方向相似度的影响。\n\n我们可以扩展这个度量标准来评估原始图像与编辑版本的相似度，只需计算`F.cosine_similarity(img_feat_two, img_feat_one)`。对于这类编辑任务，我们仍希望尽可能保留图像的主要语义特征（即保持较高的相似度分数）。\n\n该度量方法同样适用于类似流程，例如[`StableDiffusionPix2PixZeroPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix_zero#diffusers.StableDiffusionPix2PixZeroPipeline)。\n\n> [!TIP]\n> CLIP分数和CLIP方向相似度都依赖CLIP模型，可能导致评估结果存在偏差。\n\n***扩展IS、FID（后文讨论）或KID等指标存在困难***，当被评估模型是在大型图文数据集（如[LAION-5B数据集](https://laion.ai/blog/laion-5b/)）上预训练时。因为这些指标的底层都使用了在ImageNet-1k数据集上预训练的InceptionNet来提取图像特征。Stable Diffusion的预训练数据集与InceptionNet的预训练数据集可能重叠有限，因此不适合作为特征提取器。\n\n***上述指标更适合评估类别条件模型***，例如[DiT](https://huggingface.co/docs/diffusers/main/en/api/pipelines/dit)。该模型是在ImageNet-1k类别条件下预训练的。\n这是9篇文档中的第8部分。\n\n### 基于类别的图像生成\n\n基于类别的生成模型通常是在带有类别标签的数据集（如[ImageNet-1k](https://huggingface.co/datasets/imagenet-1k)）上进行预训练的。评估这些模型的常用指标包括Fréchet Inception Distance（FID）、Kernel Inception Distance（KID）和Inception Score（IS）。本文档重点介绍FID（[Heusel等人](https://huggingface.co/papers/1706.08500)），并展示如何使用[`DiTPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/dit)计算该指标，该管道底层使用了[DiT模型](https://huggingface.co/papers/2212.09748)。\n\nFID旨在衡量两组图像数据集的相似程度。根据[此资源](https://mmgeneration.readthedocs.io/en/latest/quick_run.html#fid)：\n\n> Fréchet Inception Distance是衡量两组图像数据集相似度的指标。研究表明其与人类对视觉质量的主观判断高度相关，因此最常用于评估生成对抗网络（GAN）生成样本的质量。FID通过计算Inception网络特征表示所拟合的两个高斯分布之间的Fréchet距离来实现。\n\n这两个数据集本质上是真实图像数据集和生成图像数据集（本例中为人工生成的图像）。FID通常基于两个大型数据集计算，但本文档将使用两个小型数据集进行演示。\n\n首先下载ImageNet-1k训练集中的部分图像：\n\n```python\nfrom zipfile import ZipFile\nimport requests\n\n\ndef download(url, local_filepath):\n    r = requests.get(url)\n    with open(local_filepath, \"wb\") as f:\n        f.write(r.content)\n    return local_filepath\n\ndummy_dataset_url = \"https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/sample-imagenet-images.zip\"\nlocal_filepath = download(dummy_dataset_url, dummy_dataset_url.split(\"/\")[-1])\n\nwith ZipFile(local_filepath, \"r\") as zipper:\n    zipper.extractall(\".\")\n```\n\n```python\nfrom PIL import Image\nimport os\nimport numpy as np\n\ndataset_path = \"sample-imagenet-images\"\nimage_paths = sorted([os.path.join(dataset_path, x) for x in os.listdir(dataset_path)])\n\nreal_images = [np.array(Image.open(path).convert(\"RGB\")) for path in image_paths]\n```\n\n这些是来自以下ImageNet-1k类别的10张图像：\"cassette_player\"、\"chain_saw\"（2张）、\"church\"、\"gas_pump\"（3张）、\"parachute\"（2张）和\"tench\"。\n\n<p align=\"center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/real-images.png\" alt=\"真实图像\"><br>\n    <em>真实图像</em>\n</p>\n\n加载图像后，我们对其进行轻量级预处理以便用于FID计算：\n\n```python\nfrom torchvision.transforms import functional as F\nimport torch\n\n\ndef preprocess_image(image):\n    image = torch.tensor(image).unsqueeze(0)\n    image = image.permute(0, 3, 1, 2) / 255.0\n    return F.center_crop(image, (256, 256))\n\nreal_images = torch.stack([dit_pipeline.preprocess_image(image) for image in real_images])\nprint(real_images.shape)\n# torch.Size([10, 3, 256, 256])\n```\n\n我们现在加载[`DiTPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/dit)来生成基于上述类别的条件图像。\n\n```python\nfrom diffusers import DiTPipeline, DPMSolverMultistepScheduler\n\ndit_pipeline = DiTPipeline.from_pretrained(\"facebook/DiT-XL-2-256\", torch_dtype=torch.float16)\ndit_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(dit_pipeline.scheduler.config)\ndit_pipeline = dit_pipeline.to(\"cuda\")\n\nseed = 0\ngenerator = torch.manual_seed(seed)\n\n\nwords = [\n    \"cassette player\",\n    \"chainsaw\",\n    \"chainsaw\",\n    \"church\",\n    \"gas pump\",\n    \"gas pump\",\n    \"gas pump\",\n    \"parachute\",\n    \"parachute\",\n    \"tench\",\n]\n\nclass_ids = dit_pipeline.get_label_ids(words)\noutput = dit_pipeline(class_labels=class_ids, generator=generator, output_type=\"np\")\n\nfake_images = output.images\nfake_images = torch.tensor(fake_images)\nfake_images = fake_images.permute(0, 3, 1, 2)\nprint(fake_images.shape)\n# torch.Size([10, 3, 256, 256])\n```\n\n现在，我们可以使用[`torchmetrics`](https://torchmetrics.readthedocs.io/)计算FID分数。\n\n```python\nfrom torchmetrics.image.fid import FrechetInceptionDistance\n\nfid = FrechetInceptionDistance(normalize=True)\nfid.update(real_images, real=True)\nfid.update(fake_images, real=False)\n\nprint(f\"FID分数: {float(fid.compute())}\")\n# FID分数: 177.7147216796875\n```\n\nFID分数越低越好。以下因素会影响FID结果：\n\n- 图像数量（包括真实图像和生成图像）\n- 扩散过程中引入的随机性\n- 扩散过程的推理步数\n- 扩散过程中使用的调度器\n\n对于最后两点，最佳实践是使用不同的随机种子和推理步数进行多次评估，然后报告平均结果。\n\n> [!WARNING]\n> FID结果往往具有脆弱性，因为它依赖于许多因素：\n>\n> * 计算过程中使用的特定Inception模型\n> * 计算实现的准确性\n> * 图像格式（PNG和JPG的起点不同）\n>\n> 需要注意的是，FID通常在比较相似实验时最有用，但除非作者仔细公开FID测量代码，否则很难复现论文结果。\n>\n> 这些注意事项同样适用于其他相关指标，如KID和IS。\n\n最后，让我们可视化检查这些`fake_images`。\n\n<p align=\"center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/fake-images.png\" alt=\"生成图像\"><br>\n    <em>生成图像示例</em>\n</p>\n"
  },
  {
    "path": "docs/source/zh/conceptual/philosophy.md",
    "content": "<!--版权 2025 HuggingFace 团队。保留所有权利。\n\n根据 Apache 许可证 2.0 版本（\"许可证\"）授权；\n除非符合许可证要求，否则不得使用本文件。\n您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，本软件按\"原样\"分发，\n无任何明示或暗示的担保或条件。详见许可证中\n的特定语言规定和限制。\n-->\n\n# 设计哲学\n\n🧨 Diffusers 提供**最先进**的预训练扩散模型支持多模态任务。\n其目标是成为推理和训练通用的**模块化工具箱**。\n\n我们致力于构建一个经得起时间考验的库，因此对API设计极为重视。\n\n简而言之，Diffusers 被设计为 PyTorch 的自然延伸。因此，我们的多数设计决策都基于 [PyTorch 设计原则](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy)。以下是核心原则：\n\n## 可用性优先于性能\n\n- 尽管 Diffusers 包含众多性能优化特性（参见[内存与速度优化](https://huggingface.co/docs/diffusers/optimization/fp16)），模型默认总是以最高精度和最低优化级别加载。因此除非用户指定，扩散流程(pipeline)默认在CPU上以float32精度初始化。这确保了跨平台和加速器的可用性，意味着运行本库无需复杂安装。\n- Diffusers 追求**轻量化**，仅有少量必需依赖，但提供诸多可选依赖以提升性能（如`accelerate`、`safetensors`、`onnx`等）。我们竭力保持库的轻量级特性，使其能轻松作为其他包的依赖项。\n- Diffusers 偏好简单、自解释的代码而非浓缩的\"魔法\"代码。这意味着lambda函数等简写语法和高级PyTorch操作符通常不被采用。\n\n## 简洁优于简易\n\n正如PyTorch所言：**显式优于隐式**，**简洁优于复杂**。这一哲学体现在库的多个方面：\n- 我们遵循PyTorch的API设计，例如使用[`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to)让用户自主管理设备。\n- 明确的错误提示优于静默纠正错误输入。Diffusers 旨在教育用户，而非单纯降低使用难度。\n- 暴露复杂的模型与调度器(scheduler)交互逻辑而非内部魔法处理。调度器/采样器与扩散模型分离且相互依赖最小化，迫使用户编写展开的去噪循环。但这种分离便于调试，并赋予用户更多控制权来调整去噪过程或切换模型/调度器。\n- 扩散流程中独立训练的组件（如文本编码器、UNet、变分自编码器）各有专属模型类。这要求用户处理组件间交互，且序列化格式将组件分存不同文件。但此举便于调试和定制，得益于组件分离，DreamBooth或Textual Inversion训练变得极为简单。\n\n## 可定制与贡献友好优于抽象\n\n库的大部分沿用了[Transformers库](https://github.com/huggingface/transformers)的重要设计原则：宁要重复代码，勿要仓促抽象。这一原则与[DRY原则](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself)形成鲜明对比。\n\n简言之，正如Transformers对建模文件的做法，Diffusers对流程(pipeline)和调度器(scheduler)保持极低抽象度与高度自包含代码。函数、长代码块甚至类可能在多文件中重复，初看像是糟糕的松散设计。但该设计已被Transformers证明极其成功，对社区驱动的开源机器学习库意义重大：\n- 机器学习领域发展迅猛，范式、模型架构和算法快速迭代，难以定义长效代码抽象。\n- ML从业者常需快速修改现有代码进行研究，因此偏好自包含代码而非多重抽象。\n- 开源库依赖社区贡献，必须构建易于参与的代码库。抽象度越高、依赖越复杂、可读性越差，贡献难度越大。过度抽象的库会吓退贡献者。若贡献不会破坏核心功能，不仅吸引新贡献者，也更便于并行审查和修改。\n\nHugging Face称此设计为**单文件政策**——即某个类的几乎所有代码都应写在单一自包含文件中。更多哲学探讨可参阅[此博文](https://huggingface.co/blog/transformers-design-philosophy)。\n\nDiffusers对流程和调度器完全遵循该哲学，但对diffusion模型仅部分适用。原因在于多数扩散流程（如[DDPM](https://huggingface.co/docs/diffusers/api/pipelines/ddpm)、[Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview#stable-diffusion-pipelines)、[unCLIP (DALL·E 2)](https://huggingface.co/docs/diffusers/api/pipelines/unclip)和[Imagen](https://imagen.research.google/)）都基于相同扩散模型——[UNet](https://huggingface.co/docs/diffusers/api/models/unet2d-cond)。\n\n现在您应已理解🧨 Diffusers的设计理念🤗。我们力求在全库贯彻这些原则，但仍存在少数例外或欠佳设计。如有反馈，我们❤️欢迎在[GitHub提交](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=)。\n\n## 设计哲学细节\n\n现在深入探讨设计细节。Diffusers主要包含三类：[流程(pipeline)](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)、[模型](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)和[调度器(scheduler)](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)。以下是各类的具体设计决策。\n\n### 流程(Pipelines)\n\n流程设计追求易用性（因此不完全遵循[*简洁优于简易*](#简洁优于简易)），不要求功能完备，应视为使用[模型](#模型)和[调度器](#调度器schedulers)进行推理的示例。\n\n遵循原则：\n- 采用单文件政策。所有流程位于src/diffusers/pipelines下的独立目录。一个流程文件夹对应一篇扩散论文/项目/发布。如[`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion)可包含多个流程文件。若流程功能相似，可使用[# Copied from机制](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251)。\n- 所有流程继承[`DiffusionPipeline`]。\n- 每个流程由不同模型和调度器组件构成，这些组件记录于[`model_index.json`文件](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/model_index.json)，可通过同名属性访问，并可用[`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components)在流程间共享。\n- 所有流程应能通过[`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained)加载。\n- 流程**仅**用于推理。\n- 流程代码应具备高可读性、自解释性和易修改性。\n- 流程应设计为可相互构建，便于集成到高层API。\n- 流程**非**功能完备的用户界面。完整UI推荐[InvokeAI](https://github.com/invoke-ai/InvokeAI)、[Diffuzers](https://github.com/abhishekkrthakur/diffuzers)或[lama-cleaner](https://github.com/Sanster/lama-cleaner)。\n- 每个流程应通过唯一的`__call__`方法运行，且参数命名应跨流程统一。\n- 流程应以其解决的任务命名。\n- 几乎所有新diffusion流程都应在新文件夹/文件中实现。\n\n### 模型\n\n模型设计为可配置的工具箱，是[PyTorch Module类](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)的自然延伸，仅部分遵循**单文件政策**。\n\n遵循原则：\n- 模型对应**特定架构类型**。如[`UNet2DConditionModel`]类适用于所有需要2D图像输入且受上下文调节的UNet变体。\n- 所有模型位于[`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)，每种架构应有独立文件，如[`unets/unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py)、[`transformers/transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py)等。\n- 模型**不**采用单文件政策，应使用小型建模模块如[`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py)、[`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py)、[`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py)等。**注意**：这与Transformers的建模文件截然不同，表明模型未完全遵循单文件政策。\n- 模型意图暴露复杂度（类似PyTorch的`Module`类），并提供明确错误提示。\n- 所有模型继承`ModelMixin`和`ConfigMixin`。\n- 当不涉及重大代码变更、保持向后兼容性且显著提升内存/计算效率时，可对模型进行性能优化。\n- 模型默认应具备最高精度和最低性能设置。\n- 若新模型检查点可归类为现有架构，应适配现有架构而非新建文件。仅当架构根本性不同时才创建新文件。\n- 模型设计应便于未来扩展。可通过限制公开函数参数、配置参数和\"预见\"变更实现。例如：优先采用可扩展的`string`类型参数而非布尔型`is_..._type`参数。对现有架构的修改应保持最小化。\n- 模型设计需在代码可读性与多检查点支持间权衡。多数情况下应适配现有类，但某些例外（如[UNet块](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_blocks.py)和[注意力处理器](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)）需新建类以保证长期可读性。\n\n### 调度器(Schedulers)\n\n调度器负责引导推理去噪过程及定义训练噪声计划。它们设计为独立的可加载配置类，严格遵循**单文件政策**。\n\n遵循原则：\n- 所有调度器位于[`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)。\n- 调度器**禁止**从大型工具文件导入，必须保持高度自包含。\n- 一个调度器Python文件对应一种算法（如论文定义的算法）。\n- 若调度器功能相似，可使用`# Copied from`机制。\n- 所有调度器继承`SchedulerMixin`和`ConfigMixin`。\n- 调度器可通过[`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config)轻松切换（详见[此处](../using-diffusers/schedulers)）。\n- 每个调度器必须包含`set_num_inference_steps`和`step`函数。在每次去噪过程前（即调用`step(...)`前）必须调用`set_num_inference_steps(...)`。\n- 每个调度器通过`timesteps`属性暴露需要\"循环\"的时间步，这是模型将被调用的时间步数组。\n- `step(...)`函数接收模型预测输出和\"当前\"样本(x_t)，返回\"前一个\"略去噪的样本(x_t-1)。\n- 鉴于扩散调度器的复杂性，`step`函数不暴露全部细节，可视为\"黑盒\"。\n- 几乎所有新调度器都应在新文件中实现。"
  },
  {
    "path": "docs/source/zh/hybrid_inference/api_reference.md",
    "content": "# 混合推理 API 参考\n\n## 远程解码\n\n[[autodoc]] utils.remote_utils.remote_decode\n\n## 远程编码\n\n[[autodoc]] utils.remote_utils.remote_encode"
  },
  {
    "path": "docs/source/zh/hybrid_inference/overview.md",
    "content": "<!--版权 2025 HuggingFace 团队。保留所有权利。\n\n根据 Apache 许可证 2.0 版本（\"许可证\"）授权；除非遵守许可证，否则不得使用此文件。\n您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，否则根据许可证分发的软件按\"原样\"分发，不附带任何明示或暗示的担保或条件。请参阅许可证以了解具体的语言管理权限和限制。\n-->\n\n# 混合推理\n\n**通过混合推理赋能本地 AI 构建者**\n\n> [!TIP]\n> 混合推理是一项[实验性功能](https://huggingface.co/blog/remote_vae)。\n> 可以在此处提供反馈[此处](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml)。\n\n## 为什么使用混合推理？\n\n混合推理提供了一种快速简单的方式来卸载本地生成需求。\n\n- 🚀 **降低要求：** 无需昂贵硬件即可访问强大模型。\n- 💎 **无妥协：** 在不牺牲性能的情况下实现最高质量。\n- 💰 **成本效益高：** 它是免费的！🤑\n- 🎯 **多样化用例：** 与 Diffusers � 和更广泛的社区完全兼容。\n- 🔧 **开发者友好：** 简单请求，快速响应。\n\n---\n\n## 可用模型\n\n* **VAE 解码 🖼️：** 快速将潜在表示解码为高质量图像，不影响性能或工作流速度。\n* **VAE 编码 🔢：** 高效将图像编码为潜在表示，用于生成和训练。\n* **文本编码器 📃（即将推出）：** 快速准确地计算提示的文本嵌入，确保流畅高质量的工作流。\n\n---\n\n## 集成\n\n* **[SD.Next](https://github.com/vladmandic/sdnext)：** 一体化 UI，直接支持混合推理。\n* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae)：** 用于混合推理的 ComfyUI 节点。\n\n## 更新日志\n\n- 2025 年 3 月 10 日：添加了 VAE 编码\n- 2025 年 3 月 2 日：初始发布，包含 VAE 解码\n\n## 内容\n\n文档分为三个部分：\n\n* **VAE 解码** 学习如何使用混合推理进行 VAE 解码的基础知识。\n* **VAE 编码** 学习如何使用混合推理进行 VAE 编码的基础知识。\n* **API 参考** 深入了解任务特定设置和参数。"
  },
  {
    "path": "docs/source/zh/hybrid_inference/vae_encode.md",
    "content": "# 入门：使用混合推理进行 VAE 编码\n\nVAE 编码用于训练、图像到图像和图像到视频——将图像或视频转换为潜在表示。\n\n## 内存\n\n这些表格展示了在不同 GPU 上使用 SD v1 和 SD XL 进行 VAE 编码的 VRAM 需求。\n\n对于这些 GPU 中的大多数，内存使用百分比决定了其他模型（文本编码器、UNet/Transformer）必须被卸载，或者必须使用分块编码，这会增加时间并影响质量。\n\n<details><summary>SD v1.5</summary>\n\n| GPU                           | 分辨率   |   时间（秒） |   内存（%） |   分块时间（秒） |   分块内存（%） |\n|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|\n| NVIDIA GeForce RTX 4090       | 512x512      |            0.015 |      3.51901 |               0.015 |            3.51901 |\n| NVIDIA GeForce RTX 4090       | 256x256      |            0.004 |      1.3154  |               0.005 |            1.3154  |\n| NVIDIA GeForce RTX 4090       | 2048x2048    |            0.402 |     47.1852  |               0.496 |            3.51901 |\n| NVIDIA GeForce RTX 4090       | 1024x1024    |            0.078 |     12.2658  |               0.094 |            3.51901 |\n| NVIDIA GeForce RTX 4080 SUPER | 512x512      |            0.023 |      5.30105 |               0.023 |            5.30105 |\n| NVIDIA GeForce RTX 4080 SUPER | 256x256      |            0.006 |      1.98152 |               0.006 |            1.98152 |\n| NVIDIA GeForce RTX 4080 SUPER | 2048x2048    |            0.574 |     71.08    |               0.656 |            5.30105 |\n| NVIDIA GeForce RTX 4080 SUPER | 1024x1024    |            0.111 |     18.4772  |               0.14  |            5.30105 |\n| NVIDIA GeForce RTX 3090       | 512x512      |            0.032 |      3.52782 |               0.032 |            3.52782 |\n| NVIDIA GeForce RTX 3090       | 256x256      |            0.01  |      1.31869 |               0.009 |            1.31869 |\n| NVIDIA GeForce RTX 3090       | 2048x2048    |            0.742 |     47.3033  |               0.954 |            3.52782 |\n| NVIDIA GeForce RTX 3090       | 1024x1024    |            0.136 |     12.2965  |               0.207 |            3.52782 |\n| NVIDIA GeForce RTX 3080       | 512x512      |            0.036 |      8.51761 |               0.036 |            8.51761 |\n| NVIDIA GeForce RTX 3080       | 256x256      |            0.01  |      3.18387 |               0.01  |            3.18387 |\n| NVIDIA GeForce RTX 3080       | 2048x2048    |            0.863 |     86.7424  |               1.191 |            8.51761 |\n| NVIDIA GeForce RTX 3080       | 1024x1024    |            0.157 |     29.6888  |               0.227 |            8.51761 |\n| NVIDIA GeForce RTX 3070       | 512x512      |            0.051 |     10.6941  |               0.051 |           10.6941  |\n| NVIDIA GeForce RTX 3070       | 256x256      |            0.015 |\n|      3.99743 |               0.015 |            3.99743 |\n| NVIDIA GeForce RTX 3070       | 2048x2048    |            1.217 |     96.054   |               1.482 |           10.6941  |\n| NVIDIA GeForce RTX 3070       | 1024x1024    |            0.223 |     37.2751  |               0.327 |           10.6941  |\n\n</details>\n\n<details><summary>SDXL</summary>\n\n| GPU                           | Resolution   |   Time (seconds) |   Memory Consumed (%) |   Tiled Time (seconds) |   Tiled Memory (%) |\n|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|\n| NVIDIA GeForce RTX 4090       | 512x512      |            0.029 |               4.95707 |                  0.029 |            4.95707 |\n| NVIDIA GeForce RTX 4090       | 256x256      |            0.007 |               2.29666 |                  0.007 |            2.29666 |\n| NVIDIA GeForce RTX 4090       | 2048x2048    |            0.873 |              66.3452  |                  0.863 |           15.5649  |\n| NVIDIA GeForce RTX 4090       | 1024x1024    |            0.142 |              15.5479  |                  0.143 |           15.5479  |\n| NVIDIA GeForce RTX 4080 SUPER | 512x512      |            0.044 |               7.46735 |                  0.044 |            7.46735 |\n| NVIDIA GeForce RTX 4080 SUPER | 256x256      |            0.01  |               3.4597  |                  0.01  |            3.4597  |\n| NVIDIA GeForce RTX 4080 SUPER | 2048x2048    |            1.317 |              87.1615  |                  1.291 |           23.447   |\n| NVIDIA GeForce RTX 4080 SUPER | 1024x1024    |            0.213 |              23.4215  |                  0.214 |           23.4215  |\n| NVIDIA GeForce RTX 3090       | 512x512      |            0.058 |               5.65638 |                  0.058 |            5.65638 |\n| NVIDIA GeForce RTX 3090       | 256x256      |            0.016 |               2.45081 |                  0.016 |            2.45081 |\n| NVIDIA GeForce RTX 3090       | 2048x2048    |            1.755 |              77.8239  |                  1.614 |           18.4193  |\n| NVIDIA GeForce RTX 3090       | 1024x1024    |            0.265 |              18.4023  |                  0.265 |           18.4023  |\n| NVIDIA GeForce RTX 3080       | 512x512      |            0.064 |              13.6568  |                  0.064 |           13.6568  |\n| NVIDIA GeForce RTX 3080       | 256x256      |            0.018 |               5.91728 |                  0.018 |            5.91728 |\n| NVIDIA GeForce RTX 3080       | 2048x2048    |          内存不足 (OOM) |             内存不足 (OOM) |                  1.866 |           44.4717  |\n| NVIDIA GeForce RTX 3080       | 1024x1024    |            0.302 |              44.4308  |                  0.302 |           44.4308  |\n| NVIDIA GeForce RTX 3070       | 512x512      |            0.093 |              17.1465  |                  0.093 |           17.1465  |\n| NVIDIA GeForce R\n| NVIDIA GeForce RTX 3070       | 256x256      |            0.025 |               7.42931 |                  0.026 |            7.42931 |\n| NVIDIA GeForce RTX 3070       | 2048x2048    |          OOM     |             OOM       |                  2.674 |           55.8355  |\n| NVIDIA GeForce RTX 3070       | 1024x1024    |            0.443 |              55.7841  |                  0.443 |           55.7841  |\n\n</details>\n\n## 可用 VAE\n\n|   | **端点** | **模型** |\n|:-:|:-----------:|:--------:|\n| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |\n| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |\n| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |\n\n\n> [!TIP]\n> 模型支持可以在此处请求：[这里](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml)。\n\n\n## 代码\n\n> [!TIP]\n> 从 `main` 安装 `diffusers` 以运行代码：`pip install git+https://github.com/huggingface/diffusers@main`\n\n\n一个辅助方法简化了与混合推理的交互。\n\n```python\nfrom diffusers.utils.remote_utils import remote_encode\n```\n\n### 基本示例\n\n让我们编码一张图像，然后解码以演示。\n\n<figure class=\"image flex flex-col items-center justify-center text-center m-0 w-full\">\n<img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg\"/>\n</figure>\n\n<details><summary>代码</summary>\n\n```python\nfrom diffusers.utils import load_image\nfrom diffusers.utils.remote_utils import remote_decode\n\nimage = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true\")\n\nlatent = remote_encode(\n    endpoint=\"https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/\",\n    scaling_factor=0.3611,\n    shift_factor=0.1159,\n)\n\ndecoded = remote_decode(\n    endpoint=\"https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/\",\n    tensor=latent,\n    scaling_factor=0.3611,\n    shift_factor=0.1159,\n)\n```\n\n</details>\n\n<figure class=\"image flex flex-col items-center justify-center text-center m-0 w-full\">\n<img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png\"/>\n</figure>\n\n\n### 生成\n\n现在让我们看一个生成示例，我们将编码图像，生成，然后远程解码！\n\n<details><summary>代码</summary>\n\n```python\nimport torch\nfrom diffusers import StableDiffusionImg2ImgPip\nfrom diffusers.utils import load_image\nfrom diffusers.utils.remote_utils import remote_decode, remote_encode\n\npipe = StableDiffusionImg2ImgPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n    vae=None,\n).to(\"cuda\")\n\ninit_image = load_image(\n    \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\n)\ninit_image = init_image.resize((768, 512))\n\ninit_latent = remote_encode(\n    endpoint=\"https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/\",\n    image=init_image,\n    scaling_factor=0.18215,\n)\n\nprompt = \"A fantasy landscape, trending on artstation\"\nlatent = pipe(\n    prompt=prompt,\n    image=init_latent,\n    strength=0.75,\n    output_type=\"latent\",\n).images\n\nimage = remote_decode(\n    endpoint=\"https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/\",\n    tensor=latent,\n    scaling_factor=0.18215,\n)\nimage.save(\"fantasy_landscape.jpg\")\n```\n\n</details>\n\n<figure class=\"image flex flex-col items-center justify-center text-center m-0 w-full\">\n<img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png\"/>\n</figure>\n\n## 集成\n\n* **[SD.Next](https://github.com/vladmandic/sdnext):** 具有直接支持混合推理功能的一体化用户界面。\n* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** 用于混合推理的 ComfyUI 节点。"
  },
  {
    "path": "docs/source/zh/index.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n<p align=\"center\">\n    <br>\n    <img src=\"https://raw.githubusercontent.com/huggingface/diffusers/77aadfee6a891ab9fcfb780f87c693f7a5beeb8e/docs/source/imgs/diffusers_library.jpg\" width=\"400\"/>\n    <br>\n</p>\n\n# 🧨 Diffusers\n\n🤗 Diffusers 是一个值得首选用于生成图像、音频甚至 3D 分子结构的，最先进的预训练扩散模型库。\n无论您是在寻找简单的推理解决方案，还是想训练自己的扩散模型，🤗 Diffusers 这一模块化工具箱都能对其提供支持。\n本库的设计更偏重于[可用而非高性能](conceptual/philosophy#usability-over-performance)、[简明而非简单](conceptual/philosophy#simple-over-easy)以及[易用而非抽象](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction)。\n\n\n本库包含三个主要组件：\n\n- 最先进的扩散管道 [diffusion pipelines](api/pipelines/overview)，只需几行代码即可进行推理。\n- 可交替使用的各种噪声调度器 [noise schedulers](api/schedulers/overview)，用于平衡生成速度和质量。\n- 预训练模型 [models](api/models)，可作为构建模块，并与调度程序结合使用，来创建您自己的端到端扩散系统。\n\n<div class=\"mt-10\">\n  <div class=\"w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5\">\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./tutorials/tutorial_overview\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-blue-400 to-blue-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Tutorials</div>\n      <p class=\"text-gray-700\">Learn the fundamental skills you need to start generating outputs, build your own diffusion system, and train a diffusion model. We recommend starting here if you're using 🤗 Diffusers for the first time!</p>\n    </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./using-diffusers/loading_overview\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-indigo-400 to-indigo-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">How-to guides</div>\n      <p class=\"text-gray-700\">Practical guides for helping you load pipelines, models, and schedulers. You'll also learn how to use pipelines for specific tasks, control how outputs are generated, optimize for inference speed, and different training techniques.</p>\n    </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./conceptual/philosophy\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-pink-400 to-pink-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Conceptual guides</div>\n      <p class=\"text-gray-700\">Understand why the library was designed the way it was, and learn more about the ethical guidelines and safety implementations for using the library.</p>\n   </a>\n    <a class=\"!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg\" href=\"./api/models\"\n      ><div class=\"w-full text-center bg-gradient-to-br from-purple-400 to-purple-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed\">Reference</div>\n      <p class=\"text-gray-700\">Technical descriptions of how 🤗 Diffusers classes and methods work.</p>\n    </a>\n  </div>\n</div>\n\n## 🧨 Diffusers pipelines\n\n下表汇总了当前所有官方支持的pipelines及其对应的论文.\n\n| 管道 | 论文/仓库 | 任务 |\n|---|---|:---:|\n| [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://huggingface.co/papers/2211.06679) | Image-to-Image Text-Guided Generation |\n| [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation |\n| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) | Image-to-Image Text-Guided Generation |\n| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://huggingface.co/papers/2210.05559) | Image-to-Image Text-Guided Generation |\n| [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |\n| [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://huggingface.co/papers/2006.11239) | Unconditional Image Generation |\n| [ddim](./api/pipelines/ddim) | [Denoising Diffusion Implicit Models](https://huggingface.co/papers/2010.02502) | Unconditional Image Generation |\n| [if](./if) | [**IF**](./api/pipelines/if) | Image Generation |\n| [if_img2img](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation |\n| [if_inpainting](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation |\n| [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752)| Text-to-Image Generation |\n| [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752)| Super Resolution Image-to-Image |\n| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) | Unconditional Image Generation |\n| [paint_by_example](./api/pipelines/paint_by_example) | [Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://huggingface.co/papers/2211.13227) | Image-Guided Image Inpainting |\n| [pndm](./api/pipelines/pndm) | [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://huggingface.co/papers/2202.09778) | Unconditional Image Generation |\n| [score_sde_ve](./api/pipelines/score_sde_ve) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |\n| [score_sde_vp](./api/pipelines/score_sde_vp) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |\n| [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [Semantic Guidance](https://huggingface.co/papers/2301.12247) | Text-Guided Generation |\n| [stable_diffusion_text2img](./api/pipelines/stable_diffusion/text2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation |\n| [stable_diffusion_img2img](./api/pipelines/stable_diffusion/img2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation |\n| [stable_diffusion_inpaint](./api/pipelines/stable_diffusion/inpaint) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting |\n| [stable_diffusion_panorama](./api/pipelines/stable_diffusion/panorama) | [MultiDiffusion](https://multidiffusion.github.io/) | Text-to-Panorama Generation |\n| [stable_diffusion_pix2pix](./api/pipelines/stable_diffusion/pix2pix) | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://huggingface.co/papers/2211.09800)  | Text-Guided Image Editing|\n| [stable_diffusion_pix2pix_zero](./api/pipelines/stable_diffusion/pix2pix_zero) | [Zero-shot Image-to-Image Translation](https://pix2pixzero.github.io/) | Text-Guided Image Editing |\n| [stable_diffusion_attend_and_excite](./api/pipelines/stable_diffusion/attend_and_excite) | [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://huggingface.co/papers/2301.13826) | Text-to-Image Generation |\n| [stable_diffusion_self_attention_guidance](./api/pipelines/stable_diffusion/self_attention_guidance) | [Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://huggingface.co/papers/2210.00939) | Text-to-Image Generation Unconditional Image Generation |\n| [stable_diffusion_image_variation](./stable_diffusion/image_variation) | [Stable Diffusion Image Variations](https://github.com/LambdaLabsML/lambda-diffusers#stable-diffusion-image-variations) | Image-to-Image Generation |\n| [stable_diffusion_latent_upscale](./stable_diffusion/latent_upscale) | [Stable Diffusion Latent Upscaler](https://twitter.com/StabilityAI/status/1590531958815064065) | Text-Guided Super Resolution Image-to-Image |\n| [stable_diffusion_model_editing](./api/pipelines/stable_diffusion/model_editing) | [Editing Implicit Assumptions in Text-to-Image Diffusion Models](https://time-diffusion.github.io/) | Text-to-Image Model Editing |\n| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |\n| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |\n| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Depth-Conditional Stable Diffusion](https://github.com/Stability-AI/stablediffusion#depth-conditional-stable-diffusion) | Depth-to-Image Generation |\n| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |\n| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [Safe Stable Diffusion](https://huggingface.co/papers/2211.05105) | Text-Guided Generation |\n| [stable_unclip](./stable_unclip) | Stable unCLIP | Text-to-Image Generation |\n| [stable_unclip](./stable_unclip) | Stable unCLIP | Image-to-Image Text-Guided Generation |\n| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364) | Unconditional Image Generation |\n| [text_to_video_sd](./api/pipelines/text_to_video) | [Modelscope's Text-to-video-synthesis Model in Open Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) | Text-to-Video Generation |\n| [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125)(implementation by [kakaobrain](https://github.com/kakaobrain/karlo)) | Text-to-Image Generation |\n| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://huggingface.co/papers/2211.08332) | Text-to-Image Generation |\n| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://huggingface.co/papers/2211.08332) | Image Variations Generation |\n| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://huggingface.co/papers/2211.08332) | Dual Image and Text Guided Generation |\n| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://huggingface.co/papers/2111.14822) | Text-to-Image Generation |\n"
  },
  {
    "path": "docs/source/zh/installation.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 安装\n\n在你正在使用的任意深度学习框架中安装 🤗 Diffusers 。\n\n🤗 Diffusers已在Python 3.8+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明，针对你正在使用的深度学习框架进行安装：\n\n- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.\n- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.\n\n## 使用pip安装\n\n你需要在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Diffusers 。\n\n如果你对 Python 虚拟环境不熟悉，可以看看这个[教程](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).\n\n在虚拟环境中，你可以轻松管理不同的项目，避免依赖项之间的兼容性问题。\n\n首先，在你的项目目录下创建一个虚拟环境：\n\n```bash\npython -m venv .env\n```\n\n激活虚拟环境：\n\n```bash\nsource .env/bin/activate\n```\n\n现在，你就可以安装 🤗 Diffusers了！使用下边这个命令：\n\n**PyTorch**\n\n```bash\npip install diffusers[\"torch\"]\n```\n\n**Flax**\n\n```bash\npip install diffusers[\"flax\"]\n```\n\n## 从源代码安装\n\n在从源代码安装 `diffusers` 之前，确保你已经安装了 `torch` 和 `accelerate`。\n\n`torch`的安装教程可以看 `torch` [文档](https://pytorch.org/get-started/locally/#start-locally).\n\n安装 `accelerate`\n\n```bash\npip install accelerate\n```\n\n从源码安装 🤗 Diffusers 需要使用以下命令:\n\n```bash\npip install git+https://github.com/huggingface/diffusers\n```\n\n这个命令安装的是最新的 `main`版本，而不是最近的`stable`版。\n`main`是一直和最新进展保持一致的。比如，上次发布的正式版中有bug，在`main`中可以看到这个bug被修复了，但是新的正式版此时尚未推出。\n但是这也意味着 `main`版本不保证是稳定的。\n\n我们努力保持`main`版本正常运行，大多数问题都能在几个小时或一天之内解决\n\n如果你遇到了问题，可以提 [Issue](https://github.com/huggingface/transformers/issues)，这样我们就能更快修复问题了。\n\n## 可修改安装\n\n如果你想做以下两件事，那你可能需要一个可修改代码的安装方式：\n\n* 使用 `main`版本的源代码。\n* 为 🤗 Diffusers 贡献，需要测试代码中的变化。\n\n使用以下命令克隆并安装 🤗 Diffusers:\n\n```bash\ngit clone https://github.com/huggingface/diffusers.git\ncd diffusers\n```\n\n**PyTorch**\n\n```sh\npip install -e \".[torch]\"\n```\n\n**Flax**\n\n```sh\npip install -e \".[flax]\"\n```\n\n这些命令将连接到你克隆的版本库和你的 Python 库路径。\n现在，不只是在通常的库路径，Python 还会在你克隆的文件夹内寻找包。\n例如，如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.10/Site-packages/`，Python 也会搜索你克隆到的文件夹。`~/diffusers/`。\n\n> [!WARNING]\n> 如果你想继续使用这个库，你必须保留 `diffusers` 文件夹。\n\n\n现在你可以用下面的命令轻松地将你克隆的 🤗 Diffusers 库更新到最新版本。\n\n```bash\ncd ~/diffusers/\ngit pull\n```\n\n你的Python环境将在下次运行时找到`main`版本的 🤗 Diffusers。\n\n## 注意 Telemetry 日志\n\n我们的库会在使用`from_pretrained()`请求期间收集 telemetry 信息。这些数据包括Diffusers和PyTorch/Flax的版本，请求的模型或管道类，以及预训练检查点的路径（如果它被托管在Hub上的话）。\n这些使用数据有助于我们调试问题并确定新功能的开发优先级。\nTelemetry 数据仅在从 HuggingFace Hub 中加载模型和管道时发送，而不会在本地使用期间收集。\n\n我们知道，并不是每个人都想分享这些的信息，我们尊重您的隐私，\n因此您可以通过在终端中设置 `DISABLE_TELEMETRY` 环境变量从而禁用 Telemetry 数据收集：\n\n\nLinux/MacOS :\n```bash\nexport DISABLE_TELEMETRY=YES\n```\n\nWindows :\n```bash\nset DISABLE_TELEMETRY=YES\n```"
  },
  {
    "path": "docs/source/zh/modular_diffusers/auto_pipeline_blocks.md",
    "content": "<!--版权所有 2025 The HuggingFace Team。保留所有权利。\n\n根据Apache许可证2.0版（\"许可证\"）授权；除非符合许可证，否则不得使用此文件。您可以在\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n获取许可证的副本。\n\n除非适用法律要求或书面同意，根据许可证分发的软件按\"原样\"分发，无任何明示或暗示的担保或条件。有关许可证的特定语言管理权限和限制，请参阅许可证。\n-->\n\n# AutoPipelineBlocks\n\n[`~modular_pipelines.AutoPipelineBlocks`] 是一种包含支持不同工作流程的块的多块类型。它根据运行时提供的输入自动选择要运行的子块。这通常用于将多个工作流程（文本到图像、图像到图像、修复）打包到一个管道中以便利。\n\n本指南展示如何创建 [`~modular_pipelines.AutoPipelineBlocks`]。\n\n创建三个 [`~modular_pipelines.ModularPipelineBlocks`] 用于文本到图像、图像到图像和修复。这些代表了管道中可用的不同工作流程。\n\n<hfoptions id=\"auto\">\n<hfoption id=\"text-to-image\">\n\n```py\nimport torch\nfrom diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam\n\nclass TextToImageBlock(ModularPipelineBlocks):\n    model_name = \"text2img\"\n\n    @property\n    def inputs(self):\n        return [InputParam(name=\"prompt\")]\n\n    @property\n    def intermediate_outputs(self):\n        return []\n\n    @property\n    def description(self):\n        return \"我是一个文本到图像的工作流程！\"\n\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        print(\"运行文本到图像工作流程\")\n        # 在这里添加你的文本到图像逻辑\n        # 例如：根据提示生成图像\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\n\n</hfoption>\n<hfoption id=\"image-to-image\">\n\n```py\nclass ImageToImageBlock(ModularPipelineBlocks):\n    model_name = \"img2img\"\n\n    @property\n    def inputs(self):\n        return [InputParam(name=\"prompt\"), InputParam(name=\"image\")]\n\n    @property\n    def intermediate_outputs(self):\n        return []\n\n    @property\n    def description(self):\n        return \"我是一个图像到图像的工作流程！\"\n\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        print(\"运行图像到图像工作流程\")\n        # 在这里添加你的图像到图像逻辑\n        # 例如：根据提示转换输入图像\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\n\n</hfoption>\n<hfoption id=\"inpaint\">\n\n```py\nclass InpaintBlock(ModularPipelineBlocks):\n    model_name = \"inpaint\"\n\n    @property\n    def inputs(self):\n        return [InputParam(name=\"prompt\"), InputParam(name=\"image\"), InputParam(name=\"mask\")]\n\n    @property\n\n    def intermediate_outputs(self):\n        return []\n\n    @property\n    def description(self):\n        return \"我是一个修复工作流！\"\n\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        print(\"运行修复工作流\")\n        # 在这里添加你的修复逻辑\n        # 例如：根据提示填充被遮罩的区域\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\n</hfoption>\n</hfoptions>\n\n创建一个包含子块类及其对应块名称列表的[`~modular_pipelines.AutoPipelineBlocks`]类。\n\n你还需要包括`block_trigger_inputs`，一个触发相应块的输入名称列表。如果在运行时提供了触发输入，则选择该块运行。使用`None`来指定如果未检测到触发输入时运行的默认块。\n\n最后，重要的是包括一个`description`，清楚地解释哪些输入触发哪些工作流。这有助于用户理解如何运行特定的工作流。\n\n```py\nfrom diffusers.modular_pipelines import AutoPipelineBlocks\n\nclass AutoImageBlocks(AutoPipelineBlocks):\n    # 选择子块类的列表\n    block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]\n    # 每个块的名称，顺序相同\n    block_names = [\"inpaint\", \"img2img\", \"text2img\"]\n    # 决定运行哪个块的触发输入\n    # - \"mask\" 触发修复工作流\n    # - \"image\" 触发img2img工作流（但仅在未提供mask时）\n    # - 如果以上都没有，运行text2img工作流（默认）\n    block_trigger_inputs = [\"mask\", \"image\", None]\n    # 对于AutoPipelineBlocks来说，描述极其重要\n\n    def description(self):\n        return (\n            \"Pipeline generates images given different types of conditions!\\n\"\n            + \"This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\\n\"\n            + \" - inpaint workflow is run when `mask` is provided.\\n\"\n            + \" - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\\n\"\n            + \" - text2img workflow is run when neither `image` nor `mask` is provided.\\n\"\n        )\n```\n\n包含`description`以避免任何关于如何运行块和需要什么输入的混淆**非常**重要。虽然[`~modular_pipelines.AutoPipelineBlocks`]很方便，但如果它没有正确解释，其条件逻辑可能难以理解。\n\n创建`AutoImageBlocks`的一个实例。\n\n```py\nauto_blocks = AutoImageBlocks()\n```\n\n对于更复杂的组合，例如在更大的管道中作为子块使用的嵌套[`~modular_pipelines.AutoPipelineBlocks`]块，使用[`~modular_pipelines.SequentialPipelineBlocks.get_execution_blocks`]方法根据你的输入提取实际运行的块。\n\n```py\nauto_blocks.get_execution_blocks(\"mask\")\n```\n"
  },
  {
    "path": "docs/source/zh/modular_diffusers/components_manager.md",
    "content": "<!--版权所有 2025 HuggingFace 团队。保留所有权利。\n\n根据 Apache 许可证 2.0 版（\"许可证\"）授权；除非遵守许可证，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件按\"原样\"分发，无任何明示或暗示的担保或条件。请参阅许可证以了解特定语言管理权限和限制。\n-->\n\n# 组件管理器\n\n[`ComponentsManager`] 是 Modular Diffusers 的模型注册和管理系统。它添加和跟踪模型，存储有用的元数据（模型大小、设备放置、适配器），防止重复模型实例，并支持卸载。\n\n本指南将展示如何使用 [`ComponentsManager`] 来管理组件和设备内存。\n\n## 添加组件\n\n[`ComponentsManager`] 应与 [`ModularPipeline`] 一起创建，在 [`~ModularPipeline.from_pretrained`] 或 [`~ModularPipelineBlocks.init_pipeline`] 中。\n\n> [!TIP]\n> `collection` 参数是可选的，但可以更轻松地组织和管理组件。\n\n<hfoptions id=\"create\">\n<hfoption id=\"from_pretrained\">\n\n```py\nfrom diffusers import ModularPipeline, ComponentsManager\n\ncomp = ComponentsManager()\npipe = ModularPipeline.from_pretrained(\"YiYiXu/modular-demo-auto\", components_manager=comp, collection=\"test1\")\n```\n\n</hfoption>\n<hfoption id=\"init_pipeline\">\n\n```py\nfrom diffusers import ComponentsManager\nfrom diffusers.modular_pipelines import SequentialPipelineBlocks\nfrom diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS\n\nt2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)\n\nmodular_repo_id = \"YiYiXu/modular-loader-t2i-0704\"\ncomponents = ComponentsManager()\nt2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)\n```\n\n</hfoption>\n</hfoptions>\n\n组件仅在调用 [`~ModularPipeline.load_components`] 或 [`~ModularPipeline.load_components`] 时加载和注册。以下示例使用 [`~ModularPipeline.load_components`] 创建第二个管道，重用第一个管道的所有组件，并将其分配到不同的集合。\n\n```py\npipe.load_components()\npipe2 = ModularPipeline.from_pretrained(\"YiYiXu/modular-demo-auto\", components_manager=comp, collection=\"test2\")\n```\n\n使用 [`~ModularPipeline.null_component_names`] 属性来识别需要加载的任何组件，使用 [`~ComponentsManager.get_components_by_names`] 检索它们，然后调用 [`~ModularPipeline.update_components`] 来添加缺失的组件。\n\n```py\npipe2.null_component_names \n['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']\n\ncomp_dict = comp.get_components_by_names(names=pipe2.null_component_names)\npipe2.update_components(**comp_dict)\n```\n\n要添加单个组件，请使用 [`~ComponentsManager.add`] 方法。这会使用唯一 id 注册一个组件。\n\n```py\nfrom diffusers import AutoModel\n\ntext_encoder = AutoModel.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", subfolder=\"text_encoder\")\ncomponent_id = comp.add(\"text_encoder\", text_encoder)\ncomp\n```\n\n使用 [`~ComponentsManager.remove`] 通过其 id 移除一个组件。\n\n```py\ncomp.remove(\"text_encoder_139917733042864\")\n```\n\n## 检索组件\n\n[`ComponentsManager`] 提供了几种方法来检索已注册的组件。\n\n### get_one\n\n[`~ComponentsManager.get_one`] 方法返回单个组件，并支持对 `name` 参数进行模式匹配。如果多个组件匹配，[`~ComponentsManager.get_one`] 会返回错误。\n\n| 模式       | 示例                             | 描述                                   |\n|-------------|----------------------------------|-------------------------------------------|\n| exact       | `comp.get_one(name=\"unet\")`      | 精确名称匹配                          |\n| wildcard    | `comp.get_one(name=\"unet*\")`     | 名称以 \"unet\" 开头                |\n| exclusion   | `comp.get_one(name=\"!unet\")`     | 排除名为 \"unet\" 的组件           |\n| or          | `comp.get_one(name=\"unet&#124;vae\")`  | 名称为 \"unet\" 或 \"vae\"                   |\n\n[`~ComponentsManager.get_one`] 还通过 `collection` 参数或 `load_id` 参数过滤组件。\n\n```py\ncomp.get_one(name=\"unet\", collection=\"sdxl\")\n```\n\n### get_components_by_names\n\n[`~ComponentsManager.get_components_by_names`] 方法接受一个名称列表，并返回一个将名称映射到组件的字典。这在 [`ModularPipeline`] 中特别有用，因为它们提供了所需组件名称的列表，并且返回的字典可以直接传递给 [`~ModularPipeline.update_components`]。\n\n```py\ncomponent_dict = comp.get_components_by_names(names=[\"text_encoder\", \"unet\", \"vae\"])\n{\"text_encoder\": component1, \"unet\": component2, \"vae\": component3}\n```\n\n## 重复检测\n\n建议使用 [`ComponentSpec`] 加载模型组件，以分配具有唯一 id 的组件，该 id 编码了它们的加载参数。这允许 [`ComponentsManager`] 自动检测并防止重复的模型实例，即使不同的对象代表相同的底层检查点。\n\n```py\nfrom diffusers import ComponentSpec, ComponentsManager\nfrom transformers import CLIPTextModel\n\ncomp = ComponentsManager()\n\n# 为第一个文本编码器创建 ComponentSpec\nspec = ComponentSpec(name=\"text_encoder\", repo=\"stabilityai/stable-diffusion-xl-base-1.0\", subfolder=\"text_encoder\", type_hint=AutoModel)\n# 为重复的文本编码器创建 ComponentSpec（它是相同的检查点，来自相同的仓库/子文件夹）\nspec_duplicated = ComponentSpec(name=\"text_encoder_duplicated\", repo=\"stabilityai/stable-diffusion-xl-base-1.0\", subfolder=\"text_encoder\", ty\npe_hint=CLIPTextModel)\n\n# 加载并添加两个组件 - 管理器会检测到它们是同一个模型\ncomp.add(\"text_encoder\", spec.load())\ncomp.add(\"text_encoder_duplicated\", spec_duplicated.load())\n```\n\n这会返回一个警告，附带移除重复项的说明。\n\n```py\nComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('<component_id>')`.\n'text_encoder_duplicated_139917580682672'\n```\n\n您也可以不使用 [`ComponentSpec`] 添加组件，并且在大多数情况下，即使您以不同名称添加相同组件，重复检测仍然有效。\n\n然而，当您将相同组件加载到不同对象时，[`ComponentManager`] 无法检测重复项。在这种情况下，您应该使用 [`ComponentSpec`] 加载模型。\n\n```py\ntext_encoder_2 = AutoModel.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", subfolder=\"text_encoder\")\ncomp.add(\"text_encoder\", text_encoder_2)\n'text_encoder_139917732983664'\n```\n\n## 集合\n\n集合是为组件分配的标签，用于更好的组织和管理。使用 [`~ComponentsManager.add`] 中的 `collection` 参数将组件添加到集合中。\n\n每个集合中只允许每个名称有一个组件。添加第二个同名组件会自动移除第一个组件。\n\n```py\nfrom diffusers import ComponentSpec, ComponentsManager\n\ncomp = ComponentsManager()\n# 为第一个 UNet 创建 ComponentSpec\nspec = ComponentSpec(name=\"unet\", repo=\"stabilityai/stable-diffusion-xl-base-1.0\", subfolder=\"unet\", type_hint=AutoModel)\n# 为另一个 UNet 创建 ComponentSpec\nspec2 = ComponentSpec(name=\"unet\", repo=\"RunDiffusion/Juggernaut-XL-v9\", subfolder=\"unet\", type_hint=AutoModel, variant=\"fp16\")\n\n# 将两个 UNet 添加到同一个集合 - 第二个将替换第一个\ncomp.add(\"unet\", spec.load(), collection=\"sdxl\")\ncomp.add(\"unet\", spec2.load(), collection=\"sdxl\")\n```\n\n这使得在基于节点的系统中工作变得方便，因为您可以：\n\n- 使用 `collection` 标签标记所有从一个节点加载的模型。\n- 当新检查点以相同名称加载时自动替换模型。\n- 当节点被移除时批量删除集合中的所有模型。\n\n## 卸载\n\n[`~ComponentsManager.enable_auto_cpu_offload`] 方法是一种全局卸载策略，适用于所有模型，无论哪个管道在使用它们。一旦启用，您无需担心设备放置，如果您添加或移除组件。\n\n```py\ncomp.enable_auto_cpu_offload(device=\"cuda\")\n```\n\n所有模型开始时都在 CPU 上，[`ComponentsManager`] 在需要它们之前将它们移动到适当的设备，并在 GPU 内存不足时将其他模型移回 CPU。\n\n您可以设置自己的规则来决定哪些模型要卸载。\n"
  },
  {
    "path": "docs/source/zh/modular_diffusers/loop_sequential_pipeline_blocks.md",
    "content": "<!--版权 2025 The HuggingFace Team。保留所有权利。\n\n根据 Apache 许可证 2.0 版（\"许可证\"）授权；除非符合许可证，否则不得使用此文件。您可以在\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n获取许可证的副本。\n\n除非适用法律要求或书面同意，根据许可证分发的软件按\"原样\"分发，无任何明示或暗示的担保或条件。请参阅许可证了解\n特定语言下的权限和限制。\n-->\n\n# LoopSequentialPipelineBlocks\n\n[`~modular_pipelines.LoopSequentialPipelineBlocks`] 是一种多块类型，它将其他 [`~modular_pipelines.ModularPipelineBlocks`] 以循环方式组合在一起。数据循环流动，使用 `intermediate_inputs` 和 `intermediate_outputs`，并且每个块都是迭代运行的。这通常用于创建一个默认是迭代的去噪循环。\n\n本指南向您展示如何创建 [`~modular_pipelines.LoopSequentialPipelineBlocks`]。\n\n## 循环包装器\n\n[`~modular_pipelines.LoopSequentialPipelineBlocks`]，也被称为 *循环包装器*，因为它定义了循环结构、迭代变量和配置。在循环包装器内，您需要以下变量。\n\n- `loop_inputs` 是用户提供的值，等同于 [`~modular_pipelines.ModularPipelineBlocks.inputs`]。\n- `loop_intermediate_inputs` 是来自 [`~modular_pipelines.PipelineState`] 的中间变量，等同于 [`~modular_pipelines.ModularPipelineBlocks.intermediate_inputs`]。\n- `loop_intermediate_outputs` 是由块创建并添加到 [`~modular_pipelines.PipelineState`] 的新中间变量。它等同于 [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`]。\n- `__call__` 方法定义了循环结构和迭代逻辑。\n\n```py\nimport torch\nfrom diffusers.modular_pipelines import LoopSequentialPipelineBlocks, ModularPipelineBlocks, InputParam, OutputParam\n\nclass LoopWrapper(LoopSequentialPipelineBlocks):\n    model_name = \"test\"\n    @property\n    def description(self):\n        return \"I'm a loop!!\"\n    @property\n    def loop_inputs(self):\n        return [InputParam(name=\"num_steps\")]\n    @torch.no_grad()\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        # 循环结构 - 可以根据您的需求定制\n        for i in range(block_state.num_steps):\n            # loop_step 按顺序执行所有注册的块\n            components, block_state = self.loop_step(components, block_state, i=i)\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\n循环包装器可以传递额外的参数，如当前迭代索引，到循环块。\n\n## 循环块\n\n循环块是一个 [`~modular_pipelines.ModularPipelineBlocks`]，但 `__call__` 方法的行为不同。\n\n- 它从循环包装器。\n- 它直接与[`~modular_pipelines.BlockState`]一起工作，而不是[`~modular_pipelines.PipelineState`]。\n- 它不需要检索或更新[`~modular_pipelines.BlockState`]。\n\n循环块共享相同的[`~modular_pipelines.BlockState`]，以允许值在循环的每次迭代中累积和变化。\n\n```py\nclass LoopBlock(ModularPipelineBlocks):\n    model_name = \"test\"\n    @property\n    def inputs(self):\n        return [InputParam(name=\"x\")]\n    @property\n    def intermediate_outputs(self):\n        # 这个块产生的输出\n        return [OutputParam(name=\"x\")]\n    @property\n    def description(self):\n        return \"我是一个在`LoopWrapper`类内部使用的块\"\n    def __call__(self, components, block_state, i: int):\n        block_state.x += 1\n        return components, block_state\n```\n\n## LoopSequentialPipelineBlocks\n\n使用[`~modular_pipelines.LoopSequentialPipelineBlocks.from_blocks_dict`]方法将循环块添加到循环包装器中，以创建[`~modular_pipelines.LoopSequentialPipelineBlocks`]。\n\n```py\nloop = LoopWrapper.from_blocks_dict({\"block1\": LoopBlock})\n```\n\n添加更多的循环块以在每次迭代中运行，使用[`~modular_pipelines.LoopSequentialPipelineBlocks.from_blocks_dict`]。这允许您在不改变循环逻辑本身的情况下修改块。\n\n```py\nloop = LoopWrapper.from_blocks_dict({\"block1\": LoopBlock(), \"block2\": LoopBlock})\n```\n"
  },
  {
    "path": "docs/source/zh/modular_diffusers/modular_diffusers_states.md",
    "content": "<!--版权 2025 The HuggingFace Team。保留所有权利。\n\n根据Apache许可证2.0版（\"许可证\"）授权；除非符合许可证的规定，否则不得使用此文件。\n您可以在以下网址获取许可证的副本\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件是基于\"按原样\"分发的，没有任何形式的明示或暗示的担保或条件。有关许可证下特定的语言管理权限和限制，请参阅许可证。\n-->\n\n# 状态\n\n块依赖于[`~modular_pipelines.PipelineState`]和[`~modular_pipelines.BlockState`]数据结构进行通信和数据共享。\n\n| 状态 | 描述 |\n|-------|-------------|\n| [`~modular_pipelines.PipelineState`] | 维护管道执行所需的整体数据，并允许块读取和更新其数据。 |\n| [`~modular_pipelines.BlockState`] | 允许每个块使用来自`inputs`的必要数据执行其计算 |\n\n本指南解释了状态如何工作以及它们如何连接块。\n\n## PipelineState\n\n[`~modular_pipelines.PipelineState`]是所有块的全局状态容器。它维护管道的完整运行时状态，并为块提供了一种结构化的方式来读取和写入共享数据。\n\n[`~modular_pipelines.PipelineState`]中有两个字典用于结构化数据。\n\n- `values`字典是一个**可变**状态，包含用户提供的输入值的副本和由块生成的中间输出值。如果一个块修改了一个`input`，它将在调用`set_block_state`后反映在`values`字典中。\n\n```py\nPipelineState(\n  values={\n    'prompt': 'a cat'\n    'guidance_scale': 7.0\n    'num_inference_steps': 25\n    'prompt_embeds': Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1]))\n    'negative_prompt_embeds': None\n  },\n)\n```\n\n## BlockState\n\n[`~modular_pipelines.BlockState`]是[`~modular_pipelines.PipelineState`]中相关变量的局部视图，单个块需要这些变量来执行其计算。\n\n直接作为属性访问这些变量，如`block_state.image`。\n\n```py\nBlockState(\n    image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494640>\n)\n```\n\n当一个块的`__call__`方法被执行时，它用`self.get_block_state(state)`检索[`BlockState`]，执行其操作，并用`self.set_block_state(state, block_state)`更新[`~modular_pipelines.PipelineState`]。\n\n```py\ndef __call__(self, components, state):\n    # 检索BlockState\n    block_state = self.get_block_state(state)\n\n    # 对输入进行计算的逻辑\n\n    # 更新PipelineState\n    self.set_block_state(state, block_state)\n    return components, state\n```\n\n## 状态交互\n\n[`~modular_pipelines.PipelineState`]和[`~modular_pipelines.BlockState`]的交互由块的`inputs`和`intermediate_outputs`定义。\n\n- `inputs`,\n一个块可以修改输入 - 比如 `block_state.image` - 并且这个改变可以通过调用 `set_block_state` 全局传播到 [`~modular_pipelines.PipelineState`]。\n- `intermediate_outputs`，是一个块创建的新变量。它被添加到 [`~modular_pipelines.PipelineState`] 的 `values` 字典中，并且可以作为后续块的可用变量，或者由用户作为管道的最终输出访问。\n"
  },
  {
    "path": "docs/source/zh/modular_diffusers/modular_pipeline.md",
    "content": "<!--版权 2025 HuggingFace 团队。保留所有权利。\n\n根据 Apache 许可证 2.0 版（“许可证”）授权；除非符合许可证的规定，否则不得使用此文件。您可以在\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n获取许可证的副本。\n\n除非适用法律要求或书面同意，根据许可证分发的软件是基于“按原样”分发的，没有任何形式的明示或暗示的保证或条件。有关许可证的特定语言，请参阅许可证。\n-->\n\n# 模块化管道\n\n[`ModularPipeline`] 将 [`~modular_pipelines.ModularPipelineBlocks`] 转换为可执行的管道，加载模型并执行块中定义的计算步骤。它是运行管道的主要接口，与 [`DiffusionPipeline`] API 非常相似。\n\n主要区别在于在管道中包含了一个预期的 `output` 参数。\n\n<hfoptions id=\"example\">\n<hfoption id=\"text-to-image\">\n\n```py\nimport torch\nfrom diffusers.modular_pipelines import SequentialPipelineBlocks\nfrom diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS\n\nblocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)\n\nmodular_repo_id = \"YiYiXu/modular-loader-t2i-0704\"\npipeline = blocks.init_pipeline(modular_repo_id)\n\npipeline.load_components(torch_dtype=torch.float16)\npipeline.to(\"cuda\")\n\nimage = pipeline(prompt=\"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\", output=\"images\")[0]\nimage.save(\"modular_t2i_out.png\")\n```\n\n</hfoption>\n<hfoption id=\"image-to-image\">\n\n```py\nimport torch\nfrom diffusers.modular_pipelines import SequentialPipelineBlocks\nfrom diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS\n\nblocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS)\n\nmodular_repo_id = \"YiYiXu/modular-loader-t2i-0704\"\npipeline = blocks.init_pipeline(modular_repo_id)\n\npipeline.load_components(torch_dtype=torch.float16)\npipeline.to(\"cuda\")\n\nurl = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png\"\ninit_image = load_image(url)\nprompt = \"a dog catching a frisbee in the jungle\"\nimage = pipeline(prompt=prompt, image=init_image, strength=0.8, output=\"images\")[0]\nimage.save(\"modular_i2i_out.png\")\n```\n\n</hfoption>\n<hfoption id=\"inpainting\">\n\n```py\nimport torch\nfrom diffusers.modular_pipelines import SequentialPipelineBlocks\nfrom diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS\nfrom diffusers.utils import load_image\n\nblocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS)\n\nmodular_repo_id = \"YiYiXu/modular-loader-t2i-0704\"\npipeline = blocks.init_pipeline(modular_repo_id)\n\npipeline.load_components(torch_dtype=torch.float16)\npipeline.to(\"cuda\")\n\nimg_url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png\"\nmask_url = \"h\nttps://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png\"\n\ninit_image = load_image(img_url)\nmask_image = load_image(mask_url)\n\nprompt = \"A deep sea diver floating\"\nimage = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, output=\"images\")[0]\nimage.save(\"moduar_inpaint_out.png\")\n```\n\n</hfoption>\n</hfoptions>\n\n本指南将向您展示如何创建一个[`ModularPipeline`]并管理其中的组件。\n\n## 添加块\n\n块是[`InsertableDict`]对象，可以在特定位置插入，提供了一种灵活的方式来混合和匹配块。\n\n使用[`~modular_pipelines.modular_pipeline_utils.InsertableDict.insert`]在块类或`sub_blocks`属性上添加一个块。\n\n```py\n# BLOCKS是块类的字典，您需要向其中添加类\nBLOCKS.insert(\"block_name\", BlockClass, index)\n# sub_blocks属性包含实例，向该属性添加一个块实例\nt2i_blocks.sub_blocks.insert(\"block_name\", block_instance, index)\n```\n\n使用[`~modular_pipelines.modular_pipeline_utils.InsertableDict.pop`]在块类或`sub_blocks`属性上移除一个块。\n\n```py\n# 从预设中移除一个块类\nBLOCKS.pop(\"text_encoder\")\n# 分离出一个块实例\ntext_encoder_block = t2i_blocks.sub_blocks.pop(\"text_encoder\")\n```\n\n通过将现有块设置为新块来交换块。\n\n```py\n# 在预设中替换块类\nBLOCKS[\"prepare_latents\"] = CustomPrepareLatents\n# 使用块实例在sub_blocks属性中替换\nt2i_blocks.sub_blocks[\"prepare_latents\"] = CustomPrepareLatents()\n```\n\n## 创建管道\n\n有两种方法可以创建一个[`ModularPipeline`]。从[`ModularPipelineBlocks`]组装并创建管道，或使用[`~ModularPipeline.from_pretrained`]加载现有管道。\n\n您还应该初始化一个[`ComponentsManager`]来处理设备放置和内存以及组件管理。\n\n> [!TIP]\n> 有关它如何帮助管理不同工作流中的组件的更多详细信息，请参阅[ComponentsManager](./components_manager)文档。\n\n<hfoptions id=\"create\">\n<hfoption id=\"ModularPipelineBlocks\">\n\n使用[`~ModularPipelineBlocks.init_pipeline`]方法从组件和配置规范创建一个[`ModularPipeline`]。此方法从`modular_model_index.json`文件加载*规范*，但尚未加载*模型*。\n\n```py\nfrom diffusers import ComponentsManager\nfrom diffusers.modular_pipelines import SequentialPipelineBlocks\nfrom diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS\n\nt2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)\n\nmodular_repo_id = \"YiYiXu/modular-loader-t2i-0704\"\ncomponents = ComponentsManager()\nt2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)\n```\n\n</hfoption>\n<hfoption id=\"from_pretrained\">\n\n[`~ModularPipeline.from_pretrained`]方法创建一个[`ModularPipeline`]从Hub上的模块化仓库加载。\n\n```py\nfrom diffusers import ModularPipeline, ComponentsManager\n\ncomponents = ComponentsManager()\npipeline = ModularPipeline.from_pretrained(\"YiYiXu/modular-loader-t2i-0704\", components_manager=components)\n```\n\n添加`trust_remote_code`参数以加载自定义的[`ModularPipeline`]。\n\n```py\nfrom diffusers import ModularPipeline, ComponentsManager\n\ncomponents = ComponentsManager()\nmodular_repo_id = \"YiYiXu/modular-diffdiff-0704\"\ndiffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True, components_manager=components)\n```\n\n</hfoption>\n</hfoptions>\n\n## 加载组件\n\n一个[`ModularPipeline`]不会自动实例化组件。它只加载配置和组件规范。您可以使用[`~ModularPipeline.load_components`]加载所有组件，或仅使用[`~ModularPipeline.load_components`]加载特定组件。\n\n<hfoptions id=\"load\">\n<hfoption id=\"load_components\">\n\n```py\nimport torch\n\nt2i_pipeline.load_components(torch_dtype=torch.float16)\nt2i_pipeline.to(\"cuda\")\n```\n\n</hfoption>\n<hfoption id=\"load_components\">\n\n下面的例子仅加载UNet和VAE。\n\n```py\nimport torch\n\nt2i_pipeline.load_components(names=[\"unet\", \"vae\"], torch_dtype=torch.float16)\n```\n\n</hfoption>\n</hfoptions>\n\n打印管道以检查加载的预训练组件。\n\n```py\nt2i_pipeline\n```\n\n这应该与管道初始化自的模块化仓库中的`modular_model_index.json`文件匹配。如果管道不需要某个组件，即使它在模块化仓库中存在，也不会被包含。\n\n要修改组件加载的来源，编辑仓库中的`modular_model_index.json`文件，并将其更改为您希望的加载路径。下面的例子从不同的仓库加载UNet。\n\n```json\n# 原始\n\"unet\": [\n  null, null,\n  {\n    \"repo\": \"stabilityai/stable-diffusion-xl-base-1.0\",\n    \"subfolder\": \"unet\",\n    \"variant\": \"fp16\"\n  }\n]\n\n# 修改后\n\"unet\": [\n  null, null,\n  {\n    \"repo\": \"RunDiffusion/Juggernaut-XL-v9\",\n    \"subfolder\": \"unet\",\n    \"variant\": \"fp16\"\n  }\n]\n```\n\n### 组件加载状态\n\n下面的管道属性提供了关于哪些组件被加载的更多信息。\n\n使用`component_names`返回所有预期的组件。\n\n```py\nt2i_pipeline.component_names\n['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor']\n```\n\n使用`null_component_names`返回尚未加载的组件。使用[`~ModularPipeline.from_pretrained`]加载这些组件。\n\n```py\nt2i_pipeline.null_component_names\n['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler']\n```\n\n使用`pretrained_component_names`返回将从预训练模型加载的组件。\n\n```py\nt2i_pipeline.pretrained_component_names\n['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae']\n```\n\n使用 `config_component_names` 返回那些使用默认配置创建的组件（不是从模块化仓库加载的）。来自配置的组件不包括在内，因为它们已经在管道创建期间初始化。这就是为什么它们没有列在 `null_component_names` 中。\n\n```py\nt2i_pipeline.config_component_names\n['guider', 'image_processor']\n```\n\n## 更新组件\n\n根据组件是*预训练组件*还是*配置组件*，组件可能会被更新。\n\n> [!WARNING]\n> 在更新组件时，组件可能会从预训练变为配置。组件类型最初是在块的 `expected_components` 字段中定义的。\n\n预训练组件通过 [`ComponentSpec`] 更新，而配置组件则通过直接传递对象或使用 [`ComponentSpec`] 更新。\n\n[`ComponentSpec`] 对于预训练组件显示 `default_creation_method=\"from_pretrained\"`，对于配置组件显示 `default_creation_method=\"from_config`。\n\n要更新预训练组件，创建一个 [`ComponentSpec`]，指定组件的名称和从哪里加载它。使用 [`~ComponentSpec.load`] 方法来加载组件。\n\n```py\nfrom diffusers import ComponentSpec, UNet2DConditionModel\n\nunet_spec = ComponentSpec(name=\"unet\",type_hint=UNet2DConditionModel, repo=\"stabilityai/stable-diffusion-xl-base-1.0\", subfolder=\"unet\", variant=\"fp16\")\nunet = unet_spec.load(torch_dtype=torch.float16)\n```\n\n[`~ModularPipeline.update_components`] 方法用一个新的组件替换原来的组件。\n\n```py\nt2i_pipeline.update_components(unet=unet2)\n```\n\n当组件被更新时，加载规范也会在管道配置中更新。\n\n### 组件提取和修改\n\n当你使用 [`~ComponentSpec.load`] 时，新组件保持其加载规范。这使得提取规范并重新创建组件成为可能。\n\n```py\nspec = ComponentSpec.from_component(\"unet\", unet2)\nspec\nComponentSpec(name='unet', type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained')\nunet2_recreated = spec.load(torch_dtype=torch.float16)\n```\n\n[`~ModularPipeline.get_component_spec`] 方法获取当前组件规范的副本以进行修改或更新。\n\n```py\nunet_spec = t2i_pipeline.get_component_spec(\"unet\")\nunet_spec\nComponentSpec(\n    name='unet',\n    type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,\n    pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',\n    subfolder='unet',\n    variant='fp16',\n    default_creation_method='from_pretrained'\n)\n\n# 修改以从不同的仓库加载\nunet_spec.pretrained_model_name_or_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\n\n# 使用修改后的规范加载组件\nunet = unet_spec.load(torch_dtype=torch.float16)\n```\n\n## 模块化仓库\n一个仓库\n如果管道块使用*预训练组件*，则需要y。该存储库提供了加载规范和元数据。\n\n[`ModularPipeline`]特别需要*模块化存储库*（参见[示例存储库](https://huggingface.co/YiYiXu/modular-diffdiff)），这比典型的存储库更灵活。它包含一个`modular_model_index.json`文件，包含以下3个元素。\n\n- `library`和`class`显示组件是从哪个库加载的及其类。如果是`null`，则表示组件尚未加载。\n- `loading_specs_dict`包含加载组件所需的信息，例如从中加载的存储库和子文件夹。\n\n与标准存储库不同，模块化存储库可以根据`loading_specs_dict`从不同的存储库获取组件。组件不需要存在于同一个存储库中。\n\n模块化存储库可能包含用于加载[`ModularPipeline`]的自定义代码。这允许您使用不是Diffusers原生的专用块。\n\n```\nmodular-diffdiff-0704/\n├── block.py                    # 自定义管道块实现\n├── config.json                 # 管道配置和auto_map\n└── modular_model_index.json    # 组件加载规范\n```\n\n[config.json](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json)文件包含一个`auto_map`键，指向`block.py`中定义自定义块的位置。\n\n```json\n{\n  \"_class_name\": \"DiffDiffBlocks\",\n  \"auto_map\": {\n    \"ModularPipelineBlocks\": \"block.DiffDiffBlocks\"\n  }\n}\n```\n"
  },
  {
    "path": "docs/source/zh/modular_diffusers/overview.md",
    "content": "<!--版权所有 2025 The HuggingFace Team。保留所有权利。\n\n根据Apache许可证2.0版（\"许可证\"）授权；除非符合许可证，否则不得使用此文件。您可以在以下位置获取许可证的副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件按\"原样\"分发，无任何明示或暗示的担保或条件。有关许可证下特定语言的权限和限制，请参阅许可证。\n-->\n\n# 概述\n\n> [!WARNING]\n> 模块化Diffusers正在积极开发中，其API可能会发生变化。\n\n模块化Diffusers是一个统一的管道系统，通过*管道块*简化您的工作流程。\n\n- 块是可重用的，您只需要为您的管道创建独特的块。\n- 块可以混合搭配，以适应或为特定工作流程或多个工作流程创建管道。\n\n模块化Diffusers文档的组织如下所示。\n\n## 快速开始\n\n- 一个[快速开始](./quickstart)演示了如何使用模块化Diffusers实现一个示例工作流程。\n\n## ModularPipelineBlocks\n\n- [States](./modular_diffusers_states)解释了数据如何在块和[`ModularPipeline`]之间共享和通信。\n- [ModularPipelineBlocks](./pipeline_block)是[`ModularPipeline`]最基本的单位，本指南向您展示如何创建一个。\n- [SequentialPipelineBlocks](./sequential_pipeline_blocks)是一种类型的块，它将多个块链接起来，使它们一个接一个地运行，沿着链传递数据。本指南向您展示如何创建[`~modular_pipelines.SequentialPipelineBlocks`]以及它们如何连接和一起工作。\n- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks)是一种类型的块，它在循环中运行一系列块。本指南向您展示如何创建[`~modular_pipelines.LoopSequentialPipelineBlocks`]。\n- [AutoPipelineBlocks](./auto_pipeline_blocks)是一种类型的块，它根据输入自动选择要运行的块。本指南向您展示如何创建[`~modular_pipelines.AutoPipelineBlocks`]。\n\n## ModularPipeline\n\n- [ModularPipeline](./modular_pipeline)向您展示如何创建并将管道块转换为可执行的[`ModularPipeline`]。\n- [ComponentsManager](./components_manager)向您展示如何跨多个管道管理和重用组件。\n- [Guiders](./guiders)向您展示如何在管道中使用不同的指导方法。\n"
  },
  {
    "path": "docs/source/zh/modular_diffusers/pipeline_block.md",
    "content": "<!--版权 2025 The HuggingFace Team。保留所有权利。\n\n根据Apache许可证2.0版（“许可证”）授权；除非符合许可证，否则不得使用此文件。您可以在\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n获取许可证的副本。\n\n除非适用法律要求或书面同意，根据许可证分发的软件是基于“按原样”基础分发的，没有任何明示或暗示的保证或条件。请参阅许可证了解特定语言管理权限和限制。\n-->\n\n# ModularPipelineBlocks\n\n[`~modular_pipelines.ModularPipelineBlocks`] 是构建 [`ModularPipeline`] 的基本块。它定义了管道中特定步骤应执行的组件、输入/输出和计算。一个 [`~modular_pipelines.ModularPipelineBlocks`] 与其他块连接，使用 [状态](./modular_diffusers_states)，以实现工作流的模块化构建。\n\n单独的 [`~modular_pipelines.ModularPipelineBlocks`] 无法执行。它是管道中步骤应执行的操作的蓝图。要实际运行和执行管道，需要将 [`~modular_pipelines.ModularPipelineBlocks`] 转换为 [`ModularPipeline`]。\n\n本指南将向您展示如何创建 [`~modular_pipelines.ModularPipelineBlocks`]。\n\n## 输入和输出\n\n> [!TIP]\n> 如果您不熟悉Modular Diffusers中状态的工作原理，请参考 [States](./modular_diffusers_states) 指南。\n\n一个 [`~modular_pipelines.ModularPipelineBlocks`] 需要 `inputs` 和 `intermediate_outputs`。\n\n- `inputs` 是由用户提供并从 [`~modular_pipelines.PipelineState`] 中检索的值。这很有用，因为某些工作流会调整图像大小，但仍需要原始图像。 [`~modular_pipelines.PipelineState`] 维护原始图像。\n\n    使用 `InputParam` 定义 `inputs`。\n\n    ```py\n    from diffusers.modular_pipelines import InputParam\n\n    user_inputs = [\n        InputParam(name=\"image\", type_hint=\"PIL.Image\", description=\"要处理的原始输入图像\")\n    ]\n    ```\n\n- `intermediate_inputs` 通常由前一个块创建的值，但如果前面的块没有生成它们，也可以直接提供。与 `inputs` 不同，`intermediate_inputs` 可以被修改。\n\n    使用 `InputParam` 定义 `intermediate_inputs`。\n\n    ```py\n    user_intermediate_inputs = [\n        InputParam(name=\"processed_image\", type_hint=\"torch.Tensor\", description=\"image that has been preprocessed and normalized\"),\n    ]\n    ```\n\n- `intermediate_outputs` 是由块创建并添加到 [`~modular_pipelines.PipelineState`] 的新值。`intermediate_outputs` 可作为后续块的 `intermediate_inputs` 使用，或作为运行管道的最终输出使用。\n\n    使用 `OutputParam` 定义 `intermediate_outputs`。\n\n    ```py\n    from diffusers.modular_pipelines import OutputParam\n\n        user_intermediate_outputs = [\n        OutputParam(name=\"image_latents\", description=\"latents representing the image\")\n    ]\n    ```\n\n中间输入和输出共享数据以连接块。它们可以在任何时候访问，允许你跟踪工作流的进度。\n\n## 计算逻辑\n\n一个块执行的计算在`__call__`方法中定义，它遵循特定的结构。\n\n1. 检索[`~modular_pipelines.BlockState`]以获取`inputs`和`intermediate_inputs`的局部视图。\n2. 在`inputs`和`intermediate_inputs`上实现计算逻辑。\n3. 更新[`~modular_pipelines.PipelineState`]以将局部[`~modular_pipelines.BlockState`]的更改推送回全局[`~modular_pipelines.PipelineState`]。\n4. 返回对下一个块可用的组件和状态。\n\n```py\ndef __call__(self, components, state):\n    # 获取该块需要的状态变量的局部视图\n    block_state = self.get_block_state(state)\n\n    # 你的计算逻辑在这里\n    # block_state包含你所有的inputs和intermediate_inputs\n    # 像这样访问它们: block_state.image, block_state.processed_image\n\n    # 用你更新的block_states更新管道状态\n    self.set_block_state(state, block_state)\n    return components, state\n```\n\n### 组件和配置\n\n块需要的组件和管道级别的配置在[`ComponentSpec`]和[`~modular_pipelines.ConfigSpec`]中指定。\n\n- [`ComponentSpec`]包含块使用的预期组件。你需要组件的`name`和理想情况下指定组件确切是什么的`type_hint`。\n- [`~modular_pipelines.ConfigSpec`]包含控制所有块行为的管道级别设置。\n\n```py\nfrom diffusers import ComponentSpec, ConfigSpec\n\nexpected_components = [\n    ComponentSpec(name=\"unet\", type_hint=UNet2DConditionModel),\n    ComponentSpec(name=\"scheduler\", type_hint=EulerDiscreteScheduler)\n]\n\nexpected_config = [\n    ConfigSpec(\"force_zeros_for_empty_prompt\", True)\n]\n```\n\n当块被转换为管道时，组件作为`__call__`中的第一个参数对块可用。\n\n```py\ndef __call__(self, components, state):\n    # 使用点符号访问组件\n    unet = components.unet\n    vae = components.vae\n    scheduler = components.scheduler\n```\n"
  },
  {
    "path": "docs/source/zh/modular_diffusers/quickstart.md",
    "content": "<!--版权所有 2025 The HuggingFace Team。保留所有权利。\n\n根据Apache许可证2.0版（\"许可证\"）授权；除非符合许可证，否则不得使用此文件。您可以在\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n获取许可证的副本。\n\n除非适用法律要求或书面同意，根据许可证分发的软件按\"原样\"分发，无任何明示或暗示的担保或条件。有关许可证下特定语言的管理权限和限制，请参阅许可证。\n-->\n\n# 快速入门\n\n模块化Diffusers是一个快速构建灵活和可定制管道的框架。模块化Diffusers的核心是[`ModularPipelineBlocks`]，可以与其他块组合以适应新的工作流程。这些块被转换为[`ModularPipeline`]，一个开发者可以使用的友好用户界面。\n\n本文档将向您展示如何使用模块化框架实现[Differential Diffusion](https://differential-diffusion.github.io/)管道。\n\n## ModularPipelineBlocks\n\n[`ModularPipelineBlocks`]是*定义*，指定管道中单个步骤的组件、输入、输出和计算逻辑。有四种类型的块。\n\n- [`ModularPipelineBlocks`]是最基本的单一步骤块。\n- [`SequentialPipelineBlocks`]是一个多块，线性组合其他块。一个块的输出是下一个块的输入。\n- [`LoopSequentialPipelineBlocks`]是一个多块，迭代运行，专为迭代工作流程设计。\n- [`AutoPipelineBlocks`]是一个针对不同工作流程的块集合，它根据输入选择运行哪个块。它旨在方便地将多个工作流程打包到单个管道中。\n\n[Differential Diffusion](https://differential-diffusion.github.io/)是一个图像到图像的工作流程。从`IMAGE2IMAGE_BLOCKS`预设开始，这是一个用于图像到图像生成的`ModularPipelineBlocks`集合。\n\n```py\nfrom diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS\nIMAGE2IMAGE_BLOCKS = InsertableDict([\n    (\"text_encoder\", StableDiffusionXLTextEncoderStep),\n    (\"image_encoder\", StableDiffusionXLVaeEncoderStep),\n    (\"input\", StableDiffusionXLInputStep),\n    (\"set_timesteps\", StableDiffusionXLImg2ImgSetTimestepsStep),\n    (\"prepare_latents\", StableDiffusionXLImg2ImgPrepareLatentsStep),\n    (\"prepare_add_cond\", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),\n    (\"denoise\", StableDiffusionXLDenoiseStep),\n    (\"decode\", StableDiffusionXLDecodeStep)\n])\n```\n\n## 管道和块状态\n\n模块化Diffusers使用*状态*在块之间通信数据。有两种类型的状态。\n\n- [`PipelineState`]是一个全局状态，可用于跟踪所有块的所有输入和输出。\n- [`BlockState`]是[`PipelineState`]中相关变量的局部视图，用于单个块。\n\n## 自定义块\n\n[Differential Diffusion](https://differential-diffusion.github.io/) 与标准的图像到图像转换在其 `prepare_latents` 和 `denoise` 块上有所不同。所有其他块都可以重用，但你需要修改这两个。\n\n通过复制和修改现有的块，为 `prepare_latents` 和 `denoise` 创建占位符 `ModularPipelineBlocks`。\n\n打印 `denoise` 块，可以看到它由 [`LoopSequentialPipelineBlocks`] 组成，包含三个子块，`before_denoiser`、`denoiser` 和 `after_denoiser`。只需要修改 `before_denoiser` 子块，根据变化图为去噪器准备潜在输入。\n\n```py\ndenoise_blocks = IMAGE2IMAGE_BLOCKS[\"denoise\"]()\nprint(denoise_blocks)\n```\n\n用新的 `SDXLDiffDiffLoopBeforeDenoiser` 块替换 `StableDiffusionXLLoopBeforeDenoiser` 子块。\n\n```py\n# 复制现有块作为占位符\nclass SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):\n    \"\"\"Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later\"\"\"\n    # ... 与 StableDiffusionXLImg2ImgPrepareLatentsStep 相同的实现\n\nclass SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):\n    block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]\n    block_names = [\"before_denoiser\", \"denoiser\", \"after_denoiser\"]\n```\n\n### prepare_latents\n\n`prepare_latents` 块需要进行以下更改。\n\n- 一个处理器来处理变化图\n- 一个新的 `inputs` 来接受用户提供的变化图，`timestep` 用于预计算所有潜在变量和 `num_inference_steps` 来创建更新图像区域的掩码\n- 更新 `__call__` 方法中的计算，用于处理变化图和创建掩码，并将其存储在 [`BlockState`] 中\n\n```diff\nclass SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):\n    @property\n    def expected_components(self) -> List[ComponentSpec]:\n        return [\n            ComponentSpec(\"vae\", AutoencoderKL),\n            ComponentSpec(\"scheduler\", EulerDiscreteScheduler),\n+           ComponentSpec(\"mask_processor\", VaeImageProcessor, config=FrozenDict({\"do_normalize\": False, \"do_convert_grayscale\": True}))\n        ]\n    @property\n    def inputs(self) -> List[Tuple[str, Any]]:\n        return [\n            InputParam(\"generator\"),\n+           InputParam(\"diffdiff_map\", required=True),\n-           InputParam(\"latent_timestep\", required=True, type_hint=torch.Tensor),\n+           InputParam(\"timesteps\", type_hint=torch.Tensor),\n+           InputParam(\"num_inference_steps\", type_hint=int),\n        ]\n\n    @property\n    def intermediate_outputs(self) -> List[OutputParam]:\n        return [\n+           OutputParam(\"original_latents\", type_hint=torch.Tensor),\n+           OutputParam(\"diffdiff_masks\", type_hint=torch.Tensor),\n        ]\n    def __call__(self, components, state: PipelineState):\n        # ... existing logic ...\n+       # Process change map and create masks\n+       diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)\n+       thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps\n+       block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))\n+       block_state.original_latents = block_state.latents\n```\n\n### 去噪\n\n`before_denoiser` 子块需要进行以下更改。\n\n- 新的 `inputs` 以接受 `denoising_start` 参数，`original_latents` 和 `diffdiff_masks` 来自 `prepare_latents` 块\n- 更新 `__call__` 方法中的计算以应用 Differential Diffusion\n\n```diff\nclass SDXLDiffDiffLoopBeforeDenoiser(ModularPipelineBlocks):\n    @property\n    def description(self) -> str:\n        return (\n            \"Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser\"\n        )\n\n    @property\n    def inputs(self) -> List[str]:\n        return [\n            InputParam(\"latents\", required=True, type_hint=torch.Tensor),\n+           InputParam(\"denoising_start\"),\n+           InputParam(\"original_latents\", type_hint=torch.Tensor),\n+           InputParam(\"diffdiff_masks\", type_hint=torch.Tensor),\n        ]\n\n    def __call__(self, components, block_state, i, t):\n+       # Apply differential diffusion logic\n+       if i == 0 and block_state.denoising_start is None:\n+           block_state.latents = block_state.original_latents[:1]\n+       else:\n+           block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)\n+           block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)\n\n        # ... rest of existing logic ...\n```\n\n## 组装块\n\n此时，您应该拥有创建 [`ModularPipeline`] 所需的所有块。\n\n复制现有的 `IMAGE2IMAGE_BLOCKS` 预设，对于 `set_timesteps` 块，使用 `TEXT2IMAGE_BLOCKS` 中的 `set_timesteps`，因为 Differential Diffusion 不需要 `strength` 参数。\n\n将 `prepare_latents` 和 `denoise` 块设置为您刚刚修改的 `SDXLDiffDiffPrepareLatentsStep` 和 `SDXLDiffDiffDenoiseStep` 块。\n\n调用 [`SequentialPipelineBlocks.from_blocks_dict`] 在块上创建一个 `SequentialPipelineBlocks`。\n\n```py\nDIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()\nDIFFDIFF_BLOCKS[\"set_timesteps\"] = TEXT2IMAGE_BLOCKS[\"set_timesteps\"]\nDIFFDIFF_BLOCKS[\"prepare_latents\"] = SDXLDiffDiffPrepareLatentsStep\nDIFFDIFF_BLOCKS[\"denoise\"] = SDXLDiffDiffDenoiseStep\n\ndd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)\nprint(dd_blocks)\n```\n\n## ModularPipeline\n\n将 [`SequentialPipelineBlocks`] 转换为 [`ModularPipeline`]，使用 [`ModularPipeline.init_pipeline`] 方法。这会初始化从 `modular_model_index.json` 文件加载的预期组件。通过调用 [`ModularPipeline.load_defau\nlt_components`]。\n\n初始化[`ComponentManager`]时传入pipeline是一个好主意，以帮助管理不同的组件。一旦调用[`~ModularPipeline.load_components`]，组件就会被注册到[`ComponentManager`]中，并且可以在工作流之间共享。下面的例子使用`collection`参数为组件分配了一个`\"diffdiff\"`标签，以便更好地组织。\n\n```py\nfrom diffusers.modular_pipelines import ComponentsManager\n\ncomponents = ComponentManager()\n\ndd_pipeline = dd_blocks.init_pipeline(\"YiYiXu/modular-demo-auto\", components_manager=components, collection=\"diffdiff\")\ndd_pipeline.load_default_componenets(torch_dtype=torch.float16)\ndd_pipeline.to(\"cuda\")\n```\n\n## 添加工作流\n\n可以向[`ModularPipeline`]添加其他工作流以支持更多功能，而无需从头重写整个pipeline。\n\n本节演示如何添加IP-Adapter或ControlNet。\n\n### IP-Adapter\n\nStable Diffusion XL已经有一个预设的IP-Adapter块，你可以使用，并且不需要对现有的Differential Diffusion pipeline进行任何更改。\n\n```py\nfrom diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep\n\nip_adapter_block = StableDiffusionXLAutoIPAdapterStep()\n```\n\n使用[`sub_blocks.insert`]方法将其插入到[`ModularPipeline`]中。下面的例子在位置`0`插入了`ip_adapter_block`。打印pipeline可以看到`ip_adapter_block`被添加了，并且它需要一个`ip_adapter_image`。这也向pipeline添加了两个组件，`image_encoder`和`feature_extractor`。\n\n```py\ndd_blocks.sub_blocks.insert(\"ip_adapter\", ip_adapter_block, 0)\n```\n\n调用[`~ModularPipeline.init_pipeline`]来初始化一个[`ModularPipeline`]，并使用[`~ModularPipeline.load_components`]加载模型组件。加载并设置IP-Adapter以运行pipeline。\n\n```py\ndd_pipeline = dd_blocks.init_pipeline(\"YiYiXu/modular-demo-auto\", collection=\"diffdiff\")\ndd_pipeline.load_components(torch_dtype=torch.float16)\ndd_pipeline.loader.load_ip_adapter(\"h94/IP-Adapter\", subfolder=\"sdxl_models\", weight_name=\"ip-adapter_sdxl.bin\")\ndd_pipeline.loader.set_ip_adapter_scale(0.6)\ndd_pipeline = dd_pipeline.to(device)\n\nip_adapter_image = load_image(\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg\")\nimage = load_image(\"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true\")\nmask = load_image(\"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true\")\n\nprompt = \"a green pear\"\nnegative_prompt = \"blurry\"\ngenerator = torch.Generator(device=device).manual_seed(42)\n\nimage = dd_pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_inference_steps=25,\n    generator=generator,\n    ip_adapter_image=ip_adapter_image,\n    diffdiff_map=mask,\n    image=image,\n\noutput=\"images\"\n)[0]\n```\n\n### ControlNet\n\nStable Diffusion XL 已经预设了一个可以立即使用的 ControlNet 块。\n\n```py\nfrom diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep\n\ncontrol_input_block = StableDiffusionXLAutoControlNetInputStep()\n```\n\n然而，它需要修改 `denoise` 块，因为那是 ControlNet 将控制信息注入到 UNet 的地方。\n\n通过将 `StableDiffusionXLLoopDenoiser` 子块替换为 `StableDiffusionXLControlNetLoopDenoiser` 来修改 `denoise` 块。\n\n```py\nclass SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):\n    block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]\n    block_names = [\"before_denoiser\", \"denoiser\", \"after_denoiser\"]\n\ncontrolnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()\n```\n\n插入 `controlnet_input` 块并用新的 `controlnet_denoise_block` 替换 `denoise` 块。初始化一个 [`ModularPipeline`] 并将 [`~ModularPipeline.load_components`] 加载到其中。\n\n```py\ndd_blocks.sub_blocks.insert(\"controlnet_input\", control_input_block, 7)\ndd_blocks.sub_blocks[\"denoise\"] = controlnet_denoise_block\n\ndd_pipeline = dd_blocks.init_pipeline(\"YiYiXu/modular-demo-auto\", collection=\"diffdiff\")\ndd_pipeline.load_components(torch_dtype=torch.float16)\ndd_pipeline = dd_pipeline.to(device)\n\ncontrol_image = load_image(\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg\")\nimage = load_image(\"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true\")\nmask = load_image(\"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true\")\n\nprompt = \"a green pear\"\nnegative_prompt = \"blurry\"\ngenerator = torch.Generator(device=device).manual_seed(42)\n\nimage = dd_pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_inference_steps=25,\n    generator=generator,\n    control_image=control_image,\n    controlnet_conditioning_scale=0.5,\n    diffdiff_map=mask,\n    image=image,\n    output=\"images\"\n)[0]\n```\n\n### AutoPipelineBlocks\n\n差分扩散、IP-Adapter 和 ControlNet 工作流可以通过使用 [`AutoPipelineBlocks`] 捆绑到一个单一的 [`ModularPipeline`] 中。这允许根据输入如 `control_image` 或 `ip_adapter_image` 自动选择要运行的子块。如果没有传递这些输入，则默认为差分扩散。\n\n使用 `block_trigger_inputs` 仅在提供 `control_image` 输入时运行 `SDXLDiffDiffControlNetDenoiseStep` 块。否则，使用 `SDXLDiffDiffDenoiseStep`。\n\n```py\nclass SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):\n    block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]\n    block_names = [\"contr\nolnet_denoise\", \"denoise\"]\nblock_trigger_inputs = [\"controlnet_cond\", None]\n```\n\n添加 `ip_adapter` 和 `controlnet_input` 块。\n\n```py\nDIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()\nDIFFDIFF_AUTO_BLOCKS[\"prepare_latents\"] = SDXLDiffDiffPrepareLatentsStep\nDIFFDIFF_AUTO_BLOCKS[\"set_timesteps\"] = TEXT2IMAGE_BLOCKS[\"set_timesteps\"]\nDIFFDIFF_AUTO_BLOCKS[\"denoise\"] = SDXLDiffDiffAutoDenoiseStep\nDIFFDIFF_AUTO_BLOCKS.insert(\"ip_adapter\", StableDiffusionXLAutoIPAdapterStep, 0)\nDIFFDIFF_AUTO_BLOCKS.insert(\"controlnet_input\",StableDiffusionXLControlNetAutoInput, 7)\n```\n\n调用 [`SequentialPipelineBlocks.from_blocks_dict`] 来创建一个 [`SequentialPipelineBlocks`] 并创建一个 [`ModularPipeline`] 并加载模型组件以运行。\n\n```py\ndd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)\ndd_pipeline = dd_auto_blocks.init_pipeline(\"YiYiXu/modular-demo-auto\", collection=\"diffdiff\")\ndd_pipeline.load_components(torch_dtype=torch.float16)\n```\n\n## 分享\n\n使用 [`~ModularPipeline.save_pretrained`] 将您的 [`ModularPipeline`] 添加到 Hub，并将 `push_to_hub` 参数设置为 `True`。\n\n```py\ndd_pipeline.save_pretrained(\"YiYiXu/test_modular_doc\", push_to_hub=True)\n```\n\n其他用户可以使用 [`~ModularPipeline.from_pretrained`] 加载 [`ModularPipeline`]。\n\n```py\nimport torch\nfrom diffusers.modular_pipelines import ModularPipeline, ComponentsManager\n\ncomponents = ComponentsManager()\n\ndiffdiff_pipeline = ModularPipeline.from_pretrained(\"YiYiXu/modular-diffdiff-0704\", trust_remote_code=True, components_manager=components, collection=\"diffdiff\")\ndiffdiff_pipeline.load_components(torch_dtype=torch.float16)\n```\n"
  },
  {
    "path": "docs/source/zh/modular_diffusers/sequential_pipeline_blocks.md",
    "content": "<!--版权 2025 The HuggingFace Team。保留所有权利。\n\n根据Apache许可证2.0版（\"许可证\"）授权；除非符合许可证，否则不得使用此文件。您可以在\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n获取许可证的副本。\n\n除非适用法律要求或书面同意，根据许可证分发的软件是基于\"按原样\"基础分发的，没有任何形式的明示或暗示的保证或条件。有关许可证下特定语言的管理权限和限制，请参阅许可证。\n-->\n\n# 顺序管道块\n\n[`~modular_pipelines.SequentialPipelineBlocks`] 是一种多块类型，它将其他 [`~modular_pipelines.ModularPipelineBlocks`] 按顺序组合在一起。数据通过 `intermediate_inputs` 和 `intermediate_outputs` 线性地从一个块流向下一个块。[`~modular_pipelines.SequentialPipelineBlocks`] 中的每个块通常代表管道中的一个步骤，通过组合它们，您逐步构建一个管道。\n\n本指南向您展示如何将两个块连接成一个 [`~modular_pipelines.SequentialPipelineBlocks`]。\n\n创建两个 [`~modular_pipelines.ModularPipelineBlocks`]。第一个块 `InputBlock` 输出一个 `batch_size` 值，第二个块 `ImageEncoderBlock` 使用 `batch_size` 作为 `intermediate_inputs`。\n\n<hfoptions id=\"sequential\">\n<hfoption id=\"InputBlock\">\n\n```py\nfrom diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam\n\nclass InputBlock(ModularPipelineBlocks):\n\n    @property\n    def inputs(self):\n        return [\n            InputParam(name=\"prompt\", type_hint=list, description=\"list of text prompts\"),\n            InputParam(name=\"num_images_per_prompt\", type_hint=int, description=\"number of images per prompt\"),\n        ]\n\n    @property\n    def intermediate_outputs(self):\n        return [\n            OutputParam(name=\"batch_size\", description=\"calculated batch size\"),\n        ]\n\n    @property\n    def description(self):\n        return \"A block that determines batch_size based on the number of prompts and num_images_per_prompt argument.\"\n\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        batch_size = len(block_state.prompt)\n        block_state.batch_size = batch_size * block_state.num_images_per_prompt\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\n</hfoption>\n<hfoption id=\"ImageEncoderBlock\">\n\n```py\nimport torch\nfrom diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam\n\nclass ImageEncoderBlock(ModularPipelineBlocks):\n\n    @property\n    def inputs(self):\n        return [\n            InputParam(name=\"image\", type_hint=\"PIL.Image\", description=\"raw input image to process\"),\n            InputParam(name=\"batch_size\", type_hint=int),\n        ]\n\n    @property\n    def intermediate_outputs(self):\n        return [\n            OutputParam(name=\"image_latents\", description=\"latents representing the image\"\n        ]\n\n    @property\n    def description(self):\n        return \"Encode raw image into its latent presentation\"\n\n    def __call__(self, components, state):\n        block_state = self.get_block_state(state)\n        # 模拟处理图像\n        # 这将改变所有块的图像状态，从PIL图像变为张量\n        block_state.image = torch.randn(1, 3, 512, 512)\n        block_state.batch_size = block_state.batch_size * 2\n        block_state.image_latents = torch.randn(1, 4, 64, 64)\n        self.set_block_state(state, block_state)\n        return components, state\n```\n\n</hfoption>\n</hfoptions>\n\n通过定义一个[`InsertableDict`]来连接两个块，将块名称映射到块实例。块按照它们在`blocks_dict`中注册的顺序执行。\n\n使用[`~modular_pipelines.SequentialPipelineBlocks.from_blocks_dict`]来创建一个[`~modular_pipelines.SequentialPipelineBlocks`]。\n\n```py\nfrom diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict\n\nblocks_dict = InsertableDict()\nblocks_dict[\"input\"] = input_block\nblocks_dict[\"image_encoder\"] = image_encoder_block\n\nblocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)\n```\n\n通过调用`blocks`来检查[`~modular_pipelines.SequentialPipelineBlocks`]中的子块，要获取更多关于输入和输出的详细信息，可以访问`docs`属性。\n\n```py\nprint(blocks)\nprint(blocks.doc)\n```\n"
  },
  {
    "path": "docs/source/zh/optimization/cache.md",
    "content": "<!-- 版权所有 2025 HuggingFace 团队。保留所有权利。\n\n根据 Apache 许可证 2.0 版本（“许可证”）授权；除非符合许可证，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，否则根据许可证分发的软件按“原样”分发，不附带任何明示或暗示的担保或条件。请参阅许可证以了解具体的语言管理权限和限制。 -->\n\n# 缓存\n\n缓存通过存储和重用不同层的中间输出（如注意力层和前馈层）来加速推理，而不是在每个推理步骤执行整个计算。它显著提高了生成速度，但以更多内存为代价，并且不需要额外的训练。\n\n本指南向您展示如何在 Diffusers 中使用支持的缓存方法。\n\n## 金字塔注意力广播\n\n[金字塔注意力广播 (PAB)](https://huggingface.co/papers/2408.12588) 基于这样一种观察：在生成过程的连续时间步之间，注意力输出差异不大。注意力差异在交叉注意力层中最小，并且通常在一个较长的时间步范围内被缓存。其次是时间注意力和空间注意力层。\n\n> [!TIP]\n> 并非所有视频模型都有三种类型的注意力（交叉、时间和空间）！\n\nPAB 可以与其他技术（如序列并行性和无分类器引导并行性（数据并行性））结合，实现近乎实时的视频生成。\n\n设置并传递一个 [`PyramidAttentionBroadcastConfig`] 到管道的变换器以启用它。`spatial_attention_block_skip_range` 控制跳过空间注意力块中注意力计算的频率，`spatial_attention_timestep_skip_range` 是要跳过的时间步范围。注意选择一个合适的范围，因为较小的间隔可能导致推理速度变慢，而较大的间隔可能导致生成质量降低。\n\n```python\nimport torch\nfrom diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig\n\npipeline = CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-5b\", torch_dtype=torch.bfloat16)\npipeline.to(\"cuda\")\n\nconfig = PyramidAttentionBroadcastConfig(\n    spatial_attention_block_skip_range=2,\n    spatial_attention_timestep_skip_range=(100, 800),\n    current_timestep_callback=lambda: pipe.current_timestep,\n)\npipeline.transformer.enable_cache(config)\n```\n\n## FasterCache\n\n[FasterCache](https://huggingface.co/papers/2410.19355) 缓存并重用注意力特征，类似于 [PAB](#pyramid-attention-broadcast)，因为每个连续时间步的输出差异很小。\n\n此方法在使用无分类器引导进行采样时（在大多数基础模型中常见），也可能选择跳过无条件分支预测，并且\n如果连续时间步之间的预测潜在输出存在显著冗余，则从条件分支预测中估计它。\n\n设置并将 [`FasterCacheConfig`] 传递给管道的 transformer 以启用它。\n\n```python\nimport torch\nfrom diffusers import CogVideoXPipeline, FasterCacheConfig\n\npipe line= CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-5b\", torch_dtype=torch.bfloat16)\npipeline.to(\"cuda\")\n\nconfig = FasterCacheConfig(\n    spatial_attention_block_skip_range=2,\n    spatial_attention_timestep_skip_range=(-1, 681),\n    current_timestep_callback=lambda: pipe.current_timestep,\n    attention_weight_callback=lambda _: 0.3,\n    unconditional_batch_skip_range=5,\n    unconditional_batch_timestep_skip_range=(-1, 781),\n    tensor_format=\"BFCHW\",\n)\npipeline.transformer.enable_cache(config)\n```"
  },
  {
    "path": "docs/source/zh/optimization/coreml.md",
    "content": "<!--版权所有 2025 The HuggingFace Team。保留所有权利。\n\n根据 Apache 许可证 2.0 版本（\"许可证\"）授权；除非符合许可证，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件按\"原样\"分发，无任何明示或暗示的担保或条件。有关许可证的具体语言，请参阅许可证中的权限和限制。\n-->\n\n# 如何使用 Core ML 运行 Stable Diffusion\n\n[Core ML](https://developer.apple.com/documentation/coreml) 是 Apple 框架支持的模型格式和机器学习库。如果您有兴趣在 macOS 或 iOS/iPadOS 应用中运行 Stable Diffusion 模型，本指南将展示如何将现有的 PyTorch 检查点转换为 Core ML 格式，并使用 Python 或 Swift 进行推理。\n\nCore ML 模型可以利用 Apple 设备中所有可用的计算引擎：CPU、GPU 和 Apple Neural Engine（或 ANE，一种在 Apple Silicon Mac 和现代 iPhone/iPad 中可用的张量优化加速器）。根据模型及其运行的设备，Core ML 还可以混合和匹配计算引擎，例如，模型的某些部分可能在 CPU 上运行，而其他部分在 GPU 上运行。\n\n> [!TIP]\n> 您还可以使用 PyTorch 内置的 `mps` 加速器在 Apple Silicon Mac 上运行 `diffusers` Python 代码库。这种方法在 [mps 指南](mps) 中有详细解释，但它与原生应用不兼容。\n\n## Stable Diffusion Core ML 检查点\n\nStable Diffusion 权重（或检查点）以 PyTorch 格式存储，因此在使用它们之前，需要将它们转换为 Core ML 格式。\n\n幸运的是，Apple 工程师基于 `diffusers` 开发了 [一个转换工具](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml)，用于将 PyTorch 检查点转换为 Core ML。\n\n但在转换模型之前，花点时间探索 Hugging Face Hub——很可能您感兴趣的模型已经以 Core ML 格式提供：\n\n- [Apple](https://huggingface.co/apple) 组织包括 Stable Diffusion 版本 1.4、1.5、2.0 基础和 2.1 基础\n- [coreml community](https://huggingface.co/coreml-community) 包括自定义微调模型\n- 使用此 [过滤器](https://huggingface.co/models?pipeline_tag=text-to-image&library=coreml&p=2&sort=likes) 返回所有可用的 Core ML 检查点\n\n如果您找不到感兴趣的模型，我们建议您遵循 Apple 的 [Converting Models to Core ML](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml) 说明。\n\n## 选择要使用的 Core ML 变体\n\nStable Diffusion 模型可以转换为不同的 Core ML 变体，用于不同目的：\n\n- 注意力类型\n使用了n个块。注意力操作用于“关注”图像表示中不同区域之间的关系，并理解图像和文本表示如何相关。注意力的计算和内存消耗很大，因此存在不同的实现方式，以适应不同设备的硬件特性。对于Core ML Stable Diffusion模型，有两种注意力变体：\n* `split_einsum`（[由Apple引入](https://machinelearning.apple.com/research/neural-engine-transformers)）针对ANE设备进行了优化，这些设备在现代iPhone、iPad和M系列计算机中可用。\n* “原始”注意力（在`diffusers`中使用的基础实现）仅与CPU/GPU兼容，不与ANE兼容。在CPU + GPU上使用`original`注意力运行模型可能比ANE*更快*。请参阅[此性能基准](https://huggingface.co/blog/fast-mac-diffusers#performance-benchmarks)以及社区提供的[一些额外测量](https://github.com/huggingface/swift-coreml-diffusers/issues/31)以获取更多细节。\n\n- 支持的推理框架。\n* `packages`适用于Python推理。这可用于在尝试将转换后的Core ML模型集成到原生应用程序之前进行测试，或者如果您想探索Core ML性能但不需要支持原生应用程序。例如，具有Web UI的应用程序完全可以使用Python Core ML后端。\n* `compiled`模型是Swift代码所必需的。Hub中的`compiled`模型将大型UNet模型权重分成多个文件，以兼容iOS和iPadOS设备。这对应于[`--chunk-unet`转换选项](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml)。如果您想支持原生应用程序，则需要选择`compiled`变体。\n\n官方的Core ML Stable Diffusion[模型](https://huggingface.co/apple/coreml-stable-diffusion-v1-4/tree/main)包括这些变体，但社区的可能有所不同：\n\n```\ncoreml-stable-diffusion-v1-4\n├── README.md\n├── original\n│   ├── compiled\n│   └── packages\n└── split_einsum\n    ├── compiled\n    └── packages\n```\n\n您可以下载并使用所需的变体，如下所示。\n\n## Python中的Core ML推理\n\n安装以下库以在Python中运行Core ML推理：\n\n```bash\npip install huggingface_hub\npip install git+https://github.com/apple/ml-stable-diffusion\n```\n\n### 下载模型检查点\n\n要在Python中运行推理，请使用存储在`packages`文件夹中的版本之一，因为`compiled`版本仅与Swift兼容。您可以选择使用`original`或`split_einsum`注意力。\n\n这是您如何从Hub下载`original`注意力变体到一个名为`models`的目录：\n\n```Python\nfrom huggingface_hub import snapshot_download\nfrom pathlib import Path\n\nrepo_id = \"apple/coreml-stable-diffusion-v1-4\"\nvariant = \"original/packages\"\n\nmo\ndel_path = Path(\"./models\") / (repo_id.split(\"/\")[-1] + \"_\" + variant.replace(\"/\", \"_\"))\nsnapshot_download(repo_id, allow_patterns=f\"{variant}/*\", local_dir=model_path, local_dir_use_symlinks=False)\nprint(f\"Model downloaded at {model_path}\")\n```\n\n### 推理[[python-inference]]\n\n下载模型快照后，您可以使用 Apple 的 Python 脚本来测试它。\n\n```shell\npython -m python_coreml_stable_diffusion.pipeline --prompt \"a photo of an astronaut riding a horse on mars\" -i ./models/coreml-stable-diffusion-v1-4_original_packages/original/packages -o </path/to/output/image> --compute-unit CPU_AND_GPU --seed 93\n```\n\n使用 `-i` 标志将下载的检查点路径传递给脚本。`--compute-unit` 表示您希望允许用于推理的硬件。它必须是以下选项之一：`ALL`、`CPU_AND_GPU`、`CPU_ONLY`、`CPU_AND_NE`。您也可以提供可选的输出路径和用于可重现性的种子。\n\n推理脚本假设您使用的是 Stable Diffusion 模型的原始版本，`CompVis/stable-diffusion-v1-4`。如果您使用另一个模型，您*必须*在推理命令行中使用 `--model-version` 选项指定其 Hub ID。这适用于已支持的模型以及您自己训练或微调的自定义模型。\n\n例如，如果您想使用 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)：\n\n```shell\npython -m python_coreml_stable_diffusion.pipeline --prompt \"a photo of an astronaut riding a horse on mars\" --compute-unit ALL -o output --seed 93 -i models/coreml-stable-diffusion-v1-5_original_packages --model-version stable-diffusion-v1-5/stable-diffusion-v1-5\n```\n\n## Core ML 在 Swift 中的推理\n\n在 Swift 中运行推理比在 Python 中稍快，因为模型已经以 `mlmodelc` 格式编译。这在应用启动时加载模型时很明显，但如果在之后运行多次生成，则不应明显。\n\n### 下载\n\n要在您的 Mac 上运行 Swift 推理，您需要一个 `compiled` 检查点版本。我们建议您使用类似于先前示例的 Python 代码在本地下载它们，但使用 `compiled` 变体之一：\n\n```Python\nfrom huggingface_hub import snapshot_download\nfrom pathlib import Path\n\nrepo_id = \"apple/coreml-stable-diffusion-v1-4\"\nvariant = \"original/compiled\"\n\nmodel_path = Path(\"./models\") / (repo_id.split(\"/\")[-1] + \"_\" + variant.replace(\"/\", \"_\"))\nsnapshot_download(repo_id, allow_patterns=f\"{variant}/*\", local_dir=model_path, local_dir_use_symlinks=False)\nprint(f\"Model downloaded at {model_path}\")\n```\n\n### 推理[[swift-inference]]\n\n要运行推理，请克隆 Apple 的仓库：\n\n```bash\ngit clone https://github.com/apple/ml-stable-diffusion\ncd ml-stable-diffusion\n```\n\n然后使用 Apple 的命令行工具，[Swift Package Manager](https://www.swift.org/package-manager/#)：\n\n```bash\nswift run StableDiffusionSample --resource-path models/coreml-stable-diffusion-v1-4_original_compiled --compute-units all \"a photo of an astronaut riding a horse on mars\"\n```\n\n您必须在 `--resource-path` 中指定上一步下载的检查点之一，请确保它包含扩展名为 `.mlmodelc` 的已编译 Core ML 包。`--compute-units` 必须是以下值之一：`all`、`cpuOnly`、`cpuAndGPU`、`cpuAndNeuralEngine`。\n\n有关更多详细信息，请参考 [Apple 仓库中的说明](https://github.com/apple/ml-stable-diffusion)。\n\n## 支持的 Diffusers 功能\n\nCore ML 模型和推理代码不支持 🧨 Diffusers 的许多功能、选项和灵活性。以下是一些需要注意的限制：\n\n- Core ML 模型仅适用于推理。它们不能用于训练或微调。\n- 只有两个调度器已移植到 Swift：Stable Diffusion 使用的默认调度器和我们从 `diffusers` 实现移植到 Swift 的 `DPMSolverMultistepScheduler`。我们推荐您使用 `DPMSolverMultistepScheduler`，因为它在约一半的步骤中产生相同的质量。\n- 负面提示、无分类器引导尺度和图像到图像任务在推理代码中可用。高级功能如深度引导、ControlNet 和潜在上采样器尚不可用。\n\nApple 的 [转换和推理仓库](https://github.com/apple/ml-stable-diffusion) 和我们自己的 [swift-coreml-diffusers](https://github.com/huggingface/swift-coreml-diffusers) 仓库旨在作为技术演示，以帮助其他开发者在此基础上构建。\n\n如果您对任何缺失功能有强烈需求，请随时提交功能请求或更好的是，贡献一个 PR 🙂。\n\n## 原生 Diffusers Swift 应用\n\n一个简单的方法来在您自己的 Apple 硬件上运行 Stable Diffusion 是使用 [我们的开源 Swift 仓库](https://github.com/huggingface/swift-coreml-diffusers)，它基于 `diffusers` 和 Apple 的转换和推理仓库。您可以研究代码，使用 [Xcode](https://developer.apple.com/xcode/) 编译它，并根据您的需求进行适配。为了方便，[App Store 中还有一个独立 Mac 应用](https://apps.apple.com/app/diffusers/id1666309574)，因此您无需处理代码或 IDE 即可使用它。如果您是开发者，并已确定 Core ML 是构建您的 Stable Diffusion 应用的最佳解决方案，那么您可以使用本指南的其余部分来开始您的项目。我们迫不及待想看看您会构建什么 🙂。"
  },
  {
    "path": "docs/source/zh/optimization/deepcache.md",
    "content": "<!--版权所有 2025 The HuggingFace Team。保留所有权利。\n\n根据 Apache 许可证 2.0 版本（\"许可证\"）授权；除非遵守许可证，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，否则根据许可证分发的软件按\"原样\"分发，无任何明示或暗示的担保或条件。有关许可证的具体语言，请参阅许可证中的权限和限制。\n-->\n\n# DeepCache\n[DeepCache](https://huggingface.co/papers/2312.00858) 通过策略性地缓存和重用高级特征，同时利用 U-Net 架构高效更新低级特征，来加速 [`StableDiffusionPipeline`] 和 [`StableDiffusionXLPipeline`]。\n\n首先安装 [DeepCache](https://github.com/horseee/DeepCache)：\n```bash\npip install DeepCache\n```\n\n然后加载并启用 [`DeepCacheSDHelper`](https://github.com/horseee/DeepCache#usage)：\n\n```diff\n  import torch\n  from diffusers import StableDiffusionPipeline\n  pipe = StableDiffusionPipeline.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5', torch_dtype=torch.float16).to(\"cuda\")\n\n+ from DeepCache import DeepCacheSDHelper\n+ helper = DeepCacheSDHelper(pipe=pipe)\n+ helper.set_params(\n+     cache_interval=3,\n+     cache_branch_id=0,\n+ )\n+ helper.enable()\n\n  image = pipe(\"a photo of an astronaut on a moon\").images[0]\n```\n\n`set_params` 方法接受两个参数：`cache_interval` 和 `cache_branch_id`。`cache_interval` 表示特征缓存的频率，指定为每次缓存操作之间的步数。`cache_branch_id` 标识网络的哪个分支（从最浅层到最深层排序）负责执行缓存过程。\n选择较低的 `cache_branch_id` 或较大的 `cache_interval` 可以加快推理速度，但会降低图像质量（这些超参数的消融实验可以在[论文](https://huggingface.co/papers/2312.00858)中找到）。一旦设置了这些参数，使用 `enable` 或 `disable` 方法来激活或停用 `DeepCacheSDHelper`。\n\n<div class=\"flex justify-center\">\n    <img src=\"https://github.com/horseee/Diffusion_DeepCache/raw/master/static/images/example.png\">\n</div>\n\n您可以在 [WandB 报告](https://wandb.ai/horseee/DeepCache/runs/jwlsqqgt?workspace=user-horseee) 中找到更多生成的样本（原始管道 vs DeepCache）和相应的推理延迟。提示是从 [MS-COCO 2017](https://cocodataset.org/#home) 数据集中随机选择的。\n\n## 基准测试\n\n我们在 NVIDIA RTX A5000 上测试了 DeepCache 使用 50 个推理步骤加速 [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) 的速度，使用不同的配置，包括分辨率、批处理大小、缓存间隔（I）和缓存分支(B)。\n\n| **分辨率** | **批次大小** | **原始** | **DeepCache(I=3, B=0)** | **DeepCache(I=5, B=0)** | **DeepCache(I=5, B=1)** |\n|----------------|----------------|--------------|-------------------------|-------------------------|-------------------------|\n|             512|               8|         15.96|              6.88(2.32倍)|              5.03(3.18倍)|              7.27(2.20x)|\n|                |               4|          8.39|              3.60(2.33倍)|              2.62(3.21倍)|              3.75(2.24x)|\n|                |               1|          2.61|              1.12(2.33倍)|              0.81(3.24倍)|              1.11(2.35x)|\n|             768|               8|         43.58|             18.99(2.29倍)|             13.96(3.12倍)|             21.27(2.05x)|\n|                |               4|         22.24|              9.67(2.30倍)|              7.10(3.13倍)|             10.74(2.07x)|\n|                |               1|          6.33|              2.72(2.33倍)|              1.97(3.21倍)|              2.98(2.12x)|\n|            1024|               8|        101.95|             45.57(2.24倍)|             33.72(3.02倍)|             53.00(1.92x)|\n|                |               4|         49.25|             21.86(2.25倍)|             16.19(3.04倍)|             25.78(1.91x)|\n|                |               1|         13.83|              6.07(2.28倍)|              4.43(3.12倍)|              7.15(1.93x)|"
  },
  {
    "path": "docs/source/zh/optimization/fp16.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 加速推理\n\nDiffusion模型在推理时速度较慢，因为生成是一个迭代过程，需要经过一定数量的\"步数\"逐步将噪声细化为图像或视频。要加速这一过程，您可以尝试使用不同的[调度器](../api/schedulers/overview)、降低模型权重的精度以加快计算、使用更高效的内存注意力机制等方法。\n\n将这些技术组合使用，可以比单独使用任何一种技术获得更快的推理速度。\n\n本指南将介绍如何加速推理。\n\n## 模型数据类型\n\n模型权重的精度和数据类型会影响推理速度，因为更高的精度需要更多内存来加载，也需要更多时间进行计算。PyTorch默认以float32或全精度加载模型权重，因此更改数据类型是快速获得更快推理速度的简单方法。\n\n<hfoptions id=\"dtypes\">\n<hfoption id=\"bfloat16\">\n\nbfloat16与float16类似，但对数值误差更稳健。硬件对bfloat16的支持各不相同，但大多数现代GPU都能支持bfloat16。\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\npipeline(prompt, num_inference_steps=30).images[0]\n```\n\n</hfoption>\n<hfoption id=\"float16\">\n\nfloat16与bfloat16类似，但可能更容易出现数值误差。\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n).to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\npipeline(prompt, num_inference_steps=30).images[0]\n```\n\n</hfoption>\n<hfoption id=\"TensorFloat-32\">\n\n[TensorFloat-32 (tf32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)模式在NVIDIA Ampere GPU上受支持，它以tf32计算卷积和矩阵乘法运算。存储和其他操作保持在float32。与bfloat16或float16结合使用时，可以显著加快计算速度。\n\nPyTorch默认仅对卷积启用tf32模式，您需要显式启用矩阵乘法的tf32模式。\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\ntorch.backends.cuda.matmul.allow_tf32 = True\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\npipeline(prompt, num_inference_steps=30).images[0]\n```\n\n更多详情请参阅[混合精度训练](https://huggingface.co/docs/transformers/en/perf_train_gpu_one#mixed-precision)文档。\n\n</hfoption>\n</hfoptions>\n\n## 缩放点积注意力\n\n> [!TIP]\n> 内存高效注意力优化了推理速度*和*[内存使用](./memory#memory-efficient-attention)！\n\n[缩放点积注意力（SDPA）](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)实现了多种注意力后端，包括[FlashAttention](https://github.com/Dao-AILab/flash-attention)、[xFormers](https://github.com/facebookresearch/xformers)和原生C++实现。它会根据您的硬件自动选择最优的后端。\n\n如果您使用的是PyTorch >= 2.0，SDPA默认启用，无需对代码进行任何额外更改。不过，您也可以尝试使用其他注意力后端来自行选择。下面的示例使用[torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html)上下文管理器来启用高效注意力。\n\n```py\nfrom torch.nn.attention import SDPBackend, sdpa_kernel\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\n\nwith sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):\n  image = pipeline(prompt, num_inference_steps=30).images[0]\n```\n\n## torch.compile\n\n[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)通过将PyTorch代码和操作编译为优化的内核来加速推理。Diffusers通常会编译计算密集型的模型，如UNet、transformer或VAE。\n\n启用以下编译器设置以获得最大速度（更多选项请参阅[完整列表](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py)）。\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\ntorch._inductor.config.conv_1x1_as_mm = True\ntorch._inductor.config.coordinate_descent_tuning = True\ntorch._inductor.config.epilogue_fusion = False\ntorch._inductor.config.coordinate_descent_check_all_directions = True\n```\n\n加载并编译UNet和VAE。有几种不同的模式可供选择，但`\"max-autotune\"`通过编译为CUDA图来优化速度。CUDA图通过单个CPU操作启动多个GPU操作，有效减少了开销。\n\n> [!TIP]\n> 在PyTorch 2.3.1中，您可以控制torch.compile的缓存行为。这对于像`\"max-autotune\"`这样的编译模式特别有用，它会通过网格搜索多个编译标志来找到最优配置。更多详情请参阅[torch.compile中的编译时间缓存](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)教程。\n\n将内存布局更改为[channels_last](./memory#torchchannels_last)也可以优化内存和推理速度。\n\n```py\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.unet.to(memory_format=torch.channels_last)\npipeline.vae.to(memory_format=torch.channels_last)\npipeline.unet = torch.compile(\n    pipeline.unet, mode=\"max-autotune\", fullgraph=True\n)\npipeline.vae.decode = torch.compile(\n    pipeline.vae.decode,\n    mode=\"max-autotune\",\n    fullgraph=True\n)\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\npipeline(prompt, num_inference_steps=30).images[0]\n```\n\n第一次编译时速度较慢，但一旦编译完成，速度会显著提升。尽量只在相同类型的推理操作上使用编译后的管道。在不同尺寸的图像上调用编译后的管道会重新触发编译，这会很慢且效率低下。\n\n### 动态形状编译\n\n> [!TIP]\n> 确保始终使用PyTorch的nightly版本以获得更好的支持。\n\n`torch.compile`会跟踪输入形状和条件，如果这些不同，它会重新编译模型。例如，如果模型是在1024x1024分辨率的图像上编译的，而在不同分辨率的图像上使用，就会触发重新编译。\n\n为避免重新编译，添加`dynamic=True`以尝试生成更动态的内核，避免条件变化时重新编译。\n\n```diff\n+ torch.fx.experimental._config.use_duck_shape = False\n+ pipeline.unet = torch.compile(\n    pipeline.unet, fullgraph=True, dynamic=True\n)\n```\n\n指定`use_duck_shape=False`会指示编译器是否应使用相同的符号变量来表示相同大小的输入。更多详情请参阅此[评论](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790)。\n\n并非所有模型都能开箱即用地从动态编译中受益，可能需要更改。参考此[PR](https://github.com/huggingface/diffusers/pull/11297/)，它改进了[`AuraFlowPipeline`]的实现以受益于动态编译。\n\n如果动态编译对Diffusers模型的效果不如预期，请随时提出问题。\n\n### 区域编译\n\n[区域编译](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)通过仅编译模型中*小而频繁重复的块*（通常是transformer层）来减少冷启动延迟，并为每个后续出现的块重用编译后的工件。对于许多diffusion架构，这提供了与全图编译相同的运行时加速，并将编译时间减少了8-10倍。\n\n使用[`~ModelMixin.compile_repeated_blocks`]方法（一个包装`torch.compile`的辅助函数）在任何组件（如transformer模型）上，如下所示。\n\n```py\n# pip install -U diffusers\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n\n# 仅编译UNet中重复的transformer层\npipeline.unet.compile_repeated_blocks(fullgraph=True)\n```\n\n要为新模型启用区域编译，请在模型类中添加一个`_repeated_blocks`属性，包含您想要编译的块的类名（作为字符串）。\n\n```py\nclass MyUNet(ModelMixin):\n    _repeated_blocks = (\"Transformer2DModel\",)  # ← 默认编译\n```\n\n> [!TIP]\n> 更多区域编译示例，请参阅参考[PR](https://github.com/huggingface/diffusers/pull/11705)。\n\n[Accelerate](https://huggingface.co/docs/accelerate/index)中还有一个[compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78)方法，可以自动选择模型中的候选块进行编译。其余图会单独编译。这对于快速实验很有用，因为您不需要设置哪些块要编译或调整编译标志。\n\n```py\n# pip install -U accelerate\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom accelerate.utils import compile regions\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.unet = compile_regions(pipeline.unet, mode=\"reduce-overhead\", fullgraph=True)\n```\n\n[`~ModelMixin.compile_repeated_blocks`]是故意显式的。在`_repeated_blocks`中列出要重复的块，辅助函数仅编译这些块。它提供了可预测的行为，并且只需一行代码即可轻松推理缓存重用。\n\n### 图中断\n\n在torch.compile中指定`fullgraph=True`非常重要，以确保底层模型中没有图中断。这使您可以充分利用torch.compile而不会降低性能。对于UNet和VAE，这会改变您访问返回变量的方式。\n\n```diff\n- latents = unet(\n-   latents, timestep=timestep, encoder_hidden_states=prompt_embeds\n-).sample\n\n+ latents = unet(\n+   latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False\n+)[0]\n```\n\n### GPU同步\n\n每次去噪器做出预测后，调度器的`step()`函数会被[调用](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228)，并且`sigmas`变量会被[索引](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476)。当放在GPU上时，这会引入延迟，因为CPU和GPU之间需要进行通信同步。当去噪器已经编译时，这一点会更加明显。\n\n一般来说，`sigmas`应该[保持在CPU上](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240)，以避免通信同步和延迟。\n\n> [!TIP]\n> 参阅[torch.compile和Diffusers：峰值性能实践指南](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/)博客文章，了解如何为扩散模型最大化`torch.compile`的性能。\n\n### 基准测试\n\n参阅[diffusers/benchmarks](https://huggingface.co/datasets/diffusers/benchmarks)数据集，查看编译管道的推理延迟和内存使用数据。\n\n[diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results)仓库还包含Flux和CogVideoX编译版本的基准测试结果。\n\n## 动态量化\n\n[动态量化](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html)通过降低精度以加快数学运算来提高推理速度。这种特定类型的量化在运行时根据数据确定如何缩放激活，而不是使用固定的缩放因子。因此，缩放因子与数据更准确地匹配。\n\n以下示例使用[torchao](../quantization/torchao)库对UNet和VAE应用[动态int8量化](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html)。\n\n> [!TIP]\n> 参阅我们的[torchao](../quantization/torchao)文档，了解更多关于如何使用Diffusers torchao集成的信息。\n\n配置编译器标志以获得最大速度。\n\n```py\nimport torch\nfrom torchao import apply_dynamic_quant\nfrom diffusers import StableDiffusionXLPipeline\n\ntorch._inductor.config.conv_1x1_as_mm = True\ntorch._inductor.config.coordinate_descent_tuning = True\ntorch._inductor.config.epilogue_fusion = False\ntorch._inductor.config.coordinate_descent_check_all_directions = True\ntorch._inductor.config.force_fuse_int_mm_with_mul = True\ntorch._inductor.config.use_mixed_mm = True\n```\n\n使用[dynamic_quant_filter_fn](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16)过滤掉UNet和VAE中一些不会从动态量化中受益的线性层。\n\n```py\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\napply_dynamic_quant(pipeline.unet, dynamic_quant_filter_fn)\napply_dynamic_quant(pipeline.vae, dynamic_quant_filter_fn)\n\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\npipeline(prompt, num_inference_steps=30).images[0]\n```\n\n## 融合投影矩阵\n\n> [!WARNING]\n> [fuse_qkv_projections](https://github.com/huggingface/diffusers/blob/58431f102cf39c3c8a569f32d71b2ea8caa461e1/src/diffusers/pipelines/pipeline_utils.py#L2034)方法是实验性的，目前主要支持Stable Diffusion管道。参阅此[PR](https://github.com/huggingface/diffusers/pull/6179)了解如何为其他管道启用它。\n\n在注意力块中，输入被投影到三个子空间，分别由投影矩阵Q、K和V表示。这些投影通常单独计算，但您可以水平组合这些矩阵为一个矩阵，并在单步中执行投影。这会增加输入投影的矩阵乘法大小，并提高量化的效果。\n\n```py\npipeline.fuse_qkv_projections()\n```\n\n## 资源\n\n- 阅读[Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/)博客文章，了解如何结合所有这些优化与[TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html)和[AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html)，使用[flux-fast](https://github.com/huggingface/flux-fast)的配方获得约2.5倍的加速。\n\n    这些配方支持AMD硬件和[Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev)。\n- 阅读[torch.compile和Diffusers：峰值性能实践指南](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/)博客文章，了解如何在使用`torch.compile`时最大化性能。\n"
  },
  {
    "path": "docs/source/zh/optimization/habana.md",
    "content": "<!--版权所有 2025 The HuggingFace Team。保留所有权利。\n\n根据 Apache 许可证 2.0 版本（\"许可证\"）授权；除非遵守许可证，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件按\"原样\"分发，无任何明示或暗示的担保或条件。有关许可证管理权限和限制的具体语言，请参阅许可证。\n-->\n\n# Intel Gaudi\n\nIntel Gaudi AI 加速器系列包括 [Intel Gaudi 1](https://habana.ai/products/gaudi/)、[Intel Gaudi 2](https://habana.ai/products/gaudi2/) 和 [Intel Gaudi 3](https://habana.ai/products/gaudi3/)。每台服务器配备 8 个设备，称为 Habana 处理单元 (HPU)，在 Gaudi 3 上提供 128GB 内存，在 Gaudi 2 上提供 96GB 内存，在第一代 Gaudi 上提供 32GB 内存。有关底层硬件架构的更多详细信息，请查看 [Gaudi 架构](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html) 概述。\n\nDiffusers 管道可以利用 HPU 加速，即使管道尚未添加到 [Optimum for Intel Gaudi](https://huggingface.co/docs/optimum/main/en/habana/index)，也可以通过 [GPU 迁移工具包](https://docs.habana.ai/en/latest/PyTorch/PyTorch_Model_Porting/GPU_Migration_Toolkit/GPU_Migration_Toolkit.html) 实现。\n\n在您的管道上调用 `.to(\"hpu\")` 以将其移动到 HPU 设备，如下所示为 Flux 示例：\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16)\npipeline.to(\"hpu\")\n\nimage = pipeline(\"一张松鼠在毕加索风格中的图像\").images[0]\n```\n\n> [!TIP]\n> 对于 Gaudi 优化的扩散管道实现，我们推荐使用 [Optimum for Intel Gaudi](https://huggingface.co/docs/optimum/main/en/habana/index)。"
  },
  {
    "path": "docs/source/zh/optimization/memory.md",
    "content": "<!--版权所有 2025 HuggingFace 团队。保留所有权利。\n\n根据 Apache 许可证 2.0 版本（“许可证”）授权；除非遵守许可证，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件按“原样”分发，无任何明示或暗示的担保或条件。有关许可证的特定语言管理权限和限制，请参阅许可证。\n-->\n\n# 减少内存使用\n\n现代diffusion models，如 [Flux](../api/pipelines/flux) 和 [Wan](../api/pipelines/wan)，拥有数十亿参数，在您的硬件上进行推理时会占用大量内存。这是一个挑战，因为常见的 GPU 通常没有足够的内存。为了克服内存限制，您可以使用多个 GPU（如果可用）、将一些管道组件卸载到 CPU 等。\n\n本指南将展示如何减少内存使用。\n\n> [!TIP]\n> 请记住，这些技术可能需要根据模型进行调整。例如，基于 transformer 的扩散模型可能不会像基于 UNet 的模型那样从这些内存优化中同等受益。\n\n## 多个 GPU\n\n如果您有多个 GPU 的访问权限，有几种选项可以高效地在硬件上加载和分发大型模型。这些功能由 [Accelerate](https://huggingface.co/docs/accelerate/index) 库支持，因此请确保先安装它。\n\n```bash\npip install -U accelerate\n```\n\n### 分片检查点\n\n将大型检查点加载到多个分片中很有用，因为分片会逐个加载。这保持了低内存使用，只需要足够的内存来容纳模型大小和最大分片大小。我们建议当 fp32 检查点大于 5GB 时进行分片。默认分片大小为 5GB。\n\n在 [`~DiffusionPipeline.save_pretrained`] 中使用 `max_shard_size` 参数对检查点进行分片。\n\n```py\nfrom diffusers import AutoModel\n\nunet = AutoModel.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", subfolder=\"unet\"\n)\nunet.save_pretrained(\"sdxl-unet-sharded\", max_shard_size=\"5GB\")\n```\n\n现在您可以使用分片检查点，而不是常规检查点，以节省内存。\n\n```py\nimport torch\nfrom diffusers import AutoModel, StableDiffusionXLPipeline\n\nunet = AutoModel.from_pretrained(\n    \"username/sdxl-unet-sharded\", torch_dtype=torch.float16\n)\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    unet=unet,\n    torch_dtype=torch.float16\n).to(\"cuda\")\n```\n\n### 设备放置\n\n> [!WARNING]\n> 设备放置是一个实验性功能，API 可能会更改。目前仅支持 `balanced` 策略。我们计划在未来支持额外的映射策略。\n\n`device_map` 参数控制管道或模型中的组件如何\n单个模型中的层分布在多个设备上。\n\n<hfoptions id=\"device-map\">\n<hfoption id=\"pipeline level\">\n\n`balanced` 设备放置策略将管道均匀分割到所有可用设备上。\n\n```py\nimport torch\nfrom diffusers import AutoModel, StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    device_map=\"balanced\"\n)\n```\n\n您可以使用 `hf_device_map` 检查管道的设备映射。\n\n```py\nprint(pipeline.hf_device_map)\n{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}\n```\n\n</hfoption>\n<hfoption id=\"model level\">\n\n`device_map` 对于加载大型模型非常有用，例如具有 125 亿参数的 Flux diffusion transformer。将其设置为 `\"auto\"` 可以自动将模型首先分布到最快的设备上，然后再移动到较慢的设备。有关更多详细信息，请参阅 [模型分片](../training/distributed_inference#model-sharding) 文档。\n\n```py\nimport torch\nfrom diffusers import AutoModel\n\ntransformer = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\", \n    subfolder=\"transformer\",\n    device_map=\"auto\",\n    torch_dtype=torch.bfloat16\n)\n```\n\n您可以使用 `hf_device_map` 检查模型的设备映射。\n\n```py\nprint(transformer.hf_device_map)\n```\n\n</hfoption>\n</hfoptions>\n\n当设计您自己的 `device_map` 时，它应该是一个字典，包含模型的特定模块名称或层以及设备标识符（整数表示 GPU，`cpu` 表示 CPU，`disk` 表示磁盘）。\n\n在模型上调用 `hf_device_map` 以查看模型层如何分布，然后设计您自己的映射。\n\n```py\nprint(transformer.hf_device_map)\n{'pos_embed': 0, 'time_text_embed': 0, 'context_embedder': 0, 'x_embedder': 0, 'transformer_blocks': 0, 'single_transformer_blocks.0': 0, 'single_transformer_blocks.1': 0, 'single_transformer_blocks.2': 0, 'single_transformer_blocks.3': 0, 'single_transformer_blocks.4': 0, 'single_transformer_blocks.5': 0, 'single_transformer_blocks.6': 0, 'single_transformer_blocks.7': 0, 'single_transformer_blocks.8': 0, 'single_transformer_blocks.9': 0, 'single_transformer_blocks.10': 'cpu', 'single_transformer_blocks.11': 'cpu', 'single_transformer_blocks.12': 'cpu', 'single_transformer_blocks.13': 'cpu', 'single_transformer_blocks.14': 'cpu', 'single_transformer_blocks.15': 'cpu', 'single_transformer_blocks.16': 'cpu', 'single_transformer_blocks.17': 'cpu', 'single_transformer_blocks.18': 'cpu', 'single_transformer_blocks.19': 'cpu', 'single_transformer_blocks.20': 'cpu', 'single_transformer_blocks.21': 'cpu', 'single_transformer_blocks.22': 'cpu', 'single_transformer_blocks.23': 'cpu', 'single_transformer_blocks.24': 'cpu', 'single_transformer_blocks.25': 'cpu', 'single_transformer_blocks.26': 'cpu', 'single_transformer_blocks.27': 'cpu', 'single_transformer_blocks.28': 'cpu', 'single_transformer_blocks.29': 'cpu', 'single_transformer_blocks.30': 'cpu', 'single_transformer_blocks.31': 'cpu', 'single_transformer_blocks.32': 'cpu', 'single_transformer_blocks.33': 'cpu', 'single_transformer_blocks.34': 'cpu', 'single_transformer_blocks.35': 'cpu', 'single_transformer_blocks.36': 'cpu', 'single_transformer_blocks.37': 'cpu', 'norm_out': 'cpu', 'proj_out': 'cpu'}\n```\n\n例如，下面的 `device_map` 将 `single_transformer_blocks.10` 到 `single_transformer_blocks.20` 放置在第二个 GPU（`1`）上。\n\n```py\nimport torch\nfrom diffusers import AutoModel\n\ndevice_map = {\n    'pos_embed': 0, 'time_text_embed': 0, 'context_embedder': 0, 'x_embedder': 0, 'transformer_blocks': 0, 'single_transformer_blocks.0': 0, 'single_transformer_blocks.1': 0, 'single_transformer_blocks.2': 0, 'single_transformer_blocks.3': 0, 'single_transformer_blocks.4': 0, 'single_transformer_blocks.5': 0, 'single_transformer_blocks.6': 0, 'single_transformer_blocks.7': 0, 'single_transformer_blocks.8': 0, 'single_transformer_blocks.9': 0, 'single_transformer_blocks.10': 1, 'single_transformer_blocks.11': 1, 'single_transformer_blocks.12': 1, 'single_transformer_blocks.13': 1, 'single_transformer_blocks.14': 1, 'single_transformer_blocks.15': 1, 'single_transformer_blocks.16': 1, 'single_transformer_blocks.17': 1, 'single_transformer_blocks.18': 1, 'single_transformer_blocks.19': 1, 'single_transformer_blocks.20': 1, 'single_transformer_blocks.21': 'cpu', 'single_transformer_blocks.22': 'cpu', 'single_transformer_blocks.23': 'cpu', 'single_transformer_blocks.24': 'cpu', 'single_transformer_blocks.25': 'cpu', 'single_transformer_blocks.26': 'cpu', 'single_transformer_blocks.27': 'cpu', 'single_transformer_blocks.28': 'cpu', 'single_transformer_blocks.29': 'cpu', 'single_transformer_blocks.30': 'cpu', 'single_transformer_blocks.31': 'cpu', 'single_transformer_blocks.32': 'cpu', 'single_transformer_blocks.33': 'cpu', 'single_transformer_blocks.34': 'cpu', 'single_transformer_blocks.35': 'cpu', 'single_transformer_blocks.36': 'cpu', 'single_transformer_blocks.37': 'cpu', 'norm_out': 'cpu', 'proj_out': 'cpu'\n}\n\ntransformer = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\", \n    subfolder=\"transformer\",\n    device_map=device_map,\n    torch_dtype=torch.bfloat16\n)\n```\n\n传递一个字典，将最大内存使用量映射到每个设备以强制执行限制。如果设备不在 `max_memory` 中，它将被忽略，管道组件不会分发到该设备。\n\n```py\nimport torch\nfrom diffusers import AutoModel, StableDiffusionXLPipeline\n\nmax_memory = {0:\"1GB\", 1:\"1GB\"}\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    device_map=\"balanced\",\n    max_memory=max_memory\n)\n```\n\nDiffusers 默认使用所有设备的最大内存，但如果它们无法适应 GPU，则需要使用单个 GPU 并通过以下方法卸载到 CPU。\n\n- [`~DiffusionPipeline.enable_model_cpu_offload`] 仅适用于单个 GPU，但非常大的模型可能无法适应它\n- 使用 [`~DiffusionPipeline.enable_sequential_cpu_offload`] 可能有效，但它极其缓慢，并且仅限于单个 GPU。\n\n使用 [`~DiffusionPipeline.reset_device_map`] 方法来重置 `device_map`。如果您想在已进行设备映射的管道上使用方法如 `.to()`、[`~DiffusionPipeline.enable_sequential_cpu_offload`] 和 [`~DiffusionPipeline.enable_model_cpu_offload`]，这是必要的。\n\n```py\npipeline.reset_device_map()\n```\n\n## VAE 切片\n\nVAE 切片通过将大批次输入拆分为单个数据批次并分别处理它们来节省内存。这种方法在同时生成多个图像时效果最佳。\n\n例如，如果您同时生成 4 个图像，解码会将峰值激活内存增加 4 倍。VAE 切片通过一次只解码 1 个图像而不是所有 4 个图像来减少这种情况。\n\n调用 [`~StableDiffusionPipeline.enable_vae_slicing`] 来启用切片 VAE。您可以预期在解码多图像批次时性能会有小幅提升，而在单图像批次时没有性能影响。\n\n```py\nimport torch\nfrom diffusers import AutoModel, StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\npipeline.enable_vae_slicing()\npipeline([\"An astronaut riding a horse on Mars\"]*32).images[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n```\n\n> [!WARNING]\n> [`AutoencoderKLWan`] 和 [`AsymmetricAutoencoderKL`] 类不支持切片。\n\n## VAE 平铺\n\nVAE 平铺通过将图像划分为较小的重叠图块而不是一次性处理整个图像来节省内存。这也减少了峰值内存使用量，因为 GPU 一次只处理一个图块。\n\n调用 [`~StableDiffusionPipeline.enable_vae_tiling`] 来启用 VAE 平铺。生成的图像可能因图块到图块的色调变化而有所不同，因为它们被单独解码，但图块之间不应有明显的接缝。对于低于预设（但可配置）限制的分辨率，平铺被禁用。例如，对于 [`StableDiffusionPipeline`] 中的 VAE，此限制为 512x512。\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForImage2Image\nfrom diffusers.utils import load_image\n\npipeline = AutoPipelineForImage2Image.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n).to(\"cuda\")\npipeline.enable_vae_tiling()\n\ninit_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png\")\nprompt = \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\"\npipeline(prompt, image=init_image, strength=0.5).images[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n```\n\n> [!WARNING]\n> [`AutoencoderKLWan`] 和 [`AsymmetricAutoencoderKL`] 不支持平铺。\n\n## 卸载\n\n卸载策略将非当前活动层移动\n将模型移动到 CPU 以避免增加 GPU 内存。这些策略可以与量化和 torch.compile 结合使用，以平衡推理速度和内存使用。\n\n有关更多详细信息，请参考 [编译和卸载量化模型](./speed-memory-optims) 指南。\n\n### CPU 卸载\n\nCPU 卸载选择性地将权重从 GPU 移动到 CPU。当需要某个组件时，它被传输到 GPU；当不需要时，它被移动到 CPU。此方法作用于子模块而非整个模型。它通过避免将整个模型存储在 GPU 上来节省内存。\n\nCPU 卸载显著减少内存使用，但由于子模块在设备之间多次来回传递，它也非常慢。由于速度极慢，它通常不实用。\n\n> [!WARNING]\n> 在调用 [`~DiffusionPipeline.enable_sequential_cpu_offload`] 之前，不要将管道移动到 CUDA，否则节省的内存非常有限（更多细节请参考此 [issue](https://github.com/huggingface/diffusers/issues/1934)）。这是一个状态操作，会在模型上安装钩子。\n\n调用 [`~DiffusionPipeline.enable_sequential_cpu_offload`] 以在管道上启用它。\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n)\npipeline.enable_sequential_cpu_offload()\n\npipeline(\n    prompt=\"An astronaut riding a horse on Mars\",\n    guidance_scale=0.,\n    height=768,\n    width=1360,\n    num_inference_steps=4,\n    max_sequence_length=256,\n).images[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n```\n\n### 模型卸载\n\n模型卸载将整个模型移动到 GPU，而不是选择性地移动某些层或模型组件。一个主要管道模型，通常是文本编码器、UNet 和 VAE，被放置在 GPU 上，而其他组件保持在 CPU 上。像 UNet 这样运行多次的组件会一直留在 GPU 上，直到完全完成且不再需要。这消除了 [CPU 卸载](#cpu-offloading) 的通信开销，使模型卸载成为一个更快的替代方案。权衡是内存节省不会那么大。\n\n> [!WARNING]\n> 请注意，如果在安装钩子后模型在管道外部被重用（更多细节请参考 [移除钩子](https://huggingface.co/docs/accelerate/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module)），您需要按预期顺序运行整个管道和模型以正确卸载它们。这是一个状态操作，会在模型上安装钩子。\n\n调用 [`~DiffusionPipeline.enable_model_cpu_offload`] 以在管道上启用它。\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16\n)\npipeline.enable_model_cpu_offload()\n\npipeline(\n    prompt=\"An astronaut riding a horse on Mars\",\n    guidance_scale=0.,\n    height=768,\n    width=1360,\n    num_inference_steps=4,\n    max_sequence_length=256,\n).images[0]\nprint(f\"最大内存保留: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\n```\n\n[`~DiffusionPipeline.enable_model_cpu_offload`] 在您单独使用 [`~StableDiffusionXLPipeline.encode_prompt`] 方法生成文本编码器隐藏状态时也有帮助。\n\n### 组卸载\n\n组卸载将内部层组（[torch.nn.ModuleList](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) 或 [torch.nn.Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)）移动到 CPU。它比[模型卸载](#model-offloading)使用更少的内存，并且比[CPU 卸载](#cpu-offloading)更快，因为它减少了通信开销。\n\n> [!WARNING]\n> 如果前向实现包含权重相关的输入设备转换，组卸载可能不适用于所有模型，因为它可能与组卸载的设备转换机制冲突。\n\n调用 [`~ModelMixin.enable_group_offload`] 为继承自 [`ModelMixin`] 的标准 Diffusers 模型组件启用它。对于不继承自 [`ModelMixin`] 的其他模型组件，例如通用 [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)，使用 [`~hooks.apply_group_offloading`] 代替。\n\n`offload_type` 参数可以设置为 `block_level` 或 `leaf_level`。\n\n- `block_level` 基于 `num_blocks_per_group` 参数卸载层组。例如，如果 `num_blocks_per_group=2` 在一个有 40 层的模型上，每次加载和卸载 2 层（总共 20 次加载/卸载）。这大大减少了内存需求。\n- `leaf_level` 在最低级别卸载单个层，等同于[CPU 卸载](#cpu-offloading)。但如果您使用流而不放弃推理速度，它可以更快。\n\n```py\nimport torch\nfrom diffusers import CogVideoXPipeline\nfrom diffusers.hooks import apply_group_offloading\nfrom diffusers.utils import export_to_video\n\nonload_device = torch.device(\"cuda\")\noffload_device = torch.device(\"cpu\")\npipeline = CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-5b\", torch_dtype=torch.bfloat16)\n\n# 对 Diffusers 模型实现使用 enable_group_offload 方法\npipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=\"leaf_level\")\npipeline.vae.enable_group_offload(onload_device=onload_device, offload_type=\"leaf_level\")\n\n# 对其他模型组件使用 apply_group_offloading 方法\napply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type=\"block_level\", num_blocks_per_group=2)\n\nprompt = (\n\"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. \"\n    \"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other \"\n    \"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, \"\n    \"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. \"\n    \"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical \"\n    \"atmosphere of this unique musical performance.\"\n)\nvideo = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\nexport_to_video(video, \"output.mp4\", fps=8)\n```\n\n#### CUDA 流\n`use_stream` 参数可以激活支持异步数据传输流的 CUDA 设备，以减少整体执行时间，与 [CPU 卸载](#cpu-offloading) 相比。它通过使用层预取重叠数据传输和计算。下一个要执行的层在当前层仍在执行时加载到 GPU 上。这会显著增加 CPU 内存，因此请确保您有模型大小的 2 倍内存。\n\n设置 `record_stream=True` 以获得更多速度提升，代价是内存使用量略有增加。请参阅 [torch.Tensor.record_stream](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) 文档了解更多信息。\n\n> [!TIP]\n> 当 `use_stream=True` 在启用平铺的 VAEs 上时，确保在推理前进行虚拟前向传递（可以使用虚拟输入），以避免设备不匹配错误。这可能不适用于所有实现，因此如果遇到任何问题，请随时提出问题。\n\n如果您在使用启用 `use_stream` 的 `block_level` 组卸载，`num_blocks_per_group` 参数应设置为 `1`，否则会引发警告。\n\n```py\npipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=\"leaf_level\", use_stream=True, record_stream=True)\n```\n\n`low_cpu_mem_usage` 参数可以设置为 `True`，以在使用流进行组卸载时减少 CPU 内存使用。它最适合 `leaf_level` 卸载和 CPU 内存瓶颈的情况。通过动态创建固定张量而不是预先固定它们来节省内存。然而，这可能会增加整体执行时间。\n\n#### 卸载到磁盘\n组卸载可能会消耗大量系统内存，具体取决于模型大小。在内存有限的系统上，尝试将组卸载到磁盘作为辅助内存。\n\n在 [`~ModelMixin.enable_group_offload`] 或 [`~hooks.apply_group_offloading`] 中设置 `offload_to_disk_path` 参数，将模型卸载到磁盘。\n\n```py\npipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type=\"leaf_level\", offload_to_disk_path=\"path/to/disk\")\n\napply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type=\"block_level\", num_blocks_per_group=2, offload_to_disk_path=\"path/to/disk\")\n```\n\n参考这些[两个](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363)[表格](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126)来比较速度和内存的权衡。\n\n## 分层类型转换\n\n> [!TIP]\n> 将分层类型转换与[组卸载](#group-offloading)结合使用，以获得更多内存节省。\n\n分层类型转换将权重存储在较小的数据格式中（例如 `torch.float8_e4m3fn` 和 `torch.float8_e5m2`），以使用更少的内存，并在计算时将那些权重上转换为更高精度如 `torch.float16` 或 `torch.bfloat16`。某些层（归一化和调制相关权重）被跳过，因为将它们存储在 fp8 中可能会降低生成质量。\n\n> [!WARNING]\n> 如果前向实现包含权重的内部类型转换，分层类型转换可能不适用于所有模型。当前的分层类型转换实现假设前向传递独立于权重精度，并且输入数据类型始终在 `compute_dtype` 中指定（请参见[这里](https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299)以获取不兼容的实现）。\n>\n> 分层类型转换也可能在使用[PEFT](https://huggingface.co/docs/peft/index)层的自定义建模实现上失败。有一些检查可用，但它们没有经过广泛测试或保证在所有情况下都能工作。\n\n调用 [`~ModelMixin.enable_layerwise_casting`] 来设置存储和计算数据类型。\n\n```py\nimport torch\nfrom diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel\nfrom diffusers.utils import export_to_video\n\ntransformer = CogVideoXTransformer3DModel.from_pretrained(\n    \"THUDM/CogVideoX-5b\",\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16\n)\ntransformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)\n\npipeline = CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-5b\",\n    transformer=transformer,\n    torch_dtype=torch.bfloat16\n).to(\"cuda\")\nprompt = (\n    \"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. \"\n    \"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other \"\n    \"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, \"\n    \"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. \"\n    \"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical \"\n    \"atmosphere of this unique musical performance.\"\n)\nvideo = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]\nprint(f\"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB\")\nexport_to_video(video, \"output.mp4\", fps=8)\n```\n\n[`~hooks.apply_layerwise_casting`] 方法也可以在您需要更多控制和灵活性时使用。它可以通过在特定内部模块上调用它来部分应用于模型层。使用 `skip_modules_pattern` 或 `skip_modules_classes` 参数来指定要避免的模块，例如归一化和调制层。\n\n```python\nimport torch\nfrom diffusers import CogVideoXTransformer3DModel\nfrom diffusers.hooks import apply_layerwise_casting\n\ntransformer = CogVideoXTransformer3DModel.from_pretrained(\n    \"THUDM/CogVideoX-5b\",\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16\n)\n\n# 跳过归一化层\napply_layerwise_casting(\n    transformer,\n    storage_dtype=torch.float8_e4m3fn,\n    compute_dtype=torch.bfloat16,\n    skip_modules_classes=[\"norm\"],\n    non_blocking=True,\n)\n```\n\n## torch.channels_last\n\n[torch.channels_last](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html) 将张量的存储方式从 `(批次大小, 通道数, 高度, 宽度)` 翻转为 `(批次大小, 高度, 宽度, 通道数)`。这使张量与硬件如何顺序访问存储在内存中的张量对齐，并避免了在内存中跳转以访问像素值。\n\n并非所有运算符当前都支持通道最后格式，并且可能导致性能更差，但仍然值得尝试。\n\n```py\nprint(pipeline.unet.conv_out.state_dict()[\"weight\"].stride())  # (2880, 9, 3, 1)\npipeline.unet.to(memory_format=torch.channels_last)  # 原地操作\nprint(\n    pipeline.unet.conv_out.state_dict()[\"weight\"].stride()\n)  # (2880, 1, 960, 320) 第二个维度的跨度为1证明它有效\n```\n\n## torch.jit.trace\n\n[torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) 记录模型在样本输入上执行的操作，并根据记录的执行路径创建一个新的、优化的模型表示。在跟踪过程中，模型被优化以减少来自Python和动态控制流的开销，并且操作被融合在一起以提高效率。返回的可执行文件或 [ScriptFunction](https://pytorch.org/docs/stable/generated/torch.jit.ScriptFunction.html) 可以被编译。\n\n```py\nimport time\nimport torch\nfrom diffusers import StableDiffusionPipeline\nimport functools\n\n# torch 禁用梯度\ntorch.set_grad_enabled(False)\n\n# 设置变量\nn_experiments = 2\nunet_runs_per_experiment = 50\n\n# 加载样本输入\ndef generate_inputs():\n    sample = torch.randn((2, 4, 64, 64), device=\"cuda\", dtype=torch.float16)\n    timestep = torch.rand(1, device=\"cuda\", dtype=torch.float16) * 999\n    encoder_hidden_states = torch.randn((2, 77, 768), device=\"cuda\", dtype=torch.float16)\n    return sample, timestep, encoder_hidden_states\n\n\npipeline = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n).to(\"cuda\")\nunet = pipeline.unet\nunet.eval()\nunet.to(memory\n_format=torch.channels_last)  # 使用 channels_last 内存格式\nunet.forward = functools.partial(unet.forward, return_dict=False)  # 设置 return_dict=False 为默认\n\n# 预热\nfor _ in range(3):\n    with torch.inference_mode():\n        inputs = generate_inputs()\n        orig_output = unet(*inputs)\n\n# 追踪\nprint(\"tracing..\")\nunet_traced = torch.jit.trace(unet, inputs)\nunet_traced.eval()\nprint(\"done tracing\")\n\n# 预热和优化图\nfor _ in range(5):\n    with torch.inference_mode():\n        inputs = generate_inputs()\n        orig_output = unet_traced(*inputs)\n\n# 基准测试\nwith torch.inference_mode():\n    for _ in range(n_experiments):\n        torch.cuda.synchronize()\n        start_time = time.time()\n        for _ in range(unet_runs_per_experiment):\n            orig_output = unet_traced(*inputs)\n        torch.cuda.synchronize()\n        print(f\"unet traced inference took {time.time() - start_time:.2f} seconds\")\n    for _ in range(n_experiments):\n        torch.cuda.synchronize()\n        start_time = time.time()\n        for _ in range(unet_runs_per_experiment):\n            orig_output = unet(*inputs)\n        torch.cuda.synchronize()\n        print(f\"unet inference took {time.time() - start_time:.2f} seconds\")\n\n# 保存模型\nunet_traced.save(\"unet_traced.pt\")\n```\n\n替换管道的 UNet 为追踪版本。\n\n```py\nimport torch\nfrom diffusers import StableDiffusionPipeline\nfrom dataclasses import dataclass\n\n@dataclass\nclass UNet2DConditionOutput:\n    sample: torch.Tensor\n\npipeline = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n).to(\"cuda\")\n\n# 使用 jitted unet\nunet_traced = torch.jit.load(\"unet_traced.pt\")\n\n# del pipeline.unet\nclass TracedUNet(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.in_channels = pipe.unet.config.in_channels\n        self.device = pipe.unet.device\n\n    def forward(self, latent_model_input, t, encoder_hidden_states):\n        sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]\n        return UNet2DConditionOutput(sample=sample)\n\npipeline.unet = TracedUNet()\n\nwith torch.inference_mode():\n    image = pipe([prompt] * 1, num_inference_steps=50).images[0]\n```\n\n## 内存高效注意力\n\n> [!TIP]\n> 内存高效注意力优化内存使用 *和* [推理速度](./fp16#scaled-dot-product-attention)！\n\nTransformers 注意力机制是内存密集型的，尤其对于长序列，因此您可以尝试使用不同且更内存高效的注意力类型。\n\n默认情况下，如果安装了 PyTorch >= 2.0，则使用 [scaled dot-product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)。您无需对代码进行任何额外更改。\n\nSDPA 还支持 [FlashAttention](https://github.com/Dao-AILab/flash-attention) 和 [xFormers](https://github.com/facebookresearch/xformers)，以及 a\n这是一个原生的 C++ PyTorch 实现。它会根据您的输入自动选择最优的实现。\n\n您可以使用 [`~ModelMixin.enable_xformers_memory_efficient_attention`] 方法显式地使用 xFormers。\n\n```py\n# pip install xformers\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\npipeline.enable_xformers_memory_efficient_attention()\n```\n\n调用 [`~ModelMixin.disable_xformers_memory_efficient_attention`] 来禁用它。\n\n```py\npipeline.disable_xformers_memory_efficient_attention()\n```"
  },
  {
    "path": "docs/source/zh/optimization/mps.md",
    "content": "<!--版权所有 2025 The HuggingFace Team。保留所有权利。\n\n根据 Apache 许可证 2.0 版本（\"许可证\"）授权；除非遵守许可证，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件按\"原样\"分发，无任何明示或暗示的担保或条件。请参阅许可证了解具体的语言管理权限和限制。\n-->\n\n# Metal Performance Shaders (MPS)\n\n> [!TIP]\n> 带有 <img alt=\"MPS\" src=\"https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22\"> 徽章的管道表示模型可以利用 Apple silicon 设备上的 MPS 后端进行更快的推理。欢迎提交 [Pull Request](https://github.com/huggingface/diffusers/compare) 来为缺少此徽章的管道添加它。\n\n🤗 Diffusers 与 Apple silicon（M1/M2 芯片）兼容，使用 PyTorch 的 [`mps`](https://pytorch.org/docs/stable/notes/mps.html) 设备，该设备利用 Metal 框架来发挥 MacOS 设备上 GPU 的性能。您需要具备：\n\n- 配备 Apple silicon（M1/M2）硬件的 macOS 计算机\n- macOS 12.6 或更高版本（推荐 13.0 或更高）\n- arm64 版本的 Python\n- [PyTorch 2.0](https://pytorch.org/get-started/locally/)（推荐）或 1.13（支持 `mps` 的最低版本）\n\n`mps` 后端使用 PyTorch 的 `.to()` 接口将 Stable Diffusion 管道移动到您的 M1 或 M2 设备上：\n\n```python\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\")\npipe = pipe.to(\"mps\")\n\n# 如果您的计算机内存小于 64 GB，推荐使用\npipe.enable_attention_slicing()\n\nprompt = \"a photo of an astronaut riding a horse on mars\"\nimage = pipe(prompt).images[0]\nimage\n```\n\n> [!WARNING]\n> PyTorch [mps](https://pytorch.org/docs/stable/notes/mps.html) 后端不支持大小超过 `2**32` 的 NDArray。如果您遇到此问题，请提交 [Issue](https://github.com/huggingface/diffusers/issues/new/choose) 以便我们调查。\n\n如果您使用 **PyTorch 1.13**，您需要通过管道进行一次额外的\"预热\"传递。这是一个临时解决方法，用于解决首次推理传递产生的结果与后续传递略有不同的问题。您只需要执行此传递一次，并且在仅进行一次推理步骤后可以丢弃结果。\n\n```diff\n  from diffusers import DiffusionPipeline\n\n  pipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\").to(\"mps\")\n  pipe.enable_attention_slicing()\n\n  prompt = \"a photo of an astronaut riding a horse on mars\"\n  # 如果 PyTorch 版本是 1.13，进行首次\"预热\"传递\n+ _ = pipe(prompt, num_inference_steps=1)\n\n  # 预热传递后，结果与 CPU 设备上的结果匹配。\n  image = pipe(prompt).images[0]\n```\n\n## 故障排除\n\n本节列出了使用 `mps` 后端时的一些常见问题及其解决方法。\n\n### 注意力切片\n\nM1/M2 性能对内存压力非常敏感。当发生这种情况时，系统会自动交换内存，这会显著降低性能。\n\n为了防止这种情况发生，我们建议使用*注意力切片*来减少推理过程中的内存压力并防止交换。这在您的计算机系统内存少于 64GB 或生成非标准分辨率（大于 512×512 像素）的图像时尤其相关。在您的管道上调用 [`~DiffusionPipeline.enable_attention_slicing`] 函数：\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True).to(\"mps\")\npipeline.enable_attention_slicing()\n```\n\n注意力切片将昂贵的注意力操作分多个步骤执行，而不是一次性完成。在没有统一内存的计算机中，它通常能提高约 20% 的性能，但我们观察到在大多数 Apple 芯片计算机中，除非您有 64GB 或更多 RAM，否则性能会*更好*。\n\n### 批量推理\n\n批量生成多个提示可能会导致崩溃或无法可靠工作。如果是这种情况，请尝试迭代而不是批量处理。"
  },
  {
    "path": "docs/source/zh/optimization/neuron.md",
    "content": "<!--版权所有 2025 The HuggingFace Team。保留所有权利。\n\n根据 Apache 许可证 2.0 版（“许可证”）授权；除非遵守许可证，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件按“原样”分发，无任何明示或暗示的担保或条件。请参阅许可证了解特定语言管理权限和限制。\n-->\n\n# AWS Neuron\n\nDiffusers 功能可在 [AWS Inf2 实例](https://aws.amazon.com/ec2/instance-types/inf2/)上使用，这些是由 [Neuron 机器学习加速器](https://aws.amazon.com/machine-learning/inferentia/)驱动的 EC2 实例。这些实例旨在提供更好的计算性能（更高的吞吐量、更低的延迟）和良好的成本效益，使其成为 AWS 用户将扩散模型部署到生产环境的良好选择。\n\n[Optimum Neuron](https://huggingface.co/docs/optimum-neuron/en/index) 是 Hugging Face 库与 AWS 加速器之间的接口，包括 AWS [Trainium](https://aws.amazon.com/machine-learning/trainium/) 和 AWS [Inferentia](https://aws.amazon.com/machine-learning/inferentia/)。它支持 Diffusers 中的许多功能，并具有类似的 API，因此如果您已经熟悉 Diffusers，学习起来更容易。一旦您创建了 AWS Inf2 实例，请安装 Optimum Neuron。\n\n```bash\npython -m pip install --upgrade-strategy eager optimum[neuronx]\n```\n\n> [!TIP]\n> 我们提供预构建的 [Hugging Face Neuron 深度学习 AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2)（DLAMI）和用于 Amazon SageMaker 的 Optimum Neuron 容器。建议正确设置您的环境。\n\n下面的示例演示了如何在 inf2.8xlarge 实例上使用 Stable Diffusion XL 模型生成图像（一旦模型编译完成，您可以切换到更便宜的 inf2.xlarge 实例）。要生成一些图像，请使用 [`~optimum.neuron.NeuronStableDiffusionXLPipeline`] 类，该类类似于 Diffusers 中的 [`StableDiffusionXLPipeline`] 类。\n\n与 Diffusers 不同，您需要将管道中的模型编译为 Neuron 格式，即 `.neuron`。运行以下命令将模型导出为 `.neuron` 格式。\n\n```bash\noptimum-cli export neuron --model stabilityai/stable-diffusion-xl-base-1.0 \\\n  --batch_size 1 \\\n  --height 1024 `# 生成图像的高度（像素），例如 768, 1024` \\\n  --width 1024 `# 生成图像的宽度（像素），例如 768, 1024` \\\n  --num_images_per_prompt 1 `# 每个提示生成的图像数量，默认为 1` \\\n  --auto_cast matmul `# 仅转换矩阵乘法操作` \\\n  --auto_cast_type bf16 `# 将操作从 FP32 转换为 BF16` \\\n  sd_neuron_xl/\n```\n\n现在使用预编译的 SDXL 模型生成一些图像。\n\n```python\n>>> from optimum.neuron import Neu\nronStableDiffusionXLPipeline\n\n>>> stable_diffusion_xl = NeuronStableDiffusionXLPipeline.from_pretrained(\"sd_neuron_xl/\")\n>>> prompt = \"a pig with wings flying in floating US dollar banknotes in the air, skyscrapers behind, warm color palette, muted colors, detailed, 8k\"\n>>> image = stable_diffusion_xl(prompt).images[0]\n```\n\n<img\n  src=\"https://huggingface.co/datasets/Jingya/document_images/resolve/main/optimum/neuron/sdxl_pig.png\"\n  width=\"256\"\n  height=\"256\"\n  alt=\"peggy generated by sdxl on inf2\"\n/>\n\n欢迎查看Optimum Neuron [文档](https://huggingface.co/docs/optimum-neuron/en/inference_tutorials/stable_diffusion#generate-images-with-stable-diffusion-models-on-aws-inferentia)中更多不同用例的指南和示例！"
  },
  {
    "path": "docs/source/zh/optimization/onnx.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\n根据 Apache License 2.0 许可证（以下简称\"许可证\"）授权，除非符合许可证要求，否则不得使用本文件。您可以通过以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或以书面形式同意，本软件按\"原样\"分发，不附带任何明示或暗示的担保或条件。详见许可证中规定的特定语言权限和限制。\n-->\n\n# ONNX Runtime\n\n🤗 [Optimum](https://github.com/huggingface/optimum) 提供了兼容 ONNX Runtime 的 Stable Diffusion 流水线。您需要运行以下命令安装支持 ONNX Runtime 的 🤗 Optimum：\n\n```bash\npip install -q optimum[\"onnxruntime\"]\n```\n\n本指南将展示如何使用 ONNX Runtime 运行 Stable Diffusion 和 Stable Diffusion XL (SDXL) 流水线。\n\n## Stable Diffusion\n\n要加载并运行推理，请使用 [`~optimum.onnxruntime.ORTStableDiffusionPipeline`]。若需加载 PyTorch 模型并实时转换为 ONNX 格式，请设置 `export=True`：\n\n```python\nfrom optimum.onnxruntime import ORTStableDiffusionPipeline\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipeline = ORTStableDiffusionPipeline.from_pretrained(model_id, export=True)\nprompt = \"sailing ship in storm by Leonardo da Vinci\"\nimage = pipeline(prompt).images[0]\npipeline.save_pretrained(\"./onnx-stable-diffusion-v1-5\")\n```\n\n> [!WARNING]\n> 当前批量生成多个提示可能会占用过高内存。在问题修复前，建议采用迭代方式而非批量处理。\n\n如需离线导出 ONNX 格式流水线供后续推理使用，请使用 [`optimum-cli export`](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) 命令：\n\n```bash\noptimum-cli export onnx --model stable-diffusion-v1-5/stable-diffusion-v1-5 sd_v15_onnx/\n```\n\n随后进行推理时（无需再次指定 `export=True`）：\n\n```python\nfrom optimum.onnxruntime import ORTStableDiffusionPipeline\n\nmodel_id = \"sd_v15_onnx\"\npipeline = ORTStableDiffusionPipeline.from_pretrained(model_id)\nprompt = \"sailing ship in storm by Leonardo da Vinci\"\nimage = pipeline(prompt).images[0]\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/optimum/documentation-images/resolve/main/onnxruntime/stable_diffusion_v1_5_ort_sail_boat.png\">\n</div>\n\n您可以在 🤗 Optimum [文档](https://huggingface.co/docs/optimum/) 中找到更多示例，Stable Diffusion 支持文生图、图生图和图像修复任务。\n\n## Stable Diffusion XL\n\n要加载并运行 SDXL 推理，请使用 [`~optimum.onnxruntime.ORTStableDiffusionXLPipeline`]：\n\n```python\nfrom optimum.onnxruntime import ORTStableDiffusionXLPipeline\n\nmodel_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\npipeline = ORTStableDiffusionXLPipeline.from_pretrained(model_id)\nprompt = \"sailing ship in storm by Leonardo da Vinci\"\nimage = pipeline(prompt).images[0]\n```\n\n如需导出 ONNX 格式流水线供后续推理使用，请运行：\n\n```bash\noptimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl sd_xl_onnx/\n```\n\nSDXL 的 ONNX 格式目前支持文生图和图生图任务。\n"
  },
  {
    "path": "docs/source/zh/optimization/open_vino.md",
    "content": "<!--版权所有 2025 HuggingFace 团队。保留所有权利。\n\n根据 Apache 许可证 2.0 版本（\"许可证\"）授权；除非遵守许可证，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件按\"原样\"分发，无任何明示或暗示的担保或条件。请参阅许可证以了解具体的语言管理权限和限制。\n-->\n\n# OpenVINO\n\n🤗 [Optimum](https://github.com/huggingface/optimum-intel) 提供与 OpenVINO 兼容的 Stable Diffusion 管道，可在各种 Intel 处理器上执行推理（请参阅支持的设备[完整列表](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html)）。\n\n您需要安装 🤗 Optimum Intel，并使用 `--upgrade-strategy eager` 选项以确保 [`optimum-intel`](https://github.com/huggingface/optimum-intel) 使用最新版本：\n\n```bash\npip install --upgrade-strategy eager optimum[\"openvino\"]\n```\n\n本指南将展示如何使用 Stable Diffusion 和 Stable Diffusion XL (SDXL) 管道与 OpenVINO。\n\n## Stable Diffusion\n\n要加载并运行推理，请使用 [`~optimum.intel.OVStableDiffusionPipeline`]。如果您想加载 PyTorch 模型并即时转换为 OpenVINO 格式，请设置 `export=True`：\n\n```python\nfrom optimum.intel import OVStableDiffusionPipeline\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipeline = OVStableDiffusionPipeline.from_pretrained(model_id, export=True)\nprompt = \"sailing ship in storm by Rembrandt\"\nimage = pipeline(prompt).images[0]\n\n# 别忘了保存导出的模型\npipeline.save_pretrained(\"openvino-sd-v1-5\")\n```\n\n为了进一步加速推理，静态重塑模型。如果您更改任何参数，例如输出高度或宽度，您需要再次静态重塑模型。\n\n```python\n# 定义与输入和期望输出相关的形状\nbatch_size, num_images, height, width = 1, 1, 512, 512\n\n# 静态重塑模型\npipeline.reshape(batch_size, height, width, num_images)\n# 在推理前编译模型\npipeline.compile()\n\nimage = pipeline(\n    prompt,\n    height=height,\n    width=width,\n    num_images_per_prompt=num_images,\n).images[0]\n```\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/optimum/documentation-images/resolve/main/intel/openvino/stable_diffusion_v1_5_sail_boat_rembrandt.png\">\n</div>\n\n您可以在 🤗 Optimum [文档](https://huggingface.co/docs/optimum/intel/inference#stable-diffusion) 中找到更多示例，Stable Diffusion 支持文本到图像、图像到图像和修复。\n\n## Stable Diffusion XL\n\n要加载并运行 SDXL 推理，请使用 [`~optimum.intel.OVStableDiffusionXLPipeline`]：\n\n```python\nfrom optimum.intel import OVStableDiffusionXLPipeline\n\nmodel_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\npipeline = OVStableDiffusionXLPipeline.from_pretrained(model_id)\nprompt = \"sailing ship in storm by Rembrandt\"\nimage = pipeline(prompt).images[0]\n```\n\n为了进一步加速推理，可以如Stable Diffusion部分所示[静态重塑](#stable-diffusion)模型。\n\n您可以在🤗 Optimum[文档](https://huggingface.co/docs/optimum/intel/inference#stable-diffusion-xl)中找到更多示例，并且在OpenVINO中运行SDXL支持文本到图像和图像到图像。"
  },
  {
    "path": "docs/source/zh/optimization/para_attn.md",
    "content": "# ParaAttention\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-performance.png\">\n</div>\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/hunyuan-video-performance.png\">\n</div>\n\n大型图像和视频生成模型，如 [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) 和 [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo)，由于其规模，可能对实时应用和部署构成推理挑战。\n\n[ParaAttention](https://github.com/chengzeyi/ParaAttention) 是一个实现了**上下文并行**和**第一块缓存**的库，可以与其他技术（如 torch.compile、fp8 动态量化）结合使用，以加速推理。\n\n本指南将展示如何在 NVIDIA L20 GPU 上对 FLUX.1-dev 和 HunyuanVideo 应用 ParaAttention。\n在我们的基线基准测试中，除了 HunyuanVideo 为避免内存不足错误外，未应用任何优化。\n\n我们的基线基准测试显示，FLUX.1-dev 能够在 28 步中生成 1024x1024 分辨率图像，耗时 26.36 秒；HunyuanVideo 能够在 30 步中生成 129 帧 720p 分辨率视频，耗时 3675.71 秒。\n\n> [!TIP]\n> 对于更快的上下文并行推理，请尝试使用支持 NVLink 的 NVIDIA A100 或 H100 GPU（如果可用），尤其是在 GPU 数量较多时。\n\n## 第一块缓存\n\n缓存模型中 transformer 块的输出并在后续推理步骤中重用它们，可以降低计算成本并加速推理。\n\n然而，很难决定何时重用缓存以确保生成图像或视频的质量。ParaAttention 直接使用**第一个 transformer 块输出的残差差异**来近似模型输出之间的差异。当差异足够小时，重用先前推理步骤的残差差异。换句话说，跳过去噪步骤。\n\n这在 FLUX.1-dev 和 HunyuanVideo 推理上实现了 2 倍加速，且质量非常好。\n\n<figure>\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/ada-cache.png\" alt=\"Cache in Diffusion Transformer\" />\n    <figcaption>AdaCache 的工作原理，第一块缓存是其变体</figcaption>\n</figure>\n\n<hfoptions id=\"first-block-cache\">\n<hfoption id=\"FLUX-1.dev\">\n\n要在 FLUX.1-dev 上应用第一块缓存，请调用 `apply_cache_on_pipe`，如下所示。0.08 是 FLUX 模型的默认残差差异值。\n\n```python\nimport time\nimport torch\nfrom diffusers import FluxPipeline\n\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(pipe, residual_diff_thre\nshold=0.08)\n\n# 启用内存节省\n# pipe.enable_model_cpu_offload()\n# pipe.enable_sequential_cpu_offload()\n\nbegin = time.time()\nimage = pipe(\n    \"A cat holding a sign that says hello world\",\n    num_inference_steps=28,\n).images[0]\nend = time.time()\nprint(f\"Time: {end - begin:.2f}s\")\n\nprint(\"Saving image to flux.png\")\nimage.save(\"flux.png\")\n```\n\n| 优化 | 原始 | FBCache rdt=0.06 | FBCache rdt=0.08 | FBCache rdt=0.10 | FBCache rdt=0.12 |\n| - | - | - | - | - | - |\n| 预览 | ![Original](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-original.png) | ![FBCache rdt=0.06](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.06.png) | ![FBCache rdt=0.08](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.08.png) | ![FBCache rdt=0.10](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.10.png) | ![FBCache rdt=0.12](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/flux-fbc-0.12.png) |\n| 墙时间 (s) | 26.36 | 21.83 | 17.01 | 16.00 | 13.78 |\n\nFirst Block Cache 将推理速度降低到 17.01 秒，与基线相比，或快 1.55 倍，同时保持几乎零质量损失。\n\n</hfoption>\n<hfoption id=\"HunyuanVideo\">\n\n要在 HunyuanVideo 上应用 First Block Cache，请使用 `apply_cache_on_pipe`，如下所示。0.06 是 HunyuanVideo 模型的默认残差差值。\n\n```python\nimport time\nimport torch\nfrom diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel\nfrom diffusers.utils import export_to_video\n\nmodel_id = \"tencent/HunyuanVideo\"\ntransformer = HunyuanVideoTransformer3DModel.from_pretrained(\n    model_id,\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16,\n    revision=\"refs/pr/18\",\n)\npipe = HunyuanVideoPipeline.from_pretrained(\n    model_id,\n    transformer=transformer,\n    torch_dtype=torch.float16,\n    revision=\"refs/pr/18\",\n).to(\"cuda\")\n\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(pipe, residual_diff_threshold=0.6)\n\npipe.vae.enable_tiling()\n\nbegin = time.time()\noutput = pipe(\n    prompt=\"A cat walks on the grass, realistic\",\n    height=720,\n    width=1280,\n    num_frames=129,\n    num_inference_steps=30,\n).frames[0]\nend = time.time()\nprint(f\"Time: {end - begin:.2f}s\")\n\nprint(\"Saving video to hunyuan_video.mp4\")\nexport_to_video(output, \"hunyuan_video.mp4\", fps=15)\n```\n\n<video controls>\n  <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/hunyuan-video-original.mp4\" type=\"video/mp4\">\n  您的浏览器不支持视频标签。\n</video>\n\n<small> HunyuanVideo 无 FBCache </small>\n\n<video controls>\n  <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/para-attn/hunyuan-video-fbc.mp4\" type=\"video/mp4\">\n  Your browser does not support the video tag.\n</video>\n\n<small> HunyuanVideo 与 FBCache </small>\n\nFirst Block Cache 将推理速度降低至 2271.06 秒，相比基线快了 1.62 倍，同时保持了几乎为零的质量损失。\n\n</hfoption>\n</hfoptions>\n\n## fp8 量化\n\nfp8 动态量化进一步加速推理并减少内存使用。为了使用 8 位 [NVIDIA Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/)，必须对激活和权重进行量化。\n\n使用 `float8_weight_only` 和 `float8_dynamic_activation_float8_weight` 来量化文本编码器和变换器模型。\n\n默认量化方法是逐张量量化，但如果您的 GPU 支持逐行量化，您也可以尝试它以获得更好的准确性。\n\n使用以下命令安装 [torchao](https://github.com/pytorch/ao/tree/main)。\n\n```bash\npip3 install -U torch torchao\n```\n\n[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) 使用 `mode=\"max-autotune-no-cudagraphs\"` 或 `mode=\"max-autotune\"` 选择最佳内核以获得性能。如果是第一次调用模型，编译可能会花费很长时间，但一旦模型编译完成，这是值得的。\n\n此示例仅量化变换器模型，但您也可以量化文本编码器以进一步减少内存使用。\n\n> [!TIP]\n> 动态量化可能会显著改变模型输出的分布，因此您需要将 `residual_diff_threshold` 设置为更大的值以使其生效。\n\n<hfoptions id=\"fp8-quantization\">\n<hfoption id=\"FLUX-1.dev\">\n\n```python\nimport time\nimport torch\nfrom diffusers import FluxPipeline\n\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(\n    pipe,\n    residual_diff_threshold=0.12,  # 使用更大的值以使缓存生效\n)\n\nfrom torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only\n\nquantize_(pipe.text_encoder, float8_weight_only())\nquantize_(pipe.transformer, float8_dynamic_activation_float8_weight())\npipe.transformer = torch.compile(\n   pipe.transformer, mode=\"max-autotune-no-cudagraphs\",\n)\n\n# 启用内存节省\n# pipe.enable_model_cpu_offload()\n# pipe.enable_sequential_cpu_offload()\n\nfor i in range(2):\n    begin = time.time()\n    image = pipe(\n        \"A cat holding a sign that says hello world\",\n        num_inference_steps=28,\n    ).images[0]\n    end = time.time()\n    if i == 0:\n        print(f\"预热时间: {end - begin:.2f}s\")\n    else:\n        print(f\"时间: {end - begin:.2f}s\")\n\nprint(\"保存图像到 flux.png\")\nimage.save(\"flux.png\")\n```\n\nfp8 动态量化和 torch.compile 将推理速度降低至 7.56 秒，相比基线快了 3.48 倍。\n</hfoption>\n<hfoption id=\"HunyuanVideo\">\n\n```python\nimport time\nimport torch\nfrom diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel\nfrom diffusers.utils import export_to_video\n\nmodel_id = \"tencent/HunyuanVideo\"\ntransformer = HunyuanVideoTransformer3DModel.from_pretrained(\n    model_id,\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16,\n    revision=\"refs/pr/18\",\n)\npipe = HunyuanVideoPipeline.from_pretrained(\n    model_id,\n    transformer=transformer,\n    torch_dtype=torch.float16,\n    revision=\"refs/pr/18\",\n).to(\"cuda\")\n\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(pipe)\n\nfrom torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only\n\nquantize_(pipe.text_encoder, float8_weight_only())\nquantize_(pipe.transformer, float8_dynamic_activation_float8_weight())\npipe.transformer = torch.compile(\n   pipe.transformer, mode=\"max-autotune-no-cudagraphs\",\n)\n\n# Enable memory savings\npipe.vae.enable_tiling()\n# pipe.enable_model_cpu_offload()\n# pipe.enable_sequential_cpu_offload()\n\nfor i in range(2):\n    begin = time.time()\n    output = pipe(\n        prompt=\"A cat walks on the grass, realistic\",\n        height=720,\n        width=1280,\n        num_frames=129,\n        num_inference_steps=1 if i == 0 else 30,\n    ).frames[0]\n    end = time.time()\n    if i == 0:\n        print(f\"Warm up time: {end - begin:.2f}s\")\n    else:\n        print(f\"Time: {end - begin:.2f}s\")\n\nprint(\"Saving video to hunyuan_video.mp4\")\nexport_to_video(output, \"hunyuan_video.mp4\", fps=15)\n```\n\nNVIDIA L20 GPU 仅有 48GB 内存，在编译后且如果未调用 `enable_model_cpu_offload` 时，可能会遇到内存不足（OOM）错误，因为 HunyuanVideo 在高分辨率和大量帧数运行时具有非常大的激活张量。对于内存少于 80GB 的 GPU，可以尝试降低分辨率和帧数来避免 OOM 错误。\n\n大型视频生成模型通常受注意力计算而非全连接层的瓶颈限制。这些模型不会从量化和 torch.compile 中显著受益。\n\n</hfoption>\n</hfoptions>\n\n## 上下文并行性\n\n上下文并行性并行化推理并随多个 GPU 扩展。ParaAttention 组合设计允许您将上下文并行性与第一块缓存和动态量化结合使用。\n\n> [!TIP]\n> 请参考 [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main) 仓库获取详细说明和如何使用多个 GPU 扩展推理的示例。\n\n如果推理过程需要持久化和可服务，建议使用 [torch.multiprocessing](https://pytorch.org/docs/stable/multiprocessing.html) 编写您自己的推理处理器。这可以消除启动进程以及加载和重新编译模型的开销。\n\n<hfoptions id=\"context-parallelism\">\n<hfoption id=\"FLUX-1.dev\">\n\n以下代码示例结合了第一块缓存、fp8动态量化、torch.compile和上下文并行，以实现最快的推理速度。\n\n```python\nimport time\nimport torch\nimport torch.distributed as dist\nfrom diffusers import FluxPipeline\n\ndist.init_process_group()\n\ntorch.cuda.set_device(dist.get_rank())\n\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\nfrom para_attn.context_parallel import init_context_parallel_mesh\nfrom para_attn.context_parallel.diffusers_adapters import parallelize_pipe\nfrom para_attn.parallel_vae.diffusers_adapters import parallelize_vae\n\nmesh = init_context_parallel_mesh(\n    pipe.device.type,\n    max_ring_dim_size=2,\n)\nparallelize_pipe(\n    pipe,\n    mesh=mesh,\n)\nparallelize_vae(pipe.vae, mesh=mesh._flatten())\n\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(\n    pipe,\n    residual_diff_threshold=0.12,  # 使用较大的值以使缓存生效\n)\n\nfrom torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only\n\nquantize_(pipe.text_encoder, float8_weight_only())\nquantize_(pipe.transformer, float8_dynamic_activation_float8_weight())\ntorch._inductor.config.reorder_for_compute_comm_overlap = True\npipe.transformer = torch.compile(\n   pipe.transformer, mode=\"max-autotune-no-cudagraphs\",\n)\n\n# 启用内存节省\n# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())\n# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank())\n\nfor i in range(2):\n    begin = time.time()\n    image = pipe(\n        \"A cat holding a sign that says hello world\",\n        num_inference_steps=28,\n        output_type=\"pil\" if dist.get_rank() == 0 else \"pt\",\n    ).images[0]\n    end = time.time()\n    if dist.get_rank() == 0:\n        if i == 0:\n            print(f\"预热时间: {end - begin:.2f}s\")\n        else:\n            print(f\"时间: {end - begin:.2f}s\")\n\nif dist.get_rank() == 0:\n    print(\"将图像保存到flux.png\")\n    image.save(\"flux.png\")\n\ndist.destroy_process_group()\n```\n\n保存到`run_flux.py`并使用[torchrun](https://pytorch.org/docs/stable/elastic/run.html)启动。\n\n```bash\n# 使用--nproc_per_node指定GPU数量\ntorchrun --nproc_per_node=2 run_flux.py\n```\n\n推理速度降至8.20秒，相比基线快了3.21倍，使用2个NVIDIA L20 GPU。在4个L20上，推理速度为3.90秒，快了6.75倍。\n\n</hfoption>\n<hfoption id=\"HunyuanVideo\">\n\n以下代码示例结合了第一块缓存和上下文并行，以实现最快的推理速度。\n\n```python\nimport time\nimport torch\nimport torch.distributed as dist\nfrom diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel\nfrom diffusers.utils import export_to_video\n\ndist.init_process_group()\n\ntorch.cuda.set_device(dist.get_rank())\n\nmodel_id = \"tencent/HunyuanVideo\"\ntransformer = HunyuanVideoTransformer3DModel.from_pretrained(\n    model_id,\n    subfolder=\"transformer\",\n    torch_dtype=torch.bfloat16,\n    revision=\"refs/pr/18\",\n)\npipe = HunyuanVideoPipeline.from_pretrained(\n    model_id,\n    transformer=transformer,\n    torch_dtype=torch.float16,\n    revision=\"refs/pr/18\",\n).to(\"cuda\")\n\nfrom para_attn.context_parallel import init_context_parallel_mesh\nfrom para_attn.context_parallel.diffusers_adapters import parallelize_pipe\nfrom para_attn.parallel_vae.diffusers_adapters import parallelize_vae\n\nmesh = init_context_parallel_mesh(\n    pipe.device.type,\n)\nparallelize_pipe(\n    pipe,\n    mesh=mesh,\n)\nparallelize_vae(pipe.vae, mesh=mesh._flatten())\n\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(pipe)\n\n# from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only\n#\n# torch._inductor.config.reorder_for_compute_comm_overlap = True\n#\n# quantize_(pipe.text_encoder, float8_weight_only())\n# quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())\n# pipe.transformer = torch.compile(\n#    pipe.transformer, mode=\"max-autotune-no-cudagraphs\",\n# )\n\n# 启用内存节省\npipe.vae.enable_tiling()\n# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())\n# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank())\n\nfor i in range(2):\n    begin = time.time()\n    output = pipe(\n        prompt=\"A cat walks on the grass, realistic\",\n        height=720,\n        width=1280,\n        num_frames=129,\n        num_inference_steps=1 if i == 0 else 30,\n        output_type=\"pil\" if dist.get_rank() == 0 else \"pt\",\n    ).frames[0]\n    end = time.time()\n    if dist.get_rank() == 0:\n        if i == 0:\n            print(f\"预热时间: {end - begin:.2f}s\")\n        else:\n            print(f\"时间: {end - begin:.2f}s\")\n\nif dist.get_rank() == 0:\n    print(\"保存视频到 hunyuan_video.mp4\")\n    export_to_video(output, \"hunyuan_video.mp4\", fps=15)\n\ndist.destroy_process_group()\n```\n\n保存到 `run_hunyuan_video.py` 并使用 [torchrun](https://pytorch.org/docs/stable/elastic/run.html) 启动。\n\n```bash\n# 使用 --nproc_per_node 指定 GPU 数量\ntorchrun --nproc_per_node=8 run_hunyuan_video.py\n```\n\n推理速度降低到 649.23 秒，相比基线快 5.66 倍，使用 8 个 NVIDIA L20 GPU。\n\n</hfoption>\n</hfoptions>\n\n## 基准测试\n\n<hfoptions id=\"conclusion\">\n<hfoption id=\"FLUX-1.dev\">\n\n| GPU 类型 | GPU 数量 | 优化 | 墙钟时间 (s) | 加速比 |\n| - | - | - | - | - |\n| NVIDIA L20 | 1 | 基线 | 26.36 | 1.00x |\n| NVIDIA L20 | 1 | FBCache (rdt=0.08) | 17.01 | 1.55x |\n| NVIDIA L20 | 1 | FP8 DQ | 13.40 | 1.96x |\n| NVIDIA L20 | 1 | FBCache (rdt=0.12) + FP8 DQ | 7.56 | 3.48x |\n| NVIDIA L20 | 2 | FBCache (rdt=0.12) + FP8 DQ + CP | 4.92 | 5.35x |\n| NVIDIA L20 | 4 | FBCache (rdt=0.12) + FP8 DQ + CP | 3.90 | 6.75x |\n\n</hfoption>\n<hfoption id=\"HunyuanVideo\">\n\n| GPU 类型 | GPU 数量 | 优化 | 墙钟时间 (s) | 加速比 |\n| - | - | - | - | - |\n| NVIDIA L20 | 1 | 基线 | 3675.71 | 1.00x |\n| NVIDIA\nL20 | 1 | FBCache | 2271.06 | 1.62x |\n| NVIDIA L20 | 2 | FBCache + CP | 1132.90 | 3.24x |\n| NVIDIA L20 | 4 | FBCache + CP | 718.15 | 5.12x |\n| NVIDIA L20 | 8 | FBCache + CP | 649.23 | 5.66x |\n\n</hfoption>\n</hfoptions>"
  },
  {
    "path": "docs/source/zh/optimization/pruna.md",
    "content": "# Pruna\n\n[Pruna](https://github.com/PrunaAI/pruna) 是一个模型优化框架，提供多种优化方法——量化、剪枝、缓存、编译——以加速推理并减少内存使用。以下是优化方法的概览。\n\n| 技术       | 描述                                                                                   | 速度 | 内存 | 质量 |\n|------------|---------------------------------------------------------------------------------------|:----:|:----:|:----:|\n| `batcher`  | 将多个输入分组在一起同时处理，提高计算效率并减少处理时间。                                  | ✅   | ❌   | ➖   |\n| `cacher`   | 存储计算的中间结果以加速后续操作。                                                       | ✅   | ➖   | ➖   |\n| `compiler` | 为特定硬件优化模型指令。                                                                 | ✅   | ➖   | ➖   |\n| `distiller`| 训练一个更小、更简单的模型来模仿一个更大、更复杂的模型。                                   | ✅   | ✅   | ❌   |\n| `quantizer`| 降低权重和激活的精度，减少内存需求。                                                       | ✅   | ✅   | ❌   |\n| `pruner`   | 移除不重要或冗余的连接和神经元，产生一个更稀疏、更高效的网络。                               | ✅   | ✅   | ❌   |\n| `recoverer`| 在压缩后恢复模型的性能。                                                                 | ➖   | ➖   | ✅   |\n| `factorizer`| 将多个小矩阵乘法批处理为一个大型融合操作。                                                | ✅   | ➖   | ➖   |\n| `enhancer` | 通过应用后处理算法（如去噪或上采样）来增强模型输出。                                        | ❌   | -    | ✅   |\n\n✅ (改进), ➖ (大致相同), ❌ (恶化)\n\n在 [Pruna 文档](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms) 中探索所有优化方法。\n\n## 安装\n\n使用以下命令安装 Pruna。\n\n```bash\npip install pruna\n```\n\n## 优化 Diffusers 模型\n\nDiffusers 模型支持广泛的优化算法，如下所示。\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/diffusers_combinations.png\" alt=\"Diffusers 模型支持的优化算法概览\">\n</div>\n\n下面的示例使用 factorizer、compiler 和 cacher 算法的组合优化 [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)。这种组合将推理速度加速高达 4.2 倍，并将峰值 GPU 内存使用从 34.7GB 减少到 28.0GB，同时几乎保持相同的输出质量。\n\n> [!TIP]\n> 参考 [Pruna 优化](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html) 文档以了解更多关于该操作的信息。\n本示例中使用的优化技术。\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_combination.png\" alt=\"用于FLUX.1-dev的优化技术展示，结合了因子分解器、编译器和缓存器算法\">\n</div>\n\n首先定义一个包含要使用的优化算法的`SmashConfig`。要优化模型，将管道和`SmashConfig`用`smash`包装，然后像往常一样使用管道进行推理。\n\n```python\nimport torch\nfrom diffusers import FluxPipeline\n\nfrom pruna import PrunaModel, SmashConfig, smash\n\n# 加载模型\n# 使用小GPU内存尝试segmind/Segmind-Vega或black-forest-labs/FLUX.1-schnell\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\n# 定义配置\nsmash_config = SmashConfig()\nsmash_config[\"factorizer\"] = \"qkv_diffusers\"\nsmash_config[\"compiler\"] = \"torch_compile\"\nsmash_config[\"torch_compile_target\"] = \"module_list\"\nsmash_config[\"cacher\"] = \"fora\"\nsmash_config[\"fora_interval\"] = 2\n\n# 为了获得最佳速度结果，可以添加这些配置\n# 但它们会将预热时间从1.5分钟增加到10分钟\n# smash_config[\"torch_compile_mode\"] = \"max-autotune-no-cudagraphs\"\n# smash_config[\"quantizer\"] = \"torchao\"\n# smash_config[\"torchao_quant_type\"] = \"fp8dq\"\n# smash_config[\"torchao_excluded_modules\"] = \"norm+embedding\"\n\n# 优化模型\nsmashed_pipe = smash(pipe, smash_config)\n\n# 运行模型\nsmashed_pipe(\"a knitted purple prune\").images[0]\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_smashed_comparison.png\">\n</div>\n\n优化后，我们可以使用Hugging Face Hub共享和加载优化后的模型。\n\n```python\n# 保存模型\nsmashed_pipe.save_to_hub(\"<username>/FLUX.1-dev-smashed\")\n\n# 加载模型\nsmashed_pipe = PrunaModel.from_hub(\"<username>/FLUX.1-dev-smashed\")\n```\n\n## 评估和基准测试Diffusers模型\n\nPruna提供了[EvaluationAgent](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html)来评估优化后模型的质量。\n\n我们可以定义我们关心的指标，如总时间和吞吐量，以及要评估的数据集。我们可以定义一个模型并将其传递给`EvaluationAgent`。\n\n<hfoptions id=\"eval\">\n<hfoption id=\"optimized model\">\n\n我们可以通过使用`EvaluationAgent`加载和评估优化后的模型，并将其传递给`Task`。\n\n```python\nimport torch\nfrom diffusers import FluxPipeline\n\nfrom pruna import PrunaModel\nfrom pruna.data.pruna_datamodule import PrunaDataModule\nfrom pruna.evaluation.evaluation_agent import EvaluationAgent\nfrom pruna.evaluation.metrics import (\n    ThroughputMetric,\n    TorchMetricWrapper,\n    TotalTimeMetric,\n)\nfrom pruna.evaluation.task import Task\n\n# define the device\ndevice = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n\n# 加载模型\n# 使用小GPU内存尝试 PrunaAI/Segmind-Vega-smashed 或 PrunaAI/FLUX.1-dev-smashed\nsmashed_pipe = PrunaModel.from_hub(\"PrunaAI/FLUX.1-dev-smashed\")\n\n# 定义指标\nmetrics = [\n    TotalTimeMetric(n_iterations=20, n_warmup_iterations=5),\n    ThroughputMetric(n_iterations=20, n_warmup_iterations=5),\n    TorchMetricWrapper(\"clip\"),\n]\n\n# 定义数据模块\ndatamodule = PrunaDataModule.from_string(\"LAION256\")\ndatamodule.limit_datasets(10)\n\n# 定义任务和评估代理\ntask = Task(metrics, datamodule=datamodule, device=device)\neval_agent = EvaluationAgent(task)\n\n# 评估优化模型并卸载到CPU\nsmashed_pipe.move_to_device(device)\nsmashed_pipe_results = eval_agent.evaluate(smashed_pipe)\nsmashed_pipe.move_to_device(\"cpu\")\n```\n\n</hfoption>\n<hfoption id=\"standalone model\">\n\n除了比较优化模型与基础模型，您还可以评估独立的 `diffusers` 模型。这在您想评估模型性能而不考虑优化时非常有用。我们可以通过使用 `PrunaModel` 包装器并运行 `EvaluationAgent` 来实现。\n\n```python\nimport torch\nfrom diffusers import FluxPipeline\n\nfrom pruna import PrunaModel\n\n# 加载模型\n# 使用小GPU内存尝试 PrunaAI/Segmind-Vega-smashed 或 PrunaAI/FLUX.1-dev-smashed\npipe = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16\n).to(\"cpu\")\nwrapped_pipe = PrunaModel(model=pipe)\n```\n\n</hfoption>\n</hfoptions>\n\n现在您已经了解了如何优化和评估您的模型，可以开始使用 Pruna 来优化您自己的模型了。幸运的是，我们有许多示例来帮助您入门。\n\n> [!TIP]\n> 有关基准测试 Flux 的更多详细信息，请查看 [宣布 FLUX-Juiced：最快的图像生成端点（快 2.6 倍）！](https://huggingface.co/blog/PrunaAI/flux-fastest-image-generation-endpoint) 博客文章和 [InferBench](https://huggingface.co/spaces/PrunaAI/InferBench) 空间。\n\n## 参考\n\n- [Pruna](https://github.com/pruna-ai/pruna)\n- [Pruna 优化](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms)\n- [Pruna 评估](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html)\n- [Pruna 教程](https://docs.pruna.ai/en/stable/docs_pruna/tutorials/index.html)"
  },
  {
    "path": "docs/source/zh/optimization/speed-memory-optims.md",
    "content": "<!--版权所有 2024 The HuggingFace Team。保留所有权利。\n\n根据 Apache 许可证 2.0 版（“许可证”）授权；除非符合许可证，否则不得使用此文件。\n您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件按“原样”分发，不附带任何明示或暗示的担保或条件。有关许可证的特定语言，请参阅许可证。\n-->\n\n# 编译和卸载量化模型\n\n优化模型通常涉及[推理速度](./fp16)和[内存使用](./memory)之间的权衡。例如，虽然[缓存](./cache)可以提高推理速度，但它也会增加内存消耗，因为它需要存储中间注意力层的输出。一种更平衡的优化策略结合了量化模型、[torch.compile](./fp16#torchcompile) 和各种[卸载方法](./memory#offloading)。\n\n> [!TIP]\n> 查看 [torch.compile](./fp16#torchcompile) 指南以了解更多关于编译以及如何在此处应用的信息。例如，区域编译可以显著减少编译时间，而不会放弃任何加速。\n\n对于图像生成，结合量化和[模型卸载](./memory#model-offloading)通常可以在质量、速度和内存之间提供最佳权衡。组卸载对于图像生成效果不佳，因为如果计算内核更快完成，通常不可能*完全*重叠数据传输。这会导致 CPU 和 GPU 之间的一些通信开销。\n\n对于视频生成，结合量化和[组卸载](./memory#group-offloading)往往更好，因为视频模型更受计算限制。\n\n下表提供了优化策略组合及其对 Flux 延迟和内存使用的影响的比较。\n\n| 组合 | 延迟 (s) | 内存使用 (GB) |\n|---|---|---|\n| 量化 | 32.602 | 14.9453 |\n| 量化, torch.compile | 25.847 | 14.9448 |\n| 量化, torch.compile, 模型 CPU 卸载 | 32.312 | 12.2369 |\n<small>这些结果是在 Flux 上使用 RTX 4090 进行基准测试的。transformer 和 text_encoder 组件已量化。如果您有兴趣评估自己的模型，请参考[基准测试脚本](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d)。</small>\n\n本指南将向您展示如何使用 [bitsandbytes](../quantization/bitsandbytes#torchcompile) 编译和卸载量化模型。确保您正在使用 [PyTorch nightly](https://pytorch.org/get-started/locally/) 和最新版本的 bitsandbytes。\n\n```bash\npip install -U bitsandbytes\n```\n\n## 量化和 torch.compile\n\n首先通过[量化](../quantization/overview)模型来减少存储所需的内存，并[编译](./fp16#torchcompile)它以加速推理。\n\n配置 [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `capture_dynamic_output_shape_ops = True` 以在编译 bitsandbytes 模型时处理动态输出。\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\n\ntorch._dynamo.config.capture_dynamic_output_shape_ops = True\n\n# 量化\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_backend=\"bitsandbytes_4bit\",\n    quant_kwargs={\"load_in_4bit\": True, \"bnb_4bit_quant_type\": \"nf4\", \"bnb_4bit_compute_dtype\": torch.bfloat16},\n    components_to_quantize=[\"transformer\", \"text_encoder_2\"],\n)\npipeline = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\n# 编译\npipeline.transformer.to(memory_format=torch.channels_last)\npipeline.transformer.compile(mode=\"max-autotune\", fullgraph=True)\npipeline(\"\"\"\n    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California\n    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\n\"\"\"\n).images[0]\n```\n\n## 量化、torch.compile 和卸载\n\n除了量化和 torch.compile，如果您需要进一步减少内存使用，可以尝试卸载。卸载根据需要将各种层或模型组件从 CPU 移动到 GPU 进行计算。\n\n在卸载期间配置 [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `cache_size_limit` 以避免过多的重新编译，并设置 `capture_dynamic_output_shape_ops = True` 以在编译 bitsandbytes 模型时处理动态输出。\n\n<hfoptions id=\"offloading\">\n<hfoption id=\"model CPU offloading\">\n\n[模型 CPU 卸载](./memory#model-offloading) 将单个管道组件（如 transformer 模型）在需要计算时移动到 GPU。否则，它会被卸载到 CPU。\n\n```py\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom diffusers.quantizers import PipelineQuantizationConfig\n\ntorch._dynamo.config.cache_size_limit = 1000\ntorch._dynamo.config.capture_dynamic_output_shape_ops = True\n\n# 量化\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_backend=\"bitsandbytes_4bit\",\n    quant_kwargs={\"load_in_4bit\": True, \"bnb_4bit_quant_type\": \"nf4\", \"bnb_4bit_compute_dtype\": torch.bfloat16},\n    components_to_quantize=[\"transformer\", \"text_encoder_2\"],\n)\npipeline = DiffusionPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\n# 模型 CPU 卸载\npipeline.enable_model_cpu_offload()\n\n# 编译\npipeline.transformer.compile()\npipeline(\n    \"cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain\"\n).images[0]\n```\n\n</hfoption>\n<hfoption id=\"group offloading\">\n\n[组卸载](./memory#group-offloading) 将单个管道组件（如变换器模型）的内部层移动到 GPU 进行计算，并在不需要时将其卸载。同时，它使用 [CUDA 流](./memory#cuda-stream) 功能来预取下一层以执行。\n\n通过重叠计算和数据传输，它比模型 CPU 卸载更快，同时还能节省内存。\n\n```py\n# pip install ftfy\nimport torch\nfrom diffusers import AutoModel, DiffusionPipeline\nfrom diffusers.hooks import apply_group_offloading\nfrom diffusers.utils import export_to_video\nfrom diffusers.quantizers import PipelineQuantizationConfig\nfrom transformers import UMT5EncoderModel\n\ntorch._dynamo.config.cache_size_limit = 1000\ntorch._dynamo.config.capture_dynamic_output_shape_ops = True\n\n# 量化\npipeline_quant_config = PipelineQuantizationConfig(\n    quant_backend=\"bitsandbytes_4bit\",\n    quant_kwargs={\"load_in_4bit\": True, \"bnb_4bit_quant_type\": \"nf4\", \"bnb_4bit_compute_dtype\": torch.bfloat16},\n    components_to_quantize=[\"transformer\", \"text_encoder\"],\n)\n\ntext_encoder = UMT5EncoderModel.from_pretrained(\n    \"Wan-AI/Wan2.1-T2V-14B-Diffusers\", subfolder=\"text_encoder\", torch_dtype=torch.bfloat16\n)\npipeline = DiffusionPipeline.from_pretrained(\n    \"Wan-AI/Wan2.1-T2V-14B-Diffusers\",\n    quantization_config=pipeline_quant_config,\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\n# 组卸载\nonload_device = torch.device(\"cuda\")\noffload_device = torch.device(\"cpu\")\n\npipeline.transformer.enable_group_offload(\n    onload_device=onload_device,\n    offload_device=offload_device,\n    offload_type=\"leaf_level\",\n    use_stream=True,\n    non_blocking=True\n)\npipeline.vae.enable_group_offload(\n    onload_device=onload_device,\n    offload_device=offload_device,\n    offload_type=\"leaf_level\",\n    use_stream=True,\n    non_blocking=True\n)\napply_group_offloading(\n    pipeline.text_encoder,\n    onload_device=onload_device,\n    offload_type=\"leaf_level\",\n    use_stream=True,\n    non_blocking=True\n)\n\n# 编译\npipeline.transformer.compile()\n\nprompt = \"\"\"\nThe camera rushes from far to near in a low-angle shot, \nrevealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in \nfor a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. \nBirch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic \nshadows and warm highlights. Medium composition, front view, low angle, with depth of field.\n\"\"\"\nnegative_prompt = \"\"\"\nBright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, \nlow quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, \nmisshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\n\"\"\"\n\noutput = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_frames=81,\n    guidance_scale=5.0,\n).frames[0]\nexport_to_video(output, \"output.mp4\", fps=16)\n```\n\n</hfoption>\n</hfoptions>"
  },
  {
    "path": "docs/source/zh/optimization/tgate.md",
    "content": "# T-GATE\n\n[T-GATE](https://github.com/HaozheLiu-ST/T-GATE/tree/main) 通过跳过交叉注意力计算一旦收敛，加速了 [Stable Diffusion](../api/pipelines/stable_diffusion/overview)、[PixArt](../api/pipelines/pixart) 和 [Latency Consistency Model](../api/pipelines/latent_consistency_models.md) 管道的推理。此方法不需要任何额外训练，可以将推理速度提高 10-50%。T-GATE 还与 [DeepCache](./deepcache) 等其他优化方法兼容。\n\n开始之前，请确保安装 T-GATE。\n\n```bash\npip install tgate\npip install -U torch diffusers transformers accelerate DeepCache\n```\n\n要使用 T-GATE 与管道，您需要使用其对应的加载器。\n\n| 管道 | T-GATE 加载器 |\n|---|---|\n| PixArt | TgatePixArtLoader |\n| Stable Diffusion XL | TgateSDXLLoader |\n| Stable Diffusion XL + DeepCache | TgateSDXLDeepCacheLoader |\n| Stable Diffusion | TgateSDLoader |\n| Stable Diffusion + DeepCache | TgateSDDeepCacheLoader |\n\n接下来，创建一个 `TgateLoader`，包含管道、门限步骤（停止计算交叉注意力的时间步）和推理步骤数。然后在管道上调用 `tgate` 方法，提供提示、门限步骤和推理步骤数。\n\n让我们看看如何为几个不同的管道启用此功能。\n\n<hfoptions id=\"pipelines\">\n<hfoption id=\"PixArt\">\n\n使用 T-GATE 加速 `PixArtAlphaPipeline`：\n\n```py\nimport torch\nfrom diffusers import PixArtAlphaPipeline\nfrom tgate import TgatePixArtLoader\n\npipe = PixArtAlphaPipeline.from_pretrained(\"PixArt-alpha/PixArt-XL-2-1024-MS\", torch_dtype=torch.float16)\n\ngate_step = 8\ninference_step = 25\npipe = TgatePixArtLoader(\n       pipe,\n       gate_step=gate_step,\n       num_inference_steps=inference_step,\n).to(\"cuda\")\n\nimage = pipe.tgate(\n       \"An alpaca made of colorful building blocks, cyberpunk.\",\n       gate_step=gate_step,\n       num_inference_steps=inference_step,\n).images[0]\n```\n</hfoption>\n<hfoption id=\"Stable Diffusion XL\">\n\n使用 T-GATE 加速 `StableDiffusionXLPipeline`：\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom diffusers import DPMSolverMultistepScheduler\nfrom tgate import TgateSDXLLoader\n\npipe = StableDiffusionXLPipeline.from_pretrained(\n            \"stabilityai/stable-diffusion-xl-base-1.0\",\n            torch_dtype=torch.float16,\n            variant=\"fp16\",\n            use_safetensors=True,\n)\npipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n\ngate_step = 10\ninference_step = 25\npipe = TgateSDXLLoader(\n       pipe,\n       gate_step=gate_step,\n       num_inference_steps=inference_step,\n).to(\"cuda\")\n\nimage = pipe.tgate(\n       \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.\",\n       gate_step=gate_step,\n       num_inference_steps=inference_step\n).images[0]\n```\n</hfoption>\n<hfoption id=\"StableDiffusionXL with DeepCache\">\n\n使用 [DeepCache](https://github.co 加速 `StableDiffusionXLPipeline`\nm/horseee/DeepCache) 和 T-GATE：\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom diffusers import DPMSolverMultistepScheduler\nfrom tgate import TgateSDXLDeepCacheLoader\n\npipe = StableDiffusionXLPipeline.from_pretrained(\n            \"stabilityai/stable-diffusion-xl-base-1.0\",\n            torch_dtype=torch.float16,\n            variant=\"fp16\",\n            use_safetensors=True,\n)\npipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n\ngate_step = 10\ninference_step = 25\npipe = TgateSDXLDeepCacheLoader(\n       pipe,\n       cache_interval=3,\n       cache_branch_id=0,\n).to(\"cuda\")\n\nimage = pipe.tgate(\n       \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.\",\n       gate_step=gate_step,\n       num_inference_steps=inference_step\n).images[0]\n```\n</hfoption>\n<hfoption id=\"Latent Consistency Model\">\n\n使用 T-GATE 加速 `latent-consistency/lcm-sdxl`：\n\n```py\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom diffusers import UNet2DConditionModel, LCMScheduler\nfrom diffusers import DPMSolverMultistepScheduler\nfrom tgate import TgateSDXLLoader\n\nunet = UNet2DConditionModel.from_pretrained(\n    \"latent-consistency/lcm-sdxl\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n)\npipe = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    unet=unet,\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n)\npipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)\n\ngate_step = 1\ninference_step = 4\npipe = TgateSDXLLoader(\n       pipe,\n       gate_step=gate_step,\n       num_inference_steps=inference_step,\n       lcm=True\n).to(\"cuda\")\n\nimage = pipe.tgate(\n       \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.\",\n       gate_step=gate_step,\n       num_inference_steps=inference_step\n).images[0]\n```\n</hfoption>\n</hfoptions>\n\nT-GATE 还支持 [`StableDiffusionPipeline`] 和 [PixArt-alpha/PixArt-LCM-XL-2-1024-MS](https://hf.co/PixArt-alpha/PixArt-LCM-XL-2-1024-MS)。\n\n## 基准测试\n| 模型                 | MACs     | 参数     | 延迟 | 零样本 10K-FID on MS-COCO |\n|-----------------------|----------|-----------|---------|---------------------------|\n| SD-1.5                | 16.938T  | 859.520M  | 7.032s  | 23.927                    |\n| SD-1.5 w/ T-GATE       | 9.875T   | 815.557M  | 4.313s  | 20.789                    |\n| SD-2.1                | 38.041T  | 865.785M  | 16.121s | 22.609                    |\n| SD-2.1 w/ T-GATE       | 22.208T  | 815.433 M | 9.878s  | 19.940                    |\n| SD-XL                 | 149.438T | 2.570B    | 53.187s | 24.628                    |\n| SD-XL w/ T-GATE        | 84.438T  | 2.024B    | 27.932s | 22.738                    |\n| Pixart-Alpha          | 107.031T | 611.350M  | 61.502s | 38.669                    |\n| Pixart-Alpha w/ T-GATE | 65.318T  | 462.585M  | 37.867s | 35.825                    |\n| DeepCache (SD-XL)     | 57.888T  | -         | 19.931s | 23.755                    |\n| DeepCache 配合 T-GATE    | 43.868T  | -         | 14.666秒 | 23.999                    |\n| LCM (SD-XL)           | 11.955T  | 2.570B    | 3.805秒  | 25.044                    |\n| LCM 配合 T-GATE          | 11.171T  | 2.024B    | 3.533秒  | 25.028                    |\n| LCM (Pixart-Alpha)    | 8.563T   | 611.350M  | 4.733秒  | 36.086                    |\n| LCM 配合 T-GATE          | 7.623T   | 462.585M  | 4.543秒  | 37.048                    |\n\n延迟测试基于 NVIDIA 1080TI，MACs 和 Params 使用 [calflops](https://github.com/MrYxJ/calculate-flops.pytorch) 计算，FID 使用 [PytorchFID](https://github.com/mseitzer/pytorch-fid) 计算。"
  },
  {
    "path": "docs/source/zh/optimization/tome.md",
    "content": "<!--版权所有 2025 The HuggingFace Team。保留所有权利。\n\n根据 Apache 许可证 2.0 版（“许可证”）授权；除非遵守许可证，否则不得使用此文件。\n您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件按“原样”分发，不附带任何明示或暗示的担保或条件。请参阅许可证以了解具体的语言管理权限和限制。\n-->\n\n# 令牌合并\n\n[令牌合并](https://huggingface.co/papers/2303.17604)（ToMe）在基于 Transformer 的网络的前向传递中逐步合并冗余令牌/补丁，这可以加速 [`StableDiffusionPipeline`] 的推理延迟。\n\n从 `pip` 安装 ToMe：\n\n```bash\npip install tomesd\n```\n\n您可以使用 [`tomesd`](https://github.com/dbolya/tomesd) 库中的 [`apply_patch`](https://github.com/dbolya/tomesd?tab=readme-ov-file#usage) 函数：\n\n```diff\n  from diffusers import StableDiffusionPipeline\n  import torch\n  import tomesd\n\n  pipeline = StableDiffusionPipeline.from_pretrained(\n        \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, use_safetensors=True,\n  ).to(\"cuda\")\n+ tomesd.apply_patch(pipeline, ratio=0.5)\n\n  image = pipeline(\"a photo of an astronaut riding a horse on mars\").images[0]\n```\n\n`apply_patch` 函数公开了多个[参数](https://github.com/dbolya/tomesd#usage)，以帮助在管道推理速度和生成令牌的质量之间取得平衡。最重要的参数是 `ratio`，它控制在前向传递期间合并的令牌数量。\n\n如[论文](https://huggingface.co/papers/2303.17604)中所述，ToMe 可以在显著提升推理速度的同时，很大程度上保留生成图像的质量。通过增加 `ratio`，您可以进一步加速推理，但代价是图像质量有所下降。\n\n为了测试生成图像的质量，我们从 [Parti Prompts](https://parti.research.google/) 中采样了一些提示，并使用 [`StableDiffusionPipeline`] 进行了推理，设置如下：\n\n<div class=\"flex justify-center\">\n      <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/tome/tome_samples.png\">\n</div>\n\n我们没有注意到生成样本的质量有任何显著下降，您可以在此 [WandB 报告](https://wandb.ai/sayakpaul/tomesd-results/runs/23j4bj3i?workspace=)中查看生成的样本。如果您有兴趣重现此实验，请使用此[脚本](https://gist.github.com/sayakpaul/8cac98d7f22399085a060992f411ecbd)。\n\n## 基准测试\n\n我们还在启用 [xFormers](https://huggingface.co/docs/diffusers/optimization/xformers) 的情况下，对 [`StableDiffusionPipeline`] 上 `tomesd` 的影响进行了基准测试，涵盖了多个图像分辨率。结果\n结果是从以下开发环境中的A100和V100 GPU获得的：\n\n```bash\n- `diffusers` 版本：0.15.1\n- Python 版本：3.8.16\n- PyTorch 版本（GPU？）：1.13.1+cu116 (True)\n- Huggingface_hub 版本：0.13.2\n- Transformers 版本：4.27.2\n- Accelerate 版本：0.18.0\n- xFormers 版本：0.0.16\n- tomesd 版本：0.1.2\n```\n\n要重现此基准测试，请随意使用此[脚本](https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335)。结果以秒为单位报告，并且在适用的情况下，我们报告了使用ToMe和ToMe + xFormers时相对于原始管道的加速百分比。\n\n| **GPU**  | **分辨率** | **批处理大小** | **原始** | **ToMe**       | **ToMe + xFormers** |\n|----------|----------------|----------------|-------------|----------------|---------------------|\n| **A100** |            512 |             10 |        6.88 | 5.26 (+23.55%) |      4.69 (+31.83%) |\n|          |            768 |             10 |         OOM |          14.71 |                  11 |\n|          |                |              8 |         OOM |          11.56 |                8.84 |\n|          |                |              4 |         OOM |           5.98 |                4.66 |\n|          |                |              2 |        4.99 | 3.24 (+35.07%) |       2.1 (+37.88%) |\n|          |                |              1 |        3.29 | 2.24 (+31.91%) |       2.03 (+38.3%) |\n|          |           1024 |             10 |         OOM |            OOM |                 OOM |\n|          |                |              8 |         OOM |            OOM |                 OOM |\n|          |                |              4 |         OOM |          12.51 |                9.09 |\n|          |                |              2 |         OOM |           6.52 |                4.96 |\n|          |                |              1 |         6.4 | 3.61 (+43.59%) |      2.81 (+56.09%) |\n| **V100** |            512 |             10 |         OOM |          10.03 |                9.29 |\n|          |                |              8 |         OOM |           8.05 |                7.47 |\n|          |                |              4 |         5.7 |  4.3 (+24.56%) |      3.98 (+30.18%) |\n|          |                |              2 |        3.14 | 2.43 (+22.61%) |      2.27 (+27.71%) |\n|          |                |              1 |        1.88 | 1.57 (+16.49%) |      1.57 (+16.49%) |\n|          |            768 |             10 |         OOM |            OOM |               23.67 |\n|          |                |              8 |         OOM |            OOM |               18.81 |\n|          |                |              4 |         OOM |          11.81 |                 9.7 |\n|          |                |              2 |         OOM |           6.27 |                 5.2 |\n|          |                |              1 |        5.43 | 3.38 (+37.75%) |      2.82 (+48.07%) |\n|          |           1024 |             10 |         OOM |            \n如上表所示，`tomesd` 带来的加速效果在更大的图像分辨率下变得更加明显。有趣的是，使用 `tomesd` 可以在更高分辨率如 1024x1024 上运行管道。您可能还可以通过 [`torch.compile`](fp16#torchcompile) 进一步加速推理。"
  },
  {
    "path": "docs/source/zh/optimization/xdit.md",
    "content": "# xDiT\n\n[xDiT](https://github.com/xdit-project/xDiT) 是一个推理引擎，专为大规模并行部署扩散变换器（DiTs）而设计。xDiT 提供了一套用于扩散模型的高效并行方法，以及 GPU 内核加速。\n\nxDiT 支持四种并行方法，包括[统一序列并行](https://huggingface.co/papers/2405.07719)、[PipeFusion](https://huggingface.co/papers/2405.14430)、CFG 并行和数据并行。xDiT 中的这四种并行方法可以以混合方式配置，优化通信模式以最适合底层网络硬件。\n\n与并行化正交的优化侧重于加速单个 GPU 的性能。除了利用知名的注意力优化库外，我们还利用编译加速技术，如 torch.compile 和 onediff。\n\nxDiT 的概述如下所示。\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/methods/xdit_overview.png\">\n</div>\n您可以使用以下命令安装 xDiT：\n\n```bash\npip install xfuser\n```\n\n以下是一个使用 xDiT 加速 Diffusers 模型推理的示例。\n\n```diff\n import torch\n from diffusers import StableDiffusion3Pipeline\n\n from xfuser import xFuserArgs, xDiTParallel\n from xfuser.config import FlexibleArgumentParser\n from xfuser.core.distributed import get_world_group\n\n def main():\n+    parser = FlexibleArgumentParser(description=\"xFuser Arguments\")\n+    args = xFuserArgs.add_cli_args(parser).parse_args()\n+    engine_args = xFuserArgs.from_cli_args(args)\n+    engine_config, input_config = engine_args.create_config()\n\n     local_rank = get_world_group().local_rank\n     pipe = StableDiffusion3Pipeline.from_pretrained(\n         pretrained_model_name_or_path=engine_config.model_config.model,\n         torch_dtype=torch.float16,\n     ).to(f\"cuda:{local_rank}\")\n    \n# 在这里对管道进行任何操作\n\n+    pipe = xDiTParallel(pipe, engine_config, input_config)\n\n     pipe(\n         height=input_config.height,\n         width=input_config.height,\n         prompt=input_config.prompt,\n         num_inference_steps=input_config.num_inference_steps,\n         output_type=input_config.output_type,\n         generator=torch.Generator(device=\"cuda\").manual_seed(input_config.seed),\n     )\n\n+    if input_config.output_type == \"pil\":\n+        pipe.save(\"results\", \"stable_diffusion_3\")\n\nif __name__ == \"__main__\":\n    main()\n```\n\n如您所见，我们只需要使用 xDiT 中的 xFuserArgs 来获取配置参数，并将这些参数与来自 Diffusers 库的管道对象一起传递给 xDiTParallel，即可完成对 Diffusers 中特定管道的并行化。\n\nxDiT 运行时参数可以在命令行中使用 `-h` 查看，您可以参考此[使用](https://github.com/xdit-project/xDiT?tab=readme-ov-file#2-usage)示例以获取更多详细信息。\nils。\n\nxDiT 需要使用 torchrun 启动，以支持其多节点、多 GPU 并行能力。例如，以下命令可用于 8-GPU 并行推理：\n\n```bash\ntorchrun --nproc_per_node=8 ./inference.py --model models/FLUX.1-dev --data_parallel_degree 2 --ulysses_degree 2 --ring_degree 2 --prompt \"A snowy mountain\" \"A small dog\" --num_inference_steps 50\n```\n\n## 支持的模型\n\n在 xDiT 中支持 Diffusers 模型的一个子集，例如 Flux.1、Stable Diffusion 3 等。最新支持的模型可以在[这里](https://github.com/xdit-project/xDiT?tab=readme-ov-file#-supported-dits)找到。\n\n## 基准测试\n我们在不同机器上测试了各种模型，以下是一些基准数据。\n\n### Flux.1-schnell\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/flux/Flux-2k-L40.png\">\n</div>\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/flux/Flux-2K-A100.png\">\n</div>\n\n### Stable Diffusion 3\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/sd3/L40-SD3.png\">\n</div>\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/sd3/A100-SD3.png\">\n</div>\n\n### HunyuanDiT\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/hunuyuandit/L40-HunyuanDiT.png\">\n</div>\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/hunuyuandit/V100-HunyuanDiT.png\">\n</div>\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/hunuyuandit/T4-HunyuanDiT.png\">\n</div>\n\n更详细的性能指标可以在我们的 [GitHub 页面](https://github.com/xdit-project/xDiT?tab=readme-ov-file#perf) 上找到。\n\n## 参考文献\n\n[xDiT-project](https://github.com/xdit-project/xDiT)\n\n[USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://huggingface.co/papers/2405.07719)\n\n[PipeFusion: Displaced Patch Pipeline Parallelism for Inference of Diffusion Transformer Models](https://huggingface.co/papers/2405.14430)"
  },
  {
    "path": "docs/source/zh/optimization/xformers.md",
    "content": "<!--版权归2025年HuggingFace团队所有。保留所有权利。\n\n根据Apache许可证2.0版（\"许可证\"）授权；除非符合许可证要求，否则不得使用本文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，本软件按\"原样\"分发，不附带任何明示或暗示的担保或条件。详见许可证中规定的特定语言及限制条款。\n-->\n\n# xFormers\n\n我们推荐在推理和训练过程中使用[xFormers](https://github.com/facebookresearch/xformers)。在我们的测试中，其对注意力模块的优化能同时提升运行速度并降低内存消耗。\n\n通过`pip`安装xFormers：\n\n```bash\npip install xformers\n```\n\n> [!TIP]\n> xFormers的`pip`安装包需要最新版本的PyTorch。如需使用旧版PyTorch，建议[从源码安装xFormers](https://github.com/facebookresearch/xformers#installing-xformers)。\n\n安装完成后，您可调用`enable_xformers_memory_efficient_attention()`来实现更快的推理速度和更低的内存占用，具体用法参见[此章节](memory#memory-efficient-attention)。\n\n> [!WARNING]\n> 根据[此问题](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212)反馈，xFormers `v0.0.16`版本在某些GPU上无法用于训练（微调或DreamBooth）。如遇此问题，请按照该issue评论区指引安装开发版本。"
  },
  {
    "path": "docs/source/zh/quicktour.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n[[open-in-colab]]\n\n# 快速上手\n\n训练扩散模型，是为了对随机高斯噪声进行逐步去噪，以生成令人感兴趣的样本，比如图像或者语音。\n\n扩散模型的发展引起了人们对生成式人工智能的极大兴趣，你可能已经在网上见过扩散生成的图像了。🧨 Diffusers库的目的是让大家更易上手扩散模型。\n\n无论你是开发人员还是普通用户，本文将向你介绍🧨 Diffusers 并帮助你快速开始生成内容！\n\n🧨 Diffusers 库的三个主要组件：\n\n\n无论你是开发者还是普通用户，这个快速指南将向你介绍🧨 Diffusers，并帮助你快速使用和生成！该库三个主要部分如下：\n\n* [`DiffusionPipeline`]是一个高级的端到端类，旨在通过预训练的扩散模型快速生成样本进行推理。\n* 作为创建扩散系统做组件的流行的预训练[模型](./api/models)框架和模块。\n* 许多不同的[调度器](./api/schedulers/overview)：控制如何在训练过程中添加噪声的算法，以及如何在推理过程中生成去噪图像的算法。\n\n快速入门将告诉你如何使用[`DiffusionPipeline`]进行推理，然后指导你如何结合模型和调度器以复现[`DiffusionPipeline`]内部发生的事情。\n\n> [!TIP]\n> 快速入门是🧨[Diffusers入门](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)的简化版，可以帮助你快速上手。如果你想了解更多关于🧨 Diffusers的目标、设计理念以及关于它的核心API的更多细节，可以点击🧨[Diffusers入门](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)查看。\n\n在开始之前，确认一下你已经安装好了所需要的库：\n\n```bash\npip install --upgrade diffusers accelerate transformers\n```\n\n- [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) 在推理和训练过程中加速模型加载。\n- [🤗 Transformers](https://huggingface.co/docs/transformers/index) 是运行最流行的扩散模型所必须的库，比如[Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview).\n\n## 扩散模型管道\n\n[`DiffusionPipeline`]是用预训练的扩散系统进行推理的最简单方法。它是一个包含模型和调度器的端到端系统。你可以直接使用[`DiffusionPipeline`]完成许多任务。请查看下面的表格以了解一些支持的任务，要获取完整的支持任务列表，请查看[🧨 Diffusers 总结](./api/pipelines/overview#diffusers-summary) 。\n\n| **任务**                     | **描述**                                                                                              | **管道**\n|------------------------------|--------------------------------------------------------------------------------------------------------------|-----------------|\n| Unconditional Image Generation          | 从高斯噪声中生成图片 | [unconditional_image_generation](./using-diffusers/unconditional_image_generation) |\n| Text-Guided Image Generation | 给定文本提示生成图像 | [conditional_image_generation](./using-diffusers/conditional_image_generation) |\n| Text-Guided Image-to-Image Translation     | 在文本提示的指导下调整图像 | [img2img](./using-diffusers/img2img) |\n| Text-Guided Image-Inpainting          | 给出图像、遮罩和文本提示，填充图像的遮罩部分 | [inpaint](./using-diffusers/inpaint) |\n| Text-Guided Depth-to-Image Translation | 在文本提示的指导下调整图像的部分内容，同时通过深度估计保留其结构 | [depth2img](./using-diffusers/depth2img) |\n\n首先创建一个[`DiffusionPipeline`]的实例，并指定要下载的pipeline检查点。\n你可以使用存储在Hugging Face Hub上的任何[`DiffusionPipeline`][检查点](https://huggingface.co/models?library=diffusers&sort=downloads)。\n在教程中，你将加载[`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)检查点，用于文本到图像的生成。\n\n首先创建一个[DiffusionPipeline]实例，并指定要下载的管道检查点。\n您可以在Hugging Face Hub上使用[DiffusionPipeline]的任何检查点。\n在本快速入门中，您将加载stable-diffusion-v1-5检查点，用于文本到图像生成。\n\n> [!WARNING]\n> 。\n>\n> 对于[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion)模型，在运行该模型之前，请先仔细阅读[许可证](https://huggingface.co/spaces/CompVis/stable-diffusion-license)。🧨 Diffusers实现了一个[`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py)，以防止有攻击性的或有害的内容，但Stable Diffusion模型改进图像的生成能力仍有可能产生潜在的有害内容。\n\n用[`~DiffusionPipeline.from_pretrained`]方法加载模型。\n\n```python\n>>> from diffusers import DiffusionPipeline\n\n>>> pipeline = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\")\n```\n[`DiffusionPipeline`]会下载并缓存所有的建模、标记化和调度组件。你可以看到Stable Diffusion的pipeline是由[`UNet2DConditionModel`]和[`PNDMScheduler`]等组件组成的：\n\n```py\n>>> pipeline\nStableDiffusionPipeline {\n  \"_class_name\": \"StableDiffusionPipeline\",\n  \"_diffusers_version\": \"0.13.1\",\n  ...,\n  \"scheduler\": [\n    \"diffusers\",\n    \"PNDMScheduler\"\n  ],\n  ...,\n  \"unet\": [\n    \"diffusers\",\n    \"UNet2DConditionModel\"\n  ],\n  \"vae\": [\n    \"diffusers\",\n    \"AutoencoderKL\"\n  ]\n}\n```\n\n我们强烈建议你在GPU上运行这个pipeline，因为该模型由大约14亿个参数组成。\n\n你可以像在Pytorch里那样把生成器对象移到GPU上：\n\n```python\n>>> pipeline.to(\"cuda\")\n```\n\n现在你可以向`pipeline`传递一个文本提示来生成图像，然后获得去噪的图像。默认情况下，图像输出被放在一个[`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class)对象中。\n\n```python\n>>> image = pipeline(\"An image of a squirrel in Picasso style\").images[0]\n>>> image\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png\"/>\n</div>\n\n\n调用`save`保存图像:\n\n```python\n>>> image.save(\"image_of_squirrel_painting.png\")\n```\n\n### 本地管道\n\n你也可以在本地使用管道。唯一的区别是你需提前下载权重：\n\n```\ngit lfs install\ngit clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5\n```\n\n将下载好的权重加载到管道中:\n\n```python\n>>> pipeline = DiffusionPipeline.from_pretrained(\"./stable-diffusion-v1-5\")\n```\n\n现在你可以像上一节中那样运行管道了。\n\n### 更换调度器\n\n不同的调度器对去噪速度和质量的权衡是不同的。要想知道哪种调度器最适合你，最好的办法就是试用一下。🧨 Diffusers的主要特点之一是允许你轻松切换不同的调度器。例如，要用[`EulerDiscreteScheduler`]替换默认的[`PNDMScheduler`]，用[`~diffusers.ConfigMixin.from_config`]方法加载即可：\n\n```py\n>>> from diffusers import EulerDiscreteScheduler\n\n>>> pipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\")\n>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)\n```\n\n\n试着用新的调度器生成一个图像，看看你能否发现不同之处。\n\n在下一节中，你将仔细观察组成[`DiffusionPipeline`]的组件——模型和调度器，并学习如何使用这些组件来生成猫咪的图像。\n\n## 模型\n\n大多数模型取一个噪声样本，在每个时间点预测*噪声残差*（其他模型则直接学习预测前一个样本或速度或[`v-prediction`](https://github.com/huggingface/diffusers/blob/5e5ce13e2f89ac45a0066cb3f369462a3cf1d9ef/src/diffusers/schedulers/scheduling_ddim.py#L110)），即噪声较小的图像与输入图像的差异。你可以混搭模型创建其他扩散系统。\n\n模型是用[`~ModelMixin.from_pretrained`]方法启动的，该方法还在本地缓存了模型权重，所以下次加载模型时更快。对于快速入门，你默认加载的是[`UNet2DModel`]，这是一个基础的无条件图像生成模型，该模型有一个在猫咪图像上训练的检查点：\n\n\n```py\n>>> from diffusers import UNet2DModel\n\n>>> repo_id = \"google/ddpm-cat-256\"\n>>> model = UNet2DModel.from_pretrained(repo_id)\n```\n\n想知道模型的参数，调用 `model.config`:\n\n```py\n>>> model.config\n```\n\n模型配置是一个🧊冻结的🧊字典，意思是这些参数在模型创建后就不变了。这是特意设置的，确保在开始时用于定义模型架构的参数保持不变，其他参数仍然可以在推理过程中进行调整。\n\n一些最重要的参数：\n\n* `sample_size`：输入样本的高度和宽度尺寸。\n* `in_channels`：输入样本的输入通道数。\n* `down_block_types`和`up_block_types`：用于创建U-Net架构的下采样和上采样块的类型。\n* `block_out_channels`：下采样块的输出通道数；也以相反的顺序用于上采样块的输入通道数。\n* `layers_per_block`：每个U-Net块中存在的ResNet块的数量。\n\n为了使用该模型进行推理，用随机高斯噪声生成图像形状。它应该有一个`batch`轴，因为模型可以接收多个随机噪声，一个`channel`轴，对应于输入通道的数量，以及一个`sample_size`轴，对应图像的高度和宽度。\n\n\n```py\n>>> import torch\n\n>>> torch.manual_seed(0)\n\n>>> noisy_sample = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)\n>>> noisy_sample.shape\ntorch.Size([1, 3, 256, 256])\n```\n\n对于推理，将噪声图像和一个`timestep`传递给模型。`timestep` 表示输入图像的噪声程度，开始时噪声更多，结束时噪声更少。这有助于模型确定其在扩散过程中的位置，是更接近开始还是结束。使用 `sample` 获得模型输出：\n\n\n```py\n>>> with torch.no_grad():\n...     noisy_residual = model(sample=noisy_sample, timestep=2).sample\n```\n\n想生成实际的样本，你需要一个调度器指导去噪过程。在下一节中，你将学习如何把模型与调度器结合起来。\n\n## 调度器\n\n调度器管理一个噪声样本到一个噪声较小的样本的处理过程，给出模型输出 —— 在这种情况下，它是`noisy_residual`。\n\n\n\n> [!TIP]\n> 🧨 Diffusers是一个用于构建扩散系统的工具箱。预定义好的扩散系统[`DiffusionPipeline`]能方便你快速试用，你也可以单独选择自己的模型和调度器组件来建立一个自定义的扩散系统。\n\n在快速入门教程中，你将用它的[`~diffusers.ConfigMixin.from_config`]方法实例化[`DDPMScheduler`]：\n\n```py\n>>> from diffusers import DDPMScheduler\n\n>>> scheduler = DDPMScheduler.from_config(repo_id)\n>>> scheduler\nDDPMScheduler {\n  \"_class_name\": \"DDPMScheduler\",\n  \"_diffusers_version\": \"0.13.1\",\n  \"beta_end\": 0.02,\n  \"beta_schedule\": \"linear\",\n  \"beta_start\": 0.0001,\n  \"clip_sample\": true,\n  \"clip_sample_range\": 1.0,\n  \"num_train_timesteps\": 1000,\n  \"prediction_type\": \"epsilon\",\n  \"trained_betas\": null,\n  \"variance_type\": \"fixed_small\"\n}\n```\n\n> [!TIP]\n> 💡 注意调度器是如何从配置中实例化的。与模型不同，调度器没有可训练的权重，而且是无参数的。\n\n* `num_train_timesteps`：去噪过程的长度，或者换句话说，将随机高斯噪声处理成数据样本所需的时间步数。\n* `beta_schedule`：用于推理和训练的噪声表。\n* `beta_start`和`beta_end`：噪声表的开始和结束噪声值。\n\n要预测一个噪音稍小的图像，请将 模型输出、`timestep`和当前`sample` 传递给调度器的[`~diffusers.DDPMScheduler.step`]方法：\n\n\n```py\n>>> less_noisy_sample = scheduler.step(model_output=noisy_residual, timestep=2, sample=noisy_sample).prev_sample\n>>> less_noisy_sample.shape\n```\n\n这个 `less_noisy_sample` 去噪样本 可以被传递到下一个`timestep` ，处理后会将变得噪声更小。现在让我们把所有步骤合起来，可视化整个去噪过程。\n\n首先，创建一个函数，对去噪后的图像进行后处理并显示为`PIL.Image`：\n\n```py\n>>> import PIL.Image\n>>> import numpy as np\n\n\n>>> def display_sample(sample, i):\n...     image_processed = sample.cpu().permute(0, 2, 3, 1)\n...     image_processed = (image_processed + 1.0) * 127.5\n...     image_processed = image_processed.numpy().astype(np.uint8)\n\n...     image_pil = PIL.Image.fromarray(image_processed[0])\n...     display(f\"Image at step {i}\")\n...     display(image_pil)\n```\n\n将输入和模型移到GPU上加速去噪过程：\n\n```py\n>>> model.to(\"cuda\")\n>>> noisy_sample = noisy_sample.to(\"cuda\")\n```\n\n现在创建一个去噪循环，该循环预测噪声较少样本的残差，并使用调度程序计算噪声较少的样本：\n\n```py\n>>> import tqdm\n\n>>> sample = noisy_sample\n\n>>> for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):\n...     # 1. predict noise residual\n...     with torch.no_grad():\n...         residual = model(sample, t).sample\n\n...     # 2. compute less noisy image and set x_t -> x_t-1\n...     sample = scheduler.step(residual, t, sample).prev_sample\n\n...     # 3. optionally look at image\n...     if (i + 1) % 50 == 0:\n...         display_sample(sample, i + 1)\n```\n\n看！这样就从噪声中生成出一只猫了！😻\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/diffusion-quicktour.png\"/>\n</div>\n\n## 下一步\n\n希望你在这次快速入门教程中用🧨Diffuser 生成了一些很酷的图像! 下一步你可以:\n\n* 在[训练](./tutorials/basic_training)教程中训练或微调一个模型来生成你自己的图像。\n* 查看官方和社区的[训练或微调脚本](https://github.com/huggingface/diffusers/tree/main/examples#-diffusers-examples)的例子，了解更多使用情况。\n* 在[使用不同的调度器](./using-diffusers/schedulers)指南中了解更多关于加载、访问、更改和比较调度器的信息。\n* 在[Stable Diffusion](./stable_diffusion)教程中探索提示工程、速度和内存优化，以及生成更高质量图像的技巧。\n* 通过[在GPU上优化PyTorch](./optimization/fp16)指南，以及运行[Apple (M1/M2)上的Stable Diffusion](./optimization/mps)和[ONNX Runtime](./optimization/onnx)的教程，更深入地了解如何加速🧨Diffuser。"
  },
  {
    "path": "docs/source/zh/stable_diffusion.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 有效且高效的扩散\n\n[[open-in-colab]]\n\n让 [`DiffusionPipeline`] 生成特定风格或包含你所想要的内容的图像可能会有些棘手。 通常情况下，你需要多次运行 [`DiffusionPipeline`] 才能得到满意的图像。但是从无到有生成图像是一个计算密集的过程，特别是如果你要一遍又一遍地进行推理运算。\n\n这就是为什么从pipeline中获得最高的 *computational* (speed) 和 *memory* (GPU RAM) 非常重要 ，以减少推理周期之间的时间，从而使迭代速度更快。\n\n\n本教程将指导您如何通过 [`DiffusionPipeline`]  更快、更好地生成图像。\n\n\n首先，加载 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 模型:\n\n```python\nfrom diffusers import DiffusionPipeline\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\npipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)\n```\n\n本教程将使用的提示词是 [`portrait photo of a old warrior chief`] ，但是你可以随心所欲的想象和构造自己的提示词：\n\n```python\nprompt = \"portrait photo of a old warrior chief\"\n```\n\n## 速度\n\n> [!TIP]\n> 💡 如果你没有 GPU, 你可以从像 [Colab](https://colab.research.google.com/) 这样的 GPU 提供商获取免费的 GPU !\n\n加速推理的最简单方法之一是将 pipeline 放在 GPU 上 ，就像使用任何 PyTorch 模块一样：\n\n```python\npipeline = pipeline.to(\"cuda\")\n```\n\n为了确保您可以使用相同的图像并对其进行改进，使用 [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) 方法，然后设置一个随机数种子 以确保其 [复现性](./using-diffusers/reusing_seeds):\n\n```python\nimport torch\n\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\n```\n\n现在，你可以生成一个图像：\n\n```python\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_1.png\">\n</div>\n\n在 T4 GPU 上，这个过程大概要30秒（如果你的 GPU 比 T4 好，可能会更快）。在默认情况下，[`DiffusionPipeline`] 使用完整的 `float32` 精度进行 50 步推理。你可以通过降低精度（如 `float16` ）或者减少推理步数来加速整个过程\n\n\n让我们把模型的精度降低至 `float16` ，然后生成一张图像：\n\n```python\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)\npipeline = pipeline.to(\"cuda\")\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_2.png\">\n</div>\n\n这一次，生成图像只花了约 11 秒，比之前快了近 3 倍！\n\n> [!TIP]\n> 💡 我们强烈建议把 pipeline 精度降低至 `float16` , 到目前为止, 我们很少看到输出质量有任何下降。\n\n另一个选择是减少推理步数。 你可以选择一个更高效的调度器 (*scheduler*) 可以减少推理步数同时保证输出质量。您可以在 [DiffusionPipeline] 中通过调用compatibles方法找到与当前模型兼容的调度器 (*scheduler*)。\n\n```python\npipeline.scheduler.compatibles\n[\n    diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,\n    diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,\n    diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,\n    diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,\n    diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,\n    diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,\n    diffusers.schedulers.scheduling_ddpm.DDPMScheduler,\n    diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,\n    diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,\n    diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,\n    diffusers.schedulers.scheduling_pndm.PNDMScheduler,\n    diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,\n    diffusers.schedulers.scheduling_ddim.DDIMScheduler,\n]\n```\n\nStable Diffusion 模型默认使用的是 [`PNDMScheduler`] ，通常要大概50步推理, 但是像 [`DPMSolverMultistepScheduler`] 这样更高效的调度器只要大概 20 或 25 步推理. 使用 [`ConfigMixin.from_config`] 方法加载新的调度器:\n\n```python\nfrom diffusers import DPMSolverMultistepScheduler\n\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n```\n\n现在将 `num_inference_steps` 设置为 20:\n\n```python\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\nimage = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]\nimage\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_3.png\">\n</div>\n\n太棒了！你成功把推理时间缩短到 4 秒！⚡️\n\n## 内存\n\n改善 pipeline 性能的另一个关键是减少内存的使用量，这间接意味着速度更快，因为你经常试图最大化每秒生成的图像数量。要想知道你一次可以生成多少张图片，最简单的方法是尝试不同的batch size，直到出现`OutOfMemoryError` (OOM)。\n\n创建一个函数，为每一批要生成的图像分配提示词和 `Generators` 。请务必为每个`Generator` 分配一个种子，以便于复现良好的结果。\n\n\n```python\ndef get_inputs(batch_size=1):\n    generator = [torch.Generator(\"cuda\").manual_seed(i) for i in range(batch_size)]\n    prompts = batch_size * [prompt]\n    num_inference_steps = 20\n\n    return {\"prompt\": prompts, \"generator\": generator, \"num_inference_steps\": num_inference_steps}\n```\n\n设置 `batch_size=4` ，然后看一看我们消耗了多少内存:\n\n```python\nfrom diffusers.utils import make_image_grid\n\nimages = pipeline(**get_inputs(batch_size=4)).images\nmake_image_grid(images, 2, 2)\n```\n\n除非你有一个更大内存的GPU, 否则上述代码会返回 `OOM` 错误! 大部分内存被 cross-attention 层使用。按顺序运行可以节省大量内存，而不是在批处理中进行。你可以为 pipeline 配置 [`~DiffusionPipeline.enable_attention_slicing`] 函数:\n\n```python\npipeline.enable_attention_slicing()\n```\n\n现在尝试把 `batch_size` 增加到 8!\n\n```python\nimages = pipeline(**get_inputs(batch_size=8)).images\nmake_image_grid(images, rows=2, cols=4)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_5.png\">\n</div>\n\n以前你不能一批生成 4 张图片，而现在你可以在一张图片里面生成八张图片而只需要大概3.5秒！这可能是 T4 GPU 在不牺牲质量的情况运行速度最快的一种方法。\n\n## 质量\n\n在最后两节中, 你要学习如何通过 `fp16` 来优化 pipeline 的速度, 通过使用性能更高的调度器来减少推理步数, 使用注意力切片（*enabling attention slicing*）方法来节省内存。现在，你将关注的是如何提高图像的质量。\n\n### 更好的 checkpoints\n\n有个显而易见的方法是使用更好的 checkpoints。 Stable Diffusion 模型是一个很好的起点, 自正式发布以来，还发布了几个改进版本。然而, 使用更新的版本并不意味着你会得到更好的结果。你仍然需要尝试不同的 checkpoints ，并做一些研究 (例如使用 [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)) 来获得更好的结果。\n\n随着该领域的发展, 有越来越多经过微调的高质量的 checkpoints 用来生成不一样的风格. 在 [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) 和 [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) 寻找你感兴趣的一种!\n\n### 更好的 pipeline 组件\n\n也可以尝试用新版本替换当前 pipeline 组件。让我们加载最新的 [autodecoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae) 从 Stability AI 加载到 pipeline, 并生成一些图像:\n\n```python\nfrom diffusers import AutoencoderKL\n\nvae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16).to(\"cuda\")\npipeline.vae = vae\nimages = pipeline(**get_inputs(batch_size=8)).images\nmake_image_grid(images, rows=2, cols=4)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_6.png\">\n</div>\n\n### 更好的提示词工程\n\n用于生成图像的文本非常重要, 因此被称为 *提示词工程*。 在设计提示词工程应注意如下事项:\n\n- 我想生成的图像或类似图像如何存储在互联网上？\n- 我可以提供哪些额外的细节来引导模型朝着我想要的风格生成？\n\n考虑到这一点，让我们改进提示词，以包含颜色和更高质量的细节：\n\n```python\nprompt += \", tribal panther make up, blue on red, side profile, looking away, serious eyes\"\nprompt += \" 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\"\n```\n\n使用新的提示词生成一批图像:\n\n```python\nimages = pipeline(**get_inputs(batch_size=8)).images\nmake_image_grid(images, rows=2, cols=4)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_7.png\">\n</div>\n\n非常的令人印象深刻! Let's tweak the second image - 把 `Generator` 的种子设置为 `1` - 添加一些关于年龄的主题文本:\n\n```python\nprompts = [\n    \"portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\",\n    \"portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\",\n    \"portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\",\n    \"portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\",\n]\n\ngenerator = [torch.Generator(\"cuda\").manual_seed(1) for _ in range(len(prompts))]\nimages = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images\nmake_image_grid(images, 2, 2)\n```\n\n<div class=\"flex justify-center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_8.png\">\n</div>\n\n## 最后\n\n在本教程中, 您学习了如何优化[`DiffusionPipeline`]以提高计算和内存效率，以及提高生成输出的质量. 如果你有兴趣让你的 pipeline 更快, 可以看一看以下资源:\n\n- 学习 [PyTorch 2.0](./optimization/torch2.0) 和 [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) 可以让推理速度提高 5 - 300% . 在 A100 GPU 上, 推理速度可以提高 50% !\n- 如果你没法用 PyTorch 2, 我们建议你安装 [xFormers](./optimization/xformers)。它的内存高效注意力机制（*memory-efficient attention mechanism*）与PyTorch 1.13.1配合使用，速度更快，内存消耗更少。\n- 其他的优化技术, 如：模型卸载（*model offloading*）, 包含在 [这份指南](./optimization/fp16).\n"
  },
  {
    "path": "docs/source/zh/training/adapt_a_model.md",
    "content": "# 将模型适配至新任务\n\n许多扩散系统共享相同的组件架构，这使得您能够将针对某一任务预训练的模型调整适配至完全不同的新任务。\n\n本指南将展示如何通过初始化并修改预训练 [`UNet2DConditionModel`] 的架构，将文生图预训练模型改造为图像修复(inpainting)模型。\n\n## 配置 UNet2DConditionModel 参数\n\n默认情况下，[`UNet2DConditionModel`] 的[输入样本](https://huggingface.co/docs/diffusers/v0.16.0/en/api/models#diffusers.UNet2DConditionModel.in_channels)接受4个通道。例如加载 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 这样的文生图预训练模型，查看其 `in_channels` 参数值：\n\n```python\nfrom diffusers import StableDiffusionPipeline\n\npipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", use_safetensors=True)\npipeline.unet.config[\"in_channels\"]\n4\n```\n\n而图像修复任务需要输入样本具有9个通道。您可以在 [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting) 这样的预训练修复模型中验证此参数：\n\n```python\nfrom diffusers import StableDiffusionPipeline\n\npipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-inpainting\", use_safetensors=True)\npipeline.unet.config[\"in_channels\"]\n9\n```\n\n要将文生图模型改造为修复模型，您需要将 `in_channels` 参数从4调整为9。\n\n初始化一个加载了文生图预训练权重的 [`UNet2DConditionModel`]，并将 `in_channels` 设为9。由于输入通道数变化导致张量形状改变，需要设置 `ignore_mismatched_sizes=True` 和 `low_cpu_mem_usage=False` 来避免尺寸不匹配错误。\n\n```python\nfrom diffusers import AutoModel\n\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nunet = AutoModel.from_pretrained(\n    model_id,\n    subfolder=\"unet\",\n    in_channels=9,\n    low_cpu_mem_usage=False,\n    ignore_mismatched_sizes=True,\n    use_safetensors=True,\n)\n```\n\n此时文生图模型的其他组件权重仍保持预训练状态，但UNet的输入卷积层权重(`conv_in.weight`)会随机初始化。由于这一关键变化，必须对模型进行修复任务的微调，否则模型将仅会输出噪声。\n"
  },
  {
    "path": "docs/source/zh/training/controlnet.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# ControlNet\n\n[ControlNet](https://hf.co/papers/2302.05543) 是一种基于预训练模型的适配器架构。它通过额外输入的条件图像（如边缘检测图、深度图、人体姿态图等），实现对生成图像的精细化控制。\n\n在显存有限的GPU上训练时，建议启用训练命令中的 `gradient_checkpointing`（梯度检查点）、`gradient_accumulation_steps`（梯度累积步数）和 `mixed_precision`（混合精度）参数。还可使用 [xFormers](../optimization/xformers) 的内存高效注意力机制进一步降低显存占用。虽然JAX/Flax训练支持在TPU和GPU上高效运行，但不支持梯度检查点和xFormers。若需通过Flax加速训练，建议使用显存大于30GB的GPU。\n\n本指南将解析 [train_controlnet.py](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py) 训练脚本，帮助您理解其逻辑并适配自定义需求。\n\n运行脚本前，请确保从源码安装库：\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\n然后进入包含训练脚本的示例目录，安装所需依赖：\n\n<hfoptions id=\"installation\">\n<hfoption id=\"PyTorch\">\n```bash\ncd examples/controlnet\npip install -r requirements.txt\n```\n</hfoption>\n<hfoption id=\"Flax\">\n\n若可访问TPU设备，Flax训练脚本将运行得更快！以下是在 [Google Cloud TPU VM](https://cloud.google.com/tpu/docs/run-calculation-jax) 上的配置流程。创建单个TPU v4-8虚拟机并连接：\n\n```bash\nZONE=us-central2-b\nTPU_TYPE=v4-8\nVM_NAME=hg_flax\n\ngcloud alpha compute tpus tpu-vm create $VM_NAME \\\n --zone $ZONE \\\n --accelerator-type $TPU_TYPE \\\n --version  tpu-vm-v4-base\n\ngcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \\\n```\n\n安装JAX 0.4.5：\n\n```bash\npip install \"jax[tpu]==0.4.5\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html\n```\n\n然后安装Flax脚本的依赖：\n\n```bash\ncd examples/controlnet\npip install -r requirements_flax.txt\n```\n\n</hfoption>\n</hfoptions>\n\n> [!TIP]\n> 🤗 Accelerate 是一个支持多GPU/TPU训练和混合精度的库，它能根据硬件环境自动配置训练方案。参阅 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 了解更多。\n\n初始化🤗 Accelerate环境：\n\n```bash\naccelerate config\n```\n\n若要创建默认配置（不进行交互式选择）：\n\n```bash\naccelerate config default\n```\n\n若环境不支持交互式shell（如notebook），可使用：\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n最后，如需训练自定义数据集，请参阅 [创建训练数据集](create_dataset) 指南了解数据准备方法。\n\n> [!TIP]\n> 下文重点解析脚本中的关键模块，但不会覆盖所有实现细节。如需深入了解，建议直接阅读 [脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py)，如有疑问欢迎反馈。\n\n## 脚本参数\n\n训练脚本提供了丰富的可配置参数，所有参数及其说明详见 [`parse_args()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L231) 函数。虽然该函数已为每个参数提供默认值（如训练批大小、学习率等），但您可以通过命令行参数覆盖这些默认值。\n\n例如，使用fp16混合精度加速训练, 可使用`--mixed_precision`参数\n\n```bash\naccelerate launch train_controlnet.py \\\n  --mixed_precision=\"fp16\"\n```\n\n基础参数说明可参考 [文生图](text2image#script-parameters) 训练指南，此处重点介绍ControlNet相关参数：\n\n- `--max_train_samples`: 训练样本数量，减少该值可加快训练，但对超大数据集需配合 `--streaming` 参数使用\n- `--gradient_accumulation_steps`: 梯度累积步数，通过分步计算实现显存受限情况下的更大批次训练\n\n### Min-SNR加权策略\n\n[Min-SNR](https://huggingface.co/papers/2303.09556) 加权策略通过重新平衡损失函数加速模型收敛。虽然训练脚本支持预测 `epsilon`（噪声）或 `v_prediction`，但Min-SNR对两种预测类型均兼容。该策略仅适用于PyTorch版本，Flax训练脚本暂不支持。\n\n推荐值设为5.0：\n\n```bash\naccelerate launch train_controlnet.py \\\n  --snr_gamma=5.0\n```\n\n## 训练脚本\n\n与参数说明类似，训练流程的通用解析可参考 [文生图](text2image#training-script) 指南。此处重点分析ControlNet特有的实现。\n\n脚本中的 [`make_train_dataset`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L582) 函数负责数据预处理，除常规的文本标注分词和图像变换外，还包含条件图像的特效处理：\n\n> [!TIP]\n> 在TPU上流式加载数据集时，🤗 Datasets库可能成为性能瓶颈（因其未针对图像数据优化）。建议考虑 [WebDataset](https://webdataset.github.io/webdataset/)、[TorchData](https://github.com/pytorch/data) 或 [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds) 等高效数据格式。\n\n```py\nconditioning_image_transforms = transforms.Compose(\n    [\n        transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n        transforms.CenterCrop(args.resolution),\n        transforms.ToTensor(),\n    ]\n)\n```\n\n在 [`main()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L713) 函数中，代码会加载分词器、文本编码器、调度器和模型。此处也是ControlNet模型的加载点（支持从现有权重加载或从UNet随机初始化）：\n\n```py\nif args.controlnet_model_name_or_path:\n    logger.info(\"Loading existing controlnet weights\")\n    controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)\nelse:\n    logger.info(\"Initializing controlnet weights from unet\")\n    controlnet = ControlNetModel.from_unet(unet)\n```\n\n[优化器](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L871) 专门针对ControlNet参数进行更新：\n\n```py\nparams_to_optimize = controlnet.parameters()\noptimizer = optimizer_class(\n    params_to_optimize,\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\n在 [训练循环](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L943) 中，条件文本嵌入和图像被输入到ControlNet的下采样和中层模块：\n\n```py\nencoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\ncontrolnet_image = batch[\"conditioning_pixel_values\"].to(dtype=weight_dtype)\n\ndown_block_res_samples, mid_block_res_sample = controlnet(\n    noisy_latents,\n    timesteps,\n    encoder_hidden_states=encoder_hidden_states,\n    controlnet_cond=controlnet_image,\n    return_dict=False,\n)\n```\n\n若想深入理解训练循环机制，可参阅 [理解管道、模型与调度器](../using-diffusers/write_own_pipeline) 教程，该教程详细解析了去噪过程的基本原理。\n\n## 启动训练\n\n现在可以启动训练脚本了！🚀\n\n本指南使用 [fusing/fill50k](https://huggingface.co/datasets/fusing/fill50k) 数据集，当然您也可以按照 [创建训练数据集](create_dataset) 指南准备自定义数据。\n\n设置环境变量 `MODEL_NAME` 为Hub模型ID或本地路径，`OUTPUT_DIR` 为模型保存路径。\n\n下载训练用的条件图像：\n\n```bash\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png\n```\n\n根据GPU型号，可能需要启用特定优化。默认配置需要约38GB显存。若使用多GPU训练，请在 `accelerate launch` 命令中添加 `--multi_gpu` 参数。\n\n<hfoptions id=\"gpu-select\">\n<hfoption id=\"16GB\">\n\n16GB显卡可使用bitsandbytes 8-bit优化器和梯度检查点：\n\n```py\npip install bitsandbytes\n```\n\n训练命令添加以下参数：\n\n```bash\naccelerate launch train_controlnet.py \\\n  --gradient_checkpointing \\\n  --use_8bit_adam \\\n```\n\n</hfoption>\n<hfoption id=\"12GB\">\n\n12GB显卡需组合使用bitsandbytes 8-bit优化器、梯度检查点、xFormers，并将梯度置为None而非0：\n\n```bash\naccelerate launch train_controlnet.py \\\n  --use_8bit_adam \\\n  --gradient_checkpointing \\\n  --enable_xformers_memory_efficient_attention \\\n  --set_grads_to_none \\\n```\n\n</hfoption>\n<hfoption id=\"8GB\">\n\n8GB显卡需使用 [DeepSpeed](https://www.deepspeed.ai/) 将张量卸载到CPU或NVME：\n\n运行以下命令配置环境：\n\n```bash\naccelerate config\n```\n\n选择DeepSpeed stage 2，结合fp16混合精度和参数卸载到CPU的方案。注意这会增加约25GB内存占用。配置示例如下：\n\n```bash\ncompute_environment: LOCAL_MACHINE\ndeepspeed_config:\n  gradient_accumulation_steps: 4\n  offload_optimizer_device: cpu\n  offload_param_device: cpu\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\n```\n\n建议将优化器替换为DeepSpeed特化版 [`deepspeed.ops.adam.DeepSpeedCPUAdam`](https://deepspeed.readthedocs.io/en/latest/optimizers.html#adam-cpu)，注意CUDA工具链版本需与PyTorch匹配。\n\n当前bitsandbytes与DeepSpeed存在兼容性问题。\n\n无需额外添加训练参数。\n\n</hfoption>\n</hfoptions>\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"PyTorch\">\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path/to/save/model\"\n\naccelerate launch train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n --push_to_hub\n```\n\n</hfoption>\n<hfoption id=\"Flax\">\n\nFlax版本支持通过 `--profile_steps==5` 参数进行性能分析：\n\n```bash\npip install tensorflow tensorboard-plugin-profile\ntensorboard --logdir runs/fill-circle-100steps-20230411_165612/\n```\n\n在 [http://localhost:6006/#profile](http://localhost:6006/#profile) 查看分析结果。\n\n> [!WARNING]\n> 若遇到插件版本冲突，建议重新安装TensorFlow和Tensorboard。注意性能分析插件仍处实验阶段，部分视图可能不完整。`trace_viewer` 会截断超过1M的事件记录，在编译步骤分析时可能导致设备轨迹丢失。\n\n```bash\npython3 train_controlnet_flax.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --validation_steps=1000 \\\n --train_batch_size=2 \\\n --revision=\"non-ema\" \\\n --from_pt \\\n --report_to=\"wandb\" \\\n --tracker_project_name=$HUB_MODEL_ID \\\n --num_train_epochs=11 \\\n --push_to_hub \\\n --hub_model_id=$HUB_MODEL_ID\n```\n\n</hfoption>\n</hfoptions>\n\n训练完成后即可进行推理：\n\n```py\nfrom diffusers import StableDiffusionControlNetPipeline, ControlNetModel\nfrom diffusers.utils import load_image\nimport torch\n\ncontrolnet = ControlNetModel.from_pretrained(\"path/to/controlnet\", torch_dtype=torch.float16)\npipeline = StableDiffusionControlNetPipeline.from_pretrained(\n    \"path/to/base/model\", controlnet=controlnet, torch_dtype=torch.float16\n).to(\"cuda\")\n\ncontrol_image = load_image(\"./conditioning_image_1.png\")\nprompt = \"pale golden rod circle with old lace background\"\n\ngenerator = torch.manual_seed(0)\nimage = pipeline(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]\nimage.save(\"./output.png\")\n```\n\n## Stable Diffusion XL\n\nStable Diffusion XL (SDXL) 是新一代文生图模型，通过添加第二文本编码器支持生成更高分辨率图像。使用 [`train_controlnet_sdxl.py`](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet_sdxl.py) 脚本可为SDXL训练ControlNet适配器。\n\nSDXL训练脚本的详细解析请参阅 [SDXL训练](sdxl) 指南。\n\n## 后续步骤\n\n恭喜完成ControlNet训练！如需进一步了解模型应用，以下指南可能有所帮助：\n\n- 学习如何 [使用ControlNet](../using-diffusers/controlnet) 进行多样化任务的推理\n"
  },
  {
    "path": "docs/source/zh/training/distributed_inference.md",
    "content": "<!--版权所有 2025 The HuggingFace Team。保留所有权利。\n\n根据 Apache 许可证 2.0 版本（“许可证”）授权；除非遵守许可证，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件按“原样”分发，不附带任何明示或暗示的担保或条件。请参阅许可证了解具体的语言管理权限和限制。\n-->\n\n# 分布式推理\n\n在分布式设置中，您可以使用 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) 或 [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html) 在多个 GPU 上运行推理，这对于并行生成多个提示非常有用。\n\n本指南将向您展示如何使用 🤗 Accelerate 和 PyTorch Distributed 进行分布式推理。\n\n## 🤗 Accelerate\n\n🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) 是一个旨在简化在分布式设置中训练或运行推理的库。它简化了设置分布式环境的过程，让您可以专注于您的 PyTorch 代码。\n\n首先，创建一个 Python 文件并初始化一个 [`accelerate.PartialState`] 来创建分布式环境；您的设置会自动检测，因此您无需明确定义 `rank` 或 `world_size`。将 [`DiffusionPipeline`] 移动到 `distributed_state.device` 以为每个进程分配一个 GPU。\n\n现在使用 [`~accelerate.PartialState.split_between_processes`] 实用程序作为上下文管理器，自动在进程数之间分发提示。\n\n```py\nimport torch\nfrom accelerate import PartialState\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, use_safetensors=True\n)\ndistributed_state = PartialState()\npipeline.to(distributed_state.device)\n\nwith distributed_state.split_between_processes([\"a dog\", \"a cat\"]) as prompt:\n    result = pipeline(prompt).images[0]\n    result.save(f\"result_{distributed_state.process_index}.png\")\n```\n\n使用 `--num_processes` 参数指定要使用的 GPU 数量，并调用 `accelerate launch` 来运行脚本：\n\n```bash\naccelerate launch run_distributed.py --num_processes=2\n```\n\n> [!TIP]\n> 参考这个最小示例 [脚本](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) 以在多个 GPU 上运行推理。要了解更多信息，请查看 [使用 🤗 Accelerate 进行分布式推理](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) 指南。\n\n## PyTorch Distributed\n\nPyTorch 支持 [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)，它启用了数据\n并行性。\n\n首先，创建一个 Python 文件并导入 `torch.distributed` 和 `torch.multiprocessing` 来设置分布式进程组，并为每个 GPU 上的推理生成进程。您还应该初始化一个 [`DiffusionPipeline`]：\n\n```py\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\nfrom diffusers import DiffusionPipeline\n\nsd = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, use_safetensors=True\n)\n```\n\n您需要创建一个函数来运行推理；[`init_process_group`](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group) 处理创建一个分布式环境，指定要使用的后端类型、当前进程的 `rank` 以及参与进程的数量 `world_size`。如果您在 2 个 GPU 上并行运行推理，那么 `world_size` 就是 2。\n\n将 [`DiffusionPipeline`] 移动到 `rank`，并使用 `get_rank` 为每个进程分配一个 GPU，其中每个进程处理不同的提示：\n\n```py\ndef run_inference(rank, world_size):\n    dist.init_process_group(\"nccl\", rank=rank, world_size=world_size)\n\n    sd.to(rank)\n\n    if torch.distributed.get_rank() == 0:\n        prompt = \"a dog\"\n    elif torch.distributed.get_rank() == 1:\n        prompt = \"a cat\"\n\n    image = sd(prompt).images[0]\n    image.save(f\"./{'_'.join(prompt)}.png\")\n```\n\n要运行分布式推理，调用 [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) 在 `world_size` 定义的 GPU 数量上运行 `run_inference` 函数：\n\n```py\ndef main():\n    world_size = 2\n    mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)\n\n\nif __name__ == \"__main__\":\n    main()\n```\n\n完成推理脚本后，使用 `--nproc_per_node` 参数指定要使用的 GPU 数量，并调用 `torchrun` 来运行脚本：\n\n```bash\ntorchrun run_distributed.py --nproc_per_node=2\n```\n\n> [!TIP]\n> 您可以在 [`DiffusionPipeline`] 中使用 `device_map` 将其模型级组件分布在多个设备上。请参考 [设备放置](../tutorials/inference_with_big_models#device-placement) 指南了解更多信息。\n\n## 模型分片\n\n现代扩散系统，如 [Flux](../api/pipelines/flux)，非常大且包含多个模型。例如，[Flux.1-Dev](https://hf.co/black-forest-labs/FLUX.1-dev) 由两个文本编码器 - [T5-XXL](https://hf.co/google/t5-v1_1-xxl) 和 [CLIP-L](https://hf.co/openai/clip-vit-large-patch14) - 一个 [扩散变换器](../api/models/flux_transformer)，以及一个 [VAE](../api/models/autoencoderkl) 组成。对于如此大的模型，在消费级 GPU 上运行推理可能具有挑战性。\n\n模型分片是一种技术，当模型无法容纳在单个 GPU 上时，将模型分布在多个 GPU 上。下面的示例假设有两个 16GB GPU 可用于推理。\n\n开始使用文本编码器计算文本嵌入。通过设置 `device_map=\"balanced\"` 将文本编码器保持在两个GPU上。`balanced` 策略将模型均匀分布在所有可用GPU上。使用 `max_memory` 参数为每个GPU上的每个文本编码器分配最大内存量。\n\n> [!TIP]\n> **仅** 在此步骤加载文本编码器！扩散变换器和VAE在后续步骤中加载以节省内存。\n\n```py\nfrom diffusers import FluxPipeline\nimport torch\n\nprompt = \"a photo of a dog with cat-like look\"\n\npipeline = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    transformer=None,\n    vae=None,\n    device_map=\"balanced\",\n    max_memory={0: \"16GB\", 1: \"16GB\"},\n    torch_dtype=torch.bfloat16\n)\nwith torch.no_grad():\n    print(\"Encoding prompts.\")\n    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(\n        prompt=prompt, prompt_2=None, max_sequence_length=512\n    )\n```\n\n一旦文本嵌入计算完成，从GPU中移除它们以为扩散变换器腾出空间。\n\n```py\nimport gc \n\ndef flush():\n    gc.collect()\n    torch.cuda.empty_cache()\n    torch.cuda.reset_max_memory_allocated()\n    torch.cuda.reset_peak_memory_stats()\n\ndel pipeline.text_encoder\ndel pipeline.text_encoder_2\ndel pipeline.tokenizer\ndel pipeline.tokenizer_2\ndel pipeline\n\nflush()\n```\n\n接下来加载扩散变换器，它有125亿参数。这次，设置 `device_map=\"auto\"` 以自动将模型分布在两个16GB GPU上。`auto` 策略由 [Accelerate](https://hf.co/docs/accelerate/index) 支持，并作为 [大模型推理](https://hf.co/docs/accelerate/concept_guides/big_model_inference) 功能的一部分可用。它首先将模型分布在最快的设备（GPU）上，然后在需要时移动到较慢的设备如CPU和硬盘。将模型参数存储在较慢设备上的权衡是推理延迟较慢。\n\n```py\nfrom diffusers import AutoModel\nimport torch \n\ntransformer = AutoModel.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\", \n    subfolder=\"transformer\",\n    device_map=\"auto\",\n    torch_dtype=torch.bfloat16\n)\n```\n\n> [!TIP]\n> 在任何时候，您可以尝试 `print(pipeline.hf_device_map)` 来查看各种模型如何在设备上分布。这对于跟踪模型的设备放置很有用。您也可以尝试 `print(transformer.hf_device_map)` 来查看变换器模型如何在设备上分片。\n\n将变换器模型添加到管道中以进行去噪，但将其他模型级组件如文本编码器和VAE设置为 `None`，因为您还不需要它们。\n\n```py\npipeline = FluxPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    text_encoder=None,\n    text_encoder_2=None,\n    tokenizer=None,\n    tokenizer_2=None,\n    vae=None,\n    transformer=transformer,\n    torch_dtype=torch.bfloat16\n)\n\nprint(\"Running denoising.\")\nheight, width = 768, 1360\nlatents = pipeline(\n   \n     \nprompt_embeds=prompt_embeds,\npooled_prompt_embeds=pooled_prompt_embeds,\nnum_inference_steps=50,\nguidance_scale=3.5,\nheight=height,\nwidth=width,\noutput_type=\"latent\",\n).images\n```\n\n从内存中移除管道和变换器，因为它们不再需要。\n\n```py\ndel pipeline.transformer\ndel pipeline\n\nflush()\n```\n\n最后，使用变分自编码器（VAE）将潜在表示解码为图像。VAE通常足够小，可以在单个GPU上加载。\n\n```py\nfrom diffusers import AutoencoderKL\nfrom diffusers.image_processor import VaeImageProcessor\nimport torch \n\nvae = AutoencoderKL.from_pretrained(ckpt_id, subfolder=\"vae\", torch_dtype=torch.bfloat16).to(\"cuda\")\nvae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)\nimage_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)\n\nwith torch.no_grad():\n    print(\"运行解码中。\")\n    latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)\n    latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor\n\n    image = vae.decode(latents, return_dict=False)[0]\n    image = image_processor.postprocess(image, output_type=\"pil\")\n    image[0].save(\"split_transformer.png\")\n```\n\n通过选择性加载和卸载在特定阶段所需的模型，并将最大模型分片到多个GPU上，可以在消费级GPU上运行大型模型的推理。"
  },
  {
    "path": "docs/source/zh/training/dreambooth.md",
    "content": "<!--版权所有 2025 The HuggingFace Team。保留所有权利。\n\n根据 Apache 许可证 2.0 版（“许可证”）授权；除非遵守许可证，否则不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，否则根据许可证分发的软件按“原样”分发，不附带任何明示或暗示的担保或条件。请参阅许可证以了解特定的语言管理权限和限制。\n-->\n\n# DreamBooth\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) 是一种训练技术，通过仅训练少数主题或风格的图像来更新整个扩散模型。它通过在提示中关联一个特殊词与示例图像来工作。\n\n如果您在 vRAM 有限的 GPU 上训练，应尝试在训练命令中启用 `gradient_checkpointing` 和 `mixed_precision` 参数。您还可以通过使用 [xFormers](../optimization/xformers) 的内存高效注意力来减少内存占用。JAX/Flax 训练也支持在 TPU 和 GPU 上进行高效训练，但不支持梯度检查点或 xFormers。如果您想使用 Flax 更快地训练，应拥有内存 >30GB 的 GPU。\n\n本指南将探索 [train_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) 脚本，帮助您更熟悉它，以及如何根据您的用例进行适配。\n\n在运行脚本之前，请确保从源代码安装库：\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\n导航到包含训练脚本的示例文件夹，并安装脚本所需的依赖项：\n\n<hfoptions id=\"installation\">\n<hfoption id=\"PyTorch\">\n\n```bash\ncd examples/dreambooth\npip install -r requirements.txt\n```\n\n</hfoption>\n<hfoption id=\"Flax\">\n\n```bash\ncd examples/dreambooth\npip install -r requirements_flax.txt\n```\n\n</hfoption>\n</hfoptions>\n\n> [!TIP]\n> 🤗 Accelerate 是一个库，用于帮助您在多个 GPU/TPU 上或使用混合精度进行训练。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 以了解更多信息。\n\n初始化 🤗 Accelerate 环境：\n\n```bash\naccelerate config\n```\n\n要设置默认的 🤗 Accelerate 环境而不选择任何配置：\n\n```bash\naccelerate config default\n```\n\n或者，如果您的环境不支持交互式 shell，例如笔记本，您可以使用：\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n最后，如果您想在自己的数据集上训练模型，请查看 [创建用于训练的数据集](create_dataset) 指南，了解如何创建与\n训练脚本。\n\n> [!TIP]\n> 以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分，但并未详细涵盖脚本的每个方面。如果您有兴趣了解更多，请随时阅读[脚本](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)，并告诉我们如果您有任何问题或疑虑。\n\n## 脚本参数\n\n> [!WARNING]\n> DreamBooth 对训练超参数非常敏感，容易过拟合。阅读 [使用 🧨 Diffusers 训练 Stable Diffusion 与 Dreambooth](https://huggingface.co/blog/dreambooth) 博客文章，了解针对不同主题的推荐设置，以帮助您选择合适的超参数。\n\n训练脚本提供了许多参数来自定义您的训练运行。所有参数及其描述都可以在 [`parse_args()`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L228) 函数中找到。参数设置了默认值，这些默认值应该开箱即用效果不错，但如果您愿意，也可以在训练命令中设置自己的值。\n\n例如，要以 bf16 格式进行训练：\n\n```bash\naccelerate launch train_dreambooth.py \\\n    --mixed_precision=\"bf16\"\n```\n\n一些基本且重要的参数需要了解和指定：\n\n- `--pretrained_model_name_or_path`: Hub 上的模型名称或预训练模型的本地路径\n- `--instance_data_dir`: 包含训练数据集（示例图像）的文件夹路径\n- `--instance_prompt`: 包含示例图像特殊单词的文本提示\n- `--train_text_encoder`: 是否也训练文本编码器\n- `--output_dir`: 保存训练后模型的位置\n- `--push_to_hub`: 是否将训练后的模型推送到 Hub\n- `--checkpointing_steps`: 模型训练时保存检查点的频率；这在训练因某种原因中断时很有用，您可以通过在训练命令中添加 `--resume_from_checkpoint` 来从该检查点继续训练\n\n### Min-SNR 加权\n\n[Min-SNR](https://huggingface.co/papers/2303.09556) 加权策略可以通过重新平衡损失来帮助训练，以实现更快的收敛。训练脚本支持预测 `epsilon`（噪声）或 `v_prediction`，但 Min-SNR 与两种预测类型都兼容。此加权策略仅由 PyTorch 支持，在 Flax 训练脚本中不可用。\n\n添加 `--snr_gamma` 参数并将其设置为推荐值 5.0：\n\n```bash\naccelerate launch train_dreambooth.py \\\n  --snr_gamma=5.0\n```\n\n### 先验保持损失\n\n先验保持损失是一种使用模型自身生成的样本来帮助它学习如何生成更多样化图像的方法。因为这些生成的样本图像属于您提供的图像相同的类别，它们帮助模型 r\netain 它已经学到的关于类别的知识，以及它如何利用已经了解的类别信息来创建新的组合。\n\n- `--with_prior_preservation`: 是否使用先验保留损失\n- `--prior_loss_weight`: 控制先验保留损失对模型的影响程度\n- `--class_data_dir`: 包含生成的类别样本图像的文件夹路径\n- `--class_prompt`: 描述生成的样本图像类别的文本提示\n\n```bash\naccelerate launch train_dreambooth.py \\\n  --with_prior_preservation \\\n  --prior_loss_weight=1.0 \\\n  --class_data_dir=\"path/to/class/images\" \\\n  --class_prompt=\"text prompt describing class\"\n```\n\n### 训练文本编码器\n\n为了提高生成输出的质量，除了 UNet 之外，您还可以训练文本编码器。这需要额外的内存，并且您需要一个至少有 24GB 显存的 GPU。如果您拥有必要的硬件，那么训练文本编码器会产生更好的结果，尤其是在生成面部图像时。通过以下方式启用此选项：\n\n```bash\naccelerate launch train_dreambooth.py \\\n  --train_text_encoder\n```\n\n## 训练脚本\n\nDreamBooth 附带了自己的数据集类：\n\n- [`DreamBoothDataset`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L604): 预处理图像和类别图像，并对提示进行分词以用于训练\n- [`PromptDataset`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L738): 生成提示嵌入以生成类别图像\n\n如果您启用了[先验保留损失](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L842)，类别图像在此处生成：\n\n```py\nsample_dataset = PromptDataset(args.class_prompt, num_new_images)\nsample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\nsample_dataloader = accelerator.prepare(sample_dataloader)\npipeline.to(accelerator.device)\n\nfor example in tqdm(\n    sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n):\n    images = pipeline(example[\"prompt\"]).images\n```\n\n接下来是 [`main()`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L799) 函数，它处理设置训练数据集和训练循环本身。脚本加载 [tokenizer](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L898)、[scheduler 和 models](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L912C1-L912C1)：\n\n```py\n# Load the tokenizer\nif args.tokenizer_name:\n    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)\nelif args.pretrained_model_name_or_path:\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n# 加载调度器和模型\nnoise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\ntext_encoder = text_encoder_cls.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n)\n\nif model_has_vae(args):\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision\n    )\nelse:\n    vae = None\n\nunet = UNet2DConditionModel.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n)\n```\n\n然后，是时候[创建训练数据集](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L1073)和从`DreamBoothDataset`创建DataLoader：\n\n```py\ntrain_dataset = DreamBoothDataset(\n    instance_data_root=args.instance_data_dir,\n    instance_prompt=args.instance_prompt,\n    class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n    class_prompt=args.class_prompt,\n    class_num=args.num_class_images,\n    tokenizer=tokenizer,\n    size=args.resolution,\n    center_crop=args.center_crop,\n    encoder_hidden_states=pre_computed_encoder_hidden_states,\n    class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,\n    tokenizer_max_length=args.tokenizer_max_length,\n)\n\ntrain_dataloader = torch.utils.data.DataLoader(\n    train_dataset,\n    batch_size=args.train_batch_size,\n    shuffle=True,\n    collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n    num_workers=args.dataloader_num_workers,\n)\n```\n\n最后，[训练循环](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L1151)处理剩余步骤，例如将图像转换为潜在空间、向输入添加噪声、预测噪声残差和计算损失。\n\n如果您想了解更多关于训练循环的工作原理，请查看[理解管道、模型和调度器](../using-diffusers/write_own_pipeline)教程，该教程分解了去噪过程的基本模式。\n\n## 启动脚本\n\n您现在准备好启动训练脚本了！🚀\n\n对于本指南，您将下载一些[狗的图片](https://huggingface.co/datasets/diffusers/dog-example)的图像并将它们存储在一个目录中。但请记住，您可以根据需要创建和使用自己的数据集（请参阅[创建用于训练的数据集](create_dataset)指南）。\n\n```py\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog\"\nsnapshot_download(\n    \"diffusers/dog-example\",\n    local_dir=local_dir,\n    repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\n设置环境变量 `MODEL_NAME` 为 Hub 上的模型 ID 或本地模型路径，`INSTANCE_DIR` 为您刚刚下载狗图像的路径，`OUTPUT_DIR` 为您想保存模型的位置。您将使用 `sks` 作为特殊词来绑定训练。\n\n如果您有兴趣跟随训练过程，可以定期保存生成的图像作为训练进度。将以下参数添加到训练命令中：\n\n```bash\n--validation_prompt=\"a photo of a sks dog\"\n--num_validation_images=4\n--validation_steps=100\n```\n\n在启动脚本之前，还有一件事！根据您拥有的 GPU，您可能需要启用某些优化来训练 DreamBooth。\n\n<hfoptions id=\"gpu-select\">\n<hfoption id=\"16GB\">\n\n在 16GB GPU 上，您可以使用 bitsandbytes 8 位优化器和梯度检查点来帮助训练 DreamBooth 模型。安装 bitsandbytes：\n\n```py\npip install bitsandbytes\n```\n\n然后，将以下参数添加到您的训练命令中：\n\n```bash\naccelerate launch train_dreambooth.py \\\n  --gradient_checkpointing \\\n  --use_8bit_adam \\\n```\n\n</hfoption>\n<hfoption id=\"12GB\">\n\n在 12GB GPU 上，您需要 bitsandbytes 8 位优化器、梯度检查点、xFormers，并将梯度设置为 `None` 而不是零以减少内存使用。\n\n```bash\naccelerate launch train_dreambooth.py \\\n  --use_8bit_adam \\\n  --gradient_checkpointing \\\n  --enable_xformers_memory_efficient_attention \\\n  --set_grads_to_none \\\n```\n\n</hfoption>\n<hfoption id=\"8GB\">\n\n在 8GB GPU 上，您需要 [DeepSpeed](https://www.deepspeed.ai/) 将一些张量从 vRAM 卸载到 CPU 或 NVME，以便在更少的 GPU 内存下进行训练。\n\n运行以下命令来配置您的 🤗 Accelerate 环境：\n\n```bash\naccelerate config\n```\n\n在配置过程中，确认您想使用 DeepSpeed。现在，通过结合 DeepSpeed 阶段 2、fp16 混合精度以及将模型参数和优化器状态卸载到 CPU，应该可以在低于 8GB vRAM 的情况下进行训练。缺点是这需要更多的系统 RAM（约 25 GB）。有关更多配置选项，请参阅 [DeepSpeed 文档](https://huggingface.co/docs/accelerate/usage_guides/deepspeed)。\n\n您还应将默认的 Adam 优化器更改为 DeepSpeed 的优化版本 [`deepspeed.ops.adam.DeepSpeedCPUAdam`](https://deepspeed.readthedocs.io/en/latest/optimizers.html#adam-cpu) 以获得显著的速度提升。启用 `DeepSpeedCPUAdam` 要求您的系统 CUDA 工具链版本与 PyTorch 安装的版本相同。\n\n目前，bitsandbytes 8 位优化器似乎与 DeepSpeed 不兼容。\n\n就是这样！您不需要向训练命令添加任何额外参数。\n\n</hfoption>\n</hfoptions>\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"PyTorch\">\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport INSTANCE_DIR=\"./dog\"\nexport OUTPUT_DIR=\"path_to_\nsaved_model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=400 \\\n  --push_to_hub\n```\n\n</hfoption>\n<hfoption id=\"Flax\">\n\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport INSTANCE_DIR=\"./dog\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\npython train_dreambooth_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --learning_rate=5e-6 \\\n  --max_train_steps=400 \\\n  --push_to_hub\n```\n\n</hfoption>\n</hfoptions>\n\n训练完成后，您可以使用新训练的模型进行推理！\n\n> [!TIP]\n> 等不及在训练完成前就尝试您的模型进行推理？🤭 请确保安装了最新版本的 🤗 Accelerate。\n>\n> ```py\n> from diffusers import DiffusionPipeline, UNet2DConditionModel\n> from transformers import CLIPTextModel\n> import torch\n>\n> unet = UNet2DConditionModel.from_pretrained(\"path/to/model/checkpoint-100/unet\")\n>\n> # 如果您使用了 `--args.train_text_encoder` 进行训练，请确保也加载文本编码器\n> text_encoder = CLIPTextModel.from_pretrained(\"path/to/model/checkpoint-100/checkpoint-100/text_encoder\")\n>\n> pipeline = DiffusionPipeline.from_pretrained(\n>     \"stable-diffusion-v1-5/stable-diffusion-v1-5\", unet=unet, text_encoder=text_encoder, dtype=torch.float16,\n> ).to(\"cuda\")\n>\n> image = pipeline(\"A photo of sks dog in a bucket\", num_inference_steps=50, guidance_scale=7.5).images[0]\n> image.save(\"dog-bucket.png\")\n> ```\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"PyTorch\">\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(\"path_to_saved_model\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\nimage = pipeline(\"A photo of sks dog in a bucket\", num_inference_steps=50, guidance_scale=7.5).images[0]\nimage.save(\"dog-bucket.png\")\n```\n\n</hfoption>\n<hfoption id=\"Flax\">\n\n```py\nimport jax\nimport numpy as np\nfrom flax.jax_utils import replicate\nfrom flax.training.common_utils import shard\nfrom diffusers import FlaxStableDiffusionPipeline\n\npipeline, params = FlaxStableDiffusionPipeline.from_pretrained(\"path-to-your-trained-model\", dtype=jax.numpy.bfloat16)\n\nprompt = \"A photo of sks dog in a bucket\"\nprng_seed = jax.random.PRNGKey(0)\nnum_inference_steps = 50\n\nnum_samples = jax.device_count()\nprompt = num_samples * [prompt]\nprompt_ids = pipeline.prepare_inputs(prompt)\n\n# 分片输入和随机数生成器\nparams = replicate(params)\nprng_seed = jax.random.split(prng_seed, jax.device_count())\nprompt_ids = shard(prompt_ids)\n\nimages = pipeline(prompt_ids, params, prng_seed, num_inference_\nsteps, jit=True).images\nimages = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))\nimage.save(\"dog-bucket.png\")\n```\n\n</hfoption>\n</hfoptions>\n\n## LoRA\n\nLoRA 是一种训练技术，可显著减少可训练参数的数量。因此，训练速度更快，并且更容易存储生成的权重，因为它们小得多（约 100MB）。使用 [train_dreambooth_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py) 脚本通过 LoRA 进行训练。\n\nLoRA 训练脚本在 [LoRA 训练](lora) 指南中有更详细的讨论。\n\n## Stable Diffusion XL\n\nStable Diffusion XL (SDXL) 是一个强大的文本到图像模型，可生成高分辨率图像，并在其架构中添加了第二个文本编码器。使用 [train_dreambooth_lora_sdxl.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py) 脚本通过 LoRA 训练 SDXL 模型。\n\nSDXL 训练脚本在 [SDXL 训练](sdxl) 指南中有更详细的讨论。\n\n## DeepFloyd IF\n\nDeepFloyd IF 是一个级联像素扩散模型，包含三个阶段。第一阶段生成基础图像，第二和第三阶段逐步将基础图像放大为高分辨率 1024x1024 图像。使用 [train_dreambooth_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py) 或 [train_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) 脚本通过 LoRA 或完整模型训练 DeepFloyd IF 模型。\n\nDeepFloyd IF 使用预测方差，但 Diffusers 训练脚本使用预测误差，因此训练的 DeepFloyd IF 模型被切换到固定方差调度。训练脚本将为您更新完全训练模型的调度器配置。但是，当您加载保存的 LoRA 权重时，还必须更新管道的调度器配置。\n\n```py\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\", use_safetensors=True)\n\npipe.load_lora_weights(\"<lora weights path>\")\n\n# 更新调度器配置为固定方差调度\npipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type=\"fixed_small\")\n```\n\n第二阶段模型需要额外的验证图像进行放大。您可以下载并使用训练图像的缩小版本。\n\n```py\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog_downsized\"\nsnapshot_download(\n    \"diffusers/dog-example-downsized\",\n    local_dir=local_dir,\n    repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\n以下代码示例简要概述了如何结合 DreamBooth 和 LoRA 训练 DeepFloyd IF 模型。一些需要注意的重要参数包括：\n\n* `--resolution=64`，需要更小的分辨率，因为 DeepFloyd IF 是\n一个像素扩散模型，用于处理未压缩的像素，输入图像必须更小\n* `--pre_compute_text_embeddings`，提前计算文本嵌入以节省内存，因为 [`~transformers.T5Model`] 可能占用大量内存\n* `--tokenizer_max_length=77`，您可以使用更长的默认文本长度与 T5 作为文本编码器，但默认模型编码过程使用较短的文本长度\n* `--text_encoder_use_attention_mask`，将注意力掩码传递给文本编码器\n\n<hfoptions id=\"IF-DreamBooth\">\n<hfoption id=\"Stage 1 LoRA DreamBooth\">\n\n使用 LoRA 和 DreamBooth 训练 DeepFloyd IF 的第 1 阶段需要约 28GB 内存。\n\n```bash\nexport MODEL_NAME=\"DeepFloyd/IF-I-XL-v1.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"dreambooth_dog_lora\"\n\naccelerate launch train_dreambooth_lora.py \\\n  --report_to wandb \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a sks dog\" \\\n  --resolution=64 \\\n  --train_batch_size=4 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --scale_lr \\\n  --max_train_steps=1200 \\\n  --validation_prompt=\"a sks dog\" \\\n  --validation_epochs=25 \\\n  --checkpointing_steps=100 \\\n  --pre_compute_text_embeddings \\\n  --tokenizer_max_length=77 \\\n  --text_encoder_use_attention_mask\n```\n\n</hfoption>\n<hfoption id=\"Stage 2 LoRA DreamBooth\">\n\n对于使用 LoRA 和 DreamBooth 的 DeepFloyd IF 第 2 阶段，请注意这些参数：\n\n* `--validation_images`，验证期间用于上采样的图像\n* `--class_labels_conditioning=timesteps`，根据需要额外条件化 UNet，如第 2 阶段中所需\n* `--learning_rate=1e-6`，与第 1 阶段相比使用较低的学习率\n* `--resolution=256`，上采样器的预期分辨率\n\n```bash\nexport MODEL_NAME=\"DeepFloyd/IF-II-L-v1.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"dreambooth_dog_upscale\"\nexport VALIDATION_IMAGES=\"dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png\"\n\npython train_dreambooth_lora.py \\\n    --report_to wandb \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --instance_data_dir=$INSTANCE_DIR \\\n    --output_dir=$OUTPUT_DIR \\\n    --instance_prompt=\"a sks dog\" \\\n    --resolution=256 \\\n    --train_batch_size=4 \\\n    --gradient_accumulation_steps=1 \\\n    --learning_rate=1e-6 \\\n    --max_train_steps=2000 \\\n    --validation_prompt=\"a sks dog\" \\\n    --validation_epochs=100 \\\n    --checkpointing_steps=500 \\\n    --pre_compute_text_embeddings \\\n    --tokenizer_max_length=77 \\\n    --text_encoder_use_attention_mask \\\n    --validation_images $VALIDATION_IMAGES \\\n    --class_labels_conditioning=timesteps\n```\n\n</hfoption>\n<hfoption id=\"Stage 1 DreamBooth\">\n\n对于使用 DreamBooth 的 DeepFloyd IF 第 1 阶段，请注意这些参数：\n\n* `--skip_save_text_encoder`，跳过保存完整 T5 文本编码器与微调模型\n* `--use_8bit_adam`，使用 8 位 Adam 优化器以节省内存，因为\n     \n优化器状态的大小在训练完整模型时\n* `--learning_rate=1e-7`，对于完整模型训练应使用非常低的学习率，否则模型质量会下降（您可以使用更高的学习率和更大的批次大小）\n\n使用8位Adam和批次大小为4进行训练，完整模型可以在约48GB内存下训练。\n\n```bash\nexport MODEL_NAME=\"DeepFloyd/IF-I-XL-v1.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"dreambooth_if\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=64 \\\n  --train_batch_size=4 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=1e-7 \\\n  --max_train_steps=150 \\\n  --validation_prompt \"a photo of sks dog\" \\\n  --validation_steps 25 \\\n  --text_encoder_use_attention_mask \\\n  --tokenizer_max_length 77 \\\n  --pre_compute_text_embeddings \\\n  --use_8bit_adam \\\n  --set_grads_to_none \\\n  --skip_save_text_encoder \\\n  --push_to_hub\n```\n\n</hfoption>\n<hfoption id=\"Stage 2 DreamBooth\">\n\n对于DeepFloyd IF的第二阶段DreamBooth，请注意这些参数：\n\n* `--learning_rate=5e-6`，使用较低的学习率和较小的有效批次大小\n* `--resolution=256`，上采样器的预期分辨率\n* `--train_batch_size=2` 和 `--gradient_accumulation_steps=6`，为了有效训练包含面部的图像，需要更大的批次大小\n\n```bash\nexport MODEL_NAME=\"DeepFloyd/IF-II-L-v1.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"dreambooth_dog_upscale\"\nexport VALIDATION_IMAGES=\"dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png\"\n\naccelerate launch train_dreambooth.py \\\n  --report_to wandb \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a sks dog\" \\\n  --resolution=256 \\\n  --train_batch_size=2 \\\n  --gradient_accumulation_steps=6 \\\n  --learning_rate=5e-6 \\\n  --max_train_steps=2000 \\\n  --validation_prompt=\"a sks dog\" \\\n  --validation_steps=150 \\\n  --checkpointing_steps=500 \\\n  --pre_compute_text_embeddings \\\n  --tokenizer_max_length=77 \\\n  --text_encoder_use_attention_mask \\\n  --validation_images $VALIDATION_IMAGES \\\n  --class_labels_conditioning timesteps \\\n  --push_to_hub\n```\n\n</hfoption>\n</hfoptions>\n\n### 训练技巧\n\n训练DeepFloyd IF模型可能具有挑战性，但以下是我们发现有用的技巧：\n\n- LoRA对于训练第一阶段模型已足够，因为模型的低分辨率使得表示更精细的细节变得困难，无论如何。\n- 对于常见或简单的对象，您不一定需要微调上采样器。确保传递给上采样器的提示被调整以移除实例提示中的新令牌。例如，如果您第一阶段提示是\"a sks dog\"，那么您第二阶段的提示应该是\"a dog\"。\n- 对于更精细的细节，如面部，完全训练\n使用阶段2上采样器比使用LoRA训练阶段2模型更好。使用更大的批次大小和较低的学习率也有帮助。\n- 应使用较低的学习率来训练阶段2模型。\n- [`DDPMScheduler`] 比训练脚本中使用的DPMSolver效果更好。\n\n## 下一步\n\n恭喜您训练了您的DreamBooth模型！要了解更多关于如何使用您的新模型的信息，以下指南可能有所帮助：\n- 如果您使用LoRA训练了您的模型，请学习如何[加载DreamBooth](../using-diffusers/loading_adapters)模型进行推理。"
  },
  {
    "path": "docs/source/zh/training/instructpix2pix.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# InstructPix2Pix\n\n[InstructPix2Pix](https://hf.co/papers/2211.09800) 是一个基于 Stable Diffusion 训练的模型，用于根据人类提供的指令编辑图像。例如，您的提示可以是“将云变成雨天”，模型将相应编辑输入图像。该模型以文本提示（或编辑指令）和输入图像为条件。\n\n本指南将探索 [train_instruct_pix2pix.py](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) 训练脚本，帮助您熟悉它，以及如何将其适应您自己的用例。\n\n在运行脚本之前，请确保从源代码安装库：\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\n然后导航到包含训练脚本的示例文件夹，并安装脚本所需的依赖项：\n\n```bash\ncd examples/instruct_pix2pix\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate 是一个库，用于帮助您在多个 GPU/TPU 上或使用混合精度进行训练。它将根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate [快速导览](https://huggingface.co/docs/accelerate/quicktour) 以了解更多信息。\n\n初始化一个 🤗 Accelerate 环境：\n\n```bash\naccelerate config\n```\n\n要设置一个默认的 🤗 Accelerate 环境，无需选择任何配置：\n\n```bash\naccelerate config default\n```\n\n或者，如果您的环境不支持交互式 shell，例如笔记本，您可以使用：\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n最后，如果您想在自己的数据集上训练模型，请查看 [创建用于训练的数据集](create_dataset) 指南，了解如何创建与训练脚本兼容的数据集。\n\n> [!TIP]\n> 以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分，但并未详细涵盖脚本的每个方面。如果您有兴趣了解更多，请随时阅读 [脚本](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py)，并告诉我们如果您有任何问题或疑虑。\n\n## 脚本参数\n\n训练脚本有许多参数可帮助您自定义训练运行。所有\n参数及其描述可在 [`parse_args()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L65) 函数中找到。大多数参数都提供了默认值，这些值效果相当不错，但如果您愿意，也可以在训练命令中设置自己的值。\n\n例如，要增加输入图像的分辨率：\n\n```bash\naccelerate launch train_instruct_pix2pix.py \\\n  --resolution=512 \\\n```\n\n许多基本和重要的参数在 [文本到图像](text2image#script-parameters) 训练指南中已有描述，因此本指南仅关注与 InstructPix2Pix 相关的参数：\n\n- `--original_image_column`：编辑前的原始图像\n- `--edited_image_column`：编辑后的图像\n- `--edit_prompt_column`：编辑图像的指令\n- `--conditioning_dropout_prob`：训练期间编辑图像和编辑提示的 dropout 概率，这为一种或两种条件输入启用了无分类器引导（CFG）\n\n## 训练脚本\n\n数据集预处理代码和训练循环可在 [`main()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L374) 函数中找到。这是您将修改训练脚本以适应自己用例的地方。\n\n与脚本参数类似，[文本到图像](text2image#training-script) 训练指南提供了训练脚本的逐步说明。相反，本指南将查看脚本中与 InstructPix2Pix 相关的部分。\n\n脚本首先修改 UNet 的第一个卷积层中的 [输入通道数](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L445)，以适应 InstructPix2Pix 的额外条件图像：\n\n```py\nin_channels = 8\nout_channels = unet.conv_in.out_channels\nunet.register_to_config(in_channels=in_channels)\n\nwith torch.no_grad():\n    new_conv_in = nn.Conv2d(\n        in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding\n    )\n    new_conv_in.weight.zero_()\n    new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)\n    unet.conv_in = new_conv_in\n```\n\n这些 UNet 参数由优化器 [更新](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L545C1-L551C6)：\n\n```py\noptimizer = optimizer_cls(\n    unet.parameters(),\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\n接下来，编辑后的图像和编辑指令被 [预处理](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L624)并被[tokenized](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L610C24-L610C24)。重要的是，对原始图像和编辑后的图像应用相同的图像变换。\n\n```py\ndef preprocess_train(examples):\n    preprocessed_images = preprocess_images(examples)\n\n    original_images, edited_images = preprocessed_images.chunk(2)\n    original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)\n    edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)\n\n    examples[\"original_pixel_values\"] = original_images\n    examples[\"edited_pixel_values\"] = edited_images\n\n    captions = list(examples[edit_prompt_column])\n    examples[\"input_ids\"] = tokenize_captions(captions)\n    return examples\n```\n\n最后，在[训练循环](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L730)中，它首先将编辑后的图像编码到潜在空间：\n\n```py\nlatents = vae.encode(batch[\"edited_pixel_values\"].to(weight_dtype)).latent_dist.sample()\nlatents = latents * vae.config.scaling_factor\n```\n\n然后，脚本对原始图像和编辑指令嵌入应用 dropout 以支持 CFG（Classifier-Free Guidance）。这使得模型能够调节编辑指令和原始图像对编辑后图像的影响。\n\n```py\nencoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\noriginal_image_embeds = vae.encode(batch[\"original_pixel_values\"].to(weight_dtype)).latent_dist.mode()\n\nif args.conditioning_dropout_prob is not None:\n    random_p = torch.rand(bsz, device=latents.device, generator=generator)\n    prompt_mask = random_p < 2 * args.conditioning_dropout_prob\n    prompt_mask = prompt_mask.reshape(bsz, 1, 1)\n    null_conditioning = text_encoder(tokenize_captions([\"\"]).to(accelerator.device))[0]\n    encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)\n\n    image_mask_dtype = original_image_embeds.dtype\n    image_mask = 1 - (\n        (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)\n        * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)\n    )\n    image_mask = image_mask.reshape(bsz, 1, 1, 1)\n    original_image_embeds = image_mask * original_image_embeds\n```\n\n差不多就是这样了！除了这里描述的不同之处，脚本的其余部分与[文本到图像](text2image#training-script)训练脚本非常相似，所以请随意查看以获取更多细节。如果您想了解更多关于训练循环如何工作的信息，请查看[理解管道、模型和调度器](../using-diffusers/write_own_pipeline)教程，该教程分解了去噪过程的基本模式。\n\n## 启动脚本\n\n一旦您对脚本的更改感到满意，或者如果您对默认配置没问题，您\n准备好启动训练脚本！🚀\n\n本指南使用 [fusing/instructpix2pix-1000-samples](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) 数据集，这是 [原始数据集](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) 的一个较小版本。您也可以创建并使用自己的数据集（请参阅 [创建用于训练的数据集](create_dataset) 指南）。\n\n将 `MODEL_NAME` 环境变量设置为模型名称（可以是 Hub 上的模型 ID 或本地模型的路径），并将 `DATASET_ID` 设置为 Hub 上数据集的名称。脚本会创建并保存所有组件（特征提取器、调度器、文本编码器、UNet 等）到您的仓库中的一个子文件夹。\n\n> [!TIP]\n> 为了获得更好的结果，尝试使用更大的数据集进行更长时间的训练。我们只在较小规模的数据集上测试过此训练脚本。\n>\n> <br>\n>\n> 要使用 Weights and Biases 监控训练进度，请将 `--report_to=wandb` 参数添加到训练命令中，并使用 `--val_image_url` 指定验证图像，使用 `--validation_prompt` 指定验证提示。这对于调试模型非常有用。\n\n如果您在多个 GPU 上训练，请将 `--multi_gpu` 参数添加到 `accelerate launch` 命令中。\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" train_instruct_pix2pix.py \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --dataset_name=$DATASET_ID \\\n    --enable_xformers_memory_efficient_attention \\\n    --resolution=256 \\\n    --random_flip \\\n    --train_batch_size=4 \\\n    --gradient_accumulation_steps=4 \\\n    --gradient_checkpointing \\\n    --max_train_steps=15000 \\\n    --checkpointing_steps=5000 \\\n    --checkpoints_total_limit=1 \\\n    --learning_rate=5e-05 \\\n    --max_grad_norm=1 \\\n    --lr_warmup_steps=0 \\\n    --conditioning_dropout_prob=0.05 \\\n    --mixed_precision=fp16 \\\n    --seed=42 \\\n    --push_to_hub\n```\n\n训练完成后，您可以使用您的新 InstructPix2Pix 进行推理：\n\n```py\nimport PIL\nimport requests\nimport torch\nfrom diffusers import StableDiffusionInstructPix2PixPipeline\nfrom diffusers.utils import load_image\n\npipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(\"your_cool_model\", torch_dtype=torch.float16).to(\"cuda\")\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\n\nimage = load_image(\"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png\")\nprompt = \"add some ducks to the lake\"\nnum_inference_steps = 20\nimage_guidance_scale = 1.5\nguidance_scale = 10\n\nedited_image = pipeline(\n   prompt,\n   image=image,\n   num_inference_steps=num_inference_steps,\n   image_guidance_scale=image_guidance_scale,\n   guidance_scale=guidance_scale,\n   generator=generator,\n).images[0]\nedited_image.save(\"edited_image.png\")\n```\n\n您应该尝试不同的 `num_inference_steps`、`image_guidance_scale` 和 `guidance_scale` 值，以查看它们如何影响推理速度和质量。指导比例参数\n这些参数尤其重要，因为它们控制原始图像和编辑指令对编辑后图像的影响程度。\n\n## Stable Diffusion XL\n\nStable Diffusion XL (SDXL) 是一个强大的文本到图像模型，能够生成高分辨率图像，并在其架构中添加了第二个文本编码器。使用 [`train_instruct_pix2pix_sdxl.py`](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py) 脚本来训练 SDXL 模型以遵循图像编辑指令。\n\nSDXL 训练脚本在 [SDXL 训练](sdxl) 指南中有更详细的讨论。\n\n## 后续步骤\n\n恭喜您训练了自己的 InstructPix2Pix 模型！🥳 要了解更多关于该模型的信息，可能有助于：\n\n- 阅读 [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd) 博客文章，了解更多我们使用 InstructPix2Pix 进行的一些实验、数据集准备以及不同指令的结果。"
  },
  {
    "path": "docs/source/zh/training/kandinsky.md",
    "content": "<!--版权所有 2025 HuggingFace 团队。保留所有权利。\n\n根据 Apache 许可证 2.0 版本（\"许可证\"）授权；除非遵守许可证，否则您不得使用此文件。您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，否则根据许可证分发的软件按\"原样\"分发，不附带任何明示或暗示的担保或条件。请参阅许可证以了解具体的语言管理权限和限制。\n-->\n\n# Kandinsky 2.2\n\n> [!WARNING]\n> 此脚本是实验性的，容易过拟合并遇到灾难性遗忘等问题。尝试探索不同的超参数以在您的数据集上获得最佳结果。\n\nKandinsky 2.2 是一个多语言文本到图像模型，能够生成更逼真的图像。该模型包括一个图像先验模型，用于从文本提示创建图像嵌入，以及一个解码器模型，基于先验模型的嵌入生成图像。这就是为什么在 Diffusers 中您会找到两个独立的脚本用于 Kandinsky 2.2，一个用于训练先验模型，另一个用于训练解码器模型。您可以分别训练这两个模型，但为了获得最佳结果，您应该同时训练先验和解码器模型。\n\n根据您的 GPU，您可能需要启用 `gradient_checkpointing`（⚠️ 不支持先验模型！）、`mixed_precision` 和 `gradient_accumulation_steps` 来帮助将模型装入内存并加速训练。您可以通过启用 [xFormers](../optimization/xformers) 的内存高效注意力来进一步减少内存使用（版本 [v0.0.16](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212) 在某些 GPU 上训练时失败，因此您可能需要安装开发版本）。\n\n本指南探讨了 [train_text_to_image_prior.py](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py) 和 [train_text_to_image_decoder.py](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py) 脚本，以帮助您更熟悉它，以及如何根据您的用例进行调整。\n\n在运行脚本之前，请确保从源代码安装库：\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\n然后导航到包含训练脚本的示例文件夹，并安装脚本所需的依赖项：\n\n```bash\ncd examples/kandinsky2_2/text_to_image\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate 是一个帮助您在多个 GPU/TPU 上或使用混合精度进行训练的库。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate 的 [快速入门](https://huggingface.co/docs/accelerate/quicktour\n> ) 了解更多。\n\n初始化一个 🤗 Accelerate 环境：\n\n```bash\naccelerate config\n```\n\n要设置一个默认的 🤗 Accelerate 环境而不选择任何配置：\n\n```bash\naccelerate config default\n```\n\n或者，如果您的环境不支持交互式 shell，比如 notebook，您可以使用：\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n最后，如果您想在自己的数据集上训练模型，请查看 [创建用于训练的数据集](create_dataset) 指南，了解如何创建与训练脚本兼容的数据集。\n\n> [!TIP]\n> 以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分，但并未详细涵盖脚本的每个方面。如果您有兴趣了解更多，请随时阅读脚本，并让我们知道您有任何疑问或顾虑。\n\n## 脚本参数\n\n训练脚本提供了许多参数来帮助您自定义训练运行。所有参数及其描述都可以在 [`parse_args()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L190) 函数中找到。训练脚本为每个参数提供了默认值，例如训练批次大小和学习率，但如果您愿意，也可以在训练命令中设置自己的值。\n\n例如，要使用 fp16 格式的混合精度加速训练，请在训练命令中添加 `--mixed_precision` 参数：\n\n```bash\naccelerate launch train_text_to_image_prior.py \\\n  --mixed_precision=\"fp16\"\n```\n\n大多数参数与 [文本到图像](text2image#script-parameters) 训练指南中的参数相同，所以让我们直接进入 Kandinsky 训练脚本的 walkthrough！\n\n### Min-SNR 加权\n\n[Min-SNR](https://huggingface.co/papers/2303.09556) 加权策略可以通过重新平衡损失来帮助训练，实现更快的收敛。训练脚本支持预测 `epsilon`（噪声）或 `v_prediction`，但 Min-SNR 与两种预测类型都兼容。此加权策略仅由 PyTorch 支持，在 Flax 训练脚本中不可用。\n\n添加 `--snr_gamma` 参数并将其设置为推荐值 5.0：\n\n```bash\naccelerate launch train_text_to_image_prior.py \\\n  --snr_gamma=5.0\n```\n\n## 训练脚本\n\n训练脚本也类似于 [文本到图像](text2image#training-script) 训练指南，但已修改以支持训练 prior 和 decoder 模型。本指南重点介绍 Kandinsky 2.2 训练脚本中独特的代码。\n\n<hfoptions id=\"script\">\n<hfoption id=\"prior model\">\n\n[`main()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L441) 函数包含代码 f\n或准备数据集和训练模型。\n\n您会立即注意到的主要区别之一是，训练脚本除了调度器和分词器外，还加载了一个 [`~transformers.CLIPImageProcessor`] 用于预处理图像，以及一个 [`~transformers.CLIPVisionModelWithProjection`] 模型用于编码图像：\n\n```py\nnoise_scheduler = DDPMScheduler(beta_schedule=\"squaredcos_cap_v2\", prediction_type=\"sample\")\nimage_processor = CLIPImageProcessor.from_pretrained(\n    args.pretrained_prior_model_name_or_path, subfolder=\"image_processor\"\n)\ntokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"tokenizer\")\n\nwith ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n    image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"image_encoder\", torch_dtype=weight_dtype\n    ).eval()\n    text_encoder = CLIPTextModelWithProjection.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"text_encoder\", torch_dtype=weight_dtype\n    ).eval()\n```\n\nKandinsky 使用一个 [`PriorTransformer`] 来生成图像嵌入，因此您需要设置优化器来学习先验模型的参数。\n\n```py\nprior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"prior\")\nprior.train()\noptimizer = optimizer_cls(\n    prior.parameters(),\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\n接下来，输入标题被分词，图像由 [`~transformers.CLIPImageProcessor`] [预处理](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L632)：\n\n```py\ndef preprocess_train(examples):\n    images = [image.convert(\"RGB\") for image in examples[image_column]]\n    examples[\"clip_pixel_values\"] = image_processor(images, return_tensors=\"pt\").pixel_values\n    examples[\"text_input_ids\"], examples[\"text_mask\"] = tokenize_captions(examples)\n    return examples\n```\n\n最后，[训练循环](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L718) 将输入图像转换为潜在表示，向图像嵌入添加噪声，并进行预测：\n\n```py\nmodel_pred = prior(\n    noisy_latents,\n    timestep=timesteps,\n    proj_embedding=prompt_embeds,\n    encoder_hidden_states=text_encoder_hidden_states,\n    attention_mask=text_mask,\n).predicted_image_embedding\n```\n\n如果您想了解更多关于训练循环的工作原理，请查看 [理解管道、模型和调度器](../using-diffusers/write_own_pipeline) 教程，该教程分解了去噪过程的基本模式。\n\n</hfoption>\n<hfoption id=\"decoder model\">\n\nThe [`main()`](https://github.com/huggingface/di\nffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py#L440) 函数包含准备数据集和训练模型的代码。\n\n与之前的模型不同，解码器初始化一个 [`VQModel`] 来将潜在变量解码为图像，并使用一个 [`UNet2DConditionModel`]：\n\n```py\nwith ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n    vae = VQModel.from_pretrained(\n        args.pretrained_decoder_model_name_or_path, subfolder=\"movq\", torch_dtype=weight_dtype\n    ).eval()\n    image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"image_encoder\", torch_dtype=weight_dtype\n    ).eval()\nunet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder=\"unet\")\n```\n\n接下来，脚本包括几个图像变换和一个用于对图像应用变换并返回像素值的[预处理](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py#L622)函数：\n\n```py\ndef preprocess_train(examples):\n    images = [image.convert(\"RGB\") for image in examples[image_column]]\n    examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n    examples[\"clip_pixel_values\"] = image_processor(images, return_tensors=\"pt\").pixel_values\n    return examples\n```\n\n最后，[训练循环](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py#L706)处理将图像转换为潜在变量、添加噪声和预测噪声残差。\n\n如果您想了解更多关于训练循环如何工作的信息，请查看[理解管道、模型和调度器](../using-diffusers/write_own_pipeline)教程，该教程分解了去噪过程的基本模式。\n\n```py\nmodel_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]\n```\n\n</hfoption>\n</hfoptions>\n\n## 启动脚本\n\n一旦您完成了所有更改或接受默认配置，就可以启动训练脚本了！🚀\n\n您将在[Naruto BLIP 字幕](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions)数据集上进行训练，以生成您自己的Naruto角色，但您也可以通过遵循[创建用于训练的数据集](create_dataset)指南来创建和训练您自己的数据集。将环境变量 `DATASET_NAME` 设置为Hub上数据集的名称，或者如果您在自己的文件上训练，将环境变量 `TRAIN_DIR` 设置为数据集的路径。\n\n如果您在多个GPU上训练，请在 `accelerate launch` 命令中添加 `--multi_gpu` 参数。\n\n> [!TIP]\n> 要使用Weights & Biases监控训练进度，请在训练命令中添加 `--report_to=wandb` 参数。您还需要\n> 建议在训练命令中添加 `--validation_prompt` 以跟踪结果。这对于调试模型和查看中间结果非常有用。\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"prior model\">\n\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image_prior.py \\\n  --dataset_name=$DATASET_NAME \\\n  --resolution=768 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --checkpoints_total_limit=3 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --validation_prompts=\"A robot naruto, 4k photo\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub \\\n  --output_dir=\"kandi2-prior-naruto-model\"\n```\n\n</hfoption>\n<hfoption id=\"decoder model\">\n\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image_decoder.py \\\n  --dataset_name=$DATASET_NAME \\\n  --resolution=768 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --checkpoints_total_limit=3 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --validation_prompts=\"A robot naruto, 4k photo\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub \\\n  --output_dir=\"kandi2-decoder-naruto-model\"\n```\n\n</hfoption>\n</hfoptions>\n\n训练完成后，您可以使用新训练的模型进行推理！\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"prior model\">\n\n```py\nfrom diffusers import AutoPipelineForText2Image, DiffusionPipeline\nimport torch\n\nprior_pipeline = DiffusionPipeline.from_pretrained(output_dir, torch_dtype=torch.float16)\nprior_components = {\"prior_\" + k: v for k,v in prior_pipeline.components.items()}\npipeline = AutoPipelineForText2Image.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", **prior_components, torch_dtype=torch.float16)\n\npipe.enable_model_cpu_offload()\nprompt=\"A robot naruto, 4k photo\"\nimage = pipeline(prompt=prompt, negative_prompt=negative_prompt).images[0]\n```\n\n> [!TIP]\n> 可以随意将 `kandinsky-community/kandinsky-2-2-decoder` 替换为您自己训练的 decoder 检查点！\n\n</hfoption>\n<hfoption id=\"decoder model\">\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"path/to/saved/model\", torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n\nprompt=\"A robot naruto, 4k photo\"\nimage = pipeline(prompt=prompt).images[0]\n```\n\n对于 decoder 模型，您还可以从保存的检查点进行推理，这对于查看中间结果很有用。在这种情况下，将检查点加载到 UNet 中：\n\n```py\nfrom diffusers import AutoPipelineForText2Image, UNet2DConditionModel\n\nunet = UNet2DConditionModel.from_pretrained(\"path/to/saved/model\" + \"/checkpoint-<N>/unet\")\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", unet=unet, torch_dtype=torch.float16)\npipeline.enable_model_cpu_offload()\n\nimage = pipeline(prompt=\"A robot naruto, 4k photo\").images[0]\n```\n\n</hfoption>\n</hfoptions>\n\n## 后续步骤\n\n恭喜您训练了一个 Kandinsky 2.2 模型！要了解更多关于如何使用您的新模型的信息，以下指南可能会有所帮助：\n\n- 阅读 [Kandinsky](../using-diffusers/kandinsky) 指南，学习如何将其用于各种不同的任务（文本到图像、图像到图像、修复、插值），以及如何与 ControlNet 结合使用。\n- 查看 [DreamBooth](dreambooth) 和 [LoRA](lora) 训练指南，学习如何使用少量示例图像训练个性化的 Kandinsky 模型。这两种训练技术甚至可以结合使用！"
  },
  {
    "path": "docs/source/zh/training/lora.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# LoRA 低秩适配\n\n> [!WARNING]\n> 当前功能处于实验阶段，API可能在未来版本中变更。\n\n[LoRA（大语言模型的低秩适配）](https://hf.co/papers/2106.09685) 是一种轻量级训练技术，能显著减少可训练参数量。其原理是通过向模型注入少量新权重参数，仅训练这些新增参数。这使得LoRA训练速度更快、内存效率更高，并生成更小的模型权重文件（通常仅数百MB），便于存储和分享。LoRA还可与DreamBooth等其他训练技术结合以加速训练过程。\n\n> [!TIP]\n> LoRA具有高度通用性，目前已支持以下应用场景：[DreamBooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py)、[Kandinsky 2.2](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py)、[Stable Diffusion XL](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py)、[文生图](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)以及[Wuerstchen](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py)。\n\n本指南将通过解析[train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)脚本，帮助您深入理解其工作原理，并掌握如何针对具体需求进行定制化修改。\n\n运行脚本前，请确保从源码安装库：\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\n进入包含训练脚本的示例目录，并安装所需依赖：\n\n<hfoptions id=\"installation\">\n<hfoption id=\"PyTorch\">\n\n```bash\ncd examples/text_to_image\npip install -r requirements.txt\n```\n\n</hfoption>\n<hfoption id=\"Flax\">\n\n```bash\ncd examples/text_to_image\npip install -r requirements_flax.txt\n```\n\n</hfoption>\n</hfoptions>\n\n> [!TIP]\n> 🤗 Accelerate是一个支持多GPU/TPU训练和混合精度计算的库，它能根据硬件环境自动配置训练方案。参阅🤗 Accelerate[快速入门](https://huggingface.co/docs/accelerate/quicktour)了解更多。\n\n初始化🤗 Accelerate环境：\n\n```bash\naccelerate config\n```\n\n若要创建默认配置环境（不进行交互式设置）：\n\n```bash\naccelerate config default\n```\n\n若在非交互环境（如Jupyter notebook）中使用：\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n如需训练自定义数据集，请参考[创建训练数据集指南](create_dataset)了解数据准备流程。\n\n> [!TIP]\n> 以下章节重点解析训练脚本中与LoRA相关的核心部分，但不会涵盖所有实现细节。如需完整理解，建议直接阅读[脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)，如有疑问欢迎反馈。\n\n## 脚本参数\n\n训练脚本提供众多参数用于定制训练过程。所有参数及其说明均定义在[`parse_args()`](https://github.com/huggingface/diffusers/blob/dd9a5caf61f04d11c0fa9f3947b69ab0010c9a0f/examples/text_to_image/train_text_to_image_lora.py#L85)函数中。多数参数设有默认值，您也可以通过命令行参数覆盖：\n\n例如增加训练轮次：\n\n```bash\naccelerate launch train_text_to_image_lora.py \\\n  --num_train_epochs=150 \\\n```\n\n基础参数说明可参考[文生图训练指南](text2image#script-parameters)，此处重点介绍LoRA相关参数：\n\n- `--rank`：低秩矩阵的内部维度，数值越高可训练参数越多\n- `--learning_rate`：默认学习率为1e-4，但使用LoRA时可适当提高\n\n## 训练脚本实现\n\n数据集预处理和训练循环逻辑位于[`main()`](https://github.com/huggingface/diffusers/blob/dd9a5caf61f04d11c0fa9f3947b69ab0010c9a0f/examples/text_to_image/train_text_to_image_lora.py#L371)函数，如需定制训练流程，可在此处进行修改。\n\n与参数说明类似，训练流程的完整解析请参考[文生图指南](text2image#training-script)，下文重点介绍LoRA相关实现。\n\n<hfoptions id=\"lora\">\n<hfoption id=\"UNet\">\n\nDiffusers使用[PEFT](https://hf.co/docs/peft)库的[`~peft.LoraConfig`]配置LoRA适配器参数，包括秩(rank)、alpha值以及目标模块。适配器被注入UNet后，通过`lora_layers`筛选出需要优化的LoRA层。\n\n```py\nunet_lora_config = LoraConfig(\n    r=args.rank,\n    lora_alpha=args.rank,\n    init_lora_weights=\"gaussian\",\n    target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n)\n\nunet.add_adapter(unet_lora_config)\nlora_layers = filter(lambda p: p.requires_grad, unet.parameters())\n```\n\n</hfoption>\n<hfoption id=\"text encoder\">\n\n当需要微调文本编码器时（如SDXL模型），Diffusers同样支持通过[PEFT](https://hf.co/docs/peft)库实现。[`~peft.LoraConfig`]配置适配器参数后注入文本编码器，并筛选LoRA层进行训练。\n\n```py\ntext_lora_config = LoraConfig(\n    r=args.rank,\n    lora_alpha=args.rank,\n    init_lora_weights=\"gaussian\",\n    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n)\n\ntext_encoder_one.add_adapter(text_lora_config)\ntext_encoder_two.add_adapter(text_lora_config)\ntext_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))\ntext_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))\n```\n\n</hfoption>\n</hfoptions>\n\n[优化器](https://github.com/huggingface/diffusers/blob/e4b8f173b97731686e290b2eb98e7f5df2b1b322/examples/text_to_image/train_text_to_image_lora.py#L529)仅对`lora_layers`参数进行优化：\n\n```py\noptimizer = optimizer_cls(\n    lora_layers,\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\n除LoRA层设置外，该训练脚本与标准train_text_to_image.py基本相同！\n\n## 启动训练\n\n完成所有配置后，即可启动训练脚本！🚀\n\n以下示例使用[Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions)训练生成火影角色。请设置环境变量`MODEL_NAME`和`DATASET_NAME`指定基础模型和数据集，`OUTPUT_DIR`设置输出目录，`HUB_MODEL_ID`指定Hub存储库名称。脚本运行后将生成以下文件：\n\n- 模型检查点\n- `pytorch_lora_weights.safetensors`（训练好的LoRA权重）\n\n多GPU训练请添加`--multi_gpu`参数。\n\n> [!WARNING]\n> 在11GB显存的2080 Ti显卡上完整训练约需5小时。\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"/sddata/finetune/lora/naruto\"\nexport HUB_MODEL_ID=\"naruto-lora\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image_lora.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$DATASET_NAME \\\n  --dataloader_num_workers=8 \\\n  --resolution=512 \\\n  --center_crop \\\n  --random_flip \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-04 \\\n  --max_grad_norm=1 \\\n  --lr_scheduler=\"cosine\" \\\n  --lr_warmup_steps=0 \\\n  --output_dir=${OUTPUT_DIR} \\\n  --push_to_hub \\\n  --hub_model_id=${HUB_MODEL_ID} \\\n  --report_to=wandb \\\n  --checkpointing_steps=500 \\\n  --validation_prompt=\"蓝色眼睛的火影忍者角色\" \\\n  --seed=1337\n```\n\n训练完成后，您可以通过以下方式进行推理：\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16).to(\"cuda\")\npipeline.load_lora_weights(\"path/to/lora/model\", weight_name=\"pytorch_lora_weights.safetensors\")\nimage = pipeline(\"A naruto with blue eyes\").images[0]\n```\n\n## 后续步骤\n\n恭喜完成LoRA模型训练！如需进一步了解模型使用方法，可参考以下指南：\n\n- 学习如何加载[不同格式的LoRA权重](../using-diffusers/loading_adapters#LoRA)（如Kohya或TheLastBen训练的模型）\n- 掌握使用PEFT进行[多LoRA组合推理](../tutorials/using_peft_for_inference)的技巧"
  },
  {
    "path": "docs/source/zh/training/overview.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\n根据 Apache License 2.0 版本（\"许可证\"）授权，除非符合许可证要求，否则不得使用此文件。您可以通过以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，本软件按\"原样\"分发，不附带任何明示或暗示的担保或条件。详见许可证中规定的特定语言权限和限制。\n-->\n\n# 概述\n\n🤗 Diffusers 提供了一系列训练脚本供您训练自己的diffusion模型。您可以在 [diffusers/examples](https://github.com/huggingface/diffusers/tree/main/examples) 找到所有训练脚本。\n\n每个训练脚本具有以下特点：\n\n- **独立完整**：训练脚本不依赖任何本地文件，所有运行所需的包都通过 `requirements.txt` 文件安装\n- **易于调整**：这些脚本是针对特定任务的训练示例，并不能开箱即用地适用于所有训练场景。您可能需要根据具体用例调整脚本。为此，我们完全公开了数据预处理代码和训练循环，方便您进行修改\n- **新手友好**：脚本设计注重易懂性和入门友好性，而非包含最新最优方法以获得最具竞争力的结果。我们有意省略了过于复杂的训练方法\n- **单一用途**：每个脚本仅针对一个任务设计，确保代码可读性和可理解性\n\n当前提供的训练脚本包括：\n\n| 训练类型 | 支持SDXL | 支持LoRA | 支持Flax |\n|---|---|---|---|\n| [unconditional image generation](https://github.com/huggingface/diffusers/tree/main/examples/unconditional_image_generation) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) |  |  |  |\n| [text-to-image](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) | 👍 | 👍 | 👍 |\n| [textual inversion](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) |  |  | 👍 |\n| [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) | 👍 | 👍 | 👍 |\n| [ControlNet](https://github.com/huggingface/diffusers/tree/main/examples/controlnet) | 👍 |  | 👍 |\n| [InstructPix2Pix](https://github.com/huggingface/diffusers/tree/main/examples/instruct_pix2pix) | 👍 |  |  |\n| [Custom Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion) |  |  |  |\n| [T2I-Adapters](https://github.com/huggingface/diffusers/tree/main/examples/t2i_adapter) | 👍 |  |  |\n| [Kandinsky 2.2](https://github.com/huggingface/diffusers/tree/main/examples/kandinsky2_2/text_to_image) |  | 👍 |  |\n| [Wuerstchen](https://github.com/huggingface/diffusers/tree/main/examples/wuerstchen/text_to_image) |  | 👍 |  |\n\n这些示例处于**积极维护**状态，如果遇到问题请随时提交issue。如果您认为应该添加其他训练示例，欢迎创建[功能请求](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=)与我们讨论，我们将评估其是否符合独立完整、易于调整、新手友好和单一用途的标准。\n\n## 安装\n\n请按照以下步骤在新虚拟环境中从源码安装库，确保能成功运行最新版本的示例脚本：\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\n然后进入具体训练脚本目录（例如[DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)），安装对应的`requirements.txt`文件。部分脚本针对SDXL、LoRA或Flax有特定要求文件，使用时请确保安装对应文件。\n\n```bash\ncd examples/dreambooth\npip install -r requirements.txt\n# 如需用DreamBooth训练SDXL\npip install -r requirements_sdxl.txt\n```\n\n为加速训练并降低内存消耗，我们建议：\n\n- 使用PyTorch 2.0或更高版本，自动启用[缩放点积注意力](../optimization/fp16#scaled-dot-product-attention)（无需修改训练代码）\n- 安装[xFormers](../optimization/xformers)以启用内存高效注意力机制"
  },
  {
    "path": "docs/source/zh/training/text2image.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# 文生图\n\n> [!WARNING]\n> 文生图训练脚本目前处于实验阶段，容易出现过拟合和灾难性遗忘等问题。建议尝试不同超参数以获得最佳数据集适配效果。\n\nStable Diffusion 等文生图模型能够根据文本提示生成对应图像。\n\n模型训练对硬件要求较高，但启用 `gradient_checkpointing` 和 `mixed_precision` 后，可在单块24GB显存GPU上完成训练。如需更大批次或更快训练速度，建议使用30GB以上显存的GPU设备。通过启用 [xFormers](../optimization/xformers) 内存高效注意力机制可降低显存占用。JAX/Flax 训练方案也支持TPU/GPU高效训练，但不支持梯度检查点、梯度累积和xFormers。使用Flax训练时建议配备30GB以上显存GPU或TPU v3。\n\n本指南将详解 [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) 训练脚本，助您掌握其原理并适配自定义需求。\n\n运行脚本前请确保已从源码安装库：\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\n然后进入包含训练脚本的示例目录，安装对应依赖：\n\n<hfoptions id=\"installation\">\n<hfoption id=\"PyTorch\">\n```bash\ncd examples/text_to_image\npip install -r requirements.txt\n```\n</hfoption>\n<hfoption id=\"Flax\">\n```bash\ncd examples/text_to_image\npip install -r requirements_flax.txt\n```\n</hfoption>\n</hfoptions>\n\n> [!TIP]\n> 🤗 Accelerate 是支持多GPU/TPU训练和混合精度的工具库，能根据硬件环境自动配置训练参数。参阅 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 了解更多。\n\n初始化 🤗 Accelerate 环境：\n\n```bash\naccelerate config\n```\n\n要创建默认配置环境（不进行交互式选择）：\n\n```bash\naccelerate config default\n```\n\n若环境不支持交互式shell（如notebook），可使用：\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n最后，如需在自定义数据集上训练，请参阅 [创建训练数据集](create_dataset) 指南了解如何准备适配脚本的数据集。\n\n## 脚本参数\n\n> [!TIP]\n> 以下重点介绍脚本中影响训练效果的关键参数，如需完整参数说明可查阅 [脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)。如有疑问欢迎反馈。\n\n训练脚本提供丰富参数供自定义训练流程，所有参数及说明详见 [`parse_args()`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L193) 函数。该函数为每个参数提供默认值（如批次大小、学习率等），也可通过命令行参数覆盖。\n\n例如使用fp16混合精度加速训练：\n\n```bash\naccelerate launch train_text_to_image.py \\\n  --mixed_precision=\"fp16\"\n```\n\n基础重要参数包括：\n\n- `--pretrained_model_name_or_path`: Hub模型名称或本地预训练模型路径\n- `--dataset_name`: Hub数据集名称或本地训练数据集路径\n- `--image_column`: 数据集中图像列名\n- `--caption_column`: 数据集中文本列名\n- `--output_dir`: 模型保存路径\n- `--push_to_hub`: 是否将训练模型推送至Hub\n- `--checkpointing_steps`: 模型检查点保存步数；训练中断时可添加 `--resume_from_checkpoint` 从该检查点恢复训练\n\n### Min-SNR加权策略\n\n[Min-SNR](https://huggingface.co/papers/2303.09556) 加权策略通过重新平衡损失函数加速模型收敛。训练脚本支持预测 `epsilon`（噪声）或 `v_prediction`，而Min-SNR兼容两种预测类型。该策略仅限PyTorch版本，Flax训练脚本不支持。\n\n添加 `--snr_gamma` 参数并设为推荐值5.0：\n\n```bash\naccelerate launch train_text_to_image.py \\\n  --snr_gamma=5.0\n```\n\n可通过此 [Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) 报告比较不同 `snr_gamma` 值的损失曲面。小数据集上Min-SNR效果可能不如大数据集显著。\n\n## 训练脚本解析\n\n数据集预处理代码和训练循环位于 [`main()`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L490) 函数，自定义修改需在此处进行。\n\n`train_text_to_image` 脚本首先 [加载调度器](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L543) 和分词器，此处可替换其他调度器：\n\n```py\nnoise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\ntokenizer = CLIPTokenizer.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n)\n```\n\n接着 [加载UNet模型](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L619)：\n\n```py\nload_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\nmodel.register_to_config(**load_model.config)\n\nmodel.load_state_dict(load_model.state_dict())\n```\n\n随后对数据集的文本和图像列进行预处理。[`tokenize_captions`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L724) 函数处理文本分词，[`train_transforms`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L742) 定义图像增强策略，二者集成于 `preprocess_train`：\n\n```py\ndef preprocess_train(examples):\n    images = [image.convert(\"RGB\") for image in examples[image_column]]\n    examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n    examples[\"input_ids\"] = tokenize_captions(examples)\n    return examples\n```\n\n最后，[训练循环](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L878) 处理剩余流程：图像编码为潜空间、添加噪声、计算文本嵌入条件、更新模型参数、保存并推送模型至Hub。想深入了解训练循环原理，可参阅 [理解管道、模型与调度器](../using-diffusers/write_own_pipeline) 教程，该教程解析了去噪过程的核心逻辑。\n\n## 启动脚本\n\n完成所有配置后，即可启动训练脚本！🚀\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"PyTorch\">\n\n以 [火影忍者BLIP标注数据集](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) 为例训练生成火影角色。设置环境变量 `MODEL_NAME` 和 `dataset_name` 指定模型和数据集（Hub或本地路径）。多GPU训练需在 `accelerate launch` 命令中添加 `--multi_gpu` 参数。\n\n> [!TIP]\n> 使用本地数据集时，设置 `TRAIN_DIR` 和 `OUTPUT_DIR` 环境变量为数据集路径和模型保存路径。\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport dataset_name=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$dataset_name \\\n  --use_ema \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --enable_xformers_memory_efficient_attention \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --output_dir=\"sd-naruto-model\" \\\n  --push_to_hub\n```\n\n</hfoption>\n<hfoption id=\"Flax\">\n\nFlax训练方案在TPU/GPU上效率更高（由 [@duongna211](https://github.com/duongna21) 实现），TPU性能更优但GPU表现同样出色。\n\n设置环境变量 `MODEL_NAME` 和 `dataset_name` 指定模型和数据集（Hub或本地路径）。\n\n> [!TIP]\n> 使用本地数据集时，设置 `TRAIN_DIR` 和 `OUTPUT_DIR` 环境变量为数据集路径和模型保存路径。\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport dataset_name=\"lambdalabs/naruto-blip-captions\"\n\npython train_text_to_image_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$dataset_name \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --output_dir=\"sd-naruto-model\" \\\n  --push_to_hub\n```\n\n</hfoption>\n</hfoptions>\n\n训练完成后，即可使用新模型进行推理：\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"PyTorch\">\n\n```py\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\npipeline = StableDiffusionPipeline.from_pretrained(\"path/to/saved_model\", torch_dtype=torch.float16, use_safetensors=True).to(\"cuda\")\n\nimage = pipeline(prompt=\"yoda\").images[0]\nimage.save(\"yoda-naruto.png\")\n```\n\n</hfoption>\n<hfoption id=\"Flax\">\n\n```py\nimport jax\nimport numpy as np\nfrom flax.jax_utils import replicate\nfrom flax.training.common_utils import shard\nfrom diffusers import FlaxStableDiffusionPipeline\n\npipeline, params = FlaxStableDiffusionPipeline.from_pretrained(\"path/to/saved_model\", dtype=jax.numpy.bfloat16)\n\nprompt = \"yoda naruto\"\nprng_seed = jax.random.PRNGKey(0)\nnum_inference_steps = 50\n\nnum_samples = jax.device_count()\nprompt = num_samples * [prompt]\nprompt_ids = pipeline.prepare_inputs(prompt)\n\n# 分片输入和随机数\nparams = replicate(params)\nprng_seed = jax.random.split(prng_seed, jax.device_count())\nprompt_ids = shard(prompt_ids)\n\nimages = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images\nimages = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))\nimage.save(\"yoda-naruto.png\")\n```\n\n</hfoption>\n</hfoptions>\n\n## 后续步骤\n\n恭喜完成文生图模型训练！如需进一步使用模型，以下指南可能有所帮助：\n\n- 了解如何加载 [LoRA权重](../using-diffusers/loading_adapters#LoRA) 进行推理（如果训练时使用了LoRA）\n- 在 [文生图](../using-diffusers/conditional_image_generation) 任务指南中，了解引导尺度等参数或提示词加权等技术如何控制生成效果"
  },
  {
    "path": "docs/source/zh/training/text_inversion.md",
    "content": "<!--版权声明 2025 由 HuggingFace 团队所有。保留所有权利。\n\n根据 Apache 许可证 2.0 版（\"许可证\"）授权；除非符合许可证要求，否则不得使用本文件。\n您可以通过以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，本软件按\"原样\"分发，不附带任何明示或暗示的担保或条件。详见许可证中规定的特定语言权限和限制。\n-->\n\n# 文本反转（Textual Inversion）\n\n[文本反转](https://hf.co/papers/2208.01618)是一种训练技术，仅需少量示例图像即可个性化图像生成模型。该技术通过学习和更新文本嵌入（新嵌入会绑定到提示中必须使用的特殊词汇）来匹配您提供的示例图像。\n\n如果在显存有限的GPU上训练，建议在训练命令中启用`gradient_checkpointing`和`mixed_precision`参数。您还可以通过[xFormers](../optimization/xformers)使用内存高效注意力机制来减少内存占用。JAX/Flax训练也支持在TPU和GPU上进行高效训练，但不支持梯度检查点或xFormers。在配置与PyTorch相同的情况下，Flax训练脚本的速度至少应快70%！\n\n本指南将探索[textual_inversion.py](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py)脚本，帮助您更熟悉其工作原理，并了解如何根据自身需求进行调整。\n\n运行脚本前，请确保从源码安装库：\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\n进入包含训练脚本的示例目录，并安装所需依赖：\n\n<hfoptions id=\"installation\">\n<hfoption id=\"PyTorch\">\n\n```bash\ncd examples/textual_inversion\npip install -r requirements.txt\n```\n\n</hfoption>\n<hfoption id=\"Flax\">\n\n```bash\ncd examples/textual_inversion\npip install -r requirements_flax.txt\n```\n\n</hfoption>\n</hfoptions>\n\n> [!TIP]\n> 🤗 Accelerate 是一个帮助您在多GPU/TPU或混合精度环境下训练的工具库。它会根据硬件和环境自动配置训练设置。查看🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour)了解更多。\n\n初始化🤗 Accelerate环境：\n\n```bash\naccelerate config\n```\n\n要设置默认的🤗 Accelerate环境（不选择任何配置）：\n\n```bash\naccelerate config default\n```\n\n如果您的环境不支持交互式shell（如notebook），可以使用：\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n最后，如果想在自定义数据集上训练模型，请参阅[创建训练数据集](create_dataset)指南，了解如何创建适用于训练脚本的数据集。\n\n> [!TIP]\n> 以下部分重点介绍训练脚本中需要理解的关键修改点，但未涵盖脚本所有细节。如需深入了解，可随时查阅[脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py)，如有疑问欢迎反馈。\n\n## 脚本参数\n\n训练脚本包含众多参数，便于您定制训练过程。所有参数及其说明都列在[`parse_args()`](https://github.com/huggingface/diffusers/blob/839c2a5ece0af4e75530cb520d77bc7ed8acf474/examples/textual_inversion/textual_inversion.py#L176)函数中。Diffusers为每个参数提供了默认值（如训练批次大小和学习率），但您可以通过训练命令自由调整这些值。\n\n例如，将梯度累积步数增加到默认值1以上：\n\n```bash\naccelerate launch textual_inversion.py \\\n  --gradient_accumulation_steps=4\n```\n\n其他需要指定的基础重要参数包括：\n\n- `--pretrained_model_name_or_path`：Hub上的模型名称或本地预训练模型路径\n- `--train_data_dir`：包含训练数据集（示例图像）的文件夹路径\n- `--output_dir`：训练模型保存位置\n- `--push_to_hub`：是否将训练好的模型推送至Hub\n- `--checkpointing_steps`：训练过程中保存检查点的频率；若训练意外中断，可通过在命令中添加`--resume_from_checkpoint`从该检查点恢复训练\n- `--num_vectors`：学习嵌入的向量数量；增加此参数可提升模型效果，但会提高训练成本\n- `--placeholder_token`：绑定学习嵌入的特殊词汇（推理时需在提示中使用该词）\n- `--initializer_token`：大致描述训练目标的单字词汇（如物体或风格）\n- `--learnable_property`：训练目标是学习新\"风格\"（如梵高画风）还是\"物体\"（如您的宠物狗）\n\n## 训练脚本\n\n与其他训练脚本不同，textual_inversion.py包含自定义数据集类[`TextualInversionDataset`](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L487)，用于创建数据集。您可以自定义图像尺寸、占位符词汇、插值方法、是否裁剪图像等。如需修改数据集创建方式，可调整`TextualInversionDataset`类。\n\n接下来，在[`main()`](https://github.com/huggingface/diffusers/blob/839c2a5ece0af4e75530cb520d77bc7ed8acf474/examples/textual_inversion/textual_inversion.py#L573)函数中可找到数据集预处理代码和训练循环。\n\n脚本首先加载[tokenizer](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L616)、[scheduler和模型](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L622)：\n\n```py\n# 加载tokenizer\nif args.tokenizer_name:\n    tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\nelif args.pretrained_model_name_or_path:\n    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n# 加载scheduler和模型\nnoise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\ntext_encoder = CLIPTextModel.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n)\nvae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\nunet = UNet2DConditionModel.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n)\n```\n\n随后将特殊[占位符词汇](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L632)加入tokenizer，并调整嵌入层以适配新词汇。\n\n接着，脚本通过`TextualInversionDataset`[创建数据集](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L716)：\n\n```py\ntrain_dataset = TextualInversionDataset(\n    data_root=args.train_data_dir,\n    tokenizer=tokenizer,\n    size=args.resolution,\n    placeholder_token=(\" \".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))),\n    repeats=args.repeats,\n    learnable_property=args.learnable_property,\n    center_crop=args.center_crop,\n    set=\"train\",\n)\ntrain_dataloader = torch.utils.data.DataLoader(\n    train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers\n)\n```\n\n最后，[训练循环](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L784)处理从预测噪声残差到更新特殊占位符词汇嵌入权重的所有流程。\n\n如需深入了解训练循环工作原理，请参阅[理解管道、模型与调度器](../using-diffusers/write_own_pipeline)教程，该教程解析了去噪过程的基本模式。\n\n## 启动脚本\n\n完成所有修改或确认默认配置后，即可启动训练脚本！🚀\n\n本指南将下载[猫玩具](https://huggingface.co/datasets/diffusers/cat_toy_example)的示例图像并存储在目录中。当然，您也可以创建和使用自己的数据集（参见[创建训练数据集](create_dataset)指南）。\n\n```py\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./cat\"\nsnapshot_download(\n    \"diffusers/cat_toy_example\", local_dir=local_dir, repo_type=\"dataset\", ignore_patterns=\".gitattributes\"\n)\n```\n\n设置环境变量`MODEL_NAME`为Hub上的模型ID或本地模型路径，`DATA_DIR`为刚下载的猫图像路径。脚本会将以下文件保存至您的仓库：\n\n- `learned_embeds.bin`：与示例图像对应的学习嵌入向量\n- `token_identifier.txt`：特殊占位符词汇\n- `type_of_concept.txt`：训练概念类型（\"object\"或\"style\"）\n\n> [!WARNING]\n> 在单块V100 GPU上完整训练约需1小时。\n\n启动脚本前还有最后一步。如果想实时观察训练过程，可以定期保存生成图像。在训练命令中添加以下参数：\n\n```bash\n--validation_prompt=\"A <cat-toy> train\"\n--num_validation_images=4\n--validation_steps=100\n```\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"PyTorch\">\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport DATA_DIR=\"./cat\"\n\naccelerate launch textual_inversion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<cat-toy>\" \\\n  --initializer_token=\"toy\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=3000 \\\n  --learning_rate=5.0e-04 \\\n  --scale_lr \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --output_dir=\"textual_inversion_cat\" \\\n  --push_to_hub\n```\n\n</hfoption>\n<hfoption id=\"Flax\">\n\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport DATA_DIR=\"./cat\"\n\npython textual_inversion_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<cat-toy>\" \\\n  --initializer_token=\"toy\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --max_train_steps=3000 \\\n  --learning_rate=5.0e-04 \\\n  --scale_lr \\\n  --output_dir=\"textual_inversion_cat\" \\\n  --push_to_hub\n```\n\n</hfoption>\n</hfoptions>\n\n训练完成后，可以像这样使用新模型进行推理：\n\n<hfoptions id=\"training-inference\">\n<hfoption id=\"PyTorch\">\n\n```py\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\npipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16).to(\"cuda\")\npipeline.load_textual_inversion(\"sd-concepts-library/cat-toy\")\nimage = pipeline(\"A <cat-toy> train\", num_inference_steps=50).images[0]\nimage.save(\"cat-train.png\")\n```\n\n</hfoption>\n<hfoption id=\"Flax\">\n\nFlax不支持[`~loaders.TextualInversionLoaderMixin.load_textual_inversion`]方法，但textual_inversion_flax.py脚本会在训练后[保存](https://github.com/huggingface/diffusers/blob/c0f058265161178f2a88849e92b37ffdc81f1dcc/examples/textual_inversion/textual_inversion_flax.py#L636C2-L636C2)学习到的嵌入作为模型的一部分。这意味着您可以像使用其他Flax模型一样进行推理：\n\n```py\nimport jax\nimport numpy as np\nfrom flax.jax_utils import replicate\nfrom flax.training.common_utils import shard\nfrom diffusers import FlaxStableDiffusionPipeline\n\nmodel_path = \"path-to-your-trained-model\"\npipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)\n\nprompt = \"A <cat-toy> train\"\nprng_seed = jax.random.PRNGKey(0)\nnum_inference_steps = 50\n\nnum_samples = jax.device_count()\nprompt = num_samples * [prompt]\nprompt_ids = pipeline.prepare_inputs(prompt)\n\n# 分片输入和随机数生成器\nparams = replicate(params)\nprng_seed = jax.random.split(prng_seed, jax.device_count())\nprompt_ids = shard(prompt_ids)\n\nimages = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images\nimages = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))\nimage.save(\"cat-train.png\")\n```\n\n</hfoption>\n</hfoptions>\n\n## 后续步骤\n\n恭喜您成功训练了自己的文本反转模型！🎉 如需了解更多使用技巧，以下指南可能会有所帮助：\n\n- 学习如何[加载文本反转嵌入](../using-diffusers/loading_adapters)，并将其用作负面嵌入\n- 学习如何将[文本反转](textual_inversion_inference)应用于Stable Diffusion 1/2和Stable Diffusion XL的推理\n"
  },
  {
    "path": "docs/source/zh/training/wuerstchen.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n\n# Wuerstchen\n\n[Wuerstchen](https://hf.co/papers/2306.00637) 模型通过将潜在空间压缩 42 倍，在不影响图像质量的情况下大幅降低计算成本并加速推理。在训练过程中，Wuerstchen 使用两个模型（VQGAN + 自动编码器）来压缩潜在表示，然后第三个模型（文本条件潜在扩散模型）在这个高度压缩的空间上进行条件化以生成图像。\n\n为了将先验模型放入 GPU 内存并加速训练，尝试分别启用 `gradient_accumulation_steps`、`gradient_checkpointing` 和 `mixed_precision`。\n\n本指南探讨 [train_text_to_image_prior.py](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) 脚本，帮助您更熟悉它，以及如何根据您的用例进行适配。\n\n在运行脚本之前，请确保从源代码安装库：\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\n然后导航到包含训练脚本的示例文件夹，并安装脚本所需的依赖项：\n\n```bash\ncd examples/wuerstchen/text_to_image\npip install -r requirements.txt\n```\n\n> [!TIP]\n> 🤗 Accelerate 是一个帮助您在多个 GPU/TPU 上或使用混合精度进行训练的库。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 以了解更多信息。\n\n初始化一个 🤗 Accelerate 环境：\n\n```bash\naccelerate config\n```\n\n要设置一个默认的 🤗 Accelerate 环境而不选择任何配置：\n\n```bash\naccelerate config default\n```\n\n或者，如果您的环境不支持交互式 shell，例如笔记本，您可以使用：\n\n```py\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n最后，如果您想在自己的数据集上训练模型，请查看 [创建训练数据集](create_dataset) 指南，了解如何创建与训练脚本兼容的数据集。\n\n> [!TIP]\n> 以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分，但并未涵盖 [脚本](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) 的详细信息。如果您有兴趣了解更多，请随时阅读脚本，并告诉我们您是否有任何问题或疑虑。\n\n## 脚本参数\n\n训练脚本提供了许多参数来帮助您自定义训练运行。所有参数及其描述都可以在 [`parse_args()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L192) 函数中找到。它为每个参数提供了默认值，例如训练批次大小和学习率，但如果您愿意，也可以在训练命令中设置自己的值。\n\n例如，要使用 fp16 格式的混合精度加速训练，请在训练命令中添加 `--mixed_precision` 参数：\n\n```bash\naccelerate launch train_text_to_image_prior.py \\\n  --mixed_precision=\"fp16\"\n```\n\n大多数参数与 [文本到图像](text2image#script-parameters) 训练指南中的参数相同，因此让我们直接深入 Wuerstchen 训练脚本！\n\n## 训练脚本\n\n训练脚本也与 [文本到图像](text2image#training-script) 训练指南类似，但已修改以支持 Wuerstchen。本指南重点介绍 Wuerstchen 训练脚本中独特的代码。\n\n[`main()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L441) 函数首先初始化图像编码器 - 一个 [EfficientNet](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py) - 以及通常的调度器和分词器。\n\n```py\nwith ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n    pretrained_checkpoint_file = hf_hub_download(\"dome272/wuerstchen\", filename=\"model_v2_stage_b.pt\")\n    state_dict = torch.load(pretrained_checkpoint_file, map_location=\"cpu\")\n    image_encoder = EfficientNetEncoder()\n    image_encoder.load_state_dict(state_dict[\"effnet_state_dict\"])\n    image_encoder.eval()\n```\n\n您还将加载 [`WuerstchenPrior`] 模型以进行优化。\n\n```py\nprior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"prior\")\n\noptimizer = optimizer_cls(\n    prior.parameters(),\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n```\n\n接下来，您将对图像应用一些 [transforms](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L656) 并对标题进行 [tokenize](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L637)：\n\n```py\ndef preprocess_train(examples):\n    images = [image.conver\nt(\"RGB\") for image in examples[image_column]]\n    examples[\"effnet_pixel_values\"] = [effnet_transforms(image) for image in images]\n    examples[\"text_input_ids\"], examples[\"text_mask\"] = tokenize_captions(examples)\n    return examples\n```\n\n最后，[训练循环](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L656)处理使用`EfficientNetEncoder`将图像压缩到潜在空间，向潜在表示添加噪声，并使用[`WuerstchenPrior`]模型预测噪声残差。\n\n```py\npred_noise = prior(noisy_latents, timesteps, prompt_embeds)\n```\n\n如果您想了解更多关于训练循环的工作原理，请查看[理解管道、模型和调度器](../using-diffusers/write_own_pipeline)教程，该教程分解了去噪过程的基本模式。\n\n## 启动脚本\n\n一旦您完成了所有更改或对默认配置满意，就可以启动训练脚本了！🚀\n\n设置`DATASET_NAME`环境变量为Hub中的数据集名称。本指南使用[Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions)数据集，但您也可以创建和训练自己的数据集（参见[创建用于训练的数据集](create_dataset)指南）。\n\n> [!TIP]\n> 要使用Weights & Biases监控训练进度，请在训练命令中添加`--report_to=wandb`参数。您还需要在训练命令中添加`--validation_prompt`以跟踪结果。这对于调试模型和查看中间结果非常有用。\n\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch  train_text_to_image_prior.py \\\n  --mixed_precision=\"fp16\" \\\n  --dataset_name=$DATASET_NAME \\\n  --resolution=768 \\\n  --train_batch_size=4 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --dataloader_num_workers=4 \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --checkpoints_total_limit=3 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --validation_prompts=\"A robot naruto, 4k photo\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub \\\n  --output_dir=\"wuerstchen-prior-naruto-model\"\n```\n\n训练完成后，您可以使用新训练的模型进行推理！\n\n```py\nimport torch\nfrom diffusers import AutoPipelineForText2Image\nfrom diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"path/to/saved/model\", torch_dtype=torch.float16).to(\"cuda\")\n\ncaption = \"A cute bird naruto holding a shield\"\nimages = pipeline(\n    caption,\n    width=1024,\n    height=1536,\n    prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,\n    prior_guidance_scale=4.0,\n    num_images_per_prompt=2,\n).images\n```\n\n## 下一步\n\n恭喜您训练了一个Wuerstchen模型！要了解更多关于如何使用您的新模型的信息，请参\n以下内容可能有所帮助：\n\n- 查看 [Wuerstchen](../api/pipelines/wuerstchen#text-to-image-generation) API 文档，了解更多关于如何使用该管道进行文本到图像生成及其限制的信息。"
  },
  {
    "path": "docs/source/zh/using-diffusers/consisid.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n# ConsisID\n\n[ConsisID](https://github.com/PKU-YuanGroup/ConsisID)是一种身份保持的文本到视频生成模型，其通过频率分解在生成的视频中保持面部一致性。它具有以下特点：\n\n- 基于频率分解：将人物ID特征解耦为高频和低频部分，从频域的角度分析DIT架构的特性，并且基于此特性设计合理的控制信息注入方式。\n\n- 一致性训练策略：我们提出粗到细训练策略、动态掩码损失、动态跨脸损失，进一步提高了模型的泛化能力和身份保持效果。\n\n\n- 推理无需微调：之前的方法在推理前，需要对输入id进行case-by-case微调，时间和算力开销较大，而我们的方法是tuning-free的。\n\n\n本指南将指导您使用 ConsisID 生成身份保持的视频。\n\n## Load Model Checkpoints\n模型权重可以存储在Hub上或本地的单独子文件夹中，在这种情况下，您应该使用 [`~DiffusionPipeline.from_pretrained`] 方法。\n\n\n```python\n# !pip install consisid_eva_clip insightface facexlib\nimport torch\nfrom diffusers import ConsisIDPipeline\nfrom diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer\nfrom huggingface_hub import snapshot_download\n\n# Download ckpts\nsnapshot_download(repo_id=\"BestWishYsh/ConsisID-preview\", local_dir=\"BestWishYsh/ConsisID-preview\")\n\n# Load face helper model to preprocess input face image\nface_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = prepare_face_models(\"BestWishYsh/ConsisID-preview\", device=\"cuda\", dtype=torch.bfloat16)\n\n# Load consisid base model\npipe = ConsisIDPipeline.from_pretrained(\"BestWishYsh/ConsisID-preview\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n```\n\n## Identity-Preserving Text-to-Video\n对于身份保持的文本到视频生成，需要输入文本提示和包含清晰面部（例如，最好是半身或全身）的图像。默认情况下，ConsisID 会生成 720x480 的视频以获得最佳效果。\n\n```python\nfrom diffusers.utils import export_to_video\n\nprompt = \"The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.\"\nimage = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true\"\n\nid_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(face_helper_1, face_clip_model, face_helper_2, eva_transform_mean, eva_transform_std, face_main_model, \"cuda\", torch.bfloat16, image, is_align_face=True)\n\nvideo = pipe(image=image, prompt=prompt, num_inference_steps=50, guidance_scale=6.0, use_dynamic_cfg=False, id_vit_hidden=id_vit_hidden, id_cond=id_cond, kps_cond=face_kps, generator=torch.Generator(\"cuda\").manual_seed(42))\nexport_to_video(video.frames[0], \"output.mp4\", fps=8)\n```\n<table>\n  <tr>\n    <th style=\"text-align: center;\">Face Image</th>\n    <th style=\"text-align: center;\">Video</th>\n    <th style=\"text-align: center;\">Description</th\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_0.png?download=true\" style=\"height: auto; width: 600px;\"></td>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_0.gif?download=true\" style=\"height: auto; width: 2000px;\"></td>\n    <td>The video, in a beautifully crafted animated style, features a confident woman riding a horse through a lush forest clearing. Her expression is focused yet serene as she adjusts her wide-brimmed hat with a practiced hand. She wears a flowy bohemian dress, which moves gracefully with the rhythm of the horse, the fabric flowing fluidly in the animated motion. The dappled sunlight filters through the trees, casting soft, painterly patterns on the forest floor. Her posture is poised, showing both control and elegance as she guides the horse with ease. The animation's gentle, fluid style adds a dreamlike quality to the scene, with the woman’s calm demeanor and the peaceful surroundings evoking a sense of freedom and harmony.</td>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_1.png?download=true\" style=\"height: auto; width: 600px;\"></td>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_1.gif?download=true\" style=\"height: auto; width: 2000px;\"></td>\n    <td>The video, in a captivating animated style, shows a woman standing in the center of a snowy forest, her eyes narrowed in concentration as she extends her hand forward. She is dressed in a deep blue cloak, her breath visible in the cold air, which is rendered with soft, ethereal strokes. A faint smile plays on her lips as she summons a wisp of ice magic, watching with focus as the surrounding trees and ground begin to shimmer and freeze, covered in delicate ice crystals. The animation’s fluid motion brings the magic to life, with the frost spreading outward in intricate, sparkling patterns. The environment is painted with soft, watercolor-like hues, enhancing the magical, dreamlike atmosphere. The overall mood is serene yet powerful, with the quiet winter air amplifying the delicate beauty of the frozen scene.</td>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_2.png?download=true\" style=\"height: auto; width: 600px;\"></td>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_2.gif?download=true\" style=\"height: auto; width: 2000px;\"></td>\n    <td>The animation features a whimsical portrait of a balloon seller standing in a gentle breeze, captured with soft, hazy brushstrokes that evoke the feel of a serene spring day. His face is framed by a gentle smile, his eyes squinting slightly against the sun, while a few wisps of hair flutter in the wind. He is dressed in a light, pastel-colored shirt, and the balloons around him sway with the wind, adding a sense of playfulness to the scene. The background blurs softly, with hints of a vibrant market or park, enhancing the light-hearted, yet tender mood of the moment.</td>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_3.png?download=true\" style=\"height: auto; width: 600px;\"></td>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_3.gif?download=true\" style=\"height: auto; width: 2000px;\"></td>\n    <td>The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.</td>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_4.png?download=true\" style=\"height: auto; width: 600px;\"></td>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_output_4.gif?download=true\" style=\"height: auto; width: 2000px;\"></td>\n    <td>The video features a baby wearing a bright superhero cape, standing confidently with arms raised in a powerful pose. The baby has a determined look on their face, with eyes wide and lips pursed in concentration, as if ready to take on a challenge. The setting appears playful, with colorful toys scattered around and a soft rug underfoot, while sunlight streams through a nearby window, highlighting the fluttering cape and adding to the impression of heroism. The overall atmosphere is lighthearted and fun, with the baby's expressions capturing a mix of innocence and an adorable attempt at bravery, as if truly ready to save the day.</td>\n  </tr>\n</table>\n\n## Resources\n\n通过以下资源了解有关 ConsisID 的更多信息：\n\n- 一段 [视频](https://www.youtube.com/watch?v=PhlgC-bI5SQ) 演示了 ConsisID 的主要功能；\n- 有关更多详细信息，请参阅研究论文 [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440)。\n"
  },
  {
    "path": "docs/source/zh/using-diffusers/guiders.md",
    "content": "<!--版权所有 2025 The HuggingFace Team。保留所有权利。\n\n根据 Apache 许可证 2.0 版（\"许可证\"）授权；除非遵守许可证，否则不得使用此文件。\n您可以在以下网址获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，根据许可证分发的软件按\"原样\"分发，不附带任何明示或暗示的担保或条件。请参阅许可证了解具体的语言管理权限和限制。\n-->\n\n# 引导器\n\n[Classifier-free guidance](https://huggingface.co/papers/2207.12598) 引导模型生成更好地匹配提示，通常用于提高生成质量、控制和提示的遵循度。有不同类型的引导方法，在 Diffusers 中，它们被称为*引导器*。与块类似，可以轻松切换和使用不同的引导器以适应不同的用例，而无需重写管道。\n\n本指南将向您展示如何切换引导器、调整引导器参数，以及将它们加载并共享到 Hub。\n\n## 切换引导器\n\n[`ClassifierFreeGuidance`] 是默认引导器，在使用 [`~ModularPipelineBlocks.init_pipeline`] 初始化管道时创建。它通过 `from_config` 创建，这意味着它不需要从模块化存储库加载规范。引导器不会列在 `modular_model_index.json` 中。\n\n使用 [`~ModularPipeline.get_component_spec`] 来检查引导器。\n\n```py\nt2i_pipeline.get_component_spec(\"guider\")\nComponentSpec(name='guider', type_hint=<class 'diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance'>, description=None, config=FrozenDict([('guidance_scale', 7.5), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['start', 'guidance_rescale', 'stop', 'use_original_formulation'])]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')\n```\n\n通过将新引导器传递给 [`~ModularPipeline.update_components`] 来切换到不同的引导器。\n\n> [!TIP]\n> 更改引导器将返回文本，让您知道您正在更改引导器类型。\n> ```bash\n> ModularPipeline.update_components: 添加具有新类型的引导器: PerturbedAttentionGuidance, 先前类型: ClassifierFreeGuidance\n> ```\n\n```py\nfrom diffusers import LayerSkipConfig, PerturbedAttentionGuidance\n\nconfig = LayerSkipConfig(indices=[2, 9], fqn=\"mid_block.attentions.0.transformer_blocks\", skip_attention=False, skip_attention_scores=True, skip_ff=False)\nguider = PerturbedAttentionGuidance(\n    guidance_scale=5.0, perturbed_guidance_scale=2.5, perturbed_guidance_config=config\n)\nt2i_pipeline.update_components(guider=guider)\n```\n\n再次使用 [`~ModularPipeline.get_component_spec`] 来验证引导器类型是否不同。\n\n```py\nt2i_pipeline.get_component_spec(\"guider\")\nComponentSpec(name='guider', type_hint=<class 'diffusers.guiders.perturbed_attention_guidance.PerturbedAttentionGuidance'>, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['perturbed_guidance_start', 'use_original_formulation', 'perturbed_guidance_layers', 'stop', 'start', 'guidance_rescale', 'perturbed_guidance_stop']), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')\n```\n\n## 加载自定义引导器\n\n已经在 Hub 上保存并带有 `modular_model_index.json` 文件的引导器现在被视为 `from_pretrained` 组件，而不是 `from_config` 组件。\n\n```json\n{\n  \"guider\": [\n    null,\n    null,\n    {\n      \"repo\": \"YiYiXu/modular-loader-t2i-guider\",\n      \"revision\": null,\n      \"subfolder\": \"pag_guider\",\n      \"type_hint\": [\n        \"diffusers\",\n        \"PerturbedAttentionGuidance\"\n      ],\n      \"variant\": null\n    }\n  ]\n}\n```\n\n引导器只有在调用 [`~ModularPipeline.load_components`] 之后才会创建，基于 `modular_model_index.json` 中的加载规范。\n\n```py\nt2i_pipeline = t2i_blocks.init_pipeline(\"YiYiXu/modular-doc-guider\")\n# 在初始化时未创建\nassert t2i_pipeline.guider is None\nt2i_pipeline.load_components()\n# 加载为 PAG 引导器\nt2i_pipeline.guider\n```\n\n## 更改引导器参数\n\n引导器参数可以通过 [`~ComponentSpec.create`] 方法以及 [`~ModularPipeline.update_components`] 方法进行调整。下面的示例更改了 `guidance_scale` 值。\n\n```py\nguider_spec = t2i_pipeline.get_component_spec(\"guider\")\nguider = guider_spec.create(guidance_scale=10)\nt2i_pipeline.update_components(guider=guider)\n```\n\n## 上传自定义引导器\n\n在自定义引导器上调用 [`~utils.PushToHubMixin.push_to_hub`] 方法，将其分享到 Hub。\n\n```py\nguider.push_to_hub(\"YiYiXu/modular-loader-t2i-guider\", subfolder=\"pag_guider\")\n```\n\n要使此引导器可用于管道，可以修改 `modular_model_index.json` 文件或使用 [`~ModularPipeline.update_components`] 方法。\n\n<hfoptions id=\"upload\">\n<hfoption id=\"modular_model_index.json\">\n\n编辑 `modular_model_index.json` 文件，并添加引导器的加载规范，指向包含引导器配置的文件夹\n例如。\n\n```json\n{\n  \"guider\": [\n    \"diffusers\",\n    \"PerturbedAttentionGuidance\",\n    {\n      \"repo\": \"YiYiXu/modular-loader-t2i-guider\",\n      \"revision\": null,\n      \"subfolder\": \"pag_guider\",\n      \"type_hint\": [\n        \"diffusers\",\n        \"PerturbedAttentionGuidance\"\n      ],\n      \"variant\": null\n    }\n  ],\n```\n\n</hfoption>\n<hfoption id=\"update_components\">\n\n将 [`~ComponentSpec.default_creation_method`] 更改为 `from_pretrained` 并使用 [`~ModularPipeline.update_components`] 来更新引导器和组件规范以及管道配置。\n\n> [!TIP]\n> 更改创建方法将返回文本，告知您正在将创建类型更改为 `from_pretrained`。\n> ```bash\n> ModularPipeline.update_components: 将引导器的 default_creation_method 从 from_config 更改为 from_pretrained。\n> ```\n\n```py\nguider_spec = t2i_pipeline.get_component_spec(\"guider\")\nguider_spec.default_creation_method=\"from_pretrained\"\nguider_spec.pretrained_model_name_or_path=\"YiYiXu/modular-loader-t2i-guider\"\nguider_spec.subfolder=\"pag_guider\"\npag_guider = guider_spec.load()\nt2i_pipeline.update_components(guider=pag_guider)\n```\n\n要使其成为管道的默认引导器，请调用 [`~utils.PushToHubMixin.push_to_hub`]。这是一个可选步骤，如果您仅在本地进行实验，则不需要。\n\n```py\nt2i_pipeline.push_to_hub(\"YiYiXu/modular-doc-guider\")\n```\n\n</hfoption>\n</hfoptions>\n"
  },
  {
    "path": "docs/source/zh/using-diffusers/helios.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\nthe License. You may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\nspecific language governing permissions and limitations under the License.\n-->\n# Helios\n\n[Helios](https://github.com/PKU-YuanGroup/Helios) 是首个能够在单张 NVIDIA H100 GPU 上以 19.5 FPS 运行的 14B 视频生成模型。它在支持分钟级视频生成的同时，拥有媲美强大基线模型的生成质量，并在统一架构下原生集成了文生视频（T2V）、图生视频（I2V）和视频生视频（V2V）任务。Helios 的主要特性包括：\n\n- 无需常用的防漂移策略（例如：自强制/self-forcing、误差库/error-banks、关键帧采样或逆采样），我们的模型即可生成高质量且高度连贯的分钟级视频。\n- 无需标准的加速技术（例如：KV 缓存、因果掩码、稀疏/线性注意力机制、TinyVAE、渐进式噪声调度、隐藏状态缓存或量化），作为一款 14B 规模的视频生成模型，我们在单张 H100 GPU 上的端到端推理速度便达到了 19.5 FPS。\n- 引入了多项优化方案，在降低显存消耗的同时，显著提升了训练与推理的吞吐量。这些改进使得我们无需借助并行或分片（sharding）等基础设施，即可使用与图像模型相当的批大小（batch sizes）来训练 14B 的视频生成模型。\n\n本指南将引导您完成 Helios 在不同场景下的使用。\n\n## Load Model Checkpoints\n\n模型权重可以存储在Hub上或本地的单独子文件夹中，在这种情况下，您应该使用 [`~DiffusionPipeline.from_pretrained`] 方法。\n\n```python\nimport torch\nfrom diffusers import HeliosPipeline, HeliosPyramidPipeline\nfrom huggingface_hub import snapshot_download\n\n# For Best Quality\nsnapshot_download(repo_id=\"BestWishYsh/Helios-Base\", local_dir=\"BestWishYsh/Helios-Base\")\npipe = HeliosPipeline.from_pretrained(\"BestWishYsh/Helios-Base\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\n# Intermediate Weight\nsnapshot_download(repo_id=\"BestWishYsh/Helios-Mid\", local_dir=\"BestWishYsh/Helios-Mid\")\npipe = HeliosPyramidPipeline.from_pretrained(\"BestWishYsh/Helios-Mid\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\n# For Best Efficiency\nsnapshot_download(repo_id=\"BestWishYsh/Helios-Distilled\", local_dir=\"BestWishYsh/Helios-Distilled\")\npipe = HeliosPyramidPipeline.from_pretrained(\"BestWishYsh/Helios-Distilled\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n```\n\n## Text-to-Video Showcases\n\n<table>\n  <tr>\n    <th style=\"text-align: center;\">Prompt</th>\n    <th style=\"text-align: center;\">Generated Video</th>\n  </tr>\n  <tr>\n    <td><small>A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression.\n    </small></td>\n    <td>\n      <video width=\"4000\" controls>\n        <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases1.mp4\" type=\"video/mp4\">\n      </video>\n    </td>\n  </tr>\n  <tr>\n    <td><small>A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle.\n    </small></td>\n    <td>\n      <video width=\"4000\" controls>\n        <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases2.mp4\" type=\"video/mp4\">\n      </video>\n    </td>\n  </tr>\n</table>\n\n## Image-to-Video Showcases\n\n<table>\n  <tr>\n    <th style=\"text-align: center;\">Image</th>\n    <th style=\"text-align: center;\">Prompt</th>\n    <th style=\"text-align: center;\">Generated Video</th>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.jpg\" style=\"height: auto; width: 300px;\"></td>\n    <td><small>A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads \"KIA 626,\" and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees.\n    </small></td>\n    <td>\n      <video width=\"2000\" controls>\n        <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.mp4\" type=\"video/mp4\">\n      </video>\n    </td>\n  </tr>\n  <tr>\n    <td><img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.jpg\" style=\"height: auto; width: 300px;\"></td>\n    <td><small>A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective.\n    </small></td>\n    <td>\n      <video width=\"2000\" controls>\n        <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.mp4\" type=\"video/mp4\">\n      </video>\n    </td>\n  </tr>\n</table>\n\n## Interactive-Video Showcases\n\n<table>\n  <tr>\n    <th style=\"text-align: center;\">Prompt</th>\n    <th style=\"text-align: center;\">Generated Video</th>\n  </tr>\n  <tr>\n    <td><small>The prompt can be found <a href=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.txt\">here</a></small></td>\n    <td>\n      <video width=\"680\" controls>\n        <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.mp4\" type=\"video/mp4\">\n      </video>\n    </td>\n  </tr>\n  <tr>\n    <td><small>The prompt can be found <a href=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.txt\">here</a></small></td>\n    <td>\n      <video width=\"680\" controls>\n        <source src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.mp4\" type=\"video/mp4\">\n      </video>\n    </td>\n  </tr>\n</table>\n\n## Resources\n\n通过以下资源了解有关 Helios 的更多信息：\n\n- [视频1](https://www.youtube.com/watch?v=vd_AgHtOUFQ)和[视频2](https://www.youtube.com/watch?v=1GeIU2Dn7UY)演示了 Helios 的主要功能;\n- 有关更多详细信息，请参阅研究论文 [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/2603.04379)。\n"
  },
  {
    "path": "docs/source/zh/using-diffusers/schedulers.md",
    "content": "<!--Copyright 2025 The HuggingFace Team. All rights reserved.\n\n根据 Apache License 2.0 许可证（以下简称\"许可证\"）授权；\n除非符合许可证要求，否则不得使用本文件。\n您可以通过以下链接获取许可证副本：\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\n除非适用法律要求或书面同意，本软件按\"原样\"分发，\n无任何明示或暗示的担保或条件。详见许可证中关于权限和限制的具体规定。\n-->\n\n# 加载调度器与模型\n\n[[open-in-colab]]\n\nDiffusion管道是由可互换的调度器(schedulers)和模型(models)组成的集合，可通过混合搭配来定制特定用例的流程。调度器封装了整个去噪过程（如去噪步数和寻找去噪样本的算法），其本身不包含可训练参数，因此内存占用极低。模型则主要负责从含噪输入到较纯净样本的前向传播过程。\n\n本指南将展示如何加载调度器和模型来自定义流程。我们将全程使用[stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5)检查点，首先加载基础管道：\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, use_safetensors=True\n).to(\"cuda\")\n```\n\n通过`pipeline.scheduler`属性可查看当前管道使用的调度器：\n\n```python\npipeline.scheduler\nPNDMScheduler {\n  \"_class_name\": \"PNDMScheduler\",\n  \"_diffusers_version\": \"0.21.4\",\n  \"beta_end\": 0.012,\n  \"beta_schedule\": \"scaled_linear\",\n  \"beta_start\": 0.00085,\n  \"clip_sample\": false,\n  \"num_train_timesteps\": 1000,\n  \"set_alpha_to_one\": false,\n  \"skip_prk_steps\": true,\n  \"steps_offset\": 1,\n  \"timestep_spacing\": \"leading\",\n  \"trained_betas\": null\n}\n```\n\n## 加载调度器\n\n调度器通过配置文件定义，同一配置文件可被多种调度器共享。使用[`SchedulerMixin.from_pretrained`]方法加载时，需指定`subfolder`参数以定位配置文件在仓库中的正确子目录。\n\n例如加载[`DDIMScheduler`]：\n\n```python\nfrom diffusers import DDIMScheduler, DiffusionPipeline\n\nddim = DDIMScheduler.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", subfolder=\"scheduler\")\n```\n\n然后将新调度器传入管道：\n\n```python\npipeline = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", scheduler=ddim, torch_dtype=torch.float16, use_safetensors=True\n).to(\"cuda\")\n```\n\n## 调度器对比\n\n不同调度器各有优劣，难以定量评估哪个最适合您的流程。通常需要在去噪速度与质量之间权衡。我们建议尝试多种调度器以找到最佳方案。通过`pipeline.scheduler.compatibles`属性可查看兼容当前管道的所有调度器。\n\n下面我们使用相同提示词和随机种子，对比[`LMSDiscreteScheduler`]、[`EulerDiscreteScheduler`]、[`EulerAncestralDiscreteScheduler`]和[`DPMSolverMultistepScheduler`]的表现：\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, use_safetensors=True\n).to(\"cuda\")\n\nprompt = \"A photograph of an astronaut riding a horse on Mars, high resolution, high definition.\"\ngenerator = torch.Generator(device=\"cuda\").manual_seed(8)\n```\n\n使用[`~ConfigMixin.from_config`]方法加载不同调度器的配置来切换管道调度器：\n\n<hfoptions id=\"schedulers\">\n<hfoption id=\"LMSDiscreteScheduler\">\n\n[`LMSDiscreteScheduler`]通常能生成比默认调度器更高质量的图像。\n\n```python\nfrom diffusers import LMSDiscreteScheduler\n\npipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n</hfoption>\n<hfoption id=\"EulerDiscreteScheduler\">\n\n[`EulerDiscreteScheduler`]仅需30步即可生成高质量图像。\n\n```python\nfrom diffusers import EulerDiscreteScheduler\n\npipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n</hfoption>\n<hfoption id=\"EulerAncestralDiscreteScheduler\">\n\n[`EulerAncestralDiscreteScheduler`]同样可在30步内生成高质量图像。\n\n```python\nfrom diffusers import EulerAncestralDiscreteScheduler\n\npipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n</hfoption>\n<hfoption id=\"DPMSolverMultistepScheduler\">\n\n[`DPMSolverMultistepScheduler`]在速度与质量间取得平衡，仅需20步即可生成优质图像。\n\n```python\nfrom diffusers import DPMSolverMultistepScheduler\n\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\nimage = pipeline(prompt, generator=generator).images[0]\nimage\n```\n\n</hfoption>\n</hfoptions>\n\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_lms.png\" />\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">LMSDiscreteScheduler</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_discrete.png\" />\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">EulerDiscreteScheduler</figcaption>\n  </div>\n</div>\n<div class=\"flex gap-4\">\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_ancestral.png\" />\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">EulerAncestralDiscreteScheduler</figcaption>\n  </div>\n  <div>\n    <img class=\"rounded-xl\" src=\"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_dpm.png\" />\n    <figcaption class=\"mt-2 text-center text-sm text-gray-500\">DPMSolverMultistepScheduler</figcaption>\n  </div>\n</div>\n\n多数生成图像质量相近，实际选择需根据具体场景测试多种调度器进行比较。\n\n### Flax调度器\n\n对比Flax调度器时，需额外将调度器状态加载到模型参数中。例如将[`FlaxStableDiffusionPipeline`]的默认调度器切换为超高效的[`FlaxDPMSolverMultistepScheduler`]：\n\n> [!警告]\n> [`FlaxLMSDiscreteScheduler`]和[`FlaxDDPMScheduler`]目前暂不兼容[`FlaxStableDiffusionPipeline`]。\n\n```python\nimport jax\nimport numpy as np\nfrom flax.jax_utils import replicate\nfrom flax.training.common_utils import shard\nfrom diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler\n\nscheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    subfolder=\"scheduler\"\n)\npipeline, params = FlaxStableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    scheduler=scheduler,\n    variant=\"bf16\",\n    dtype=jax.numpy.bfloat16,\n)\nparams[\"scheduler\"] = scheduler_state\n```\n\n利用Flax对TPU的兼容性实现并行图像生成。需为每个设备复制模型参数，并分配输入数据：\n\n```python\n# 每个并行设备生成1张图像（TPUv2-8/TPUv3-8支持8设备并行）\nprompt = \"一张宇航员在火星上骑马的高清照片，高分辨率，高画质。\"\nnum_samples = jax.device_count()\nprompt_ids = pipeline.prepare_inputs([prompt] * num_samples)\n\nprng_seed = jax.random.PRNGKey(0)\nnum_inference_steps = 25\n\n# 分配输入和随机种子\nparams = replicate(params)\nprng_seed = jax.random.split(prng_seed, jax.device_count())\nprompt_ids = shard(prompt_ids)\n\nimages = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images\nimages = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))\n```\n\n## 模型加载\n\n通过[`ModelMixin.from_pretrained`]方法加载模型，该方法会下载并缓存模型权重和配置的最新版本。若本地缓存已存在最新文件，则直接复用缓存而非重复下载。\n\n通过`subfolder`参数可从子目录加载模型。例如[stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5)的模型权重存储在[unet](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet)子目录中：\n\n```python\nfrom diffusers import UNet2DConditionModel\n\nunet = UNet2DConditionModel.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", subfolder=\"unet\", use_safetensors=True)\n```\n\n也可直接从[仓库](https://huggingface.co/google/ddpm-cifar10-32/tree/main)加载：\n\n```python\nfrom diffusers import UNet2DModel\n\nunet = UNet2DModel.from_pretrained(\"google/ddpm-cifar10-32\", use_safetensors=True)\n```\n\n加载和保存模型变体时，需在[`ModelMixin.from_pretrained`]和[`ModelMixin.save_pretrained`]中指定`variant`参数：\n\n```python\nfrom diffusers import UNet2DConditionModel\n\nunet = UNet2DConditionModel.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", subfolder=\"unet\", variant=\"non_ema\", use_safetensors=True\n)\nunet.save_pretrained(\"./local-unet\", variant=\"non_ema\")\n```\n\n使用[`~ModelMixin.from_pretrained`]的`torch_dtype`参数指定模型加载精度：\n\n```python\nfrom diffusers import AutoModel\n\nunet = AutoModel.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", subfolder=\"unet\", torch_dtype=torch.float16\n)\n```\n\n也可使用[torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html)方法即时转换精度，但会转换所有权重（不同于`torch_dtype`参数会保留`_keep_in_fp32_modules`中的层）。这对某些必须保持fp32精度的层尤为重要（参见[示例](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)）。\n"
  },
  {
    "path": "examples/README.md",
    "content": "<!---\nCopyright 2025 The HuggingFace Team. All rights reserved.\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\n\n# 🧨 Diffusers Examples\n\nDiffusers examples are a collection of scripts to demonstrate how to effectively use the `diffusers` library\nfor a variety of use cases involving training or fine-tuning.\n\n**Note**: If you are looking for **official** examples on how to use `diffusers` for inference, please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).\n\nOur examples aspire to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.\nMore specifically, this means:\n\n- **Self-contained**: An example script shall only depend on \"pip-install-able\" Python packages that can be found in a `requirements.txt` file. Example scripts shall **not** depend on any local files. This means that one can simply download an example script, *e.g.* [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py), install the required dependencies, *e.g.* [requirements.txt](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/requirements.txt) and execute the example script.\n- **Easy-to-tweak**: While we strive to present as many use cases as possible, the example scripts are just that - examples. It is expected that they won't work out-of-the box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs. To help you with that, most of the examples fully expose the preprocessing of the data and the training loop to allow you to tweak and edit them as required.\n- **Beginner-friendly**: We do not aim for providing state-of-the-art training scripts for the newest models, but rather examples that can be used as a way to better understand diffusion models and how to use them with the `diffusers` library. We often purposefully leave out certain state-of-the-art methods if we consider them too complex for beginners.\n- **One-purpose-only**: Examples should show one task and one task only. Even if a task is from a modeling point of view very similar, *e.g.* image super-resolution and image modification tend to use the same model and training method, we want examples to showcase only one task to keep them as readable and easy-to-understand as possible.\n\nWe provide **official** examples that cover the most popular tasks of diffusion models.\n*Official* examples are **actively** maintained by the `diffusers` maintainers and we try to rigorously follow our example philosophy as defined above.\nIf you feel like another important example should exist, we are more than happy to welcome a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) or directly a [Pull Request](https://github.com/huggingface/diffusers/compare) from you!\n\nTraining examples show how to pretrain or fine-tune diffusion models for a variety of tasks. Currently we support:\n\n| Task | 🤗 Accelerate | 🤗 Datasets | Colab\n|---|---|:---:|:---:|\n| [**Unconditional Image Generation**](./unconditional_image_generation) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)\n| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |\n| [**Textual Inversion**](./textual_inversion) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)\n| [**Dreambooth**](./dreambooth) | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)\n| [**ControlNet**](./controlnet) | ✅ | ✅ | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)\n| [**InstructPix2Pix**](./instruct_pix2pix) | ✅ | ✅ | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/InstructPix2Pix_using_diffusers.ipynb)\n| [**Reinforcement Learning for Control**](./reinforcement_learning)                    | - | - | [Notebook1](https://github.com/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_for_control.ipynb),  [Notebook2](https://github.com/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb)\n\n## Community\n\nIn addition, we provide **community** examples, which are examples added and maintained by our community.\nCommunity examples can consist of both *training* examples or *inference* pipelines.\nFor such examples, we are more lenient regarding the philosophy defined above and also cannot guarantee to provide maintenance for every issue.\nExamples that are useful for the community, but are either not yet deemed popular or not yet following our above philosophy should go into the [community examples](https://github.com/huggingface/diffusers/tree/main/examples/community) folder. The community folder therefore includes training examples and inference pipelines.\n**Note**: Community examples can be a [great first contribution](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) to show to the community how you like to use `diffusers` 🪄.\n\n## Research Projects\n\nWe also provide **research_projects** examples that are maintained by the community as defined in the respective research project folders. These examples are useful and offer the extended capabilities which are complementary to the official examples. You may refer to [research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) for details.\n\n## Important note\n\nTo make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\nThen cd in the example folder of your choice and run\n```bash\npip install -r requirements.txt\n```\n"
  },
  {
    "path": "examples/advanced_diffusion_training/README.md",
    "content": "# Advanced diffusion training examples\n\n## Train Dreambooth LoRA with Stable Diffusion XL\n> [!TIP]\n> 💡 This example follows the techniques and recommended practices covered in the blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). Make sure to check it out before starting 🤗\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.\n\nLoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*\nIn a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:\n- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)\n- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.\n- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.\n[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in\nthe popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.\n\nThe `train_dreambooth_lora_sdxl_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_sdxl.py`, with\nadvanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://huggingface.co/papers/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),\n[Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️\n\n> [!NOTE]\n> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳\n> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)\n\n📚 Read more about the advanced features and best practices in this community derived blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script)\n\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/advanced_diffusion_training` folder and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell e.g. a notebook\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\nLastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:\n```bash\nhf auth login\n```\nThis command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.\n\n> [!NOTE]\n> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:\n> `pip install wandb`\n> Alternatively, you can use other tools / train without reporting by modifying the flag  `--report_to=\"wandb\"`.\n\n### Pivotal Tuning\n**Training with text encoder(s)**\n\nAlongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization\navailable with `train_dreambooth_lora_sdxl_advanced.py`, in the advanced script **pivotal tuning** is also supported.\n[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -\nwe insert new tokens into the text encoders of the model, instead of reusing existing ones.\nWe then optimize the newly-inserted token embeddings to represent the new concept.\n\nTo do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).\nPlease keep the following points in mind:\n\n* SDXL has two text encoders. So, we fine-tune both using LoRA.\n* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.\n\n### 3D icon example\n\nNow let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.\n\nLet's first download it locally:\n\n```python\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./3d_icon\"\nsnapshot_download(\n    \"LinoyTsaban/3d_icon\",\n    local_dir=local_dir, repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nLet's review some of the advanced features we're going to be using for this example:\n- **custom captions**:\nTo use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by\n```bash\npip install datasets\n```\n\nNow we'll simply specify the name of the dataset and caption column (in this case it's \"prompt\")\n\n```\n--dataset_name=./3d_icon\n--caption_column=prompt\n```\n\nYou can also load a dataset straight from by specifying it's name in `dataset_name`.\nLook [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loading your own caption dataset.\n\n- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer\n  - To use Prodigy, please make sure to install the prodigyopt library: `pip install prodigyopt`\n- **pivotal tuning**\n- **min SNR gamma**\n\n**Now, we can launch training:**\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport DATASET_NAME=\"./3d_icon\"\nexport OUTPUT_DIR=\"3d-icon-SDXL-LoRA\"\nexport VAE_PATH=\"madebyollin/sdxl-vae-fp16-fix\"\n\naccelerate launch train_dreambooth_lora_sdxl_advanced.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --pretrained_vae_model_name_or_path=$VAE_PATH \\\n  --dataset_name=$DATASET_NAME \\\n  --instance_prompt=\"3d icon in the style of TOK\" \\\n  --validation_prompt=\"a TOK icon of an astronaut riding a horse, in the style of TOK\" \\\n  --output_dir=$OUTPUT_DIR \\\n  --caption_column=\"prompt\" \\\n  --mixed_precision=\"bf16\" \\\n  --resolution=1024 \\\n  --train_batch_size=3 \\\n  --repeats=1 \\\n  --report_to=\"wandb\"\\\n  --gradient_accumulation_steps=1 \\\n  --gradient_checkpointing \\\n  --learning_rate=1.0 \\\n  --text_encoder_lr=1.0 \\\n  --optimizer=\"prodigy\"\\\n  --train_text_encoder_ti\\\n  --train_text_encoder_ti_frac=0.5\\\n  --snr_gamma=5.0 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --rank=8 \\\n  --max_train_steps=1000 \\\n  --checkpointing_steps=2000 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\nOur experiments were conducted on a single 40GB A100 GPU.\n\n\n### Inference\n\nOnce training is done, we can perform inference like so:\n1. starting with loading the unet lora weights\n```python\nimport torch\nfrom huggingface_hub import hf_hub_download, upload_file\nfrom diffusers import DiffusionPipeline\nfrom diffusers.models import AutoencoderKL\nfrom safetensors.torch import load_file\n\nusername = \"linoyts\"\nrepo_id = f\"{username}/3d-icon-SDXL-LoRA\"\n\npipe = DiffusionPipeline.from_pretrained(\n        \"stabilityai/stable-diffusion-xl-base-1.0\",\n        torch_dtype=torch.float16,\n        variant=\"fp16\",\n).to(\"cuda\")\n\n\npipe.load_lora_weights(repo_id, weight_name=\"pytorch_lora_weights.safetensors\")\n```\n2. now we load the pivotal tuning embeddings\n\n```python\ntext_encoders = [pipe.text_encoder, pipe.text_encoder_2]\ntokenizers = [pipe.tokenizer, pipe.tokenizer_2]\n\nembedding_path = hf_hub_download(repo_id=repo_id, filename=\"3d-icon-SDXL-LoRA_emb.safetensors\", repo_type=\"model\")\n\nstate_dict = load_file(embedding_path)\n# load embeddings of text_encoder 1 (CLIP ViT-L/14)\npipe.load_textual_inversion(state_dict[\"clip_l\"], token=[\"<s0>\", \"<s1>\"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)\n# load embeddings of text_encoder 2 (CLIP ViT-G/14)\npipe.load_textual_inversion(state_dict[\"clip_g\"], token=[\"<s0>\", \"<s1>\"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)\n```\n\n3. let's generate images\n\n```python\ninstance_token = \"<s0><s1>\"\nprompt = f\"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}\"\n\nimage = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={\"scale\": 1.0}).images[0]\nimage.save(\"llama.png\")\n```\n\n### Comfy UI / AUTOMATIC1111 Inference\nThe new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!\n\n**AUTOMATIC1111 / SD.Next** \\\nIn AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.\n- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.\n- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.\n\nYou can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.\n\n**ComfyUI** \\\nIn ComfyUI we will load a LoRA and a textual embedding at the same time.\n- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)\n- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).\n-\n### Specifying a better VAE\n\nSDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).\n\n### DoRA training\nThe advanced script supports DoRA training too!\n> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://huggingface.co/papers/2402.09353),\n**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.\nThe authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.\n\n> [!NOTE]\n> 💡DoRA training is still _experimental_\n> and is likely to require different hyperparameter values to perform best compared to a LoRA.\n> Specifically, we've noticed 2 differences to take into account your training:\n> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA)\n> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example.\n> This is also aligned with some of the quantitative analysis shown in the paper.\n\n**Usage**\n1. To use DoRA you need to install `peft` from main:\n```bash\npip install git+https://github.com/huggingface/peft.git\n```\n2. Enable DoRA training by adding this flag\n```bash\n--use_dora\n```\n**Inference**\nThe inference is the same as if you train a regular LoRA 🤗\n\n## Conducting EDM-style training\n\nIt's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364).\n\nsimply set:\n\n```diff\n+  --do_edm_style_training \\\n```\n\nOther SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:\n\n```bash\naccelerate launch train_dreambooth_lora_sdxl_advanced.py \\\n  --pretrained_model_name_or_path=\"playgroundai/playground-v2.5-1024px-aesthetic\"  \\\n  --dataset_name=\"linoyts/3d_icon\" \\\n  --instance_prompt=\"3d icon in the style of TOK\" \\\n  --validation_prompt=\"a TOK icon of an astronaut riding a horse, in the style of TOK\" \\\n  --output_dir=\"3d-icon-SDXL-LoRA\" \\\n  --do_edm_style_training \\\n  --caption_column=\"prompt\" \\\n  --mixed_precision=\"bf16\" \\\n  --resolution=1024 \\\n  --train_batch_size=3 \\\n  --repeats=1 \\\n  --report_to=\"wandb\"\\\n  --gradient_accumulation_steps=1 \\\n  --gradient_checkpointing \\\n  --learning_rate=1.0 \\\n  --text_encoder_lr=1.0 \\\n  --optimizer=\"prodigy\"\\\n  --train_text_encoder_ti\\\n  --train_text_encoder_ti_frac=0.5\\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --rank=8 \\\n  --max_train_steps=1000 \\\n  --checkpointing_steps=2000 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n> [!CAUTION]\n> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any \"variant\".\n\n### B-LoRA training\nThe advanced script now supports B-LoRA training too!\n> Proposed in [Implicit Style-Content Separation using B-LoRA](https://huggingface.co/papers/2403.14572),\nB-LoRA is a method that leverages LoRA to implicitly separate the style and content components of a **single** image.\nIt was shown that learning the LoRA weights of two specific blocks (referred to as B-LoRAs)\nachieves style-content separation that cannot be achieved by training each B-LoRA independently.\nOnce trained, the two B-LoRAs can be used as independent components to allow various image stylization tasks\n\n**Usage**\nEnable B-LoRA training by adding this flag\n```bash\n--use_blora\n```\nYou can train a B-LoRA with as little as 1 image, and 1000 steps. Try this default configuration as a start:\n```bash\n!accelerate launch train_dreambooth_b-lora_sdxl.py \\\n --pretrained_model_name_or_path=\"stabilityai/stable-diffusion-xl-base-1.0\" \\\n --instance_data_dir=\"linoyts/B-LoRA_teddy_bear\" \\\n --output_dir=\"B-LoRA_teddy_bear\" \\\n --instance_prompt=\"a [v18]\" \\\n --resolution=1024 \\\n --rank=64 \\\n --train_batch_size=1 \\\n --learning_rate=5e-5 \\\n --lr_scheduler=\"constant\" \\\n --lr_warmup_steps=0 \\\n --max_train_steps=1000 \\\n --checkpointing_steps=2000 \\\n --seed=\"0\" \\\n --gradient_checkpointing \\\n --mixed_precision=\"fp16\"\n```\n**Inference**\nThe inference is a bit different:\n1. we need load *specific* unet layers (as opposed to a regular LoRA/DoRA)\n2. the trained layers we load, changes based on our objective (e.g. style/content)\n\n```python\nimport torch\nfrom diffusers import StableDiffusionXLPipeline, AutoencoderKL\n\n# taken & modified from B-LoRA repo - https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py\ndef is_belong_to_blocks(key, blocks):\n    try:\n        for g in blocks:\n            if g in key:\n                return True\n        return False\n    except Exception as e:\n        raise type(e)(f'failed to is_belong_to_block, due to: {e}')\n\ndef lora_lora_unet_blocks(lora_path, alpha, target_blocks):\n  state_dict, _ = pipeline.lora_state_dict(lora_path)\n  filtered_state_dict = {k: v * alpha for k, v in state_dict.items() if is_belong_to_blocks(k, target_blocks)}\n  return filtered_state_dict\n\nvae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16)\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    vae=vae,\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n\n# pick a blora for content/style (you can also set one to None)\ncontent_B_lora_path  = \"lora-library/B-LoRA-teddybear\"\nstyle_B_lora_path= \"lora-library/B-LoRA-pen_sketch\"\n\n\ncontent_B_LoRA = lora_lora_unet_blocks(content_B_lora_path,alpha=1,target_blocks=[\"unet.up_blocks.0.attentions.0\"])\nstyle_B_LoRA = lora_lora_unet_blocks(style_B_lora_path,alpha=1.1,target_blocks=[\"unet.up_blocks.0.attentions.1\"])\ncombined_lora = {**content_B_LoRA, **style_B_LoRA}\n\n# Load both loras\npipeline.load_lora_into_unet(combined_lora, None, pipeline.unet)\n\n#generate\nprompt = \"a [v18] in [v30] style\"\npipeline(prompt, num_images_per_prompt=4).images\n```\n### LoRA training of Targeted U-net Blocks\nThe advanced script now supports custom choice of U-net blocks to train during Dreambooth LoRA tuning.\n> [!NOTE]\n> This feature is still experimental\n\n> Recently, works like B-LoRA showed the potential advantages of learning the LoRA weights of specific U-net blocks, not only in speed & memory,\n> but also in reducing the amount of needed data, improving style manipulation and overcoming overfitting issues.\n> In light of this, we're introducing a new feature to the advanced script to allow for configurable U-net learned blocks.\n\n**Usage**\nConfigure LoRA learned U-net blocks adding a `lora_unet_blocks` flag, with a comma separated string specifying the targeted blocks.\ne.g:\n```bash\n--lora_unet_blocks=\"unet.up_blocks.0.attentions.0,unet.up_blocks.0.attentions.1\"\n```\n\n> [!NOTE]\n> if you specify both `--use_blora` and `--lora_unet_blocks`, values given in --lora_unet_blocks will be ignored.\n> When enabling --use_blora, targeted U-net blocks are automatically set to be \"unet.up_blocks.0.attentions.0,unet.up_blocks.0.attentions.1\" as discussed in the paper.\n> If you wish to experiment with different blocks, specify `--lora_unet_blocks` only.\n\n**Inference**\nInference is the same as for B-LoRAs, except the input targeted blocks should be modified based on your training configuration.\n```python\nimport torch\nfrom diffusers import StableDiffusionXLPipeline, AutoencoderKL\n\n# taken & modified from B-LoRA repo - https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py\ndef is_belong_to_blocks(key, blocks):\n    try:\n        for g in blocks:\n            if g in key:\n                return True\n        return False\n    except Exception as e:\n        raise type(e)(f'failed to is_belong_to_block, due to: {e}')\n\ndef lora_lora_unet_blocks(lora_path, alpha, target_blocks):\n  state_dict, _ = pipeline.lora_state_dict(lora_path)\n  filtered_state_dict = {k: v * alpha for k, v in state_dict.items() if is_belong_to_blocks(k, target_blocks)}\n  return filtered_state_dict\n\nvae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16)\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    vae=vae,\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n\nlora_path  = \"lora-library/B-LoRA-pen_sketch\"\n\nstate_dict = lora_lora_unet_blocks(content_B_lora_path,alpha=1,target_blocks=[\"unet.up_blocks.0.attentions.0\"])\n\n# Load trained lora layers into the unet\npipeline.load_lora_into_unet(state_dict, None, pipeline.unet)\n\n#generate\nprompt = \"a dog in [v30] style\"\npipeline(prompt, num_images_per_prompt=4).images\n```\n\n\n### Tips and Tricks\nCheck out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)\n\n## Running on Colab Notebook\nCheck out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb).\nto train using the advanced features (including pivotal tuning), and [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb) to train on a free colab, using some of the advanced features (excluding pivotal tuning)\n\n"
  },
  {
    "path": "examples/advanced_diffusion_training/README_flux.md",
    "content": "# Advanced diffusion training examples\n\n## Train Dreambooth LoRA with Flux.1 Dev\n> [!TIP]\n> 💡 This example follows some of the techniques and recommended practices covered in the community derived guide we made for SDXL training: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). \n> As many of these are architecture agnostic & generally relevant to fine-tuning of diffusion models we suggest to take a look 🤗\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text-to-image models like flux, stable diffusion given just a few(3~5) images of a subject.\n\nLoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*\nIn a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:\n- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)\n- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.\n- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.\n[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in\nthe popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.\n\nThe `train_dreambooth_lora_flux_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_flux.py`, with\nadvanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://huggingface.co/papers/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),\n[ostris](https://x.com/ostrisai):[ai-toolkit](https://github.com/ostris/ai-toolkit), [bghira](https://github.com/bghira):[SimpleTuner](https://github.com/bghira/SimpleTuner), [Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️\n\n> [!NOTE]\n> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳\n> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/advanced_diffusion_training` folder and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell e.g. a notebook\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\nLastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:\n```bash\nhf auth login\n```\nThis command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.\n\n> [!NOTE]\n> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:\n> `pip install wandb`\n> Alternatively, you can use other tools / train without reporting by modifying the flag  `--report_to=\"wandb\"`.\n\n### LoRA Rank and Alpha\nTwo key LoRA hyperparameters are LoRA rank and LoRA alpha. \n- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).\n- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.\n- lora_alpha vs. rank:\nThis ratio dictates the LoRA's effective strength:\nlora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)\nlora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)\nlora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)\n\n> [!TIP]\n> A common starting point is to set `lora_alpha` equal to `rank`. \n> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16) \n> to give the LoRA updates more influence without increasing parameter count. \n> If you find your LoRA is \"overcooking\" or learning too aggressively, consider setting `lora_alpha` to half of `rank` \n> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.\n\n\n### Target Modules\nWhen LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. \nMore recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore \napplying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string\nthe exact modules for LoRA training. Here are some examples of target modules you can provide: \n- for attention only layers: `--lora_layers=\"attn.to_k,attn.to_q,attn.to_v,attn.to_out.0\"`\n- to train the same modules as in the fal trainer: `--lora_layers=\"attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2\"`\n- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks=\"attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out\"`\n> [!NOTE]\n> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:\n> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`\n> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k` \n> [!NOTE]\n> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.\n\n### Pivotal Tuning (and more)\n**Training with text encoder(s)**\n\nAlongside the Transformer, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization\navailable with `train_dreambooth_lora_flux_advanced.py`, in the advanced script **pivotal tuning** is also supported.\n[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -\nwe insert new tokens into the text encoders of the model, instead of reusing existing ones.\nWe then optimize the newly-inserted token embeddings to represent the new concept.\n\nTo do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).\nPlease keep the following points in mind:\n\n* Flux uses two text encoders - [CLIP](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder) & [T5](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder_2) , by default `--train_text_encoder_ti` performs pivotal tuning for the **CLIP** encoder only.\nTo activate pivotal tuning for both encoders, add the flag `--enable_t5_ti`. \n* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.\n* **pure textual inversion** - to support the full range from pivotal tuning to textual inversion we introduce `--train_transformer_frac` which controls the amount of epochs the transformer LoRA layers are trained. By default, `--train_transformer_frac==1`, to trigger a textual inversion run set `--train_transformer_frac==0`. Values between 0 and 1 are supported as well, and we welcome the community to experiment w/ different settings and share the results!\n* **token initializer** - similar to the original textual inversion work, you can specify a concept of your choosing as the starting point for training. By default, when enabling `--train_text_encoder_ti`, the new inserted tokens are initialized randomly. You can specify a token in `--initializer_concept` such that the starting point for the trained embeddings will be the embeddings associated with your chosen `--initializer_concept`.\n\n## Training examples\n\nNow let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.\n\nLet's first download it locally:\n\n```python\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./3d_icon\"\nsnapshot_download(\n    \"LinoyTsaban/3d_icon\",\n    local_dir=local_dir, repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nLet's review some of the advanced features we're going to be using for this example:\n- **custom captions**:\nTo use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by\n```bash\npip install datasets\n```\n\nNow we'll simply specify the name of the dataset and caption column (in this case it's \"prompt\")\n\n```\n--dataset_name=./3d_icon\n--caption_column=prompt\n```\n\nYou can also load a dataset straight from by specifying it's name in `dataset_name`.\nLook [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loading your own caption dataset.\n\n- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer \n    - To use Prodigy, please make sure to install the prodigyopt library: `pip install prodigyopt`\n- **pivotal tuning**\n\n### Example #1: Pivotal tuning\n**Now, we can launch training:**\n\n```bash\nexport MODEL_NAME=\"black-forest-labs/FLUX.1-dev\"\nexport DATASET_NAME=\"./3d_icon\"\nexport OUTPUT_DIR=\"3d-icon-Flux-LoRA\"\n\naccelerate launch train_dreambooth_lora_flux_advanced.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$DATASET_NAME \\\n  --instance_prompt=\"3d icon in the style of TOK\" \\\n  --output_dir=$OUTPUT_DIR \\\n  --caption_column=\"prompt\" \\\n  --mixed_precision=\"bf16\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --repeats=1 \\\n  --report_to=\"wandb\"\\\n  --gradient_accumulation_steps=1 \\\n  --gradient_checkpointing \\\n  --learning_rate=1.0 \\\n  --text_encoder_lr=1.0 \\\n  --optimizer=\"prodigy\"\\\n  --train_text_encoder_ti\\\n  --train_text_encoder_ti_frac=0.5\\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --rank=8 \\\n  --max_train_steps=700 \\\n  --checkpointing_steps=2000 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\nOur experiments were conducted on a single 40GB A100 GPU.\n\n### Example #2: Pivotal tuning with T5\nNow let's try that with T5 as well, so instead of only optimizing the CLIP embeddings associated with newly inserted tokens, we'll optimize\nthe T5 embeddings as well. We can do this by simply adding `--enable_t5_ti` to the previous configuration:\n```bash\nexport MODEL_NAME=\"black-forest-labs/FLUX.1-dev\"\nexport DATASET_NAME=\"./3d_icon\"\nexport OUTPUT_DIR=\"3d-icon-Flux-LoRA\"\n\naccelerate launch train_dreambooth_lora_flux_advanced.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$DATASET_NAME \\\n  --instance_prompt=\"3d icon in the style of TOK\" \\\n  --output_dir=$OUTPUT_DIR \\\n  --caption_column=\"prompt\" \\\n  --mixed_precision=\"bf16\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --repeats=1 \\\n  --report_to=\"wandb\"\\\n  --gradient_accumulation_steps=1 \\\n  --gradient_checkpointing \\\n  --learning_rate=1.0 \\\n  --text_encoder_lr=1.0 \\\n  --optimizer=\"prodigy\"\\\n  --train_text_encoder_ti\\\n  --enable_t5_ti\\\n  --train_text_encoder_ti_frac=0.5\\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --rank=8 \\\n  --max_train_steps=700 \\\n  --checkpointing_steps=2000 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n### Example #3: Textual Inversion\nTo explore a pure textual inversion - i.e. only optimizing the text embeddings w/o training transformer LoRA layers, we \ncan set the value for `--train_transformer_frac` - which is responsible for the percent of epochs in which the transformer is \ntrained. By setting `--train_transformer_frac == 0` and enabling `--train_text_encoder_ti` we trigger a textual inversion train \nrun.\n```bash\nexport MODEL_NAME=\"black-forest-labs/FLUX.1-dev\"\nexport DATASET_NAME=\"./3d_icon\"\nexport OUTPUT_DIR=\"3d-icon-Flux-LoRA\"\n\naccelerate launch train_dreambooth_lora_flux_advanced.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$DATASET_NAME \\\n  --instance_prompt=\"3d icon in the style of TOK\" \\\n  --output_dir=$OUTPUT_DIR \\\n  --caption_column=\"prompt\" \\\n  --mixed_precision=\"bf16\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --repeats=1 \\\n  --report_to=\"wandb\"\\\n  --gradient_accumulation_steps=1 \\\n  --gradient_checkpointing \\\n  --learning_rate=1.0 \\\n  --text_encoder_lr=1.0 \\\n  --optimizer=\"prodigy\"\\\n  --train_text_encoder_ti\\\n  --enable_t5_ti\\\n  --train_text_encoder_ti_frac=0.5\\\n  --train_transformer_frac=0\\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --rank=8 \\\n  --max_train_steps=700 \\\n  --checkpointing_steps=2000 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n### Inference - pivotal tuning\n\nOnce training is done, we can perform inference like so:\n1. starting with loading the transformer lora weights\n```python\nimport torch\nfrom huggingface_hub import hf_hub_download, upload_file\nfrom diffusers import AutoPipelineForText2Image\nfrom safetensors.torch import load_file\n\nusername = \"linoyts\"\nrepo_id = f\"{username}/3d-icon-Flux-LoRA\"\n\npipe = AutoPipelineForText2Image.from_pretrained(\"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16).to('cuda')\n\n\npipe.load_lora_weights(repo_id, weight_name=\"pytorch_lora_weights.safetensors\")\n```\n2. now we load the pivotal tuning embeddings \n> [!NOTE] #1 if `--enable_t5_ti` wasn't passed, we only load the embeddings to the CLIP encoder.\n\n> [!NOTE] #2 the number of tokens (i.e. <s0>,...,<si>) is either determined by `--num_new_tokens_per_abstraction` or by `--initializer_concept`. Make sure to update inference code accordingly :)\n```python\ntext_encoders = [pipe.text_encoder, pipe.text_encoder_2]\ntokenizers = [pipe.tokenizer, pipe.tokenizer_2]\n\nembedding_path = hf_hub_download(repo_id=repo_id, filename=\"3d-icon-Flux-LoRA_emb.safetensors\", repo_type=\"model\")\n\nstate_dict = load_file(embedding_path)\n# load embeddings of text_encoder 1 (CLIP ViT-L/14)\npipe.load_textual_inversion(state_dict[\"clip_l\"], token=[\"<s0>\", \"<s1>\"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)\n# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti`\npipe.load_textual_inversion(state_dict[\"t5\"], token=[\"<s0>\", \"<s1>\"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)\n```\n\n3. let's generate images\n\n```python\ninstance_token = \"<s0><s1>\"\nprompt = f\"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}\"\n\nimage = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={\"scale\": 1.0}).images[0]\nimage.save(\"llama.png\")\n```\n\n### Inference - pure textual inversion\nIn this case, we don't load transformer layers as before, since we only optimize the text embeddings. The output of a textual inversion train run is a\n`.safetensors` file containing the trained embeddings for the new tokens either for the CLIP encoder, or for both encoders (CLIP and T5) \n\n1. starting with loading the embeddings.\n💡note that here too, if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder\n\n```python\nimport torch\nfrom huggingface_hub import hf_hub_download, upload_file\nfrom diffusers import AutoPipelineForText2Image\nfrom safetensors.torch import load_file\n\nusername = \"linoyts\"\nrepo_id = f\"{username}/3d-icon-Flux-LoRA\"\n\npipe = AutoPipelineForText2Image.from_pretrained(\"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16).to('cuda')\n\ntext_encoders = [pipe.text_encoder, pipe.text_encoder_2]\ntokenizers = [pipe.tokenizer, pipe.tokenizer_2]\n\nembedding_path = hf_hub_download(repo_id=repo_id, filename=\"3d-icon-Flux-LoRA_emb.safetensors\", repo_type=\"model\")\n\nstate_dict = load_file(embedding_path)\n# load embeddings of text_encoder 1 (CLIP ViT-L/14)\npipe.load_textual_inversion(state_dict[\"clip_l\"], token=[\"<s0>\", \"<s1>\"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)\n# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti`\npipe.load_textual_inversion(state_dict[\"t5\"], token=[\"<s0>\", \"<s1>\"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)\n```\n2. let's generate images\n\n```python\ninstance_token = \"<s0><s1>\"\nprompt = f\"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}\"\n\nimage = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={\"scale\": 1.0}).images[0]\nimage.save(\"llama.png\")\n```\n\n### Comfy UI / AUTOMATIC1111 Inference\nThe new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!\n\n**AUTOMATIC1111 / SD.Next** \\\nIn AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.\n- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.\n- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.\n\nYou can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.\n\n**ComfyUI** \\\nIn ComfyUI we will load a LoRA and a textual embedding at the same time.\n- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)\n- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).\n"
  },
  {
    "path": "examples/advanced_diffusion_training/requirements.txt",
    "content": "accelerate>=0.31.0\ntorchvision\ntransformers>=4.41.2\nftfy\ntensorboard\nJinja2\npeft>=0.11.1\nsentencepiece"
  },
  {
    "path": "examples/advanced_diffusion_training/requirements_flux.txt",
    "content": "accelerate>=0.31.0\ntorchvision\ntransformers>=4.41.2\nftfy\ntensorboard\nJinja2\npeft>=0.11.1\nsentencepiece"
  },
  {
    "path": "examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\nfrom diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):\n    instance_data_dir = \"docs/source/en/imgs\"\n    instance_prompt = \"photo\"\n    pretrained_model_name_or_path = \"hf-internal-testing/tiny-flux-pipe\"\n    script_path = \"examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py\"\n\n    def test_dreambooth_lora_flux(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_text_encoder_flux(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --train_text_encoder\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            starts_with_expected_prefix = all(\n                (key.startswith(\"transformer\") or key.startswith(\"text_encoder\")) for key in lora_state_dict.keys()\n            )\n            self.assertTrue(starts_with_expected_prefix)\n\n    def test_dreambooth_lora_pivotal_tuning_flux_clip(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --train_text_encoder_ti\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n            # make sure embeddings were also saved\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, f\"{os.path.basename(tmpdir)}_emb.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # make sure the state_dict has the correct naming in the parameters.\n            textual_inversion_state_dict = safetensors.torch.load_file(\n                os.path.join(tmpdir, f\"{os.path.basename(tmpdir)}_emb.safetensors\")\n            )\n            is_clip = all(\"clip_l\" in k for k in textual_inversion_state_dict.keys())\n            self.assertTrue(is_clip)\n\n            # when performing pivotal tuning, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --train_text_encoder_ti\n                --enable_t5_ti\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n            # make sure embeddings were also saved\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, f\"{os.path.basename(tmpdir)}_emb.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # make sure the state_dict has the correct naming in the parameters.\n            textual_inversion_state_dict = safetensors.torch.load_file(\n                os.path.join(tmpdir, f\"{os.path.basename(tmpdir)}_emb.safetensors\")\n            )\n            is_te = all((\"clip_l\" in k or \"t5\" in k) for k in textual_inversion_state_dict.keys())\n            self.assertTrue(is_te)\n\n            # when performing pivotal tuning, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_latent_caching(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-2\", \"checkpoint-4\"})\n\n            resume_run_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n\n    def test_dreambooth_lora_with_metadata(self):\n        # Use a `lora_alpha` that is different from `rank`.\n        lora_alpha = 8\n        rank = 4\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --lora_alpha={lora_alpha}\n                --rank={rank}\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            state_dict_file = os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")\n            self.assertTrue(os.path.isfile(state_dict_file))\n\n            # Check if the metadata was properly serialized.\n            with safetensors.torch.safe_open(state_dict_file, framework=\"pt\", device=\"cpu\") as f:\n                metadata = f.metadata() or {}\n\n            metadata.pop(\"format\", None)\n            raw = metadata.get(LORA_ADAPTER_METADATA_KEY)\n            if raw:\n                raw = json.loads(raw)\n\n            loaded_lora_alpha = raw[\"transformer.lora_alpha\"]\n            self.assertTrue(loaded_lora_alpha == lora_alpha)\n            loaded_lora_rank = raw[\"transformer.r\"]\n            self.assertTrue(loaded_lora_rank == rank)\n"
  },
  {
    "path": "examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=0.31.0\",\n#     \"transformers>=4.41.2\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.11.1\",\n#     \"sentencepiece\",\n#     \"torchvision\",\n#     \"datasets\",\n#     \"bitsandbytes\",\n#     \"prodigyopt\",\n# ]\n# ///\n\nimport argparse\nimport copy\nimport itertools\nimport logging\nimport math\nimport os\nimport random\nimport re\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import List, Optional\n\nimport numpy as np\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom safetensors.torch import save_file\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    FlowMatchEulerDiscreteScheduler,\n    FluxPipeline,\n    FluxTransformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    _collate_lora_metadata,\n    _set_state_dict_into_text_encoder,\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    free_memory,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    train_text_encoder=False,\n    train_text_encoder_ti=False,\n    enable_t5_ti=False,\n    pure_textual_inversion=False,\n    token_abstraction_dict=None,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n):\n    widget_dict = []\n    trigger_str = f\"You should use {instance_prompt} to trigger the image generation.\"\n\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n    diffusers_load_lora = \"\"\n    diffusers_imports_pivotal = \"\"\n    diffusers_example_pivotal = \"\"\n    if not pure_textual_inversion:\n        diffusers_load_lora = (\n            f\"\"\"pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')\"\"\"\n        )\n    if train_text_encoder_ti:\n        embeddings_filename = f\"{repo_folder}_emb\"\n        ti_keys = \", \".join(f'\"{match}\"' for match in re.findall(r\"<s\\d+>\", instance_prompt))\n        trigger_str = (\n            \"To trigger image generation of trained concept(or concepts) replace each concept identifier \"\n            \"in you prompt with the new inserted tokens:\\n\"\n        )\n        diffusers_imports_pivotal = \"\"\"from huggingface_hub import hf_hub_download\n    from safetensors.torch import load_file\n            \"\"\"\n        if enable_t5_ti:\n            diffusers_example_pivotal = f\"\"\"embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type=\"model\")\n    state_dict = load_file(embedding_path)\n    pipeline.load_textual_inversion(state_dict[\"clip_l\"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)\n    pipeline.load_textual_inversion(state_dict[\"t5\"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)\n            \"\"\"\n        else:\n            diffusers_example_pivotal = f\"\"\"embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type=\"model\")\n    state_dict = load_file(embedding_path)\n    pipeline.load_textual_inversion(state_dict[\"clip_l\"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)\n            \"\"\"\n        if token_abstraction_dict:\n            for key, value in token_abstraction_dict.items():\n                tokens = \"\".join(value)\n                trigger_str += f\"\"\"\n    to trigger concept `{key}` → use `{tokens}` in your prompt \\n\n    \"\"\"\n\n    model_description = f\"\"\"\n# Flux DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).\n\nWas LoRA for the text encoder enabled? {train_text_encoder}.\n\nPivotal tuning was enabled: {train_text_encoder_ti}.\n\n## Trigger words\n\n{trigger_str}\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n{diffusers_imports_pivotal}\npipeline = AutoPipelineForText2Image.from_pretrained(\"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16).to('cuda')\n{diffusers_load_lora}\n{diffusers_example_pivotal}\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"flux\",\n        \"flux-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef load_text_encoders(class_one, class_two):\n    text_encoder_one = class_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = class_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    return text_encoder_one, text_encoder_two\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n\n    # pre-calculate  prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast\n    with torch.no_grad():\n        prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(\n            pipeline_args[\"prompt\"], prompt_2=pipeline_args[\"prompt\"]\n        )\n    images = []\n    for _ in range(args.num_validation_images):\n        with autocast_ctx:\n            image = pipeline(\n                prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator\n            ).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--token_abstraction\",\n        type=str,\n        default=\"TOK\",\n        help=\"identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, \"\n        \"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. \"\n        \"'TOK,TOK2,TOK3' etc.\",\n    )\n\n    parser.add_argument(\n        \"--num_new_tokens_per_abstraction\",\n        type=int,\n        default=None,\n        help=\"number of new tokens inserted to the tokenizers per token_abstraction identifier when \"\n        \"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new \"\n        \"tokens - <si><si+1> \",\n    )\n    parser.add_argument(\n        \"--initializer_concept\",\n        type=str,\n        default=None,\n        help=\"the concept to use to initialize the new inserted tokens when training with \"\n        \"--train_text_encoder_ti = True. By default, new tokens (<si><si+1>) are initialized with random value. \"\n        \"Alternatively, you could specify a different word/words whose value will be used as the starting point for the new inserted tokens. \"\n        \"--num_new_tokens_per_abstraction is ignored when initializer_concept is provided\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=512,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=4,\n        help=\"LoRA alpha to be used for additional scaling.\",\n    )\n\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_text_encoder_ti\",\n        action=\"store_true\",\n        help=(\"Whether to use pivotal tuning / textual inversion\"),\n    )\n    parser.add_argument(\n        \"--enable_t5_ti\",\n        action=\"store_true\",\n        help=(\n            \"Whether to use pivotal tuning / textual inversion for the T5 encoder as well (in addition to CLIP encoder)\"\n        ),\n    )\n\n    parser.add_argument(\n        \"--train_text_encoder_ti_frac\",\n        type=float,\n        default=0.5,\n        help=(\"The percentage of epochs to perform textual inversion\"),\n    )\n\n    parser.add_argument(\n        \"--train_text_encoder_frac\",\n        type=float,\n        default=1.0,\n        help=(\"The percentage of epochs to perform text encoder tuning\"),\n    )\n    parser.add_argument(\n        \"--train_transformer_frac\",\n        type=float,\n        default=1.0,\n        help=(\"The percentage of epochs to perform transformer tuning\"),\n    )\n\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the FLUX.1 dev variant is a guidance distilled model\",\n    )\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\n        \"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for transformer params\"\n    )\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            \"The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. \"\n            'E.g. - \"to_k,to_q,to_v,to_out.0\" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    if args.train_text_encoder and args.train_text_encoder_ti:\n        raise ValueError(\n            \"Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. \"\n            \"For full LoRA text encoder training check --train_text_encoder, for textual \"\n            \"inversion training check `--train_text_encoder_ti`\"\n        )\n    if args.train_transformer_frac < 1 and not args.train_text_encoder_ti:\n        raise ValueError(\n            \"--train_transformer_frac must be == 1 if text_encoder training / textual inversion is not enabled.\"\n        )\n    if args.train_transformer_frac < 1 and args.train_text_encoder_ti_frac < 1:\n        raise ValueError(\n            \"--train_transformer_frac and --train_text_encoder_ti_frac are identical and smaller than 1. \"\n            \"This contradicts with --max_train_steps, please specify different values or set both to 1.\"\n        )\n    if args.enable_t5_ti and not args.train_text_encoder_ti:\n        logger.warning(\"You need not use --enable_t5_ti without --train_text_encoder_ti.\")\n\n    if args.train_text_encoder_ti and args.initializer_concept and args.num_new_tokens_per_abstraction:\n        logger.warning(\n            \"When specifying --initializer_concept, the number of tokens per abstraction is detrimned \"\n            \"by the initializer token. --num_new_tokens_per_abstraction will be ignored\"\n        )\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        if args.class_data_dir is not None:\n            logger.warning(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            logger.warning(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\n# Modified from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py\nclass TokenEmbeddingsHandler:\n    def __init__(self, text_encoders, tokenizers):\n        self.text_encoders = text_encoders\n        self.tokenizers = tokenizers\n\n        self.train_ids: Optional[torch.Tensor] = None\n        self.train_ids_t5: Optional[torch.Tensor] = None\n        self.inserting_toks: Optional[List[str]] = None\n        self.embeddings_settings = {}\n\n    def initialize_new_tokens(self, inserting_toks: List[str]):\n        idx = 0\n        for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):\n            assert isinstance(inserting_toks, list), \"inserting_toks should be a list of strings.\"\n            assert all(isinstance(tok, str) for tok in inserting_toks), (\n                \"All elements in inserting_toks should be strings.\"\n            )\n\n            self.inserting_toks = inserting_toks\n            special_tokens_dict = {\"additional_special_tokens\": self.inserting_toks}\n            tokenizer.add_special_tokens(special_tokens_dict)\n            # Resize the token embeddings as we are adding new special tokens to the tokenizer\n            text_encoder.resize_token_embeddings(len(tokenizer))\n\n            # Convert the token abstractions to ids\n            if idx == 0:\n                self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)\n            else:\n                self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks)\n\n            # random initialization of new tokens\n            embeds = (\n                text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens\n            )\n            std_token_embedding = embeds.weight.data.std()\n\n            logger.info(f\"{idx} text encoder's std_token_embedding: {std_token_embedding}\")\n\n            train_ids = self.train_ids if idx == 0 else self.train_ids_t5\n            # if initializer_concept are not provided, token embeddings are initialized randomly\n            if args.initializer_concept is None:\n                hidden_size = (\n                    text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size\n                )\n                embeds.weight.data[train_ids] = (\n                    torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype)\n                    * std_token_embedding\n                )\n            else:\n                # Convert the initializer_token, placeholder_token to ids\n                initializer_token_ids = tokenizer.encode(args.initializer_concept, add_special_tokens=False)\n                for token_idx, token_id in enumerate(train_ids):\n                    embeds.weight.data[token_id] = (embeds.weight.data)[\n                        initializer_token_ids[token_idx % len(initializer_token_ids)]\n                    ].clone()\n\n            self.embeddings_settings[f\"original_embeddings_{idx}\"] = embeds.weight.data.clone()\n            self.embeddings_settings[f\"std_token_embedding_{idx}\"] = std_token_embedding\n\n            # makes sure we don't update any embedding weights besides the newly added token\n            index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)\n            index_no_updates[train_ids] = False\n\n            self.embeddings_settings[f\"index_no_updates_{idx}\"] = index_no_updates\n\n            logger.info(self.embeddings_settings[f\"index_no_updates_{idx}\"].shape)\n\n            idx += 1\n\n    def save_embeddings(self, file_path: str):\n        assert self.train_ids is not None, \"Initialize new tokens before saving embeddings.\"\n        tensors = {}\n        # text_encoder_one, idx==0 - CLIP ViT-L/14, text_encoder_two, idx==1 - T5 xxl\n        idx_to_text_encoder_name = {0: \"clip_l\", 1: \"t5\"}\n        for idx, text_encoder in enumerate(self.text_encoders):\n            train_ids = self.train_ids if idx == 0 else self.train_ids_t5\n            embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared\n            assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), \"Tokenizers should be the same.\"\n            new_token_embeddings = embeds.weight.data[train_ids]\n\n            # New tokens for each text encoder are saved under \"clip_l\" (for text_encoder 0),\n            # Note: When loading with diffusers, any name can work - simply specify in inference\n            tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings\n            # tensors[f\"text_encoders_{idx}\"] = new_token_embeddings\n\n        save_file(tensors, file_path)\n\n    @property\n    def dtype(self):\n        return self.text_encoders[0].dtype\n\n    @property\n    def device(self):\n        return self.text_encoders[0].device\n\n    @torch.no_grad()\n    def retract_embeddings(self):\n        for idx, text_encoder in enumerate(self.text_encoders):\n            embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared\n            index_no_updates = self.embeddings_settings[f\"index_no_updates_{idx}\"]\n            embeds.weight.data[index_no_updates] = (\n                self.embeddings_settings[f\"original_embeddings_{idx}\"][index_no_updates]\n                .to(device=text_encoder.device)\n                .to(dtype=text_encoder.dtype)\n            )\n\n            # for the parts that were updated, we need to normalize them\n            # to have the same std as before\n            std_token_embedding = self.embeddings_settings[f\"std_token_embedding_{idx}\"]\n\n            index_updates = ~index_no_updates\n            new_embeddings = embeds.weight.data[index_updates]\n            off_ratio = std_token_embedding / new_embeddings.std()\n\n            new_embeddings = new_embeddings * (off_ratio**0.1)\n            embeds.weight.data[index_updates] = new_embeddings\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        args,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        train_text_encoder_ti,\n        token_abstraction_dict=None,  # token mapping for textual inversion\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n    ):\n        self.size = size\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n        self.token_abstraction_dict = token_abstraction_dict\n        self.train_text_encoder_ti = train_text_encoder_ti\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        self.pixel_values = []\n        interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)\n        if interpolation is None:\n            raise ValueError(f\"Unsupported interpolation mode {interpolation=}.\")\n        train_resize = transforms.Resize(size, interpolation=interpolation)\n        train_crop = transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in self.instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - self.size) / 2.0)))\n                x1 = max(0, int(round((image.width - self.size) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (self.size, self.size))\n                image = crop(image, y1, x1, h, w)\n            image = train_transforms(image)\n            self.pixel_values.append(image)\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=interpolation),\n                transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                if self.train_text_encoder_ti:\n                    # replace instances of --token_abstraction in caption with the new tokens: \"<si><si+1>\" etc.\n                    for token_abs, token_replacement in self.token_abstraction_dict.items():\n                        caption = caption.replace(token_abs, \"\".join(token_replacement))\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # the given instance prompt is used for all images\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=False):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=max_sequence_length,\n        truncation=True,\n        return_length=False,\n        return_overflowing_tokens=False,\n        add_special_tokens=add_special_tokens,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\ndef _encode_prompt_with_t5(\n    text_encoder,\n    tokenizer,\n    max_sequence_length=512,\n    prompt=None,\n    num_images_per_prompt=1,\n    device=None,\n    text_input_ids=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device))[0]\n\n    if hasattr(text_encoder, \"module\"):\n        dtype = text_encoder.module.dtype\n    else:\n        dtype = text_encoder.dtype\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    _, seq_len, _ = prompt_embeds.shape\n\n    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds\n\n\ndef _encode_prompt_with_clip(\n    text_encoder,\n    tokenizer,\n    prompt: str,\n    device=None,\n    text_input_ids=None,\n    num_images_per_prompt: int = 1,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=77,\n            truncation=True,\n            return_overflowing_tokens=False,\n            return_length=False,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n    if hasattr(text_encoder, \"module\"):\n        dtype = text_encoder.module.dtype\n    else:\n        dtype = text_encoder.dtype\n    # Use pooled output of CLIPTextModel\n    prompt_embeds = prompt_embeds.pooler_output\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    # duplicate text embeddings for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n    return prompt_embeds\n\n\ndef encode_prompt(\n    text_encoders,\n    tokenizers,\n    prompt: str,\n    max_sequence_length,\n    device=None,\n    num_images_per_prompt: int = 1,\n    text_input_ids_list=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    if hasattr(text_encoders[0], \"module\"):\n        dtype = text_encoders[0].module.dtype\n    else:\n        dtype = text_encoders[0].dtype\n\n    pooled_prompt_embeds = _encode_prompt_with_clip(\n        text_encoder=text_encoders[0],\n        tokenizer=tokenizers[0],\n        prompt=prompt,\n        device=device if device is not None else text_encoders[0].device,\n        num_images_per_prompt=num_images_per_prompt,\n        text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,\n    )\n\n    prompt_embeds = _encode_prompt_with_t5(\n        text_encoder=text_encoders[1],\n        tokenizer=tokenizers[1],\n        max_sequence_length=max_sequence_length,\n        prompt=prompt,\n        num_images_per_prompt=num_images_per_prompt,\n        device=device if device is not None else text_encoders[1].device,\n        text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,\n    )\n\n    text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\n\n    return prompt_embeds, pooled_prompt_embeds, text_ids\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()\n            torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n\n            pipeline = FluxPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):\n                    images = pipeline(prompt=example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        model_id = args.hub_model_id or Path(args.output_dir).name\n        repo_id = None\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=model_id,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n    tokenizer_two = T5TokenizerFast.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    transformer = FluxTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n    )\n\n    if args.train_text_encoder_ti:\n        # we parse the provided token identifier (or identifiers) into a list. s.t. - \"TOK\" -> [\"TOK\"], \"TOK,\n        # TOK2\" -> [\"TOK\", \"TOK2\"] etc.\n        token_abstraction_list = [place_holder.strip() for place_holder in re.split(r\",\\s*\", args.token_abstraction)]\n        logger.info(f\"list of token identifiers: {token_abstraction_list}\")\n\n        if args.initializer_concept is None:\n            num_new_tokens_per_abstraction = (\n                2 if args.num_new_tokens_per_abstraction is None else args.num_new_tokens_per_abstraction\n            )\n        # if args.initializer_concept is provided, we ignore args.num_new_tokens_per_abstraction\n        else:\n            token_ids = tokenizer_one.encode(args.initializer_concept, add_special_tokens=False)\n            num_new_tokens_per_abstraction = len(token_ids)\n            if args.enable_t5_ti:\n                token_ids_t5 = tokenizer_two.encode(args.initializer_concept, add_special_tokens=False)\n                num_new_tokens_per_abstraction = max(len(token_ids), len(token_ids_t5))\n            logger.info(\n                f\"initializer_concept: {args.initializer_concept}, num_new_tokens_per_abstraction: {num_new_tokens_per_abstraction}\"\n            )\n\n        token_abstraction_dict = {}\n        token_idx = 0\n        for i, token in enumerate(token_abstraction_list):\n            token_abstraction_dict[token] = [f\"<s{token_idx + i + j}>\" for j in range(num_new_tokens_per_abstraction)]\n            token_idx += num_new_tokens_per_abstraction - 1\n\n        # replace instances of --token_abstraction in --instance_prompt with the new tokens: \"<si><si+1>\" etc.\n        for token_abs, token_replacement in token_abstraction_dict.items():\n            new_instance_prompt = args.instance_prompt.replace(token_abs, \"\".join(token_replacement))\n            if args.instance_prompt == new_instance_prompt:\n                logger.warning(\n                    \"Note! the instance prompt provided in --instance_prompt does not include the token abstraction specified \"\n                    \"--token_abstraction. This may lead to incorrect optimization of text embeddings during pivotal tuning\"\n                )\n            args.instance_prompt = new_instance_prompt\n            if args.with_prior_preservation:\n                args.class_prompt = args.class_prompt.replace(token_abs, \"\".join(token_replacement))\n            if args.validation_prompt:\n                args.validation_prompt = args.validation_prompt.replace(token_abs, \"\".join(token_replacement))\n\n        # initialize the new tokens for textual inversion\n        text_encoders = [text_encoder_one, text_encoder_two] if args.enable_t5_ti else [text_encoder_one]\n        tokenizers = [tokenizer_one, tokenizer_two] if args.enable_t5_ti else [tokenizer_one]\n        embedding_handler = TokenEmbeddingsHandler(text_encoders, tokenizers)\n        inserting_toks = []\n        for new_tok in token_abstraction_dict.values():\n            inserting_toks.extend(new_tok)\n        embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks)\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    vae.to(accelerator.device, dtype=weight_dtype)\n    transformer.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder_one.gradient_checkpointing_enable()\n\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\n            \"attn.to_k\",\n            \"attn.to_q\",\n            \"attn.to_v\",\n            \"attn.to_out.0\",\n            \"attn.add_k_proj\",\n            \"attn.add_q_proj\",\n            \"attn.add_v_proj\",\n            \"attn.to_add_out\",\n            \"ff.net.0.proj\",\n            \"ff.net.2\",\n            \"ff_context.net.0.proj\",\n            \"ff_context.net.2\",\n        ]\n    # now we will add new LoRA weights to the attention layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n    if args.train_text_encoder:\n        text_lora_config = LoraConfig(\n            r=args.rank,\n            lora_alpha=args.lora_alpha,\n            lora_dropout=args.lora_dropout,\n            init_lora_weights=\"gaussian\",\n            target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n        )\n        text_encoder_one.add_adapter(text_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n            text_encoder_one_lora_layers_to_save = None\n            modules_to_save = {}\n            for model in models:\n                if isinstance(model, type(unwrap_model(transformer))):\n                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                    modules_to_save[\"transformer\"] = model\n                elif isinstance(model, type(unwrap_model(text_encoder_one))):\n                    if args.train_text_encoder:  # when --train_text_encoder_ti we don't save the layers\n                        text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)\n                        modules_to_save[\"text_encoder\"] = model\n                elif isinstance(model, type(unwrap_model(text_encoder_two))):\n                    pass  # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            FluxPipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n                text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,\n                **_collate_lora_metadata(modules_to_save),\n            )\n        if args.train_text_encoder_ti:\n            embedding_handler.save_embeddings(f\"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors\")\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n        text_encoder_one_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(transformer))):\n                transformer_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder_one))):\n                text_encoder_one_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict = FluxPipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n        if args.train_text_encoder:\n            # Do we need to call `scale_lora_layers()` here?\n            _set_state_dict_into_text_encoder(lora_state_dict, prefix=\"text_encoder.\", text_encoder=text_encoder_one_)\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            if args.train_text_encoder:\n                models.extend([text_encoder_one_])\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        if args.train_text_encoder:\n            models.extend([text_encoder_one])\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n    if args.train_text_encoder:\n        text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))\n    # if we use textual inversion, we freeze all parameters except for the token embeddings\n    # in text encoder\n    elif args.train_text_encoder_ti:\n        text_lora_parameters_one = []  # CLIP\n        for name, param in text_encoder_one.named_parameters():\n            if \"token_embedding\" in name:\n                # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16\n                if args.mixed_precision == \"fp16\":\n                    param.data = param.to(dtype=torch.float32)\n                param.requires_grad = True\n                text_lora_parameters_one.append(param)\n            else:\n                param.requires_grad = False\n        if args.enable_t5_ti:  # whether to do pivotal tuning/textual inversion for T5 as well\n            text_lora_parameters_two = []\n            for name, param in text_encoder_two.named_parameters():\n                if \"shared\" in name:\n                    # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16\n                    if args.mixed_precision == \"fp16\":\n                        param.data = param.to(dtype=torch.float32)\n                    param.requires_grad = True\n                    text_lora_parameters_two.append(param)\n                else:\n                    param.requires_grad = False\n\n    # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training\n    freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)\n\n    # if --train_text_encoder_ti and train_transformer_frac == 0 where essentially performing textual inversion\n    # and not training transformer LoRA layers\n    pure_textual_inversion = args.train_text_encoder_ti and args.train_transformer_frac == 0\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    if not freeze_text_encoder:\n        # different learning rate for text encoder and transformer\n        text_parameters_one_with_lr = {\n            \"params\": text_lora_parameters_one,\n            \"weight_decay\": args.adam_weight_decay_text_encoder\n            if args.adam_weight_decay_text_encoder\n            else args.adam_weight_decay,\n            \"lr\": args.text_encoder_lr,\n        }\n        if not args.enable_t5_ti:\n            # pure textual inversion - only clip\n            if pure_textual_inversion:\n                params_to_optimize = [text_parameters_one_with_lr]\n                te_idx = 0\n            else:  # regular te training or regular pivotal for clip\n                params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]\n                te_idx = 1\n        elif args.enable_t5_ti:\n            # pivotal tuning of clip & t5\n            text_parameters_two_with_lr = {\n                \"params\": text_lora_parameters_two,\n                \"weight_decay\": args.adam_weight_decay_text_encoder\n                if args.adam_weight_decay_text_encoder\n                else args.adam_weight_decay,\n                \"lr\": args.text_encoder_lr,\n            }\n            # pure textual inversion - only clip & t5\n            if pure_textual_inversion:\n                params_to_optimize = [text_parameters_one_with_lr, text_parameters_two_with_lr]\n                te_idx = 0\n            else:  # regular pivotal tuning of clip & t5\n                params_to_optimize = [\n                    transformer_parameters_with_lr,\n                    text_parameters_one_with_lr,\n                    text_parameters_two_with_lr,\n                ]\n                te_idx = 1\n    else:\n        params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n        if not freeze_text_encoder and args.text_encoder_lr:\n            logger.warning(\n                f\"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:\"\n                f\" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. \"\n                f\"When using prodigy only learning_rate is used as the initial learning rate.\"\n            )\n            # changes the learning rate of text_encoder_parameters to be\n            # --learning_rate\n\n            params_to_optimize[te_idx][\"lr\"] = args.learning_rate\n            params_to_optimize[-1][\"lr\"] = args.learning_rate\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        args=args,\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        train_text_encoder_ti=args.train_text_encoder_ti,\n        token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    if freeze_text_encoder:\n        tokenizers = [tokenizer_one, tokenizer_two]\n        text_encoders = [text_encoder_one, text_encoder_two]\n\n        def compute_text_embeddings(prompt, text_encoders, tokenizers):\n            with torch.no_grad():\n                prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n                    text_encoders, tokenizers, prompt, args.max_sequence_length\n                )\n                prompt_embeds = prompt_embeds.to(accelerator.device)\n                pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)\n                text_ids = text_ids.to(accelerator.device)\n            return prompt_embeds, pooled_prompt_embeds, text_ids\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if freeze_text_encoder and not train_dataset.custom_instance_prompts:\n        instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings(\n            args.instance_prompt, text_encoders, tokenizers\n        )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        if freeze_text_encoder:\n            class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings(\n                args.class_prompt, text_encoders, tokenizers\n            )\n\n    # Clear the memory here\n    if freeze_text_encoder and not train_dataset.custom_instance_prompts:\n        del tokenizers, text_encoders, text_encoder_one, text_encoder_two\n        free_memory()\n\n    # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion\n    add_special_tokens_clip = True if args.train_text_encoder_ti else False\n    add_special_tokens_t5 = True if (args.train_text_encoder_ti and args.enable_t5_ti) else False\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n\n    if not train_dataset.custom_instance_prompts:\n        if freeze_text_encoder:\n            prompt_embeds = instance_prompt_hidden_states\n            pooled_prompt_embeds = instance_pooled_prompt_embeds\n            text_ids = instance_text_ids\n            if args.with_prior_preservation:\n                prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n                pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)\n                text_ids = torch.cat([text_ids, class_text_ids], dim=0)\n        # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts)\n        # we need to tokenize and encode the batch prompts on all training steps\n        else:\n            tokens_one = tokenize_prompt(\n                tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens_clip\n            )\n            tokens_two = tokenize_prompt(\n                tokenizer_two,\n                args.instance_prompt,\n                max_sequence_length=args.max_sequence_length,\n                add_special_tokens=add_special_tokens_t5,\n            )\n            if args.with_prior_preservation:\n                class_tokens_one = tokenize_prompt(\n                    tokenizer_one,\n                    args.class_prompt,\n                    max_sequence_length=77,\n                    add_special_tokens=add_special_tokens_clip,\n                )\n                class_tokens_two = tokenize_prompt(\n                    tokenizer_two,\n                    args.class_prompt,\n                    max_sequence_length=args.max_sequence_length,\n                    add_special_tokens=add_special_tokens_t5,\n                )\n                tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)\n                tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)\n\n    vae_config_shift_factor = vae.config.shift_factor\n    vae_config_scaling_factor = vae.config.scaling_factor\n    vae_config_block_out_channels = vae.config.block_out_channels\n    if args.cache_latents:\n        latents_cache = []\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                    accelerator.device, non_blocking=True, dtype=weight_dtype\n                )\n                latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n\n        if args.validation_prompt is None:\n            del vae\n            free_memory()\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if not freeze_text_encoder:\n        if args.enable_t5_ti:\n            (\n                transformer,\n                text_encoder_one,\n                text_encoder_two,\n                optimizer,\n                train_dataloader,\n                lr_scheduler,\n            ) = accelerator.prepare(\n                transformer,\n                text_encoder_one,\n                text_encoder_two,\n                optimizer,\n                train_dataloader,\n                lr_scheduler,\n            )\n        else:\n            transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n                transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler\n            )\n\n    else:\n        transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            transformer, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-flux-dev-lora-advanced\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            logger.info(f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\")\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            logger.info(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    if args.train_text_encoder:\n        num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)\n        num_train_epochs_transformer = int(args.train_transformer_frac * args.num_train_epochs)\n    elif args.train_text_encoder_ti:  # args.train_text_encoder_ti\n        num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)\n        num_train_epochs_transformer = int(args.train_transformer_frac * args.num_train_epochs)\n\n    # flag used for textual inversion\n    pivoted_te = False\n    pivoted_tr = False\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n        # if performing any kind of optimization of text_encoder params\n        if args.train_text_encoder or args.train_text_encoder_ti:\n            if epoch == num_train_epochs_text_encoder:\n                # flag to stop text encoder optimization\n                logger.info(f\"PIVOT TE {epoch}\")\n                pivoted_te = True\n            else:\n                # still optimizing the text encoder\n                if args.train_text_encoder:\n                    text_encoder_one.train()\n                    # set top parameter requires_grad = True for gradient checkpointing works\n                    unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)\n                elif args.train_text_encoder_ti:  # textual inversion / pivotal tuning\n                    text_encoder_one.train()\n                if args.enable_t5_ti:\n                    text_encoder_two.train()\n\n            if epoch == num_train_epochs_transformer:\n                # flag to stop transformer optimization\n                logger.info(f\"PIVOT TRANSFORMER {epoch}\")\n                pivoted_tr = True\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            if not freeze_text_encoder:\n                models_to_accumulate.extend([text_encoder_one])\n                if args.enable_t5_ti:\n                    models_to_accumulate.extend([text_encoder_two])\n            if pivoted_te:\n                # stopping optimization of text_encoder params\n                optimizer.param_groups[te_idx][\"lr\"] = 0.0\n                optimizer.param_groups[-1][\"lr\"] = 0.0\n            elif pivoted_tr and not pure_textual_inversion:\n                logger.info(f\"PIVOT TRANSFORMER {epoch}\")\n                optimizer.param_groups[0][\"lr\"] = 0.0\n\n            with accelerator.accumulate(models_to_accumulate):\n                prompts = batch[\"prompts\"]\n\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    elems_to_repeat = 1\n                    if freeze_text_encoder:\n                        prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(\n                            prompts, text_encoders, tokenizers\n                        )\n                    else:\n                        tokens_one = tokenize_prompt(\n                            tokenizer_one, prompts, max_sequence_length=77, add_special_tokens=add_special_tokens_clip\n                        )\n                        tokens_two = tokenize_prompt(\n                            tokenizer_two,\n                            prompts,\n                            max_sequence_length=args.max_sequence_length,\n                            add_special_tokens=add_special_tokens_t5,\n                        )\n                else:\n                    elems_to_repeat = len(prompts)\n\n                if not freeze_text_encoder:\n                    prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n                        text_encoders=[text_encoder_one, text_encoder_two],\n                        tokenizers=[None, None],\n                        text_input_ids_list=[\n                            tokens_one.repeat(elems_to_repeat, 1),\n                            tokens_two.repeat(elems_to_repeat, 1),\n                        ],\n                        max_sequence_length=args.max_sequence_length,\n                        device=accelerator.device,\n                        prompt=prompts,\n                    )\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step].sample()\n                else:\n                    pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent_dist.sample()\n                model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor\n                model_input = model_input.to(dtype=weight_dtype)\n\n                vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)\n\n                latent_image_ids = FluxPipeline._prepare_latent_image_ids(\n                    model_input.shape[0],\n                    model_input.shape[2] // 2,\n                    model_input.shape[3] // 2,\n                    accelerator.device,\n                    weight_dtype,\n                )\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                packed_noisy_model_input = FluxPipeline._pack_latents(\n                    noisy_model_input,\n                    batch_size=model_input.shape[0],\n                    num_channels_latents=model_input.shape[1],\n                    height=model_input.shape[2],\n                    width=model_input.shape[3],\n                )\n\n                # handle guidance\n                if unwrap_model(transformer).config.guidance_embeds:\n                    guidance = torch.tensor([args.guidance_scale], device=accelerator.device)\n                    guidance = guidance.expand(model_input.shape[0])\n                else:\n                    guidance = None\n\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=packed_noisy_model_input,\n                    # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)\n                    timestep=timesteps / 1000,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    return_dict=False,\n                )[0]\n                model_pred = FluxPipeline._unpack_latents(\n                    model_pred,\n                    height=model_input.shape[2] * vae_scale_factor,\n                    width=model_input.shape[3] * vae_scale_factor,\n                    vae_scale_factor=vae_scale_factor,\n                )\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = noise - model_input\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    if not freeze_text_encoder:\n                        if args.train_text_encoder:  # text encoder tuning\n                            params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters())\n                        elif pure_textual_inversion:\n                            if args.enable_t5_ti:\n                                params_to_clip = itertools.chain(\n                                    text_encoder_one.parameters(), text_encoder_two.parameters()\n                                )\n                            else:\n                                params_to_clip = itertools.chain(text_encoder_one.parameters())\n                        else:\n                            if args.enable_t5_ti:\n                                params_to_clip = itertools.chain(\n                                    transformer.parameters(),\n                                    text_encoder_one.parameters(),\n                                    text_encoder_two.parameters(),\n                                )\n                            else:\n                                params_to_clip = itertools.chain(\n                                    transformer.parameters(), text_encoder_one.parameters()\n                                )\n                    else:\n                        params_to_clip = itertools.chain(transformer.parameters())\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n                # every step, we reset the embeddings to the original embeddings.\n                if args.train_text_encoder_ti:\n                    embedding_handler.retract_embeddings()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        if args.train_text_encoder_ti:\n                            embedding_handler.save_embeddings(\n                                f\"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors\"\n                            )\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                if freeze_text_encoder:  # no text encoder one, two optimizations\n                    text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)\n                    text_encoder_one.to(weight_dtype)\n                    text_encoder_two.to(weight_dtype)\n                pipeline = FluxPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    text_encoder=unwrap_model(text_encoder_one),\n                    text_encoder_2=unwrap_model(text_encoder_two),\n                    transformer=unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                    torch_dtype=weight_dtype,\n                )\n                if freeze_text_encoder:\n                    del text_encoder_one, text_encoder_two\n                    free_memory()\n\n                images = None\n                del pipeline\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        modules_to_save = {}\n        transformer = unwrap_model(transformer)\n        if args.upcast_before_saving:\n            transformer.to(torch.float32)\n        else:\n            transformer = transformer.to(weight_dtype)\n        transformer_lora_layers = get_peft_model_state_dict(transformer)\n        modules_to_save[\"transformer\"] = transformer\n\n        if args.train_text_encoder:\n            text_encoder_one = unwrap_model(text_encoder_one)\n            text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))\n            modules_to_save[\"text_encoder\"] = text_encoder_one\n        else:\n            text_encoder_lora_layers = None\n\n        if not pure_textual_inversion:\n            FluxPipeline.save_lora_weights(\n                save_directory=args.output_dir,\n                transformer_lora_layers=transformer_lora_layers,\n                text_encoder_lora_layers=text_encoder_lora_layers,\n                **_collate_lora_metadata(modules_to_save),\n            )\n\n        if args.train_text_encoder_ti:\n            embeddings_path = f\"{args.output_dir}/{os.path.basename(args.output_dir)}_emb.safetensors\"\n            embedding_handler.save_embeddings(embeddings_path)\n\n        # Final inference\n        # Load previous pipeline\n        pipeline = FluxPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        if not pure_textual_inversion:\n            # load attention processors\n            pipeline.load_lora_weights(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt}\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=pipeline_args,\n                epoch=epoch,\n                is_final_validation=True,\n                torch_dtype=weight_dtype,\n            )\n\n        save_model_card(\n            model_id if not args.push_to_hub else repo_id,\n            images=images,\n            base_model=args.pretrained_model_name_or_path,\n            train_text_encoder=args.train_text_encoder,\n            train_text_encoder_ti=args.train_text_encoder_ti,\n            enable_t5_ti=args.enable_t5_ti,\n            pure_textual_inversion=pure_textual_inversion,\n            token_abstraction_dict=train_dataset.token_abstraction_dict,\n            instance_prompt=args.instance_prompt,\n            validation_prompt=args.validation_prompt,\n            repo_folder=args.output_dir,\n        )\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        images = None\n        del pipeline\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=0.31.0\",\n#     \"transformers>=4.41.2\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.11.1\",\n#     \"sentencepiece\",\n# ]\n# ///\n\nimport argparse\nimport gc\nimport hashlib\nimport itertools\nimport logging\nimport math\nimport os\nimport re\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import List, Optional\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\n# imports of the TokenEmbeddingsHandler class\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom safetensors.torch import load_file, save_file\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DPMSolverMultistepScheduler,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr\nfrom diffusers.utils import (\n    check_min_version,\n    convert_all_state_dict_to_peft,\n    convert_state_dict_to_diffusers,\n    convert_state_dict_to_kohya,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    use_dora: bool,\n    images: list = None,\n    base_model: str = None,\n    train_text_encoder=False,\n    train_text_encoder_ti=False,\n    token_abstraction_dict=None,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n    vae_path=None,\n):\n    lora = \"lora\" if not use_dora else \"dora\"\n\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n    else:\n        widget_dict.append({\"text\": instance_prompt})\n    embeddings_filename = f\"{repo_folder}_emb\"\n    instance_prompt_webui = re.sub(r\"<s\\d+>\", \"\", re.sub(r\"<s\\d+>\", embeddings_filename, instance_prompt, count=1))\n    ti_keys = \", \".join(f'\"{match}\"' for match in re.findall(r\"<s\\d+>\", instance_prompt))\n    if instance_prompt_webui != embeddings_filename:\n        instance_prompt_sentence = f\"For example, `{instance_prompt_webui}`\"\n    else:\n        instance_prompt_sentence = \"\"\n    trigger_str = f\"You should use {instance_prompt} to trigger the image generation.\"\n    diffusers_imports_pivotal = \"\"\n    diffusers_example_pivotal = \"\"\n    webui_example_pivotal = \"\"\n    if train_text_encoder_ti:\n        trigger_str = (\n            \"To trigger image generation of trained concept(or concepts) replace each concept identifier \"\n            \"in you prompt with the new inserted tokens:\\n\"\n        )\n        diffusers_imports_pivotal = \"\"\"from huggingface_hub import hf_hub_download\nfrom safetensors.torch import load_file\n        \"\"\"\n        diffusers_example_pivotal = f\"\"\"embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type=\"model\")\nstate_dict = load_file(embedding_path)\npipeline.load_textual_inversion(state_dict[\"clip_l\"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)\n        \"\"\"\n        webui_example_pivotal = f\"\"\"- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.\n    - Place it on it on your `embeddings` folder\n    - Use it by adding `{embeddings_filename}` to your prompt. {instance_prompt_sentence}\n    (you need both the LoRA and the embeddings as they were trained together for this LoRA)\n    \"\"\"\n        if token_abstraction_dict:\n            for key, value in token_abstraction_dict.items():\n                tokens = \"\".join(value)\n                trigger_str += f\"\"\"\nto trigger concept `{key}` → use `{tokens}` in your prompt \\n\n\"\"\"\n    model_description = f\"\"\"\n# SD1.5 LoRA DreamBooth - {repo_id}\n\n<Gallery />\n\n## Model description\n\n### These are {repo_id} LoRA adaption weights for {base_model}.\n\n## Download model\n\n### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke\n\n- **LoRA**: download **[`{repo_folder}.safetensors` here 💾](/{repo_id}/blob/main/{repo_folder}.safetensors)**.\n    - Place it on your `models/Lora` folder.\n    - On AUTOMATIC1111, load the LoRA by adding `<lora:{repo_folder}:1>` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).\n{webui_example_pivotal}\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n{diffusers_imports_pivotal}\npipeline = AutoPipelineForText2Image.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5', torch_dtype=torch.float16).to('cuda')\npipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')\n{diffusers_example_pivotal}\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## Trigger words\n\n{trigger_str}\n\n## Details\nAll [Files & versions](/{repo_id}/tree/main).\n\nThe weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py).\n\nLoRA for the text encoder was enabled. {train_text_encoder}.\n\nPivotal tuning was enabled: {train_text_encoder_ti}.\n\nSpecial VAE used for training: {vae_path}.\n\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"openrail++\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        inference=True,\n        widget=widget_dict,\n    )\n\n    tags = [\n        \"text-to-image\",\n        \"diffusers\",\n        \"diffusers-training\",\n        lora,\n        \"template:sd-lora\",\n        \"stable-diffusion\",\n        \"stable-diffusion-diffusers\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.To load the custom captions, the training set directory needs to follow the structure of a \"\n            \"datasets ImageFolder, containing both the images and the corresponding caption for each image. see: \"\n            \"https://huggingface.co/docs/datasets/image_dataset for more information\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset. In some cases, a dataset may have more than one configuration (for example \"\n        \"if it contains different subsets of data within, and you only wish to load a specific subset - in that case specify the desired configuration using --dataset_config_name. Leave as \"\n        \"None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=\"A path to local folder containing the training data of instance images. Specify this arg instead of \"\n        \"--dataset_name if you wish to train using a local folder without custom captions. If you wish to train with custom captions please specify \"\n        \"--dataset_name instead.\",\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--token_abstraction\",\n        type=str,\n        default=\"TOK\",\n        help=\"identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, \"\n        \"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. \"\n        \"'TOK,TOK2,TOK3' etc.\",\n    )\n\n    parser.add_argument(\n        \"--num_new_tokens_per_abstraction\",\n        type=int,\n        default=2,\n        help=\"number of new tokens inserted to the tokenizers per token_abstraction identifier when \"\n        \"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new \"\n        \"tokens - <si><si+1> \",\n    )\n\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"lora-dreambooth-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n\n    parser.add_argument(\n        \"--train_text_encoder_ti\",\n        action=\"store_true\",\n        help=(\"Whether to use textual inversion\"),\n    )\n\n    parser.add_argument(\n        \"--train_text_encoder_ti_frac\",\n        type=float,\n        default=0.5,\n        help=(\"The percentage of epochs to perform textual inversion\"),\n    )\n\n    parser.add_argument(\n        \"--train_text_encoder_frac\",\n        type=float,\n        default=1.0,\n        help=(\"The percentage of epochs to perform text encoder tuning\"),\n    )\n\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"adamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=None, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--use_dora\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://huggingface.co/papers/2402.09353. \"\n            \"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    if args.train_text_encoder and args.train_text_encoder_ti:\n        raise ValueError(\n            \"Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. \"\n            \"For full LoRA text encoder training check --train_text_encoder, for textual \"\n            \"inversion training check `--train_text_encoder_ti`\"\n        )\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\n# Taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py\nclass TokenEmbeddingsHandler:\n    def __init__(self, text_encoders, tokenizers):\n        self.text_encoders = text_encoders\n        self.tokenizers = tokenizers\n\n        self.train_ids: Optional[torch.Tensor] = None\n        self.inserting_toks: Optional[List[str]] = None\n        self.embeddings_settings = {}\n\n    def initialize_new_tokens(self, inserting_toks: List[str]):\n        idx = 0\n        for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):\n            assert isinstance(inserting_toks, list), \"inserting_toks should be a list of strings.\"\n            assert all(isinstance(tok, str) for tok in inserting_toks), (\n                \"All elements in inserting_toks should be strings.\"\n            )\n\n            self.inserting_toks = inserting_toks\n            special_tokens_dict = {\"additional_special_tokens\": self.inserting_toks}\n            tokenizer.add_special_tokens(special_tokens_dict)\n            text_encoder.resize_token_embeddings(len(tokenizer))\n\n            self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)\n\n            # random initialization of new tokens\n            std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()\n\n            print(f\"{idx} text encoder's std_token_embedding: {std_token_embedding}\")\n\n            text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (\n                torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)\n                .to(device=self.device)\n                .to(dtype=self.dtype)\n                * std_token_embedding\n            )\n            self.embeddings_settings[f\"original_embeddings_{idx}\"] = (\n                text_encoder.text_model.embeddings.token_embedding.weight.data.clone()\n            )\n            self.embeddings_settings[f\"std_token_embedding_{idx}\"] = std_token_embedding\n\n            inu = torch.ones((len(tokenizer),), dtype=torch.bool)\n            inu[self.train_ids] = False\n\n            self.embeddings_settings[f\"index_no_updates_{idx}\"] = inu\n\n            print(self.embeddings_settings[f\"index_no_updates_{idx}\"].shape)\n\n            idx += 1\n\n    # Copied from train_dreambooth_lora_sdxl_advanced.py\n    def save_embeddings(self, file_path: str):\n        assert self.train_ids is not None, \"Initialize new tokens before saving embeddings.\"\n        tensors = {}\n        # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 -  CLIP ViT-G/14 - TODO - change for sd\n        idx_to_text_encoder_name = {0: \"clip_l\", 1: \"clip_g\"}\n        for idx, text_encoder in enumerate(self.text_encoders):\n            assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(\n                self.tokenizers[0]\n            ), \"Tokenizers should be the same.\"\n            new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]\n\n            # New tokens for each text encoder are saved under \"clip_l\" (for text_encoder 0), \"clip_g\" (for\n            # text_encoder 1) to keep compatible with the ecosystem.\n            # Note: When loading with diffusers, any name can work - simply specify in inference\n            tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings\n            # tensors[f\"text_encoders_{idx}\"] = new_token_embeddings\n\n        save_file(tensors, file_path)\n\n    @property\n    def dtype(self):\n        return self.text_encoders[0].dtype\n\n    @property\n    def device(self):\n        return self.text_encoders[0].device\n\n    @torch.no_grad()\n    def retract_embeddings(self):\n        for idx, text_encoder in enumerate(self.text_encoders):\n            index_no_updates = self.embeddings_settings[f\"index_no_updates_{idx}\"]\n            text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (\n                self.embeddings_settings[f\"original_embeddings_{idx}\"][index_no_updates]\n                .to(device=text_encoder.device)\n                .to(dtype=text_encoder.dtype)\n            )\n\n            # for the parts that were updated, we need to normalize them\n            # to have the same std as before\n            std_token_embedding = self.embeddings_settings[f\"std_token_embedding_{idx}\"]\n\n            index_updates = ~index_no_updates\n            new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]\n            off_ratio = std_token_embedding / new_embeddings.std()\n\n            new_embeddings = new_embeddings * (off_ratio**0.1)\n            text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        dataset_name,\n        dataset_config_name,\n        cache_dir,\n        image_column,\n        caption_column,\n        train_text_encoder_ti,\n        class_data_root=None,\n        class_num=None,\n        token_abstraction_dict=None,  # token mapping for textual inversion\n        size=1024,\n        repeats=1,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n        self.token_abstraction_dict = token_abstraction_dict\n        self.train_text_encoder_ti = train_text_encoder_ti\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                dataset_name,\n                dataset_config_name,\n                cache_dir=cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)\n        if interpolation is None:\n            raise ValueError(f\"Unsupported interpolation mode {interpolation=}.\")\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=interpolation),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.instance_images[index % self.num_instance_images]\n        instance_image = exif_transpose(instance_image)\n\n        if not instance_image.mode == \"RGB\":\n            instance_image = instance_image.convert(\"RGB\")\n        example[\"instance_images\"] = self.image_transforms(instance_image)\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                if self.train_text_encoder_ti:\n                    # replace instances of --token_abstraction in caption with the new tokens: \"<si><si+1>\" etc.\n                    for token_abs, token_replacement in self.token_abstraction_dict.items():\n                        caption = caption.replace(token_abs, \"\".join(token_replacement))\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef tokenize_prompt(tokenizer, prompt, add_special_tokens=False):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=tokenizer.model_max_length,\n        truncation=True,\n        add_special_tokens=add_special_tokens,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):\n    for i, text_encoder in enumerate(text_encoders):\n        if tokenizers is not None:\n            tokenizer = tokenizers[i]\n            text_input_ids = tokenize_prompt(tokenizer, prompt)\n        else:\n            assert text_input_ids_list is not None\n            text_input_ids = text_input_ids_list[i]\n\n        prompt_embeds = text_encoder(\n            text_input_ids.to(text_encoder.device),\n            output_hidden_states=True,\n        )\n\n    return prompt_embeds[0]\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            torch_dtype = torch.float16 if accelerator.device.type == \"cuda\" else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n            pipeline = StableDiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        model_id = args.hub_model_id or Path(args.output_dir).name\n        repo_id = None\n        if args.push_to_hub:\n            repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        variant=args.variant,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    vae_scaling_factor = vae.config.scaling_factor\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    if args.train_text_encoder_ti:\n        # we parse the provided token identifier (or identifiers) into a list. s.t. - \"TOK\" -> [\"TOK\"], \"TOK,\n        # TOK2\" -> [\"TOK\", \"TOK2\"] etc.\n        token_abstraction_list = \"\".join(args.token_abstraction.split()).split(\",\")\n        logger.info(f\"list of token identifiers: {token_abstraction_list}\")\n\n        token_abstraction_dict = {}\n        token_idx = 0\n        for i, token in enumerate(token_abstraction_list):\n            token_abstraction_dict[token] = [\n                f\"<s{token_idx + i + j}>\" for j in range(args.num_new_tokens_per_abstraction)\n            ]\n            token_idx += args.num_new_tokens_per_abstraction - 1\n\n        # replace instances of --token_abstraction in --instance_prompt with the new tokens: \"<si><si+1>\" etc.\n        for token_abs, token_replacement in token_abstraction_dict.items():\n            args.instance_prompt = args.instance_prompt.replace(token_abs, \"\".join(token_replacement))\n            if args.with_prior_preservation:\n                args.class_prompt = args.class_prompt.replace(token_abs, \"\".join(token_replacement))\n\n        # initialize the new tokens for textual inversion\n        embedding_handler = TokenEmbeddingsHandler([text_encoder_one], [tokenizer_one])\n        inserting_toks = []\n        for new_tok in token_abstraction_dict.values():\n            inserting_toks.extend(new_tok)\n        embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks)\n\n    # We only train the additional adapter LoRA layers\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n\n    # The VAE is always in float32 to avoid NaN losses.\n    vae.to(accelerator.device, dtype=torch.float32)\n\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, \"\n                    \"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder_one.gradient_checkpointing_enable()\n\n    # now we will add new LoRA weights to the attention layers\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        lora_dropout=args.lora_dropout,\n        use_dora=args.use_dora,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n    unet.add_adapter(unet_lora_config)\n\n    # The text encoder comes from 🤗 transformers, so we cannot directly modify it.\n    # So, instead, we monkey-patch the forward calls of its attention-blocks.\n    if args.train_text_encoder:\n        text_lora_config = LoraConfig(\n            r=args.rank,\n            lora_alpha=args.rank,\n            lora_dropout=args.lora_dropout,\n            use_dora=args.use_dora,\n            init_lora_weights=\"gaussian\",\n            target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n        )\n        text_encoder_one.add_adapter(text_lora_config)\n\n    # if we use textual inversion, we freeze all parameters except for the token embeddings\n    # in text encoder\n    elif args.train_text_encoder_ti:\n        text_lora_parameters_one = []\n        for name, param in text_encoder_one.named_parameters():\n            if \"token_embedding\" in name:\n                # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16\n                param = param.to(dtype=torch.float32)\n                param.requires_grad = True\n                text_lora_parameters_one.append(param)\n            else:\n                param.requires_grad = False\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [unet]\n        if args.train_text_encoder:\n            models.extend([text_encoder_one])\n        for model in models:\n            for param in model.parameters():\n                # only upcast trainable parameters (LoRA) into fp32\n                if param.requires_grad:\n                    param.data = param.to(torch.float32)\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            # there are only two options here. Either are just the unet attn processor layers\n            # or there are the unet and text encoder atten layers\n            unet_lora_layers_to_save = None\n            text_encoder_one_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(accelerator.unwrap_model(unet))):\n                    unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))\n                elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):\n                    if args.train_text_encoder:\n                        text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(\n                            get_peft_model_state_dict(model)\n                        )\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            StableDiffusionPipeline.save_lora_weights(\n                output_dir,\n                unet_lora_layers=unet_lora_layers_to_save,\n                text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,\n            )\n        if args.train_text_encoder_ti:\n            embedding_handler.save_embeddings(f\"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors\")\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n        text_encoder_one_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(accelerator.unwrap_model(unet))):\n                unet_ = model\n            elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):\n                text_encoder_one_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)\n\n        unet_state_dict = {f\"{k.replace('unet.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"unet.\")}\n        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)\n        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        if args.train_text_encoder:\n            # Do we need to call `scale_lora_layers()` here?\n            _set_state_dict_into_text_encoder(lora_state_dict, prefix=\"text_encoder.\", text_encoder=text_encoder_one_)\n\n            _set_state_dict_into_text_encoder(\n                lora_state_dict, prefix=\"text_encoder_2.\", text_encoder=text_encoder_one_\n            )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [unet_]\n            if args.train_text_encoder:\n                models.extend([text_encoder_one_])\n                # only upcast trainable parameters (LoRA) into fp32\n                cast_training_params(models)\n        lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)\n        StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)\n\n        text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if \"text_encoder.\" in k}\n        StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder(\n            text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_\n        )\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))\n\n    if args.train_text_encoder:\n        text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))\n\n    # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training\n    freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)\n\n    # Optimization parameters\n    unet_lora_parameters_with_lr = {\"params\": unet_lora_parameters, \"lr\": args.learning_rate}\n    if not freeze_text_encoder:\n        # different learning rate for text encoder and unet\n        text_lora_parameters_one_with_lr = {\n            \"params\": text_lora_parameters_one,\n            \"weight_decay\": args.adam_weight_decay_text_encoder\n            if args.adam_weight_decay_text_encoder\n            else args.adam_weight_decay,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        params_to_optimize = [unet_lora_parameters_with_lr, text_lora_parameters_one_with_lr]\n    else:\n        params_to_optimize = [unet_lora_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n        if args.train_text_encoder and args.text_encoder_lr:\n            logger.warning(\n                f\"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:\"\n                f\" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. \"\n                f\"When using prodigy only learning_rate is used as the initial learning rate.\"\n            )\n            # changes the learning rate of text_encoder_parameters_one to be\n            # --learning_rate\n            params_to_optimize[1][\"lr\"] = args.learning_rate\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        dataset_name=args.dataset_name,\n        dataset_config_name=args.dataset_config_name,\n        cache_dir=args.cache_dir,\n        image_column=args.image_column,\n        train_text_encoder_ti=args.train_text_encoder_ti,\n        caption_column=args.caption_column,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    if not args.train_text_encoder:\n        tokenizers = [tokenizer_one]\n        text_encoders = [text_encoder_one]\n\n        def compute_text_embeddings(prompt, text_encoders, tokenizers):\n            with torch.no_grad():\n                prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)\n                prompt_embeds = prompt_embeds.to(accelerator.device)\n            return prompt_embeds\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if freeze_text_encoder and not train_dataset.custom_instance_prompts:\n        instance_prompt_hidden_states = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers)\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        if freeze_text_encoder:\n            class_prompt_hidden_states = compute_text_embeddings(args.class_prompt, text_encoders, tokenizers)\n\n    # Clear the memory here\n    if freeze_text_encoder and not train_dataset.custom_instance_prompts:\n        del tokenizers, text_encoders\n        gc.collect()\n        torch.cuda.empty_cache()\n\n    # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion\n    add_special_tokens = True if args.train_text_encoder_ti else False\n\n    if not train_dataset.custom_instance_prompts:\n        if freeze_text_encoder:\n            prompt_embeds = instance_prompt_hidden_states\n            if args.with_prior_preservation:\n                prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n\n        # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the\n        # batch prompts on all training steps\n        else:\n            tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, add_special_tokens)\n            if args.with_prior_preservation:\n                class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, add_special_tokens)\n                tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)\n\n    if args.train_text_encoder_ti and args.validation_prompt:\n        # replace instances of --token_abstraction in validation prompt with the new tokens: \"<si><si+1>\" etc.\n        for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():\n            args.validation_prompt = args.validation_prompt.replace(token_abs, \"\".join(token_replacement))\n    print(\"validation prompt:\", args.validation_prompt)\n\n    if args.cache_latents:\n        latents_cache = []\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                    accelerator.device, non_blocking=True, dtype=torch.float32\n                )\n                latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n\n        if args.validation_prompt is None:\n            del vae\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if not freeze_text_encoder:\n        unet, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder_one, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n        logger.warning(\n            f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n            f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n            f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n        )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"dreambooth-lora-sd-15\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    if args.train_text_encoder:\n        num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)\n    elif args.train_text_encoder_ti:  # args.train_text_encoder_ti\n        num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        # if performing any kind of optimization of text_encoder params\n        if args.train_text_encoder or args.train_text_encoder_ti:\n            if epoch == num_train_epochs_text_encoder:\n                print(\"PIVOT HALFWAY\", epoch)\n                # stopping optimization of text_encoder params\n                # re setting the optimizer to optimize only on unet params\n                optimizer.param_groups[1][\"lr\"] = 0.0\n\n            else:\n                # still optimizng the text encoder\n                text_encoder_one.train()\n                # set top parameter requires_grad = True for gradient checkpointing works\n                if args.train_text_encoder:\n                    text_encoder_one.text_model.embeddings.requires_grad_(True)\n\n        unet.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                prompts = batch[\"prompts\"]\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    if freeze_text_encoder:\n                        prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers)\n\n                    else:\n                        tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens)\n\n                if args.cache_latents:\n                    model_input = latents_cache[step].sample()\n                else:\n                    pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent_dist.sample()\n\n                model_input = model_input * vae_scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    model_input = model_input.to(weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device\n                    )\n                bsz = model_input.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(\n                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                )\n                timesteps = timesteps.long()\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n\n                # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.\n                if not train_dataset.custom_instance_prompts:\n                    elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz\n\n                else:\n                    elems_to_repeat_text_embeds = 1\n\n                # Predict the noise residual\n                if freeze_text_encoder:\n                    prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)\n                    model_pred = unet(noisy_model_input, timesteps, prompt_embeds_input).sample\n                else:\n                    prompt_embeds = encode_prompt(\n                        text_encoders=[text_encoder_one],\n                        tokenizers=None,\n                        prompt=None,\n                        text_input_ids_list=[tokens_one],\n                    )\n                    prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)\n                    model_pred = unet(noisy_model_input, timesteps, prompt_embeds_input).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction=\"mean\")\n\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n\n                    if args.with_prior_preservation:\n                        # if we're using prior preservation, we calc snr for instance loss only -\n                        # and hence only need timesteps corresponding to instance images\n                        snr_timesteps, _ = torch.chunk(timesteps, 2, dim=0)\n                    else:\n                        snr_timesteps = timesteps\n\n                    snr = compute_snr(noise_scheduler, snr_timesteps)\n                    base_weight = (\n                        torch.stack([snr, args.snr_gamma * torch.ones_like(snr_timesteps)], dim=1).min(dim=1)[0] / snr\n                    )\n\n                    if noise_scheduler.config.prediction_type == \"v_prediction\":\n                        # Velocity objective needs to be floored to an SNR weight of one.\n                        mse_loss_weights = base_weight + 1\n                    else:\n                        # Epsilon and sample both use the same loss weights.\n                        mse_loss_weights = base_weight\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(unet_lora_parameters, text_lora_parameters_one)\n                        if (args.train_text_encoder or args.train_text_encoder_ti)\n                        else unet_lora_parameters\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n                # every step, we reset the embeddings to the original embeddings.\n                if args.train_text_encoder_ti:\n                    for idx, text_encoder in enumerate(text_encoders):\n                        embedding_handler.retract_embeddings()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                logger.info(\n                    f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n                    f\" {args.validation_prompt}.\"\n                )\n                # create pipeline\n                if freeze_text_encoder:\n                    text_encoder_one = text_encoder_cls_one.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        subfolder=\"text_encoder\",\n                        revision=args.revision,\n                        variant=args.variant,\n                    )\n                pipeline = StableDiffusionPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    tokenizer=tokenizer_one,\n                    text_encoder=accelerator.unwrap_model(text_encoder_one),\n                    unet=accelerator.unwrap_model(unet),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n\n                # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n                scheduler_args = {}\n\n                if \"variance_type\" in pipeline.scheduler.config:\n                    variance_type = pipeline.scheduler.config.variance_type\n\n                    if variance_type in [\"learned\", \"learned_range\"]:\n                        variance_type = \"fixed_small\"\n\n                    scheduler_args[\"variance_type\"] = variance_type\n\n                pipeline.scheduler = DPMSolverMultistepScheduler.from_config(\n                    pipeline.scheduler.config, **scheduler_args\n                )\n\n                pipeline = pipeline.to(accelerator.device)\n                pipeline.set_progress_bar_config(disable=True)\n\n                # run inference\n                generator = (\n                    torch.Generator(device=accelerator.device).manual_seed(args.seed)\n                    if args.seed is not None\n                    else None\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n\n                if torch.backends.mps.is_available():\n                    autocast_ctx = nullcontext()\n                else:\n                    autocast_ctx = torch.autocast(accelerator.device.type)\n\n                with autocast_ctx:\n                    images = [\n                        pipeline(**pipeline_args, generator=generator).images[0]\n                        for _ in range(args.num_validation_images)\n                    ]\n\n                for tracker in accelerator.trackers:\n                    if tracker.name == \"tensorboard\":\n                        np_images = np.stack([np.asarray(img) for img in images])\n                        tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n                    if tracker.name == \"wandb\":\n                        tracker.log(\n                            {\n                                \"validation\": [\n                                    wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                    for i, image in enumerate(images)\n                                ]\n                            }\n                        )\n                del pipeline\n                torch.cuda.empty_cache()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = accelerator.unwrap_model(unet)\n        unet = unet.to(torch.float32)\n        unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n\n        if args.train_text_encoder:\n            text_encoder_one = accelerator.unwrap_model(text_encoder_one)\n            text_encoder_lora_layers = convert_state_dict_to_diffusers(\n                get_peft_model_state_dict(text_encoder_one.to(torch.float32))\n            )\n        else:\n            text_encoder_lora_layers = None\n\n        StableDiffusionPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_layers,\n            text_encoder_lora_layers=text_encoder_lora_layers,\n        )\n\n        if args.train_text_encoder_ti:\n            embeddings_path = f\"{args.output_dir}/{args.output_dir}_emb.safetensors\"\n            embedding_handler.save_embeddings(embeddings_path)\n\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            # Final inference\n            # Load previous pipeline\n            vae = AutoencoderKL.from_pretrained(\n                vae_path,\n                subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n                revision=args.revision,\n                variant=args.variant,\n                torch_dtype=weight_dtype,\n            )\n            pipeline = StableDiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                vae=vae,\n                revision=args.revision,\n                variant=args.variant,\n                torch_dtype=weight_dtype,\n            )\n\n            # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n            scheduler_args = {}\n\n            if \"variance_type\" in pipeline.scheduler.config:\n                variance_type = pipeline.scheduler.config.variance_type\n\n                if variance_type in [\"learned\", \"learned_range\"]:\n                    variance_type = \"fixed_small\"\n\n                scheduler_args[\"variance_type\"] = variance_type\n\n            pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n\n            # load attention processors\n            pipeline.load_lora_weights(args.output_dir)\n\n            # load new tokens\n            if args.train_text_encoder_ti:\n                state_dict = load_file(embeddings_path)\n                all_new_tokens = []\n                for key, value in token_abstraction_dict.items():\n                    all_new_tokens.extend(value)\n                pipeline.load_textual_inversion(\n                    state_dict[\"clip_l\"],\n                    token=all_new_tokens,\n                    text_encoder=pipeline.text_encoder,\n                    tokenizer=pipeline.tokenizer,\n                )\n            # run inference\n            pipeline = pipeline.to(accelerator.device)\n            generator = (\n                torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n            )\n            images = [\n                pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]\n                for _ in range(args.num_validation_images)\n            ]\n\n            for tracker in accelerator.trackers:\n                if tracker.name == \"tensorboard\":\n                    np_images = np.stack([np.asarray(img) for img in images])\n                    tracker.writer.add_images(\"test\", np_images, epoch, dataformats=\"NHWC\")\n                if tracker.name == \"wandb\":\n                    tracker.log(\n                        {\n                            \"test\": [\n                                wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                for i, image in enumerate(images)\n                            ]\n                        }\n                    )\n\n        # Convert to WebUI format\n        lora_state_dict = load_file(f\"{args.output_dir}/pytorch_lora_weights.safetensors\")\n        peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)\n        kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)\n        save_file(kohya_state_dict, f\"{args.output_dir}/{Path(args.output_dir).name}.safetensors\")\n\n        save_model_card(\n            model_id if not args.push_to_hub else repo_id,\n            use_dora=args.use_dora,\n            images=images,\n            base_model=args.pretrained_model_name_or_path,\n            train_text_encoder=args.train_text_encoder,\n            train_text_encoder_ti=args.train_text_encoder_ti,\n            token_abstraction_dict=train_dataset.token_abstraction_dict,\n            instance_prompt=args.instance_prompt,\n            validation_prompt=args.validation_prompt,\n            repo_folder=args.output_dir,\n            vae_path=args.pretrained_vae_model_name_or_path,\n        )\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=0.31.0\",\n#     \"transformers>=4.41.2\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.11.1\",\n#     \"sentencepiece\",\n# ]\n# ///\n\nimport argparse\nimport gc\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport re\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import List, Optional\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, hf_hub_download, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom packaging import version\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom safetensors.torch import load_file, save_file\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DPMSolverMultistepScheduler,\n    EDMEulerScheduler,\n    EulerDiscreteScheduler,\n    StableDiffusionXLPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr\nfrom diffusers.utils import (\n    check_min_version,\n    convert_all_state_dict_to_peft,\n    convert_state_dict_to_diffusers,\n    convert_state_dict_to_kohya,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef determine_scheduler_type(pretrained_model_name_or_path, revision):\n    model_index_filename = \"model_index.json\"\n    if os.path.isdir(pretrained_model_name_or_path):\n        model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)\n    else:\n        model_index = hf_hub_download(\n            repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision\n        )\n\n    with open(model_index, \"r\") as f:\n        scheduler_type = json.load(f)[\"scheduler\"][1]\n    return scheduler_type\n\n\ndef save_model_card(\n    repo_id: str,\n    use_dora: bool,\n    images: list = None,\n    base_model: str = None,\n    train_text_encoder=False,\n    train_text_encoder_ti=False,\n    token_abstraction_dict=None,\n    instance_prompt: str = None,\n    validation_prompt: str = None,\n    repo_folder=None,\n    vae_path=None,\n):\n    lora = \"lora\" if not use_dora else \"dora\"\n\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n    else:\n        widget_dict.append({\"text\": instance_prompt})\n    embeddings_filename = f\"{repo_folder}_emb\"\n    instance_prompt_webui = re.sub(r\"<s\\d+>\", \"\", re.sub(r\"<s\\d+>\", embeddings_filename, instance_prompt, count=1))\n    ti_keys = \", \".join(f'\"{match}\"' for match in re.findall(r\"<s\\d+>\", instance_prompt))\n    if instance_prompt_webui != embeddings_filename:\n        instance_prompt_sentence = f\"For example, `{instance_prompt_webui}`\"\n    else:\n        instance_prompt_sentence = \"\"\n    trigger_str = f\"You should use {instance_prompt} to trigger the image generation.\"\n    diffusers_imports_pivotal = \"\"\n    diffusers_example_pivotal = \"\"\n    webui_example_pivotal = \"\"\n    license = \"\"\n    if \"playground\" in base_model:\n        license = \"\"\"\\n\n    ## License\n\n    Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).\n    \"\"\"\n\n    if train_text_encoder_ti:\n        trigger_str = (\n            \"To trigger image generation of trained concept(or concepts) replace each concept identifier \"\n            \"in you prompt with the new inserted tokens:\\n\"\n        )\n        diffusers_imports_pivotal = \"\"\"from huggingface_hub import hf_hub_download\nfrom safetensors.torch import load_file\n        \"\"\"\n        diffusers_example_pivotal = f\"\"\"embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type=\"model\")\nstate_dict = load_file(embedding_path)\npipeline.load_textual_inversion(state_dict[\"clip_l\"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)\npipeline.load_textual_inversion(state_dict[\"clip_g\"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)\n        \"\"\"\n        webui_example_pivotal = f\"\"\"- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.\n    - Place it on it on your `embeddings` folder\n    - Use it by adding `{embeddings_filename}` to your prompt. {instance_prompt_sentence}\n    (you need both the LoRA and the embeddings as they were trained together for this LoRA)\n    \"\"\"\n        if token_abstraction_dict:\n            for key, value in token_abstraction_dict.items():\n                tokens = \"\".join(value)\n                trigger_str += f\"\"\"\nto trigger concept `{key}` → use `{tokens}` in your prompt \\n\n\"\"\"\n\n    model_description = f\"\"\"\n# SDXL LoRA DreamBooth - {repo_id}\n\n<Gallery />\n\n## Model description\n\n### These are {repo_id} LoRA adaption weights for {base_model}.\n\n## Download model\n\n### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke\n\n- **LoRA**: download **[`{repo_folder}.safetensors` here 💾](/{repo_id}/blob/main/{repo_folder}.safetensors)**.\n    - Place it on your `models/Lora` folder.\n    - On AUTOMATIC1111, load the LoRA by adding `<lora:{repo_folder}:1>` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).\n{webui_example_pivotal}\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n{diffusers_imports_pivotal}\npipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16).to('cuda')\npipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')\n{diffusers_example_pivotal}\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## Trigger words\n\n{trigger_str}\n\n## Details\nAll [Files & versions](/{repo_id}/tree/main).\n\nThe weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py).\n\nLoRA for the text encoder was enabled. {train_text_encoder}.\n\nPivotal tuning was enabled: {train_text_encoder_ti}.\n\nSpecial VAE used for training: {vae_path}.\n\n{license}\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"openrail++\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"stable-diffusion-xl\",\n        \"stable-diffusion-xl-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        lora,\n        \"template:sd-lora\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n\n    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n    scheduler_args = {}\n\n    if not args.do_edm_style_training:\n        if \"variance_type\" in pipeline.scheduler.config:\n            variance_type = pipeline.scheduler.config.variance_type\n\n            if variance_type in [\"learned\", \"learned_range\"]:\n                variance_type = \"fixed_small\"\n\n            scheduler_args[\"variance_type\"] = variance_type\n\n        pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better\n    # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051\n    if torch.backends.mps.is_available() or \"playground\" in args.pretrained_model_name_or_path:\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type)\n\n    with autocast_ctx:\n        images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.To load the custom captions, the training set directory needs to follow the structure of a \"\n            \"datasets ImageFolder, containing both the images and the corresponding caption for each image. see: \"\n            \"https://huggingface.co/docs/datasets/image_dataset for more information\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset. In some cases, a dataset may have more than one configuration (for example \"\n        \"if it contains different subsets of data within, and you only wish to load a specific subset - in that case specify the desired configuration using --dataset_config_name. Leave as \"\n        \"None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=\"A path to local folder containing the training data of instance images. Specify this arg instead of \"\n        \"--dataset_name if you wish to train using a local folder without custom captions. If you wish to train with custom captions please specify \"\n        \"--dataset_name instead.\",\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--token_abstraction\",\n        type=str,\n        default=\"TOK\",\n        help=\"identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, \"\n        \"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. \"\n        \"'TOK,TOK2,TOK3' etc.\",\n    )\n\n    parser.add_argument(\n        \"--num_new_tokens_per_abstraction\",\n        type=int,\n        default=2,\n        help=\"number of new tokens inserted to the tokenizers per token_abstraction identifier when \"\n        \"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new \"\n        \"tokens - <si><si+1> \",\n    )\n\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--do_edm_style_training\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to conduct training using the EDM formulation as introduced in https://huggingface.co/papers/2206.00364.\",\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"lora-dreambooth-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--clip_skip\",\n        type=int,\n        default=None,\n        help=\"Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that \"\n        \"the output of the pre-final layer will be used for computing the prompt embeddings.\",\n    )\n\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n\n    parser.add_argument(\n        \"--train_text_encoder_ti\",\n        action=\"store_true\",\n        help=(\"Whether to use textual inversion\"),\n    )\n\n    parser.add_argument(\n        \"--train_text_encoder_ti_frac\",\n        type=float,\n        default=0.5,\n        help=(\"The percentage of epochs to perform textual inversion\"),\n    )\n\n    parser.add_argument(\n        \"--train_text_encoder_frac\",\n        type=float,\n        default=1.0,\n        help=(\"The percentage of epochs to perform text encoder tuning\"),\n    )\n\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=None, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--use_dora\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://huggingface.co/papers/2402.09353. \"\n            \"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`\"\n        ),\n    )\n    parser.add_argument(\n        \"--lora_unet_blocks\",\n        type=str,\n        default=None,\n        help=(\n            \"the U-net blocks to tune during training. please specify them in a comma separated string, e.g. `unet.up_blocks.0.attentions.0,unet.up_blocks.0.attentions.1` etc.\"\n            \"NOTE: By default (if not specified) - regular LoRA training is performed. \"\n            \"if --use_blora is enabled, this arg will be ignored, since in B-LoRA training, targeted U-net blocks are `unet.up_blocks.0.attentions.0` and `unet.up_blocks.0.attentions.1`\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_blora\",\n        action=\"store_true\",\n        help=(\n            \"Whether to train a B-LoRA as proposed in- Implicit Style-Content Separation using B-LoRA https://huggingface.co/papers/2403.14572. \"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    if args.train_text_encoder and args.train_text_encoder_ti:\n        raise ValueError(\n            \"Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. \"\n            \"For full LoRA text encoder training check --train_text_encoder, for textual \"\n            \"inversion training check `--train_text_encoder_ti`\"\n        )\n    if args.use_blora and args.lora_unet_blocks:\n        warnings.warn(\n            \"You specified both `--use_blora` and `--lora_unet_blocks`, for B-LoRA training, target unet blocks are: `unet.up_blocks.0.attentions.0` and `unet.up_blocks.0.attentions.1`. \"\n            \"If you wish to target different U-net blocks, don't enable `--use_blora`\"\n        )\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\n# Taken (and slightly modified) from B-LoRA repo https://github.com/yardenfren1996/B-LoRA/blob/main/blora_utils.py\ndef is_belong_to_blocks(key, blocks):\n    try:\n        for g in blocks:\n            if g in key:\n                return True\n        return False\n    except Exception as e:\n        raise type(e)(f\"failed to is_belong_to_block, due to: {e}\")\n\n\ndef get_unet_lora_target_modules(unet, use_blora, target_blocks=None):\n    if use_blora:\n        content_b_lora_blocks = \"unet.up_blocks.0.attentions.0\"\n        style_b_lora_blocks = \"unet.up_blocks.0.attentions.1\"\n        target_blocks = [content_b_lora_blocks, style_b_lora_blocks]\n    try:\n        blocks = [(\".\").join(blk.split(\".\")[1:]) for blk in target_blocks]\n\n        attns = [\n            attn_processor_name.rsplit(\".\", 1)[0]\n            for attn_processor_name, _ in unet.attn_processors.items()\n            if is_belong_to_blocks(attn_processor_name, blocks)\n        ]\n\n        target_modules = [f\"{attn}.{mat}\" for mat in [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"] for attn in attns]\n        return target_modules\n    except Exception as e:\n        raise type(e)(\n            f\"failed to get_target_modules, due to: {e}. \"\n            f\"Please check the modules specified in --lora_unet_blocks are correct\"\n        )\n\n\n# Taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py\nclass TokenEmbeddingsHandler:\n    def __init__(self, text_encoders, tokenizers):\n        self.text_encoders = text_encoders\n        self.tokenizers = tokenizers\n\n        self.train_ids: Optional[torch.Tensor] = None\n        self.inserting_toks: Optional[List[str]] = None\n        self.embeddings_settings = {}\n\n    def initialize_new_tokens(self, inserting_toks: List[str]):\n        idx = 0\n        for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):\n            assert isinstance(inserting_toks, list), \"inserting_toks should be a list of strings.\"\n            assert all(isinstance(tok, str) for tok in inserting_toks), (\n                \"All elements in inserting_toks should be strings.\"\n            )\n\n            self.inserting_toks = inserting_toks\n            special_tokens_dict = {\"additional_special_tokens\": self.inserting_toks}\n            tokenizer.add_special_tokens(special_tokens_dict)\n            text_encoder.resize_token_embeddings(len(tokenizer))\n\n            self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)\n\n            # random initialization of new tokens\n            std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std()\n\n            print(f\"{idx} text encoder's std_token_embedding: {std_token_embedding}\")\n\n            text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = (\n                torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size)\n                .to(device=self.device)\n                .to(dtype=self.dtype)\n                * std_token_embedding\n            )\n            self.embeddings_settings[f\"original_embeddings_{idx}\"] = (\n                text_encoder.text_model.embeddings.token_embedding.weight.data.clone()\n            )\n            self.embeddings_settings[f\"std_token_embedding_{idx}\"] = std_token_embedding\n\n            inu = torch.ones((len(tokenizer),), dtype=torch.bool)\n            inu[self.train_ids] = False\n\n            self.embeddings_settings[f\"index_no_updates_{idx}\"] = inu\n\n            print(self.embeddings_settings[f\"index_no_updates_{idx}\"].shape)\n\n            idx += 1\n\n    def save_embeddings(self, file_path: str):\n        assert self.train_ids is not None, \"Initialize new tokens before saving embeddings.\"\n        tensors = {}\n        # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 -  CLIP ViT-G/14\n        idx_to_text_encoder_name = {0: \"clip_l\", 1: \"clip_g\"}\n        for idx, text_encoder in enumerate(self.text_encoders):\n            assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(\n                self.tokenizers[0]\n            ), \"Tokenizers should be the same.\"\n            new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]\n\n            # New tokens for each text encoder are saved under \"clip_l\" (for text_encoder 0), \"clip_g\" (for\n            # text_encoder 1) to keep compatible with the ecosystem.\n            # Note: When loading with diffusers, any name can work - simply specify in inference\n            tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings\n            # tensors[f\"text_encoders_{idx}\"] = new_token_embeddings\n\n        save_file(tensors, file_path)\n\n    @property\n    def dtype(self):\n        return self.text_encoders[0].dtype\n\n    @property\n    def device(self):\n        return self.text_encoders[0].device\n\n    @torch.no_grad()\n    def retract_embeddings(self):\n        for idx, text_encoder in enumerate(self.text_encoders):\n            index_no_updates = self.embeddings_settings[f\"index_no_updates_{idx}\"]\n            text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = (\n                self.embeddings_settings[f\"original_embeddings_{idx}\"][index_no_updates]\n                .to(device=text_encoder.device)\n                .to(dtype=text_encoder.dtype)\n            )\n\n            # for the parts that were updated, we need to normalize them\n            # to have the same std as before\n            std_token_embedding = self.embeddings_settings[f\"std_token_embedding_{idx}\"]\n\n            index_updates = ~index_no_updates\n            new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates]\n            off_ratio = std_token_embedding / new_embeddings.std()\n\n            new_embeddings = new_embeddings * (off_ratio**0.1)\n            text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        train_text_encoder_ti,\n        class_data_root=None,\n        class_num=None,\n        token_abstraction_dict=None,  # token mapping for textual inversion\n        size=1024,\n        repeats=1,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n        self.token_abstraction_dict = token_abstraction_dict\n        self.train_text_encoder_ti = train_text_encoder_ti\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        # image processing to prepare for using SD-XL micro-conditioning\n        self.original_sizes = []\n        self.crop_top_lefts = []\n        self.pixel_values = []\n        interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)\n        if interpolation is None:\n            raise ValueError(f\"Unsupported interpolation mode {interpolation=}.\")\n        train_resize = transforms.Resize(size, interpolation=interpolation)\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        # if using B-LoRA for single image. do not use transformations\n        single_image = len(self.instance_images) < 2\n        for image in self.instance_images:\n            if not single_image:\n                image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            self.original_sizes.append((image.height, image.width))\n            image = train_resize(image)\n\n            if not single_image and args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop or single_image:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            crop_top_left = (y1, x1)\n            self.crop_top_lefts.append(crop_top_left)\n            image = train_transforms(image)\n            self.pixel_values.append(image)\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n\n            self.original_sizes_class_imgs = []\n            self.crop_top_lefts_class_imgs = []\n            self.pixel_values_class_imgs = []\n            self.class_images = [Image.open(path) for path in self.class_images_path]\n            for image in self.class_images:\n                image = exif_transpose(image)\n                if not image.mode == \"RGB\":\n                    image = image.convert(\"RGB\")\n                self.original_sizes_class_imgs.append((image.height, image.width))\n                image = train_resize(image)\n                if args.random_flip and random.random() < 0.5:\n                    # flip\n                    image = train_flip(image)\n                if args.center_crop:\n                    y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                    x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                    image = train_crop(image)\n                else:\n                    y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                    image = crop(image, y1, x1, h, w)\n                crop_top_left = (y1, x1)\n                self.crop_top_lefts_class_imgs.append(crop_top_left)\n                image = train_transforms(image)\n                self.pixel_values_class_imgs.append(image)\n\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=interpolation),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"instance_images\"] = self.pixel_values[index % self.num_instance_images]\n        example[\"original_size\"] = self.original_sizes[index % self.num_instance_images]\n        example[\"crop_top_left\"] = self.crop_top_lefts[index % self.num_instance_images]\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                if self.train_text_encoder_ti:\n                    # replace instances of --token_abstraction in caption with the new tokens: \"<si><si+1>\" etc.\n                    for token_abs, token_replacement in self.token_abstraction_dict.items():\n                        caption = caption.replace(token_abs, \"\".join(token_replacement))\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            example[\"class_prompt\"] = self.class_prompt\n            example[\"class_images\"] = self.pixel_values_class_imgs[index % self.num_class_images]\n            example[\"class_original_size\"] = self.original_sizes_class_imgs[index % self.num_class_images]\n            example[\"class_crop_top_left\"] = self.crop_top_lefts_class_imgs[index % self.num_class_images]\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n    original_sizes = [example[\"original_size\"] for example in examples]\n    crop_top_lefts = [example[\"crop_top_left\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n        original_sizes += [example[\"class_original_size\"] for example in examples]\n        crop_top_lefts += [example[\"class_crop_top_left\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\n        \"pixel_values\": pixel_values,\n        \"prompts\": prompts,\n        \"original_sizes\": original_sizes,\n        \"crop_top_lefts\": crop_top_lefts,\n    }\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef tokenize_prompt(tokenizer, prompt, add_special_tokens=False):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=tokenizer.model_max_length,\n        truncation=True,\n        add_special_tokens=add_special_tokens,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None, clip_skip=None):\n    prompt_embeds_list = []\n\n    for i, text_encoder in enumerate(text_encoders):\n        if tokenizers is not None:\n            tokenizer = tokenizers[i]\n            text_input_ids = tokenize_prompt(tokenizer, prompt)\n        else:\n            assert text_input_ids_list is not None\n            text_input_ids = text_input_ids_list[i]\n\n        prompt_embeds = text_encoder(\n            text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False\n        )\n\n        # We are only ALWAYS interested in the pooled output of the final text encoder\n        pooled_prompt_embeds = prompt_embeds[0]\n        if clip_skip is None:\n            prompt_embeds = prompt_embeds[-1][-2]\n        else:\n            # \"2\" because SDXL always indexes from the penultimate layer.\n            prompt_embeds = prompt_embeds[-1][-(clip_skip + 2)]\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n        prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if args.do_edm_style_training and args.snr_gamma is not None:\n        raise ValueError(\"Min-SNR formulation is not supported when conducting EDM-style training.\")\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()\n            torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n            pipeline = StableDiffusionXLPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        model_id = args.hub_model_id or Path(args.output_dir).name\n        repo_id = None\n        if args.push_to_hub:\n            repo_id = create_repo(repo_id=model_id, exist_ok=True, token=args.hub_token).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        variant=args.variant,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        variant=args.variant,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)\n    if \"EDM\" in scheduler_type:\n        args.do_edm_style_training = True\n        noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n        logger.info(\"Performing EDM-style training!\")\n    elif args.do_edm_style_training:\n        noise_scheduler = EulerDiscreteScheduler.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n        )\n        logger.info(\"Performing EDM-style training!\")\n    else:\n        noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    latents_mean = latents_std = None\n    if hasattr(vae.config, \"latents_mean\") and vae.config.latents_mean is not None:\n        latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)\n    if hasattr(vae.config, \"latents_std\") and vae.config.latents_std is not None:\n        latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)\n\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    if args.train_text_encoder_ti:\n        # we parse the provided token identifier (or identifiers) into a list. s.t. - \"TOK\" -> [\"TOK\"], \"TOK,\n        # TOK2\" -> [\"TOK\", \"TOK2\"] etc.\n        token_abstraction_list = \"\".join(args.token_abstraction.split()).split(\",\")\n        logger.info(f\"list of token identifiers: {token_abstraction_list}\")\n\n        token_abstraction_dict = {}\n        token_idx = 0\n        for i, token in enumerate(token_abstraction_list):\n            token_abstraction_dict[token] = [\n                f\"<s{token_idx + i + j}>\" for j in range(args.num_new_tokens_per_abstraction)\n            ]\n            token_idx += args.num_new_tokens_per_abstraction - 1\n\n        # replace instances of --token_abstraction in --instance_prompt with the new tokens: \"<si><si+1>\" etc.\n        for token_abs, token_replacement in token_abstraction_dict.items():\n            args.instance_prompt = args.instance_prompt.replace(token_abs, \"\".join(token_replacement))\n            if args.with_prior_preservation:\n                args.class_prompt = args.class_prompt.replace(token_abs, \"\".join(token_replacement))\n            if args.validation_prompt:\n                args.validation_prompt = args.validation_prompt.replace(token_abs, \"\".join(token_replacement))\n                print(\"validation prompt:\", args.validation_prompt)\n        # initialize the new tokens for textual inversion\n        embedding_handler = TokenEmbeddingsHandler(\n            [text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two]\n        )\n        inserting_toks = []\n        for new_tok in token_abstraction_dict.values():\n            inserting_toks.extend(new_tok)\n        embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks)\n\n    # We only train the additional adapter LoRA layers\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n\n    # The VAE is always in float32 to avoid NaN losses.\n    vae.to(accelerator.device, dtype=torch.float32)\n\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, \"\n                    \"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder_one.gradient_checkpointing_enable()\n            text_encoder_two.gradient_checkpointing_enable()\n\n    # now we will add new LoRA weights to the attention layers\n\n    if args.use_blora:\n        # if using B-LoRA, the targeted blocks to train are automatically set\n        target_modules = get_unet_lora_target_modules(unet, use_blora=True)\n    elif args.lora_unet_blocks:\n        # if training specific unet blocks not in the B-LoRA scheme\n        target_blocks_list = \"\".join(args.lora_unet_blocks.split()).split(\",\")\n        logger.info(f\"list of unet blocks to train: {target_blocks_list}\")\n        target_modules = get_unet_lora_target_modules(unet, use_blora=False, target_blocks=target_blocks_list)\n    else:\n        target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\n\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        use_dora=args.use_dora,\n        lora_alpha=args.rank,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    unet.add_adapter(unet_lora_config)\n\n    # The text encoder comes from 🤗 transformers, so we cannot directly modify it.\n    # So, instead, we monkey-patch the forward calls of its attention-blocks.\n    if args.train_text_encoder:\n        text_lora_config = LoraConfig(\n            r=args.rank,\n            use_dora=args.use_dora,\n            lora_alpha=args.rank,\n            lora_dropout=args.lora_dropout,\n            init_lora_weights=\"gaussian\",\n            target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n        )\n        text_encoder_one.add_adapter(text_lora_config)\n        text_encoder_two.add_adapter(text_lora_config)\n\n    # if we use textual inversion, we freeze all parameters except for the token embeddings\n    # in text encoder\n    elif args.train_text_encoder_ti:\n        text_lora_parameters_one = []\n        for name, param in text_encoder_one.named_parameters():\n            if \"token_embedding\" in name:\n                # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16\n                param.data = param.to(dtype=torch.float32)\n                param.requires_grad = True\n                text_lora_parameters_one.append(param)\n            else:\n                param.requires_grad = False\n        text_lora_parameters_two = []\n        for name, param in text_encoder_two.named_parameters():\n            if \"token_embedding\" in name:\n                # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16\n                param.data = param.to(dtype=torch.float32)\n                param.requires_grad = True\n                text_lora_parameters_two.append(param)\n            else:\n                param.requires_grad = False\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            # there are only two options here. Either are just the unet attn processor layers\n            # or there are the unet and text encoder atten layers\n            unet_lora_layers_to_save = None\n            text_encoder_one_lora_layers_to_save = None\n            text_encoder_two_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(unwrap_model(unet))):\n                    unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))\n                elif isinstance(model, type(unwrap_model(text_encoder_one))):\n                    if args.train_text_encoder:\n                        text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(\n                            get_peft_model_state_dict(model)\n                        )\n                elif isinstance(model, type(unwrap_model(text_encoder_two))):\n                    if args.train_text_encoder:\n                        text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(\n                            get_peft_model_state_dict(model)\n                        )\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            StableDiffusionXLPipeline.save_lora_weights(\n                output_dir,\n                unet_lora_layers=unet_lora_layers_to_save,\n                text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,\n                text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,\n            )\n        if args.train_text_encoder_ti:\n            embedding_handler.save_embeddings(f\"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors\")\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n        text_encoder_one_ = None\n        text_encoder_two_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(unet))):\n                unet_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder_one))):\n                text_encoder_one_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder_two))):\n                text_encoder_two_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)\n\n        unet_state_dict = {f\"{k.replace('unet.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"unet.\")}\n        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)\n        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        if args.train_text_encoder:\n            # Do we need to call `scale_lora_layers()` here?\n            _set_state_dict_into_text_encoder(lora_state_dict, prefix=\"text_encoder.\", text_encoder=text_encoder_one_)\n\n            _set_state_dict_into_text_encoder(\n                lora_state_dict, prefix=\"text_encoder_2.\", text_encoder=text_encoder_two_\n            )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [unet_]\n            if args.train_text_encoder:\n                models.extend([text_encoder_one_, text_encoder_two_])\n                # only upcast trainable parameters (LoRA) into fp32\n                cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [unet]\n        if args.train_text_encoder:\n            models.extend([text_encoder_one, text_encoder_two])\n\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))\n\n    if args.train_text_encoder:\n        text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))\n        text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))\n\n    # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training\n    freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)\n\n    # Optimization parameters\n    unet_lora_parameters_with_lr = {\"params\": unet_lora_parameters, \"lr\": args.learning_rate}\n    if not freeze_text_encoder:\n        # different learning rate for text encoder and unet\n        text_lora_parameters_one_with_lr = {\n            \"params\": text_lora_parameters_one,\n            \"weight_decay\": args.adam_weight_decay_text_encoder\n            if args.adam_weight_decay_text_encoder\n            else args.adam_weight_decay,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        text_lora_parameters_two_with_lr = {\n            \"params\": text_lora_parameters_two,\n            \"weight_decay\": args.adam_weight_decay_text_encoder\n            if args.adam_weight_decay_text_encoder\n            else args.adam_weight_decay,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        params_to_optimize = [\n            unet_lora_parameters_with_lr,\n            text_lora_parameters_one_with_lr,\n            text_lora_parameters_two_with_lr,\n        ]\n    else:\n        params_to_optimize = [unet_lora_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n        if args.train_text_encoder and args.text_encoder_lr:\n            logger.warning(\n                f\"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:\"\n                f\" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. \"\n                f\"When using prodigy only learning_rate is used as the initial learning rate.\"\n            )\n            # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be\n            # --learning_rate\n            params_to_optimize[1][\"lr\"] = args.learning_rate\n            params_to_optimize[2][\"lr\"] = args.learning_rate\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        train_text_encoder_ti=args.train_text_encoder_ti,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Computes additional embeddings/ids required by the SDXL UNet.\n    # regular text embeddings (when `train_text_encoder` is not True)\n    # pooled text embeddings\n    # time ids\n\n    def compute_time_ids(crops_coords_top_left, original_size=None):\n        # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n        target_size = (args.resolution, args.resolution)\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n        add_time_ids = torch.tensor([add_time_ids])\n        add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)\n        return add_time_ids\n\n    if not args.train_text_encoder:\n        tokenizers = [tokenizer_one, tokenizer_two]\n        text_encoders = [text_encoder_one, text_encoder_two]\n\n        def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):\n            with torch.no_grad():\n                prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt, clip_skip)\n                prompt_embeds = prompt_embeds.to(accelerator.device)\n                pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)\n            return prompt_embeds, pooled_prompt_embeds\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if freeze_text_encoder and not train_dataset.custom_instance_prompts:\n        instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(\n            args.instance_prompt, text_encoders, tokenizers, args.clip_skip\n        )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        if freeze_text_encoder:\n            class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(\n                args.class_prompt, text_encoders, tokenizers\n            )\n\n    # Clear the memory here\n    if freeze_text_encoder and not train_dataset.custom_instance_prompts:\n        del tokenizers, text_encoders\n        gc.collect()\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n\n    # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion\n    add_special_tokens = True if args.train_text_encoder_ti else False\n\n    if not train_dataset.custom_instance_prompts:\n        if freeze_text_encoder:\n            prompt_embeds = instance_prompt_hidden_states\n            unet_add_text_embeds = instance_pooled_prompt_embeds\n            if args.with_prior_preservation:\n                prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n                unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)\n        # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the\n        # batch prompts on all training steps\n        else:\n            tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, add_special_tokens)\n            tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt, add_special_tokens)\n            if args.with_prior_preservation:\n                class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, add_special_tokens)\n                class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt, add_special_tokens)\n                tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)\n                tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)\n\n    if args.cache_latents:\n        latents_cache = []\n        # Store vae config before potential deletion\n        vae_scaling_factor = vae.config.scaling_factor\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                    accelerator.device, non_blocking=True, dtype=torch.float32\n                )\n                latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n\n        if args.validation_prompt is None:\n            del vae\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n    else:\n        vae_scaling_factor = vae.config.scaling_factor\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if not freeze_text_encoder:\n        unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = (\n            \"dreambooth-lora-sd-xl\"\n            if \"playground\" not in args.pretrained_model_name_or_path\n            else \"dreambooth-lora-playground\"\n        )\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    if args.train_text_encoder:\n        num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)\n    elif args.train_text_encoder_ti:  # args.train_text_encoder_ti\n        num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)\n    # flag used for textual inversion\n    pivoted = False\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        # if performing any kind of optimization of text_encoder params\n        if args.train_text_encoder or args.train_text_encoder_ti:\n            if epoch == num_train_epochs_text_encoder:\n                print(\"PIVOT HALFWAY\", epoch)\n                # stopping optimization of text_encoder params\n                # this flag is used to reset the optimizer to optimize only on unet params\n                pivoted = True\n\n            else:\n                # still optimizing the text encoder\n                text_encoder_one.train()\n                text_encoder_two.train()\n                # set top parameter requires_grad = True for gradient checkpointing works\n                if args.train_text_encoder:\n                    accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)\n                    accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)\n\n        for step, batch in enumerate(train_dataloader):\n            if pivoted:\n                # stopping optimization of text_encoder params\n                # re setting the optimizer to optimize only on unet params\n                optimizer.param_groups[1][\"lr\"] = 0.0\n                optimizer.param_groups[2][\"lr\"] = 0.0\n\n            with accelerator.accumulate(unet):\n                prompts = batch[\"prompts\"]\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    if freeze_text_encoder:\n                        prompt_embeds, unet_add_text_embeds = compute_text_embeddings(\n                            prompts, text_encoders, tokenizers, args.clip_skip\n                        )\n\n                    else:\n                        tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens)\n                        tokens_two = tokenize_prompt(tokenizer_two, prompts, add_special_tokens)\n\n                if args.cache_latents:\n                    model_input = latents_cache[step].sample()\n                else:\n                    pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent_dist.sample()\n\n                if latents_mean is None and latents_std is None:\n                    model_input = model_input * vae_scaling_factor\n                    if args.pretrained_vae_model_name_or_path is None:\n                        model_input = model_input.to(weight_dtype)\n                else:\n                    latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)\n                    latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)\n                    model_input = (model_input - latents_mean) * vae_scaling_factor / latents_std\n                    model_input = model_input.to(dtype=weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device\n                    )\n\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                if not args.do_edm_style_training:\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                    )\n                    timesteps = timesteps.long()\n                else:\n                    # in EDM formulation, the model is conditioned on the pre-conditioned noise levels\n                    # instead of discrete timesteps, so here we sample indices to get the noise levels\n                    # from `scheduler.timesteps`\n                    indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))\n                    timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n                # For EDM-style training, we first obtain the sigmas based on the continuous timesteps.\n                # We then precondition the final model inputs based on these sigmas instead of the timesteps.\n                # Follow: Section 5 of https://huggingface.co/papers/2206.00364.\n                if args.do_edm_style_training:\n                    sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)\n                    if \"EDM\" in scheduler_type:\n                        inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)\n                    else:\n                        inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)\n\n                # time ids\n                add_time_ids = torch.cat(\n                    [\n                        compute_time_ids(original_size=s, crops_coords_top_left=c)\n                        for s, c in zip(batch[\"original_sizes\"], batch[\"crop_top_lefts\"])\n                    ]\n                )\n\n                # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.\n                if not train_dataset.custom_instance_prompts:\n                    elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz\n\n                else:\n                    elems_to_repeat_text_embeds = 1\n\n                # Predict the noise residual\n                if freeze_text_encoder:\n                    unet_added_conditions = {\n                        \"time_ids\": add_time_ids,\n                        \"text_embeds\": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),\n                    }\n                    prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)\n                    model_pred = unet(\n                        inp_noisy_latents if args.do_edm_style_training else noisy_model_input,\n                        timesteps,\n                        prompt_embeds_input,\n                        added_cond_kwargs=unet_added_conditions,\n                        return_dict=False,\n                    )[0]\n                else:\n                    unet_added_conditions = {\"time_ids\": add_time_ids}\n                    prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                        text_encoders=[text_encoder_one, text_encoder_two],\n                        tokenizers=None,\n                        prompt=None,\n                        text_input_ids_list=[tokens_one, tokens_two],\n                        clip_skip=args.clip_skip,\n                    )\n                    unet_added_conditions.update(\n                        {\"text_embeds\": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}\n                    )\n                    prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)\n                    model_pred = unet(\n                        inp_noisy_latents if args.do_edm_style_training else noisy_model_input,\n                        timesteps,\n                        prompt_embeds_input,\n                        added_cond_kwargs=unet_added_conditions,\n                        return_dict=False,\n                    )[0]\n\n                weighting = None\n                if args.do_edm_style_training:\n                    # Similar to the input preconditioning, the model predictions are also preconditioned\n                    # on noised model inputs (before preconditioning) and the sigmas.\n                    # Follow: Section 5 of https://huggingface.co/papers/2206.00364.\n                    if \"EDM\" in scheduler_type:\n                        model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)\n                    else:\n                        if noise_scheduler.config.prediction_type == \"epsilon\":\n                            model_pred = model_pred * (-sigmas) + noisy_model_input\n                        elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                            model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (\n                                noisy_model_input / (sigmas**2 + 1)\n                            )\n                    # We are not doing weighting here because it tends result in numerical problems.\n                    # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051\n                    # There might be other alternatives for weighting as well:\n                    # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686\n                    if \"EDM\" not in scheduler_type:\n                        weighting = (sigmas**-2.0).float()\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = model_input if args.do_edm_style_training else noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = (\n                        model_input\n                        if args.do_edm_style_training\n                        else noise_scheduler.get_velocity(model_input, noise, timesteps)\n                    )\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    if weighting is not None:\n                        prior_loss = torch.mean(\n                            (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                                target_prior.shape[0], -1\n                            ),\n                            1,\n                        )\n                        prior_loss = prior_loss.mean()\n                    else:\n                        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction=\"mean\")\n\n                if args.snr_gamma is None:\n                    if weighting is not None:\n                        loss = torch.mean(\n                            (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(\n                                target.shape[0], -1\n                            ),\n                            1,\n                        )\n                        loss = loss.mean()\n                    else:\n                        loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n\n                    if args.with_prior_preservation:\n                        # if we're using prior preservation, we calc snr for instance loss only -\n                        # and hence only need timesteps corresponding to instance images\n                        snr_timesteps, _ = torch.chunk(timesteps, 2, dim=0)\n                    else:\n                        snr_timesteps = timesteps\n\n                    snr = compute_snr(noise_scheduler, snr_timesteps)\n                    base_weight = (\n                        torch.stack([snr, args.snr_gamma * torch.ones_like(snr_timesteps)], dim=1).min(dim=1)[0] / snr\n                    )\n\n                    if noise_scheduler.config.prediction_type == \"v_prediction\":\n                        # Velocity objective needs to be floored to an SNR weight of one.\n                        mse_loss_weights = base_weight + 1\n                    else:\n                        # Epsilon and sample both use the same loss weights.\n                        mse_loss_weights = base_weight\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)\n                        if (args.train_text_encoder or args.train_text_encoder_ti)\n                        else unet_lora_parameters\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n                # every step, we reset the embeddings to the original embeddings.\n                if args.train_text_encoder_ti:\n                    embedding_handler.retract_embeddings()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                if freeze_text_encoder:\n                    text_encoder_one = text_encoder_cls_one.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        subfolder=\"text_encoder\",\n                        revision=args.revision,\n                        variant=args.variant,\n                    )\n                    text_encoder_two = text_encoder_cls_two.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        subfolder=\"text_encoder_2\",\n                        revision=args.revision,\n                        variant=args.variant,\n                    )\n                pipeline = StableDiffusionXLPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    tokenizer=tokenizer_one,\n                    tokenizer_2=tokenizer_two,\n                    text_encoder=accelerator.unwrap_model(text_encoder_one),\n                    text_encoder_2=accelerator.unwrap_model(text_encoder_two),\n                    unet=accelerator.unwrap_model(unet),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n\n                images = log_validation(\n                    pipeline,\n                    args,\n                    accelerator,\n                    pipeline_args,\n                    epoch,\n                )\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        unet = unet.to(torch.float32)\n        unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n\n        if args.train_text_encoder:\n            text_encoder_one = unwrap_model(text_encoder_one)\n            text_encoder_lora_layers = convert_state_dict_to_diffusers(\n                get_peft_model_state_dict(text_encoder_one.to(torch.float32))\n            )\n            text_encoder_two = unwrap_model(text_encoder_two)\n            text_encoder_2_lora_layers = convert_state_dict_to_diffusers(\n                get_peft_model_state_dict(text_encoder_two.to(torch.float32))\n            )\n        else:\n            text_encoder_lora_layers = None\n            text_encoder_2_lora_layers = None\n\n        StableDiffusionXLPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_layers,\n            text_encoder_lora_layers=text_encoder_lora_layers,\n            text_encoder_2_lora_layers=text_encoder_2_lora_layers,\n        )\n\n        if args.train_text_encoder_ti:\n            embeddings_path = f\"{args.output_dir}/{args.output_dir}_emb.safetensors\"\n            embedding_handler.save_embeddings(embeddings_path)\n\n        # Final inference\n        # Load previous pipeline\n        vae = AutoencoderKL.from_pretrained(\n            vae_path,\n            subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        pipeline = StableDiffusionXLPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            vae=vae,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt, \"num_inference_steps\": 25}\n            images = log_validation(\n                pipeline,\n                args,\n                accelerator,\n                pipeline_args,\n                epoch,\n                is_final_validation=True,\n            )\n\n        # Convert to WebUI format\n        lora_state_dict = load_file(f\"{args.output_dir}/pytorch_lora_weights.safetensors\")\n        peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)\n        kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)\n        save_file(kohya_state_dict, f\"{args.output_dir}/{Path(args.output_dir).name}.safetensors\")\n\n        save_model_card(\n            model_id if not args.push_to_hub else repo_id,\n            use_dora=args.use_dora,\n            images=images,\n            base_model=args.pretrained_model_name_or_path,\n            train_text_encoder=args.train_text_encoder,\n            train_text_encoder_ti=args.train_text_encoder_ti,\n            token_abstraction_dict=train_dataset.token_abstraction_dict,\n            instance_prompt=args.instance_prompt,\n            validation_prompt=args.validation_prompt,\n            repo_folder=args.output_dir,\n            vae_path=args.pretrained_vae_model_name_or_path,\n        )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/amused/README.md",
    "content": "## Amused training\n\nAmused can be finetuned on simple datasets relatively cheaply and quickly. Using 8bit optimizers, lora, and gradient accumulation, amused can be finetuned with as little as 5.5 GB. Here are a set of examples for finetuning amused on some relatively simple datasets. These training recipes are aggressively oriented towards minimal resources and fast verification -- i.e. the batch sizes are quite low and the learning rates are quite high. For optimal quality, you will probably want to increase the batch sizes and decrease learning rates.\n\nAll training examples use fp16 mixed precision and gradient checkpointing. We don't show 8 bit adam + lora as its about the same memory use as just using lora (bitsandbytes uses full precision optimizer states for weights below a minimum size).\n\n### Finetuning the 256 checkpoint\n\nThese examples finetune on this [nouns](https://huggingface.co/datasets/m1guelpf/nouns) dataset.\n\nExample results:\n\n![noun1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun1.png) ![noun2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun2.png) ![noun3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun3.png)\n\n\n#### Full finetuning\n\nBatch size: 8, Learning rate: 1e-4, Gives decent results in 750-1000 steps\n\n| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |\n|------------|-----------------------------|------------------|-------------|\n|    8        |          1                   |     8             |      19.7 GB       |\n|    4        |          2                   |     8             |      18.3 GB       |\n|    1        |          8                   |     8             |      17.9 GB       |\n\n```sh\naccelerate launch train_amused.py \\\n    --output_dir <output path> \\\n    --train_batch_size <batch size> \\\n    --gradient_accumulation_steps <gradient accumulation steps> \\\n    --learning_rate 1e-4 \\\n    --pretrained_model_name_or_path amused/amused-256 \\\n    --instance_data_dataset  'm1guelpf/nouns' \\\n    --image_key image \\\n    --prompt_key text \\\n    --resolution 256 \\\n    --mixed_precision fp16 \\\n    --lr_scheduler constant \\\n    --validation_prompts \\\n        'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \\\n        'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \\\n        'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \\\n        'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \\\n        'a pixel art character with square red glasses' \\\n        'a pixel art character' \\\n        'square red glasses on a pixel art character' \\\n        'square red glasses on a pixel art character with a baseball-shaped head' \\\n    --max_train_steps 10000 \\\n    --checkpointing_steps 500 \\\n    --validation_steps 250 \\\n    --gradient_checkpointing\n```\n\n#### Full finetuning + 8 bit adam\n\nNote that this training config keeps the batch size low and the learning rate high to get results fast with low resources. However, due to 8 bit adam, it will diverge eventually. If you want to train for longer, you will have to up the batch size and lower the learning rate.\n\nBatch size: 16, Learning rate: 2e-5, Gives decent results in ~750 steps\n\n| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |\n|------------|-----------------------------|------------------|-------------|\n|    16        |          1                   |     16             |      20.1 GB       |\n|    8        |          2                   |      16           |      15.6 GB       |\n|    1        |          16                   |     16            |      10.7 GB       |\n\n```sh\naccelerate launch train_amused.py \\\n    --output_dir <output path> \\\n    --train_batch_size <batch size> \\\n    --gradient_accumulation_steps <gradient accumulation steps> \\\n    --learning_rate 2e-5 \\\n    --use_8bit_adam \\\n    --pretrained_model_name_or_path amused/amused-256 \\\n    --instance_data_dataset  'm1guelpf/nouns' \\\n    --image_key image \\\n    --prompt_key text \\\n    --resolution 256 \\\n    --mixed_precision fp16 \\\n    --lr_scheduler constant \\\n    --validation_prompts \\\n        'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \\\n        'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \\\n        'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \\\n        'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \\\n        'a pixel art character with square red glasses' \\\n        'a pixel art character' \\\n        'square red glasses on a pixel art character' \\\n        'square red glasses on a pixel art character with a baseball-shaped head' \\\n    --max_train_steps 10000 \\\n    --checkpointing_steps 500 \\\n    --validation_steps 250 \\\n    --gradient_checkpointing\n```\n\n#### Full finetuning + lora\n\nBatch size: 16, Learning rate: 8e-4, Gives decent results in 1000-1250 steps\n\n| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |\n|------------|-----------------------------|------------------|-------------|\n|    16        |          1                   |     16             |      14.1 GB       |\n|    8        |          2                   |      16           |      10.1 GB       |\n|    1        |          16                   |     16            |      6.5 GB       |\n\n```sh\naccelerate launch train_amused.py \\\n    --output_dir <output path> \\\n    --train_batch_size <batch size> \\\n    --gradient_accumulation_steps <gradient accumulation steps> \\\n    --learning_rate 8e-4 \\\n    --use_lora \\\n    --pretrained_model_name_or_path amused/amused-256 \\\n    --instance_data_dataset  'm1guelpf/nouns' \\\n    --image_key image \\\n    --prompt_key text \\\n    --resolution 256 \\\n    --mixed_precision fp16 \\\n    --lr_scheduler constant \\\n    --validation_prompts \\\n        'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \\\n        'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \\\n        'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \\\n        'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \\\n        'a pixel art character with square red glasses' \\\n        'a pixel art character' \\\n        'square red glasses on a pixel art character' \\\n        'square red glasses on a pixel art character with a baseball-shaped head' \\\n    --max_train_steps 10000 \\\n    --checkpointing_steps 500 \\\n    --validation_steps 250 \\\n    --gradient_checkpointing\n```\n\n### Finetuning the 512 checkpoint\n\nThese examples finetune on this [minecraft](https://huggingface.co/monadical-labs/minecraft-preview) dataset.\n\nExample results:\n\n![minecraft1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft1.png) ![minecraft2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft2.png) ![minecraft3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft3.png)\n\n#### Full finetuning\n\nBatch size: 8, Learning rate: 8e-5, Gives decent results in 500-1000 steps\n\n| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |\n|------------|-----------------------------|------------------|-------------|\n|    8        |          1                   |     8             |      24.2 GB       |\n|    4        |          2                   |     8             |      19.7 GB       |\n|    1        |          8                   |     8             |      16.99 GB       |\n\n```sh\naccelerate launch train_amused.py \\\n    --output_dir <output path> \\\n    --train_batch_size <batch size> \\\n    --gradient_accumulation_steps <gradient accumulation steps> \\\n    --learning_rate 8e-5 \\\n    --pretrained_model_name_or_path amused/amused-512 \\\n    --instance_data_dataset  'monadical-labs/minecraft-preview' \\\n    --prompt_prefix 'minecraft ' \\\n    --image_key image \\\n    --prompt_key text \\\n    --resolution 512 \\\n    --mixed_precision fp16 \\\n    --lr_scheduler constant \\\n    --validation_prompts \\\n        'minecraft Avatar' \\\n        'minecraft character' \\\n        'minecraft' \\\n        'minecraft president' \\\n        'minecraft pig' \\\n    --max_train_steps 10000 \\\n    --checkpointing_steps 500 \\\n    --validation_steps 250 \\\n    --gradient_checkpointing\n```\n\n#### Full finetuning + 8 bit adam\n\nBatch size: 8, Learning rate: 5e-6, Gives decent results in 500-1000 steps\n\n| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |\n|------------|-----------------------------|------------------|-------------|\n|    8        |          1                   |     8             |      21.2 GB       |\n|    4        |          2                   |     8             |      13.3 GB       |\n|    1        |          8                   |     8             |      9.9 GB       |\n\n```sh\naccelerate launch train_amused.py \\\n    --output_dir <output path> \\\n    --train_batch_size <batch size> \\\n    --gradient_accumulation_steps <gradient accumulation steps> \\\n    --learning_rate 5e-6 \\\n    --pretrained_model_name_or_path amused/amused-512 \\\n    --instance_data_dataset  'monadical-labs/minecraft-preview' \\\n    --prompt_prefix 'minecraft ' \\\n    --image_key image \\\n    --prompt_key text \\\n    --resolution 512 \\\n    --mixed_precision fp16 \\\n    --lr_scheduler constant \\\n    --validation_prompts \\\n        'minecraft Avatar' \\\n        'minecraft character' \\\n        'minecraft' \\\n        'minecraft president' \\\n        'minecraft pig' \\\n    --max_train_steps 10000 \\\n    --checkpointing_steps 500 \\\n    --validation_steps 250 \\\n    --gradient_checkpointing\n```\n\n#### Full finetuning + lora\n\nBatch size: 8, Learning rate: 1e-4, Gives decent results in 500-1000 steps\n\n| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used |\n|------------|-----------------------------|------------------|-------------|\n|    8        |          1                   |     8             |      12.7 GB       |\n|    4        |          2                   |     8             |      9.0 GB       |\n|    1        |          8                   |     8             |      5.6 GB       |\n\n```sh\naccelerate launch train_amused.py \\\n    --output_dir <output path> \\\n    --train_batch_size <batch size> \\\n    --gradient_accumulation_steps <gradient accumulation steps> \\\n    --learning_rate 1e-4 \\\n    --use_lora \\\n    --pretrained_model_name_or_path amused/amused-512 \\\n    --instance_data_dataset  'monadical-labs/minecraft-preview' \\\n    --prompt_prefix 'minecraft ' \\\n    --image_key image \\\n    --prompt_key text \\\n    --resolution 512 \\\n    --mixed_precision fp16 \\\n    --lr_scheduler constant \\\n    --validation_prompts \\\n        'minecraft Avatar' \\\n        'minecraft character' \\\n        'minecraft' \\\n        'minecraft president' \\\n        'minecraft pig' \\\n    --max_train_steps 10000 \\\n    --checkpointing_steps 500 \\\n    --validation_steps 250 \\\n    --gradient_checkpointing\n```\n\n### Styledrop\n\n[Styledrop](https://huggingface.co/papers/2306.00983) is an efficient finetuning method for learning a new style from just one or very few images. It has an optional first stage to generate human picked additional training samples. The additional training samples can be used to augment the initial images. Our examples exclude the optional additional image selection stage and instead we just finetune on a single image.\n\nThis is our example style image:\n![example](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png)\n\nDownload it to your local directory with\n```sh\nwget https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png\n```\n\n#### 256\n\nExample results:\n\n![glowing_256_1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_1.png) ![glowing_256_2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_2.png) ![glowing_256_3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_3.png)\n\nLearning rate: 4e-4, Gives decent results in 1500-2000 steps\n\nMemory used: 6.5 GB\n\n```sh\naccelerate launch train_amused.py \\\n    --output_dir <output path> \\\n    --mixed_precision fp16 \\\n    --report_to wandb \\\n    --use_lora \\\n    --pretrained_model_name_or_path amused/amused-256 \\\n    --train_batch_size 1 \\\n    --lr_scheduler constant \\\n    --learning_rate 4e-4 \\\n    --validation_prompts \\\n        'A chihuahua walking on the street in [V] style' \\\n        'A banana on the table in [V] style' \\\n        'A church on the street in [V] style' \\\n        'A tabby cat walking in the forest in [V] style' \\\n    --instance_data_image 'A mushroom in [V] style.png' \\\n    --max_train_steps 10000 \\\n    --checkpointing_steps 500 \\\n    --validation_steps 100 \\\n    --resolution 256\n```\n\n#### 512\n\nExample results:\n\n![glowing_512_1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_1.png) ![glowing_512_2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_2.png) ![glowing_512_3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_3.png)\n\nLearning rate: 1e-3, Lora alpha 1, Gives decent results in 1500-2000 steps\n\nMemory used: 5.6 GB\n\n```\naccelerate launch train_amused.py \\\n    --output_dir <output path> \\\n    --mixed_precision fp16 \\\n    --report_to wandb \\\n    --use_lora \\\n    --pretrained_model_name_or_path amused/amused-512 \\\n    --train_batch_size 1 \\\n    --lr_scheduler constant \\\n    --learning_rate 1e-3 \\\n    --validation_prompts \\\n        'A chihuahua walking on the street in [V] style' \\\n        'A banana on the table in [V] style' \\\n        'A church on the street in [V] style' \\\n        'A tabby cat walking in the forest in [V] style' \\\n    --instance_data_image 'A mushroom in [V] style.png' \\\n    --max_train_steps 100000 \\\n    --checkpointing_steps 500 \\\n    --validation_steps 100 \\\n    --resolution 512 \\\n    --lora_alpha 1\n```"
  },
  {
    "path": "examples/amused/train_amused.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport logging\nimport math\nimport os\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport torch\nimport torch.nn.functional as F\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom peft import LoraConfig\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import DataLoader, Dataset, default_collate\nfrom torchvision import transforms\nfrom transformers import (\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n)\n\nimport diffusers.optimization\nfrom diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel\nfrom diffusers.loaders import AmusedLoraLoaderMixin\nfrom diffusers.utils import is_wandb_available\n\n\nif is_wandb_available():\n    import wandb\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--instance_data_dataset\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A Hugging Face dataset containing the training images\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of instance images.\",\n    )\n    parser.add_argument(\n        \"--instance_data_image\", type=str, default=None, required=False, help=\"A single training image\"\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\"--ema_decay\", type=float, default=0.9999)\n    parser.add_argument(\"--ema_update_after_step\", type=int, default=0)\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"muse_training\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--logging_steps\",\n        type=int,\n        default=50,\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more details\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=0.0003,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"wandb\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--validation_prompts\", type=str, nargs=\"*\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\"--split_vae_encode\", type=int, required=False, default=None)\n    parser.add_argument(\"--min_masking_rate\", type=float, default=0.0)\n    parser.add_argument(\"--cond_dropout_prob\", type=float, default=0.0)\n    parser.add_argument(\"--max_grad_norm\", default=None, type=float, help=\"Max gradient norm.\", required=False)\n    parser.add_argument(\"--use_lora\", action=\"store_true\", help=\"Fine tune the model using LoRa\")\n    parser.add_argument(\"--text_encoder_use_lora\", action=\"store_true\", help=\"Fine tune the model using LoRa\")\n    parser.add_argument(\"--lora_r\", default=16, type=int)\n    parser.add_argument(\"--lora_alpha\", default=32, type=int)\n    parser.add_argument(\"--lora_target_modules\", default=[\"to_q\", \"to_k\", \"to_v\"], type=str, nargs=\"+\")\n    parser.add_argument(\"--text_encoder_lora_r\", default=16, type=int)\n    parser.add_argument(\"--text_encoder_lora_alpha\", default=32, type=int)\n    parser.add_argument(\"--text_encoder_lora_target_modules\", default=[\"to_q\", \"to_k\", \"to_v\"], type=str, nargs=\"+\")\n    parser.add_argument(\"--train_text_encoder\", action=\"store_true\")\n    parser.add_argument(\"--image_key\", type=str, required=False)\n    parser.add_argument(\"--prompt_key\", type=str, required=False)\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\"--prompt_prefix\", type=str, required=False, default=None)\n\n    args = parser.parse_args()\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    num_datasources = sum(\n        [x is not None for x in [args.instance_data_dir, args.instance_data_image, args.instance_data_dataset]]\n    )\n\n    if num_datasources != 1:\n        raise ValueError(\n            \"provide one and only one of `--instance_data_dir`, `--instance_data_image`, or `--instance_data_dataset`\"\n        )\n\n    if args.instance_data_dir is not None:\n        if not os.path.exists(args.instance_data_dir):\n            raise ValueError(f\"Does not exist: `--args.instance_data_dir` {args.instance_data_dir}\")\n\n    if args.instance_data_image is not None:\n        if not os.path.exists(args.instance_data_image):\n            raise ValueError(f\"Does not exist: `--args.instance_data_image` {args.instance_data_image}\")\n\n    if args.instance_data_dataset is not None and (args.image_key is None or args.prompt_key is None):\n        raise ValueError(\"`--instance_data_dataset` requires setting `--image_key` and `--prompt_key`\")\n\n    return args\n\n\nclass InstanceDataRootDataset(Dataset):\n    def __init__(\n        self,\n        instance_data_root,\n        tokenizer,\n        size=512,\n    ):\n        self.size = size\n        self.tokenizer = tokenizer\n        self.instance_images_path = list(Path(instance_data_root).iterdir())\n\n    def __len__(self):\n        return len(self.instance_images_path)\n\n    def __getitem__(self, index):\n        image_path = self.instance_images_path[index % len(self.instance_images_path)]\n        instance_image = Image.open(image_path)\n        rv = process_image(instance_image, self.size)\n\n        prompt = os.path.splitext(os.path.basename(image_path))[0]\n        rv[\"prompt_input_ids\"] = tokenize_prompt(self.tokenizer, prompt)[0]\n        return rv\n\n\nclass InstanceDataImageDataset(Dataset):\n    def __init__(\n        self,\n        instance_data_image,\n        train_batch_size,\n        size=512,\n    ):\n        self.value = process_image(Image.open(instance_data_image), size)\n        self.train_batch_size = train_batch_size\n\n    def __len__(self):\n        # Needed so a full batch of the data can be returned. Otherwise will return\n        # batches of size 1\n        return self.train_batch_size\n\n    def __getitem__(self, index):\n        return self.value\n\n\nclass HuggingFaceDataset(Dataset):\n    def __init__(\n        self,\n        hf_dataset,\n        tokenizer,\n        image_key,\n        prompt_key,\n        prompt_prefix=None,\n        size=512,\n    ):\n        self.size = size\n        self.image_key = image_key\n        self.prompt_key = prompt_key\n        self.tokenizer = tokenizer\n        self.hf_dataset = hf_dataset\n        self.prompt_prefix = prompt_prefix\n\n    def __len__(self):\n        return len(self.hf_dataset)\n\n    def __getitem__(self, index):\n        item = self.hf_dataset[index]\n\n        rv = process_image(item[self.image_key], self.size)\n\n        prompt = item[self.prompt_key]\n\n        if self.prompt_prefix is not None:\n            prompt = self.prompt_prefix + prompt\n\n        rv[\"prompt_input_ids\"] = tokenize_prompt(self.tokenizer, prompt)[0]\n\n        return rv\n\n\ndef process_image(image, size):\n    image = exif_transpose(image)\n\n    if not image.mode == \"RGB\":\n        image = image.convert(\"RGB\")\n\n    orig_height = image.height\n    orig_width = image.width\n\n    image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image)\n\n    c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size))\n    image = transforms.functional.crop(image, c_top, c_left, size, size)\n\n    image = transforms.ToTensor()(image)\n\n    micro_conds = torch.tensor(\n        [orig_width, orig_height, c_top, c_left, 6.0],\n    )\n\n    return {\"image\": image, \"micro_conds\": micro_conds}\n\n\ndef tokenize_prompt(tokenizer, prompt):\n    return tokenizer(\n        prompt,\n        truncation=True,\n        padding=\"max_length\",\n        max_length=77,\n        return_tensors=\"pt\",\n    ).input_ids\n\n\ndef encode_prompt(text_encoder, input_ids):\n    outputs = text_encoder(input_ids, return_dict=True, output_hidden_states=True)\n    encoder_hidden_states = outputs.hidden_states[-2]\n    cond_embeds = outputs[0]\n    return encoder_hidden_states, cond_embeds\n\n\ndef main(args):\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if accelerator.is_main_process:\n        os.makedirs(args.output_dir, exist_ok=True)\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"amused\", config=vars(copy.deepcopy(args)))\n\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # TODO - will have to fix loading if training text encoder\n    text_encoder = CLIPTextModelWithProjection.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    tokenizer = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision, variant=args.variant\n    )\n    vq_model = VQModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vqvae\", revision=args.revision, variant=args.variant\n    )\n\n    if args.train_text_encoder:\n        if args.text_encoder_use_lora:\n            lora_config = LoraConfig(\n                r=args.text_encoder_lora_r,\n                lora_alpha=args.text_encoder_lora_alpha,\n                target_modules=args.text_encoder_lora_target_modules,\n            )\n            text_encoder.add_adapter(lora_config)\n        text_encoder.train()\n        text_encoder.requires_grad_(True)\n    else:\n        text_encoder.eval()\n        text_encoder.requires_grad_(False)\n\n    vq_model.requires_grad_(False)\n\n    model = UVit2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n\n    if args.use_lora:\n        lora_config = LoraConfig(\n            r=args.lora_r,\n            lora_alpha=args.lora_alpha,\n            target_modules=args.lora_target_modules,\n        )\n        model.add_adapter(lora_config)\n\n    model.train()\n\n    if args.gradient_checkpointing:\n        model.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder.gradient_checkpointing_enable()\n\n    if args.use_ema:\n        ema = EMAModel(\n            model.parameters(),\n            decay=args.ema_decay,\n            update_after_step=args.ema_update_after_step,\n            model_cls=UVit2DModel,\n            model_config=model.config,\n        )\n\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n            text_encoder_lora_layers_to_save = None\n\n            for model_ in models:\n                if isinstance(model_, type(accelerator.unwrap_model(model))):\n                    if args.use_lora:\n                        transformer_lora_layers_to_save = get_peft_model_state_dict(model_)\n                    else:\n                        model_.save_pretrained(os.path.join(output_dir, \"transformer\"))\n                elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))):\n                    if args.text_encoder_use_lora:\n                        text_encoder_lora_layers_to_save = get_peft_model_state_dict(model_)\n                    else:\n                        model_.save_pretrained(os.path.join(output_dir, \"text_encoder\"))\n                else:\n                    raise ValueError(f\"unexpected save model: {model_.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:\n                AmusedLoraLoaderMixin.save_lora_weights(\n                    output_dir,\n                    transformer_lora_layers=transformer_lora_layers_to_save,\n                    text_encoder_lora_layers=text_encoder_lora_layers_to_save,\n                )\n\n            if args.use_ema:\n                ema.save_pretrained(os.path.join(output_dir, \"ema_model\"))\n\n    def load_model_hook(models, input_dir):\n        transformer = None\n        text_encoder_ = None\n\n        while len(models) > 0:\n            model_ = models.pop()\n\n            if isinstance(model_, type(accelerator.unwrap_model(model))):\n                if args.use_lora:\n                    transformer = model_\n                else:\n                    load_model = UVit2DModel.from_pretrained(os.path.join(input_dir, \"transformer\"))\n                    model_.load_state_dict(load_model.state_dict())\n                    del load_model\n            elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):\n                if args.text_encoder_use_lora:\n                    text_encoder_ = model_\n                else:\n                    load_model = CLIPTextModelWithProjection.from_pretrained(os.path.join(input_dir, \"text_encoder\"))\n                    model_.load_state_dict(load_model.state_dict())\n                    del load_model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        if transformer is not None or text_encoder_ is not None:\n            lora_state_dict, network_alphas = AmusedLoraLoaderMixin.lora_state_dict(input_dir)\n            AmusedLoraLoaderMixin.load_lora_into_text_encoder(\n                lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_\n            )\n            AmusedLoraLoaderMixin.load_lora_into_transformer(\n                lora_state_dict, network_alphas=network_alphas, transformer=transformer\n            )\n\n        if args.use_ema:\n            load_from = EMAModel.from_pretrained(os.path.join(input_dir, \"ema_model\"), model_cls=UVit2DModel)\n            ema.load_state_dict(load_from.state_dict())\n            del load_from\n\n    accelerator.register_load_state_pre_hook(load_model_hook)\n    accelerator.register_save_state_pre_hook(save_model_hook)\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n        )\n\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    # no decay on bias and layernorm and embedding\n    no_decay = [\"bias\", \"layer_norm.weight\", \"mlm_ln.weight\", \"embeddings.weight\"]\n    optimizer_grouped_parameters = [\n        {\n            \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n            \"weight_decay\": args.adam_weight_decay,\n        },\n        {\n            \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n            \"weight_decay\": 0.0,\n        },\n    ]\n\n    if args.train_text_encoder:\n        optimizer_grouped_parameters.append(\n            {\"params\": text_encoder.parameters(), \"weight_decay\": args.adam_weight_decay}\n        )\n\n    optimizer = optimizer_cls(\n        optimizer_grouped_parameters,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    logger.info(\"Creating dataloaders and lr_scheduler\")\n\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    if args.instance_data_dir is not None:\n        dataset = InstanceDataRootDataset(\n            instance_data_root=args.instance_data_dir,\n            tokenizer=tokenizer,\n            size=args.resolution,\n        )\n    elif args.instance_data_image is not None:\n        dataset = InstanceDataImageDataset(\n            instance_data_image=args.instance_data_image,\n            train_batch_size=args.train_batch_size,\n            size=args.resolution,\n        )\n    elif args.instance_data_dataset is not None:\n        dataset = HuggingFaceDataset(\n            hf_dataset=load_dataset(args.instance_data_dataset, split=\"train\"),\n            tokenizer=tokenizer,\n            image_key=args.image_key,\n            prompt_key=args.prompt_key,\n            prompt_prefix=args.prompt_prefix,\n            size=args.resolution,\n        )\n    else:\n        assert False\n\n    train_dataloader = DataLoader(\n        dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        num_workers=args.dataloader_num_workers,\n        collate_fn=default_collate,\n    )\n    train_dataloader.num_batches = len(train_dataloader)\n\n    lr_scheduler = diffusers.optimization.get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n    )\n\n    logger.info(\"Preparing model, optimizer and dataloaders\")\n\n    if args.train_text_encoder:\n        model, optimizer, lr_scheduler, train_dataloader, text_encoder = accelerator.prepare(\n            model, optimizer, lr_scheduler, train_dataloader, text_encoder\n        )\n    else:\n        model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(\n            model, optimizer, lr_scheduler, train_dataloader\n        )\n\n    train_dataloader.num_batches = len(train_dataloader)\n\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if not args.train_text_encoder:\n        text_encoder.to(device=accelerator.device, dtype=weight_dtype)\n\n    vq_model.to(device=accelerator.device)\n\n    if args.use_ema:\n        ema.to(accelerator.device)\n\n    with nullcontext() if args.train_text_encoder else torch.no_grad():\n        empty_embeds, empty_clip_embeds = encode_prompt(\n            text_encoder, tokenize_prompt(tokenizer, \"\").to(text_encoder.device, non_blocking=True)\n        )\n\n        # There is a single image, we can just pre-encode the single prompt\n        if args.instance_data_image is not None:\n            prompt = os.path.splitext(os.path.basename(args.instance_data_image))[0]\n            encoder_hidden_states, cond_embeds = encode_prompt(\n                text_encoder, tokenize_prompt(tokenizer, prompt).to(text_encoder.device, non_blocking=True)\n            )\n            encoder_hidden_states = encoder_hidden_states.repeat(args.train_batch_size, 1, 1)\n            cond_embeds = cond_embeds.repeat(args.train_batch_size, 1)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    # Afterwards we recalculate our number of training epochs.\n    # Note: We are not doing epoch based training here, but just using this for book keeping and being able to\n    # reuse the same training loop with other datasets/loaders.\n    num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # Train!\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num training steps = {args.max_train_steps}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n\n    resume_from_checkpoint = args.resume_from_checkpoint\n    if resume_from_checkpoint:\n        if resume_from_checkpoint == \"latest\":\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            if len(dirs) > 0:\n                resume_from_checkpoint = os.path.join(args.output_dir, dirs[-1])\n            else:\n                resume_from_checkpoint = None\n\n        if resume_from_checkpoint is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n        else:\n            accelerator.print(f\"Resuming from checkpoint {resume_from_checkpoint}\")\n\n    if resume_from_checkpoint is None:\n        global_step = 0\n        first_epoch = 0\n    else:\n        accelerator.load_state(resume_from_checkpoint)\n        global_step = int(os.path.basename(resume_from_checkpoint).split(\"-\")[1])\n        first_epoch = global_step // num_update_steps_per_epoch\n\n    # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to\n    # reuse the same training loop with other datasets/loaders.\n    for epoch in range(first_epoch, num_train_epochs):\n        for batch in train_dataloader:\n            with torch.no_grad():\n                micro_conds = batch[\"micro_conds\"].to(accelerator.device, non_blocking=True)\n                pixel_values = batch[\"image\"].to(accelerator.device, non_blocking=True)\n\n                batch_size = pixel_values.shape[0]\n\n                split_batch_size = args.split_vae_encode if args.split_vae_encode is not None else batch_size\n                num_splits = math.ceil(batch_size / split_batch_size)\n                image_tokens = []\n                for i in range(num_splits):\n                    start_idx = i * split_batch_size\n                    end_idx = min((i + 1) * split_batch_size, batch_size)\n                    bs = pixel_values.shape[0]\n                    image_tokens.append(\n                        vq_model.quantize(vq_model.encode(pixel_values[start_idx:end_idx]).latents)[2][2].reshape(\n                            bs, -1\n                        )\n                    )\n                image_tokens = torch.cat(image_tokens, dim=0)\n\n                batch_size, seq_len = image_tokens.shape\n\n                timesteps = torch.rand(batch_size, device=image_tokens.device)\n                mask_prob = torch.cos(timesteps * math.pi * 0.5)\n                mask_prob = mask_prob.clip(args.min_masking_rate)\n\n                num_token_masked = (seq_len * mask_prob).round().clamp(min=1)\n                batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1)\n                mask = batch_randperm < num_token_masked.unsqueeze(-1)\n\n                mask_id = accelerator.unwrap_model(model).config.vocab_size - 1\n                input_ids = torch.where(mask, mask_id, image_tokens)\n                labels = torch.where(mask, image_tokens, -100)\n\n                if args.cond_dropout_prob > 0.0:\n                    assert encoder_hidden_states is not None\n\n                    batch_size = encoder_hidden_states.shape[0]\n\n                    mask = (\n                        torch.zeros((batch_size, 1, 1), device=encoder_hidden_states.device).float().uniform_(0, 1)\n                        < args.cond_dropout_prob\n                    )\n\n                    empty_embeds_ = empty_embeds.expand(batch_size, -1, -1)\n                    encoder_hidden_states = torch.where(\n                        (encoder_hidden_states * mask).bool(), encoder_hidden_states, empty_embeds_\n                    )\n\n                    empty_clip_embeds_ = empty_clip_embeds.expand(batch_size, -1)\n                    cond_embeds = torch.where((cond_embeds * mask.squeeze(-1)).bool(), cond_embeds, empty_clip_embeds_)\n\n                bs = input_ids.shape[0]\n                vae_scale_factor = 2 ** (len(vq_model.config.block_out_channels) - 1)\n                resolution = args.resolution // vae_scale_factor\n                input_ids = input_ids.reshape(bs, resolution, resolution)\n\n            if \"prompt_input_ids\" in batch:\n                with nullcontext() if args.train_text_encoder else torch.no_grad():\n                    encoder_hidden_states, cond_embeds = encode_prompt(\n                        text_encoder, batch[\"prompt_input_ids\"].to(accelerator.device, non_blocking=True)\n                    )\n\n            # Train Step\n            with accelerator.accumulate(model):\n                codebook_size = accelerator.unwrap_model(model).config.codebook_size\n\n                logits = (\n                    model(\n                        input_ids=input_ids,\n                        encoder_hidden_states=encoder_hidden_states,\n                        micro_conds=micro_conds,\n                        pooled_text_emb=cond_embeds,\n                    )\n                    .reshape(bs, codebook_size, -1)\n                    .permute(0, 2, 1)\n                    .reshape(-1, codebook_size)\n                )\n\n                loss = F.cross_entropy(\n                    logits,\n                    labels.view(-1),\n                    ignore_index=-100,\n                    reduction=\"mean\",\n                )\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                avg_masking_rate = accelerator.gather(mask_prob.repeat(args.train_batch_size)).mean()\n\n                accelerator.backward(loss)\n\n                if args.max_grad_norm is not None and accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n\n                optimizer.zero_grad(set_to_none=True)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema.step(model.parameters())\n\n                if (global_step + 1) % args.logging_steps == 0:\n                    logs = {\n                        \"step_loss\": avg_loss.item(),\n                        \"lr\": lr_scheduler.get_last_lr()[0],\n                        \"avg_masking_rate\": avg_masking_rate.item(),\n                    }\n                    accelerator.log(logs, step=global_step + 1)\n\n                    logger.info(\n                        f\"Step: {global_step + 1} \"\n                        f\"Loss: {avg_loss.item():0.4f} \"\n                        f\"LR: {lr_scheduler.get_last_lr()[0]:0.6f}\"\n                    )\n\n                if (global_step + 1) % args.checkpointing_steps == 0:\n                    save_checkpoint(args, accelerator, global_step + 1)\n\n                if (global_step + 1) % args.validation_steps == 0 and accelerator.is_main_process:\n                    if args.use_ema:\n                        ema.store(model.parameters())\n                        ema.copy_to(model.parameters())\n\n                    with torch.no_grad():\n                        logger.info(\"Generating images...\")\n\n                        model.eval()\n\n                        if args.train_text_encoder:\n                            text_encoder.eval()\n\n                        scheduler = AmusedScheduler.from_pretrained(\n                            args.pretrained_model_name_or_path,\n                            subfolder=\"scheduler\",\n                            revision=args.revision,\n                            variant=args.variant,\n                        )\n\n                        pipe = AmusedPipeline(\n                            transformer=accelerator.unwrap_model(model),\n                            tokenizer=tokenizer,\n                            text_encoder=text_encoder,\n                            vqvae=vq_model,\n                            scheduler=scheduler,\n                        )\n\n                        pil_images = pipe(prompt=args.validation_prompts).images\n                        wandb_images = [\n                            wandb.Image(image, caption=args.validation_prompts[i])\n                            for i, image in enumerate(pil_images)\n                        ]\n\n                        wandb.log({\"generated_images\": wandb_images}, step=global_step + 1)\n\n                        model.train()\n\n                        if args.train_text_encoder:\n                            text_encoder.train()\n\n                    if args.use_ema:\n                        ema.restore(model.parameters())\n\n                global_step += 1\n\n            # Stop training if max steps is reached\n            if global_step >= args.max_train_steps:\n                break\n        # End for\n\n    accelerator.wait_for_everyone()\n\n    # Evaluate and save checkpoint at the end of training\n    save_checkpoint(args, accelerator, global_step)\n\n    # Save the final trained checkpoint\n    if accelerator.is_main_process:\n        model = accelerator.unwrap_model(model)\n        if args.use_ema:\n            ema.copy_to(model.parameters())\n        model.save_pretrained(args.output_dir)\n\n    accelerator.end_training()\n\n\ndef save_checkpoint(args, accelerator, global_step):\n    output_dir = args.output_dir\n\n    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n    if accelerator.is_main_process and args.checkpoints_total_limit is not None:\n        checkpoints = os.listdir(output_dir)\n        checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n        checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n        # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n        if len(checkpoints) >= args.checkpoints_total_limit:\n            num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n            removing_checkpoints = checkpoints[0:num_to_remove]\n\n            logger.info(\n                f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n            )\n            logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n            for removing_checkpoint in removing_checkpoints:\n                removing_checkpoint = os.path.join(output_dir, removing_checkpoint)\n                shutil.rmtree(removing_checkpoint)\n\n    save_path = Path(output_dir) / f\"checkpoint-{global_step}\"\n    accelerator.save_state(save_path)\n    logger.info(f\"Saved state to {save_path}\")\n\n\nif __name__ == \"__main__\":\n    main(parse_args())\n"
  },
  {
    "path": "examples/cogvideo/README.md",
    "content": "# LoRA finetuning example for CogVideoX\n\nLow-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.\n\nIn a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:\n\n- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).\n- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.\n- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.\n\nAt the moment, LoRA finetuning has only been tested for [CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b).\n\n> [!NOTE]\n> The scripts for CogVideoX come with limited support and may not be fully compatible with different training techniques. They are not feature-rich either and simply serve as minimal examples of finetuning to take inspiration from and improve.\n>\n> A repository containing memory-optimized finetuning scripts with support for multiple resolutions, dataset preparation, captioning, etc. is available [here](https://github.com/a-r-r-o-w/cogvideox-factory), which will be maintained jointly by the CogVideoX and Diffusers team.\n\n## Data Preparation\n\nThe training scripts accepts data in two formats.\n\n**First data format**\n\nTwo files where one file contains line-separated prompts and another file contains line-separated paths to video data (the path to video files must be relative to the path you pass when specifying `--instance_data_root`). Let's take a look at an example to understand this better!\n\nAssume you've specified `--instance_data_root` as `/dataset`, and that this directory contains the files: `prompts.txt` and `videos.txt`.\n\nThe `prompts.txt` file should contain line-separated prompts:\n\n```\nA black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.\nA black and white animated sequence on a ship's deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language. The character progresses from confident to focused, then to strained and distressed, displaying a range of emotions as it navigates challenges. The ship's interior remains static in the background, with minimalistic details such as a bell and open door. The character's dynamic movements and changing expressions drive the narrative, with no camera movement to distract from its evolving reactions and physical gestures.\n...\n```\n\nThe `videos.txt` file should contain line-separate paths to video files. Note that the path should be _relative_ to the `--instance_data_root` directory.\n\n```\nvideos/00000.mp4\nvideos/00001.mp4\n...\n```\n\nOverall, this is how your dataset would look like if you ran the `tree` command on the dataset root directory:\n\n```\n/dataset\n├── prompts.txt\n├── videos.txt\n├── videos\n    ├── videos/00000.mp4\n    ├── videos/00001.mp4\n    ├── ...\n```\n\nWhen using this format, the `--caption_column` must be `prompts.txt` and `--video_column` must be `videos.txt`.\n\n**Second data format**\n\nYou could use a single CSV file. For the sake of this example, assume you have a `metadata.csv` file. The expected format is:\n\n```\n<CAPTION_COLUMN>,<PATH_TO_VIDEO_COLUMN>\n\"\"\"A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.\"\"\",\"\"\"00000.mp4\"\"\"\n\"\"\"A black and white animated sequence on a ship's deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language. The character progresses from confident to focused, then to strained and distressed, displaying a range of emotions as it navigates challenges. The ship's interior remains static in the background, with minimalistic details such as a bell and open door. The character's dynamic movements and changing expressions drive the narrative, with no camera movement to distract from its evolving reactions and physical gestures.\"\"\",\"\"\"00001.mp4\"\"\"\n...\n```\n\nIn this case, the `--instance_data_root` should be the location where the videos are stored and `--dataset_name` should be either a path to local folder or `load_dataset` compatible hosted HF Dataset Repository or URL. Assuming you have videos of your Minecraft gameplay at `https://huggingface.co/datasets/my-awesome-username/minecraft-videos`, you would have to specify `my-awesome-username/minecraft-videos`.\n\nWhen using this format, the `--caption_column` must be `<CAPTION_COLUMN>` and `--video_column` must be `<PATH_TO_VIDEO_COLUMN>`.\n\nYou are not strictly restricted to the CSV format. As long as the `load_dataset` method supports the file format to load a basic `<PATH_TO_VIDEO_COLUMN>` and `<CAPTION_COLUMN>`, you should be good to go. The reason for going through these dataset organization gymnastics for loading video data is because we found `load_dataset` from the datasets library to not fully support all kinds of video formats. This will undoubtedly be improved in the future.\n\n>![NOTE]\n> CogVideoX works best with long and descriptive LLM-augmented prompts for video generation. We recommend pre-processing your videos by first generating a summary using a VLM and then augmenting the prompts with an LLM. To generate the above captions, we use [MiniCPM-V-26](https://huggingface.co/openbmb/MiniCPM-V-2_6) and [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct). A very barebones and no-frills example for this is available [here](https://gist.github.com/a-r-r-o-w/4dee20250e82f4e44690a02351324a4a). The official recommendation for augmenting prompts is [ChatGLM](https://huggingface.co/THUDM?search_models=chatglm) and a length of 50-100 words is considered good.\n\n>![NOTE]\n> It is expected that your dataset is already pre-processed. If not, some basic pre-processing can be done by playing with the following parameters:\n> `--height`, `--width`, `--fps`, `--max_num_frames`, `--skip_frames_start` and `--skip_frames_end`.\n> Presently, all videos in your dataset should contain the same number of video frames when using a training batch size > 1.\n\n<!-- TODO: Implement frame packing in future to address above issue. -->\n\n## Training\n\nYou need to setup your development environment by installing the necessary requirements. The following packages are required:\n- Torch 2.0 or above based on the training features you are utilizing (might require latest or nightly versions for quantized/deepspeed training)\n- `pip install diffusers transformers accelerate peft huggingface_hub` for all things modeling and training related\n- `pip install datasets decord` for loading video training data\n- `pip install bitsandbytes` for using 8-bit Adam or AdamW optimizers for memory-optimized training\n- `pip install wandb` optionally for monitoring training logs\n- `pip install deepspeed` optionally for [DeepSpeed](https://github.com/microsoft/DeepSpeed) training\n- `pip install prodigyopt` optionally if you would like to use the Prodigy optimizer for training\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nAnd initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\nIf you would like to push your model to the HF Hub after training is completed with a neat model card, make sure you're logged in:\n\n```\nhf auth login\n\n# Alternatively, you could upload your model manually using:\n# hf upload my-cool-account-name/my-cool-lora-name /path/to/awesome/lora\n```\n\nMake sure your data is prepared as described in [Data Preparation](#data-preparation). When ready, you can begin training!\n\nAssuming you are training on 50 videos of a similar concept, we have found 1500-2000 steps to work well. The official recommendation, however, is 100 videos with a total of 4000 steps. Assuming you are training on a single GPU with a `--train_batch_size` of `1`:\n- 1500 steps on 50 videos would correspond to `30` training epochs\n- 4000 steps on 100 videos would correspond to `40` training epochs\n\nThe following bash script launches training for text-to-video lora.\n\n```bash\n#!/bin/bash\n\nGPU_IDS=\"0\"\n\naccelerate launch --gpu_ids $GPU_IDS examples/cogvideo/train_cogvideox_lora.py \\\n  --pretrained_model_name_or_path THUDM/CogVideoX-2b \\\n  --cache_dir <CACHE_DIR> \\\n  --instance_data_root <PATH_TO_WHERE_VIDEO_FILES_ARE_STORED> \\\n  --dataset_name my-awesome-name/my-awesome-dataset \\\n  --caption_column <CAPTION_COLUMN> \\\n  --video_column <PATH_TO_VIDEO_COLUMN> \\\n  --id_token <ID_TOKEN> \\\n  --validation_prompt \"<ID_TOKEN> Spiderman swinging over buildings:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \\\n  --validation_prompt_separator ::: \\\n  --num_validation_videos 1 \\\n  --validation_epochs 10 \\\n  --seed 42 \\\n  --rank 64 \\\n  --lora_alpha 64 \\\n  --mixed_precision fp16 \\\n  --output_dir /raid/aryan/cogvideox-lora \\\n  --height 480 --width 720 --fps 8 --max_num_frames 49 --skip_frames_start 0 --skip_frames_end 0 \\\n  --train_batch_size 1 \\\n  --num_train_epochs 30 \\\n  --checkpointing_steps 1000 \\\n  --gradient_accumulation_steps 1 \\\n  --learning_rate 1e-3 \\\n  --lr_scheduler cosine_with_restarts \\\n  --lr_warmup_steps 200 \\\n  --lr_num_cycles 1 \\\n  --enable_slicing \\\n  --enable_tiling \\\n  --optimizer Adam \\\n  --adam_beta1 0.9 \\\n  --adam_beta2 0.95 \\\n  --max_grad_norm 1.0 \\\n  --report_to wandb\n```\n\nFor launching image-to-video finetuning instead, run the `train_cogvideox_image_to_video_lora.py` file instead. Additionally, you will have to pass `--validation_images` as paths to initial images corresponding to `--validation_prompts` for I2V validation to work.\n\nTo better track our training experiments, we're using the following flags in the command above:\n* `--report_to wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\nNote that setting the `<ID_TOKEN>` is not necessary. From some limited experimentation, we found it to work better (as it resembles [Dreambooth](https://huggingface.co/docs/diffusers/en/training/dreambooth) like training) than without. When provided, the ID_TOKEN is appended to the beginning of each prompt. So, if your ID_TOKEN was `\"DISNEY\"` and your prompt was `\"Spiderman swinging over buildings\"`, the effective prompt used in training would be `\"DISNEY Spiderman swinging over buildings\"`. When not provided, you would either be training without any such additional token or could augment your dataset to apply the token where you wish before starting the training.\n\n> [!TIP]\n> You can pass `--use_8bit_adam` to reduce the memory requirements of training.\n> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none']. See [this](https://gist.github.com/glide-the/7658dbfd5f555be0a1a687a4139dba40) notebook for examples.\n\n> [!IMPORTANT]\n> The following settings have been tested at the time of adding CogVideoX LoRA training support:\n> - Our testing was primarily done on CogVideoX-2b. We will work on CogVideoX-5b and CogVideoX-5b-I2V soon\n> - One dataset comprised of 70 training videos of resolutions `200 x 480 x 720` (F x H x W). From this, by using frame skipping in data preprocessing, we created two smaller 49-frame and 16-frame datasets for faster experimentation and because the maximum limit recommended by the CogVideoX team is 49 frames. Out of the 70 videos, we created three groups of 10, 25 and 50 videos. All videos were similar in nature of the concept being trained.\n> - 25+ videos worked best for training new concepts and styles.\n> - We found that it is better to train with an identifier token that can be specified as `--id_token`. This is similar to Dreambooth-like training but normal finetuning without such a token works too.\n> - Trained concept seemed to work decently well when combined with completely unrelated prompts. We expect even better results if CogVideoX-5B is finetuned.\n> - The original repository uses a `lora_alpha` of `1`. We found this not suitable in many runs, possibly due to difference in modeling backends and training settings. Our recommendation is to set to the `lora_alpha` to either `rank` or `rank // 2`.\n> - If you're training on data whose captions generate bad results with the original model, a `rank` of 64 and above is good and also the recommendation by the team behind CogVideoX. If the generations are already moderately good on your training captions, a `rank` of 16/32 should work. We found that setting the rank too low, say `4`, is not ideal and doesn't produce promising results.\n> - The authors of CogVideoX recommend 4000 training steps and 100 training videos overall to achieve the best result. While that might yield the best results, we found from our limited experimentation that 2000 steps and 25 videos could also be sufficient.\n> - When using the Prodigy optimizer for training, one can follow the recommendations from [this](https://huggingface.co/blog/sdxl_lora_advanced_script) blog. Prodigy tends to overfit quickly. From my very limited testing, I found a learning rate of `0.5` to be suitable in addition to `--prodigy_use_bias_correction`, `prodigy_safeguard_warmup` and `--prodigy_decouple`.\n> - The recommended learning rate by the CogVideoX authors and from our experimentation with Adam/AdamW is between `1e-3` and `1e-4` for a dataset of 25+ videos.\n>\n> Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data.\n\n## Inference\n\nOnce you have trained a lora model, the inference can be done simply loading the lora weights into the `CogVideoXPipeline`.\n\n```python\nimport torch\nfrom diffusers import CogVideoXPipeline\nfrom diffusers.utils import export_to_video\n\npipe = CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-2b\", torch_dtype=torch.float16)\n# pipe.load_lora_weights(\"/path/to/lora/weights\", adapter_name=\"cogvideox-lora\") # Or,\npipe.load_lora_weights(\"my-awesome-hf-username/my-awesome-lora-name\", adapter_name=\"cogvideox-lora\") # If loading from the HF Hub\npipe.to(\"cuda\")\n\n# Assuming lora_alpha=32 and rank=64 for training. If different, set accordingly\npipe.set_adapters([\"cogvideox-lora\"], [32 / 64])\n\nprompt = (\n    \"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The \"\n    \"panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other \"\n    \"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, \"\n    \"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. \"\n    \"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical \"\n    \"atmosphere of this unique musical performance\"\n)\nframes = pipe(prompt, guidance_scale=6, use_dynamic_cfg=True).frames[0]\nexport_to_video(frames, \"output.mp4\", fps=8)\n```\n\nIf you've trained a LoRA for `CogVideoXImageToVideoPipeline` instead, everything in the above example remains the same except you must also pass an image as initial condition for generation.\n"
  },
  {
    "path": "examples/cogvideo/requirements.txt",
    "content": "accelerate>=0.31.0\ntorchvision\ntransformers>=4.41.2\nftfy\ntensorboard\nJinja2\npeft>=0.11.1\nsentencepiece\ndecord>=0.6.0\nimageio-ffmpeg"
  },
  {
    "path": "examples/cogvideo/train_cogvideox_image_to_video_lora.py",
    "content": "# Copyright 2025 The HuggingFace Team.\n# All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom datetime import timedelta\nfrom pathlib import Path\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport transformers\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKLCogVideoX,\n    CogVideoXDPMScheduler,\n    CogVideoXImageToVideoPipeline,\n    CogVideoXTransformer3DModel,\n)\nfrom diffusers.models.embeddings import get_3d_rotary_pos_embed\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid\nfrom diffusers.training_utils import cast_training_params, free_memory\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    export_to_video,\n    is_wandb_available,\n    load_image,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef get_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script for CogVideoX.\")\n\n    # Model information\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    # Dataset information\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_root\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data.\"),\n    )\n    parser.add_argument(\n        \"--video_column\",\n        type=str,\n        default=\"video\",\n        help=\"The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.\",\n    )\n    parser.add_argument(\n        \"--id_token\", type=str, default=None, help=\"Identifier token appended to the start of each prompt if provided.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n\n    # Validation\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.\",\n    )\n    parser.add_argument(\n        \"--validation_images\",\n        type=str,\n        default=None,\n        help=\"One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_separator' string. These should correspond to the order of the validation prompts.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt_separator\",\n        type=str,\n        default=\":::\",\n        help=\"String that separates multiple validation prompts\",\n    )\n    parser.add_argument(\n        \"--num_validation_videos\",\n        type=int,\n        default=1,\n        help=\"Number of videos that should be generated during validation per `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=6,\n        help=\"The guidance scale to use while sampling validation videos.\",\n    )\n    parser.add_argument(\n        \"--use_dynamic_cfg\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.\",\n    )\n\n    # Training information\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=128,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=float,\n        default=128,\n        help=(\"The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`\"),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"cogvideox-i2v-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--height\",\n        type=int,\n        default=480,\n        help=\"All input videos are resized to this height.\",\n    )\n    parser.add_argument(\n        \"--width\",\n        type=int,\n        default=720,\n        help=\"All input videos are resized to this width.\",\n    )\n    parser.add_argument(\"--fps\", type=int, default=8, help=\"All input videos will be used at this FPS.\")\n    parser.add_argument(\n        \"--max_num_frames\", type=int, default=49, help=\"All input videos will be truncated to these many frames.\"\n    )\n    parser.add_argument(\n        \"--skip_frames_start\",\n        type=int,\n        default=0,\n        help=\"Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.\",\n    )\n    parser.add_argument(\n        \"--skip_frames_end\",\n        type=int,\n        default=0,\n        help=\"Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.\",\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip videos horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform. If provided, overrides `--num_train_epochs`.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--enable_slicing\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether or not to use VAE slicing for saving memory.\",\n    )\n    parser.add_argument(\n        \"--enable_tiling\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether or not to use VAE tiling for saving memory.\",\n    )\n    parser.add_argument(\n        \"--noised_image_dropout\",\n        type=float,\n        default=0.05,\n        help=\"Image condition dropout probability.\",\n    )\n\n    # Optimizer\n    parser.add_argument(\n        \"--optimizer\",\n        type=lambda s: s.lower(),\n        default=\"adam\",\n        choices=[\"adam\", \"adamw\", \"prodigy\"],\n        help=(\"The optimizer type to use.\"),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.95, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", action=\"store_true\", help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--prodigy_use_bias_correction\", action=\"store_true\", help=\"Turn on Adam's bias correction.\")\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        action=\"store_true\",\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage.\",\n    )\n\n    # Other information\n    parser.add_argument(\"--tracker_name\", type=str, default=None, help=\"Project tracker name\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=\"Directory where logs are stored.\",\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=None,\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--nccl_timeout\", type=int, default=600, help=\"NCCL backend timeout in seconds.\")\n\n    return parser.parse_args()\n\n\nclass VideoDataset(Dataset):\n    def __init__(\n        self,\n        instance_data_root: str | None = None,\n        dataset_name: str | None = None,\n        dataset_config_name: str | None = None,\n        caption_column: str = \"text\",\n        video_column: str = \"video\",\n        height: int = 480,\n        width: int = 720,\n        fps: int = 8,\n        max_num_frames: int = 49,\n        skip_frames_start: int = 0,\n        skip_frames_end: int = 0,\n        cache_dir: str | None = None,\n        id_token: str | None = None,\n    ) -> None:\n        super().__init__()\n\n        self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None\n        self.dataset_name = dataset_name\n        self.dataset_config_name = dataset_config_name\n        self.caption_column = caption_column\n        self.video_column = video_column\n        self.height = height\n        self.width = width\n        self.fps = fps\n        self.max_num_frames = max_num_frames\n        self.skip_frames_start = skip_frames_start\n        self.skip_frames_end = skip_frames_end\n        self.cache_dir = cache_dir\n        self.id_token = id_token or \"\"\n\n        if dataset_name is not None:\n            self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()\n        else:\n            self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()\n\n        self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts]\n\n        self.num_instance_videos = len(self.instance_video_paths)\n        if self.num_instance_videos != len(self.instance_prompts):\n            raise ValueError(\n                f\"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset.\"\n            )\n\n        self.instance_videos = self._preprocess_data()\n\n    def __len__(self):\n        return self.num_instance_videos\n\n    def __getitem__(self, index):\n        return {\n            \"instance_prompt\": self.instance_prompts[index],\n            \"instance_video\": self.instance_videos[index],\n        }\n\n    def _load_dataset_from_hub(self):\n        try:\n            from datasets import load_dataset\n        except ImportError:\n            raise ImportError(\n                \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                \"local folder containing images only, specify --instance_data_root instead.\"\n            )\n\n        # Downloading and loading a dataset from the hub. See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n        dataset = load_dataset(\n            self.dataset_name,\n            self.dataset_config_name,\n            cache_dir=self.cache_dir,\n        )\n        column_names = dataset[\"train\"].column_names\n\n        if self.video_column is None:\n            video_column = column_names[0]\n            logger.info(f\"`video_column` defaulting to {video_column}\")\n        else:\n            video_column = self.video_column\n            if video_column not in column_names:\n                raise ValueError(\n                    f\"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                )\n\n        if self.caption_column is None:\n            caption_column = column_names[1]\n            logger.info(f\"`caption_column` defaulting to {caption_column}\")\n        else:\n            caption_column = self.caption_column\n            if self.caption_column not in column_names:\n                raise ValueError(\n                    f\"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                )\n\n        instance_prompts = dataset[\"train\"][caption_column]\n        instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset[\"train\"][video_column]]\n\n        return instance_prompts, instance_videos\n\n    def _load_dataset_from_local_path(self):\n        if not self.instance_data_root.exists():\n            raise ValueError(\"Instance videos root folder does not exist\")\n\n        prompt_path = self.instance_data_root.joinpath(self.caption_column)\n        video_path = self.instance_data_root.joinpath(self.video_column)\n\n        if not prompt_path.exists() or not prompt_path.is_file():\n            raise ValueError(\n                \"Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts.\"\n            )\n        if not video_path.exists() or not video_path.is_file():\n            raise ValueError(\n                \"Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory.\"\n            )\n\n        with open(prompt_path, \"r\", encoding=\"utf-8\") as file:\n            instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]\n        with open(video_path, \"r\", encoding=\"utf-8\") as file:\n            instance_videos = [\n                self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0\n            ]\n\n        if any(not path.is_file() for path in instance_videos):\n            raise ValueError(\n                \"Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found at least one path that is not a valid file.\"\n            )\n\n        return instance_prompts, instance_videos\n\n    def _preprocess_data(self):\n        try:\n            import decord\n        except ImportError:\n            raise ImportError(\n                \"The `decord` package is required for loading the video dataset. Install with `pip install decord`\"\n            )\n\n        decord.bridge.set_bridge(\"torch\")\n\n        videos = []\n        train_transforms = transforms.Compose(\n            [\n                transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),\n            ]\n        )\n\n        for filename in self.instance_video_paths:\n            video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)\n            video_num_frames = len(video_reader)\n\n            start_frame = min(self.skip_frames_start, video_num_frames)\n            end_frame = max(0, video_num_frames - self.skip_frames_end)\n            if end_frame <= start_frame:\n                frames = video_reader.get_batch([start_frame])\n            elif end_frame - start_frame <= self.max_num_frames:\n                frames = video_reader.get_batch(list(range(start_frame, end_frame)))\n            else:\n                indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))\n                frames = video_reader.get_batch(indices)\n\n            # Ensure that we don't go over the limit\n            frames = frames[: self.max_num_frames]\n            selected_num_frames = frames.shape[0]\n\n            # Choose first (4k + 1) frames as this is how many is required by the VAE\n            remainder = (3 + (selected_num_frames % 4)) % 4\n            if remainder != 0:\n                frames = frames[:-remainder]\n            selected_num_frames = frames.shape[0]\n\n            assert (selected_num_frames - 1) % 4 == 0\n\n            # Training transforms\n            frames = frames.float()\n            frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)\n            videos.append(frames.permute(0, 3, 1, 2).contiguous())  # [F, C, H, W]\n\n        return videos\n\n\ndef save_model_card(\n    repo_id: str,\n    videos=None,\n    base_model: str = None,\n    validation_prompt=None,\n    repo_folder=None,\n    fps=8,\n):\n    widget_dict = []\n    if videos is not None:\n        for i, video in enumerate(videos):\n            video_path = f\"final_video_{i}.mp4\"\n            export_to_video(video, os.path.join(repo_folder, video_path, fps=fps))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": video_path}},\n            )\n\n    model_description = f\"\"\"\n# CogVideoX LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} LoRA weights for {base_model}.\n\nThe weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_image_to_video_lora.py).\n\nWas LoRA for the text encoder enabled? No.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nimport torch\nfrom diffusers import CogVideoXImageToVideoPipeline\nfrom diffusers.utils import load_image, export_to_video\n\npipe = CogVideoXImageToVideoPipeline.from_pretrained(\"THUDM/CogVideoX-5b\", torch_dtype=torch.bfloat16).to(\"cuda\")\npipe.load_lora_weights(\"{repo_id}\", weight_name=\"pytorch_lora_weights.safetensors\", adapter_name=[\"cogvideox-i2v-lora\"])\n\n# The LoRA adapter weights are determined by what was used for training.\n# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.\n# It can be made lower or higher from what was used in training to decrease or amplify the effect\n# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.\npipe.set_adapters([\"cogvideox-i2v-lora\"], [32 / 64])\n\nimage = load_image(\"/path/to/image\")\nvideo = pipe(image=image, \"{validation_prompt}\", guidance_scale=6, use_dynamic_cfg=True).frames[0]\nexport_to_video(video, \"output.mp4\", fps=8)\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=validation_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"image-to-video\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"cogvideox\",\n        \"cogvideox-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipe,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    is_final_validation: bool = False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}.\"\n    )\n    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n    scheduler_args = {}\n\n    if \"variance_type\" in pipe.scheduler.config:\n        variance_type = pipe.scheduler.config.variance_type\n\n        if variance_type in [\"learned\", \"learned_range\"]:\n            variance_type = \"fixed_small\"\n\n        scheduler_args[\"variance_type\"] = variance_type\n\n    pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)\n    pipe = pipe.to(accelerator.device)\n    # pipe.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n\n    videos = []\n    for _ in range(args.num_validation_videos):\n        video = pipe(**pipeline_args, generator=generator, output_type=\"np\").frames[0]\n        videos.append(video)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"wandb\":\n            video_filenames = []\n            for i, video in enumerate(videos):\n                prompt = (\n                    pipeline_args[\"prompt\"][:25]\n                    .replace(\" \", \"_\")\n                    .replace(\" \", \"_\")\n                    .replace(\"'\", \"_\")\n                    .replace('\"', \"_\")\n                    .replace(\"/\", \"_\")\n                )\n                filename = os.path.join(args.output_dir, f\"{phase_name}_video_{i}_{prompt}.mp4\")\n                export_to_video(video, filename, fps=8)\n                video_filenames.append(filename)\n\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Video(filename, caption=f\"{i}: {pipeline_args['prompt']}\")\n                        for i, filename in enumerate(video_filenames)\n                    ]\n                }\n            )\n\n    del pipe\n    free_memory()\n\n    return videos\n\n\ndef _get_t5_prompt_embeds(\n    tokenizer: T5Tokenizer,\n    text_encoder: T5EncoderModel,\n    prompt: Union[str, List[str]],\n    num_videos_per_prompt: int = 1,\n    max_sequence_length: int = 226,\n    device: Optional[torch.device] = None,\n    dtype: Optional[torch.dtype] = None,\n    text_input_ids=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            add_special_tokens=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"`text_input_ids` must be provided when the tokenizer is not specified.\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device))[0]\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    # duplicate text embeddings for each generation per prompt, using mps friendly method\n    _, seq_len, _ = prompt_embeds.shape\n    prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)\n\n    return prompt_embeds\n\n\ndef encode_prompt(\n    tokenizer: T5Tokenizer,\n    text_encoder: T5EncoderModel,\n    prompt: Union[str, List[str]],\n    num_videos_per_prompt: int = 1,\n    max_sequence_length: int = 226,\n    device: Optional[torch.device] = None,\n    dtype: Optional[torch.dtype] = None,\n    text_input_ids=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    prompt_embeds = _get_t5_prompt_embeds(\n        tokenizer,\n        text_encoder,\n        prompt=prompt,\n        num_videos_per_prompt=num_videos_per_prompt,\n        max_sequence_length=max_sequence_length,\n        device=device,\n        dtype=dtype,\n        text_input_ids=text_input_ids,\n    )\n    return prompt_embeds\n\n\ndef compute_prompt_embeddings(\n    tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False\n):\n    if requires_grad:\n        prompt_embeds = encode_prompt(\n            tokenizer,\n            text_encoder,\n            prompt,\n            num_videos_per_prompt=1,\n            max_sequence_length=max_sequence_length,\n            device=device,\n            dtype=dtype,\n        )\n    else:\n        with torch.no_grad():\n            prompt_embeds = encode_prompt(\n                tokenizer,\n                text_encoder,\n                prompt,\n                num_videos_per_prompt=1,\n                max_sequence_length=max_sequence_length,\n                device=device,\n                dtype=dtype,\n            )\n    return prompt_embeds\n\n\ndef prepare_rotary_positional_embeddings(\n    height: int,\n    width: int,\n    num_frames: int,\n    vae_scale_factor_spatial: int = 8,\n    patch_size: int = 2,\n    attention_head_dim: int = 64,\n    device: Optional[torch.device] = None,\n    base_height: int = 480,\n    base_width: int = 720,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    grid_height = height // (vae_scale_factor_spatial * patch_size)\n    grid_width = width // (vae_scale_factor_spatial * patch_size)\n    base_size_width = base_width // (vae_scale_factor_spatial * patch_size)\n    base_size_height = base_height // (vae_scale_factor_spatial * patch_size)\n\n    grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)\n    freqs_cos, freqs_sin = get_3d_rotary_pos_embed(\n        embed_dim=attention_head_dim,\n        crops_coords=grid_crops_coords,\n        grid_size=(grid_height, grid_width),\n        temporal_size=num_frames,\n        device=device,\n    )\n\n    return freqs_cos, freqs_sin\n\n\ndef get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):\n    # Use DeepSpeed optimizer\n    if use_deepspeed:\n        from accelerate.utils import DummyOptim\n\n        return DummyOptim(\n            params_to_optimize,\n            lr=args.learning_rate,\n            betas=(args.adam_beta1, args.adam_beta2),\n            eps=args.adam_epsilon,\n            weight_decay=args.adam_weight_decay,\n        )\n\n    # Optimizer creation\n    supported_optimizers = [\"adam\", \"adamw\", \"prodigy\"]\n    if args.optimizer not in supported_optimizers:\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and args.optimizer.lower() not in [\"adam\", \"adamw\"]:\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n    if args.optimizer.lower() == \"adamw\":\n        optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            eps=args.adam_epsilon,\n            weight_decay=args.adam_weight_decay,\n        )\n    elif args.optimizer.lower() == \"adam\":\n        optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            eps=args.adam_epsilon,\n            weight_decay=args.adam_weight_decay,\n        )\n    elif args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    return optimizer\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    init_kwargs = InitProcessGroupKwargs(backend=\"nccl\", timeout=timedelta(seconds=args.nccl_timeout))\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[ddp_kwargs, init_kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Prepare models and scheduler\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n    )\n\n    text_encoder = T5EncoderModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n\n    # CogVideoX-2b weights are stored in float16\n    # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16\n    load_dtype = torch.bfloat16 if \"5b\" in args.pretrained_model_name_or_path.lower() else torch.float16\n    transformer = CogVideoXTransformer3DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        torch_dtype=load_dtype,\n        revision=args.revision,\n        variant=args.variant,\n    )\n\n    vae = AutoencoderKLCogVideoX.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n    )\n\n    scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    if args.enable_slicing:\n        vae.enable_slicing()\n    if args.enable_tiling:\n        vae.enable_tiling()\n\n    # We only train the additional adapter LoRA layers\n    text_encoder.requires_grad_(False)\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.state.deepspeed_plugin:\n        # DeepSpeed is handling precision, use what's in the DeepSpeed config\n        if (\n            \"fp16\" in accelerator.state.deepspeed_plugin.deepspeed_config\n            and accelerator.state.deepspeed_plugin.deepspeed_config[\"fp16\"][\"enabled\"]\n        ):\n            weight_dtype = torch.float16\n        if (\n            \"bf16\" in accelerator.state.deepspeed_plugin.deepspeed_config\n            and accelerator.state.deepspeed_plugin.deepspeed_config[\"bf16\"][\"enabled\"]\n        ):\n            weight_dtype = torch.float16\n    else:\n        if accelerator.mixed_precision == \"fp16\":\n            weight_dtype = torch.float16\n        elif accelerator.mixed_precision == \"bf16\":\n            weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    transformer.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    # now we will add new LoRA weights to the attention layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.lora_alpha,\n        init_lora_weights=True,\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(unwrap_model(transformer))):\n                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            CogVideoXImageToVideoPipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(transformer))):\n                transformer_ = model\n            else:\n                raise ValueError(f\"Unexpected save model: {model.__class__}\")\n\n        lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params([transformer_])\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params([transformer], dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    use_deepspeed_optimizer = (\n        accelerator.state.deepspeed_plugin is not None\n        and \"optimizer\" in accelerator.state.deepspeed_plugin.deepspeed_config\n    )\n    use_deepspeed_scheduler = (\n        accelerator.state.deepspeed_plugin is not None\n        and \"scheduler\" in accelerator.state.deepspeed_plugin.deepspeed_config\n    )\n\n    optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)\n\n    # Dataset and DataLoader\n    train_dataset = VideoDataset(\n        instance_data_root=args.instance_data_root,\n        dataset_name=args.dataset_name,\n        dataset_config_name=args.dataset_config_name,\n        caption_column=args.caption_column,\n        video_column=args.video_column,\n        height=args.height,\n        width=args.width,\n        fps=args.fps,\n        max_num_frames=args.max_num_frames,\n        skip_frames_start=args.skip_frames_start,\n        skip_frames_end=args.skip_frames_end,\n        cache_dir=args.cache_dir,\n        id_token=args.id_token,\n    )\n\n    def encode_video(video):\n        video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)\n        video = video.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]\n        image = video[:, :, :1].clone()\n\n        latent_dist = vae.encode(video).latent_dist\n\n        image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)\n        image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)\n        noisy_image = image + torch.randn_like(image) * image_noise_sigma[:, None, None, None, None]\n        image_latent_dist = vae.encode(noisy_image).latent_dist\n\n        return latent_dist, image_latent_dist\n\n    train_dataset.instance_prompts = [\n        compute_prompt_embeddings(\n            tokenizer,\n            text_encoder,\n            [prompt],\n            transformer.config.max_text_seq_length,\n            accelerator.device,\n            weight_dtype,\n            requires_grad=False,\n        )\n        for prompt in train_dataset.instance_prompts\n    ]\n    train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]\n\n    def collate_fn(examples):\n        videos = []\n        images = []\n        for example in examples:\n            latent_dist, image_latent_dist = example[\"instance_video\"]\n\n            video_latents = latent_dist.sample() * vae.config.scaling_factor\n            image_latents = image_latent_dist.sample() * vae.config.scaling_factor\n            video_latents = video_latents.permute(0, 2, 1, 3, 4)\n            image_latents = image_latents.permute(0, 2, 1, 3, 4)\n\n            padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:])\n            latent_padding = image_latents.new_zeros(padding_shape)\n            image_latents = torch.cat([image_latents, latent_padding], dim=1)\n\n            if random.random() < args.noised_image_dropout:\n                image_latents = torch.zeros_like(image_latents)\n\n            videos.append(video_latents)\n            images.append(image_latents)\n\n        videos = torch.cat(videos)\n        images = torch.cat(images)\n        videos = videos.to(memory_format=torch.contiguous_format).float()\n        images = images.to(memory_format=torch.contiguous_format).float()\n\n        prompts = [example[\"instance_prompt\"] for example in examples]\n        prompts = torch.cat(prompts)\n\n        return {\n            \"videos\": (videos, images),\n            \"prompts\": prompts,\n        }\n\n    train_dataloader = DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=collate_fn,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    if use_deepspeed_scheduler:\n        from accelerate.utils import DummyScheduler\n\n        lr_scheduler = DummyScheduler(\n            name=args.lr_scheduler,\n            optimizer=optimizer,\n            total_num_steps=args.max_train_steps * accelerator.num_processes,\n            num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        )\n    else:\n        lr_scheduler = get_scheduler(\n            args.lr_scheduler,\n            optimizer=optimizer,\n            num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n            num_training_steps=args.max_train_steps * accelerator.num_processes,\n            num_cycles=args.lr_num_cycles,\n            power=args.lr_power,\n        )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = args.tracker_name or \"cogvideox-i2v-lora\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n    num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model[\"params\"])\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num trainable parameters = {num_trainable_parameters}\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if not args.resume_from_checkpoint:\n        initial_global_step = 0\n    else:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)\n\n    # For DeepSpeed training\n    model_config = transformer.module.config if hasattr(transformer, \"module\") else transformer.config\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n\n            with accelerator.accumulate(models_to_accumulate):\n                video_latents, image_latents = batch[\"videos\"]\n                prompt_embeds = batch[\"prompts\"]\n\n                video_latents = video_latents.to(dtype=weight_dtype)  # [B, F, C, H, W]\n                image_latents = image_latents.to(dtype=weight_dtype)  # [B, F, C, H, W]\n\n                batch_size, num_frames, num_channels, height, width = video_latents.shape\n\n                # Sample a random timestep for each image\n                timesteps = torch.randint(\n                    0, scheduler.config.num_train_timesteps, (batch_size,), device=video_latents.device\n                )\n                timesteps = timesteps.long()\n\n                # Sample noise that will be added to the latents\n                noise = torch.randn_like(video_latents)\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps)\n                noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2)\n\n                # Prepare rotary embeds\n                image_rotary_emb = (\n                    prepare_rotary_positional_embeddings(\n                        height=args.height,\n                        width=args.width,\n                        num_frames=num_frames,\n                        vae_scale_factor_spatial=vae_scale_factor_spatial,\n                        patch_size=model_config.patch_size,\n                        attention_head_dim=model_config.attention_head_dim,\n                        device=accelerator.device,\n                    )\n                    if model_config.use_rotary_positional_embeddings\n                    else None\n                )\n\n                # Predict the noise residual\n                model_output = transformer(\n                    hidden_states=noisy_model_input,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep=timesteps,\n                    image_rotary_emb=image_rotary_emb,\n                    return_dict=False,\n                )[0]\n                model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps)\n\n                alphas_cumprod = scheduler.alphas_cumprod[timesteps]\n                weights = 1 / (1 - alphas_cumprod)\n                while len(weights.shape) < len(model_pred.shape):\n                    weights = weights.unsqueeze(-1)\n\n                target = video_latents\n\n                loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)\n                loss = loss.mean()\n                accelerator.backward(loss)\n\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                if accelerator.state.deepspeed_plugin is None:\n                    optimizer.step()\n                    optimizer.zero_grad()\n\n                lr_scheduler.step()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"Removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:\n                # Create pipeline\n                pipe = CogVideoXImageToVideoPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    transformer=unwrap_model(transformer),\n                    scheduler=scheduler,\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n\n                validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)\n                validation_images = args.validation_images.split(args.validation_prompt_separator)\n\n                for validation_image, validation_prompt in zip(validation_images, validation_prompts):\n                    pipeline_args = {\n                        \"image\": load_image(validation_image),\n                        \"prompt\": validation_prompt,\n                        \"guidance_scale\": args.guidance_scale,\n                        \"use_dynamic_cfg\": args.use_dynamic_cfg,\n                        \"height\": args.height,\n                        \"width\": args.width,\n                    }\n\n                    validation_outputs = log_validation(\n                        pipe=pipe,\n                        args=args,\n                        accelerator=accelerator,\n                        pipeline_args=pipeline_args,\n                        epoch=epoch,\n                    )\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        transformer = unwrap_model(transformer)\n        dtype = (\n            torch.float16\n            if args.mixed_precision == \"fp16\"\n            else torch.bfloat16\n            if args.mixed_precision == \"bf16\"\n            else torch.float32\n        )\n        transformer = transformer.to(dtype)\n        transformer_lora_layers = get_peft_model_state_dict(transformer)\n\n        CogVideoXImageToVideoPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n        )\n\n        # Cleanup trained models to save memory\n        del transformer\n        free_memory()\n\n        # Final test inference\n        pipe = CogVideoXImageToVideoPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)\n\n        if args.enable_slicing:\n            pipe.vae.enable_slicing()\n        if args.enable_tiling:\n            pipe.vae.enable_tiling()\n\n        # Load LoRA weights\n        lora_scaling = args.lora_alpha / args.rank\n        pipe.load_lora_weights(args.output_dir, adapter_name=\"cogvideox-i2v-lora\")\n        pipe.set_adapters([\"cogvideox-i2v-lora\"], [lora_scaling])\n\n        # Run inference\n        validation_outputs = []\n        if args.validation_prompt and args.num_validation_videos > 0:\n            validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)\n            validation_images = args.validation_images.split(args.validation_prompt_separator)\n\n            for validation_image, validation_prompt in zip(validation_images, validation_prompts):\n                pipeline_args = {\n                    \"image\": load_image(validation_image),\n                    \"prompt\": validation_prompt,\n                    \"guidance_scale\": args.guidance_scale,\n                    \"use_dynamic_cfg\": args.use_dynamic_cfg,\n                    \"height\": args.height,\n                    \"width\": args.width,\n                }\n\n                video = log_validation(\n                    pipe=pipe,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                    is_final_validation=True,\n                )\n                validation_outputs.extend(video)\n\n        if args.push_to_hub:\n            validation_prompt = args.validation_prompt or \"\"\n            validation_prompt = validation_prompt.split(args.validation_prompt_separator)[0]\n            save_model_card(\n                repo_id,\n                videos=validation_outputs,\n                base_model=args.pretrained_model_name_or_path,\n                validation_prompt=validation_prompt,\n                repo_folder=args.output_dir,\n                fps=args.fps,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n    main(args)\n"
  },
  {
    "path": "examples/cogvideo/train_cogvideox_lora.py",
    "content": "# Copyright 2025 The HuggingFace Team.\n# All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport shutil\nfrom pathlib import Path\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torchvision.transforms as TT\nimport transformers\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision.transforms import InterpolationMode\nfrom torchvision.transforms.functional import resize\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer\n\nimport diffusers\nfrom diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.models.embeddings import get_3d_rotary_pos_embed\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid\nfrom diffusers.training_utils import cast_training_params, free_memory\nfrom diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef get_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script for CogVideoX.\")\n\n    # Model information\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    # Dataset information\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_root\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data.\"),\n    )\n    parser.add_argument(\n        \"--video_column\",\n        type=str,\n        default=\"video\",\n        help=\"The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.\",\n    )\n    parser.add_argument(\n        \"--id_token\", type=str, default=None, help=\"Identifier token appended to the start of each prompt if provided.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n\n    # Validation\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt_separator\",\n        type=str,\n        default=\":::\",\n        help=\"String that separates multiple validation prompts\",\n    )\n    parser.add_argument(\n        \"--num_validation_videos\",\n        type=int,\n        default=1,\n        help=\"Number of videos that should be generated during validation per `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=6,\n        help=\"The guidance scale to use while sampling validation videos.\",\n    )\n    parser.add_argument(\n        \"--use_dynamic_cfg\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.\",\n    )\n\n    # Training information\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=128,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=float,\n        default=128,\n        help=(\"The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`\"),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"cogvideox-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--height\",\n        type=int,\n        default=480,\n        help=\"All input videos are resized to this height.\",\n    )\n    parser.add_argument(\n        \"--width\",\n        type=int,\n        default=720,\n        help=\"All input videos are resized to this width.\",\n    )\n    parser.add_argument(\n        \"--video_reshape_mode\",\n        type=str,\n        default=\"center\",\n        help=\"All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']\",\n    )\n    parser.add_argument(\"--fps\", type=int, default=8, help=\"All input videos will be used at this FPS.\")\n    parser.add_argument(\n        \"--max_num_frames\", type=int, default=49, help=\"All input videos will be truncated to these many frames.\"\n    )\n    parser.add_argument(\n        \"--skip_frames_start\",\n        type=int,\n        default=0,\n        help=\"Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.\",\n    )\n    parser.add_argument(\n        \"--skip_frames_end\",\n        type=int,\n        default=0,\n        help=\"Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.\",\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip videos horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform. If provided, overrides `--num_train_epochs`.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--enable_slicing\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether or not to use VAE slicing for saving memory.\",\n    )\n    parser.add_argument(\n        \"--enable_tiling\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether or not to use VAE tiling for saving memory.\",\n    )\n\n    # Optimizer\n    parser.add_argument(\n        \"--optimizer\",\n        type=lambda s: s.lower(),\n        default=\"adam\",\n        choices=[\"adam\", \"adamw\", \"prodigy\"],\n        help=(\"The optimizer type to use.\"),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.95, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", action=\"store_true\", help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--prodigy_use_bias_correction\", action=\"store_true\", help=\"Turn on Adam's bias correction.\")\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        action=\"store_true\",\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage.\",\n    )\n\n    # Other information\n    parser.add_argument(\"--tracker_name\", type=str, default=None, help=\"Project tracker name\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=\"Directory where logs are stored.\",\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=None,\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n\n    return parser.parse_args()\n\n\nclass VideoDataset(Dataset):\n    def __init__(\n        self,\n        instance_data_root: str | None = None,\n        dataset_name: str | None = None,\n        dataset_config_name: str | None = None,\n        caption_column: str = \"text\",\n        video_column: str = \"video\",\n        height: int = 480,\n        width: int = 720,\n        video_reshape_mode: str = \"center\",\n        fps: int = 8,\n        max_num_frames: int = 49,\n        skip_frames_start: int = 0,\n        skip_frames_end: int = 0,\n        cache_dir: str | None = None,\n        id_token: str | None = None,\n    ) -> None:\n        super().__init__()\n\n        self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None\n        self.dataset_name = dataset_name\n        self.dataset_config_name = dataset_config_name\n        self.caption_column = caption_column\n        self.video_column = video_column\n        self.height = height\n        self.width = width\n        self.video_reshape_mode = video_reshape_mode\n        self.fps = fps\n        self.max_num_frames = max_num_frames\n        self.skip_frames_start = skip_frames_start\n        self.skip_frames_end = skip_frames_end\n        self.cache_dir = cache_dir\n        self.id_token = id_token or \"\"\n\n        if dataset_name is not None:\n            self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()\n        else:\n            self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()\n\n        self.num_instance_videos = len(self.instance_video_paths)\n        if self.num_instance_videos != len(self.instance_prompts):\n            raise ValueError(\n                f\"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset.\"\n            )\n\n        self.instance_videos = self._preprocess_data()\n\n    def __len__(self):\n        return self.num_instance_videos\n\n    def __getitem__(self, index):\n        return {\n            \"instance_prompt\": self.id_token + self.instance_prompts[index],\n            \"instance_video\": self.instance_videos[index],\n        }\n\n    def _load_dataset_from_hub(self):\n        try:\n            from datasets import load_dataset\n        except ImportError:\n            raise ImportError(\n                \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                \"local folder containing images only, specify --instance_data_root instead.\"\n            )\n\n        # Downloading and loading a dataset from the hub. See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n        dataset = load_dataset(\n            self.dataset_name,\n            self.dataset_config_name,\n            cache_dir=self.cache_dir,\n        )\n        column_names = dataset[\"train\"].column_names\n\n        if self.video_column is None:\n            video_column = column_names[0]\n            logger.info(f\"`video_column` defaulting to {video_column}\")\n        else:\n            video_column = self.video_column\n            if video_column not in column_names:\n                raise ValueError(\n                    f\"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                )\n\n        if self.caption_column is None:\n            caption_column = column_names[1]\n            logger.info(f\"`caption_column` defaulting to {caption_column}\")\n        else:\n            caption_column = self.caption_column\n            if self.caption_column not in column_names:\n                raise ValueError(\n                    f\"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                )\n\n        instance_prompts = dataset[\"train\"][caption_column]\n        instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset[\"train\"][video_column]]\n\n        return instance_prompts, instance_videos\n\n    def _load_dataset_from_local_path(self):\n        if not self.instance_data_root.exists():\n            raise ValueError(\"Instance videos root folder does not exist\")\n\n        prompt_path = self.instance_data_root.joinpath(self.caption_column)\n        video_path = self.instance_data_root.joinpath(self.video_column)\n\n        if not prompt_path.exists() or not prompt_path.is_file():\n            raise ValueError(\n                \"Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts.\"\n            )\n        if not video_path.exists() or not video_path.is_file():\n            raise ValueError(\n                \"Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory.\"\n            )\n\n        with open(prompt_path, \"r\", encoding=\"utf-8\") as file:\n            instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]\n        with open(video_path, \"r\", encoding=\"utf-8\") as file:\n            instance_videos = [\n                self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0\n            ]\n\n        if any(not path.is_file() for path in instance_videos):\n            raise ValueError(\n                \"Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found at least one path that is not a valid file.\"\n            )\n\n        return instance_prompts, instance_videos\n\n    def _resize_for_rectangle_crop(self, arr):\n        image_size = self.height, self.width\n        reshape_mode = self.video_reshape_mode\n        if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:\n            arr = resize(\n                arr,\n                size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],\n                interpolation=InterpolationMode.BICUBIC,\n            )\n        else:\n            arr = resize(\n                arr,\n                size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],\n                interpolation=InterpolationMode.BICUBIC,\n            )\n\n        h, w = arr.shape[2], arr.shape[3]\n        arr = arr.squeeze(0)\n\n        delta_h = h - image_size[0]\n        delta_w = w - image_size[1]\n\n        if reshape_mode == \"random\" or reshape_mode == \"none\":\n            top = np.random.randint(0, delta_h + 1)\n            left = np.random.randint(0, delta_w + 1)\n        elif reshape_mode == \"center\":\n            top, left = delta_h // 2, delta_w // 2\n        else:\n            raise NotImplementedError\n        arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])\n        return arr\n\n    def _preprocess_data(self):\n        try:\n            import decord\n        except ImportError:\n            raise ImportError(\n                \"The `decord` package is required for loading the video dataset. Install with `pip install decord`\"\n            )\n\n        decord.bridge.set_bridge(\"torch\")\n\n        progress_dataset_bar = tqdm(\n            range(0, len(self.instance_video_paths)),\n            desc=\"Loading progress resize and crop videos\",\n        )\n        videos = []\n\n        for filename in self.instance_video_paths:\n            video_reader = decord.VideoReader(uri=filename.as_posix())\n            video_num_frames = len(video_reader)\n\n            start_frame = min(self.skip_frames_start, video_num_frames)\n            end_frame = max(0, video_num_frames - self.skip_frames_end)\n            if end_frame <= start_frame:\n                frames = video_reader.get_batch([start_frame])\n            elif end_frame - start_frame <= self.max_num_frames:\n                frames = video_reader.get_batch(list(range(start_frame, end_frame)))\n            else:\n                indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))\n                frames = video_reader.get_batch(indices)\n\n            # Ensure that we don't go over the limit\n            frames = frames[: self.max_num_frames]\n            selected_num_frames = frames.shape[0]\n\n            # Choose first (4k + 1) frames as this is how many is required by the VAE\n            remainder = (3 + (selected_num_frames % 4)) % 4\n            if remainder != 0:\n                frames = frames[:-remainder]\n            selected_num_frames = frames.shape[0]\n\n            assert (selected_num_frames - 1) % 4 == 0\n\n            # Training transforms\n            frames = (frames - 127.5) / 127.5\n            frames = frames.permute(0, 3, 1, 2)  # [F, C, H, W]\n            progress_dataset_bar.set_description(\n                f\"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}\"\n            )\n            frames = self._resize_for_rectangle_crop(frames)\n            videos.append(frames.contiguous())  # [F, C, H, W]\n            progress_dataset_bar.update(1)\n\n        progress_dataset_bar.close()\n        return videos\n\n\ndef save_model_card(\n    repo_id: str,\n    videos=None,\n    base_model: str = None,\n    validation_prompt=None,\n    repo_folder=None,\n    fps=8,\n):\n    widget_dict = []\n    if videos is not None:\n        for i, video in enumerate(videos):\n            export_to_video(video, os.path.join(repo_folder, f\"final_video_{i}.mp4\", fps=fps))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"video_{i}.mp4\"}}\n            )\n\n    model_description = f\"\"\"\n# CogVideoX LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} LoRA weights for {base_model}.\n\nThe weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).\n\nWas LoRA for the text encoder enabled? No.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import CogVideoXPipeline\nimport torch\n\npipe = CogVideoXPipeline.from_pretrained(\"THUDM/CogVideoX-5b\", torch_dtype=torch.bfloat16).to(\"cuda\")\npipe.load_lora_weights(\"{repo_id}\", weight_name=\"pytorch_lora_weights.safetensors\", adapter_name=[\"cogvideox-lora\"])\n\n# The LoRA adapter weights are determined by what was used for training.\n# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.\n# It can be made lower or higher from what was used in training to decrease or amplify the effect\n# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows.\npipe.set_adapters([\"cogvideox-lora\"], [32 / 64])\n\nvideo = pipe(\"{validation_prompt}\", guidance_scale=6, use_dynamic_cfg=True).frames[0]\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b/blob/main/LICENSE) and [here](https://huggingface.co/THUDM/CogVideoX-2b/blob/main/LICENSE).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=validation_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-video\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"cogvideox\",\n        \"cogvideox-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipe,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    is_final_validation: bool = False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}.\"\n    )\n    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n    scheduler_args = {}\n\n    if \"variance_type\" in pipe.scheduler.config:\n        variance_type = pipe.scheduler.config.variance_type\n\n        if variance_type in [\"learned\", \"learned_range\"]:\n            variance_type = \"fixed_small\"\n\n        scheduler_args[\"variance_type\"] = variance_type\n\n    pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)\n    pipe = pipe.to(accelerator.device)\n    # pipe.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n\n    videos = []\n    for _ in range(args.num_validation_videos):\n        pt_images = pipe(**pipeline_args, generator=generator, output_type=\"pt\").frames[0]\n        pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])\n\n        image_np = VaeImageProcessor.pt_to_numpy(pt_images)\n        image_pil = VaeImageProcessor.numpy_to_pil(image_np)\n\n        videos.append(image_pil)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"wandb\":\n            video_filenames = []\n            for i, video in enumerate(videos):\n                prompt = (\n                    pipeline_args[\"prompt\"][:25]\n                    .replace(\" \", \"_\")\n                    .replace(\" \", \"_\")\n                    .replace(\"'\", \"_\")\n                    .replace('\"', \"_\")\n                    .replace(\"/\", \"_\")\n                )\n                filename = os.path.join(args.output_dir, f\"{phase_name}_video_{i}_{prompt}.mp4\")\n                export_to_video(video, filename, fps=8)\n                video_filenames.append(filename)\n\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Video(filename, caption=f\"{i}: {pipeline_args['prompt']}\")\n                        for i, filename in enumerate(video_filenames)\n                    ]\n                }\n            )\n\n    del pipe\n    free_memory()\n\n    return videos\n\n\ndef _get_t5_prompt_embeds(\n    tokenizer: T5Tokenizer,\n    text_encoder: T5EncoderModel,\n    prompt: Union[str, List[str]],\n    num_videos_per_prompt: int = 1,\n    max_sequence_length: int = 226,\n    device: Optional[torch.device] = None,\n    dtype: Optional[torch.dtype] = None,\n    text_input_ids=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            add_special_tokens=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"`text_input_ids` must be provided when the tokenizer is not specified.\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device))[0]\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    # duplicate text embeddings for each generation per prompt, using mps friendly method\n    _, seq_len, _ = prompt_embeds.shape\n    prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)\n\n    return prompt_embeds\n\n\ndef encode_prompt(\n    tokenizer: T5Tokenizer,\n    text_encoder: T5EncoderModel,\n    prompt: Union[str, List[str]],\n    num_videos_per_prompt: int = 1,\n    max_sequence_length: int = 226,\n    device: Optional[torch.device] = None,\n    dtype: Optional[torch.dtype] = None,\n    text_input_ids=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    prompt_embeds = _get_t5_prompt_embeds(\n        tokenizer,\n        text_encoder,\n        prompt=prompt,\n        num_videos_per_prompt=num_videos_per_prompt,\n        max_sequence_length=max_sequence_length,\n        device=device,\n        dtype=dtype,\n        text_input_ids=text_input_ids,\n    )\n    return prompt_embeds\n\n\ndef compute_prompt_embeddings(\n    tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False\n):\n    if requires_grad:\n        prompt_embeds = encode_prompt(\n            tokenizer,\n            text_encoder,\n            prompt,\n            num_videos_per_prompt=1,\n            max_sequence_length=max_sequence_length,\n            device=device,\n            dtype=dtype,\n        )\n    else:\n        with torch.no_grad():\n            prompt_embeds = encode_prompt(\n                tokenizer,\n                text_encoder,\n                prompt,\n                num_videos_per_prompt=1,\n                max_sequence_length=max_sequence_length,\n                device=device,\n                dtype=dtype,\n            )\n    return prompt_embeds\n\n\ndef prepare_rotary_positional_embeddings(\n    height: int,\n    width: int,\n    num_frames: int,\n    vae_scale_factor_spatial: int = 8,\n    patch_size: int = 2,\n    attention_head_dim: int = 64,\n    device: Optional[torch.device] = None,\n    base_height: int = 480,\n    base_width: int = 720,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    grid_height = height // (vae_scale_factor_spatial * patch_size)\n    grid_width = width // (vae_scale_factor_spatial * patch_size)\n    base_size_width = base_width // (vae_scale_factor_spatial * patch_size)\n    base_size_height = base_height // (vae_scale_factor_spatial * patch_size)\n\n    grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)\n    freqs_cos, freqs_sin = get_3d_rotary_pos_embed(\n        embed_dim=attention_head_dim,\n        crops_coords=grid_crops_coords,\n        grid_size=(grid_height, grid_width),\n        temporal_size=num_frames,\n        device=device,\n    )\n\n    return freqs_cos, freqs_sin\n\n\ndef get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):\n    # Use DeepSpeed optimizer\n    if use_deepspeed:\n        from accelerate.utils import DummyOptim\n\n        return DummyOptim(\n            params_to_optimize,\n            lr=args.learning_rate,\n            betas=(args.adam_beta1, args.adam_beta2),\n            eps=args.adam_epsilon,\n            weight_decay=args.adam_weight_decay,\n        )\n\n    # Optimizer creation\n    supported_optimizers = [\"adam\", \"adamw\", \"prodigy\"]\n    if args.optimizer not in supported_optimizers:\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and args.optimizer.lower() not in [\"adam\", \"adamw\"]:\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n    if args.optimizer.lower() == \"adamw\":\n        optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            eps=args.adam_epsilon,\n            weight_decay=args.adam_weight_decay,\n        )\n    elif args.optimizer.lower() == \"adam\":\n        optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            eps=args.adam_epsilon,\n            weight_decay=args.adam_weight_decay,\n        )\n    elif args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    return optimizer\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Prepare models and scheduler\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n    )\n\n    text_encoder = T5EncoderModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n\n    # CogVideoX-2b weights are stored in float16\n    # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16\n    load_dtype = torch.bfloat16 if \"5b\" in args.pretrained_model_name_or_path.lower() else torch.float16\n    transformer = CogVideoXTransformer3DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        torch_dtype=load_dtype,\n        revision=args.revision,\n        variant=args.variant,\n    )\n\n    vae = AutoencoderKLCogVideoX.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n    )\n\n    scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    if args.enable_slicing:\n        vae.enable_slicing()\n    if args.enable_tiling:\n        vae.enable_tiling()\n\n    # We only train the additional adapter LoRA layers\n    text_encoder.requires_grad_(False)\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.state.deepspeed_plugin:\n        # DeepSpeed is handling precision, use what's in the DeepSpeed config\n        if (\n            \"fp16\" in accelerator.state.deepspeed_plugin.deepspeed_config\n            and accelerator.state.deepspeed_plugin.deepspeed_config[\"fp16\"][\"enabled\"]\n        ):\n            weight_dtype = torch.float16\n        if (\n            \"bf16\" in accelerator.state.deepspeed_plugin.deepspeed_config\n            and accelerator.state.deepspeed_plugin.deepspeed_config[\"bf16\"][\"enabled\"]\n        ):\n            weight_dtype = torch.float16\n    else:\n        if accelerator.mixed_precision == \"fp16\":\n            weight_dtype = torch.float16\n        elif accelerator.mixed_precision == \"bf16\":\n            weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    transformer.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    # now we will add new LoRA weights to the attention layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.lora_alpha,\n        init_lora_weights=True,\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(unwrap_model(transformer))):\n                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            CogVideoXPipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(transformer))):\n                transformer_ = model\n            else:\n                raise ValueError(f\"Unexpected save model: {model.__class__}\")\n\n        lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params([transformer_])\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params([transformer], dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    use_deepspeed_optimizer = (\n        accelerator.state.deepspeed_plugin is not None\n        and \"optimizer\" in accelerator.state.deepspeed_plugin.deepspeed_config\n    )\n    use_deepspeed_scheduler = (\n        accelerator.state.deepspeed_plugin is not None\n        and \"scheduler\" in accelerator.state.deepspeed_plugin.deepspeed_config\n    )\n\n    optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)\n\n    # Dataset and DataLoader\n    train_dataset = VideoDataset(\n        instance_data_root=args.instance_data_root,\n        dataset_name=args.dataset_name,\n        dataset_config_name=args.dataset_config_name,\n        caption_column=args.caption_column,\n        video_column=args.video_column,\n        height=args.height,\n        width=args.width,\n        video_reshape_mode=args.video_reshape_mode,\n        fps=args.fps,\n        max_num_frames=args.max_num_frames,\n        skip_frames_start=args.skip_frames_start,\n        skip_frames_end=args.skip_frames_end,\n        cache_dir=args.cache_dir,\n        id_token=args.id_token,\n    )\n\n    def encode_video(video):\n        video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)\n        video = video.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]\n        latent_dist = vae.encode(video).latent_dist\n        return latent_dist\n\n    # Distribute video encoding across processes: each process only encodes its own shard\n    num_videos = len(train_dataset.instance_videos)\n    num_procs = accelerator.num_processes\n    local_rank = accelerator.process_index\n    local_count = len(range(local_rank, num_videos, num_procs))\n\n    progress_encode_bar = tqdm(\n        range(local_count),\n        desc=\"Encoding videos\",\n        disable=not accelerator.is_local_main_process,\n    )\n\n    encoded_videos = [None] * num_videos\n    for i, video in enumerate(train_dataset.instance_videos):\n        if i % num_procs == local_rank:\n            encoded_videos[i] = encode_video(video)\n            progress_encode_bar.update(1)\n    progress_encode_bar.close()\n\n    # Broadcast encoded latent distributions so every process has the full set\n    if num_procs > 1:\n        import torch.distributed as dist\n\n        from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution\n\n        ref_params = next(v for v in encoded_videos if v is not None).parameters\n        for i in range(num_videos):\n            src = i % num_procs\n            if encoded_videos[i] is not None:\n                params = encoded_videos[i].parameters.contiguous()\n            else:\n                params = torch.empty_like(ref_params)\n            dist.broadcast(params, src=src)\n            encoded_videos[i] = DiagonalGaussianDistribution(params)\n\n    train_dataset.instance_videos = encoded_videos\n\n    def collate_fn(examples):\n        videos = [example[\"instance_video\"].sample() * vae.config.scaling_factor for example in examples]\n        prompts = [example[\"instance_prompt\"] for example in examples]\n\n        videos = torch.cat(videos)\n        videos = videos.permute(0, 2, 1, 3, 4)\n        videos = videos.to(memory_format=torch.contiguous_format).float()\n\n        return {\n            \"videos\": videos,\n            \"prompts\": prompts,\n        }\n\n    train_dataloader = DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=collate_fn,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    if use_deepspeed_scheduler:\n        from accelerate.utils import DummyScheduler\n\n        lr_scheduler = DummyScheduler(\n            name=args.lr_scheduler,\n            optimizer=optimizer,\n            total_num_steps=args.max_train_steps * accelerator.num_processes,\n            num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        )\n    else:\n        lr_scheduler = get_scheduler(\n            args.lr_scheduler,\n            optimizer=optimizer,\n            num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n            num_training_steps=args.max_train_steps * accelerator.num_processes,\n            num_cycles=args.lr_num_cycles,\n            power=args.lr_power,\n        )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = args.tracker_name or \"cogvideox-lora\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n    num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model[\"params\"])\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num trainable parameters = {num_trainable_parameters}\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if not args.resume_from_checkpoint:\n        initial_global_step = 0\n    else:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)\n\n    # For DeepSpeed training\n    model_config = transformer.module.config if hasattr(transformer, \"module\") else transformer.config\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n\n            with accelerator.accumulate(models_to_accumulate):\n                model_input = batch[\"videos\"].to(dtype=weight_dtype)  # [B, F, C, H, W]\n                prompts = batch[\"prompts\"]\n\n                # encode prompts\n                prompt_embeds = compute_prompt_embeddings(\n                    tokenizer,\n                    text_encoder,\n                    prompts,\n                    model_config.max_text_seq_length,\n                    accelerator.device,\n                    weight_dtype,\n                    requires_grad=False,\n                )\n\n                # Sample noise that will be added to the latents\n                noise = torch.randn_like(model_input)\n                batch_size, num_frames, num_channels, height, width = model_input.shape\n\n                # Sample a random timestep for each image\n                timesteps = torch.randint(\n                    0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device\n                )\n                timesteps = timesteps.long()\n\n                # Prepare rotary embeds\n                image_rotary_emb = (\n                    prepare_rotary_positional_embeddings(\n                        height=args.height,\n                        width=args.width,\n                        num_frames=num_frames,\n                        vae_scale_factor_spatial=vae_scale_factor_spatial,\n                        patch_size=model_config.patch_size,\n                        attention_head_dim=model_config.attention_head_dim,\n                        device=accelerator.device,\n                    )\n                    if model_config.use_rotary_positional_embeddings\n                    else None\n                )\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)\n\n                # Predict the noise residual\n                model_output = transformer(\n                    hidden_states=noisy_model_input,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep=timesteps,\n                    image_rotary_emb=image_rotary_emb,\n                    return_dict=False,\n                )[0]\n                model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps)\n\n                alphas_cumprod = scheduler.alphas_cumprod[timesteps]\n                weights = 1 / (1 - alphas_cumprod)\n                while len(weights.shape) < len(model_pred.shape):\n                    weights = weights.unsqueeze(-1)\n\n                target = model_input\n\n                loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)\n                loss = loss.mean()\n                accelerator.backward(loss)\n\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                if accelerator.state.deepspeed_plugin is None:\n                    optimizer.step()\n                    optimizer.zero_grad()\n\n                lr_scheduler.step()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"Removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:\n                # Create pipeline\n                pipe = CogVideoXPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    transformer=unwrap_model(transformer),\n                    text_encoder=unwrap_model(text_encoder),\n                    scheduler=scheduler,\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n\n                validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)\n                for validation_prompt in validation_prompts:\n                    pipeline_args = {\n                        \"prompt\": validation_prompt,\n                        \"guidance_scale\": args.guidance_scale,\n                        \"use_dynamic_cfg\": args.use_dynamic_cfg,\n                        \"height\": args.height,\n                        \"width\": args.width,\n                    }\n\n                    validation_outputs = log_validation(\n                        pipe=pipe,\n                        args=args,\n                        accelerator=accelerator,\n                        pipeline_args=pipeline_args,\n                        epoch=epoch,\n                    )\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        transformer = unwrap_model(transformer)\n        dtype = (\n            torch.float16\n            if args.mixed_precision == \"fp16\"\n            else torch.bfloat16\n            if args.mixed_precision == \"bf16\"\n            else torch.float32\n        )\n        transformer = transformer.to(dtype)\n        transformer_lora_layers = get_peft_model_state_dict(transformer)\n\n        CogVideoXPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n        )\n\n        # Cleanup trained models to save memory\n        del transformer\n        free_memory()\n\n        # Final test inference\n        pipe = CogVideoXPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)\n\n        if args.enable_slicing:\n            pipe.vae.enable_slicing()\n        if args.enable_tiling:\n            pipe.vae.enable_tiling()\n\n        # Load LoRA weights\n        lora_scaling = args.lora_alpha / args.rank\n        pipe.load_lora_weights(args.output_dir, adapter_name=\"cogvideox-lora\")\n        pipe.set_adapters([\"cogvideox-lora\"], [lora_scaling])\n\n        # Run inference\n        validation_outputs = []\n        if args.validation_prompt and args.num_validation_videos > 0:\n            validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)\n            for validation_prompt in validation_prompts:\n                pipeline_args = {\n                    \"prompt\": validation_prompt,\n                    \"guidance_scale\": args.guidance_scale,\n                    \"use_dynamic_cfg\": args.use_dynamic_cfg,\n                    \"height\": args.height,\n                    \"width\": args.width,\n                }\n\n                video = log_validation(\n                    pipe=pipe,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                    is_final_validation=True,\n                )\n                validation_outputs.extend(video)\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                videos=validation_outputs,\n                base_model=args.pretrained_model_name_or_path,\n                validation_prompt=args.validation_prompt,\n                repo_folder=args.output_dir,\n                fps=args.fps,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n    main(args)\n"
  },
  {
    "path": "examples/cogview4-control/README.md",
    "content": "# Training CogView4 Control\n\nThis (experimental) example shows how to train Control LoRAs with [CogView4](https://huggingface.co/THUDM/CogView4-6B) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about CogView4 Control family, refer to the following resources:\n\nTo incorporate additional condition latents, we expand the input features of CogView-4 from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `patch_embed` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `CogView4ControlPipeline`.\n\n> [!NOTE]\n> **Gated model**\n>\n> As the model is gated, before using it with diffusers you first need to go to the [CogView4 Hugging Face page](https://huggingface.co/THUDM/CogView4-6B), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:\n\n```bash\nhf auth login\n```\n\nThe example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.\n\n```bash\naccelerate launch train_control_lora_cogview4.py \\\n  --pretrained_model_name_or_path=\"THUDM/CogView4-6B\" \\\n  --dataset_name=\"raulc0399/open_pose_controlnet\" \\\n  --output_dir=\"pose-control-lora\" \\\n  --mixed_precision=\"bf16\" \\\n  --train_batch_size=1 \\\n  --rank=64 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=5000 \\\n  --validation_image=\"openpose.png\" \\\n  --validation_prompt=\"A couple, 4k photo, highly detailed\" \\\n  --offload \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).\n\nYou need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`.\n\nThe training script exposes additional CLI args that might be useful to experiment with:\n\n* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer. \n* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading.\n* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify \"all-linear\", all the linear layers will be LoRA-attached.\n\n### Training with DeepSpeed\n\nIt's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed):\n\n```yaml\ncompute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  gradient_clipping: 1.0\n  offload_optimizer_device: cpu\n  offload_param_device: cpu\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 1\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n```\n\nAnd then while launching training, pass the config file:\n\n```bash\naccelerate launch --config_file=CONFIG_FILE.yaml ...\n```\n\n### Inference\n\nThe pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first:\n\n```bash\npip install controlnet_aux\n```\n\nAnd then we are ready:\n\n```py\nfrom controlnet_aux import OpenposeDetector\nfrom diffusers import CogView4ControlPipeline\nfrom diffusers.utils import load_image\nfrom PIL import Image\nimport numpy as np\nimport torch \n\npipe = CogView4ControlPipeline.from_pretrained(\"THUDM/CogView4-6B\", torch_dtype=torch.bfloat16).to(\"cuda\")\npipe.load_lora_weights(\"...\") # change this.\n\nopen_pose = OpenposeDetector.from_pretrained(\"lllyasviel/Annotators\")\n\n# prepare pose condition.\nurl = \"https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg\"\nimage = load_image(url)\nimage = open_pose(image, detect_resolution=512, image_resolution=1024)\nimage = np.array(image)[:, :, ::-1]           \nimage = Image.fromarray(np.uint8(image))\n\nprompt = \"A couple, 4k photo, highly detailed\"\n\ngen_images = pipe(\n  prompt=prompt,\n  control_image=image,\n  num_inference_steps=50,\n  joint_attention_kwargs={\"scale\": 0.9},\n  guidance_scale=25., \n).images[0]\ngen_images.save(\"output.png\")\n```\n\n## Full fine-tuning\n\nWe provide a non-LoRA version of the training script `train_control_cogview4.py`. Here is an example command:\n\n```bash\naccelerate launch --config_file=accelerate_ds2.yaml train_control_cogview4.py \\\n  --pretrained_model_name_or_path=\"THUDM/CogView4-6B\" \\\n  --dataset_name=\"raulc0399/open_pose_controlnet\" \\\n  --output_dir=\"pose-control\" \\\n  --mixed_precision=\"bf16\" \\\n  --train_batch_size=2 \\\n  --dataloader_num_workers=4 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --proportion_empty_prompts=0.2 \\\n  --learning_rate=5e-5 \\\n  --adam_weight_decay=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"cosine\" \\\n  --lr_warmup_steps=1000 \\\n  --checkpointing_steps=1000 \\\n  --max_train_steps=10000 \\\n  --validation_steps=200 \\\n  --validation_image \"2_pose_1024.jpg\" \"3_pose_1024.jpg\" \\\n  --validation_prompt \"two friends sitting by each other enjoying a day at the park, full hd, cinematic\" \"person enjoying a day at the park, full hd, cinematic\" \\\n  --offload \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nChange the `validation_image` and `validation_prompt` as needed.\n\nFor inference, this time, we will run:\n\n```py\nfrom controlnet_aux import OpenposeDetector\nfrom diffusers import CogView4ControlPipeline, CogView4Transformer2DModel\nfrom diffusers.utils import load_image\nfrom PIL import Image\nimport numpy as np\nimport torch \n\ntransformer = CogView4Transformer2DModel.from_pretrained(\"...\") # change this.\npipe = CogView4ControlPipeline.from_pretrained(\n  \"THUDM/CogView4-6B\",  transformer=transformer, torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\nopen_pose = OpenposeDetector.from_pretrained(\"lllyasviel/Annotators\")\n\n# prepare pose condition.\nurl = \"https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg\"\nimage = load_image(url)\nimage = open_pose(image, detect_resolution=512, image_resolution=1024)\nimage = np.array(image)[:, :, ::-1]           \nimage = Image.fromarray(np.uint8(image))\n\nprompt = \"A couple, 4k photo, highly detailed\"\n\ngen_images = pipe(\n  prompt=prompt,\n  control_image=image,\n  num_inference_steps=50,\n  guidance_scale=25., \n).images[0]\ngen_images.save(\"output.png\")\n```\n\n## Things to note\n\n* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗\n* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified. \n* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality. "
  },
  {
    "path": "examples/cogview4-control/requirements.txt",
    "content": "transformers==4.47.0\nwandb\ntorch\ntorchvision\naccelerate==1.2.0\npeft>=0.14.0\n"
  },
  {
    "path": "examples/cogview4-control/train_control_cogview4.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport accelerate\nimport numpy as np\nimport torch\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedType, ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    CogView4ControlPipeline,\n    CogView4Transformer2DModel,\n    FlowMatchEulerDiscreteScheduler,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    free_memory,\n)\nfrom diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\nNORM_LAYER_PREFIXES = [\"norm_q\", \"norm_k\", \"norm_added_q\", \"norm_added_k\"]\n\n\ndef encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype):\n    pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample()\n    pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor\n    return pixel_latents.to(weight_dtype)\n\n\ndef log_validation(cogview4_transformer, args, accelerator, weight_dtype, step, is_final_validation=False):\n    logger.info(\"Running validation... \")\n\n    if not is_final_validation:\n        cogview4_transformer = accelerator.unwrap_model(cogview4_transformer)\n        pipeline = CogView4ControlPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            transformer=cogview4_transformer,\n            torch_dtype=weight_dtype,\n        )\n    else:\n        transformer = CogView4Transformer2DModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)\n        pipeline = CogView4ControlPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            transformer=transformer,\n            torch_dtype=weight_dtype,\n        )\n\n    pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n    if is_final_validation or torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype)\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        validation_image = load_image(validation_image)\n        # maybe need to inference on 1024 to get a good image\n        validation_image = validation_image.resize((args.resolution, args.resolution))\n\n        images = []\n\n        for _ in range(args.num_validation_images):\n            with autocast_ctx:\n                image = pipeline(\n                    prompt=validation_prompt,\n                    control_image=validation_image,\n                    num_inference_steps=50,\n                    guidance_scale=args.guidance_scale,\n                    max_sequence_length=args.max_sequence_length,\n                    generator=generator,\n                    height=args.resolution,\n                    width=args.resolution,\n                ).images[0]\n            image = image.resize((args.resolution, args.resolution))\n            images.append(image)\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n                formatted_images = []\n                formatted_images.append(np.asarray(validation_image))\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n                formatted_images = np.stack(formatted_images)\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n                formatted_images.append(wandb.Image(validation_image, caption=\"Conditioning\"))\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({tracker_key: formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n        del pipeline\n        free_memory()\n        return image_logs\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# cogview4-control-{repo_id}\n\nThese are Control weights trained on {base_model} with new type of conditioning.\n{img_str}\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogView4-6b/blob/main/LICENSE.md)\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"cogview4\",\n        \"cogview4-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"control\",\n        \"diffusers-training\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a CogView4 Control training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"cogview4-control\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_sequence_length\", type=int, default=128, help=\"The maximum sequence length for the prompt.\"\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the control conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\"--log_dataset_samples\", action=\"store_true\", help=\"Whether to log somple dataset samples.\")\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the control conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=1,\n        help=\"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"cogview4_train_control\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n    parser.add_argument(\n        \"--jsonl_for_train\",\n        type=str,\n        default=None,\n        help=\"Path to the jsonl file containing the training data.\",\n    )\n    parser.add_argument(\n        \"--only_target_transformer_blocks\",\n        action=\"store_true\",\n        help=\"If we should only target the transformer blocks to train along with the input layer (`x_embedder`).\",\n    )\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the guidance scale used for transformer.\",\n    )\n\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoders to CPU when they are not used.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.jsonl_for_train is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--jsonl_for_train`\")\n\n    if args.dataset_name is not None and args.jsonl_for_train is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--jsonl_for_train`\")\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    if args.validation_prompt is not None and args.validation_image is None:\n        raise ValueError(\"`--validation_image` must be set if `--validation_prompt` is set\")\n\n    if args.validation_prompt is None and args.validation_image is not None:\n        raise ValueError(\"`--validation_prompt` must be set if `--validation_image` is set\")\n\n    if (\n        args.validation_image is not None\n        and args.validation_prompt is not None\n        and len(args.validation_image) != 1\n        and len(args.validation_prompt) != 1\n        and len(args.validation_image) != len(args.validation_prompt)\n    ):\n        raise ValueError(\n            \"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,\"\n            \" or the same number of `--validation_prompt`s and `--validation_image`s\"\n        )\n\n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the cogview4 transformer.\"\n        )\n\n    return args\n\n\ndef get_train_dataset(args, accelerator):\n    dataset = None\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    if args.jsonl_for_train is not None:\n        # load from json\n        dataset = load_dataset(\"json\", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)\n        dataset = dataset.flatten_indices()\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    if args.image_column is None:\n        image_column = column_names[0]\n        logger.info(f\"image column defaulting to {image_column}\")\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.caption_column is None:\n        caption_column = column_names[1]\n        logger.info(f\"caption column defaulting to {caption_column}\")\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.conditioning_image_column is None:\n        conditioning_image_column = column_names[2]\n        logger.info(f\"conditioning image column defaulting to {conditioning_image_column}\")\n    else:\n        conditioning_image_column = args.conditioning_image_column\n        if conditioning_image_column not in column_names:\n            raise ValueError(\n                f\"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    with accelerator.main_process_first():\n        train_dataset = dataset[\"train\"].shuffle(seed=args.seed)\n        if args.max_train_samples is not None:\n            train_dataset = train_dataset.select(range(args.max_train_samples))\n    return train_dataset\n\n\ndef prepare_train_dataset(dataset, accelerator):\n    image_transforms = transforms.Compose(\n        [\n            transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.ToTensor(),\n            transforms.Lambda(lambda x: x * 2 - 1),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [\n            (image.convert(\"RGB\") if not isinstance(image, str) else Image.open(image).convert(\"RGB\"))\n            for image in examples[args.image_column]\n        ]\n        images = [image_transforms(image) for image in images]\n\n        conditioning_images = [\n            (image.convert(\"RGB\") if not isinstance(image, str) else Image.open(image).convert(\"RGB\"))\n            for image in examples[args.conditioning_image_column]\n        ]\n        conditioning_images = [image_transforms(image) for image in conditioning_images]\n        examples[\"pixel_values\"] = images\n        examples[\"conditioning_pixel_values\"] = conditioning_images\n\n        is_caption_list = isinstance(examples[args.caption_column][0], list)\n        if is_caption_list:\n            examples[\"captions\"] = [max(example, key=len) for example in examples[args.caption_column]]\n        else:\n            examples[\"captions\"] = list(examples[args.caption_column])\n\n        return examples\n\n    with accelerator.main_process_first():\n        dataset = dataset.with_transform(preprocess_train)\n\n    return dataset\n\n\ndef collate_fn(examples):\n    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n    conditioning_pixel_values = torch.stack([example[\"conditioning_pixel_values\"] for example in examples])\n    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()\n    captions = [example[\"captions\"] for example in examples]\n    return {\"pixel_values\": pixel_values, \"conditioning_pixel_values\": conditioning_pixel_values, \"captions\": captions}\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_out_dir = Path(args.output_dir, args.logging_dir)\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir))\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.\n    if torch.backends.mps.is_available():\n        logger.info(\"MPS is enabled. Disabling AMP.\")\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        # DEBUG, INFO, WARNING, ERROR, CRITICAL\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load models. We will load the text encoders later in a pipeline to compute\n    # embeddings.\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    cogview4_transformer = CogView4Transformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    logger.info(\"All models loaded successfully\")\n\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"scheduler\",\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    if not args.only_target_transformer_blocks:\n        cogview4_transformer.requires_grad_(True)\n    vae.requires_grad_(False)\n\n    # cast down and move to the CPU\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # let's not move the VAE to the GPU yet.\n    vae.to(dtype=torch.float32)  # keep the VAE in float32.\n\n    # enable image inputs\n    with torch.no_grad():\n        patch_size = cogview4_transformer.config.patch_size\n        initial_input_channels = cogview4_transformer.config.in_channels * patch_size**2\n        new_linear = torch.nn.Linear(\n            cogview4_transformer.patch_embed.proj.in_features * 2,\n            cogview4_transformer.patch_embed.proj.out_features,\n            bias=cogview4_transformer.patch_embed.proj.bias is not None,\n            dtype=cogview4_transformer.dtype,\n            device=cogview4_transformer.device,\n        )\n        new_linear.weight.zero_()\n        new_linear.weight[:, :initial_input_channels].copy_(cogview4_transformer.patch_embed.proj.weight)\n        if cogview4_transformer.patch_embed.proj.bias is not None:\n            new_linear.bias.copy_(cogview4_transformer.patch_embed.proj.bias)\n        cogview4_transformer.patch_embed.proj = new_linear\n\n    assert torch.all(cogview4_transformer.patch_embed.proj.weight[:, initial_input_channels:].data == 0)\n    cogview4_transformer.register_to_config(\n        in_channels=cogview4_transformer.config.in_channels * 2, out_channels=cogview4_transformer.config.in_channels\n    )\n\n    if args.only_target_transformer_blocks:\n        cogview4_transformer.patch_embed.proj.requires_grad_(True)\n        for name, module in cogview4_transformer.named_modules():\n            if \"transformer_blocks\" in name:\n                module.requires_grad_(True)\n            else:\n                module.requirs_grad_(False)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                for model in models:\n                    if isinstance(unwrap_model(model), type(unwrap_model(cogview4_transformer))):\n                        model = unwrap_model(model)\n                        model.save_pretrained(os.path.join(output_dir, \"transformer\"))\n                    else:\n                        raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    if weights:\n                        weights.pop()\n\n        def load_model_hook(models, input_dir):\n            transformer_ = None\n\n            if not accelerator.distributed_type == DistributedType.DEEPSPEED:\n                while len(models) > 0:\n                    model = models.pop()\n\n                    if isinstance(unwrap_model(model), type(unwrap_model(cogview4_transformer))):\n                        transformer_ = model  # noqa: F841\n                    else:\n                        raise ValueError(f\"unexpected save model: {unwrap_model(model).__class__}\")\n\n            else:\n                transformer_ = CogView4Transformer2DModel.from_pretrained(input_dir, subfolder=\"transformer\")  # noqa: F841\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        cogview4_transformer.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimization parameters\n    optimizer = optimizer_class(\n        cogview4_transformer.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Prepare dataset and dataloader.\n    train_dataset = get_train_dataset(args, accelerator)\n    train_dataset = prepare_train_dataset(train_dataset, accelerator)\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n    # Prepare everything with our `accelerator`.\n    cogview4_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        cogview4_transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n\n        # tensorboard cannot handle list types for config\n        tracker_config.pop(\"validation_prompt\")\n        tracker_config.pop(\"validation_image\")\n\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed.\n    text_encoding_pipeline = CogView4ControlPipeline.from_pretrained(\n        args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype\n    )\n    tokenizer = text_encoding_pipeline.tokenizer\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            logger.info(f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\")\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            logger.info(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    if accelerator.is_main_process and args.report_to == \"wandb\" and args.log_dataset_samples:\n        logger.info(\"Logging some dataset samples.\")\n        formatted_images = []\n        formatted_control_images = []\n        all_prompts = []\n        for i, batch in enumerate(train_dataloader):\n            images = (batch[\"pixel_values\"] + 1) / 2\n            control_images = (batch[\"conditioning_pixel_values\"] + 1) / 2\n            prompts = batch[\"captions\"]\n\n            if len(formatted_images) > 10:\n                break\n\n            for img, control_img, prompt in zip(images, control_images, prompts):\n                formatted_images.append(img)\n                formatted_control_images.append(control_img)\n                all_prompts.append(prompt)\n\n        logged_artifacts = []\n        for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):\n            logged_artifacts.append(wandb.Image(control_img, caption=\"Conditioning\"))\n            logged_artifacts.append(wandb.Image(img, caption=prompt))\n\n        wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == \"wandb\"]\n        wandb_tracker[0].log({\"dataset_samples\": logged_artifacts})\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        cogview4_transformer.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(cogview4_transformer):\n                # Convert images to latent space\n                # vae encode\n                prompts = batch[\"captions\"]\n                attention_mask = tokenizer(\n                    prompts,\n                    padding=\"longest\",  # not use max length\n                    max_length=args.max_sequence_length,\n                    truncation=True,\n                    add_special_tokens=True,\n                    return_tensors=\"pt\",\n                ).attention_mask.float()\n\n                pixel_latents = encode_images(batch[\"pixel_values\"], vae.to(accelerator.device), weight_dtype)\n                control_latents = encode_images(\n                    batch[\"conditioning_pixel_values\"], vae.to(accelerator.device), weight_dtype\n                )\n                if args.offload:\n                    vae.cpu()\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                bsz = pixel_latents.shape[0]\n                noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype)\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n\n                # Add noise according for cogview4\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)\n                sigmas = noise_scheduler_copy.sigmas[indices].to(device=pixel_latents.device)\n                captions = batch[\"captions\"]\n                image_seq_lens = torch.tensor(\n                    pixel_latents.shape[2] * pixel_latents.shape[3] // patch_size**2,\n                    dtype=pixel_latents.dtype,\n                    device=pixel_latents.device,\n                )  # H * W  / VAE patch_size\n                mu = torch.sqrt(image_seq_lens / 256)\n                mu = mu * 0.75 + 0.25\n                scale_factors = mu / (mu + (1 / sigmas - 1) ** 1.0).to(\n                    dtype=pixel_latents.dtype, device=pixel_latents.device\n                )\n                scale_factors = scale_factors.view(len(batch[\"captions\"]), 1, 1, 1)\n                noisy_model_input = (1.0 - scale_factors) * pixel_latents + scale_factors * noise\n                concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1)\n                text_encoding_pipeline = text_encoding_pipeline.to(\"cuda\")\n\n                with torch.no_grad():\n                    (\n                        prompt_embeds,\n                        pooled_prompt_embeds,\n                    ) = text_encoding_pipeline.encode_prompt(captions, \"\")\n                original_size = (args.resolution, args.resolution)\n                original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device)\n\n                target_size = (args.resolution, args.resolution)\n                target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=prompt_embeds.device)\n\n                target_size = target_size.repeat(len(batch[\"captions\"]), 1)\n                original_size = original_size.repeat(len(batch[\"captions\"]), 1)\n                crops_coords_top_left = torch.tensor([(0, 0)], dtype=prompt_embeds.dtype, device=prompt_embeds.device)\n                crops_coords_top_left = crops_coords_top_left.repeat(len(batch[\"captions\"]), 1)\n\n                # this could be optimized by not having to do any text encoding and just\n                # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`\n                if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:\n                    # Here, we directly pass 16 pad tokens from pooled_prompt_embeds to prompt_embeds.\n                    prompt_embeds = pooled_prompt_embeds\n                if args.offload:\n                    text_encoding_pipeline = text_encoding_pipeline.to(\"cpu\")\n                # Predict.\n                noise_pred_cond = cogview4_transformer(\n                    hidden_states=concatenated_noisy_model_input,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep=timesteps,\n                    original_size=original_size,\n                    target_size=target_size,\n                    crop_coords=crops_coords_top_left,\n                    return_dict=False,\n                    attention_mask=attention_mask,\n                )[0]\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n                # flow-matching loss\n                target = noise - pixel_latents\n\n                weighting = weighting.view(len(batch[\"captions\"]), 1, 1, 1)\n                loss = torch.mean(\n                    (weighting.float() * (noise_pred_cond.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n                accelerator.backward(loss)\n\n                if accelerator.sync_gradients:\n                    params_to_clip = cogview4_transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.\n                if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        image_logs = log_validation(\n                            cogview4_transformer=cogview4_transformer,\n                            args=args,\n                            accelerator=accelerator,\n                            weight_dtype=weight_dtype,\n                            step=global_step,\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        cogview4_transformer = unwrap_model(cogview4_transformer)\n        if args.upcast_before_saving:\n            cogview4_transformer.to(torch.float32)\n        cogview4_transformer.save_pretrained(args.output_dir)\n\n        del cogview4_transformer\n        del text_encoding_pipeline\n        del vae\n        free_memory()\n\n        # Run a final round of validation.\n        image_logs = None\n        if args.validation_prompt is not None:\n            image_logs = log_validation(\n                cogview4_transformer=None,\n                args=args,\n                accelerator=accelerator,\n                weight_dtype=weight_dtype,\n                step=global_step,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                image_logs=image_logs,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\", \"checkpoint-*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/community/README.md",
    "content": "# Community Pipeline Examples\n\n> **For more information about community pipelines, please have a look at [this issue](https://github.com/huggingface/diffusers/issues/841).**\n\n**Community pipeline** examples consist pipelines that have been added by the community.\nPlease have a look at the following tables to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste ready code example that you can try out.\nIf a community pipeline doesn't work as expected, please open an issue and ping the author on it.\n\nPlease also check out our [Community Scripts](https://github.com/huggingface/diffusers/blob/main/examples/community/README_community_scripts.md) examples for tips and tricks that you can use with diffusers without having to run a community pipeline.\n\n| Example                                                                                                                               | Description                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              | Code Example                                                                              | Colab                                                                                                                                                                                                              |                                                        Author |\n|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|\n|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://huggingface.co/papers/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/), and [Ednaordinary](https://github.com/Ednaordinary)|\n|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|\n|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/flux_with_cfg.ipynb)|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|\n|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|\n| HD-Painter                                                                                                                            | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method.                                                                                                                                                                                                                                                                                                               | [HD-Painter](#hd-painter)                                                                 | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/PAIR/HD-Painter)                                                                              | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) |\n| Marigold Monocular Depth Estimation                                                                                                   | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.)                                                                                                                                                                                                                                                        | [Marigold Depth Estimation](#marigold-depth-estimation)                                   | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/toshas/marigold) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) |\n| LLM-grounded Diffusion (LMD+)                                                                                                         | LMD greatly improves the prompt following ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion)                                                                                                                                                                                                                                                                                                                                                                                                                                   | [LLM-grounded Diffusion (LMD+)](#llm-grounded-diffusion)                             | [Huggingface Demo](https://huggingface.co/spaces/longlian/llm-grounded-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SXzMSeAB-LJYISb2yrUOdypLz4OYWUKj) |                [Long (Tony) Lian](https://tonylian.com/) |\n| CLIP Guided Stable Diffusion                                                                                                          | Doing CLIP guidance for text to image generation with Stable Diffusion                                                                                                                                                                                                                                                                                                                                                                                                                                   | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion)                             | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) |                [Suraj Patil](https://github.com/patil-suraj/) |\n| One Step U-Net (Dummy)                                                                                                                | Example showcasing of how to use Community Pipelines (see <https://github.com/huggingface/diffusers/issues/841>)                                                                                                                                                                                                                                                                                                                                                                                           | [One Step U-Net](#one-step-unet)                                                          | -                                                                                                                                                                                                                  |    [Patrick von Platen](https://github.com/patrickvonplaten/) |\n| Stable Diffusion Interpolation                                                                                                        | Interpolate the latent space of Stable Diffusion between different prompts/seeds                                                                                                                                                                                                                                                                                                                                                                                                                         | [Stable Diffusion Interpolation](#stable-diffusion-interpolation)                         | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_interpolation.ipynb)                                                                                                                                                           |                       [Nate Raw](https://github.com/nateraw/) |\n| Stable Diffusion Mega                                                                                                                 | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega)                                           | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_mega.ipynb)                                                                                                                                                                             |    [Patrick von Platen](https://github.com/patrickvonplaten/) |\n| Long Prompt Weighting Stable Diffusion                                                                                                | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt.                                                                                                                                                                                                                                                                                                                                                                                                  | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion)         | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/long_prompt_weighting_stable_diffusion.ipynb)                                                                                        |                           [SkyTNT](https://github.com/SkyTNT) |\n| Speech to Image                                                                                                                       | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images                                                                                                                                                                                                                                                                                                                                                                                                            | [Speech to Image](#speech-to-image)                                                       |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/speech_to_image.ipynb)                                                                                                                                                                                                   |             [Mikail Duzenli](https://github.com/MikailINTech)\n| Wild Card Stable Diffusion                                                                                                            | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values                                                                                                                                                                                                                                                                            | [Wildcard Stable Diffusion](#wildcard-stable-diffusion)                                   | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/wildcard_stable_diffusion.ipynb)                                                                                                                                                                                 |              [Shyam Sudhakaran](https://github.com/shyamsn97) |\n| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain \"&#124;\" in prompts (as an AND condition) and weights (separated by \"&#124;\" as well) to positively / negatively weight prompts.                                                                                                                                                                                                                                                                                                            | [Composable Stable Diffusion](#composable-stable-diffusion)                               | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/composable_stable_diffusion.ipynb)                                                                                                                                                                                                                 |                      [Mark Rich](https://github.com/MarkRich) |\n| Seed Resizing Stable Diffusion                                                                                                        | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation.                                                                                                                                                                                                                                                                                                                                                                                       | [Seed Resizing](#seed-resizing)                                                           | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/seed_resizing.ipynb)                                                                                                                                                                                                                  |                      [Mark Rich](https://github.com/MarkRich) |\n| Imagic Stable Diffusion                                                                                                               | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image                                                                                                                                                                                                                                                                                                                                                                                                                   | [Imagic Stable Diffusion](#imagic-stable-diffusion)                                       | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/imagic_stable_diffusion.ipynb)                                                                                                                                                                                                  |                      [Mark Rich](https://github.com/MarkRich) |\n| Multilingual Stable Diffusion                                                                                                         | Stable Diffusion Pipeline that supports prompts in 50 different languages.                                                                                                                                                                                                                                                                                                                                                                                                                               | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline)                  | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/multilingual_stable_diffusion.ipynb)                                                                                                                                                                             |          [Juan Carlos Piñeros](https://github.com/juancopi81) |\n| GlueGen Stable Diffusion                                                                                                         | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter.                                                                                                                                                                                                                                                                                                                                                                                                                               | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline)                  | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/gluegen_stable_diffusion.ipynb)                                                                                                                                                                                                                  |          [Phạm Hồng Vinh](https://github.com/rootonchair) |\n| Image to Image Inpainting Stable Diffusion                                                                                            | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting                                                                                                                                                                                                                                                                                                                                                                                                            | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/image_to_image_inpainting_stable_diffusion.ipynb)                                                                                                                                                                                                                  |                    [Alex McKinney](https://github.com/vvvm23) |\n| Text Based Inpainting Stable Diffusion                                                                                                | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting                                                                                                                                                                                                                                                                                                                                                                                              | [Text Based Inpainting Stable Diffusion](#text-based-inpainting-stable-diffusion)     | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/text_based_inpainting_stable_dffusion.ipynb)                                                                                                                                                                                                    |                   [Dhruv Karan](https://github.com/unography) |\n| Bit Diffusion                                                                                                                         | Diffusion on discrete data                                                                                                                                                                                                                                                                                                                                                                                                                                                                               | [Bit Diffusion](#bit-diffusion)                                                           | -  |                       [Stuti R.](https://github.com/kingstut) |\n| K-Diffusion Stable Diffusion                                                                                                          | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py)                                                                                                                                                                                                                                                                                                                                                                  | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion)                   | -  |    [Patrick von Platen](https://github.com/patrickvonplaten/) |\n| Checkpoint Merger Pipeline                                                                                                            | Diffusion Pipeline that enables merging of saved model checkpoints                                                                                                                                                                                                                                                                                                                                                                                                                                       | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline)                                 | -                                                                                                                                                                                                                  | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |\n| Stable Diffusion v1.1-1.4 Comparison                                                                                                  | Run all 4 model checkpoints for Stable Diffusion and compare their results together                                                                                                                                                                                                                                                                                                                                                                                                                      | [Stable Diffusion Comparison](#stable-diffusion-comparisons)                              | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_comparison.ipynb) |        [Suvaditya Mukherjee](https://github.com/suvadityamuk) |\n| MagicMix                                                                                                                              | Diffusion Pipeline for semantic mixing of an image and a text prompt                                                                                                                                                                                                                                                                                                                                                                                                                                     | [MagicMix](#magic-mix)                                                                    | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/magic_mix.ipynb) |                    [Partho Das](https://github.com/daspartho) |\n| Stable UnCLIP                                                                                                                         | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `\"kakaobrain/karlo-v1-alpha\"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `\"lambdalabs/sd-image-variations-diffusers\"` ).                                                                                                                                                                                                                   | [Stable UnCLIP](#stable-unclip)                                                           | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_unclip.ipynb)  |                                [Ray Wang](https://wrong.wang) |\n| UnCLIP Text Interpolation Pipeline                                                                                                    | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts                                                                                                                                                                                                                                                                                                                                                                | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline)                 | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_text_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |\n| UnCLIP Image Interpolation Pipeline                                                                                                   | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings                                                                                                                                                                                                                                                                                                                                                                | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline)               | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_image_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |\n| DDIM Noise Comparative Analysis Pipeline                                                                                              | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://huggingface.co/papers/2204.00227))                                                                                                                                                                                                                                                                                                                             | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline)     | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)|              [Aengus (Duc-Anh)](https://github.com/aengusng8) |\n| CLIP Guided Img2Img Stable Diffusion Pipeline                                                                                         | Doing CLIP guidance for image to image generation with Stable Diffusion                                                                                                                                                                                                                                                                                                                                                                                                                                  | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion)             | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/clip_guided_img2img_stable_diffusion.ipynb) |               [Nipun Jindal](https://github.com/nipunjindal/) |\n| TensorRT Stable Diffusion Text to Image Pipeline                                                                                                    | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT                                                                                                                                                                                                                                                                                                                                                                                                                                      | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline)      | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/tensorrt_text2image_stable_diffusion_pipeline.ipynb) |              [Asfiya Baig](https://github.com/asfiyab-nvidia) |\n| EDICT Image Editing Pipeline                                                                                                          | Diffusion pipeline for text-guided image editing                                                                                                                                                                                                                                                                                                                                                                                                                                                         | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline)                             | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) |                    [Joqsan Azocar](https://github.com/Joqsan) |\n| Stable Diffusion RePaint                                                                                                              | Stable Diffusion pipeline using [RePaint](https://huggingface.co/papers/2201.09865) for inpainting.                                                                                                                                                                                                                                                                                                                                                                                                               | [Stable Diffusion RePaint](#stable-diffusion-repaint )|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_repaint.ipynb)|                  [Markus Pobitzer](https://github.com/Markus-Pobitzer) |\n| TensorRT Stable Diffusion Image to Image Pipeline                                                                                                    | Accelerates the Stable Diffusion Image2Image Pipeline using TensorRT                                                                                                                                                                                                                                                                                                                                                                                                                                      | [TensorRT Stable Diffusion Image to Image Pipeline](#tensorrt-image2image-stable-diffusion-pipeline)      | - |              [Asfiya Baig](https://github.com/asfiyab-nvidia) |\n| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) |\n| CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/clip_guided_images_mixing_with_stable_diffusion.ipynb) | [Karachev Denis](https://github.com/TheDenk) |\n| TensorRT Stable Diffusion Inpainting Pipeline                                                                                                    | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT                                                                                                                                                                                                                                                                                                                                                                                                                                      | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline)      | - |              [Asfiya Baig](https://github.com/asfiyab-nvidia) |\n|   IADB Pipeline                                                                                                    | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://huggingface.co/papers/2305.03486)                                                                                                                                                                                                                                                                                                                                                                                                                                      | [IADB Pipeline](#iadb-pipeline)      | - |              [Thomas Chambon](https://github.com/tchambon)\n|   Zero1to3 Pipeline                                                                                                    | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://huggingface.co/papers/2303.11328)                                                                                                                                                                                                                                                                                                                                                                                                                                      | [Zero1to3 Pipeline](#zero1to3-pipeline)      | - |              [Xin Kong](https://github.com/kxhit) |\n| Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) |\n| Stable Diffusion Mixture Tiling Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SD 1.5](#stable-diffusion-mixture-tiling-pipeline-sd-15) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |\n| Stable Diffusion Mixture Canvas Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending. Works by defining a list of Text2Image region objects that detail the region of influence of each diffuser. | [Stable Diffusion Mixture Canvas Pipeline SD 1.5](#stable-diffusion-mixture-canvas-pipeline-sd-15) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |\n| Stable Diffusion Mixture Tiling Pipeline SDXL | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SDXL](#stable-diffusion-mixture-tiling-pipeline-sdxl) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/elismasilva/mixture-of-diffusers-sdxl-tiling) | [Eliseu Silva](https://github.com/DEVAIEXP/) |\n| Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL | This is an advanced pipeline that leverages ControlNet Tile and Mixture-of-Diffusers techniques, integrating tile diffusion directly into the latent space denoising process. Designed to overcome the limitations of conventional pixel-space tile processing, this pipeline delivers Super Resolution (SR) upscaling for higher-quality images, reduced processing time, and greater adaptability. | [Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL](#stable-diffusion-mod-controlnet-tile-sr-pipeline-sdxl) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/elismasilva/mod-control-tile-upscaler-sdxl) | [Eliseu Silva](https://github.com/DEVAIEXP/) |\n| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |\n| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |\n| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |\n| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/prompt_2_prompt_pipeline.ipynb) | [Umer H. Adil](https://twitter.com/UmerHAdil) |\n|   Latent Consistency Pipeline                                                                                                    | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://huggingface.co/papers/2310.04378)                                                                                                                                                                                                                                                                                                                                                                                                                                      | [Latent Consistency Pipeline](#latent-consistency-pipeline)      | - |              [Simian Luo](https://github.com/luosiallen) |\n|   Latent Consistency Img2img Pipeline                                                                                                    | Img2img pipeline for Latent Consistency Models                                                                                                                                                                                                                                                                                                                                                                                                                                    | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline)      | - |              [Logan Zoellner](https://github.com/nagolinc) |\n|   Latent Consistency Interpolation Pipeline                                                                                                    | Interpolate the latent space of Latent Consistency Models with multiple prompts                                                                                                                                                                                                                                                                                                                                                                                                                                    | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |\n| SDE Drag Pipeline                                                                                                                         | The pipeline supports drag editing of images using stochastic differential equations                                                                                                                                                                                                                                                                                                                                                                                                                | [SDE Drag Pipeline](#sde-drag-pipeline)                                                     | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/sde_drag.ipynb) | [NieShen](https://github.com/NieShenRuc) [Fengqi Zhu](https://github.com/Monohydroxides) |\n|   Regional Prompting Pipeline                                                                                               | Assign multiple prompts for different regions                                                                                                                                                                                                                                                                                                                                                    |  [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |\n| LDM3D-sr (LDM3D upscaler)                                                                                                             | Upscale low resolution RGB and depth inputs to high resolution                                                                                                                                                                                                                                                                                                                                                                                                                              | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline)                                                                             | -                                                                                                                                                                                                             |                                                        [Estelle Aflalo](https://github.com/estelleafl) |\n| AnimateDiff ControlNet Pipeline                                                                                                    | Combines AnimateDiff with precise motion control using ControlNets                                                                                                                                                                                                                                                                                                                                                                                                                                    | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |\n|   DemoFusion Pipeline                                                                                                    | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://huggingface.co/papers/2311.16973)                                                                                                                                                                                                                                                                                                                                                                                                                                      | [DemoFusion Pipeline](#demofusion)      | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/demo_fusion.ipynb) |              [Ruoyi Du](https://github.com/RuoyiDu) |\n|   Instaflow Pipeline                                                                                                    | Implementation of [InstaFlow! One-Step Stable Diffusion with Rectified Flow](https://huggingface.co/papers/2309.06380)                                                                                                                                                                                                                                                                                                                                                                                                                                      | [Instaflow Pipeline](#instaflow-pipeline)      | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/insta_flow.ipynb) |              [Ayush Mangal](https://github.com/ayushtues) |\n|   Null-Text Inversion Pipeline  | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://huggingface.co/papers/2211.09794) as a pipeline.                                                                                                                                                                                                                                                                                                                                                                                                                                      | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/)      | - |              [Junsheng Luan](https://github.com/Junsheng121) |\n|   Rerender A Video Pipeline                                                                                                    | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://huggingface.co/papers/2306.07954)                                                                                                                                                                                                                                                                                                                                                                                                                                      | [Rerender A Video Pipeline](#rerender-a-video)      | - |              [Yifan Zhou](https://github.com/SingleZombie) |\n| StyleAligned Pipeline                                                                                                    | Implementation of [Style Aligned Image Generation via Shared Attention](https://huggingface.co/papers/2312.02133)                                                                                                                                                                                                                                                                                                                                                                                                                                   | [StyleAligned Pipeline](#stylealigned-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |\n| AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |\n|   IP Adapter FaceID Stable Diffusion                                                                                               | Stable Diffusion Pipeline that supports IP Adapter Face ID                                                                                                                                                                                                                                                                                                                                                  |  [IP Adapter Face ID](#ip-adapter-face-id) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ip_adapter_face_id.ipynb)| [Fabio Rigano](https://github.com/fabiorigano) |\n|   InstantID Pipeline                                                                                               | Stable Diffusion XL Pipeline that supports InstantID                                                                                                                                                                                                                                                                                                                                                 |  [InstantID Pipeline](#instantid-pipeline) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) |\n|   UFOGen Scheduler                                                                                               | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines)                                                                                                                                                                                                                                                                                                                                                 |  [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) |\n| Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |\n| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |\n|   FRESCO V2V Pipeline                                                                                                    | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://huggingface.co/papers/2403.12962)                                                                                                                                                                                                                                                                                                                                                                                                                                      | [FRESCO V2V Pipeline](#fresco)      | - |              [Yifan Zhou](https://github.com/SingleZombie) |\n| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |\nPIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |\n| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |\n| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |\n| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|\n| Perturbed-Attention Guidance |StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).|[Perturbed-Attention Guidance](#perturbed-attention-guidance)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/perturbed_attention_guidance.ipynb)|[Hyoungwon Cho](https://github.com/HyoungwonCho)|\n| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |\n| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |\n| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|\n| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - |  [Net-Mist](https://github.com/Net-Mist) |\n| Flux Fill ControlNet Pipeline | A modified version of the `FluxFillPipeline` and `FluxControlNetInpaintPipeline` that supports Controlnet with Flux Fill model.| [Flux Fill ControlNet Pipeline](#Flux-Fill-ControlNet-Pipeline) | - |  [pratim4dasude](https://github.com/pratim4dasude) |\n\nTo load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.\n\n```py\npipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", custom_pipeline=\"filename_in_the_community_folder\")\n```\n\n## Example usages\n\n### Spatiotemporal Skip Guidance\n\n**Junha Hyung\\*, Kinam Kim\\*, Susung Hong, Min-Jung Kim, Jaegul Choo**\n\n**KAIST AI, University of Washington**\n\n[*Spatiotemporal Skip Guidance (STG) for Enhanced Video Diffusion Sampling*](https://huggingface.co/papers/2411.18664) (CVPR 2025) is a simple training-free sampling guidance method for enhancing transformer-based video diffusion models. STG employs an implicit weak model via self-perturbation, avoiding the need for external models or additional training. By selectively skipping spatiotemporal layers, STG produces an aligned, degraded version of the original model to boost sample quality without compromising diversity or dynamic degree.\n\nFollowing is the example video of STG applied to Mochi.\n\n\nhttps://github.com/user-attachments/assets/148adb59-da61-4c50-9dfa-425dcb5c23b3\n\nMore examples and information can be found on the [GitHub repository](https://github.com/junhahyung/STGuidance) and the [Project website](https://junhahyung.github.io/STGuidance/).\n\n#### Usage example\n```python\nimport torch\nfrom pipeline_stg_mochi import MochiSTGPipeline\nfrom diffusers.utils import export_to_video\n\n# Load the pipeline\npipe = MochiSTGPipeline.from_pretrained(\"genmo/mochi-1-preview\", variant=\"bf16\", torch_dtype=torch.bfloat16)\n\n# Enable memory savings\npipe = pipe.to(\"cuda\")\n\n#--------Option--------#\nprompt = \"A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style.\"\nstg_applied_layers_idx = [34]\nstg_scale = 1.0 # 0.0 for CFG\n#----------------------#\n\n# Generate video frames\nframes = pipe(\n    prompt, \n    height=480,\n    width=480,\n    num_frames=81,\n    stg_applied_layers_idx=stg_applied_layers_idx,\n    stg_scale=stg_scale,\n    generator = torch.Generator().manual_seed(42),\n    do_rescaling=do_rescaling,\n).frames[0]\n\nexport_to_video(frames, \"output.mp4\", fps=30)\n```\n\n### Adaptive Mask Inpainting\n\n**Hyeonwoo Kim\\*, Sookwan Han\\*, Patrick Kwon, Hanbyul Joo**\n\n**Seoul National University, Naver Webtoon**\n\nAdaptive Mask Inpainting, presented in the ECCV'24 oral paper [*Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models*](https://snuvclab.github.io/coma), is an algorithm designed to insert humans into scene images without altering the background. Traditional inpainting methods often fail to preserve object geometry and details within the masked region, leading to false affordances. Adaptive Mask Inpainting addresses this issue by progressively specifying the inpainting region over diffusion timesteps, ensuring that the inserted human integrates seamlessly with the existing scene.\n\nHere is the demonstration of Adaptive Mask Inpainting:\n\n<video controls>\n  <source src=\"https://snuvclab.github.io/coma/static/videos/adaptive_mask_inpainting_vis.mp4\" type=\"video/mp4\">\n  Your browser does not support the video tag.\n</video>\n\n![teaser-img](https://snuvclab.github.io/coma/static/images/example_result_adaptive_mask_inpainting.png)\n\n\nYou can find additional information about Adaptive Mask Inpainting in the [paper](https://huggingface.co/papers/2401.12978) or in the [project website](https://snuvclab.github.io/coma).\n\n#### Usage example\nFirst, clone the diffusers github repository, and run the following command to set environment.\n```Shell\ngit clone https://github.com/huggingface/diffusers.git\ncd diffusers\n\nconda create --name ami python=3.9 -y\nconda activate ami\n\nconda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge -y\npython -m pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html\npip install easydict\npip install diffusers==0.20.2 accelerate safetensors transformers\npip install setuptools==59.5.0\npip install opencv-python\npip install numpy==1.24.1\n```\nThen, run the below code under 'diffusers' directory.\n```python\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom diffusers import DDIMScheduler\nfrom diffusers import DiffusionPipeline\nfrom diffusers.utils import load_image\n\nfrom examples.community.adaptive_mask_inpainting import download_file, AdaptiveMaskInpaintPipeline, AMI_INSTALL_MESSAGE\n\nprint(AMI_INSTALL_MESSAGE)\n\nfrom easydict import EasyDict\n\n\n\nif __name__ == \"__main__\":    \n    \"\"\"\n    Download Necessary Files\n    \"\"\"\n    download_file(\n        url = \"https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/model_final_edd263.pkl?download=true\",\n        output_file = \"model_final_edd263.pkl\",\n        exist_ok=True,\n    )\n    download_file(\n        url = \"https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/pointrend_rcnn_R_50_FPN_3x_coco.yaml?download=true\",\n        output_file = \"pointrend_rcnn_R_50_FPN_3x_coco.yaml\",\n        exist_ok=True,\n    )\n    download_file(\n        url = \"https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_img.png?download=true\",\n        output_file = \"input_img.png\",\n        exist_ok=True,\n    )\n    download_file(\n        url = \"https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/input_mask.png?download=true\",\n        output_file = \"input_mask.png\",\n        exist_ok=True,\n    )\n    download_file(\n        url = \"https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-PointRend-RCNN-FPN.yaml?download=true\",\n        output_file = \"Base-PointRend-RCNN-FPN.yaml\",\n        exist_ok=True,\n    )\n    download_file(\n        url = \"https://huggingface.co/datasets/jellyheadnadrew/adaptive-mask-inpainting-test-images/resolve/main/Base-RCNN-FPN.yaml?download=true\",\n        output_file = \"Base-RCNN-FPN.yaml\",\n        exist_ok=True,\n    )\n    \n    \"\"\" \n    Prepare Adaptive Mask Inpainting Pipeline\n    \"\"\"\n    # device\n    device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n    num_steps = 50\n    \n    # Scheduler\n    scheduler = DDIMScheduler(\n        beta_start=0.00085, \n        beta_end=0.012, \n        beta_schedule=\"scaled_linear\", \n        clip_sample=False, \n        set_alpha_to_one=False\n    )\n    scheduler.set_timesteps(num_inference_steps=num_steps)\n\n    ## load models as pipelines\n    pipeline = AdaptiveMaskInpaintPipeline.from_pretrained(\n        \"Uminosachi/realisticVisionV51_v51VAE-inpainting\", \n        scheduler=scheduler, \n        torch_dtype=torch.float16, \n        requires_safety_checker=False\n    ).to(device)\n\n    ## disable safety checker\n    enable_safety_checker = False\n    if not enable_safety_checker:\n        pipeline.safety_checker = None\n    \n    \"\"\" \n    Run Adaptive Mask Inpainting \n    \"\"\"\n    default_mask_image = Image.open(\"./input_mask.png\").convert(\"L\")\n    init_image = Image.open(\"./input_img.png\").convert(\"RGB\")\n    \n    \n    seed = 59\n    generator = torch.Generator(device=device)\n    generator.manual_seed(seed)\n    \n    image = pipeline(\n        prompt=\"a man sitting on a couch\",\n        negative_prompt=\"worst quality, normal quality, low quality, bad anatomy, artifacts, blurry, cropped, watermark, greyscale, nsfw\",\n        image=init_image,\n        default_mask_image=default_mask_image,\n        guidance_scale=11.0,\n        strength=0.98,\n        use_adaptive_mask=True,\n        generator=generator,\n        enforce_full_mask_ratio=0.0,\n        visualization_save_dir=\"./ECCV2024_adaptive_mask_inpainting_demo\", # DON'T CHANGE THIS!!!\n        human_detection_thres=0.015,\n    ).images[0]\n\n    \n    image.save(f'final_img.png')\n```\n#### [Troubleshooting]\n\nIf you run into an error `cannot import name 'cached_download' from 'huggingface_hub'` (issue [1851](https://github.com/easydiffusion/easydiffusion/issues/1851)), remove `cached_download` from the import line in the file `diffusers/utils/dynamic_modules_utils.py`. \n\nFor example, change the import line from `.../env/lib/python3.8/site-packages/diffusers/utils/dynamic_modules_utils.py`.\n\n\n### Flux with CFG\n\nKnow more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).\n\nExample usage:\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\nmodel_name = \"black-forest-labs/FLUX.1-dev\"\nprompt = \"a watercolor painting of a unicorn\"\nnegative_prompt = \"pink\"\n\n# Load the diffusion pipeline\npipeline = DiffusionPipeline.from_pretrained(\n    model_name,\n    torch_dtype=torch.bfloat16,\n    custom_pipeline=\"pipeline_flux_with_cfg\"\n)\npipeline.enable_model_cpu_offload()\n\n# Generate the image\nimg = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    true_cfg=1.5,\n    guidance_scale=3.5,\n    generator=torch.manual_seed(0)\n).images[0]\n\n# Save the generated image\nimg.save(\"cfg_flux.png\")\nprint(\"Image generated and saved successfully.\")\n```\n\n### Differential Diffusion\n\n**Eran Levin, Ohad Fried**\n\n**Tel Aviv University, Reichman University**\n\nDiffusion models have revolutionized image generation and editing, producing state-of-the-art results in conditioned and unconditioned image synthesis. While current techniques enable user control over the degree of change in an image edit, the controllability is limited to global changes over an entire edited region. This paper introduces a novel framework that enables customization of the amount of change per pixel or per image region. Our framework can be integrated into any existing diffusion model, enhancing it with this capability. Such granular control on the quantity of change opens up a diverse array of new editing capabilities, such as control of the extent to which individual objects are modified, or the ability to introduce gradual spatial changes. Furthermore, we showcase the framework's effectiveness in soft-inpainting---the completion of portions of an image while subtly adjusting the surrounding areas to ensure seamless integration. Additionally, we introduce a new tool for exploring the effects of different change quantities. Our framework operates solely during inference, requiring no model training or fine-tuning. We demonstrate our method with the current open state-of-the-art models, and validate it via both quantitative and qualitative comparisons, and a user study.\n\n![teaser-img](https://github.com/exx8/differential-diffusion/raw/main/assets/teaser.png)\n\nYou can find additional information about Differential Diffusion in the [paper](https://differential-diffusion.github.io/paper.pdf) or in the [project website](https://differential-diffusion.github.io/).\n\n#### Usage example\n\n```python\nimport torch\nfrom torchvision import transforms\n\nfrom diffusers import DPMSolverMultistepScheduler\nfrom diffusers.utils import load_image\nfrom examples.community.pipeline_stable_diffusion_xl_differential_img2img import (\n    StableDiffusionXLDifferentialImg2ImgPipeline,\n)\n\n\npipeline = StableDiffusionXLDifferentialImg2ImgPipeline.from_pretrained(\n    \"SG161222/RealVisXL_V4.0\", torch_dtype=torch.float16, variant=\"fp16\"\n).to(\"cuda\")\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True)\n\n\ndef preprocess_image(image):\n    image = image.convert(\"RGB\")\n    image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)\n    image = transforms.ToTensor()(image)\n    image = image * 2 - 1\n    image = image.unsqueeze(0).to(\"cuda\")\n    return image\n\n\ndef preprocess_map(map):\n    map = map.convert(\"L\")\n    map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map)\n    map = transforms.ToTensor()(map)\n    map = map.to(\"cuda\")\n    return map\n\n\nimage = preprocess_image(\n    load_image(\n        \"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true\"\n    )\n)\n\nmask = preprocess_map(\n    load_image(\n        \"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true\"\n    )\n)\n\nprompt = \"a green pear\"\nnegative_prompt = \"blurry\"\n\nimage = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    guidance_scale=7.5,\n    num_inference_steps=25,\n    original_image=image,\n    image=image,\n    strength=1.0,\n    map=mask,\n).images[0]\n\nimage.save(\"result.png\")\n```\n\n### HD-Painter\n\nImplementation of [HD-Painter: High-Resolution and Prompt-Faithful Text-Guided Image Inpainting with Diffusion Models](https://huggingface.co/papers/2312.14091).\n\n![teaser-img](https://raw.githubusercontent.com/Picsart-AI-Research/HD-Painter/main/__assets__/github/teaser.jpg)\n\nThe abstract from the paper is:\n\nRecent progress in text-guided image inpainting, based on the unprecedented success of text-to-image diffusion models, has led to exceptionally realistic and visually plausible results.\nHowever, there is still significant potential for improvement in current text-to-image inpainting models, particularly in better aligning the inpainted area with user prompts and performing high-resolution inpainting.\nTherefore, in this paper we introduce _HD-Painter_, a completely **training-free** approach that **accurately follows to prompts** and coherently **scales to high-resolution** image inpainting.\nTo this end, we design the _Prompt-Aware Introverted Attention (PAIntA)_ layer enhancing self-attention scores by prompt information and resulting in better text alignment generations.\nTo further improve the prompt coherence we introduce the _Reweighting Attention Score Guidance (RASG)_ mechanism seamlessly integrating a post-hoc sampling strategy into general form of DDIM to prevent out-of-distribution latent shifts.\nMoreover, HD-Painter allows extension to larger scales by introducing a specialized super-resolution technique customized for inpainting, enabling the completion of missing regions in images of up to 2K resolution.\nOur experiments demonstrate that HD-Painter surpasses existing state-of-the-art approaches qualitatively and quantitatively, achieving an impressive generation accuracy improvement of **61.4** vs **51.9**.\nWe will make the codes publicly available.\n\nYou can find additional information about Text2Video-Zero in the [paper](https://huggingface.co/papers/2312.14091) or the [original codebase](https://github.com/Picsart-AI-Research/HD-Painter).\n\n#### Usage example\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline, DDIMScheduler\nfrom diffusers.utils import load_image, make_image_grid\n\npipe = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-2-inpainting\",\n    custom_pipeline=\"hd_painter\"\n)\npipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)\n\nprompt = \"wooden boat\"\ninit_image = load_image(\"https://raw.githubusercontent.com/Picsart-AI-Research/HD-Painter/main/__assets__/samples/images/2.jpg\")\nmask_image = load_image(\"https://raw.githubusercontent.com/Picsart-AI-Research/HD-Painter/main/__assets__/samples/masks/2.png\")\n\nimage = pipe(prompt, init_image, mask_image, use_rasg=True, use_painta=True, generator=torch.manual_seed(12345)).images[0]\n\nmake_image_grid([init_image, mask_image, image], rows=1, cols=3)\n```\n\n### Marigold Depth Estimation\n\nMarigold is a universal monocular depth estimator that delivers accurate and sharp predictions in the wild. Based on Stable Diffusion, it is trained exclusively with synthetic depth data and excels in zero-shot adaptation to real-world imagery. This pipeline is an official implementation of the inference process. More details can be found on our [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) (also implemented with diffusers).\n\n![Marigold Teaser](https://marigoldmonodepth.github.io/images/teaser_collage_compressed.jpg)\n\nThis depth estimation pipeline processes a single input image through multiple diffusion denoising stages to estimate depth maps. These maps are subsequently merged to produce the final output. Below is an example code snippet, including optional arguments:\n\n```python\nimport numpy as np\nimport torch\nfrom PIL import Image\nfrom diffusers import DiffusionPipeline\nfrom diffusers.utils import load_image\n\n# Original DDIM version (higher quality)\npipe = DiffusionPipeline.from_pretrained(\n    \"prs-eth/marigold-v1-0\",\n    custom_pipeline=\"marigold_depth_estimation\"\n    # torch_dtype=torch.float16,                # (optional) Run with half-precision (16-bit float).\n    # variant=\"fp16\",                           # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint\n)\n\n# (New) LCM version (faster speed)\npipe = DiffusionPipeline.from_pretrained(\n    \"prs-eth/marigold-depth-lcm-v1-0\",\n    custom_pipeline=\"marigold_depth_estimation\"\n    # torch_dtype=torch.float16,                # (optional) Run with half-precision (16-bit float).\n    # variant=\"fp16\",                           # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint\n)\n\npipe.to(\"cuda\")\n\nimg_path_or_url = \"https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_example.jpg\"\nimage: Image.Image = load_image(img_path_or_url)\n\npipeline_output = pipe(\n    image,                    # Input image.\n    # ----- recommended setting for DDIM version -----\n    # denoising_steps=10,     # (optional) Number of denoising steps of each inference pass. Default: 10.\n    # ensemble_size=10,       # (optional) Number of inference passes in the ensemble. Default: 10.\n    # ------------------------------------------------\n\n    # ----- recommended setting for LCM version ------\n    # denoising_steps=4,\n    # ensemble_size=5,\n    # -------------------------------------------------\n\n    # processing_res=768,     # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768.\n    # match_input_res=True,   # (optional) Resize depth prediction to match input resolution.\n    # batch_size=0,           # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0.\n    # seed=2024,              # (optional) Random seed can be set to ensure additional reproducibility. Default: None (unseeded). Note: forcing --batch_size 1 helps to increase reproducibility. To ensure full reproducibility, deterministic mode needs to be used.\n    # color_map=\"Spectral\",   # (optional) Colormap used to colorize the depth map. Defaults to \"Spectral\". Set to `None` to skip colormap generation.\n    # show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress.\n)\n\ndepth: np.ndarray = pipeline_output.depth_np                    # Predicted depth map\ndepth_colored: Image.Image = pipeline_output.depth_colored      # Colorized prediction\n\n# Save as uint16 PNG\ndepth_uint16 = (depth * 65535.0).astype(np.uint16)\nImage.fromarray(depth_uint16).save(\"./depth_map.png\", mode=\"I;16\")\n\n# Save colorized depth map\ndepth_colored.save(\"./depth_colored.png\")\n```\n\n### LLM-grounded Diffusion\n\nLMD and LMD+ greatly improves the prompt understanding ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. It improves spatial reasoning, the understanding of negation, attribute binding, generative numeracy, etc. in a unified manner without explicitly aiming for each. LMD is completely training-free (i.e., uses SD model off-the-shelf). LMD+ takes in additional adapters for better control. This is a reproduction of LMD+ model used in our work. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion)\n\n![Main Image](https://llm-grounded-diffusion.github.io/main_figure.jpg)\n![Visualizations: Enhanced Prompt Understanding](https://llm-grounded-diffusion.github.io/visualizations.jpg)\n\nThis pipeline can be used with an LLM or on its own. We provide a parser that parses LLM outputs to the layouts. You can obtain the prompt to input to the LLM for layout generation [here](https://github.com/TonyLianLong/LLM-groundedDiffusion/blob/main/prompt.py). After feeding the prompt to an LLM (e.g., GPT-4 on ChatGPT website), you can feed the LLM response into our pipeline.\n\nThe following code has been tested on 1x RTX 4090, but it should also support GPUs with lower GPU memory.\n\n#### Use this pipeline with an LLM\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\n    \"longlian/lmd_plus\",\n    custom_pipeline=\"llm_grounded_diffusion\",\n    custom_revision=\"main\",\n    variant=\"fp16\", torch_dtype=torch.float16\n)\npipe.enable_model_cpu_offload()\n\n# Generate directly from a text prompt and an LLM response\nprompt = \"a waterfall and a modern high speed train in a beautiful forest with fall foliage\"\nphrases, boxes, bg_prompt, neg_prompt = pipe.parse_llm_response(\"\"\"\n[('a waterfall', [71, 105, 148, 258]), ('a modern high speed train', [255, 223, 181, 149])]\nBackground prompt: A beautiful forest with fall foliage\nNegative prompt:\n\"\"\")\n\nimages = pipe(\n    prompt=prompt,\n    negative_prompt=neg_prompt,\n    phrases=phrases,\n    boxes=boxes,\n    gligen_scheduled_sampling_beta=0.4,\n    output_type=\"pil\",\n    num_inference_steps=50,\n    lmd_guidance_kwargs={}\n).images\n\nimages[0].save(\"./lmd_plus_generation.jpg\")\n```\n\n#### Use this pipeline on its own for layout generation\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\n    \"longlian/lmd_plus\",\n    custom_pipeline=\"llm_grounded_diffusion\",\n    variant=\"fp16\", torch_dtype=torch.float16\n)\npipe.enable_model_cpu_offload()\n\n# Generate an image described by the prompt and\n# insert objects described by text at the region defined by bounding boxes\nprompt = \"a waterfall and a modern high speed train in a beautiful forest with fall foliage\"\nboxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]]\nphrases = [\"a waterfall\", \"a modern high speed train\"]\n\nimages = pipe(\n    prompt=prompt,\n    phrases=phrases,\n    boxes=boxes,\n    gligen_scheduled_sampling_beta=0.4,\n    output_type=\"pil\",\n    num_inference_steps=50,\n    lmd_guidance_kwargs={}\n).images\n\nimages[0].save(\"./lmd_plus_generation.jpg\")\n```\n\n### CLIP Guided Stable Diffusion\n\nCLIP guided stable diffusion can help to generate more realistic images\nby guiding stable diffusion at every denoising step with an additional CLIP model.\n\nThe following code requires roughly 12GB of GPU RAM.\n\n```python\nfrom diffusers import DiffusionPipeline\nfrom transformers import CLIPImageProcessor, CLIPModel\nimport torch\n\n\nfeature_extractor = CLIPImageProcessor.from_pretrained(\"laion/CLIP-ViT-B-32-laion2B-s34B-b79K\")\nclip_model = CLIPModel.from_pretrained(\"laion/CLIP-ViT-B-32-laion2B-s34B-b79K\", torch_dtype=torch.float16)\n\n\nguided_pipeline = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    custom_pipeline=\"clip_guided_stable_diffusion\",\n    clip_model=clip_model,\n    feature_extractor=feature_extractor,\n    torch_dtype=torch.float16,\n)\nguided_pipeline.enable_attention_slicing()\nguided_pipeline = guided_pipeline.to(\"cuda\")\n\nprompt = \"fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece\"\n\ngenerator = torch.Generator(device=\"cuda\").manual_seed(0)\nimages = []\nfor i in range(4):\n    image = guided_pipeline(\n        prompt,\n        num_inference_steps=50,\n        guidance_scale=7.5,\n        clip_guidance_scale=100,\n        num_cutouts=4,\n        use_cutouts=False,\n        generator=generator,\n    ).images[0]\n    images.append(image)\n\n# save images locally\nfor i, img in enumerate(images):\n    img.save(f\"./clip_guided_sd/image_{i}.png\")\n```\n\nThe `images` list contains a list of PIL images that can be saved locally or displayed directly in a google colab.\nGenerated images tend to be of higher quality than natively using stable diffusion. E.g. the above script generates the following images:\n\n![clip_guidance](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/clip_guidance/merged_clip_guidance.jpg).\n\n### One Step Unet\n\nThe dummy \"one-step-unet\" can be run as follows:\n\n```python\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"google/ddpm-cifar10-32\", custom_pipeline=\"one_step_unet\")\npipe()\n```\n\n**Note**: This community pipeline is not useful as a feature, but rather just serves as an example of how community pipelines can be added (see <https://github.com/huggingface/diffusers/issues/841>).\n\n### Stable Diffusion Interpolation\n\nThe following code can be run on a GPU of at least 8GB VRAM and should take approximately 5 minutes.\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipe = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    variant='fp16',\n    torch_dtype=torch.float16,\n    safety_checker=None,  # Very important for videos...lots of false positives while interpolating\n    custom_pipeline=\"interpolate_stable_diffusion\",\n).to('cuda')\npipe.enable_attention_slicing()\n\nframe_filepaths = pipe.walk(\n    prompts=['a dog', 'a cat', 'a horse'],\n    seeds=[42, 1337, 1234],\n    num_interpolation_steps=16,\n    output_dir='./dreams',\n    batch_size=4,\n    height=512,\n    width=512,\n    guidance_scale=8.5,\n    num_inference_steps=50,\n)\n```\n\nThe output of the `walk(...)` function returns a list of images saved under the folder as defined in `output_dir`. You can use these images to create videos of stable diffusion.\n\n> **Please have a look at <https://github.com/nateraw/stable-diffusion-videos> for more in-detail information on how to create videos using stable diffusion as well as more feature-complete functionality.**\n\n### Stable Diffusion Mega\n\nThe Stable Diffusion Mega Pipeline lets you use the main use cases of the stable diffusion pipeline in a single class.\n\n```python\n#!/usr/bin/env python3\nfrom diffusers import DiffusionPipeline\nimport PIL\nimport requests\nfrom io import BytesIO\nimport torch\n\n\ndef download_image(url):\n    response = requests.get(url)\n    return PIL.Image.open(BytesIO(response.content)).convert(\"RGB\")\n\npipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", custom_pipeline=\"stable_diffusion_mega\", torch_dtype=torch.float16, variant=\"fp16\")\npipe.to(\"cuda\")\npipe.enable_attention_slicing()\n\n\n### Text-to-Image\nimages = pipe.text2img(\"An astronaut riding a horse\").images\n\n### Image-to-Image\ninit_image = download_image(\"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\")\n\nprompt = \"A fantasy landscape, trending on artstation\"\n\nimages = pipe.img2img(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images\n\n### Inpainting\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\ninit_image = download_image(img_url).resize((512, 512))\nmask_image = download_image(mask_url).resize((512, 512))\n\nprompt = \"a cat sitting on a bench\"\nimages = pipe.inpaint(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.75).images\n```\n\nAs shown above this one pipeline can run all both \"text-to-image\", \"image-to-image\", and \"inpainting\" in one pipeline.\n\n### Long Prompt Weighting Stable Diffusion\n\nFeatures of this custom pipeline:\n\n- Input a prompt without the 77 token length limit.\n- Includes tx2img, img2img, and inpainting pipelines.\n- Emphasize/weigh part of your prompt with parentheses as so: `a baby deer with (big eyes)`\n- De-emphasize part of your prompt as so: `a [baby] deer with big eyes`\n- Precisely weigh part of your prompt as so: `a baby deer with (big eyes:1.3)`\n\nPrompt weighting equivalents:\n\n- `a baby deer with` == `(a baby deer with:1.0)`\n- `(big eyes)` == `(big eyes:1.1)`\n- `((big eyes))` == `(big eyes:1.21)`\n- `[big eyes]` == `(big eyes:0.91)`\n\nYou can run this custom pipeline as so:\n\n#### PyTorch\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipe = DiffusionPipeline.from_pretrained(\n    'hakurei/waifu-diffusion',\n    custom_pipeline=\"lpw_stable_diffusion\",\n    torch_dtype=torch.float16\n)\npipe = pipe.to(\"cuda\")\n\nprompt = \"best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms\"\nneg_prompt = \"lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry\"\n\npipe.text2img(prompt, negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0]\n```\n\n#### onnxruntime\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipe = DiffusionPipeline.from_pretrained(\n    'CompVis/stable-diffusion-v1-4',\n    custom_pipeline=\"lpw_stable_diffusion_onnx\",\n    revision=\"onnx\",\n    provider=\"CUDAExecutionProvider\"\n)\n\nprompt = \"a photo of an astronaut riding a horse on mars, best quality\"\nneg_prompt = \"lowres, bad anatomy, error body, error hair, error arm, error hands, bad hands, error fingers, bad fingers, missing fingers, error legs, bad legs, multiple legs, missing legs, error lighting, error shadow, error reflection, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry\"\n\npipe.text2img(prompt, negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0]\n```\n\nIf you see `Token indices sequence length is longer than the specified maximum sequence length for this model ( *** > 77 ) . Running this sequence through the model will result in indexing errors`. Do not worry, it is normal.\n\n### Speech to Image\n\nThe following code can generate an image from an audio sample using pre-trained OpenAI whisper-small and Stable Diffusion.\n\n```Python\nimport torch\n\nimport matplotlib.pyplot as plt\nfrom datasets import load_dataset\nfrom diffusers import DiffusionPipeline\nfrom transformers import (\n    WhisperForConditionalGeneration,\n    WhisperProcessor,\n)\n\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n\naudio_sample = ds[3]\n\ntext = audio_sample[\"text\"].lower()\nspeech_data = audio_sample[\"audio\"][\"array\"]\n\nmodel = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-small\").to(device)\nprocessor = WhisperProcessor.from_pretrained(\"openai/whisper-small\")\n\ndiffuser_pipeline = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    custom_pipeline=\"speech_to_image_diffusion\",\n    speech_model=model,\n    speech_processor=processor,\n    torch_dtype=torch.float16,\n)\n\ndiffuser_pipeline.enable_attention_slicing()\ndiffuser_pipeline = diffuser_pipeline.to(device)\n\noutput = diffuser_pipeline(speech_data)\nplt.imshow(output.images[0])\n```\n\nThis example produces the following image:\n\n![image](https://user-images.githubusercontent.com/45072645/196901736-77d9c6fc-63ee-4072-90b0-dc8b903d63e3.png)\n\n### Wildcard Stable Diffusion\n\nFollowing the great examples from <https://github.com/jtkelm2/stable-diffusion-webui-1/blob/master/scripts/wildcards.py> and <https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts#wildcards>, here's a minimal implementation that allows for users to add \"wildcards\", denoted by `__wildcard__` to prompts that are used as placeholders for randomly sampled values given by either a dictionary or a `.txt` file. For example:\n\nSay we have a prompt:\n\n```\nprompt = \"__animal__ sitting on a __object__ wearing a __clothing__\"\n```\n\nWe can then define possible values to be sampled for `animal`, `object`, and `clothing`. These can either be from a `.txt` with the same name as the category.\n\nThe possible values can also be defined / combined by using a dictionary like: `{\"animal\":[\"dog\", \"cat\", mouse\"]}`.\n\nThe actual pipeline works just like `StableDiffusionPipeline`, except the `__call__` method takes in:\n\n`wildcard_files`: list of file paths for wild card replacement\n`wildcard_option_dict`: dict with key as `wildcard` and values as a list of possible replacements\n`num_prompt_samples`: number of prompts to sample, uniformly sampling wildcards\n\nA full example:\n\ncreate `animal.txt`, with contents like:\n\n```\ndog\ncat\nmouse\n```\n\ncreate `object.txt`, with contents like:\n\n```\nchair\nsofa\nbench\n```\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipe = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    custom_pipeline=\"wildcard_stable_diffusion\",\n    torch_dtype=torch.float16,\n)\nprompt = \"__animal__ sitting on a __object__ wearing a __clothing__\"\nout = pipe(\n    prompt,\n    wildcard_option_dict={\n        \"clothing\":[\"hat\", \"shirt\", \"scarf\", \"beret\"]\n    },\n    wildcard_files=[\"object.txt\", \"animal.txt\"],\n    num_prompt_samples=1\n)\nout.images[0].save(\"image.png\")\ntorch.cuda.empty_cache()\n```\n\n### Composable Stable diffusion\n\n[Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) proposes conjunction and negation (negative prompts) operators for compositional generation with conditional diffusion models.\n\n```python\nimport torch as th\nimport numpy as np\nimport torchvision.utils as tvu\n\nfrom diffusers import DiffusionPipeline\n\nimport argparse\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--prompt\", type=str, default=\"mystical trees | A magical pond | dark\",\n                    help=\"use '|' as the delimiter to compose separate sentences.\")\nparser.add_argument(\"--steps\", type=int, default=50)\nparser.add_argument(\"--scale\", type=float, default=7.5)\nparser.add_argument(\"--weights\", type=str, default=\"7.5 | 7.5 | -7.5\")\nparser.add_argument(\"--seed\", type=int, default=2)\nparser.add_argument(\"--model_path\", type=str, default=\"CompVis/stable-diffusion-v1-4\")\nparser.add_argument(\"--num_images\", type=int, default=1)\nargs = parser.parse_args()\n\nhas_cuda = th.cuda.is_available()\ndevice = th.device('cpu' if not has_cuda else 'cuda')\n\nprompt = args.prompt\nscale = args.scale\nsteps = args.steps\n\npipe = DiffusionPipeline.from_pretrained(\n    args.model_path,\n    custom_pipeline=\"composable_stable_diffusion\",\n).to(device)\n\npipe.safety_checker = None\n\nimages = []\ngenerator = th.Generator(\"cuda\").manual_seed(args.seed)\nfor i in range(args.num_images):\n    image = pipe(prompt, guidance_scale=scale, num_inference_steps=steps,\n                 weights=args.weights, generator=generator).images[0]\n    images.append(th.from_numpy(np.array(image)).permute(2, 0, 1) / 255.)\ngrid = tvu.make_grid(th.stack(images, dim=0), nrow=4, padding=0)\ntvu.save_image(grid, f'{prompt}_{args.weights}' + '.png')\nprint(\"Image saved successfully!\")\n```\n\n### Imagic Stable Diffusion\n\nAllows you to edit an image using stable diffusion.\n\n```python\nimport requests\nfrom PIL import Image\nfrom io import BytesIO\nimport torch\nimport os\nfrom diffusers import DiffusionPipeline, DDIMScheduler\n\nhas_cuda = torch.cuda.is_available()\ndevice = torch.device('cpu' if not has_cuda else 'cuda')\npipe = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    safety_checker=None,\n    custom_pipeline=\"imagic_stable_diffusion\",\n    scheduler=DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n).to(device)\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\nseed = 0\nprompt = \"A photo of Barack Obama smiling with a big grin\"\nurl = 'https://www.dropbox.com/s/6tlwzr73jd1r9yk/obama.png?dl=1'\nresponse = requests.get(url)\ninit_image = Image.open(BytesIO(response.content)).convert(\"RGB\")\ninit_image = init_image.resize((512, 512))\nres = pipe.train(\n    prompt,\n    image=init_image,\n    generator=generator)\nres = pipe(alpha=1, guidance_scale=7.5, num_inference_steps=50)\nos.makedirs(\"imagic\", exist_ok=True)\nimage = res.images[0]\nimage.save('./imagic/imagic_image_alpha_1.png')\nres = pipe(alpha=1.5, guidance_scale=7.5, num_inference_steps=50)\nimage = res.images[0]\nimage.save('./imagic/imagic_image_alpha_1_5.png')\nres = pipe(alpha=2, guidance_scale=7.5, num_inference_steps=50)\nimage = res.images[0]\nimage.save('./imagic/imagic_image_alpha_2.png')\n```\n\n### Seed Resizing\n\nTest seed resizing. Originally generate an image in 512 by 512, then generate image with same seed at 512 by 592 using seed resizing. Finally, generate 512 by 592 using original stable diffusion pipeline.\n\n```python\nimport os\nimport torch as th\nimport numpy as np\nfrom diffusers import DiffusionPipeline\n\n# Ensure the save directory exists or create it\nsave_dir = './seed_resize/'\nos.makedirs(save_dir, exist_ok=True)\n\nhas_cuda = th.cuda.is_available()\ndevice = th.device('cpu' if not has_cuda else 'cuda')\n\npipe = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    custom_pipeline=\"seed_resize_stable_diffusion\"\n).to(device)\n\ndef dummy(images, **kwargs):\n    return images, False\n\npipe.safety_checker = dummy\n\nimages = []\nth.manual_seed(0)\ngenerator = th.Generator(\"cuda\").manual_seed(0)\n\nseed = 0\nprompt = \"A painting of a futuristic cop\"\n\nwidth = 512\nheight = 512\n\nres = pipe(\n    prompt,\n    guidance_scale=7.5,\n    num_inference_steps=50,\n    height=height,\n    width=width,\n    generator=generator)\nimage = res.images[0]\nimage.save(os.path.join(save_dir, 'seed_resize_{w}_{h}_image.png'.format(w=width, h=height)))\n\nth.manual_seed(0)\ngenerator = th.Generator(\"cuda\").manual_seed(0)\n\npipe = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    custom_pipeline=\"seed_resize_stable_diffusion\"\n).to(device)\n\nwidth = 512\nheight = 592\n\nres = pipe(\n    prompt,\n    guidance_scale=7.5,\n    num_inference_steps=50,\n    height=height,\n    width=width,\n    generator=generator)\nimage = res.images[0]\nimage.save(os.path.join(save_dir, 'seed_resize_{w}_{h}_image.png'.format(w=width, h=height)))\n\npipe_compare = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    custom_pipeline=\"seed_resize_stable_diffusion\"\n).to(device)\n\nres = pipe_compare(\n    prompt,\n    guidance_scale=7.5,\n    num_inference_steps=50,\n    height=height,\n    width=width,\n    generator=generator\n)\n\nimage = res.images[0]\nimage.save(os.path.join(save_dir, 'seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height)))\n```\n\n### Multilingual Stable Diffusion Pipeline\n\nThe following code can generate images from texts in different languages using the pre-trained [mBART-50 many-to-one multilingual machine translation model](https://huggingface.co/facebook/mbart-large-50-many-to-one-mmt) and Stable Diffusion.\n\n```python\nfrom PIL import Image\n\nimport torch\n\nfrom diffusers import DiffusionPipeline\nfrom transformers import (\n    pipeline,\n    MBart50TokenizerFast,\n    MBartForConditionalGeneration,\n)\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\ndevice_dict = {\"cuda\": 0, \"cpu\": -1}\n\n# helper function taken from: https://huggingface.co/blog/stable_diffusion\ndef image_grid(imgs, rows, cols):\n    assert len(imgs) == rows*cols\n\n    w, h = imgs[0].size\n    grid = Image.new('RGB', size=(cols*w, rows*h))\n    grid_w, grid_h = grid.size\n\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i%cols*w, i//cols*h))\n    return grid\n\n# Add language detection pipeline\nlanguage_detection_model_ckpt = \"papluca/xlm-roberta-base-language-detection\"\nlanguage_detection_pipeline = pipeline(\"text-classification\",\n                                       model=language_detection_model_ckpt,\n                                       device=device_dict[device])\n\n# Add model for language translation\ntrans_tokenizer = MBart50TokenizerFast.from_pretrained(\"facebook/mbart-large-50-many-to-one-mmt\")\ntrans_model = MBartForConditionalGeneration.from_pretrained(\"facebook/mbart-large-50-many-to-one-mmt\").to(device)\n\ndiffuser_pipeline = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    custom_pipeline=\"multilingual_stable_diffusion\",\n    detection_pipeline=language_detection_pipeline,\n    translation_model=trans_model,\n    translation_tokenizer=trans_tokenizer,\n    torch_dtype=torch.float16,\n)\n\ndiffuser_pipeline.enable_attention_slicing()\ndiffuser_pipeline = diffuser_pipeline.to(device)\n\nprompt = [\"a photograph of an astronaut riding a horse\",\n          \"Una casa en la playa\",\n          \"Ein Hund, der Orange isst\",\n          \"Un restaurant parisien\"]\n\noutput = diffuser_pipeline(prompt)\n\nimages = output.images\n\ngrid = image_grid(images, rows=2, cols=2)\n```\n\nThis example produces the following images:\n![image](https://user-images.githubusercontent.com/4313860/198328706-295824a4-9856-4ce5-8e66-278ceb42fd29.png)\n\n### GlueGen Stable Diffusion Pipeline\n\nGlueGen is a minimal adapter that allows alignment between any encoder (Text Encoder of different language, Multilingual Roberta, AudioClip) and CLIP text encoder used in standard Stable Diffusion model. This method allows easy language adaptation to available english Stable Diffusion checkpoints without the need of an image captioning dataset as well as long training hours.\n\nMake sure you downloaded `gluenet_French_clip_overnorm_over3_noln.ckpt` for French (there are also pre-trained weights for Chinese, Italian, Japanese, Spanish or train your own) at [GlueGen's official repo](https://github.com/salesforce/GlueGen/tree/main).\n\n```python\nimport os\nimport gc\nimport urllib.request\nimport torch\nfrom transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM, CLIPTokenizer, CLIPTextModel\nfrom diffusers import DiffusionPipeline\n\n# Download checkpoints\nCHECKPOINTS = [\n    \"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Chinese_clip_overnorm_over3_noln.ckpt\",\n    \"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_French_clip_overnorm_over3_noln.ckpt\",\n    \"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Italian_clip_overnorm_over3_noln.ckpt\",\n    \"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Japanese_clip_overnorm_over3_noln.ckpt\",\n    \"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Spanish_clip_overnorm_over3_noln.ckpt\",\n    \"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_sound2img_audioclip_us8k.ckpt\"\n]\n\nLANGUAGE_PROMPTS = {\n    \"French\": \"une voiture sur la plage\",\n    #\"Chinese\": \"海滩上的一辆车\",\n    #\"Italian\": \"una macchina sulla spiaggia\",\n    #\"Japanese\": \"浜辺の車\",\n    #\"Spanish\": \"un coche en la playa\"\n}\n\ndef download_checkpoints(checkpoint_dir):\n    os.makedirs(checkpoint_dir, exist_ok=True)\n    for url in CHECKPOINTS:\n        filename = os.path.join(checkpoint_dir, os.path.basename(url))\n        if not os.path.exists(filename):\n            print(f\"Downloading {filename}...\")\n            urllib.request.urlretrieve(url, filename)\n            print(f\"Downloaded {filename}\")\n        else:\n            print(f\"Checkpoint {filename} already exists, skipping download.\")\n    return checkpoint_dir\n\ndef load_checkpoint(pipeline, checkpoint_path, device):\n    state_dict = torch.load(checkpoint_path, map_location=device)\n    state_dict = state_dict.get(\"state_dict\", state_dict)\n    missing_keys, unexpected_keys = pipeline.unet.load_state_dict(state_dict, strict=False)\n    return pipeline\n\ndef generate_image(pipeline, prompt, device, output_path):\n    with torch.inference_mode():\n        image = pipeline(\n            prompt,\n            generator=torch.Generator(device=device).manual_seed(42),\n            num_inference_steps=50\n        ).images[0]\n        image.save(output_path)\n        print(f\"Image saved to {output_path}\")\n\ncheckpoint_dir = download_checkpoints(\"./checkpoints_all/gluenet_checkpoint\")\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\nprint(f\"Using device: {device}\")\n\ntokenizer = XLMRobertaTokenizer.from_pretrained(\"xlm-roberta-base\", use_fast=False)\nmodel = XLMRobertaForMaskedLM.from_pretrained(\"xlm-roberta-base\").to(device)\ninputs = tokenizer(\"Ceci est une phrase incomplète avec un [MASK].\", return_tensors=\"pt\").to(device)\nwith torch.inference_mode():\n    _ = model(**inputs)\n\n\nclip_tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\")\nclip_text_encoder = CLIPTextModel.from_pretrained(\"openai/clip-vit-large-patch14\").to(device)\n\n# Initialize pipeline\npipeline = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    text_encoder=clip_text_encoder,\n    tokenizer=clip_tokenizer,\n    custom_pipeline=\"gluegen\",\n    safety_checker=None\n).to(device)\n\nos.makedirs(\"outputs\", exist_ok=True)\n\n# Generate images\nfor language, prompt in LANGUAGE_PROMPTS.items():\n\n    checkpoint_file = f\"gluenet_{language}_clip_overnorm_over3_noln.ckpt\"\n    checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)\n    try:\n        pipeline = load_checkpoint(pipeline, checkpoint_path, device)\n        output_path = f\"outputs/gluegen_output_{language.lower()}.png\"\n        generate_image(pipeline, prompt, device, output_path)\n    except Exception as e:\n        print(f\"Error processing {language} model: {e}\")\n        continue\n\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n    gc.collect()\n```\n\nWhich will produce:\n\n![output_image](https://github.com/rootonchair/diffusers/assets/23548268/db43ffb6-8667-47c1-8872-26f85dc0a57f)\n\n### Image to Image Inpainting Stable Diffusion\n\nSimilar to the standard stable diffusion inpainting example, except with the addition of an `inner_image` argument.\n\n`image`, `inner_image`, and `mask` should have the same dimensions. `inner_image` should have an alpha (transparency) channel.\n\nThe aim is to overlay two images, then mask out the boundary between `image` and `inner_image` to allow stable diffusion to make the connection more seamless.\nFor example, this could be used to place a logo on a shirt and make it blend seamlessly.\n\n```python\nimport torch\nimport requests\nfrom PIL import Image\nfrom io import BytesIO\nfrom diffusers import DiffusionPipeline\n\nimage_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\ninner_image_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\ndef load_image(url, mode=\"RGB\"):\n    response = requests.get(url)\n    if response.status_code == 200:\n        return Image.open(BytesIO(response.content)).convert(mode).resize((512, 512))\n    else:\n        raise FileNotFoundError(f\"Could not retrieve image from {url}\")\n\n\ninit_image = load_image(image_url, mode=\"RGB\")\ninner_image = load_image(inner_image_url, mode=\"RGBA\")\nmask_image = load_image(mask_url, mode=\"RGB\")\n\npipe = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\",\n    custom_pipeline=\"img2img_inpainting\",\n    torch_dtype=torch.float16\n)\npipe = pipe.to(\"cuda\")\n\nprompt = \"a mecha robot sitting on a bench\"\nimage = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]\n\nimage.save(\"output.png\")\n```\n\n![2 by 2 grid demonstrating image to image inpainting.](https://user-images.githubusercontent.com/44398246/203506577-ec303be4-887e-4ebd-a773-c83fcb3dd01a.png)\n\n### Text Based Inpainting Stable Diffusion\n\nUse a text prompt to generate the mask for the area to be inpainted.\nCurrently uses the CLIPSeg model for mask generation, then calls the standard Stable Diffusion Inpainting pipeline to perform the inpainting.\n\n```python\nfrom transformers import CLIPSegProcessor, CLIPSegForImageSegmentation\nfrom diffusers import DiffusionPipeline\nfrom PIL import Image\nimport requests\nimport torch\n\n# Load CLIPSeg model and processor\nprocessor = CLIPSegProcessor.from_pretrained(\"CIDAS/clipseg-rd64-refined\")\nmodel = CLIPSegForImageSegmentation.from_pretrained(\"CIDAS/clipseg-rd64-refined\").to(\"cuda\")\n\n# Load Stable Diffusion Inpainting Pipeline with custom pipeline\npipe = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-inpainting\",\n    custom_pipeline=\"text_inpainting\",\n    segmentation_model=model,\n    segmentation_processor=processor\n).to(\"cuda\")\n\n# Load input image\nurl = \"https://github.com/timojl/clipseg/blob/master/example_image.jpg?raw=true\"\nimage = Image.open(requests.get(url, stream=True).raw)\n\n# Step 1: Resize input image for CLIPSeg (224x224)\nsegmentation_input = image.resize((224, 224))\n\n# Step 2: Generate segmentation mask\ntext = \"a glass\"  # Object to mask\ninputs = processor(text=text, images=segmentation_input, return_tensors=\"pt\").to(\"cuda\")\n\nwith torch.no_grad():\n    mask = model(**inputs).logits.sigmoid()  # Get segmentation mask\n\n# Resize mask back to 512x512 for SD inpainting\nmask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(512, 512), mode=\"bilinear\").squeeze(0)\n\n# Step 3: Resize input image for Stable Diffusion\nimage = image.resize((512, 512))\n\n# Step 4: Run inpainting with Stable Diffusion\nprompt = \"a cup\"  # The masked-out region will be replaced with this\nresult = pipe(image=image, mask=mask, prompt=prompt,text=text).images[0]\n\n# Save output\nresult.save(\"inpainting_output.png\")\nprint(\"Inpainting completed. Image saved as 'inpainting_output.png'.\")\n```\n\n### Bit Diffusion\n\nBased <https://huggingface.co/papers/2208.04202>, this is used for diffusion on discrete data - eg, discrete image data, DNA sequence data. An unconditional discrete image can be generated like this:\n\n```python\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"google/ddpm-cifar10-32\", custom_pipeline=\"bit_diffusion\")\nimage = pipe().images[0]\n```\n\n### Stable Diffusion with K Diffusion\n\nMake sure you have @crowsonkb's <https://github.com/crowsonkb/k-diffusion> installed:\n\n```sh\npip install k-diffusion\n```\n\nYou can use the community pipeline as follows:\n\n```python\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", custom_pipeline=\"sd_text2img_k_diffusion\")\npipe = pipe.to(\"cuda\")\n\nprompt = \"an astronaut riding a horse on mars\"\npipe.set_scheduler(\"sample_heun\")\ngenerator = torch.Generator(device=\"cuda\").manual_seed(seed)\nimage = pipe(prompt, generator=generator, num_inference_steps=20).images[0]\n\nimage.save(\"./astronaut_heun_k_diffusion.png\")\n```\n\nTo make sure that K Diffusion and `diffusers` yield the same results:\n\n**Diffusers**:\n\n```python\nfrom diffusers import DiffusionPipeline, EulerDiscreteScheduler\n\nseed = 33\n\npipe = DiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\")\npipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)\npipe = pipe.to(\"cuda\")\n\ngenerator = torch.Generator(device=\"cuda\").manual_seed(seed)\nimage = pipe(prompt, generator=generator, num_inference_steps=50).images[0]\n```\n\n![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler.png)\n\n**K Diffusion**:\n\n```python\nfrom diffusers import DiffusionPipeline, EulerDiscreteScheduler\n\nseed = 33\n\npipe = DiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", custom_pipeline=\"sd_text2img_k_diffusion\")\npipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)\npipe = pipe.to(\"cuda\")\n\npipe.set_scheduler(\"sample_euler\")\ngenerator = torch.Generator(device=\"cuda\").manual_seed(seed)\nimage = pipe(prompt, generator=generator, num_inference_steps=50).images[0]\n```\n\n![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler_k_diffusion.png)\n\n### Checkpoint Merger Pipeline\n\nBased on the AUTOMATIC1111/webui for checkpoint merging. This is a custom pipeline that merges up to 3 pretrained model checkpoints as long as they are in the HuggingFace model_index.json format.\n\nThe checkpoint merging is currently memory intensive as it modifies the weights of a DiffusionPipeline object in place. Expect at least 13GB RAM usage on Kaggle GPU kernels and\non Colab you might run out of the 12GB memory even while merging two checkpoints.\n\nUsage:-\n\n```python\nfrom diffusers import DiffusionPipeline\n\n# Return a CheckpointMergerPipeline class that allows you to merge checkpoints.\n# The checkpoint passed here is ignored. But still pass one of the checkpoints you plan to\n# merge for convenience\npipe = DiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", custom_pipeline=\"checkpoint_merger\")\n\n# There are multiple possible scenarios:\n# The pipeline with the merged checkpoints is returned in all the scenarios\n\n# Compatible checkpoints a.k.a matched model_index.json files. Ignores the meta attributes in model_index.json during comparison.( attrs with _ as prefix )\nmerged_pipe = pipe.merge([\"CompVis/stable-diffusion-v1-4\",\" CompVis/stable-diffusion-v1-2\"], interp=\"sigmoid\", alpha=0.4)\n\n# Incompatible checkpoints in model_index.json but merge might be possible. Use force=True to ignore model_index.json compatibility\nmerged_pipe_1 = pipe.merge([\"CompVis/stable-diffusion-v1-4\", \"hakurei/waifu-diffusion\"], force=True, interp=\"sigmoid\", alpha=0.4)\n\n# Three checkpoint merging. Only \"add_difference\" method actually works on all three checkpoints. Using any other options will ignore the 3rd checkpoint.\nmerged_pipe_2 = pipe.merge([\"CompVis/stable-diffusion-v1-4\", \"hakurei/waifu-diffusion\", \"prompthero/openjourney\"], force=True, interp=\"add_difference\", alpha=0.4)\n\nprompt = \"An astronaut riding a horse on Mars\"\n\nimage = merged_pipe(prompt).images[0]\n```\n\nSome examples along with the merge details:\n\n1. \"CompVis/stable-diffusion-v1-4\" + \"hakurei/waifu-diffusion\" ; Sigmoid interpolation; alpha = 0.8\n\n![Stable plus Waifu Sigmoid 0.8](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/stability_v1_4_waifu_sig_0.8.png)\n\n2. \"hakurei/waifu-diffusion\" + \"prompthero/openjourney\" ; Inverse Sigmoid interpolation; alpha = 0.8\n\n![Waifu plus openjourney Sigmoid 0.8](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/waifu_openjourney_inv_sig_0.8.png)\n\n3. \"CompVis/stable-diffusion-v1-4\" + \"hakurei/waifu-diffusion\" + \"prompthero/openjourney\"; Add Difference interpolation; alpha = 0.5\n\n![Stable plus Waifu plus openjourney add_diff 0.5](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/stable_waifu_openjourney_add_diff_0.5.png)\n\n### Stable Diffusion Comparisons\n\nThis Community Pipeline enables the comparison between the 4 checkpoints that exist for Stable Diffusion. They can be found through the following links:\n\n1. [Stable Diffusion v1.1](https://huggingface.co/CompVis/stable-diffusion-v1-1)\n2. [Stable Diffusion v1.2](https://huggingface.co/CompVis/stable-diffusion-v1-2)\n3. [Stable Diffusion v1.3](https://huggingface.co/CompVis/stable-diffusion-v1-3)\n4. [Stable Diffusion v1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4)\n\n```python\nfrom diffusers import DiffusionPipeline\nimport matplotlib.pyplot as plt\n\npipe = DiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', custom_pipeline='suvadityamuk/StableDiffusionComparison')\npipe.enable_attention_slicing()\npipe = pipe.to('cuda')\nprompt = \"an astronaut riding a horse on mars\"\noutput = pipe(prompt)\n\nplt.subplots(2,2,1)\nplt.imshow(output.images[0])\nplt.title('Stable Diffusion v1.1')\nplt.axis('off')\nplt.subplots(2,2,2)\nplt.imshow(output.images[1])\nplt.title('Stable Diffusion v1.2')\nplt.axis('off')\nplt.subplots(2,2,3)\nplt.imshow(output.images[2])\nplt.title('Stable Diffusion v1.3')\nplt.axis('off')\nplt.subplots(2,2,4)\nplt.imshow(output.images[3])\nplt.title('Stable Diffusion v1.4')\nplt.axis('off')\n\nplt.show()\n```\n\nAs a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints.\n\n### Magic Mix\n\nImplementation of the [MagicMix: Semantic Mixing with Diffusion Models](https://huggingface.co/papers/2210.16056) paper. This is a Diffusion Pipeline for semantic mixing of an image and a text prompt to create a new concept while preserving the spatial layout and geometry of the subject in the image. The pipeline takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process.\n\nThere are 3 parameters for the method-\n\n- `mix_factor`: It is the interpolation constant used in the layout generation phase. The greater the value of `mix_factor`, the greater the influence of the prompt on the layout generation process.\n- `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process.\n\nHere is an example usage-\n\n```python\nimport requests\nfrom diffusers import DiffusionPipeline, DDIMScheduler\nfrom PIL import Image\nfrom io import BytesIO\n\npipe = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    custom_pipeline=\"magic_mix\",\n    scheduler=DDIMScheduler.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"scheduler\"),\n).to('cuda')\n\nurl = \"https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg\"\nresponse = requests.get(url)\nimage = Image.open(BytesIO(response.content)).convert(\"RGB\")  # Convert to RGB to avoid issues\nmix_img = pipe(\n    image,\n    prompt='bed',\n    kmin=0.3,\n    kmax=0.5,\n    mix_factor=0.5,\n    )\nmix_img.save('phone_bed_mix.jpg')\n```\n\nThe `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt.\n\nE.g. the above script generates the following image:\n\n`phone.jpg`\n\n![206903102-34e79b9f-9ed2-4fac-bb38-82871343c655](https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg)\n\n`phone_bed_mix.jpg`\n\n![206903104-913a671d-ef53-4ae4-919d-64c3059c8f67](https://user-images.githubusercontent.com/59410571/209578602-70f323fa-05b7-4dd6-b055-e40683e37914.jpg)\n\nFor more example generations check out this [demo notebook](https://github.com/daspartho/MagicMix/blob/main/demo.ipynb).\n\n### Stable UnCLIP\n\nUnCLIPPipeline(\"kakaobrain/karlo-v1-alpha\") provides a prior model that can generate clip image embedding from text.\nStableDiffusionImageVariationPipeline(\"lambdalabs/sd-image-variations-diffusers\") provides a decoder model than can generate images from clip image embedding.\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\n\ndevice = torch.device(\"cpu\" if not torch.cuda.is_available() else \"cuda\")\n\npipeline = DiffusionPipeline.from_pretrained(\n    \"kakaobrain/karlo-v1-alpha\",\n    torch_dtype=torch.float16,\n    custom_pipeline=\"stable_unclip\",\n    decoder_pipe_kwargs=dict(\n        image_encoder=None,\n    ),\n)\npipeline.to(device)\n\nprompt = \"a shiba inu wearing a beret and black turtleneck\"\nrandom_generator = torch.Generator(device=device).manual_seed(1000)\noutput = pipeline(\n    prompt=prompt,\n    width=512,\n    height=512,\n    generator=random_generator,\n    prior_guidance_scale=4,\n    prior_num_inference_steps=25,\n    decoder_guidance_scale=8,\n    decoder_num_inference_steps=50,\n)\n\nimage = output.images[0]\nimage.save(\"./shiba-inu.jpg\")\n\n# debug\n\n# `pipeline.decoder_pipe` is a regular StableDiffusionImageVariationPipeline instance.\n# It is used to convert clip image embedding to latents, then fed into VAE decoder.\nprint(pipeline.decoder_pipe.__class__)\n# <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline'>\n\n# this pipeline only uses prior module in \"kakaobrain/karlo-v1-alpha\"\n# It is used to convert clip text embedding to clip image embedding.\nprint(pipeline)\n# StableUnCLIPPipeline {\n#   \"_class_name\": \"StableUnCLIPPipeline\",\n#   \"_diffusers_version\": \"0.12.0.dev0\",\n#   \"prior\": [\n#     \"diffusers\",\n#     \"PriorTransformer\"\n#   ],\n#   \"prior_scheduler\": [\n#     \"diffusers\",\n#     \"UnCLIPScheduler\"\n#   ],\n#   \"text_encoder\": [\n#     \"transformers\",\n#     \"CLIPTextModelWithProjection\"\n#   ],\n#   \"tokenizer\": [\n#     \"transformers\",\n#     \"CLIPTokenizer\"\n#   ]\n# }\n\n# pipeline.prior_scheduler is the scheduler used for prior in UnCLIP.\nprint(pipeline.prior_scheduler)\n# UnCLIPScheduler {\n#   \"_class_name\": \"UnCLIPScheduler\",\n#   \"_diffusers_version\": \"0.12.0.dev0\",\n#   \"clip_sample\": true,\n#   \"clip_sample_range\": 5.0,\n#   \"num_train_timesteps\": 1000,\n#   \"prediction_type\": \"sample\",\n#   \"variance_type\": \"fixed_small_log\"\n# }\n```\n\n`shiba-inu.jpg`\n\n![shiba-inu](https://user-images.githubusercontent.com/16448529/209185639-6e5ec794-ce9d-4883-aa29-bd6852a2abad.jpg)\n\n### UnCLIP Text Interpolation Pipeline\n\nThis Diffusion Pipeline takes two prompts and interpolates between the two input prompts using spherical interpolation ( slerp ). The input prompts are converted to text embeddings by the pipeline's text_encoder and the interpolation is done on the resulting text_embeddings over the number of steps specified. Defaults to 5 steps.\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\n\ndevice = torch.device(\"cpu\" if not torch.cuda.is_available() else \"cuda\")\n\npipe = DiffusionPipeline.from_pretrained(\n    \"kakaobrain/karlo-v1-alpha\",\n    torch_dtype=torch.float16,\n    custom_pipeline=\"unclip_text_interpolation\"\n)\npipe.to(device)\n\nstart_prompt = \"A photograph of an adult lion\"\nend_prompt = \"A photograph of a lion cub\"\n# For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths.\ngenerator = torch.Generator(device=device).manual_seed(42)\n\noutput = pipe(start_prompt, end_prompt, steps=6, generator=generator, enable_sequential_cpu_offload=False)\n\nfor i,image in enumerate(output.images):\n    img.save('result%s.jpg' % i)\n```\n\nThe resulting images in order:-\n\n![result_0](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_0.png)\n![result_1](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_1.png)\n![result_2](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_2.png)\n![result_3](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_3.png)\n![result_4](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_4.png)\n![result_5](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPTextInterpolationSamples/resolve/main/lion_to_cub_5.png)\n\n### UnCLIP Image Interpolation Pipeline\n\nThis Diffusion Pipeline takes two images or an image_embeddings tensor of size 2 and interpolates between their embeddings using spherical interpolation ( slerp ). The input images/image_embeddings are converted to image embeddings by the pipeline's image_encoder and the interpolation is done on the resulting image_embeddings over the number of steps specified. Defaults to 5 steps.\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom PIL import Image\nimport requests\nfrom io import BytesIO\n\ndevice = torch.device(\"cpu\" if not torch.cuda.is_available() else \"cuda\")\ndtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16\n\npipe = DiffusionPipeline.from_pretrained(\n    \"kakaobrain/karlo-v1-alpha-image-variations\",\n    torch_dtype=dtype,\n    custom_pipeline=\"unclip_image_interpolation\"\n)\npipe.to(device)\n\n# List of image URLs\nimage_urls = [\n    'https://camo.githubusercontent.com/ef13c8059b12947c0d5e8d3ea88900de6bf1cd76bbf61ace3928e824c491290e/68747470733a2f2f68756767696e67666163652e636f2f64617461736574732f4e616761536169416268696e61792f556e434c4950496d616765496e746572706f6c6174696f6e53616d706c65732f7265736f6c76652f6d61696e2f7374617272795f6e696768742e6a7067',\n    'https://camo.githubusercontent.com/d1947ab7c49ae3f550c28409d5e8b120df48e456559cf4557306c0848337702c/68747470733a2f2f68756767696e67666163652e636f2f64617461736574732f4e616761536169416268696e61792f556e434c4950496d616765496e746572706f6c6174696f6e53616d706c65732f7265736f6c76652f6d61696e2f666c6f776572732e6a7067'\n]\n\n# Open images from URLs\nimages = []\nfor url in image_urls:\n    response = requests.get(url)\n    img = Image.open(BytesIO(response.content))\n    images.append(img)\n\n# For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths.\ngenerator = torch.Generator(device=device).manual_seed(42)\n\noutput = pipe(image=images, steps=6, generator=generator)\n\nfor i, image in enumerate(output.images):\n    image.save('starry_to_flowers_%s.jpg' % i)\n```\n\nThe original images:-\n\n![starry](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_night.jpg)\n![flowers](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/flowers.jpg)\n\nThe resulting images in order:-\n\n![result0](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_0.png)\n![result1](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_1.png)\n![result2](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_2.png)\n![result3](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_3.png)\n![result4](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_4.png)\n![result5](https://huggingface.co/datasets/NagaSaiAbhinay/UnCLIPImageInterpolationSamples/resolve/main/starry_to_flowers_5.png)\n\n### DDIM Noise Comparative Analysis Pipeline\n\n#### **Research question: What visual concepts do the diffusion models learn from each noise level during training?**\n\nThe [P2 weighting (CVPR 2022)](https://huggingface.co/papers/2204.00227) paper proposed an approach to answer the above question, which is their second contribution.\nThe approach consists of the following steps:\n\n1. The input is an image x0.\n2. Perturb it to xt using a diffusion process q(xt|x0).\n    - `strength` is a value between 0.0 and 1.0, that controls the amount of noise that is added to the input image. Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.\n3. Reconstruct the image with the learned denoising process pθ(ˆx0|xt).\n4. Compare x0 and ˆx0 among various t to show how each step contributes to the sample.\nThe authors used [openai/guided-diffusion](https://github.com/openai/guided-diffusion) model to denoise images in FFHQ dataset. This pipeline extends their second contribution by investigating DDIM on any input image.\n\n```python\nimport torch\nfrom PIL import Image\nimport numpy as np\n\nimage_path = \"path/to/your/image\"  # images from CelebA-HQ might be better\nimage_pil = Image.open(image_path)\nimage_name = image_path.split(\"/\")[-1].split(\".\")[0]\n\ndevice = torch.device(\"cpu\" if not torch.cuda.is_available() else \"cuda\")\npipe = DiffusionPipeline.from_pretrained(\n    \"google/ddpm-ema-celebahq-256\",\n    custom_pipeline=\"ddim_noise_comparative_analysis\",\n)\npipe = pipe.to(device)\n\nfor strength in np.linspace(0.1, 1, 25):\n    denoised_image, latent_timestep = pipe(\n        image_pil, strength=strength, return_dict=False\n    )\n    denoised_image = denoised_image[0]\n    denoised_image.save(\n        f\"noise_comparative_analysis_{image_name}_{latent_timestep}.png\"\n    )\n```\n\nHere is the result of this pipeline (which is DDIM) on CelebA-HQ dataset.\n\n![noise-comparative-analysis](https://user-images.githubusercontent.com/67547213/224677066-4474b2ed-56ab-4c27-87c6-de3c0255eb9c.jpeg)\n\n### CLIP Guided Img2Img Stable Diffusion\n\nCLIP guided Img2Img stable diffusion can help to generate more realistic images with an initial image\nby guiding stable diffusion at every denoising step with an additional CLIP model.\n\nThe following code requires roughly 12GB of GPU RAM.\n\n```python\nfrom io import BytesIO\nimport requests\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom PIL import Image\nfrom transformers import CLIPImageProcessor, CLIPModel\n\n# Load CLIP model and feature extractor\nfeature_extractor = CLIPImageProcessor.from_pretrained(\n    \"laion/CLIP-ViT-B-32-laion2B-s34B-b79K\"\n)\nclip_model = CLIPModel.from_pretrained(\n    \"laion/CLIP-ViT-B-32-laion2B-s34B-b79K\", torch_dtype=torch.float16\n)\n\n# Load guided pipeline\nguided_pipeline = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    custom_pipeline=\"clip_guided_stable_diffusion_img2img\",\n    clip_model=clip_model,\n    feature_extractor=feature_extractor,\n    torch_dtype=torch.float16,\n)\nguided_pipeline.enable_attention_slicing()\nguided_pipeline = guided_pipeline.to(\"cuda\")\n\n# Define prompt and fetch image\nprompt = \"fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece\"\nurl = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\nresponse = requests.get(url)\nedit_image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n\n# Run the pipeline\nimage = guided_pipeline(\n    prompt=prompt,\n    height=512,  # Height of the output image\n    width=512,   # Width of the output image\n    image=edit_image,  # Input image to guide the diffusion\n    strength=0.75,  # How much to transform the input image\n    num_inference_steps=30,  # Number of diffusion steps\n    guidance_scale=7.5,  # Scale of the classifier-free guidance\n    clip_guidance_scale=100,  # Scale of the CLIP guidance\n    num_images_per_prompt=1,  # Generate one image per prompt\n    eta=0.0,  # Noise scheduling parameter\n    num_cutouts=4,  # Number of cutouts for CLIP guidance\n    use_cutouts=False,  # Whether to use cutouts\n    output_type=\"pil\",  # Output as PIL image\n).images[0]\n\n# Display the generated image\nimage.show()\n\n```\n\nInit Image\n\n![img2img_init_clip_guidance](https://huggingface.co/datasets/njindal/images/resolve/main/clip_guided_img2img_init.jpg)\n\nOutput Image\n\n![img2img_clip_guidance](https://huggingface.co/datasets/njindal/images/resolve/main/clip_guided_img2img.jpg)\n\n### TensorRT Text2Image Stable Diffusion Pipeline\n\nThe TensorRT Pipeline can be used to accelerate the Text2Image Stable Diffusion Inference run.\n\nNOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.\n\n```python\nimport torch\nfrom diffusers import DDIMScheduler\nfrom diffusers.pipelines import DiffusionPipeline\n\n# Use the DDIMScheduler scheduler here instead\nscheduler = DDIMScheduler.from_pretrained(\"stabilityai/stable-diffusion-2-1\", subfolder=\"scheduler\")\n\npipe = DiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-2-1\",\n    custom_pipeline=\"stable_diffusion_tensorrt_txt2img\",\n    variant='fp16',\n    torch_dtype=torch.float16,\n    scheduler=scheduler,)\n\n# re-use cached folder to save ONNX models and TensorRT Engines\npipe.set_cached_folder(\"stabilityai/stable-diffusion-2-1\", variant='fp16',)\n\npipe = pipe.to(\"cuda\")\n\nprompt = \"a beautiful photograph of Mt. Fuji during cherry blossom\"\nimage = pipe(prompt).images[0]\nimage.save('tensorrt_mt_fuji.png')\n```\n\n### EDICT Image Editing Pipeline\n\nThis pipeline implements the text-guided image editing approach from the paper [EDICT: Exact Diffusion Inversion via Coupled Transformations](https://huggingface.co/papers/2211.12446). You have to pass:\n\n- (`PIL`) `image` you want to edit.\n- `base_prompt`: the text prompt describing the current image (before editing).\n- `target_prompt`: the text prompt describing with the edits.\n\n```python\nfrom diffusers import DiffusionPipeline, DDIMScheduler\nfrom transformers import CLIPTextModel\nimport torch, PIL, requests\nfrom io import BytesIO\nfrom IPython.display import display\n\ndef center_crop_and_resize(im):\n\n    width, height = im.size\n    d = min(width, height)\n    left = (width - d) / 2\n    upper = (height - d) / 2\n    right = (width + d) / 2\n    lower = (height + d) / 2\n\n    return im.crop((left, upper, right, lower)).resize((512, 512))\n\ntorch_dtype = torch.float16\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n# scheduler and text_encoder param values as in the paper\nscheduler = DDIMScheduler(\n        num_train_timesteps=1000,\n        beta_start=0.00085,\n        beta_end=0.012,\n        beta_schedule=\"scaled_linear\",\n        set_alpha_to_one=False,\n        clip_sample=False,\n)\n\ntext_encoder = CLIPTextModel.from_pretrained(\n    pretrained_model_name_or_path=\"openai/clip-vit-large-patch14\",\n    torch_dtype=torch_dtype,\n)\n\n# initialize pipeline\npipeline = DiffusionPipeline.from_pretrained(\n    pretrained_model_name_or_path=\"CompVis/stable-diffusion-v1-4\",\n    custom_pipeline=\"edict_pipeline\",\n    variant=\"fp16\",\n    scheduler=scheduler,\n    text_encoder=text_encoder,\n    leapfrog_steps=True,\n    torch_dtype=torch_dtype,\n).to(device)\n\n# download image\nimage_url = \"https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1.jpeg\"\nresponse = requests.get(image_url)\nimage = PIL.Image.open(BytesIO(response.content))\n\n# preprocess it\ncropped_image = center_crop_and_resize(image)\n\n# define the prompts\nbase_prompt = \"A dog\"\ntarget_prompt = \"A golden retriever\"\n\n# run the pipeline\nresult_image = pipeline(\n      base_prompt=base_prompt,\n      target_prompt=target_prompt,\n      image=cropped_image,\n)\n\ndisplay(result_image)\n```\n\nInit Image\n\n![img2img_init_edict_text_editing](https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1.jpeg)\n\nOutput Image\n\n![img2img_edict_text_editing](https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1_cropped_generated.png)\n\n### Stable Diffusion RePaint\n\nThis pipeline uses the [RePaint](https://huggingface.co/papers/2201.09865) logic on the latent space of stable diffusion. It can\nbe used similarly to other image inpainting pipelines but does not rely on a specific inpainting model. This means you can use\nmodels that are not specifically created for inpainting.\n\nMake sure to use the ```RePaintScheduler``` as shown in the example below.\n\nDisclaimer: The mask gets transferred into latent space, this may lead to unexpected changes on the edge of the masked part.\nThe inference time is a lot slower.\n\n```py\nimport PIL\nimport requests\nimport torch\nfrom io import BytesIO\nfrom diffusers import StableDiffusionPipeline, RePaintScheduler\n\ndef download_image(url):\n    response = requests.get(url)\n    return PIL.Image.open(BytesIO(response.content)).convert(\"RGB\")\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\ninit_image = download_image(img_url).resize((512, 512))\nmask_image = download_image(mask_url).resize((512, 512))\nmask_image = PIL.ImageOps.invert(mask_image)\npipe = StableDiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\", torch_dtype=torch.float16, custom_pipeline=\"stable_diffusion_repaint\",\n)\npipe.scheduler = RePaintScheduler.from_config(pipe.scheduler.config)\npipe = pipe.to(\"cuda\")\nprompt = \"Face of a yellow cat, high resolution, sitting on a park bench\"\nimage = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]\n```\n\n### TensorRT Image2Image Stable Diffusion Pipeline\n\nThe TensorRT Pipeline can be used to accelerate the Image2Image Stable Diffusion Inference run.\n\nNOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.\n\n```python\nimport requests\nfrom io import BytesIO\nfrom PIL import Image\nimport torch\nfrom diffusers import DDIMScheduler\nfrom diffusers import DiffusionPipeline\n\n# Use the DDIMScheduler scheduler here instead\nscheduler = DDIMScheduler.from_pretrained(\"stabilityai/stable-diffusion-2-1\",\n                                            subfolder=\"scheduler\")\n\npipe = DiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-2-1\",\n                                            custom_pipeline=\"stable_diffusion_tensorrt_img2img\",\n                                            variant='fp16',\n                                            torch_dtype=torch.float16,\n                                            scheduler=scheduler,)\n\n# re-use cached folder to save ONNX models and TensorRT Engines\npipe.set_cached_folder(\"stabilityai/stable-diffusion-2-1\", variant='fp16',)\n\npipe = pipe.to(\"cuda\")\n\nurl = \"https://pajoca.com/wp-content/uploads/2022/09/tekito-yamakawa-1.png\"\nresponse = requests.get(url)\ninput_image = Image.open(BytesIO(response.content)).convert(\"RGB\")\nprompt = \"photorealistic new zealand hills\"\nimage = pipe(prompt, image=input_image, strength=0.75,).images[0]\nimage.save('tensorrt_img2img_new_zealand_hills.png')\n```\n\n### Stable Diffusion BoxDiff\nBoxDiff is a training-free method for controlled generation with bounding box coordinates. It should work with any Stable Diffusion model. Below shows an example with `stable-diffusion-2-1-base`.\n```py\nimport torch\nfrom PIL import Image, ImageDraw\nfrom copy import deepcopy\n\nfrom examples.community.pipeline_stable_diffusion_boxdiff import StableDiffusionBoxDiffPipeline\n\ndef draw_box_with_text(img, boxes, names):\n    colors = [\"red\", \"olive\", \"blue\", \"green\", \"orange\", \"brown\", \"cyan\", \"purple\"]\n    img_new = deepcopy(img)\n    draw = ImageDraw.Draw(img_new)\n\n    W, H = img.size\n    for bid, box in enumerate(boxes):\n        draw.rectangle([box[0] * W, box[1] * H, box[2] * W, box[3] * H], outline=colors[bid % len(colors)], width=4)\n        draw.text((box[0] * W, box[1] * H), names[bid], fill=colors[bid % len(colors)])\n    return img_new\n\npipe = StableDiffusionBoxDiffPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-2-1-base\",\n    torch_dtype=torch.float16,\n)\npipe.to(\"cuda\")\n\n# example 1\nprompt = \"as the aurora lights up the sky, a herd of reindeer leisurely wanders on the grassy meadow, admiring the breathtaking view, a serene lake quietly reflects the magnificent display, and in the distance, a snow-capped mountain stands majestically, fantasy, 8k, highly detailed\"\nphrases = [\n    \"aurora\",\n    \"reindeer\",\n    \"meadow\",\n    \"lake\",\n    \"mountain\"\n]\nboxes = [[1,3,512,202], [75,344,421,495], [1,327,508,507], [2,217,507,341], [1,135,509,242]]\n\n# example 2\n# prompt = \"A rabbit wearing sunglasses looks very proud\"\n# phrases = [\"rabbit\", \"sunglasses\"]\n# boxes = [[67,87,366,512], [66,130,364,262]]\n\nboxes = [[x / 512 for x in box] for box in boxes]\n\nimages = pipe(\n    prompt,\n    boxdiff_phrases=phrases,\n    boxdiff_boxes=boxes,\n    boxdiff_kwargs={\n        \"attention_res\": 16,\n        \"normalize_eot\": True\n    },\n    num_inference_steps=50,\n    guidance_scale=7.5,\n    generator=torch.manual_seed(42),\n    safety_checker=None\n).images\n\ndraw_box_with_text(images[0], boxes, phrases).save(\"output.png\")\n```\n\n\n### Stable Diffusion Reference\n\nThis pipeline uses the Reference Control. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280).\n\nBased on [this issue](https://github.com/huggingface/diffusers/issues/3566),\n\n- `EulerAncestralDiscreteScheduler` got poor results.\n\n```py\nimport torch\nfrom diffusers import UniPCMultistepScheduler\nfrom diffusers.utils import load_image\n\ninput_image = load_image(\"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\")\n\npipe = StableDiffusionReferencePipeline.from_pretrained(\n       \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n       safety_checker=None,\n       torch_dtype=torch.float16\n       ).to('cuda:0')\n\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n\nresult_img = pipe(ref_image=input_image,\n      prompt=\"1girl\",\n      num_inference_steps=20,\n      reference_attn=True,\n      reference_adain=True).images[0]\n```\n\nReference Image\n\n![reference_image](https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png)\n\nOutput Image of `reference_attn=True` and `reference_adain=False`\n\n![output_image](https://github.com/huggingface/diffusers/assets/24734142/813b5c6a-6d89-46ba-b7a4-2624e240eea5)\n\nOutput Image of `reference_attn=False` and `reference_adain=True`\n\n![output_image](https://github.com/huggingface/diffusers/assets/24734142/ffc90339-9ef0-4c4d-a544-135c3e5644da)\n\nOutput Image of `reference_attn=True` and `reference_adain=True`\n\n![output_image](https://github.com/huggingface/diffusers/assets/24734142/3c5255d6-867d-4d35-b202-8dfd30cc6827)\n\n### Stable Diffusion ControlNet Reference\n\nThis pipeline uses the Reference Control with ControlNet. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280).\n\nBased on [this issue](https://github.com/huggingface/diffusers/issues/3566),\n\n- `EulerAncestralDiscreteScheduler` got poor results.\n- `guess_mode=True` works well for ControlNet v1.1\n\n```py\nimport cv2\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom diffusers import UniPCMultistepScheduler\nfrom diffusers.utils import load_image\n\ninput_image = load_image(\"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\")\n\n# get canny image\nimage = cv2.Canny(np.array(input_image), 100, 200)\nimage = image[:, :, None]\nimage = np.concatenate([image, image, image], axis=2)\ncanny_image = Image.fromarray(image)\n\ncontrolnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\npipe = StableDiffusionControlNetReferencePipeline.from_pretrained(\n       \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n       controlnet=controlnet,\n       safety_checker=None,\n       torch_dtype=torch.float16\n       ).to('cuda:0')\n\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n\nresult_img = pipe(ref_image=input_image,\n      prompt=\"1girl\",\n      image=canny_image,\n      num_inference_steps=20,\n      reference_attn=True,\n      reference_adain=True).images[0]\n```\n\nReference Image\n\n![reference_image](https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png)\n\nOutput Image\n\n![output_image](https://github.com/huggingface/diffusers/assets/24734142/7b9a5830-f173-4b92-b0cf-73d0e9c01d60)\n\n### Stable Diffusion on IPEX\n\nThis diffusion pipeline aims to accelerate the inference of Stable-Diffusion on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).\n\nTo use this pipeline, you need to:\n\n1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)\n\n**Note:** For each PyTorch release, there is a corresponding release of the IPEX. Here is the mapping relationship. It is recommended to install PyTorch/IPEX2.0 to get the best performance.\n\n|PyTorch Version|IPEX Version|\n|--|--|\n|[v2.0.\\*](https://github.com/pytorch/pytorch/tree/v2.0.1 \"v2.0.1\")|[v2.0.\\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.0.100+cpu)|\n|[v1.13.\\*](https://github.com/pytorch/pytorch/tree/v1.13.0 \"v1.13.0\")|[v1.13.\\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|\n\nYou can simply use pip to install IPEX with the latest version.\n\n```sh\npython -m pip install intel_extension_for_pytorch\n```\n\n**Note:** To install a specific version, run with the following command:\n\n```sh\npython -m pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu\n```\n\n2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX acceleration. Supported inference datatypes are Float32 and BFloat16.\n\n**Note:** The setting of generated image height/width for `prepare_for_ipex()` should be same as the setting of pipeline inference.\n\n```python\npipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", custom_pipeline=\"stable_diffusion_ipex\")\n# For Float32\npipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) # value of image height/width should be consistent with the pipeline inference\n# For BFloat16\npipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512) # value of image height/width should be consistent with the pipeline inference\n```\n\nThen you can use the ipex pipeline in a similar way to the default stable diffusion pipeline.\n\n```python\n# For Float32\nimage = pipe(prompt, num_inference_steps=20, height=512, width=512).images[0] # value of image height/width should be consistent with 'prepare_for_ipex()'\n# For BFloat16\nwith torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):\n    image = pipe(prompt, num_inference_steps=20, height=512, width=512).images[0] # value of image height/width should be consistent with 'prepare_for_ipex()'\n```\n\nThe following code compares the performance of the original stable diffusion pipeline with the ipex-optimized pipeline.\n\n```python\nimport torch\nimport intel_extension_for_pytorch as ipex\nfrom diffusers import StableDiffusionPipeline\nimport time\n\nprompt = \"sailing ship in storm by Rembrandt\"\nmodel_id = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n# Helper function for time evaluation\ndef elapsed_time(pipeline, nb_pass=3, num_inference_steps=20):\n    # warmup\n    for _ in range(2):\n        images = pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images\n    # time evaluation\n    start = time.time()\n    for _ in range(nb_pass):\n        pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512)\n    end = time.time()\n    return (end - start) / nb_pass\n\n##############     bf16 inference performance    ###############\n\n# 1. IPEX Pipeline initialization\npipe = DiffusionPipeline.from_pretrained(model_id, custom_pipeline=\"stable_diffusion_ipex\")\npipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512)\n\n# 2. Original Pipeline initialization\npipe2 = StableDiffusionPipeline.from_pretrained(model_id)\n\n# 3. Compare performance between Original Pipeline and IPEX Pipeline\nwith torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):\n    latency = elapsed_time(pipe)\n    print(\"Latency of StableDiffusionIPEXPipeline--bf16\", latency)\n    latency = elapsed_time(pipe2)\n    print(\"Latency of StableDiffusionPipeline--bf16\", latency)\n\n##############     fp32 inference performance    ###############\n\n# 1. IPEX Pipeline initialization\npipe3 = DiffusionPipeline.from_pretrained(model_id, custom_pipeline=\"stable_diffusion_ipex\")\npipe3.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512)\n\n# 2. Original Pipeline initialization\npipe4 = StableDiffusionPipeline.from_pretrained(model_id)\n\n# 3. Compare performance between Original Pipeline and IPEX Pipeline\nlatency = elapsed_time(pipe3)\nprint(\"Latency of StableDiffusionIPEXPipeline--fp32\", latency)\nlatency = elapsed_time(pipe4)\nprint(\"Latency of StableDiffusionPipeline--fp32\", latency)\n```\n\n### Stable Diffusion XL on IPEX\n\nThis diffusion pipeline aims to accelerate the inference of Stable-Diffusion XL on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).\n\nTo use this pipeline, you need to:\n\n1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)\n\n**Note:** For each PyTorch release, there is a corresponding release of IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.0 to get the best performance.\n\n|PyTorch Version|IPEX Version|\n|--|--|\n|[v2.0.\\*](https://github.com/pytorch/pytorch/tree/v2.0.1 \"v2.0.1\")|[v2.0.\\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.0.100+cpu)|\n|[v1.13.\\*](https://github.com/pytorch/pytorch/tree/v1.13.0 \"v1.13.0\")|[v1.13.\\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|\n\nYou can simply use pip to install IPEX with the latest version.\n\n```sh\npython -m pip install intel_extension_for_pytorch\n```\n\n**Note:** To install a specific version, run with the following command:\n\n```sh\npython -m pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu\n```\n\n2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX acceleration. Supported inference datatypes are Float32 and BFloat16.\n\n**Note:** The values of `height` and `width` used during preparation with `prepare_for_ipex()` should be the same when running inference with the prepared pipeline.\n\n```python\npipe = StableDiffusionXLPipelineIpex.from_pretrained(\"stabilityai/sdxl-turbo\", low_cpu_mem_usage=True, use_safetensors=True)\n# value of image height/width should be consistent with the pipeline inference\n# For Float32\npipe.prepare_for_ipex(torch.float32, prompt, height=512, width=512)\n# For BFloat16\npipe.prepare_for_ipex(torch.bfloat16, prompt, height=512, width=512)\n```\n\nThen you can use the ipex pipeline in a similar way to the default stable diffusion xl pipeline.\n\n```python\n# value of image height/width should be consistent with 'prepare_for_ipex()'\n# For Float32\nimage = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=guidance_scale).images[0]\n# For BFloat16\nwith torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):\n    image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=guidance_scale).images[0]\n```\n\nThe following code compares the performance of the original stable diffusion xl pipeline with the ipex-optimized pipeline.\nBy using this optimized pipeline, we can get about 1.4-2 times performance boost with BFloat16 on fourth generation of Intel Xeon CPUs,\ncode-named Sapphire Rapids.\n\n```python\nimport torch\nfrom diffusers import StableDiffusionXLPipeline\nfrom pipeline_stable_diffusion_xl_ipex import StableDiffusionXLPipelineIpex\nimport time\n\nprompt = \"sailing ship in storm by Rembrandt\"\nmodel_id = \"stabilityai/sdxl-turbo\"\nsteps = 4\n\n# Helper function for time evaluation\ndef elapsed_time(pipeline, nb_pass=3, num_inference_steps=1):\n    # warmup\n    for _ in range(2):\n        images = pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=0.0).images\n    # time evaluation\n    start = time.time()\n    for _ in range(nb_pass):\n        pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=0.0)\n    end = time.time()\n    return (end - start) / nb_pass\n\n##############     bf16 inference performance    ###############\n\n# 1. IPEX Pipeline initialization\npipe = StableDiffusionXLPipelineIpex.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)\npipe.prepare_for_ipex(torch.bfloat16, prompt, height=512, width=512)\n\n# 2. Original Pipeline initialization\npipe2 = StableDiffusionXLPipeline.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)\n\n# 3. Compare performance between Original Pipeline and IPEX Pipeline\nwith torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):\n    latency = elapsed_time(pipe, num_inference_steps=steps)\n    print(\"Latency of StableDiffusionXLPipelineIpex--bf16\", latency, \"s for total\", steps, \"steps\")\n    latency = elapsed_time(pipe2, num_inference_steps=steps)\n    print(\"Latency of StableDiffusionXLPipeline--bf16\", latency, \"s for total\", steps, \"steps\")\n\n##############     fp32 inference performance    ###############\n\n# 1. IPEX Pipeline initialization\npipe3 = StableDiffusionXLPipelineIpex.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)\npipe3.prepare_for_ipex(torch.float32, prompt, height=512, width=512)\n\n# 2. Original Pipeline initialization\npipe4 = StableDiffusionXLPipeline.from_pretrained(model_id, low_cpu_mem_usage=True, use_safetensors=True)\n\n# 3. Compare performance between Original Pipeline and IPEX Pipeline\nlatency = elapsed_time(pipe3, num_inference_steps=steps)\nprint(\"Latency of StableDiffusionXLPipelineIpex--fp32\", latency, \"s for total\", steps, \"steps\")\nlatency = elapsed_time(pipe4, num_inference_steps=steps)\nprint(\"Latency of StableDiffusionXLPipeline--fp32\", latency, \"s for total\", steps, \"steps\")\n```\n\n### CLIP Guided Images Mixing With Stable Diffusion\n\n![clip_guided_images_mixing_examples](https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/main.png)\n\nCLIP guided stable diffusion images mixing pipeline allows to combine two images using standard diffusion models.\nThis approach is using (optional) CoCa model to avoid writing image description.\n[More code examples](https://github.com/TheDenk/images_mixing)\n\n### Example Images Mixing (with CoCa)\n\n```python\nimport PIL\nimport torch\nimport requests\nimport open_clip\nfrom open_clip import SimpleTokenizer\nfrom io import BytesIO\nfrom diffusers import DiffusionPipeline\nfrom transformers import CLIPImageProcessor, CLIPModel\n\n\ndef download_image(url):\n    response = requests.get(url)\n    return PIL.Image.open(BytesIO(response.content)).convert(\"RGB\")\n\n# Loading additional models\nfeature_extractor = CLIPImageProcessor.from_pretrained(\n    \"laion/CLIP-ViT-B-32-laion2B-s34B-b79K\"\n)\nclip_model = CLIPModel.from_pretrained(\n    \"laion/CLIP-ViT-B-32-laion2B-s34B-b79K\", torch_dtype=torch.float16\n)\ncoca_model = open_clip.create_model('coca_ViT-L-14', pretrained='laion2B-s13B-b90k').to('cuda')\ncoca_model.dtype = torch.float16\ncoca_transform = open_clip.image_transform(\n    coca_model.visual.image_size,\n    is_train=False,\n    mean=getattr(coca_model.visual, 'image_mean', None),\n    std=getattr(coca_model.visual, 'image_std', None),\n)\ncoca_tokenizer = SimpleTokenizer()\n\n# Pipeline creating\nmixing_pipeline = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\",\n    custom_pipeline=\"clip_guided_images_mixing_stable_diffusion\",\n    clip_model=clip_model,\n    feature_extractor=feature_extractor,\n    coca_model=coca_model,\n    coca_tokenizer=coca_tokenizer,\n    coca_transform=coca_transform,\n    torch_dtype=torch.float16,\n)\nmixing_pipeline.enable_attention_slicing()\nmixing_pipeline = mixing_pipeline.to(\"cuda\")\n\n# Pipeline running\ngenerator = torch.Generator(device=\"cuda\").manual_seed(17)\n\ndef download_image(url):\n    response = requests.get(url)\n    return PIL.Image.open(BytesIO(response.content)).convert(\"RGB\")\n\ncontent_image = download_image(\"https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/boromir.jpg\")\nstyle_image = download_image(\"https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/gigachad.jpg\")\n\npipe_images = mixing_pipeline(\n    num_inference_steps=50,\n    content_image=content_image,\n    style_image=style_image,\n    noise_strength=0.65,\n    slerp_latent_style_strength=0.9,\n    slerp_prompt_style_strength=0.1,\n    slerp_clip_image_style_strength=0.1,\n    guidance_scale=9.0,\n    batch_size=1,\n    clip_guidance_scale=100,\n    generator=generator,\n).images\n\noutput_path = \"mixed_output.jpg\"\npipe_images[0].save(output_path)\nprint(f\"Image saved successfully at {output_path}\")\n```\n\n![image_mixing_result](https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/boromir_gigachad.png)\n\n### Stable Diffusion XL Long Weighted Prompt Pipeline\n\nThis SDXL pipeline supports unlimited length prompt and negative prompt, compatible with A1111 prompt weighted style.\n\nYou can provide both `prompt` and `prompt_2`. If only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.\n\n```python\nfrom diffusers import DiffusionPipeline\nfrom diffusers.utils import load_image\nimport torch\n\npipe = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\"\n    , torch_dtype       = torch.float16\n    , use_safetensors   = True\n    , variant           = \"fp16\"\n    , custom_pipeline   = \"lpw_stable_diffusion_xl\",\n)\n\nprompt = \"photo of a cute (white) cat running on the grass\" * 20\nprompt2 = \"chasing (birds:1.5)\" * 20\nprompt = f\"{prompt},{prompt2}\"\nneg_prompt = \"blur, low quality, carton, animate\"\n\npipe.to(\"cuda\")\n\n# text2img\nt2i_images = pipe(\n    prompt=prompt,\n    negative_prompt=neg_prompt,\n).images  # alternatively, you can call the .text2img() function\n\n# img2img\ninput_image = load_image(\"/path/to/local/image.png\")  # or URL to your input image\ni2i_images = pipe.img2img(\n  prompt=prompt,\n  negative_prompt=neg_prompt,\n  image=input_image,\n  strength=0.8,  # higher strength will result in more variation compared to original image\n).images\n\n# inpaint\ninput_mask = load_image(\"/path/to/local/mask.png\")  # or URL to your input inpainting mask\ninpaint_images = pipe.inpaint(\n  prompt=\"photo of a cute (black) cat running on the grass\" * 20,\n  negative_prompt=neg_prompt,\n  image=input_image,\n  mask=input_mask,\n  strength=0.6,  # higher strength will result in more variation compared to original image\n).images\n\npipe.to(\"cpu\")\ntorch.cuda.empty_cache()\n\nfrom IPython.display import display  # assuming you are using this code in a notebook\ndisplay(t2i_images[0])\ndisplay(i2i_images[0])\ndisplay(inpaint_images[0])\n```\n\nIn the above code, the `prompt2` is appended to the `prompt`, which is more than 77 tokens. \"birds\" are showing up in the result.\n![Stable Diffusion XL Long Weighted Prompt Pipeline sample](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_long_weighted_prompt.png)\n\nFor more results, checkout [PR #6114](https://github.com/huggingface/diffusers/pull/6114).\n\n### Stable Diffusion Mixture Tiling Pipeline SD 1.5\n\nThis pipeline uses the Mixture. Refer to the [Mixture](https://huggingface.co/papers/2302.02412) paper for more details.\n\n```python\nfrom diffusers import LMSDiscreteScheduler, DiffusionPipeline\n\n# Create scheduler and model (similar to StableDiffusionPipeline)\nscheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000)\npipeline = DiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", scheduler=scheduler, custom_pipeline=\"mixture_tiling\")\npipeline.to(\"cuda\")\n\n# Mixture of Diffusers generation\nimage = pipeline(\n    prompt=[[\n        \"A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece\",\n        \"A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece\",\n        \"An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece\"\n    ]],\n    tile_height=640,\n    tile_width=640,\n    tile_row_overlap=0,\n    tile_col_overlap=256,\n    guidance_scale=8,\n    seed=7178915308,\n    num_inference_steps=50,\n)[\"images\"][0]\n```\n\n![mixture_tiling_results](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/mixture_tiling.png)\n\n### Stable Diffusion Mixture Canvas Pipeline SD 1.5\n\nThis pipeline uses the Mixture. Refer to the [Mixture](https://huggingface.co/papers/2302.02412) paper for more details.\n\n```python\nfrom PIL import Image\nfrom diffusers import LMSDiscreteScheduler, DiffusionPipeline\nfrom diffusers.pipelines.pipeline_utils import Image2ImageRegion, Text2ImageRegion, preprocess_image\n\n\n# Load and preprocess guide image\niic_image = preprocess_image(Image.open(\"input_image.png\").convert(\"RGB\"))\n\n# Create scheduler and model (similar to StableDiffusionPipeline)\nscheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000)\npipeline = DiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", scheduler=scheduler).to(\"cuda:0\", custom_pipeline=\"mixture_canvas\")\npipeline.to(\"cuda\")\n\n# Mixture of Diffusers generation\noutput = pipeline(\n    canvas_height=800,\n    canvas_width=352,\n    regions=[\n        Text2ImageRegion(0, 800, 0, 352, guidance_scale=8,\n            prompt=f\"best quality, masterpiece, WLOP, sakimichan, art contest winner on pixiv, 8K, intricate details, wet effects, rain drops, ethereal, mysterious, futuristic, UHD, HDR, cinematic lighting, in a beautiful forest, rainy day, award winning, trending on artstation, beautiful confident cheerful young woman, wearing a futuristic sleeveless dress, ultra beautiful detailed  eyes, hyper-detailed face, complex,  perfect, model,  textured,  chiaroscuro, professional make-up, realistic, figure in frame, \"),\n        Image2ImageRegion(352-800, 352, 0, 352, reference_image=iic_image, strength=1.0),\n    ],\n    num_inference_steps=100,\n    seed=5525475061,\n)[\"images\"][0]\n```\n\n![Input_Image](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/input_image.png)\n![mixture_canvas_results](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/canvas.png)\n\n### Stable Diffusion Mixture Tiling Pipeline SDXL\n\nThis pipeline uses the Mixture. Refer to the [Mixture](https://huggingface.co/papers/2302.02412) paper for more details.\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, AutoencoderKL\n\ndevice=\"cuda\"\n\n# Load fixed vae (optional)\nvae = AutoencoderKL.from_pretrained(\n    \"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16\n).to(device)\n\n# Create scheduler and model (similar to StableDiffusionPipeline)\nmodel_id=\"stablediffusionapi/yamermix-v8-vae\"\nscheduler = DPMSolverMultistepScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000)\npipe = DiffusionPipeline.from_pretrained(\n    model_id,\n    torch_dtype=torch.float16,\n    vae=vae,\n    custom_pipeline=\"mixture_tiling_sdxl\",\n    scheduler=scheduler,\n    use_safetensors=False    \n).to(device)\n\npipe.enable_model_cpu_offload()\npipe.enable_vae_tiling()\npipe.enable_vae_slicing()\n\ngenerator = torch.Generator(device).manual_seed(297984183)\n\n# Mixture of Diffusers generation\nimage = pipe(\n    prompt=[[\n        \"A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece\",\n        \"A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece\",        \n        \"An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece\"\n    ]],\n    tile_height=1024,\n    tile_width=1280,\n    tile_row_overlap=0,\n    tile_col_overlap=256,\n    guidance_scale_tiles=[[7, 7, 7]], # or guidance_scale=7 if is the same for all prompts\n    height=1024,\n    width=3840,    \n    generator=generator,\n    num_inference_steps=30,\n)[\"images\"][0]\n```\n\n![mixture_tiling_results](https://huggingface.co/datasets/elismasilva/results/resolve/main/mixture_of_diffusers_sdxl_1.png)\n\n### Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL\n\nThis pipeline implements the [MoD (Mixture-of-Diffusers)](https://huggingface.co/papers/2408.06072) tiled diffusion technique and combines it with SDXL's ControlNet Tile process to generate SR images.\n\nThis works better with 4x scales, but you can try adjusts parameters to higher scales.\n\n````python\nimport torch\nfrom diffusers import DiffusionPipeline, ControlNetUnionModel, AutoencoderKL, UniPCMultistepScheduler, UNet2DConditionModel\nfrom diffusers.utils import load_image\nfrom PIL import Image\n\ndevice = \"cuda\"\n\n# Initialize the models and pipeline\ncontrolnet = ControlNetUnionModel.from_pretrained(\n    \"brad-twinkl/controlnet-union-sdxl-1.0-promax\", torch_dtype=torch.float16\n).to(device=device)\nvae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16).to(device=device)\n\nmodel_id = \"SG161222/RealVisXL_V5.0\"\npipe = DiffusionPipeline.from_pretrained(\n    model_id,\n    torch_dtype=torch.float16,\n    vae=vae,\n    controlnet=controlnet,\n    custom_pipeline=\"mod_controlnet_tile_sr_sdxl\",    \n    use_safetensors=True,\n    variant=\"fp16\",\n).to(device)\n\nunet = UNet2DConditionModel.from_pretrained(model_id, subfolder=\"unet\", variant=\"fp16\", use_safetensors=True)\n\n#pipe.enable_model_cpu_offload()  # << Enable this if you have limited VRAM\npipe.enable_vae_tiling() # << Enable this if you have limited VRAM\npipe.enable_vae_slicing() # << Enable this if you have limited VRAM\n\n# Set selected scheduler\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n\n# Load image\ncontrol_image = load_image(\"https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/1.jpg\")\noriginal_height = control_image.height\noriginal_width = control_image.width\nprint(f\"Current resolution: H:{original_height} x W:{original_width}\")\n\n# Pre-upscale image for tiling\nresolution = 4096\ntile_gaussian_sigma = 0.3\nmax_tile_size = 1024 # or 1280\n\ncurrent_size = max(control_image.size)\nscale_factor = max(2, resolution / current_size)\nnew_size = (int(control_image.width * scale_factor), int(control_image.height * scale_factor))\nimage = control_image.resize(new_size, Image.LANCZOS)\n\n# Update target height and width\ntarget_height = image.height\ntarget_width = image.width\nprint(f\"Target resolution: H:{target_height} x W:{target_width}\")\n\n# Calculate overlap size\nnormal_tile_overlap, border_tile_overlap = pipe.calculate_overlap(target_width, target_height)\n\n# Set other params\ntile_weighting_method = pipe.TileWeightingMethod.COSINE.value\nguidance_scale = 4\nnum_inference_steps = 35\ndenoising_strenght = 0.65\ncontrolnet_strength = 1.0\nprompt = \"high-quality, noise-free edges, high quality, 4k, hd, 8k\"\nnegative_prompt = \"blurry, pixelated, noisy, low resolution, artifacts, poor details\"\n\n# Image generation\ngenerated_image = pipe(\n    image=image,\n    control_image=control_image,\n    control_mode=[6],\n    controlnet_conditioning_scale=float(controlnet_strength),\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    normal_tile_overlap=normal_tile_overlap,\n    border_tile_overlap=border_tile_overlap,\n    height=target_height,\n    width=target_width,\n    original_size=(original_width, original_height),\n    target_size=(target_width, target_height),\n    guidance_scale=guidance_scale,        \n    strength=float(denoising_strenght),\n    tile_weighting_method=tile_weighting_method,\n    max_tile_size=max_tile_size,\n    tile_gaussian_sigma=float(tile_gaussian_sigma),\n    num_inference_steps=num_inference_steps,\n)[\"images\"][0]\n````\n![Upscaled](https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/1_input_4x.png)\n\n### TensorRT Inpainting Stable Diffusion Pipeline\n\nThe TensorRT Pipeline can be used to accelerate the Inpainting Stable Diffusion Inference run.\n\nNOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.\n\n```python\nimport requests\nfrom io import BytesIO\nfrom PIL import Image\nimport torch\nfrom diffusers import PNDMScheduler\nfrom diffusers.pipelines import DiffusionPipeline\n\n# Use the PNDMScheduler scheduler here instead\nscheduler = PNDMScheduler.from_pretrained(\"stabilityai/stable-diffusion-2-inpainting\", subfolder=\"scheduler\")\n\npipe = DiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-2-inpainting\",\n    custom_pipeline=\"stable_diffusion_tensorrt_inpaint\",\n    variant='fp16',\n    torch_dtype=torch.float16,\n    scheduler=scheduler,\n    )\n\n# re-use cached folder to save ONNX models and TensorRT Engines\npipe.set_cached_folder(\"stabilityai/stable-diffusion-2-inpainting\", variant='fp16',)\n\npipe = pipe.to(\"cuda\")\n\nurl = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nresponse = requests.get(url)\ninput_image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\nresponse = requests.get(mask_url)\nmask_image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n\nprompt = \"a mecha robot sitting on a bench\"\nimage = pipe(prompt, image=input_image, mask_image=mask_image, strength=0.75,).images[0]\nimage.save('tensorrt_inpaint_mecha_robot.png')\n```\n\n### IADB pipeline\n\nThis pipeline is the implementation of the [α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://huggingface.co/papers/2305.03486) paper.\nIt is a simple and minimalist diffusion model.\n\nThe following code shows how to use the IADB pipeline to generate images using a pretrained celebahq-256 model.\n\n```python\npipeline_iadb = DiffusionPipeline.from_pretrained(\"thomasc4/iadb-celebahq-256\", custom_pipeline='iadb')\n\npipeline_iadb = pipeline_iadb.to('cuda')\n\noutput = pipeline_iadb(batch_size=4, num_inference_steps=128)\nfor i in range(len(output[0])):\n    plt.imshow(output[0][i])\n    plt.show()\n```\n\nSampling with the IADB formulation is easy, and can be done in a few lines (the pipeline already implements it):\n\n```python\ndef sample_iadb(model, x0, nb_step):\n    x_alpha = x0\n    for t in range(nb_step):\n        alpha = (t/nb_step)\n        alpha_next =((t+1)/nb_step)\n\n        d = model(x_alpha, torch.tensor(alpha, device=x_alpha.device))['sample']\n        x_alpha = x_alpha + (alpha_next-alpha)*d\n\n    return x_alpha\n```\n\nThe training loop is also straightforward:\n\n```python\n# Training loop\nwhile True:\n    x0 = sample_noise()\n    x1 = sample_dataset()\n\n    alpha = torch.rand(batch_size)\n\n    # Blend\n    x_alpha = (1-alpha) * x0 + alpha * x1\n\n    # Loss\n    loss = torch.sum((D(x_alpha, alpha)- (x1-x0))**2)\n    optimizer.zero_grad()\n    loss.backward()\n    optimizer.step()\n```\n\n### Zero1to3 pipeline\n\nThis pipeline is the implementation of the [Zero-1-to-3: Zero-shot One Image to 3D Object](https://huggingface.co/papers/2303.11328) paper.\nThe original pytorch-lightning [repo](https://github.com/cvlab-columbia/zero123) and a diffusers [repo](https://github.com/kxhit/zero123-hf).\n\nThe following code shows how to use the Zero1to3 pipeline to generate novel view synthesis images using a pretrained stable diffusion model.\n\n```python\nimport os\nimport torch\nfrom pipeline_zero1to3 import Zero1to3StableDiffusionPipeline\nfrom diffusers.utils import load_image\n\nmodel_id = \"kxic/zero123-165000\"  # zero123-105000, zero123-165000, zero123-xl\n\npipe = Zero1to3StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)\n\npipe.enable_xformers_memory_efficient_attention()\npipe.enable_vae_tiling()\npipe.enable_attention_slicing()\npipe = pipe.to(\"cuda\")\n\nnum_images_per_prompt = 4\n\n# test inference pipeline\n# x y z, Polar angle (vertical rotation in degrees)  Azimuth angle (horizontal rotation in degrees)  Zoom (relative distance from center)\nquery_pose1 = [-75.0, 100.0, 0.0]\nquery_pose2 = [-20.0, 125.0, 0.0]\nquery_pose3 = [-55.0, 90.0, 0.0]\n\n# load image\n# H, W = (256, 256) # H, W = (512, 512)   # zero123 training is 256,256\n\n# for batch input\ninput_image1 = load_image(\"./demo/4_blackarm.png\")  # load_image(\"https://cvlab-zero123-live.hf.space/file=/home/user/app/configs/4_blackarm.png\")\ninput_image2 = load_image(\"./demo/8_motor.png\")  # load_image(\"https://cvlab-zero123-live.hf.space/file=/home/user/app/configs/8_motor.png\")\ninput_image3 = load_image(\"./demo/7_london.png\")  # load_image(\"https://cvlab-zero123-live.hf.space/file=/home/user/app/configs/7_london.png\")\ninput_images = [input_image1, input_image2, input_image3]\nquery_poses = [query_pose1, query_pose2, query_pose3]\n\n# # for single input\n# H, W = (256, 256)\n# input_images = [input_image2.resize((H, W), PIL.Image.NEAREST)]\n# query_poses = [query_pose2]\n\n\n# better do preprocessing\nfrom gradio_new import preprocess_image, create_carvekit_interface\nimport numpy as np\nimport PIL.Image as Image\n\npre_images = []\nmodels = dict()\nprint('Instantiating Carvekit HiInterface...')\nmodels['carvekit'] = create_carvekit_interface()\nif not isinstance(input_images, list):\n    input_images = [input_images]\nfor raw_im in input_images:\n    input_im = preprocess_image(models, raw_im, True)\n    H, W = input_im.shape[:2]\n    pre_images.append(Image.fromarray((input_im * 255.0).astype(np.uint8)))\ninput_images = pre_images\n\n# infer pipeline, in original zero123 num_inference_steps=76\nimages = pipe(input_imgs=input_images, prompt_imgs=input_images, poses=query_poses, height=H, width=W,\n              guidance_scale=3.0, num_images_per_prompt=num_images_per_prompt, num_inference_steps=50).images\n\n# save imgs\nlog_dir = \"logs\"\nos.makedirs(log_dir, exist_ok=True)\nbs = len(input_images)\ni = 0\nfor obj in range(bs):\n    for idx in range(num_images_per_prompt):\n        images[i].save(os.path.join(log_dir,f\"obj{obj}_{idx}.jpg\"))\n        i += 1\n```\n\n### Stable Diffusion XL Reference\n\nThis pipeline uses the Reference. Refer to the [Stable Diffusion Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference) section for more information.\n\n```py\nimport torch\n# from diffusers import DiffusionPipeline\nfrom diffusers.utils import load_image\nfrom diffusers.schedulers import UniPCMultistepScheduler\n\nfrom .stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline\n\ninput_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg\")\n\n# pipe = DiffusionPipeline.from_pretrained(\n#     \"stabilityai/stable-diffusion-xl-base-1.0\",\n#     custom_pipeline=\"stable_diffusion_xl_reference\",\n#     torch_dtype=torch.float16,\n#     use_safetensors=True,\n#     variant=\"fp16\").to('cuda:0')\n\npipe = StableDiffusionXLReferencePipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    use_safetensors=True,\n    variant=\"fp16\").to('cuda:0')\n\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n\nresult_img = pipe(ref_image=input_image,\n      prompt=\"a dog\",\n      num_inference_steps=20,\n      reference_attn=True,\n      reference_adain=True).images[0]\n```\n\nReference Image\n\n![reference_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg)\n\nOutput Image\n\n`prompt: a dog`\n\n`reference_attn=False, reference_adain=True, num_inference_steps=20`\n![Output_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_adain_dog.png)\n\nReference Image\n![reference_image](https://github.com/huggingface/diffusers/assets/34944964/449bdab6-e744-4fb2-9620-d4068d9a741b)\n\nOutput Image\n\n`prompt: A dog`\n\n`reference_attn=True, reference_adain=False, num_inference_steps=20`\n![Output_image](https://github.com/huggingface/diffusers/assets/34944964/fff2f16f-6e91-434b-abcc-5259d866c31e)\n\nReference Image\n![reference_image](https://github.com/huggingface/diffusers/assets/34944964/077ed4fe-2991-4b79-99a1-009f056227d1)\n\nOutput Image\n\n`prompt: An astronaut riding a lion`\n\n`reference_attn=True, reference_adain=True, num_inference_steps=20`\n![output_image](https://github.com/huggingface/diffusers/assets/34944964/9b2f1aca-886f-49c3-89ec-d2031c8e3670)\n\n### Stable Diffusion XL ControlNet Reference\n\nThis pipeline uses the Reference Control and with ControlNet. Refer to the [Stable Diffusion ControlNet Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-controlnet-reference) and [Stable Diffusion XL Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-xl-reference) sections for more information.\n\n```py\nfrom diffusers import ControlNetModel, AutoencoderKL\nfrom diffusers.schedulers import UniPCMultistepScheduler\nfrom diffusers.utils import load_image\nimport numpy as np\nimport torch\n\nimport cv2\nfrom PIL import Image\n\nfrom .stable_diffusion_xl_controlnet_reference import StableDiffusionXLControlNetReferencePipeline\n\n# download an image\ncanny_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg\"\n)\n\nref_image = load_image(\n    \"https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png\"\n)\n\n# initialize the models and pipeline\ncontrolnet_conditioning_scale = 0.5  # recommended for good generalization\ncontrolnet = ControlNetModel.from_pretrained(\n    \"diffusers/controlnet-canny-sdxl-1.0\", torch_dtype=torch.float16\n)\nvae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16)\npipe = StableDiffusionXLControlNetReferencePipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", controlnet=controlnet, vae=vae, torch_dtype=torch.float16\n).to(\"cuda:0\")\n\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n\n# get canny image\nimage = np.array(canny_image)\nimage = cv2.Canny(image, 100, 200)\nimage = image[:, :, None]\nimage = np.concatenate([image, image, image], axis=2)\ncanny_image = Image.fromarray(image)\n\n# generate image\nimage = pipe(\n    prompt=\"a cat\",\n    num_inference_steps=20,\n    controlnet_conditioning_scale=controlnet_conditioning_scale,\n    image=canny_image,\n    ref_image=ref_image,\n    reference_attn=False,\n    reference_adain=True,\n    style_fidelity=1.0,\n    generator=torch.Generator(\"cuda\").manual_seed(42)\n).images[0]\n```\n\nCanny ControlNet Image\n\n![canny_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg)\n\nReference Image\n\n![ref_image](https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png)\n\nOutput Image\n\n`prompt: a cat`\n\n`reference_attn=True, reference_adain=True, num_inference_steps=20, style_fidelity=1.0`\n\n![Output_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_attn_adain_canny_cat.png)\n\n`reference_attn=False, reference_adain=True, num_inference_steps=20, style_fidelity=1.0`\n\n![Output_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_adain_canny_cat.png)\n\n`reference_attn=True, reference_adain=False, num_inference_steps=20, style_fidelity=1.0`\n\n![Output_image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_attn_canny_cat.png)\n\n### Stable diffusion fabric pipeline\n\nFABRIC approach applicable to a wide range of popular diffusion models, which exploits\nthe self-attention layer present in the most widely used architectures to condition\nthe diffusion process on a set of feedback images.\n\n```python\nimport requests\nimport torch\nfrom PIL import Image\nfrom io import BytesIO\n\nfrom diffusers import DiffusionPipeline\n\n# load the pipeline\n# make sure you're logged in with `hf auth login`\nmodel_id_or_path = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n# can also be used with dreamlike-art/dreamlike-photoreal-2.0\npipe = DiffusionPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16, custom_pipeline=\"pipeline_fabric\").to(\"cuda\")\n\n# let's specify a prompt\nprompt = \"An astronaut riding an elephant\"\nnegative_prompt = \"lowres, cropped\"\n\n# call the pipeline\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    num_inference_steps=20,\n    generator=torch.manual_seed(12)\n).images[0]\n\nimage.save(\"horse_to_elephant.jpg\")\n\n# let's try another example with feedback\nurl = \"https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png\"\nresponse = requests.get(url)\ninit_image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n\nprompt = \"photo, A blue colored car, fish eye\"\nliked = [init_image]\n## same goes with disliked\n\n# call the pipeline\ntorch.manual_seed(0)\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    liked=liked,\n    num_inference_steps=20,\n).images[0]\n\nimage.save(\"black_to_blue.png\")\n```\n\n*With enough feedbacks you can create very similar high quality images.*\n\nThe original codebase can be found at [sd-fabric/fabric](https://github.com/sd-fabric/fabric), and available checkpoints are [dreamlike-art/dreamlike-photoreal-2.0](https://huggingface.co/dreamlike-art/dreamlike-photoreal-2.0), [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5), and [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) (may give unexpected results).\n\nLet's have a look at the images (_512X512_)\n\n| Without Feedback            | With Feedback  (1st image)          |\n|---------------------|---------------------|\n| ![Image 1](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/fabric_wo_feedback.jpg) | ![Feedback Image 1](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/fabric_w_feedback.png) |\n\n### Masked Im2Im Stable Diffusion Pipeline\n\nThis pipeline reimplements sketch inpaint feature from A1111 for non-inpaint models. The following code reads two images, original and one with mask painted over it. It computes mask as a difference of two images and does the inpainting in the area defined by the mask.\n\n```python\nimg = PIL.Image.open(\"./mech.png\")\n# read image with mask painted over\nimg_paint = PIL.Image.open(\"./mech_painted.png\")\nneq = numpy.any(numpy.array(img) != numpy.array(img_paint), axis=-1)\nmask = neq / neq.max()\n\npipeline = MaskedStableDiffusionImg2ImgPipeline.from_pretrained(\"frankjoshua/icbinpICantBelieveIts_v8\")\n\n# works best with EulerAncestralDiscreteScheduler\npipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)\ngenerator = torch.Generator(device=\"cpu\").manual_seed(4)\n\nprompt = \"a man wearing a mask\"\nresult = pipeline(prompt=prompt, image=img_paint, mask=mask, strength=0.75,\n                  generator=generator)\nresult.images[0].save(\"result.png\")\n```\n\noriginal image mech.png\n\n<img src=https://github.com/noskill/diffusers/assets/733626/10ad972d-d655-43cb-8de1-039e3d79e849 width=\"25%\" >\n\nimage with mask mech_painted.png\n\n<img src=https://github.com/noskill/diffusers/assets/733626/c334466a-67fe-4377-9ff7-f46021b9c224 width=\"25%\" >\n\nresult:\n\n<img src=https://github.com/noskill/diffusers/assets/733626/23a0a71d-51db-471e-926a-107ac62512a8 width=\"25%\" >\n\n### Masked Im2Im Stable Diffusion Pipeline XL\n\nThis pipeline implements sketch inpaint feature from A1111 for non-inpaint models. The following code reads two images, original and one with mask painted over it. It computes mask as a difference of two images and does the inpainting in the area defined by the mask. Latent code is initialized from the image with the mask by default so the color of the mask affects the result.\n\n```\nimg = PIL.Image.open(\"./mech.png\")\n# read image with mask painted over\nimg_paint = PIL.Image.open(\"./mech_painted.png\")\n\npipeline = MaskedStableDiffusionXLImg2ImgPipeline.from_pretrained(\"frankjoshua/juggernautXL_v8Rundiffusion\", dtype=torch.float16)\n\npipeline.to('cuda')\npipeline.enable_xformers_memory_efficient_attention()\n\nprompt = \"a mech warrior wearing a mask\"\nseed = 8348273636437\nfor i in range(10):\n    generator = torch.Generator(device=\"cuda\").manual_seed(seed + i)\n    print(seed + i)\n    result = pipeline(prompt=prompt, blur=48, image=img_paint, original_image=img, strength=0.9,\n                          generator=generator, num_inference_steps=60, num_images_per_prompt=1)\n    im = result.images[0]\n    im.save(f\"result{i}.png\")\n```\n\noriginal image mech.png\n\n<img src=https://github.com/noskill/diffusers/assets/733626/10ad972d-d655-43cb-8de1-039e3d79e849 width=\"25%\" >\n\nimage with mask mech_painted.png\n\n<img src=https://github.com/noskill/diffusers/assets/733626/c334466a-67fe-4377-9ff7-f46021b9c224 width=\"25%\" >\n\nresult:\n\n<img src=https://github.com/noskill/diffusers/assets/733626/5043fb57-a785-4606-a5ba-a36704f7cb42 width=\"25%\" >\n\n### Prompt2Prompt Pipeline\n\nPrompt2Prompt allows the following edits:\n\n- ReplaceEdit (change words in prompt)\n- ReplaceEdit with local blend (change words in prompt, keep image part unrelated to changes constant)\n- RefineEdit (add words to prompt)\n- RefineEdit with local blend (add words to prompt, keep image part unrelated to changes constant)\n- ReweightEdit (modulate importance of words)\n\nHere's a full example for `ReplaceEdit``:\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\nimport numpy as np\nfrom PIL import Image\n\npipe = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\", \n    custom_pipeline=\"pipeline_prompt2prompt\"\n).to(\"cuda\")\n\nprompts = [\n    \"A turtle playing with a ball\",\n    \"A monkey playing with a ball\"\n]\n\ncross_attention_kwargs = {\n    \"edit_type\": \"replace\",\n    \"cross_replace_steps\": 0.4,\n    \"self_replace_steps\": 0.4\n}\n\noutputs = pipe(\n    prompt=prompts,\n    height=512,\n    width=512,\n    num_inference_steps=50,\n    cross_attention_kwargs=cross_attention_kwargs\n)\n\noutputs.images[0].save(\"output_image_0.png\")\n```\n\nAnd abbreviated examples for the other edits:\n\n`ReplaceEdit with local blend`\n\n```python\nprompts = [\"A turtle playing with a ball\",\n           \"A monkey playing with a ball\"]\n\ncross_attention_kwargs = {\n    \"edit_type\": \"replace\",\n    \"cross_replace_steps\": 0.4,\n    \"self_replace_steps\": 0.4,\n    \"local_blend_words\": [\"turtle\", \"monkey\"]\n}\n```\n\n`RefineEdit`\n\n```python\nprompts = [\"A turtle\",\n           \"A turtle in a forest\"]\n\ncross_attention_kwargs = {\n    \"edit_type\": \"refine\",\n    \"cross_replace_steps\": 0.4,\n    \"self_replace_steps\": 0.4,\n}\n```\n\n`RefineEdit with local blend`\n\n```python\nprompts = [\"A turtle\",\n           \"A turtle in a forest\"]\n\ncross_attention_kwargs = {\n    \"edit_type\": \"refine\",\n    \"cross_replace_steps\": 0.4,\n    \"self_replace_steps\": 0.4,\n    \"local_blend_words\": [\"in\", \"a\" , \"forest\"]\n}\n```\n\n`ReweightEdit`\n\n```python\nprompts = [\"A smiling turtle\"] * 2\n\nedit_kcross_attention_kwargswargs = {\n    \"edit_type\": \"reweight\",\n    \"cross_replace_steps\": 0.4,\n    \"self_replace_steps\": 0.4,\n    \"equalizer_words\": [\"smiling\"],\n    \"equalizer_strengths\": [5]\n}\n```\n\nSide note: See [this GitHub gist](https://gist.github.com/UmerHA/b65bb5fb9626c9c73f3ade2869e36164) if you want to visualize the attention maps.\n\n### Latent Consistency Pipeline\n\nLatent Consistency Models was proposed in [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://huggingface.co/papers/2310.04378) by _Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, Hang Zhao_ from Tsinghua University.\n\nThe abstract of the paper reads as follows:\n\n*Latent Diffusion models (LDMs) have achieved remarkable results in synthesizing high-resolution images. However, the iterative sampling process is computationally intensive and leads to slow generation. Inspired by Consistency Models (song et al.), we propose Latent Consistency Models (LCMs), enabling swift inference with minimal steps on any pre-trained LDMs, including Stable Diffusion (rombach et al). Viewing the guided reverse diffusion process as solving an augmented probability flow ODE (PF-ODE), LCMs are designed to directly predict the solution of such ODE in latent space, mitigating the need for numerous iterations and allowing rapid, high-fidelity sampling. Efficiently distilled from pre-trained classifier-free guided diffusion models, a high-quality 768 x 768 2~4-step LCM takes only 32 A100 GPU hours for training. Furthermore, we introduce Latent Consistency Fine-tuning (LCF), a novel method that is tailored for fine-tuning LCMs on customized image datasets. Evaluation on the LAION-5B-Aesthetics dataset demonstrates that LCMs achieve state-of-the-art text-to-image generation performance with few-step inference. Project Page: [this https URL](https://latent-consistency-models.github.io/)*\n\nThe model can be used with `diffusers` as follows:\n\n- *1. Load the model from the community pipeline.*\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipe = DiffusionPipeline.from_pretrained(\"SimianLuo/LCM_Dreamshaper_v7\", custom_pipeline=\"latent_consistency_txt2img\", custom_revision=\"main\")\n\n# To save GPU memory, torch.float16 can be used, but it may compromise image quality.\npipe.to(torch_device=\"cuda\", torch_dtype=torch.float32)\n```\n\n- 2. Run inference with as little as 4 steps:\n\n```py\nprompt = \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\"\n\n# Can be set to 1~50 steps. LCM supports fast inference even <= 4 steps. Recommend: 1~8 steps.\nnum_inference_steps = 4\n\nimages = pipe(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=8.0, lcm_origin_steps=50, output_type=\"pil\").images\n```\n\nFor any questions or feedback, feel free to reach out to [Simian Luo](https://github.com/luosiallen).\n\nYou can also try this pipeline directly in the [🚀 official spaces](https://huggingface.co/spaces/SimianLuo/Latent_Consistency_Model).\n\n### Latent Consistency Img2img Pipeline\n\nThis pipeline extends the Latent Consistency Pipeline to allow it to take an input image.\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipe = DiffusionPipeline.from_pretrained(\"SimianLuo/LCM_Dreamshaper_v7\", custom_pipeline=\"latent_consistency_img2img\")\n\n# To save GPU memory, torch.float16 can be used, but it may compromise image quality.\npipe.to(torch_device=\"cuda\", torch_dtype=torch.float32)\n```\n\n- 2. Run inference with as little as 4 steps:\n\n```py\nprompt = \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\"\n\n\ninput_image=Image.open(\"myimg.png\")\n\nstrength = 0.5  # strength =0 (no change) strength=1 (completely overwrite image)\n\n# Can be set to 1~50 steps. LCM supports fast inference even <= 4 steps. Recommend: 1~8 steps.\nnum_inference_steps = 4\n\nimages = pipe(prompt=prompt, image=input_image, strength=strength, num_inference_steps=num_inference_steps, guidance_scale=8.0, lcm_origin_steps=50, output_type=\"pil\").images\n```\n\n### Latent Consistency Interpolation Pipeline\n\nThis pipeline extends the Latent Consistency Pipeline to allow for interpolation of the latent space between multiple prompts. It is similar to the [Stable Diffusion Interpolate](https://github.com/huggingface/diffusers/blob/main/examples/community/interpolate_stable_diffusion.py) and [unCLIP Interpolate](https://github.com/huggingface/diffusers/blob/main/examples/community/unclip_text_interpolation.py) community pipelines.\n\n```py\nimport torch\nimport numpy as np\n\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"SimianLuo/LCM_Dreamshaper_v7\", custom_pipeline=\"latent_consistency_interpolate\")\n\n# To save GPU memory, torch.float16 can be used, but it may compromise image quality.\npipe.to(torch_device=\"cuda\", torch_dtype=torch.float32)\n\nprompts = [\n    \"Self-portrait oil painting, a beautiful cyborg with golden hair, Margot Robbie, 8k\",\n    \"Self-portrait oil painting, an extremely strong man, body builder, Huge Jackman, 8k\",\n    \"An astronaut floating in space, renaissance art, realistic, high quality, 8k\",\n    \"Oil painting of a cat, cute, dream-like\",\n    \"Hugging face emoji, cute, realistic\"\n]\nnum_inference_steps = 4\nnum_interpolation_steps = 60\nseed = 1337\n\ntorch.manual_seed(seed)\nnp.random.seed(seed)\n\nimages = pipe(\n    prompt=prompts,\n    height=512,\n    width=512,\n    num_inference_steps=num_inference_steps,\n    num_interpolation_steps=num_interpolation_steps,\n    guidance_scale=8.0,\n    embedding_interpolation_type=\"lerp\",\n    latent_interpolation_type=\"slerp\",\n    process_batch_size=4,  # Make it higher or lower based on your GPU memory\n    generator=torch.Generator(seed),\n)\n\nassert len(images) == (len(prompts) - 1) * num_interpolation_steps\n```\n\n### StableDiffusionUpscaleLDM3D Pipeline\n\n[LDM3D-VR](https://huggingface.co/papers/2311.03226) is an extended version of LDM3D.\n\nThe abstract from the paper is:\n*Latent diffusion models have proven to be state-of-the-art in the creation and manipulation of visual outputs. However, as far as we know, the generation of depth maps jointly with RGB is still limited. We introduce LDM3D-VR, a suite of diffusion models targeting virtual reality development that includes LDM3D-pano and LDM3D-SR. These models enable the generation of panoramic RGBD based on textual prompts and the upscaling of low-resolution inputs to high-resolution RGBD, respectively. Our models are fine-tuned from existing pretrained models on datasets containing panoramic/high-resolution RGB images, depth maps and captions. Both models are evaluated in comparison to existing related methods*\n\nTwo checkpoints are available for use:\n\n- [ldm3d-pano](https://huggingface.co/Intel/ldm3d-pano). This checkpoint enables the generation of panoramic images and requires the StableDiffusionLDM3DPipeline pipeline to be used.\n- [ldm3d-sr](https://huggingface.co/Intel/ldm3d-sr). This checkpoint enables the upscaling of RGB and depth images. Can be used in cascade after the original LDM3D pipeline using the StableDiffusionUpscaleLDM3DPipeline pipeline.\n\n```py\nfrom PIL import Image\nimport os\nimport torch\nfrom diffusers import StableDiffusionLDM3DPipeline, DiffusionPipeline\n\n# Generate a rgb/depth output from LDM3D\n\npipe_ldm3d = StableDiffusionLDM3DPipeline.from_pretrained(\"Intel/ldm3d-4c\")\npipe_ldm3d.to(\"cuda\")\n\nprompt = \"A picture of some lemons on a table\"\noutput = pipe_ldm3d(prompt)\nrgb_image, depth_image = output.rgb, output.depth\nrgb_image[0].save(\"lemons_ldm3d_rgb.jpg\")\ndepth_image[0].save(\"lemons_ldm3d_depth.png\")\n\n# Upscale the previous output to a resolution of (1024, 1024)\n\npipe_ldm3d_upscale = DiffusionPipeline.from_pretrained(\"Intel/ldm3d-sr\", custom_pipeline=\"pipeline_stable_diffusion_upscale_ldm3d\")\n\npipe_ldm3d_upscale.to(\"cuda\")\n\nlow_res_img = Image.open(\"lemons_ldm3d_rgb.jpg\").convert(\"RGB\")\nlow_res_depth = Image.open(\"lemons_ldm3d_depth.png\").convert(\"L\")\noutputs = pipe_ldm3d_upscale(prompt=\"high quality high resolution uhd 4k image\", rgb=low_res_img, depth=low_res_depth, num_inference_steps=50, target_res=[1024, 1024])\n\nupscaled_rgb, upscaled_depth = outputs.rgb[0], outputs.depth[0]\nupscaled_rgb.save(\"upscaled_lemons_rgb.png\")\nupscaled_depth.save(\"upscaled_lemons_depth.png\")\n```\n\n### ControlNet + T2I Adapter Pipeline\n\nThis pipeline combines both ControlNet and T2IAdapter into a single pipeline, where the forward pass is executed once.\nIt receives `control_image` and `adapter_image`, as well as `controlnet_conditioning_scale` and `adapter_conditioning_scale`, for the ControlNet and Adapter modules, respectively. Whenever `adapter_conditioning_scale=0` or `controlnet_conditioning_scale=0`, it will act as a full ControlNet module or as a full T2IAdapter module, respectively.\n\n```py\nimport cv2\nimport numpy as np\nimport torch\nfrom controlnet_aux.midas import MidasDetector\nfrom PIL import Image\n\nfrom diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.utils import load_image\nfrom examples.community.pipeline_stable_diffusion_xl_controlnet_adapter import (\n    StableDiffusionXLControlNetAdapterPipeline,\n)\n\ncontrolnet_depth = ControlNetModel.from_pretrained(\n    \"diffusers/controlnet-depth-sdxl-1.0\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n    use_safetensors=True\n)\nadapter_depth = T2IAdapter.from_pretrained(\n  \"TencentARC/t2i-adapter-depth-midas-sdxl-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\nvae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16, use_safetensors=True)\n\npipe = StableDiffusionXLControlNetAdapterPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    controlnet=controlnet_depth,\n    adapter=adapter_depth,\n    vae=vae,\n    variant=\"fp16\",\n    use_safetensors=True,\n    torch_dtype=torch.float16,\n)\npipe = pipe.to(\"cuda\")\npipe.enable_xformers_memory_efficient_attention()\n# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)\nmidas_depth = MidasDetector.from_pretrained(\n  \"valhalla/t2iadapter-aux-models\", filename=\"dpt_large_384.pt\", model_type=\"dpt_large\"\n).to(\"cuda\")\n\nprompt = \"a tiger sitting on a park bench\"\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\n\nimage = load_image(img_url).resize((1024, 1024))\n\ndepth_image = midas_depth(\n  image, detect_resolution=512, image_resolution=1024\n)\n\nstrength = 0.5\n\nimages = pipe(\n    prompt,\n    control_image=depth_image,\n    adapter_image=depth_image,\n    num_inference_steps=30,\n    controlnet_conditioning_scale=strength,\n    adapter_conditioning_scale=strength,\n).images\nimages[0].save(\"controlnet_and_adapter.png\")\n```\n\n### ControlNet + T2I Adapter + Inpainting Pipeline\n\n```py\nimport cv2\nimport numpy as np\nimport torch\nfrom controlnet_aux.midas import MidasDetector\nfrom PIL import Image\n\nfrom diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.utils import load_image\nfrom examples.community.pipeline_stable_diffusion_xl_controlnet_adapter_inpaint import (\n    StableDiffusionXLControlNetAdapterInpaintPipeline,\n)\n\ncontrolnet_depth = ControlNetModel.from_pretrained(\n    \"diffusers/controlnet-depth-sdxl-1.0\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n    use_safetensors=True\n)\nadapter_depth = T2IAdapter.from_pretrained(\n  \"TencentARC/t2i-adapter-depth-midas-sdxl-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n)\nvae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16, use_safetensors=True)\n\npipe = StableDiffusionXLControlNetAdapterInpaintPipeline.from_pretrained(\n    \"diffusers/stable-diffusion-xl-1.0-inpainting-0.1\",\n    controlnet=controlnet_depth,\n    adapter=adapter_depth,\n    vae=vae,\n    variant=\"fp16\",\n    use_safetensors=True,\n    torch_dtype=torch.float16,\n)\npipe = pipe.to(\"cuda\")\npipe.enable_xformers_memory_efficient_attention()\n# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)\nmidas_depth = MidasDetector.from_pretrained(\n  \"valhalla/t2iadapter-aux-models\", filename=\"dpt_large_384.pt\", model_type=\"dpt_large\"\n).to(\"cuda\")\n\nprompt = \"a tiger sitting on a park bench\"\nimg_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\nmask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\nimage = load_image(img_url).resize((1024, 1024))\nmask_image = load_image(mask_url).resize((1024, 1024))\n\ndepth_image = midas_depth(\n  image, detect_resolution=512, image_resolution=1024\n)\n\nstrength = 0.4\n\nimages = pipe(\n    prompt,\n    image=image,\n    mask_image=mask_image,\n    control_image=depth_image,\n    adapter_image=depth_image,\n    num_inference_steps=30,\n    controlnet_conditioning_scale=strength,\n    adapter_conditioning_scale=strength,\n    strength=0.7,\n).images\nimages[0].save(\"controlnet_and_adapter_inpaint.png\")\n```\n\n### Regional Prompting Pipeline\n\nThis pipeline is a port of the [Regional Prompter extension](https://github.com/hako-mikan/sd-webui-regional-prompter) for [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) to `diffusers`.\nThis code implements a pipeline for the Stable Diffusion model, enabling the division of the canvas into multiple regions, with different prompts applicable to each region. Users can specify regions in two ways: using `Cols` and `Rows` modes for grid-like divisions, or the `Prompt` mode for regions calculated based on prompts.\n\n![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline1.png)\n\n### Usage\n\n### Sample Code\n\n```py\nfrom examples.community.regional_prompting_stable_diffusion import RegionalPromptingStableDiffusionPipeline\n\npipe = RegionalPromptingStableDiffusionPipeline.from_single_file(model_path, vae=vae)\n\nrp_args = {\n    \"mode\":\"rows\",\n    \"div\": \"1;1;1\"\n}\n\nprompt = \"\"\"\ngreen hair twintail BREAK\nred blouse BREAK\nblue skirt\n\"\"\"\n\nimages = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    guidance_scale=7.5,\n    height=768,\n    width=512,\n    num_inference_steps=20,\n    num_images_per_prompt=1,\n    rp_args=rp_args\n    ).images\n\ntime = time.strftime(r\"%Y%m%d%H%M%S\")\ni = 1\nfor image in images:\n    i += 1\n    fileName = f'img-{time}-{i+1}.png'\n    image.save(fileName)\n```\n\n### Cols, Rows mode\n\nIn the Cols, Rows mode, you can split the screen vertically and horizontally and assign prompts to each region. The split ratio can be specified by 'div', and you can set the division ratio like '3;3;2' or '0.1;0.5'. Furthermore, as will be described later, you can also subdivide the split Cols, Rows to specify more complex regions.\n\nIn this image, the image is divided into three parts, and a separate prompt is applied to each. The prompts are divided by 'BREAK', and each is applied to the respective region.\n![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline2.png)\n\n```\ngreen hair twintail BREAK\nred blouse BREAK\nblue skirt\n```\n\n### 2-Dimentional division\n\nThe prompt consists of instructions separated by the term `BREAK` and is assigned to different regions of a two-dimensional space. The image is initially split in the main splitting direction, which in this case is rows, due to the presence of a single semicolon `;`, dividing the space into an upper and a lower section. Additional sub-splitting is then applied, indicated by commas. The upper row is split into ratios of `2:1:1`, while the lower row is split into a ratio of `4:6`. Rows themselves are split in a `1:2` ratio. According to the reference image, the blue sky is designated as the first region, green hair as the second, the bookshelf as the third, and so on, in a sequence based on their position from the top left. The terrarium is placed on the desk in the fourth region, and the orange dress and sofa are in the fifth region, conforming to their respective splits.\n\n```py\nrp_args = {\n    \"mode\":\"rows\",\n    \"div\": \"1,2,1,1;2,4,6\"\n}\n\nprompt = \"\"\"\nblue sky BREAK\ngreen hair BREAK\nbook shelf BREAK\nterrarium on the desk BREAK\norange dress and sofa\n\"\"\"\n```\n\n![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline4.png)\n\n### Prompt Mode\n\nThere are limitations to methods of specifying regions in advance. This is because specifying regions can be a hindrance when designating complex shapes or dynamic compositions. In the region specified by the prompt, the region is determined after the image generation has begun. This allows us to accommodate compositions and complex regions.\nFor further infomagen, see [here](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/main/prompt_en.md).\n\n### Syntax\n\n```\nbaseprompt target1 target2 BREAK\neffect1, target1 BREAK\neffect2 ,target2\n```\n\nFirst, write the base prompt. In the base prompt, write the words (target1, target2) for which you want to create a mask. Next, separate them with BREAK. Next, write the prompt corresponding to target1. Then enter a comma and write target1. The order of the targets in the base prompt and the order of the BREAK-separated targets can be back to back.\n\n```\ntarget2 baseprompt target1  BREAK\neffect1, target1 BREAK\neffect2 ,target2\n```\n\nis also effective.\n\n### Sample\n\nIn this example, masks are calculated for shirt, tie, skirt, and color prompts are specified only for those regions.\n\n```py\nrp_args = {\n    \"mode\": \"prompt-ex\",\n    \"save_mask\": True,\n    \"th\": \"0.4,0.6,0.6\",\n}\n\nprompt = \"\"\"\na girl in street with shirt, tie, skirt BREAK\nred, shirt BREAK\ngreen, tie BREAK\nblue , skirt\n\"\"\"\n```\n\n![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline3.png)\n\n### Threshold\n\nThe threshold used to determine the mask created by the prompt. This can be set as many times as there are masks, as the range varies widely depending on the target prompt. If multiple regions are used, enter them separated by commas. For example, hair tends to be ambiguous and requires a small value, while face tends to be large and requires a small value. These should be ordered by BREAK.\n\n```\na lady ,hair, face  BREAK\nred, hair BREAK\ntanned ,face\n```\n\n`threshold : 0.4,0.6`\nIf only one input is given for multiple regions, they are all assumed to be the same value.\n\n### Prompt and Prompt-EX\n\nThe difference is that in Prompt, duplicate regions are added, whereas in Prompt-EX, duplicate regions are overwritten sequentially. Since they are processed in order, setting a TARGET with a large regions first makes it easier for the effect of small regions to remain unmuffled.\n\n### Accuracy\n\nIn the case of a 512x512 image, Attention mode reduces the size of the region to about 8x8 pixels deep in the U-Net, so that small regions get mixed up; Latent mode calculates 64*64, so that the region is exact.\n\n```\ngirl hair twintail frills,ribbons, dress, face BREAK\ngirl, ,face\n```\n\n### Mask\n\nWhen an image is generated, the generated mask is displayed. It is generated at the same size as the image, but is actually used at a much smaller size.\n\n### Use common prompt\n\nYou can attach the prompt up to ADDCOMM to all prompts by separating it first with ADDCOMM. This is useful when you want to include elements common to all regions. For example, when generating pictures of three people with different appearances, it's necessary to include the instruction of 'three people' in all regions. It's also useful when inserting quality tags and other things. \"For example, if you write as follows:\n\n```\nbest quality, 3persons in garden, ADDCOMM\na girl white dress BREAK\na boy blue shirt BREAK\nan old man red suit\n```\n\nIf common is enabled, this prompt is converted to the following:\n\n```\nbest quality, 3persons in garden, a girl white dress BREAK\nbest quality, 3persons in garden, a boy blue shirt BREAK\nbest quality, 3persons in garden, an old man red suit\n```\n\n### Use base prompt\n\nYou can use a base prompt to apply the prompt to all areas. You can set a base prompt by adding `ADDBASE` at the end. Base prompts can also be combined with common prompts, but the base prompt must be specified first.\n\n```\n2d animation style ADDBASE\nmasterpiece, high quality ADDCOMM\n(blue sky)++ BREAK\ngreen hair twintail BREAK\nbook shelf BREAK\nmessy desk BREAK\norange++ dress and sofa\n```\n\n### Negative prompt\n\nNegative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.\n\n### Parameters\n\nTo activate Regional Prompter, it is necessary to enter settings in `rp_args`. The items that can be set are as follows. `rp_args` is a dictionary type.\n\n### Input Parameters\n\nParameters are specified through the `rp_arg`(dictionary type).\n\n```py\nrp_args = {\n    \"mode\":\"rows\",\n    \"div\": \"1;1;1\"\n}\n\npipe(prompt=prompt, rp_args=rp_args)\n```\n\n### Required Parameters\n\n- `mode`: Specifies the method for defining regions. Choose from `Cols`, `Rows`, `Prompt`, or `Prompt-Ex`. This parameter is case-insensitive.\n- `divide`: Used in `Cols` and `Rows` modes. Details on how to specify this are provided under the respective `Cols` and `Rows` sections.\n- `th`: Used in `Prompt` mode. The method of specification is detailed under the `Prompt` section.\n\n### Optional Parameters\n\n- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.\n- `base_ratio`: Used with `ADDBASE`. Sets the ratio of the base prompt; if base ratio is set to 0.2, then resulting images will consist of `20%*BASE_PROMPT + 80%*REGION_PROMPT`\n\nThe Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.\n\n### Diffusion Posterior Sampling Pipeline\n\n- Reference paper\n\n    ```bibtex\n    @article{chung2022diffusion,\n    title={Diffusion posterior sampling for general noisy inverse problems},\n    author={Chung, Hyungjin and Kim, Jeongsol and Mccann, Michael T and Klasky, Marc L and Ye, Jong Chul},\n    journal={arXiv preprint arXiv:2209.14687},\n    year={2022}\n    }\n    ```\n\n- This pipeline allows zero-shot conditional sampling from the posterior distribution $p(x|y)$, given observation on $y$, unconditional generative model $p(x)$ and differentiable operator $y=f(x)$.\n\n- For example, $f(.)$ can be downsample operator, then $y$ is a downsampled image, and the pipeline becomes a super-resolution pipeline.\n- To use this pipeline, you need to know your operator $f(.)$ and corrupted image $y$, and pass them during the call. For example, as in the main function of `dps_pipeline.py`, you need to first define the Gaussian blurring operator $f(.)$. The operator should be a callable `nn.Module`, with all the parameter gradient disabled:\n\n    ```python\n    import torch.nn.functional as F\n    import scipy\n    from torch import nn\n\n    # define the Gaussian blurring operator first\n    class GaussialBlurOperator(nn.Module):\n        def __init__(self, kernel_size, intensity):\n            super().__init__()\n\n            class Blurkernel(nn.Module):\n                def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0):\n                    super().__init__()\n                    self.blur_type = blur_type\n                    self.kernel_size = kernel_size\n                    self.std = std\n                    self.seq = nn.Sequential(\n                        nn.ReflectionPad2d(self.kernel_size//2),\n                        nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)\n                    )\n                    self.weights_init()\n\n                def forward(self, x):\n                    return self.seq(x)\n\n                def weights_init(self):\n                    if self.blur_type == \"gaussian\":\n                        n = np.zeros((self.kernel_size, self.kernel_size))\n                        n[self.kernel_size // 2, self.kernel_size // 2] = 1\n                        k = scipy.ndimage.gaussian_filter(n, sigma=self.std)\n                        k = torch.from_numpy(k)\n                        self.k = k\n                        for name, f in self.named_parameters():\n                            f.data.copy_(k)\n                    elif self.blur_type == \"motion\":\n                        k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix\n                        k = torch.from_numpy(k)\n                        self.k = k\n                        for name, f in self.named_parameters():\n                            f.data.copy_(k)\n\n                def update_weights(self, k):\n                    if not torch.is_tensor(k):\n                        k = torch.from_numpy(k)\n                    for name, f in self.named_parameters():\n                        f.data.copy_(k)\n\n                def get_kernel(self):\n                    return self.k\n\n            self.kernel_size = kernel_size\n            self.conv = Blurkernel(blur_type='gaussian',\n                                kernel_size=kernel_size,\n                                std=intensity)\n            self.kernel = self.conv.get_kernel()\n            self.conv.update_weights(self.kernel.type(torch.float32))\n\n            for param in self.parameters():\n                param.requires_grad = False\n\n        def forward(self, data, **kwargs):\n            return self.conv(data)\n\n        def transpose(self, data, **kwargs):\n            return data\n\n        def get_kernel(self):\n            return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)\n    ```\n\n- Next, you should obtain the corrupted image $y$ by the operator. In this example, we generate $y$ from the source image $x$. However in practice, having the operator $f(.)$ and corrupted image $y$ is enough:\n\n    ```python\n    # set up source image\n    src = Image.open('sample.png')\n    # read image into [1,3,H,W]\n    src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2,0,1)[None]\n    # normalize image to [-1,1]\n    src = (src / 127.5) - 1.0\n    src = src.to(\"cuda\")\n\n    # set up operator and measurement\n    operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to(\"cuda\")\n    measurement = operator(src)\n\n    # save the source and corrupted images\n    save_image((src+1.0)/2.0, \"dps_src.png\")\n    save_image((measurement+1.0)/2.0, \"dps_mea.png\")\n    ```\n\n- We provide an example pair of saved source and corrupted images, using the Gaussian blur operator above\n  - Source image:\n  - ![sample](https://github.com/tongdaxu/Images/assets/22267548/4d2a1216-08d1-4aeb-9ce3-7a2d87561d65)\n  - Gaussian blurred image:\n  - ![ddpm_generated_image](https://github.com/tongdaxu/Images/assets/22267548/65076258-344b-4ed8-b704-a04edaade8ae)\n  - You can download those images to run the example on your own.\n\n- Next, we need to define a loss function used for diffusion posterior sample. For most of the cases, the RMSE is fine:\n\n    ```python\n    def RMSELoss(yhat, y):\n        return torch.sqrt(torch.sum((yhat-y)**2))\n    ```\n\n- And next, as any other diffusion models, we need the score estimator and scheduler. As we are working with $256x256$ face images, we use ddpm-celebahq-256:\n\n    ```python\n    # set up scheduler\n    scheduler = DDPMScheduler.from_pretrained(\"google/ddpm-celebahq-256\")\n    scheduler.set_timesteps(1000)\n\n    # set up model\n    model = UNet2DModel.from_pretrained(\"google/ddpm-celebahq-256\").to(\"cuda\")\n    ```\n\n- And finally, run the pipeline:\n\n    ```python\n    # finally, the pipeline\n    dpspipe = DPSPipeline(model, scheduler)\n    image = dpspipe(\n        measurement=measurement,\n        operator=operator,\n        loss_fn=RMSELoss,\n        zeta=1.0,\n    ).images[0]\n    image.save(\"dps_generated_image.png\")\n    ```\n\n- The `zeta` is a hyperparameter that is in range of $[0,1]$. It needs to be tuned for best effect. By setting `zeta=1`, you should be able to have the reconstructed result:\n  - Reconstructed image:\n  - ![sample](https://github.com/tongdaxu/Images/assets/22267548/0ceb5575-d42e-4f0b-99c0-50e69c982209)\n\n- The reconstruction is perceptually similar to the source image, but different in details.\n- In `dps_pipeline.py`, we also provide a super-resolution example, which should produce:\n  - Downsampled image:\n  - ![dps_mea](https://github.com/tongdaxu/Images/assets/22267548/ff6a33d6-26f0-42aa-88ce-f8a76ba45a13)\n  - Reconstructed image:\n  - ![dps_generated_image](https://github.com/tongdaxu/Images/assets/22267548/b74f084d-93f4-4845-83d8-44c0fa758a5f)\n\n### AnimateDiff ControlNet Pipeline\n\nThis pipeline combines AnimateDiff and ControlNet. Enjoy precise motion control for your videos! Refer to [this](https://github.com/huggingface/diffusers/issues/5866) issue for more details.\n\n```py\nimport torch\nfrom diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, DiffusionPipeline, DPMSolverMultistepScheduler\nfrom diffusers.utils import export_to_gif\nfrom PIL import Image\n\nmotion_id = \"guoyww/animatediff-motion-adapter-v1-5-2\"\nadapter = MotionAdapter.from_pretrained(motion_id)\ncontrolnet = ControlNetModel.from_pretrained(\"lllyasviel/control_v11p_sd15_openpose\", torch_dtype=torch.float16)\nvae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16)\n\nmodel_id = \"SG161222/Realistic_Vision_V5.1_noVAE\"\npipe = DiffusionPipeline.from_pretrained(\n    model_id,\n    motion_adapter=adapter,\n    controlnet=controlnet,\n    vae=vae,\n    custom_pipeline=\"pipeline_animatediff_controlnet\",\n    torch_dtype=torch.float16,\n).to(device=\"cuda\")\npipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(\n    model_id, subfolder=\"scheduler\", beta_schedule=\"linear\", clip_sample=False, timestep_spacing=\"linspace\", steps_offset=1\n)\npipe.enable_vae_slicing()\n\nconditioning_frames = []\nfor i in range(1, 16 + 1):\n    conditioning_frames.append(Image.open(f\"frame_{i}.png\"))\n\nprompt = \"astronaut in space, dancing\"\nnegative_prompt = \"bad quality, worst quality, jpeg artifacts, ugly\"\nresult = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    width=512,\n    height=768,\n    conditioning_frames=conditioning_frames,\n    num_inference_steps=20,\n).frames[0]\n\nexport_to_gif(result.frames[0], \"result.gif\")\n```\n\n<table>\n  <tr><td colspan=\"2\" align=center><b>Conditioning Frames</b></td></tr>\n  <tr align=center>\n    <td align=center><img src=\"https://user-images.githubusercontent.com/7365912/265043418-23291941-864d-495a-8ba8-d02e05756396.gif\" alt=\"input-frames\"></td>\n  </tr>\n  <tr><td colspan=\"2\" align=center><b>AnimateDiff model: SG161222/Realistic_Vision_V5.1_noVAE</b></td></tr>\n  <tr>\n    <td align=center><img src=\"https://github.com/huggingface/diffusers/assets/72266394/baf301e2-d03c-4129-bd84-203a1de2b2be\" alt=\"gif-1\"></td>\n    <td align=center><img src=\"https://github.com/huggingface/diffusers/assets/72266394/9f923475-ecaf-452b-92c8-4e42171182d8\" alt=\"gif-2\"></td>\n  </tr>\n  <tr><td colspan=\"2\" align=center><b>AnimateDiff model: CardosAnime</b></td></tr>\n  <tr>\n    <td align=center><img src=\"https://github.com/huggingface/diffusers/assets/72266394/b2c41028-38a0-45d6-86ed-fec7446b87f7\" alt=\"gif-1\"></td>\n    <td align=center><img src=\"https://github.com/huggingface/diffusers/assets/72266394/eb7d2952-72e4-44fa-b664-077c79b4fc70\" alt=\"gif-2\"></td>\n  </tr>\n</table>\n\nYou can also use multiple controlnets at once!\n\n```python\nimport torch\nfrom diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, DiffusionPipeline, DPMSolverMultistepScheduler\nfrom diffusers.utils import export_to_gif\nfrom PIL import Image\n\nmotion_id = \"guoyww/animatediff-motion-adapter-v1-5-2\"\nadapter = MotionAdapter.from_pretrained(motion_id)\ncontrolnet1 = ControlNetModel.from_pretrained(\"lllyasviel/control_v11p_sd15_openpose\", torch_dtype=torch.float16)\ncontrolnet2 = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\nvae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16)\n\nmodel_id = \"SG161222/Realistic_Vision_V5.1_noVAE\"\npipe = DiffusionPipeline.from_pretrained(\n    model_id,\n    motion_adapter=adapter,\n    controlnet=[controlnet1, controlnet2],\n    vae=vae,\n    custom_pipeline=\"pipeline_animatediff_controlnet\",\n    torch_dtype=torch.float16,\n).to(device=\"cuda\")\npipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(\n    model_id, subfolder=\"scheduler\", clip_sample=False, timestep_spacing=\"linspace\", steps_offset=1, beta_schedule=\"linear\",\n)\npipe.enable_vae_slicing()\n\ndef load_video(file_path: str):\n    images = []\n\n    if file_path.startswith(('http://', 'https://')):\n        # If the file_path is a URL\n        response = requests.get(file_path)\n        response.raise_for_status()\n        content = BytesIO(response.content)\n        vid = imageio.get_reader(content)\n    else:\n        # Assuming it's a local file path\n        vid = imageio.get_reader(file_path)\n\n    for frame in vid:\n        pil_image = Image.fromarray(frame)\n        images.append(pil_image)\n\n    return images\n\nvideo = load_video(\"dance.gif\")\n\n# You need to install it using `pip install controlnet_aux`\nfrom controlnet_aux.processor import Processor\n\np1 = Processor(\"openpose_full\")\ncn1 = [p1(frame) for frame in video]\n\np2 = Processor(\"canny\")\ncn2 = [p2(frame) for frame in video]\n\nprompt = \"astronaut in space, dancing\"\nnegative_prompt = \"bad quality, worst quality, jpeg artifacts, ugly\"\nresult = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    width=512,\n    height=768,\n    conditioning_frames=[cn1, cn2],\n    num_inference_steps=20,\n)\n\nexport_to_gif(result.frames[0], \"result.gif\")\n```\n\n### DemoFusion\n\nThis pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://huggingface.co/papers/2311.16973).\nThe original repo can be found at [repo](https://github.com/PRIS-CV/DemoFusion).\n\n- `view_batch_size` (`int`, defaults to 16):\n  The batch size for multiple denoising paths. Typically, a larger batch size can result in higher efficiency but comes with increased GPU memory requirements.\n\n- `stride` (`int`, defaults to 64):\n  The stride of moving local patches. A smaller stride is better for alleviating seam issues, but it also introduces additional computational overhead and inference time.\n\n- `cosine_scale_1` (`float`, defaults to 3):\n  Control the strength of skip-residual. For specific impacts, please refer to Appendix C in the DemoFusion paper.\n\n- `cosine_scale_2` (`float`, defaults to 1):\n  Control the strength of dilated sampling. For specific impacts, please refer to Appendix C in the DemoFusion paper.\n\n- `cosine_scale_3` (`float`, defaults to 1):\n  Control the strength of the Gaussian filter. For specific impacts, please refer to Appendix C in the DemoFusion paper.\n\n- `sigma` (`float`, defaults to 1):\n  The standard value of the Gaussian filter. Larger sigma promotes the global guidance of dilated sampling, but has the potential of over-smoothing.\n\n- `multi_decoder` (`bool`, defaults to True):\n  Determine whether to use a tiled decoder. Generally, when the resolution exceeds 3072x3072, a tiled decoder becomes necessary.\n\n- `show_image` (`bool`, defaults to False):\n  Determine whether to show intermediate results during generation.\n\n```py\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipe = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    custom_pipeline=\"pipeline_demofusion_sdxl\",\n    custom_revision=\"main\",\n    torch_dtype=torch.float16,\n)\npipe = pipe.to(\"cuda\")\n\nprompt = \"Envision a portrait of an elderly woman, her face a canvas of time, framed by a headscarf with muted tones of rust and cream. Her eyes, blue like faded denim. Her attire, simple yet dignified.\"\nnegative_prompt = \"blurry, ugly, duplicate, poorly drawn, deformed, mosaic\"\n\nimages = pipe(\n    prompt,\n    negative_prompt=negative_prompt,\n    height=3072,\n    width=3072,\n    view_batch_size=16,\n    stride=64,\n    num_inference_steps=50,\n    guidance_scale=7.5,\n    cosine_scale_1=3,\n    cosine_scale_2=1,\n    cosine_scale_3=1,\n    sigma=0.8,\n    multi_decoder=True,\n    show_image=True\n)\n```\n\nYou can display and save the generated images as:\n\n```py\ndef image_grid(imgs, save_path=None):\n\n    w = 0\n    for i, img in enumerate(imgs):\n        h_, w_ = imgs[i].size\n        w += w_\n    h = h_\n    grid = Image.new('RGB', size=(w, h))\n    grid_w, grid_h = grid.size\n\n    w = 0\n    for i, img in enumerate(imgs):\n        h_, w_ = imgs[i].size\n        grid.paste(img, box=(w, h - h_))\n        if save_path != None:\n            img.save(save_path + \"/img_{}.jpg\".format((i + 1) * 1024))\n        w += w_\n\n    return grid\n\nimage_grid(images, save_path=\"./outputs/\")\n```\n\n ![output_example](https://github.com/PRIS-CV/DemoFusion/blob/main/output_example.png)\n\n### SDE Drag pipeline\n\nThis pipeline provides drag-and-drop image editing using stochastic differential equations. It enables image editing by inputting prompt, image, mask_image, source_points, and target_points.\n\n![SDE Drag Image](https://github.com/huggingface/diffusers/assets/75928535/bd54f52f-f002-4951-9934-b2a4592771a5)\n\nSee [paper](https://huggingface.co/papers/2311.01410), [paper page](https://ml-gsai.github.io/SDE-Drag-demo/), [original repo](https://github.com/ML-GSAI/SDE-Drag) for more information.\n\n```py\nimport torch\nfrom diffusers import DDIMScheduler, DiffusionPipeline\nfrom PIL import Image\nimport requests\nfrom io import BytesIO\nimport numpy as np\n\n# Load the pipeline\nmodel_path = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nscheduler = DDIMScheduler.from_pretrained(model_path, subfolder=\"scheduler\")\npipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline=\"sde_drag\")\n\n# Ensure the model is moved to the GPU\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\npipe.to(device)\n\n# Function to load image from URL\ndef load_image_from_url(url):\n    response = requests.get(url)\n    return Image.open(BytesIO(response.content)).convert(\"RGB\")\n\n# Function to prepare mask\ndef prepare_mask(mask_image):\n    # Convert to grayscale\n    mask = mask_image.convert(\"L\")\n    return mask\n\n# Function to convert numpy array to PIL Image\ndef array_to_pil(array):\n    # Ensure the array is in uint8 format\n    if array.dtype != np.uint8:\n        if array.max() <= 1.0:\n            array = (array * 255).astype(np.uint8)\n        else:\n            array = array.astype(np.uint8)\n    \n    # Handle different array shapes\n    if len(array.shape) == 3:\n        if array.shape[0] == 3:  # If channels first\n            array = array.transpose(1, 2, 0)\n        return Image.fromarray(array)\n    elif len(array.shape) == 4:  # If batch dimension\n        array = array[0]\n        if array.shape[0] == 3:  # If channels first\n            array = array.transpose(1, 2, 0)\n        return Image.fromarray(array)\n    else:\n        raise ValueError(f\"Unexpected array shape: {array.shape}\")\n\n# Image and mask URLs\nimage_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png'\nmask_url = 'https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png'\n\n# Load the images\nimage = load_image_from_url(image_url)\nmask_image = load_image_from_url(mask_url)\n\n# Resize images to a size that's compatible with the model's latent space\nimage = image.resize((512, 512))\nmask_image = mask_image.resize((512, 512))\n\n# Prepare the mask (keep as PIL Image)\nmask = prepare_mask(mask_image)\n\n# Provide the prompt and points for drag editing\nprompt = \"A cute dog\"\nsource_points = [[32, 32]]  # Adjusted for 512x512 image\ntarget_points = [[64, 64]]  # Adjusted for 512x512 image\n\n# Generate the output image\noutput_array = pipe(\n    prompt=prompt,\n    image=image,\n    mask_image=mask,\n    source_points=source_points,\n    target_points=target_points\n)\n\n# Convert output array to PIL Image and save\noutput_image = array_to_pil(output_array)\noutput_image.save(\"./output.png\")\nprint(\"Output image saved as './output.png'\")\n\n```\n\n### Instaflow Pipeline\n\nInstaFlow is an ultra-fast, one-step image generator that achieves image quality close to Stable Diffusion, significantly reducing the demand of computational resources. This efficiency is made possible through a recent [Rectified Flow](https://github.com/gnobitab/RectifiedFlow) technique, which trains probability flows with straight trajectories, hence inherently requiring only a single step for fast inference.\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\n\npipe = DiffusionPipeline.from_pretrained(\"XCLIU/instaflow_0_9B_from_sd_1_5\", torch_dtype=torch.float16, custom_pipeline=\"instaflow_one_step\")\npipe.to(\"cuda\")  ### if GPU is not available, comment this line\nprompt = \"A hyper-realistic photo of a cute cat.\"\n\nimages = pipe(prompt=prompt,\n            num_inference_steps=1,\n            guidance_scale=0.0).images\nimages[0].save(\"./image.png\")\n```\n\n![image1](https://huggingface.co/datasets/ayushtues/instaflow_images/resolve/main/instaflow_cat.png)\n\nYou can also combine it with LORA out of the box, like <https://huggingface.co/artificialguybr/logo-redmond-1-5v-logo-lora-for-liberteredmond-sd-1-5>, to unlock cool use cases in single step!\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\npipe = DiffusionPipeline.from_pretrained(\"XCLIU/instaflow_0_9B_from_sd_1_5\", torch_dtype=torch.float16, custom_pipeline=\"instaflow_one_step\")\npipe.to(device)  ### if GPU is not available, comment this line\npipe.load_lora_weights(\"artificialguybr/logo-redmond-1-5v-logo-lora-for-liberteredmond-sd-1-5\")\nprompt = \"logo, A logo for a fitness app, dynamic running figure, energetic colors (red, orange) ),LogoRedAF ,\"\nimages = pipe(prompt=prompt,\n            num_inference_steps=1,\n            guidance_scale=0.0).images\nimages[0].save(\"./image.png\")\n```\n\n![image0](https://huggingface.co/datasets/ayushtues/instaflow_images/resolve/main/instaflow_logo.png)\n\n### Null-Text Inversion pipeline\n\nThis pipeline provides null-text inversion for editing real images. It enables null-text optimization, and DDIM reconstruction via w, w/o null-text optimization. No prompt-to-prompt code is implemented as there is a Prompt2PromptPipeline.\n\n- Reference paper\n\n    ```bibtex\n    @article{hertz2022prompt,\n    title={Prompt-to-prompt image editing with cross attention control},\n    author={Hertz, Amir and Mokady, Ron and Tenenbaum, Jay and Aberman, Kfir and Pritch, Yael and Cohen-Or, Daniel},\n    booktitle={arXiv preprint arXiv:2208.01626},\n    year={2022}\n    ```}\n\n```py\nfrom diffusers import DDIMScheduler\nfrom examples.community.pipeline_null_text_inversion import NullTextPipeline\nimport torch\n\ndevice = \"cuda\"\n# Provide invert_prompt and the image for null-text optimization.\ninvert_prompt = \"A lying cat\"\ninput_image = \"siamese.jpg\"\nsteps = 50\n\n# Provide prompt used for generation. Same if reconstruction\nprompt = \"A lying cat\"\n# or different if editing.\nprompt = \"A lying dog\"\n\n# Float32 is essential to a well optimization\nmodel_path = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nscheduler = DDIMScheduler(num_train_timesteps=1000, beta_start=0.00085, beta_end=0.0120, beta_schedule=\"scaled_linear\")\npipeline = NullTextPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=torch.float32).to(device)\n\n# Saves the inverted_latent to save time\ninverted_latent, uncond = pipeline.invert(input_image, invert_prompt, num_inner_steps=10, early_stop_epsilon=1e-5, num_inference_steps=steps)\npipeline(prompt, uncond, inverted_latent, guidance_scale=7.5, num_inference_steps=steps).images[0].save(input_image+\".output.jpg\")\n```\n\n### Rerender A Video\n\nThis is the Diffusers implementation of zero-shot video-to-video translation pipeline [Rerender A Video](https://github.com/williamyang1991/Rerender_A_Video) (without Ebsynth postprocessing). To run the code, please install gmflow. Then modify the path in `gmflow_dir`. After that, you can run the pipeline with:\n\n```py\nimport sys\ngmflow_dir = \"/path/to/gmflow\"\nsys.path.insert(0, gmflow_dir)\n\nfrom diffusers import ControlNetModel, AutoencoderKL, DDIMScheduler\nfrom diffusers.utils import export_to_video\nimport numpy as np\nimport torch\n\nimport cv2\nfrom PIL import Image\n\ndef video_to_frame(video_path: str, interval: int):\n    vidcap = cv2.VideoCapture(video_path)\n    success = True\n\n    count = 0\n    res = []\n    while success:\n        count += 1\n        success, image = vidcap.read()\n        if count % interval != 1:\n            continue\n        if image is not None:\n            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n            res.append(image)\n\n    vidcap.release()\n    return res\n\ninput_video_path = 'path/to/video'\ninput_interval = 10\nframes = video_to_frame(\n    input_video_path, input_interval)\n\ncontrol_frames = []\n# get canny image\nfor frame in frames:\n    np_image = cv2.Canny(frame, 50, 100)\n    np_image = np_image[:, :, None]\n    np_image = np.concatenate([np_image, np_image, np_image], axis=2)\n    canny_image = Image.fromarray(np_image)\n    control_frames.append(canny_image)\n\n# You can use any ControlNet here\ncontrolnet = ControlNetModel.from_pretrained(\n    \"lllyasviel/sd-controlnet-canny\").to('cuda')\n\n# You can use any finetuned SD here\npipe = DiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\", controlnet=controlnet, custom_pipeline='rerender_a_video').to('cuda')\n\n# Optional: you can download vae-ft-mse-840000-ema-pruned.ckpt to enhance the results\n# pipe.vae = AutoencoderKL.from_single_file(\n#     \"path/to/vae-ft-mse-840000-ema-pruned.ckpt\").to('cuda')\n\npipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)\n\ngenerator = torch.manual_seed(0)\nframes = [Image.fromarray(frame) for frame in frames]\noutput_frames = pipe(\n    \"a beautiful woman in CG style, best quality, extremely detailed\",\n    frames,\n    control_frames,\n    num_inference_steps=20,\n    strength=0.75,\n    controlnet_conditioning_scale=0.7,\n    generator=generator,\n    warp_start=0.0,\n    warp_end=0.1,\n    mask_start=0.5,\n    mask_end=0.8,\n    mask_strength=0.5,\n    negative_prompt='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'\n).frames[0]\n\nexport_to_video(\n    output_frames, \"/path/to/video.mp4\", 5)\n```\n\n### StyleAligned Pipeline\n\nThis pipeline is the implementation of [Style Aligned Image Generation via Shared Attention](https://huggingface.co/papers/2312.02133). You can find more results [here](https://github.com/huggingface/diffusers/pull/6489#issuecomment-1881209354).\n\n> Large-scale Text-to-Image (T2I) models have rapidly gained prominence across creative fields, generating visually compelling outputs from textual prompts. However, controlling these models to ensure consistent style remains challenging, with existing methods necessitating fine-tuning and manual intervention to disentangle content and style. In this paper, we introduce StyleAligned, a novel technique designed to establish style alignment among a series of generated images. By employing minimal `attention sharing' during the diffusion process, our method maintains style consistency across images within T2I models. This approach allows for the creation of style-consistent images using a reference style through a straightforward inversion operation. Our method's evaluation across diverse styles and text prompts demonstrates high-quality synthesis and fidelity, underscoring its efficacy in achieving consistent style across various inputs.\n\n```python\nfrom typing import List\n\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom PIL import Image\n\nmodel_id = \"a-r-r-o-w/dreamshaper-xl-turbo\"\npipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant=\"fp16\", custom_pipeline=\"pipeline_sdxl_style_aligned\")\npipe = pipe.to(\"cuda\")\n\n# Enable memory saving techniques\npipe.enable_vae_slicing()\npipe.enable_vae_tiling()\n\nprompt = [\n  \"a toy train. macro photo. 3d game asset\",\n  \"a toy airplane. macro photo. 3d game asset\",\n  \"a toy bicycle. macro photo. 3d game asset\",\n  \"a toy car. macro photo. 3d game asset\",\n]\nnegative_prompt = \"low quality, worst quality, \"\n\n# Enable StyleAligned\npipe.enable_style_aligned(\n    share_group_norm=False,\n    share_layer_norm=False,\n    share_attention=True,\n    adain_queries=True,\n    adain_keys=True,\n    adain_values=False,\n    full_attention_share=False,\n    shared_score_scale=1.0,\n    shared_score_shift=0.0,\n    only_self_level=0.0,\n)\n\n# Run inference\nimages = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    guidance_scale=2,\n    height=1024,\n    width=1024,\n    num_inference_steps=10,\n    generator=torch.Generator().manual_seed(42),\n).images\n\n# Disable StyleAligned if you do not wish to use it anymore\npipe.disable_style_aligned()\n```\n\n### AnimateDiff Image-To-Video Pipeline\n\nThis pipeline adds experimental support for the image-to-video task using AnimateDiff. Refer to [this](https://github.com/huggingface/diffusers/pull/6328) PR for more examples and results.\n\nThis pipeline relies on a \"hack\" discovered by the community that allows the generation of videos given an input image with AnimateDiff. It works by creating a copy of the image `num_frames` times and progressively adding more noise to the image based on the strength and latent interpolation method.\n\n```py\nimport torch\nfrom diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler\nfrom diffusers.utils import export_to_gif, load_image\n\nmodel_id = \"SG161222/Realistic_Vision_V5.1_noVAE\"\nadapter = MotionAdapter.from_pretrained(\"guoyww/animatediff-motion-adapter-v1-5-2\")\npipe = DiffusionPipeline.from_pretrained(model_id, motion_adapter=adapter, custom_pipeline=\"pipeline_animatediff_img2video\").to(\"cuda\")\npipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder=\"scheduler\", clip_sample=False, timestep_spacing=\"linspace\", beta_schedule=\"linear\", steps_offset=1)\n\nimage = load_image(\"snail.png\")\noutput = pipe(\n  image=image,\n  prompt=\"A snail moving on the ground\",\n  strength=0.8,\n  latent_interpolation_method=\"slerp\",  # can be lerp, slerp, or your own callback\n)\nframes = output.frames[0]\nexport_to_gif(frames, \"animation.gif\")\n```\n\n### IP Adapter Face ID\n\nIP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.\nYou need to install `insightface` and all its requirements to use this model.\nYou must pass the image embedding tensor as `image_embeds` to the `DiffusionPipeline` instead of `ip_adapter_image`.\nYou can find more results [here](https://github.com/huggingface/diffusers/pull/6276).\n\n```py\nimport torch\nfrom diffusers.utils import load_image\nimport cv2\nimport numpy as np\nfrom diffusers import DiffusionPipeline, AutoencoderKL, DDIMScheduler\nfrom insightface.app import FaceAnalysis\n\n\nnoise_scheduler = DDIMScheduler(\n    num_train_timesteps=1000,\n    beta_start=0.00085,\n    beta_end=0.012,\n    beta_schedule=\"scaled_linear\",\n    clip_sample=False,\n    set_alpha_to_one=False,\n    steps_offset=1,\n)\nvae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\").to(dtype=torch.float16)\npipeline = DiffusionPipeline.from_pretrained(\n    \"SG161222/Realistic_Vision_V4.0_noVAE\",\n    torch_dtype=torch.float16,\n    scheduler=noise_scheduler,\n    vae=vae,\n    custom_pipeline=\"ip_adapter_face_id\"\n)\npipeline.load_ip_adapter_face_id(\"h94/IP-Adapter-FaceID\", \"ip-adapter-faceid_sd15.bin\")\npipeline.to(\"cuda\")\n\ngenerator = torch.Generator(device=\"cpu\").manual_seed(42)\nnum_images = 2\n\nimage = load_image(\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png\")\n\napp = FaceAnalysis(name=\"buffalo_l\", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\napp.prepare(ctx_id=0, det_size=(640, 640))\nimage = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)\nfaces = app.get(image)\nimage = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)\nimages = pipeline(\n    prompt=\"A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower\",\n    image_embeds=image,\n    negative_prompt=\"monochrome, lowres, bad anatomy, worst quality, low quality\",\n    num_inference_steps=20, num_images_per_prompt=num_images, width=512, height=704,\n    generator=generator\n).images\n\nfor i in range(num_images):\n    images[i].save(f\"c{i}.png\")\n```\n\n### InstantID Pipeline\n\nInstantID is a new state-of-the-art tuning-free method to achieve ID-Preserving generation with only single image, supporting various downstream tasks. For any usage question, please refer to the [official implementation](https://github.com/InstantID/InstantID).\n\n```py\n# !pip install diffusers opencv-python transformers accelerate insightface\nimport diffusers\nfrom diffusers.utils import load_image\nfrom diffusers import ControlNetModel\n\nimport cv2\nimport torch\nimport numpy as np\nfrom PIL import Image\n\nfrom insightface.app import FaceAnalysis\nfrom pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps\n\n# prepare 'antelopev2' under ./models\n# https://github.com/deepinsight/insightface/issues/1896#issuecomment-1023867304\napp = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\napp.prepare(ctx_id=0, det_size=(640, 640))\n\n# prepare models under ./checkpoints\n# https://huggingface.co/InstantX/InstantID\nfrom huggingface_hub import hf_hub_download\n\nhf_hub_download(repo_id=\"InstantX/InstantID\", filename=\"ControlNetModel/config.json\", local_dir=\"./checkpoints\")\nhf_hub_download(repo_id=\"InstantX/InstantID\", filename=\"ControlNetModel/diffusion_pytorch_model.safetensors\", local_dir=\"./checkpoints\")\nhf_hub_download(repo_id=\"InstantX/InstantID\", filename=\"ip-adapter.bin\", local_dir=\"./checkpoints\")\n\nface_adapter = './checkpoints/ip-adapter.bin'\ncontrolnet_path = './checkpoints/ControlNetModel'\n\n# load IdentityNet\ncontrolnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)\n\nbase_model = 'wangqixun/YamerMIX_v8'\npipe = StableDiffusionXLInstantIDPipeline.from_pretrained(\n    base_model,\n    controlnet=controlnet,\n    torch_dtype=torch.float16\n)\npipe.to(\"cuda\")\n\n# load adapter\npipe.load_ip_adapter_instantid(face_adapter)\n\n# load an image\nface_image = load_image(\"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png\")\n\n# prepare face emb\nface_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))\nface_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1]  # only use the maximum face\nface_emb = face_info['embedding']\nface_kps = draw_kps(face_image, face_info['kps'])\n\n# prompt\nprompt = \"film noir style, ink sketch|vector, male man, highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic\"\nnegative_prompt = \"ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, vibrant, colorful\"\n\n# generate image\npipe.set_ip_adapter_scale(0.8)\nimage = pipe(\n    prompt,\n    image_embeds=face_emb,\n    image=face_kps,\n    controlnet_conditioning_scale=0.8,\n).images[0]\n```\n\n### UFOGen Scheduler\n\n[UFOGen](https://huggingface.co/papers/2311.09257) is a generative model designed for fast one-step text-to-image generation, trained via adversarial training starting from an initial pretrained diffusion model such as Stable Diffusion. `scheduling_ufogen.py` implements a onestep and multistep sampling algorithm for UFOGen models compatible with pipelines like `StableDiffusionPipeline`. A usage example is as follows:\n\n```py\nimport torch\nfrom diffusers import StableDiffusionPipeline\n\nfrom scheduling_ufogen import UFOGenScheduler\n\n# NOTE: currently, I am not aware of any publicly available UFOGen model checkpoints trained from SD v1.5.\nufogen_model_id_or_path = \"/path/to/ufogen/model\"\npipe = StableDiffusionPipeline(\n    ufogen_model_id_or_path,\n    torch_dtype=torch.float16,\n)\n\n# You can initialize a UFOGenScheduler as follows:\npipe.scheduler = UFOGenScheduler.from_config(pipe.scheduler.config)\n\nprompt = \"Three cats having dinner at a table at new years eve, cinematic shot, 8k.\"\n\n# Onestep sampling\nonestep_image = pipe(prompt, num_inference_steps=1).images[0]\n\n# Multistep sampling\nmultistep_image = pipe(prompt, num_inference_steps=4).images[0]\n```\n\n### FRESCO\n\nThis is the Diffusers implementation of zero-shot video-to-video translation pipeline [FRESCO](https://github.com/williamyang1991/FRESCO) (without Ebsynth postprocessing and background smooth). To run the code, please install gmflow. Then modify the path in `gmflow_dir`. After that, you can run the pipeline with:\n\n```py\nfrom PIL import Image\nimport cv2\nimport torch\nimport numpy as np\n\nfrom diffusers import ControlNetModel, DDIMScheduler, DiffusionPipeline\nimport sys\n\ngmflow_dir = \"/path/to/gmflow\"\nsys.path.insert(0, gmflow_dir)\n\ndef video_to_frame(video_path: str, interval: int):\n    vidcap = cv2.VideoCapture(video_path)\n    success = True\n\n    count = 0\n    res = []\n    while success:\n        count += 1\n        success, image = vidcap.read()\n        if count % interval != 1:\n            continue\n        if image is not None:\n            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n            res.append(image)\n            if len(res) >= 8:\n                break\n\n    vidcap.release()\n    return res\n\n\ninput_video_path = 'https://github.com/williamyang1991/FRESCO/raw/main/data/car-turn.mp4'\noutput_video_path = 'car.gif'\n\n# You can use any finetuned SD here\nmodel_path = 'SG161222/Realistic_Vision_V2.0'\n\nprompt = 'a red car turns in the winter'\na_prompt = ', RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, '\nn_prompt = '(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation'\n\ninput_interval = 5\nframes = video_to_frame(\n    input_video_path, input_interval)\n\ncontrol_frames = []\n# get canny image\nfor frame in frames:\n    image = cv2.Canny(frame, 50, 100)\n    np_image = np.array(image)\n    np_image = np_image[:, :, None]\n    np_image = np.concatenate([np_image, np_image, np_image], axis=2)\n    canny_image = Image.fromarray(np_image)\n    control_frames.append(canny_image)\n\n# You can use any ControlNet here\ncontrolnet = ControlNetModel.from_pretrained(\n    \"lllyasviel/sd-controlnet-canny\").to('cuda')\n\npipe = DiffusionPipeline.from_pretrained(\n    model_path, controlnet=controlnet, custom_pipeline='fresco_v2v').to('cuda')\npipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)\n\ngenerator = torch.manual_seed(0)\nframes = [Image.fromarray(frame) for frame in frames]\n\noutput_frames = pipe(\n    prompt + a_prompt,\n    frames,\n    control_frames,\n    num_inference_steps=20,\n    strength=0.75,\n    controlnet_conditioning_scale=0.7,\n    generator=generator,\n    negative_prompt=n_prompt\n).images\n\noutput_frames[0].save(output_video_path, save_all=True,\n                 append_images=output_frames[1:], duration=100, loop=0)\n```\n\n### AnimateDiff on IPEX\n\nThis diffusion pipeline aims to accelerate the inference of AnimateDiff on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).\n\nTo use this pipeline, you need to:\n1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)\n\n**Note:** For each PyTorch release, there is a corresponding release of IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.3 to get the best performance.\n\n|PyTorch Version|IPEX Version|\n|--|--|\n|[v2.3.\\*](https://github.com/pytorch/pytorch/tree/v2.3.0 \"v2.3.0\")|[v2.3.\\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.3.0+cpu)|\n|[v1.13.\\*](https://github.com/pytorch/pytorch/tree/v1.13.0 \"v1.13.0\")|[v1.13.\\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|\n\nYou can simply use pip to install IPEX with the latest version.\n```python\npython -m pip install intel_extension_for_pytorch\n```\n**Note:** To install a specific version, run with the following command:\n```\npython -m pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu\n```\n2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX acceleration. Supported inference datatypes are Float32 and BFloat16.\n\n```python\npipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)\n# For Float32\npipe.prepare_for_ipex(torch.float32, prompt=\"A girl smiling\")\n# For BFloat16\npipe.prepare_for_ipex(torch.bfloat16, prompt=\"A girl smiling\")\n```\n\nThen you can use the ipex pipeline in a similar way to the default animatediff pipeline.\n```python\n# For Float32\noutput = pipe(prompt=\"A girl smiling\", guidance_scale=1.0, num_inference_steps=step)\n# For BFloat16\nwith torch.cpu.amp.autocast(enabled = True, dtype = torch.bfloat16):\n    output = pipe(prompt=\"A girl smiling\", guidance_scale=1.0, num_inference_steps=step)\n```\n\nThe following code compares the performance of the original animatediff pipeline with the ipex-optimized pipeline.\nBy using this optimized pipeline, we can get about 1.5-2.2 times performance boost with BFloat16 on the fifth generation of Intel Xeon CPUs, code-named Emerald Rapids.\n\n```python\nimport torch\nfrom diffusers import MotionAdapter, AnimateDiffPipeline, EulerDiscreteScheduler\nfrom safetensors.torch import load_file\nfrom pipeline_animatediff_ipex import AnimateDiffPipelineIpex\nimport time\n\ndevice = \"cpu\"\ndtype = torch.float32\n\nprompt = \"A girl smiling\"\nstep = 8  # Options: [1,2,4,8]\nrepo = \"ByteDance/AnimateDiff-Lightning\"\nckpt = f\"animatediff_lightning_{step}step_diffusers.safetensors\"\nbase = \"emilianJR/epiCRealism\"  # Choose to your favorite base model.\n\nadapter = MotionAdapter().to(device, dtype)\nadapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))\n\n# Helper function for time evaluation\ndef elapsed_time(pipeline, nb_pass=3, num_inference_steps=1):\n    # warmup\n    for _ in range(2):\n        output = pipeline(prompt = prompt, guidance_scale=1.0, num_inference_steps = num_inference_steps)\n    #time evaluation\n    start = time.time()\n    for _ in range(nb_pass):\n        pipeline(prompt = prompt, guidance_scale=1.0, num_inference_steps = num_inference_steps)\n    end = time.time()\n    return (end - start) / nb_pass\n\n##############     bf16 inference performance    ###############\n\n# 1. IPEX Pipeline initialization\npipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)\npipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing=\"trailing\", beta_schedule=\"linear\")\npipe.prepare_for_ipex(torch.bfloat16, prompt = prompt)\n\n# 2. Original Pipeline initialization\npipe2 = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)\npipe2.scheduler = EulerDiscreteScheduler.from_config(pipe2.scheduler.config, timestep_spacing=\"trailing\", beta_schedule=\"linear\")\n\n# 3. Compare performance between Original Pipeline and IPEX Pipeline\nwith torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):\n    latency = elapsed_time(pipe, num_inference_steps=step)\n    print(\"Latency of AnimateDiffPipelineIpex--bf16\", latency, \"s for total\", step, \"steps\")\n    latency = elapsed_time(pipe2, num_inference_steps=step)\n    print(\"Latency of AnimateDiffPipeline--bf16\", latency, \"s for total\", step, \"steps\")\n\n##############     fp32 inference performance    ###############\n\n# 1. IPEX Pipeline initialization\npipe3 = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)\npipe3.scheduler = EulerDiscreteScheduler.from_config(pipe3.scheduler.config, timestep_spacing=\"trailing\", beta_schedule=\"linear\")\npipe3.prepare_for_ipex(torch.float32, prompt = prompt)\n\n# 2. Original Pipeline initialization\npipe4 = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)\npipe4.scheduler = EulerDiscreteScheduler.from_config(pipe4.scheduler.config, timestep_spacing=\"trailing\", beta_schedule=\"linear\")\n\n# 3. Compare performance between Original Pipeline and IPEX Pipeline\nlatency = elapsed_time(pipe3, num_inference_steps=step)\nprint(\"Latency of AnimateDiffPipelineIpex--fp32\", latency, \"s for total\", step, \"steps\")\nlatency = elapsed_time(pipe4, num_inference_steps=step)\nprint(\"Latency of AnimateDiffPipeline--fp32\",latency, \"s for total\", step, \"steps\")\n```\n### HunyuanDiT with Differential Diffusion\n\n#### Usage\n\n```python\nimport torch\nfrom diffusers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import load_image\nfrom PIL import Image\nfrom torchvision import transforms\n\nfrom pipeline_hunyuandit_differential_img2img import (\n    HunyuanDiTDifferentialImg2ImgPipeline,\n)\n\n\npipe = HunyuanDiTDifferentialImg2ImgPipeline.from_pretrained(\n    \"Tencent-Hunyuan/HunyuanDiT-Diffusers\", torch_dtype=torch.float16\n).to(\"cuda\")\n\n\nsource_image = load_image(\n    \"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png\"\n)\nmap = load_image(\n    \"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask_2.png\"\n)\nprompt = \"a green pear\"\nnegative_prompt = \"blurry\"\n\nimage = pipe(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    image=source_image,\n    num_inference_steps=28,\n    guidance_scale=4.5,\n    strength=1.0,\n    map=map,\n).images[0]\n```\n\n| ![Gradient](https://github.com/user-attachments/assets/e38ce4d5-1ae6-4df0-ab43-adc1b45716b5) | ![Input](https://github.com/user-attachments/assets/9c95679c-e9d7-4f5a-90d6-560203acd6b3) | ![Output](https://github.com/user-attachments/assets/5313ff64-a0c4-418b-8b55-a38f1a5e7532) |\n| -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ |\n| Gradient                                                                                     | Input                                                                                     | Output                                                                                     |\n\nA colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.\n\n### 🪆Matryoshka Diffusion Models\n\n![🪆Matryoshka Diffusion Models](https://github.com/user-attachments/assets/bf90b53b-48c3-4769-a805-d9dfe4a7c572)\n\nThe Abstract of the paper:\n>Diffusion models are the _de-facto_ approach for generating high-quality images and videos but learning high-dimensional models remains a formidable task due to computational and optimization challenges. Existing methods often resort to training cascaded models in pixel space, or using a downsampled latent space of a separately trained auto-encoder. In this paper, we introduce Matryoshka Diffusion (MDM), **a novel framework for high-resolution image and video synthesis**. We propose a diffusion process that denoises inputs at multiple resolutions jointly and uses a **NestedUNet** architecture where features and parameters for small scale inputs are nested within those of the large scales. In addition, MDM enables a progressive training schedule from lower to higher resolutions which leads to significant improvements in optimization for high-resolution generation. We demonstrate the effectiveness of our approach on various benchmarks, including class-conditioned image generation, high-resolution text-to-image, and text-to-video applications. Remarkably, we can train a **_single pixel-space model_ at resolutions of up to 1024 × 1024 pixels**, demonstrating strong zero shot generalization using the **CC12M dataset, which contains only 12 million images**. Code and pre-trained checkpoints are released at https://github.com/apple/ml-mdm.\n\n- `64×64, nesting_level=0`: 1.719 GiB. With `50` DDIM inference steps:\n\n**64x64**\n:-------------------------:\n| <img src=\"https://github.com/user-attachments/assets/032738eb-c6cd-4fd9-b4d7-a7317b4b6528\" width=\"222\" height=\"222\" alt=\"bird_64_64\"> |\n\n- `256×256, nesting_level=1`: 1.776 GiB. With `150` DDIM inference steps:\n\n**64x64**             |  **256x256**\n:-------------------------:|:-------------------------:\n| <img src=\"https://github.com/user-attachments/assets/21b9ad8b-eea6-4603-80a2-31180f391589\" width=\"222\" height=\"222\" alt=\"bird_256_64\"> | <img src=\"https://github.com/user-attachments/assets/fc411682-8a36-422c-9488-395b77d4406e\" width=\"222\" height=\"222\" alt=\"bird_256_256\"> |\n\n- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible in this context! With `250` DDIM inference steps:\n\n**64x64**             |  **256x256**  |  **1024x1024**\n:-------------------------:|:-------------------------:|:-------------------------:\n| <img src=\"https://github.com/user-attachments/assets/febf4b98-3dee-4a8e-9946-fd42e1f232e6\" width=\"222\" height=\"222\" alt=\"bird_1024_64\"> | <img src=\"https://github.com/user-attachments/assets/c5f85b40-5d6d-4267-a92a-c89dff015b9b\" width=\"222\" height=\"222\" alt=\"bird_1024_256\"> | <img src=\"https://github.com/user-attachments/assets/ad66b913-4367-4cb9-889e-bc06f4d96148\" width=\"222\" height=\"222\" alt=\"bird_1024_1024\"> |\n\n```py\nfrom diffusers import DiffusionPipeline\nfrom diffusers.utils import make_image_grid\n\n# nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64\npipe = DiffusionPipeline.from_pretrained(\"tolgacangoz/matryoshka-diffusion-models\",\n                                         nesting_level=0,\n                                         trust_remote_code=False,  # One needs to give permission for this code to run\n                                         ).to(\"cuda\")\n\nprompt0 = \"a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree\"\nprompt = f\"breathtaking {prompt0}. award-winning, professional, highly detailed\"\nimage = pipe(prompt, num_inference_steps=50).images\nmake_image_grid(image, rows=1, cols=len(image))\n\n# pipe.change_nesting_level(<int>)  # 0, 1, or 2\n# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.\n```\n\n### Stable Diffusion XL Attentive Eraser Pipeline\n<img src=\"https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/fenmian.png\"  width=\"600\" />\n\n**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://huggingface.co/papers/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).\n\n#### Key features\n\n- **Tuning-Free**: No additional training is required, making it easy to integrate and use.\n- **Flexible Mask Support**: Works with different types of masks for targeted object removal.\n- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.\n\n#### Usage example\nTo use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:\n```py\nimport torch\nfrom diffusers import DDIMScheduler, DiffusionPipeline\nfrom diffusers.utils import load_image\nimport torch.nn.functional as F\nfrom torchvision.transforms.functional import to_tensor, gaussian_blur\n\ndtype = torch.float16\ndevice = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\") \n\nscheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\npipeline = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    custom_pipeline=\"pipeline_stable_diffusion_xl_attentive_eraser\",\n    scheduler=scheduler,\n    variant=\"fp16\",\n    use_safetensors=True,\n    torch_dtype=dtype,\n).to(device)\n\n\ndef preprocess_image(image_path, device):\n    image = to_tensor((load_image(image_path)))\n    image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]\n    if image.shape[1] != 3:\n        image = image.expand(-1, 3, -1, -1)\n        image = F.interpolate(image, (1024, 1024))\n        image = image.to(dtype).to(device)\n        return image\n\ndef preprocess_mask(mask_path, device):\n    mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))\n    mask = mask.unsqueeze_(0).float()  # 0 or 1\n    mask = F.interpolate(mask, (1024, 1024))\n    mask = gaussian_blur(mask, kernel_size=(77, 77))\n    mask[mask < 0.1] = 0\n    mask[mask >= 0.1] = 1\n    mask = mask.to(dtype).to(device)\n    return mask\n\nprompt = \"\" # Set prompt to null\nseed=123 \ngenerator = torch.Generator(device=device).manual_seed(seed)\nsource_image_path = \"https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png\"\nmask_path = \"https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png\"\nsource_image = preprocess_image(source_image_path, device)\nmask = preprocess_mask(mask_path, device)\n\nimage = pipeline(\n    prompt=prompt, \n    image=source_image,\n    mask_image=mask,\n    height=1024,\n    width=1024,\n    AAS=True, # enable AAS\n    strength=0.8, # inpainting strength\n    rm_guidance_scale=9, # removal guidance scale\n    ss_steps = 9, # similarity suppression steps\n    ss_scale = 0.3, # similarity suppression scale\n    AAS_start_step=0, # AAS start step\n    AAS_start_layer=34, # AAS start layer\n    AAS_end_layer=70, # AAS end layer\n    num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)\n    generator=generator,\n    guidance_scale=1,\n).images[0]\nimage.save('./removed_img.png')\nprint(\"Object removal completed\")\n```\n\n| Source Image                                                                                   | Mask                                                                                        | Output                                                                                              |\n| ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |\n| ![Source Image](https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png) | ![Mask](https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png) | ![Output](https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/AE_step40_layer34.png) |\n\n# Perturbed-Attention Guidance\n\n[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://huggingface.co/papers/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)\n\nThis implementation is based on [Diffusers](https://huggingface.co/docs/diffusers/index). `StableDiffusionPAGPipeline` is a modification of `StableDiffusionPipeline` to support Perturbed-Attention Guidance (PAG).\n\n## Example Usage\n\n```py\nimport os\nimport torch\n\nfrom accelerate.utils import set_seed\n\nfrom diffusers import StableDiffusionPipeline\nfrom diffusers.utils import load_image, make_image_grid\nfrom diffusers.utils.torch_utils import randn_tensor\n\npipe = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    custom_pipeline=\"hyoungwoncho/sd_perturbed_attention_guidance\",\n    torch_dtype=torch.float16\n)\n\ndevice = \"cuda\"\npipe = pipe.to(device)\n\npag_scale = 5.0\npag_applied_layers_index = ['m0']\n\nbatch_size = 4\nseed = 10\n\nbase_dir = \"./results/\"\ngrid_dir = base_dir + \"/pag\" + str(pag_scale) + \"/\"\n\nif not os.path.exists(grid_dir):\n    os.makedirs(grid_dir)\n\nset_seed(seed)\n\nlatent_input = randn_tensor(shape=(batch_size,4,64,64), generator=None, device=device, dtype=torch.float16)\n\noutput_baseline = pipe(\n    \"\",\n    width=512,\n    height=512,\n    num_inference_steps=50,\n    guidance_scale=0.0,\n    pag_scale=0.0,\n    pag_applied_layers_index=pag_applied_layers_index,\n    num_images_per_prompt=batch_size,\n    latents=latent_input\n).images\n\noutput_pag = pipe(\n    \"\",\n    width=512,\n    height=512,\n    num_inference_steps=50,\n    guidance_scale=0.0,\n    pag_scale=5.0,\n    pag_applied_layers_index=pag_applied_layers_index,\n    num_images_per_prompt=batch_size,\n    latents=latent_input\n).images\n\ngrid_image = make_image_grid(output_baseline + output_pag, rows=2, cols=batch_size)\ngrid_image.save(grid_dir + \"sample.png\")\n```\n\n## PAG Parameters\n\n`pag_scale` : guidance scale of PAG (ex: 5.0)\n\n`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])\n\n# PIXART-α Controlnet pipeline\n\n[Project](https://pixart-alpha.github.io/) / [GitHub](https://github.com/PixArt-alpha/PixArt-alpha/blob/master/asset/docs/pixart_controlnet.md)\n\nThis the implementation of the controlnet model and the pipelne for the Pixart-alpha model, adapted to use the HuggingFace Diffusers.\n\n## Example Usage\n\nThis example uses the Pixart HED Controlnet model, converted from the control net model as trained by the authors of the paper.\n\n```py\nimport sys\nimport os\nimport torch\nimport torchvision.transforms as T\nimport torchvision.transforms.functional as TF\n\nfrom pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline\nfrom diffusers.utils import load_image\n\nfrom diffusers.image_processor import PixArtImageProcessor\n\nfrom controlnet_aux import HEDdetector\n\nsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\nfrom pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel\n\ncontrolnet_repo_id = \"raulc0399/pixart-alpha-hed-controlnet\"\n\nweight_dtype = torch.float16\nimage_size = 1024\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\ntorch.manual_seed(0)\n\n# load controlnet\ncontrolnet = PixArtControlNetAdapterModel.from_pretrained(\n    controlnet_repo_id,\n    torch_dtype=weight_dtype,\n    use_safetensors=True,\n).to(device)\n\npipe = PixArtAlphaControlnetPipeline.from_pretrained(\n    \"PixArt-alpha/PixArt-XL-2-1024-MS\",\n    controlnet=controlnet,\n    torch_dtype=weight_dtype,\n    use_safetensors=True,\n).to(device)\n\nimages_path = \"images\"\ncontrol_image_file = \"0_7.jpg\"\n\nprompt = \"battleship in space, galaxy in background\"\n\ncontrol_image_name = control_image_file.split('.')[0]\n\ncontrol_image = load_image(f\"{images_path}/{control_image_file}\")\nprint(control_image.size)\nheight, width = control_image.size\n\nhed = HEDdetector.from_pretrained(\"lllyasviel/Annotators\")\n\ncondition_transform = T.Compose([\n    T.Lambda(lambda img: img.convert('RGB')),\n    T.CenterCrop([image_size, image_size]),\n])\n\ncontrol_image = condition_transform(control_image)\nhed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)\n\nhed_edge.save(f\"{images_path}/{control_image_name}_hed.jpg\")\n\n# run pipeline\nwith torch.no_grad():\n    out = pipe(\n        prompt=prompt,\n        image=hed_edge,\n        num_inference_steps=14,\n        guidance_scale=4.5,\n        height=image_size,\n        width=image_size,\n    )\n\n    out.images[0].save(f\"{images_path}//{control_image_name}_output.jpg\")\n    \n```\n\nIn the folder examples/pixart there is also a script that can be used to train new models.\nPlease check the script `train_controlnet_hf_diffusers.sh` on how to start the training.\n\n# CogVideoX DDIM Inversion Pipeline\n\nThis implementation performs DDIM inversion on the video based on CogVideoX and uses guided attention to reconstruct or edit the inversion latents.\n\n## Example Usage\n\n```python\nimport torch\n\nfrom examples.community.cogvideox_ddim_inversion import CogVideoXPipelineForDDIMInversion\n\n\n# Load pretrained pipeline\npipeline = CogVideoXPipelineForDDIMInversion.from_pretrained(\n    \"THUDM/CogVideoX1.5-5B\",\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\n# Run DDIM inversion, and the videos will be generated in the output_path\noutput = pipeline_for_inversion(\n    prompt=\"prompt that describes the edited video\",\n    video_path=\"path/to/input.mp4\",\n    guidance_scale=6.0,\n    num_inference_steps=50,\n    skip_frames_start=0,\n    skip_frames_end=0,\n    frame_sample_step=None,\n    max_num_frames=81,\n    width=720,\n    height=480,\n    seed=42,\n)\npipeline.export_latents_to_video(output.inverse_latents[-1], \"path/to/inverse_video.mp4\", fps=8)\npipeline.export_latents_to_video(output.recon_latents[-1], \"path/to/recon_video.mp4\", fps=8)\n```\n# FaithDiff Stable Diffusion XL Pipeline\n\n[Project](https://jychen9811.github.io/FaithDiff_page/) / [GitHub](https://github.com/JyChen9811/FaithDiff/)\n\nThis the implementation of the FaithDiff pipeline for SDXL, adapted to use the HuggingFace Diffusers.\n\nFor more details see the project links above.\n\n## Example Usage\n\nThis example upscale and restores a low-quality image. The input image has a resolution of 512x512 and will be upscaled at a scale of 2x, to a final resolution of 1024x1024. It is possible to upscale to a larger scale, but it is recommended that the input image be at least 1024x1024 in these cases. To upscale this image by 4x, for example, it would be recommended to re-input the result into a new 2x processing, thus performing progressive scaling.\n\n````py\nimport random\nimport numpy as np\nimport torch\nfrom diffusers import DiffusionPipeline, AutoencoderKL, UniPCMultistepScheduler\nfrom huggingface_hub import hf_hub_download\nfrom diffusers.utils import load_image\nfrom PIL import Image\n\ndevice = \"cuda\"\ndtype = torch.float16\nMAX_SEED = np.iinfo(np.int32).max\n\n# Download weights for additional unet layers\nmodel_file = hf_hub_download(\n    \"jychen9811/FaithDiff\",\n    filename=\"FaithDiff.bin\", local_dir=\"./proc_data/faithdiff\", local_dir_use_symlinks=False\n)\n\n# Initialize the models and pipeline\nvae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=dtype)\n\nmodel_id = \"SG161222/RealVisXL_V4.0\"\npipe = DiffusionPipeline.from_pretrained(\n    model_id,\n    torch_dtype=dtype,\n    vae=vae,\n    unet=None, #<- Do not load with original model.\n    custom_pipeline=\"pipeline_faithdiff_stable_diffusion_xl\",    \n    use_safetensors=True,\n    variant=\"fp16\",\n).to(device)\n\n# Here we need use pipeline internal unet model\npipe.unet = pipe.unet_model.from_pretrained(model_id, subfolder=\"unet\", variant=\"fp16\", use_safetensors=True)\n\n# Load additional layers to the model\npipe.unet.load_additional_layers(weight_path=\"proc_data/faithdiff/FaithDiff.bin\", dtype=dtype)\n\n# Enable vae tiling\npipe.set_encoder_tile_settings()\npipe.enable_vae_tiling()\n\n# Optimization\npipe.enable_model_cpu_offload()\n\n# Set selected scheduler\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n\n#input params\nprompt = \"The image features a woman in her 55s with blonde hair and a white shirt, smiling at the camera. She appears to be in a good mood and is wearing a white scarf around her neck. \"\nupscale = 2 # scale here\nstart_point = \"lr\" # or \"noise\"\nlatent_tiled_overlap = 0.5\nlatent_tiled_size = 1024\n\n# Load image\nlq_image = load_image(\"https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/woman.png\")\noriginal_height = lq_image.height\noriginal_width = lq_image.width\nprint(f\"Current resolution: H:{original_height} x W:{original_width}\")\n\nwidth = original_width * int(upscale)\nheight = original_height * int(upscale)\nprint(f\"Final resolution: H:{height} x W:{width}\")\n\n# Restoration\nimage = lq_image.resize((width, height), Image.LANCZOS)\ninput_image, width_init, height_init, width_now, height_now = pipe.check_image_size(image)\n\ngenerator = torch.Generator(device=device).manual_seed(random.randint(0, MAX_SEED))\ngen_image = pipe(lr_img=input_image, \n                 prompt = prompt,                  \n                 num_inference_steps=20, \n                 guidance_scale=5, \n                 generator=generator, \n                 start_point=start_point, \n                 height = height_now, \n                 width=width_now, \n                 overlap=latent_tiled_overlap, \n                 target_size=(latent_tiled_size, latent_tiled_size)\n                ).images[0]\n\ncropped_image = gen_image.crop((0, 0, width_init, height_init))\ncropped_image.save(\"data/result.png\")\n````\n### Result\n[<img src=\"https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/faithdiff_restored.PNG\" width=\"512px\" height=\"512px\"/>](https://imgsli.com/MzY1NzE2)\n\n\n# Stable Diffusion 3 InstructPix2Pix Pipeline\nThis the implementation of the Stable Diffusion 3 InstructPix2Pix Pipeline, based on the HuggingFace Diffusers.\n\n## Example Usage\nThis pipeline aims to edit image based on user's instruction by using SD3\n````py\nimport torch\nfrom diffusers import SD3Transformer2DModel\nfrom diffusers import DiffusionPipeline\nfrom diffusers.utils import load_image\n\n\nresolution = 512\nimage = load_image(\"https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png\").resize(\n    (resolution, resolution)\n)\nedit_instruction = \"Turn sky into a sunny one\"\n\n\npipe = DiffusionPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-3-medium-diffusers\", custom_pipeline=\"pipeline_stable_diffusion_3_instruct_pix2pix\", torch_dtype=torch.float16).to('cuda')\n\npipe.transformer = SD3Transformer2DModel.from_pretrained(\"CaptainZZZ/sd3-instructpix2pix\",torch_dtype=torch.float16).to('cuda')\n\nedited_image = pipe(\n    prompt=edit_instruction,\n    image=image,\n    height=resolution,\n    width=resolution,\n    guidance_scale=7.5,\n    image_guidance_scale=1.5,\n    num_inference_steps=30,\n).images[0]\n\nedited_image.save(\"edited_image.png\")\n````\n|Original|Edited|\n|---|---|\n|![Original image](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/StableDiffusion3InstructPix2Pix/mountain.png)|![Edited image](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/StableDiffusion3InstructPix2Pix/edited.png)\n\n### Note\nThis model is trained on 512x512, so input size is better on 512x512.\nFor better editing performance, please refer to this powerful model https://huggingface.co/BleachNick/SD3_UltraEdit_freeform and Paper \"UltraEdit: Instruction-based Fine-Grained Image\nEditing at Scale\", many thanks to their contribution!\n\n# Flux Kontext multiple images\n\nThis implementation of Flux Kontext allows users to pass multiple reference images. Each image is encoded separately, and the resulting latent vectors are concatenated.\n\nAs explained in Section 3 of [the paper](https://huggingface.co/papers/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.\n\n## Example Usage\n\nThis pipeline loads two reference images and generates a new image based on them.\n\n```python\nimport torch\n\nfrom diffusers import FluxKontextPipeline\nfrom diffusers.utils import load_image\n\n\npipe = FluxKontextPipeline.from_pretrained(\n    \"black-forest-labs/FLUX.1-Kontext-dev\",\n    torch_dtype=torch.bfloat16,\n    custom_pipeline=\"pipeline_flux_kontext_multiple_images\",\n)\npipe.to(\"cuda\")\n\npikachu_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png\"\n).convert(\"RGB\")\ncat_image = load_image(\n    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png\"\n).convert(\"RGB\")\n\n\nprompts = [\n    \"Pikachu and the cat are sitting together at a pizzeria table, enjoying a delicious pizza.\",\n]\nimages = pipe(\n    multiple_images=[(pikachu_image, cat_image)],\n    prompt=prompts,\n    guidance_scale=2.5,\n    generator=torch.Generator().manual_seed(42),\n).images\nimages[0].save(\"pizzeria.png\")\n```\n\n# Flux Fill ControlNet Pipeline\n\nThis implementation of Flux Fill + ControlNet Inpaint combines the fill-style masked editing of FLUX.1-Fill-dev with full ControlNet conditioning. The base image is processed through the Fill model while the ControlNet receives the corresponding conditioning input (depth, canny, pose, etc.), and both outputs are fused during denoising to guide structure and composition.\n\nWhile FLUX.1-Fill-dev is designed for mask-based edits, it was not originally trained to operate jointly with ControlNet. In practice, this combined setup works well for structured inpainting tasks, though results may vary depending on the conditioning strength and the alignment between the mask and the control input.\n\n## Example Usage\n\n\n```python\nimport torch\nfrom diffusers import (\n    FluxControlNetModel,\n    FluxPriorReduxPipeline,\n)\nfrom diffusers.utils import load_image\n\n# NEW PIPELINE (updated name)\nfrom pipline_flux_fill_controlnet_Inpaint import  FluxControlNetFillInpaintPipeline\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\ndtype = torch.bfloat16\n\n# Models\nbase_model = \"black-forest-labs/FLUX.1-Fill-dev\"\ncontrolnet_model = \"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0\"\nprior_model = \"black-forest-labs/FLUX.1-Redux-dev\"\n\n# Load ControlNet\ncontrolnet = FluxControlNetModel.from_pretrained(\n    controlnet_model,\n    torch_dtype=dtype,\n)\n\n# Load Fill + ControlNet Pipeline\nfill_pipe = FluxControlNetFillInpaintPipeline.from_pretrained(\n    base_model,\n    controlnet=controlnet,\n    torch_dtype=dtype,\n).to(device)\n\n# OPTIONAL FP8\n# fill_pipe.transformer.enable_layerwise_casting(\n#     storage_dtype=torch.float8_e4m3fn,\n#     compute_dtype=torch.bfloat16\n# )\n\n#  OPTIONAL Prior Redux\n#pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(\n#    prior_model,\n#    torch_dtype=dtype,\n#).to(device)\n\n# Inputs\n\n# combined_image = load_image(\"person_input.png\")\n\n\n# 1. Prior conditioning\n#prior_out = pipe_prior_redux(\n#    image=cloth_image,\n#    prompt=cloth_prompt,\n#)\n\n# 2. Fill Inpaint with ControlNet\n\n# canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).\n\nimg = load_image(r\"imgs/background.jpg\")\nmask = load_image(r\"imgs/mask.png\")\n\ncontrol_image_depth = load_image(r\"imgs/dog_depth _2.png\")\n\nresult = fill_pipe(\n    prompt=\"a dog on a bench\",\n    image=img,\n    mask_image=mask,\n\n    control_image=control_image_depth,\n    control_mode=[2],  # union mode\n    control_guidance_start=0.0,\n    control_guidance_end=0.8,\n    controlnet_conditioning_scale=0.9,\n\n    height=1024,\n    width=1024,\n\n    strength=1.0,\n    guidance_scale=50.0,\n    num_inference_steps=60,\n    max_sequence_length=512,\n\n#    **prior_out,\n)\n\n# result.images[0].save(\"flux_fill_controlnet_inpaint.png\")\n\nfrom datetime import datetime\ntimestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\nresult.images[0].save(f\"flux_fill_controlnet_inpaint_depth{timestamp}.jpg\")\n```\n\n"
  },
  {
    "path": "examples/community/README_community_scripts.md",
    "content": "# Community Scripts\n\n**Community scripts** consist of inference examples using Diffusers pipelines that have been added by the community.\nPlease have a look at the following table to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste code example that you can try out.\nIf a community script doesn't work as expected, please open an issue and ping the author on it.\n\n| Example                                                                                                                               | Description                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              | Code Example                                                                              | Colab                                                                                                                                                                                                              |                                                        Author |\n|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|\n| Using IP-Adapter with Negative Noise                                                                                                  | Using negative noise with IP-adapter to better control the generation (see the [original post](https://github.com/huggingface/diffusers/discussions/7167) on the forum for more details)                                                                                                                                                                                                                                                    | [IP-Adapter Negative Noise](#ip-adapter-negative-noise)                                   |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ip_adapter_negative_noise.ipynb) | [Álvaro Somoza](https://github.com/asomoza)|\n| Asymmetric Tiling                                                                                                  |configure seamless image tiling independently for the X and Y axes                                                                                                                                                                                                      | [Asymmetric Tiling](#Asymmetric-Tiling )                                   |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/asymetric_tiling.ipynb) | [alexisrolland](https://github.com/alexisrolland)|\n| Prompt Scheduling Callback                                                                                                  |Allows changing prompts during a generation                                                                                                                                                                                                      | [Prompt Scheduling-Callback](#Prompt-Scheduling-Callback )                                   |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/prompt_scheduling_callback.ipynb) | [hlky](https://github.com/hlky)|\n\n\n## Example usages\n\n### IP Adapter Negative Noise\n\nDiffusers pipelines are fully integrated with IP-Adapter, which allows you to prompt the diffusion model with an image. However, it does not support negative image prompts (there is no `negative_ip_adapter_image` argument) the same way it supports negative text prompts. When you pass an `ip_adapter_image,` it will create a zero-filled tensor as a negative image. This script shows you how to create a negative noise from `ip_adapter_image` and use it to significantly improve the generation quality while preserving the composition of images.\n\n[cubiq](https://github.com/cubiq) initially developed this feature in his [repository](https://github.com/cubiq/ComfyUI_IPAdapter_plus). The community script was contributed by [asomoza](https://github.com/Somoza). You can find more details about this experimentation [this discussion](https://github.com/huggingface/diffusers/discussions/7167)\n\nIP-Adapter without negative noise\n|source|result|\n|---|---|\n|![20240229150812](https://github.com/huggingface/diffusers/assets/5442875/901d8bd8-7a59-4fe7-bda1-a0e0d6c7dffd)|![20240229163923_normal](https://github.com/huggingface/diffusers/assets/5442875/3432e25a-ece6-45f4-a3f4-fca354f40b5b)|\n\nIP-Adapter with negative noise\n|source|result|\n|---|---|\n|![20240229150812](https://github.com/huggingface/diffusers/assets/5442875/901d8bd8-7a59-4fe7-bda1-a0e0d6c7dffd)|![20240229163923](https://github.com/huggingface/diffusers/assets/5442875/736fd15a-36ba-40c0-a7d8-6ec1ac26f788)|\n\n```python\nimport torch\n\nfrom diffusers import AutoencoderKL, DPMSolverMultistepScheduler, StableDiffusionXLPipeline\nfrom diffusers.models import ImageProjection\nfrom diffusers.utils import load_image\n\n\ndef encode_image(\n    image_encoder,\n    feature_extractor,\n    image,\n    device,\n    num_images_per_prompt,\n    output_hidden_states=None,\n    negative_image=None,\n):\n    dtype = next(image_encoder.parameters()).dtype\n\n    if not isinstance(image, torch.Tensor):\n        image = feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n    image = image.to(device=device, dtype=dtype)\n    if output_hidden_states:\n        image_enc_hidden_states = image_encoder(image, output_hidden_states=True).hidden_states[-2]\n        image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n\n        if negative_image is None:\n            uncond_image_enc_hidden_states = image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n        else:\n            if not isinstance(negative_image, torch.Tensor):\n                negative_image = feature_extractor(negative_image, return_tensors=\"pt\").pixel_values\n            negative_image = negative_image.to(device=device, dtype=dtype)\n            uncond_image_enc_hidden_states = image_encoder(negative_image, output_hidden_states=True).hidden_states[-2]\n\n        uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n        return image_enc_hidden_states, uncond_image_enc_hidden_states\n    else:\n        image_embeds = image_encoder(image).image_embeds\n        image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n        uncond_image_embeds = torch.zeros_like(image_embeds)\n\n        return image_embeds, uncond_image_embeds\n\n\n@torch.no_grad()\ndef prepare_ip_adapter_image_embeds(\n    unet,\n    image_encoder,\n    feature_extractor,\n    ip_adapter_image,\n    do_classifier_free_guidance,\n    device,\n    num_images_per_prompt,\n    ip_adapter_negative_image=None,\n):\n    if not isinstance(ip_adapter_image, list):\n        ip_adapter_image = [ip_adapter_image]\n\n    if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers):\n        raise ValueError(\n            f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n        )\n\n    image_embeds = []\n    for single_ip_adapter_image, image_proj_layer in zip(\n        ip_adapter_image, unet.encoder_hid_proj.image_projection_layers\n    ):\n        output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n        single_image_embeds, single_negative_image_embeds = encode_image(\n            image_encoder,\n            feature_extractor,\n            single_ip_adapter_image,\n            device,\n            1,\n            output_hidden_state,\n            negative_image=ip_adapter_negative_image,\n        )\n        single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)\n        single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)\n\n        if do_classifier_free_guidance:\n            single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n            single_image_embeds = single_image_embeds.to(device)\n\n        image_embeds.append(single_image_embeds)\n\n    return image_embeds\n\n\nvae = AutoencoderKL.from_pretrained(\n    \"madebyollin/sdxl-vae-fp16-fix\",\n    torch_dtype=torch.float16,\n).to(\"cuda\")\n\npipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"RunDiffusion/Juggernaut-XL-v9\",\n    torch_dtype=torch.float16,\n    vae=vae,\n    variant=\"fp16\",\n).to(\"cuda\")\n\npipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\npipeline.scheduler.config.use_karras_sigmas = True\n\npipeline.load_ip_adapter(\n    \"h94/IP-Adapter\",\n    subfolder=\"sdxl_models\",\n    weight_name=\"ip-adapter-plus_sdxl_vit-h.safetensors\",\n    image_encoder_folder=\"models/image_encoder\",\n)\npipeline.set_ip_adapter_scale(0.7)\n\nip_image = load_image(\"source.png\")\nnegative_ip_image = load_image(\"noise.png\")\n\nimage_embeds = prepare_ip_adapter_image_embeds(\n    unet=pipeline.unet,\n    image_encoder=pipeline.image_encoder,\n    feature_extractor=pipeline.feature_extractor,\n    ip_adapter_image=[[ip_image]],\n    do_classifier_free_guidance=True,\n    device=\"cuda\",\n    num_images_per_prompt=1,\n    ip_adapter_negative_image=negative_ip_image,\n)\n\n\nprompt = \"cinematic photo of a cyborg in the city, 4k, high quality, intricate, highly detailed\"\nnegative_prompt = \"blurry, smooth, plastic\"\n\nimage = pipeline(\n    prompt=prompt,\n    negative_prompt=negative_prompt,\n    ip_adapter_image_embeds=image_embeds,\n    guidance_scale=6.0,\n    num_inference_steps=25,\n    generator=torch.Generator(device=\"cpu\").manual_seed(1556265306),\n).images[0]\n\nimage.save(\"result.png\")\n```\n\n### Asymmetric Tiling\nStable Diffusion is not trained to generate seamless textures. However, you can use this simple script to add tiling to your generation. This script is contributed by [alexisrolland](https://github.com/alexisrolland). See more details in the [this issue](https://github.com/huggingface/diffusers/issues/556)\n\n\n|Generated|Tiled|\n|---|---|\n|![20240313003235_573631814](https://github.com/huggingface/diffusers/assets/5442875/eca174fb-06a4-464e-a3a7-00dbb024543e)|![wall](https://github.com/huggingface/diffusers/assets/5442875/b4aa774b-2a6a-4316-a8eb-8f30b5f4d024)|\n\n\n```py\nimport torch\nfrom typing import Optional\nfrom diffusers import StableDiffusionPipeline\nfrom diffusers.models.lora import LoRACompatibleConv\n\ndef seamless_tiling(pipeline, x_axis, y_axis):\n    def asymmetric_conv2d_convforward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):\n        self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)\n        self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])\n        working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)\n        working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)\n        return torch.nn.functional.conv2d(working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups)\n    x_mode = 'circular' if x_axis else 'constant'\n    y_mode = 'circular' if y_axis else 'constant'\n    targets = [pipeline.vae, pipeline.text_encoder, pipeline.unet]\n    convolution_layers = []\n    for target in targets:\n        for module in target.modules():\n            if isinstance(module, torch.nn.Conv2d):\n                convolution_layers.append(module)\n    for layer in convolution_layers:\n        if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:\n            layer.lora_layer = lambda * x: 0\n        layer._conv_forward = asymmetric_conv2d_convforward.__get__(layer, torch.nn.Conv2d)\n    return pipeline\n\npipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16, use_safetensors=True)\npipeline.enable_model_cpu_offload()\nprompt = [\"texture of a red brick wall\"]\nseed = 123456\ngenerator = torch.Generator(device='cuda').manual_seed(seed)\n\npipeline = seamless_tiling(pipeline=pipeline, x_axis=True, y_axis=True)\nimage = pipeline(\n    prompt=prompt,\n    width=512,\n    height=512,\n    num_inference_steps=20,\n    guidance_scale=7,\n    num_images_per_prompt=1,\n    generator=generator\n).images[0]\nseamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False)\n\ntorch.cuda.empty_cache()\nimage.save('image.png')\n```\n\n### Prompt Scheduling callback\n\nPrompt scheduling callback allows changing prompts during a generation, like [prompt editing in A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-editing)\n\n```python\nfrom diffusers import StableDiffusionPipeline\nfrom diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks\nfrom diffusers.configuration_utils import register_to_config\nimport torch\nfrom typing import Any, Dict, Tuple, Union\n\n\nclass SDPromptSchedulingCallback(PipelineCallback):\n    @register_to_config\n    def __init__(\n        self,\n        encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],\n        cutoff_step_ratio=None,\n        cutoff_step_index=None,\n    ):\n        super().__init__(\n            cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index\n        )\n\n    tensor_inputs = [\"prompt_embeds\"]\n\n    def callback_fn(\n        self, pipeline, step_index, timestep, callback_kwargs\n    ) -> dict[str, Any]:\n        cutoff_step_ratio = self.config.cutoff_step_ratio\n        cutoff_step_index = self.config.cutoff_step_index\n        if isinstance(self.config.encoded_prompt, tuple):\n            prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt\n        else:\n            prompt_embeds = self.config.encoded_prompt\n\n        # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio\n        cutoff_step = (\n            cutoff_step_index\n            if cutoff_step_index is not None\n            else int(pipeline.num_timesteps * cutoff_step_ratio)\n        )\n\n        if step_index == cutoff_step:\n            if pipeline.do_classifier_free_guidance:\n                prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n            callback_kwargs[self.tensor_inputs[0]] = prompt_embeds\n        return callback_kwargs\n\n\npipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(\n    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n    use_safetensors=True,\n).to(\"cuda\")\npipeline.safety_checker = None\npipeline.requires_safety_checker = False\n\ncallback = MultiPipelineCallbacks(\n    [\n        SDPromptSchedulingCallback(\n            encoded_prompt=pipeline.encode_prompt(\n                prompt=f\"prompt {index}\",\n                negative_prompt=f\"negative prompt {index}\",\n                device=pipeline._execution_device,\n                num_images_per_prompt=1,\n                # pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran\n                do_classifier_free_guidance=True,\n            ),\n            cutoff_step_index=index,\n        ) for index in range(1, 20)\n    ]\n)\n\nimage = pipeline(\n    prompt=\"prompt\"\n    negative_prompt=\"negative prompt\",\n    callback_on_step_end=callback,\n    callback_on_step_end_tensor_inputs=[\"prompt_embeds\"],\n).images[0]\ntorch.cuda.empty_cache()\nimage.save('image.png')\n```\n\n```python\nfrom diffusers import StableDiffusionXLPipeline\nfrom diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks\nfrom diffusers.configuration_utils import register_to_config\nimport torch\nfrom typing import Any, Dict, Tuple, Union\n\n\nclass SDXLPromptSchedulingCallback(PipelineCallback):\n    @register_to_config\n    def __init__(\n        self,\n        encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],\n        add_text_embeds: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],\n        add_time_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],\n        cutoff_step_ratio=None,\n        cutoff_step_index=None,\n    ):\n        super().__init__(\n            cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index\n        )\n\n    tensor_inputs = [\"prompt_embeds\", \"add_text_embeds\", \"add_time_ids\"]\n\n    def callback_fn(\n        self, pipeline, step_index, timestep, callback_kwargs\n    ) -> dict[str, Any]:\n        cutoff_step_ratio = self.config.cutoff_step_ratio\n        cutoff_step_index = self.config.cutoff_step_index\n        if isinstance(self.config.encoded_prompt, tuple):\n            prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt\n        else:\n            prompt_embeds = self.config.encoded_prompt\n        if isinstance(self.config.add_text_embeds, tuple):\n            add_text_embeds, negative_add_text_embeds = self.config.add_text_embeds\n        else:\n            add_text_embeds = self.config.add_text_embeds\n        if isinstance(self.config.add_time_ids, tuple):\n            add_time_ids, negative_add_time_ids = self.config.add_time_ids\n        else:\n            add_time_ids = self.config.add_time_ids\n\n        # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio\n        cutoff_step = (\n            cutoff_step_index\n            if cutoff_step_index is not None\n            else int(pipeline.num_timesteps * cutoff_step_ratio)\n        )\n\n        if step_index == cutoff_step:\n            if pipeline.do_classifier_free_guidance:\n                prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n                add_text_embeds = torch.cat([negative_add_text_embeds, add_text_embeds])\n                add_time_ids = torch.cat([negative_add_time_ids, add_time_ids])\n            callback_kwargs[self.tensor_inputs[0]] = prompt_embeds\n            callback_kwargs[self.tensor_inputs[1]] = add_text_embeds\n            callback_kwargs[self.tensor_inputs[2]] = add_time_ids\n        return callback_kwargs\n\n\npipeline: StableDiffusionXLPipeline = StableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\",\n    torch_dtype=torch.float16,\n    variant=\"fp16\",\n    use_safetensors=True,\n).to(\"cuda\")\n\ncallbacks = []\nfor index in range(1, 20):\n    (\n        prompt_embeds,\n        negative_prompt_embeds,\n        pooled_prompt_embeds,\n        negative_pooled_prompt_embeds,\n    ) = pipeline.encode_prompt(\n        prompt=f\"prompt {index}\",\n        negative_prompt=f\"prompt {index}\",\n        device=pipeline._execution_device,\n        num_images_per_prompt=1,\n        # pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran\n        do_classifier_free_guidance=True,\n    )\n    text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n    add_time_ids = pipeline._get_add_time_ids(\n        (1024, 1024),\n        (0, 0),\n        (1024, 1024),\n        dtype=prompt_embeds.dtype,\n        text_encoder_projection_dim=text_encoder_projection_dim,\n    )\n    negative_add_time_ids = pipeline._get_add_time_ids(\n        (1024, 1024),\n        (0, 0),\n        (1024, 1024),\n        dtype=prompt_embeds.dtype,\n        text_encoder_projection_dim=text_encoder_projection_dim,\n    )\n    callbacks.append(\n        SDXLPromptSchedulingCallback(\n            encoded_prompt=(prompt_embeds, negative_prompt_embeds),\n            add_text_embeds=(pooled_prompt_embeds, negative_pooled_prompt_embeds),\n            add_time_ids=(add_time_ids, negative_add_time_ids),\n            cutoff_step_index=index,\n        )\n    )\n\n\ncallback = MultiPipelineCallbacks(callbacks)\n\nimage = pipeline(\n    prompt=\"prompt\",\n    negative_prompt=\"negative prompt\",\n    callback_on_step_end=callback,\n    callback_on_step_end_tensor_inputs=[\n        \"prompt_embeds\",\n        \"add_text_embeds\",\n        \"add_time_ids\",\n    ],\n).images[0]\n```\n"
  },
  {
    "path": "examples/community/adaptive_mask_inpainting.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/\n\nimport inspect\nimport os\nimport shutil\nfrom glob import glob\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport cv2\nimport numpy as np\nimport PIL.Image\nimport requests\nimport torch\nfrom detectron2.config import get_cfg\nfrom detectron2.data import MetadataCatalog\nfrom detectron2.engine import DefaultPredictor\nfrom detectron2.projects import point_rend\nfrom detectron2.structures.instances import Instances\nfrom detectron2.utils.visualizer import ColorMode, Visualizer\nfrom packaging import version\nfrom tqdm import tqdm\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    deprecate,\n    is_accelerate_available,\n    is_accelerate_version,\n    logging,\n    randn_tensor,\n)\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nAMI_INSTALL_MESSAGE = \"\"\"\n\nExample Demo of Adaptive Mask Inpainting\n\nBeyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models\nKim et al.\nECCV-2024 (Oral)\n\n\nPlease prepare the environment via\n\n```\nconda create --name ami python=3.9 -y\nconda activate ami\n\nconda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge -y\npython -m pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html\npip install easydict\npip install diffusers==0.20.2 accelerate safetensors transformers\npip install setuptools==59.5.0\npip install opencv-python\npip install numpy==1.24.1\n```\n\n\nPut the code inside the root of diffusers library (e.g., as '/home/username/diffusers/adaptive_mask_inpainting_example.py') and run the python code.\n\n\n\n\n\"\"\"\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> # !pip install transformers accelerate\n        >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler\n        >>> from diffusers.utils import load_image\n        >>> import numpy as np\n        >>> import torch\n\n        >>> init_image = load_image(\n        ...     \"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png\"\n        ... )\n        >>> init_image = init_image.resize((512, 512))\n\n        >>> generator = torch.Generator(device=\"cpu\").manual_seed(1)\n\n        >>> mask_image = load_image(\n        ...     \"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png\"\n        ... )\n        >>> mask_image = mask_image.resize((512, 512))\n\n\n        >>> def make_inpaint_condition(image, image_mask):\n        ...     image = np.array(image.convert(\"RGB\")).astype(np.float32) / 255.0\n        ...     image_mask = np.array(image_mask.convert(\"L\")).astype(np.float32) / 255.0\n\n        ...     assert image.shape[0:1] == image_mask.shape[0:1], \"image and image_mask must have the same image size\"\n        ...     image[image_mask > 0.5] = -1.0  # set as masked pixel\n        ...     image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)\n        ...     image = torch.from_numpy(image)\n        ...     return image\n\n\n        >>> control_image = make_inpaint_condition(init_image, mask_image)\n\n        >>> controlnet = ControlNetModel.from_pretrained(\n        ...     \"lllyasviel/control_v11p_sd15_inpaint\", torch_dtype=torch.float16\n        ... )\n        >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(\n        ...     \"stable-diffusion-v1-5/stable-diffusion-v1-5\", controlnet=controlnet, torch_dtype=torch.float16\n        ... )\n\n        >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> # generate image\n        >>> image = pipe(\n        ...     \"a handsome man with ray-ban sunglasses\",\n        ...     num_inference_steps=20,\n        ...     generator=generator,\n        ...     eta=1.0,\n        ...     image=init_image,\n        ...     mask_image=mask_image,\n        ...     control_image=control_image,\n        ... ).images[0]\n        ```\n\"\"\"\n\n\ndef download_file(url, output_file, exist_ok: bool):\n    if exist_ok and os.path.exists(output_file):\n        return\n\n    response = requests.get(url, stream=True)\n\n    with open(output_file, \"wb\") as file:\n        for chunk in tqdm(response.iter_content(chunk_size=8192), desc=f\"Downloading '{output_file}'...\"):\n            if chunk:\n                file.write(chunk)\n\n\ndef generate_video_from_imgs(images_save_directory, fps=15.0, delete_dir=True):\n    # delete videos if exists\n    if os.path.exists(f\"{images_save_directory}.mp4\"):\n        os.remove(f\"{images_save_directory}.mp4\")\n    if os.path.exists(f\"{images_save_directory}_before_process.mp4\"):\n        os.remove(f\"{images_save_directory}_before_process.mp4\")\n\n    # assume there are \"enumerated\" images under \"images_save_directory\"\n    assert os.path.isdir(images_save_directory)\n    ImgPaths = sorted(glob(f\"{images_save_directory}/*\"))\n\n    if len(ImgPaths) == 0:\n        print(\"\\tSkipping, since there must be at least one image to create mp4\\n\")\n    else:\n        # mp4 configuration\n        video_path = images_save_directory + \"_before_process.mp4\"\n\n        # Get height and width config\n        images = sorted([ImgPath.split(\"/\")[-1] for ImgPath in ImgPaths if ImgPath.endswith(\".png\")])\n        frame = cv2.imread(os.path.join(images_save_directory, images[0]))\n        height, width, channels = frame.shape\n\n        # create mp4 video writer\n        fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n        video = cv2.VideoWriter(video_path, fourcc, fps, (width, height))\n        for image in images:\n            video.write(cv2.imread(os.path.join(images_save_directory, image)))\n        cv2.destroyAllWindows()\n        video.release()\n\n        # generated video is not compatible with HTML5. Post-process and change codec of video, so that it is applicable to HTML.\n        os.system(\n            f'ffmpeg -i \"{images_save_directory}_before_process.mp4\" -vcodec libx264 -f mp4 \"{images_save_directory}.mp4\" '\n        )\n\n    # remove group of images, and remove video before post-process.\n    if delete_dir and os.path.exists(images_save_directory):\n        shutil.rmtree(images_save_directory)\n    # remove 'before-process' video\n    if os.path.exists(f\"{images_save_directory}_before_process.mp4\"):\n        os.remove(f\"{images_save_directory}_before_process.mp4\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image\ndef prepare_mask_and_masked_image(image, mask, height, width, return_image=False):\n    \"\"\"\n    Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be\n    converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the\n    ``image`` and ``1`` for the ``mask``.\n\n    The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be\n    binarized (``mask > 0.5``) and cast to ``torch.float32`` too.\n\n    Args:\n        image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.\n            It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``\n            ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.\n        mask (_type_): The mask to apply to the image, i.e. regions to inpaint.\n            It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``\n            ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.\n\n\n    Raises:\n        ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask\n        should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.\n        TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not\n            (ot the other way around).\n\n    Returns:\n        tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4\n            dimensions: ``batch x channels x height x width``.\n    \"\"\"\n\n    if image is None:\n        raise ValueError(\"`image` input cannot be undefined.\")\n\n    if mask is None:\n        raise ValueError(\"`mask_image` input cannot be undefined.\")\n\n    if isinstance(image, torch.Tensor):\n        if not isinstance(mask, torch.Tensor):\n            raise TypeError(f\"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not\")\n\n        # Batch single image\n        if image.ndim == 3:\n            assert image.shape[0] == 3, \"Image outside a batch should be of shape (3, H, W)\"\n            image = image.unsqueeze(0)\n\n        # Batch and add channel dim for single mask\n        if mask.ndim == 2:\n            mask = mask.unsqueeze(0).unsqueeze(0)\n\n        # Batch single mask or add channel dim\n        if mask.ndim == 3:\n            # Single batched mask, no channel dim or single mask not batched but channel dim\n            if mask.shape[0] == 1:\n                mask = mask.unsqueeze(0)\n\n            # Batched masks no channel dim\n            else:\n                mask = mask.unsqueeze(1)\n\n        assert image.ndim == 4 and mask.ndim == 4, \"Image and Mask must have 4 dimensions\"\n        assert image.shape[-2:] == mask.shape[-2:], \"Image and Mask must have the same spatial dimensions\"\n        assert image.shape[0] == mask.shape[0], \"Image and Mask must have the same batch size\"\n\n        # Check image is in [-1, 1]\n        if image.min() < -1 or image.max() > 1:\n            raise ValueError(\"Image should be in [-1, 1] range\")\n\n        # Check mask is in [0, 1]\n        if mask.min() < 0 or mask.max() > 1:\n            raise ValueError(\"Mask should be in [0, 1] range\")\n\n        # Binarize mask\n        mask[mask < 0.5] = 0\n        mask[mask >= 0.5] = 1\n\n        # Image as float32\n        image = image.to(dtype=torch.float32)\n    elif isinstance(mask, torch.Tensor):\n        raise TypeError(f\"`mask` is a torch.Tensor but `image` (type: {type(image)} is not\")\n    else:\n        # preprocess image\n        if isinstance(image, (PIL.Image.Image, np.ndarray)):\n            image = [image]\n        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n            # resize all images w.r.t passed height an width\n            image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]\n            image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n            image = np.concatenate(image, axis=0)\n        elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n            image = np.concatenate([i[None, :] for i in image], axis=0)\n\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n        # preprocess mask\n        if isinstance(mask, (PIL.Image.Image, np.ndarray)):\n            mask = [mask]\n\n        if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):\n            mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]\n            mask = np.concatenate([np.array(m.convert(\"L\"))[None, None, :] for m in mask], axis=0)\n            mask = mask.astype(np.float32) / 255.0\n        elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):\n            mask = np.concatenate([m[None, None, :] for m in mask], axis=0)\n\n        mask[mask < 0.5] = 0\n        mask[mask >= 0.5] = 1\n        mask = torch.from_numpy(mask)\n\n    masked_image = image * (mask < 0.5)\n\n    # n.b. ensure backwards compatibility as old function does not return image\n    if return_image:\n        return mask, masked_image, image\n\n    return mask, masked_image\n\n\nclass AdaptiveMaskInpaintPipeline(\n    DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin\n):\n    r\"\"\"\n    Pipeline for text-guided image inpainting using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n\n    Args:\n        vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: Union[AutoencoderKL, AsymmetricAutoencoderKL],\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        # safety_checker: StableDiffusionSafetyChecker,\n        safety_checker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        self.register_adaptive_mask_model()\n        self.register_adaptive_mask_settings()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"skip_prk_steps\", True) is False:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration\"\n                \" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make\"\n                \" sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to\"\n                \" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face\"\n                \" Hub, it would be very nice if you could open a Pull request for the\"\n                \" `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"skip_prk_steps not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"skip_prk_steps\"] = True\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4\n        if unet is not None and unet.config.in_channels != 9:\n            logger.info(f\"You have loaded a UNet with {unet.config.in_channels} input channels which.\")\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n        \"\"\" Preparation for Adaptive Mask inpainting \"\"\"\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload\n    def enable_model_cpu_offload(self, gpu_id=0):\n        r\"\"\"\n        Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a\n        time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs.\n        Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the\n        iterative execution of the `unet`.\n        \"\"\"\n        if is_accelerate_available() and is_accelerate_version(\">=\", \"0.17.0.dev0\"):\n            from accelerate import cpu_offload_with_hook\n        else:\n            raise ImportError(\"`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.\")\n\n        device = torch.device(f\"cuda:{gpu_id}\")\n\n        if self.device.type != \"cpu\":\n            self.to(\"cpu\", silence_dtype_warnings=True)\n            torch.cuda.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)\n\n        hook = None\n        for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:\n            _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)\n\n        if self.safety_checker is not None:\n            _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)\n\n        # We'll offload the last model manually.\n        self.final_offload_hook = hook\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, LoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            prompt_embeds = self.text_encoder(\n                text_input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        strength,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n    ):\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n        image=None,\n        timestep=None,\n        is_strength_max=True,\n        return_noise=False,\n        return_image_latents=False,\n    ):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if (image is None or timestep is None) and not is_strength_max:\n            raise ValueError(\n                \"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise.\"\n                \"However, either the image or the noise timestep has not been provided.\"\n            )\n\n        if return_image_latents or (latents is None and not is_strength_max):\n            image = image.to(device=device, dtype=dtype)\n            image_latents = self._encode_vae_image(image=image, generator=generator)\n\n        if latents is None:\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            # if strength is 1. then initialise the latents to noise, else initial to image + noise\n            latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)\n            # if pure noise then scale the initial latents by the  Scheduler's init sigma\n            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents\n        else:\n            noise = latents.to(device)\n            latents = noise * self.scheduler.init_noise_sigma\n\n        outputs = (latents,)\n\n        if return_noise:\n            outputs += (noise,)\n\n        if return_image_latents:\n            outputs += (image_latents,)\n\n        return outputs\n\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\n        if isinstance(generator, list):\n            image_latents = [\n                self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])\n                for i in range(image.shape[0])\n            ]\n            image_latents = torch.cat(image_latents, dim=0)\n        else:\n            image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)\n\n        image_latents = self.vae.config.scaling_factor * image_latents\n\n        return image_latents\n\n    def prepare_mask_latents(\n        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance\n    ):\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\n        # and half precision\n        mask = torch.nn.functional.interpolate(\n            mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)\n        )\n        mask = mask.to(device=device, dtype=dtype)\n\n        masked_image = masked_image.to(device=device, dtype=dtype)\n        masked_image_latents = self._encode_vae_image(masked_image, generator=generator)\n\n        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\n        if mask.shape[0] < batch_size:\n            if not batch_size % mask.shape[0] == 0:\n                raise ValueError(\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\n                    f\" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number\"\n                    \" of masks that you pass is divisible by the total requested batch size.\"\n                )\n            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)\n        if masked_image_latents.shape[0] < batch_size:\n            if not batch_size % masked_image_latents.shape[0] == 0:\n                raise ValueError(\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                    f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                )\n            masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)\n\n        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask\n        masked_image_latents = (\n            torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents\n        )\n\n        # aligning device to prevent device errors when concating it with the latent model input\n        masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\n        return mask, masked_image_latents\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        return timesteps, num_inference_steps - t_start\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: Union[torch.FloatTensor, PIL.Image.Image] = None,\n        default_mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        strength: float = 1.0,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        use_adaptive_mask: bool = True,\n        enforce_full_mask_ratio: float = 0.5,\n        human_detection_thres: float = 0.008,\n        visualization_save_dir: str = None,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            image (`PIL.Image.Image`):\n                `Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked\n                out with `default_mask_image` and repainted according to `prompt`).\n            default_mask_image (`PIL.Image.Image`):\n                `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted\n                while black pixels are preserved. If `default_mask_image` is a PIL image, it is converted to a single channel\n                (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the\n                expected shape would be `(B, H, W, 1)`.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            strength (`float`, *optional*, defaults to 1.0):\n                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1\n                essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference. This parameter is modulated by `strength`.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that calls every `callback_steps` steps during inference. The function is called with the\n                following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function is called. If not specified, the callback is called at\n                every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n\n        Examples:\n\n        ```py\n        >>> import PIL\n        >>> import requests\n        >>> import torch\n        >>> from io import BytesIO\n\n        >>> from diffusers import AdaptiveMaskInpaintPipeline\n\n\n        >>> def download_image(url):\n        ...     response = requests.get(url)\n        ...     return PIL.Image.open(BytesIO(response.content)).convert(\"RGB\")\n\n\n        >>> img_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\n        >>> mask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\n        >>> init_image = download_image(img_url).resize((512, 512))\n        >>> default_mask_image = download_image(mask_url).resize((512, 512))\n\n        >>> pipe = AdaptiveMaskInpaintPipeline.from_pretrained(\n        ...     \"stable-diffusion-v1-5/stable-diffusion-inpainting\", torch_dtype=torch.float16\n        ... )\n        >>> pipe = pipe.to(\"cuda\")\n\n        >>> prompt = \"Face of a yellow cat, high resolution, sitting on a park bench\"\n        >>> image = pipe(prompt=prompt, image=init_image, default_mask_image=default_mask_image).images[0]\n        ```\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n        # 0. Default height and width to unet\n        width, height = image.size\n        # height = height or self.unet.config.sample_size * self.vae_scale_factor\n        # width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            strength,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 4. set timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps, num_inference_steps = self.get_timesteps(\n            num_inference_steps=num_inference_steps, strength=strength, device=device\n        )\n        # check that number of inference steps is not < 1 - as this doesn't make sense\n        if num_inference_steps < 1:\n            raise ValueError(\n                f\"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline\"\n                f\"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline.\"\n            )\n        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise\n        is_strength_max = strength == 1.0\n\n        # 5. Preprocess mask and image (will be used later, once again)\n        mask, masked_image, init_image = prepare_mask_and_masked_image(\n            image, default_mask_image, height, width, return_image=True\n        )\n        default_mask_image_np = np.array(default_mask_image).astype(np.uint8) / 255\n        mask_condition = mask.clone()\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.vae.config.latent_channels\n        num_channels_unet = self.unet.config.in_channels\n        return_image_latents = num_channels_unet == 4\n\n        latents_outputs = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n            image=init_image,\n            timestep=latent_timestep,\n            is_strength_max=is_strength_max,\n            return_noise=True,\n            return_image_latents=return_image_latents,\n        )\n\n        if return_image_latents:\n            latents, noise, image_latents = latents_outputs\n        else:\n            latents, noise = latents_outputs\n\n        # 7. Prepare mask latent variables\n        mask, masked_image_latents = self.prepare_mask_latents(\n            mask,\n            masked_image,\n            batch_size * num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            do_classifier_free_guidance,\n        )\n\n        # 8. Check that sizes of mask, masked image and latents match\n        if num_channels_unet == 9:\n            # default case for stable-diffusion-v1-5/stable-diffusion-inpainting\n            num_channels_mask = mask.shape[1]\n            num_channels_masked_image = masked_image_latents.shape[1]\n            if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:\n                raise ValueError(\n                    f\"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects\"\n                    f\" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +\"\n                    f\" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}\"\n                    f\" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of\"\n                    \" `pipeline.unet` or your `default_mask_image` or `image` input.\"\n                )\n        elif num_channels_unet != 4:\n            raise ValueError(\n                f\"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}.\"\n            )\n\n        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 10. Denoising loop\n        mask_image_np = default_mask_image_np\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n\n                # concat latents, mask, masked_image_latents in the channel dimension\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                if num_channels_unet == 9:\n                    latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)\n                else:\n                    raise NotImplementedError\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1 & predicted original sample x_0\n                outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)\n                latents = outputs[\"prev_sample\"]  # x_t-1\n                pred_orig_latents = outputs[\"pred_original_sample\"]  # x_0\n\n                # run segmentation\n                if use_adaptive_mask:\n                    if enforce_full_mask_ratio > 0.0:\n                        use_default_mask = t < self.scheduler.config.num_train_timesteps * enforce_full_mask_ratio\n                    elif enforce_full_mask_ratio == 0.0:\n                        use_default_mask = False\n                    else:\n                        raise NotImplementedError\n\n                    pred_orig_image = self.decode_to_npuint8_image(pred_orig_latents)\n                    dilate_num = self.adaptive_mask_settings.dilate_scheduler(i)\n                    do_adapt_mask = self.adaptive_mask_settings.provoke_scheduler(i)\n                    if do_adapt_mask:\n                        mask, masked_image_latents, mask_image_np, vis_np = self.adapt_mask(\n                            init_image,\n                            pred_orig_image,\n                            default_mask_image_np,\n                            dilate_num=dilate_num,\n                            use_default_mask=use_default_mask,\n                            height=height,\n                            width=width,\n                            batch_size=batch_size,\n                            num_images_per_prompt=num_images_per_prompt,\n                            prompt_embeds=prompt_embeds,\n                            device=device,\n                            generator=generator,\n                            do_classifier_free_guidance=do_classifier_free_guidance,\n                            i=i,\n                            human_detection_thres=human_detection_thres,\n                            mask_image_np=mask_image_np,\n                        )\n\n                    if self.adaptive_mask_model.use_visualizer:\n                        import matplotlib.pyplot as plt\n\n                        # mask_image_new_colormap = np.clip(0.6 + (1.0 - mask_image_np), a_min=0.0, a_max=1.0) * 255\n\n                        os.makedirs(visualization_save_dir, exist_ok=True)\n\n                        # Image.fromarray(mask_image_new_colormap).convert(\"L\").save(f\"{visualization_save_dir}/masks/{i:05}.png\")\n                        plt.axis(\"off\")\n                        plt.subplot(1, 2, 1)\n                        plt.imshow(mask_image_np)\n                        plt.subplot(1, 2, 2)\n                        plt.imshow(pred_orig_image)\n                        plt.savefig(f\"{visualization_save_dir}/{i:05}.png\", bbox_inches=\"tight\")\n                        plt.close(\"all\")\n\n                if num_channels_unet == 4:\n                    init_latents_proper = image_latents[:1]\n                    init_mask = mask[:1]\n\n                    if i < len(timesteps) - 1:\n                        noise_timestep = timesteps[i + 1]\n                        init_latents_proper = self.scheduler.add_noise(\n                            init_latents_proper, noise, torch.tensor([noise_timestep])\n                        )\n\n                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        callback(i, t, latents)\n\n        if not output_type == \"latent\":\n            condition_kwargs = {}\n            if isinstance(self.vae, AsymmetricAutoencoderKL):\n                init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)\n                init_image_condition = init_image.clone()\n                init_image = self._encode_vae_image(init_image, generator=generator)\n                mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)\n                condition_kwargs = {\"image\": init_image_condition, \"mask\": mask_condition}\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if self.adaptive_mask_model.use_visualizer:\n            generate_video_from_imgs(images_save_directory=visualization_save_dir, fps=10, delete_dir=True)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n    def decode_to_npuint8_image(self, latents):\n        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **{})[\n            0\n        ]  # torch, float32, -1.~1.\n        image = self.image_processor.postprocess(image, output_type=\"pt\", do_denormalize=[True] * image.shape[0])\n        image = (image.squeeze().permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)  # np, uint8, 0~255\n        return image\n\n    def register_adaptive_mask_settings(self):\n        from easydict import EasyDict\n\n        num_steps = 50\n\n        step_num = int(num_steps * 0.1)\n        final_step_num = num_steps - step_num * 7\n        # adaptive mask settings\n        self.adaptive_mask_settings = EasyDict(\n            dilate_scheduler=MaskDilateScheduler(\n                max_dilate_num=20,\n                num_inference_steps=num_steps,\n                schedule=[20] * step_num\n                + [10] * step_num\n                + [5] * step_num\n                + [4] * step_num\n                + [3] * step_num\n                + [2] * step_num\n                + [1] * step_num\n                + [0] * final_step_num,\n            ),\n            dilate_kernel=np.ones((3, 3), dtype=np.uint8),\n            provoke_scheduler=ProvokeScheduler(\n                num_inference_steps=num_steps,\n                schedule=list(range(2, 10 + 1, 2)) + list(range(12, 40 + 1, 2)) + [45],\n                is_zero_indexing=False,\n            ),\n        )\n\n    def register_adaptive_mask_model(self):\n        # declare segmentation model used for mask adaptation\n        use_visualizer = True\n        # assert not use_visualizer, \\\n        # \"\"\"\n        # If you plan to 'use_visualizer', USE WITH CAUTION.\n        # It creates a directory of images and masks, which is used for merging into a video.\n        # The procedure involves deleting the directory of images, which means that\n        # if you set the directory wrong you can have other important files blown away.\n        # \"\"\"\n\n        self.adaptive_mask_model = PointRendPredictor(\n            # pointrend_thres=0.2,\n            pointrend_thres=0.9,\n            device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n            use_visualizer=use_visualizer,\n            config_pth=\"pointrend_rcnn_R_50_FPN_3x_coco.yaml\",\n            weights_pth=\"model_final_edd263.pkl\",\n        )\n\n    def adapt_mask(self, init_image, pred_orig_image, default_mask_image, dilate_num, use_default_mask, **kwargs):\n        ## predict mask to use for adaptation\n        adapt_output = self.adaptive_mask_model(pred_orig_image)  # vis can be None if 'use_visualizer' is False\n        mask = adapt_output[\"mask\"]\n        vis = adapt_output[\"vis\"]\n\n        ## if mask is empty or too small, use default_mask_image. else, use dilate and intersect with default_mask_image\n        if use_default_mask or mask.sum() < 512 * 512 * kwargs[\"human_detection_thres\"]:  # 0.005\n            # set mask as default mask\n            mask = default_mask_image  # HxW\n\n        else:\n            ## timestep-adaptive mask\n            mask = cv2.dilate(\n                mask, self.adaptive_mask_settings.dilate_kernel, iterations=dilate_num\n            )  # dilate_kernel: np.ones((3,3), np.uint8)\n            mask = np.logical_and(mask, default_mask_image)  # HxW\n\n        ## prepare mask as pt tensor format\n        mask = torch.tensor(mask, dtype=torch.float32).to(kwargs[\"device\"])[None, None]  # 1 x 1 x H x W\n        mask, masked_image = prepare_mask_and_masked_image(\n            init_image.to(kwargs[\"device\"]), mask, kwargs[\"height\"], kwargs[\"width\"], return_image=False\n        )\n\n        mask_image_np = mask.clone().squeeze().detach().cpu().numpy()\n\n        mask, masked_image_latents = self.prepare_mask_latents(\n            mask,\n            masked_image,\n            kwargs[\"batch_size\"] * kwargs[\"num_images_per_prompt\"],\n            kwargs[\"height\"],\n            kwargs[\"width\"],\n            kwargs[\"prompt_embeds\"].dtype,\n            kwargs[\"device\"],\n            kwargs[\"generator\"],\n            kwargs[\"do_classifier_free_guidance\"],\n        )\n\n        return mask, masked_image_latents, mask_image_np, vis\n\n\ndef seg2bbox(seg_mask: np.ndarray):\n    nonzero_i, nonzero_j = seg_mask.nonzero()\n    min_i, max_i = nonzero_i.min(), nonzero_i.max()\n    min_j, max_j = nonzero_j.min(), nonzero_j.max()\n\n    return np.array([min_j, min_i, max_j + 1, max_i + 1])\n\n\ndef merge_bbox(bboxes: list):\n    assert len(bboxes) > 0\n\n    all_bboxes = np.stack(bboxes, axis=0)  # shape: N_bbox X 4\n    merged_bbox = np.zeros_like(all_bboxes[0])  # shape: 4,\n\n    merged_bbox[0] = all_bboxes[:, 0].min()\n    merged_bbox[1] = all_bboxes[:, 1].min()\n    merged_bbox[2] = all_bboxes[:, 2].max()\n    merged_bbox[3] = all_bboxes[:, 3].max()\n\n    return merged_bbox\n\n\nclass PointRendPredictor:\n    def __init__(\n        self,\n        cat_id_to_focus=0,\n        pointrend_thres=0.9,\n        device=\"cuda\",\n        use_visualizer=False,\n        merge_mode=\"merge\",\n        config_pth=None,\n        weights_pth=None,\n    ):\n        super().__init__()\n\n        # category id to focus (default: 0, which is human)\n        self.cat_id_to_focus = cat_id_to_focus\n\n        # setup coco metadata\n        self.coco_metadata = MetadataCatalog.get(\"coco_2017_val\")\n        self.cfg = get_cfg()\n\n        # get segmentation model config\n        point_rend.add_pointrend_config(self.cfg)  # --> Add PointRend-specific config\n        self.cfg.merge_from_file(config_pth)\n        self.cfg.MODEL.WEIGHTS = weights_pth\n        self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = pointrend_thres\n        self.cfg.MODEL.DEVICE = device\n\n        # get segmentation model\n        self.pointrend_seg_model = DefaultPredictor(self.cfg)\n\n        # settings for visualizer\n        self.use_visualizer = use_visualizer\n\n        # mask-merge mode\n        assert merge_mode in [\"merge\", \"max-confidence\"], f\"'merge_mode': {merge_mode} not implemented.\"\n        self.merge_mode = merge_mode\n\n    def merge_mask(self, masks, scores=None):\n        if self.merge_mode == \"merge\":\n            mask = np.any(masks, axis=0)\n        elif self.merge_mode == \"max-confidence\":\n            mask = masks[np.argmax(scores)]\n        return mask\n\n    def vis_seg_on_img(self, image, mask):\n        if type(mask) == np.ndarray:\n            mask = torch.tensor(mask)\n        v = Visualizer(image, self.coco_metadata, scale=0.5, instance_mode=ColorMode.IMAGE_BW)\n        instances = Instances(image_size=image.shape[:2], pred_masks=mask if len(mask.shape) == 3 else mask[None])\n        vis = v.draw_instance_predictions(instances.to(\"cpu\")).get_image()\n        return vis\n\n    def __call__(self, image):\n        # run segmentation\n        outputs = self.pointrend_seg_model(image)\n        instances = outputs[\"instances\"]\n\n        # merge instances for the category-id to focus\n        is_class = instances.pred_classes == self.cat_id_to_focus\n        masks = instances.pred_masks[is_class]\n        masks = masks.detach().cpu().numpy()  # [N, img_size, img_size]\n        mask = self.merge_mask(masks, scores=instances.scores[is_class])\n\n        return {\n            \"asset_mask\": None,\n            \"mask\": mask.astype(np.uint8),\n            \"vis\": self.vis_seg_on_img(image, mask) if self.use_visualizer else None,\n        }\n\n\nclass MaskDilateScheduler:\n    def __init__(self, max_dilate_num=15, num_inference_steps=50, schedule=None):\n        super().__init__()\n        self.max_dilate_num = max_dilate_num\n        self.schedule = [num_inference_steps - i for i in range(num_inference_steps)] if schedule is None else schedule\n        assert len(self.schedule) == num_inference_steps\n\n    def __call__(self, i):\n        return min(self.max_dilate_num, self.schedule[i])\n\n\nclass ProvokeScheduler:\n    def __init__(self, num_inference_steps=50, schedule=None, is_zero_indexing=False):\n        super().__init__()\n        if len(schedule) > 0:\n            if is_zero_indexing:\n                assert max(schedule) <= num_inference_steps - 1\n            else:\n                assert max(schedule) <= num_inference_steps\n\n        # register as self\n        self.is_zero_indexing = is_zero_indexing\n        self.schedule = schedule\n\n    def __call__(self, i):\n        if self.is_zero_indexing:\n            return i in self.schedule\n        else:\n            return i + 1 in self.schedule\n"
  },
  {
    "path": "examples/community/bit_diffusion.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport torch\nfrom einops import rearrange, reduce\n\nfrom diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DConditionModel\nfrom diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput\nfrom diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput\n\n\nBITS = 8\n\n\n# convert to bit representations and back taken from https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py\ndef decimal_to_bits(x, bits=BITS):\n    \"\"\"expects image tensor ranging from 0 to 1, outputs bit tensor ranging from -1 to 1\"\"\"\n    device = x.device\n\n    x = (x * 255).int().clamp(0, 255)\n\n    mask = 2 ** torch.arange(bits - 1, -1, -1, device=device)\n    mask = rearrange(mask, \"d -> d 1 1\")\n    x = rearrange(x, \"b c h w -> b c 1 h w\")\n\n    bits = ((x & mask) != 0).float()\n    bits = rearrange(bits, \"b c d h w -> b (c d) h w\")\n    bits = bits * 2 - 1\n    return bits\n\n\ndef bits_to_decimal(x, bits=BITS):\n    \"\"\"expects bits from -1 to 1, outputs image tensor from 0 to 1\"\"\"\n    device = x.device\n\n    x = (x > 0).int()\n    mask = 2 ** torch.arange(bits - 1, -1, -1, device=device, dtype=torch.int32)\n\n    mask = rearrange(mask, \"d -> d 1 1\")\n    x = rearrange(x, \"b (c d) h w -> b c d h w\", d=8)\n    dec = reduce(x * mask, \"b c d h w -> b c h w\", \"sum\")\n    return (dec / 255).clamp(0.0, 1.0)\n\n\n# modified scheduler step functions for clamping the predicted x_0 between -bit_scale and +bit_scale\ndef ddim_bit_scheduler_step(\n    self,\n    model_output: torch.Tensor,\n    timestep: int,\n    sample: torch.Tensor,\n    eta: float = 0.0,\n    use_clipped_model_output: bool = True,\n    generator=None,\n    return_dict: bool = True,\n) -> Union[DDIMSchedulerOutput, Tuple]:\n    \"\"\"\n    Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion\n    process from the learned model outputs (most often the predicted noise).\n    Args:\n        model_output (`torch.Tensor`): direct output from learned diffusion model.\n        timestep (`int`): current discrete timestep in the diffusion chain.\n        sample (`torch.Tensor`):\n            current instance of sample being created by diffusion process.\n        eta (`float`): weight of noise for added noise in diffusion step.\n        use_clipped_model_output (`bool`): TODO\n        generator: random number generator.\n        return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class\n    Returns:\n        [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:\n        [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When\n        returning a tuple, the first element is the sample tensor.\n    \"\"\"\n    if self.num_inference_steps is None:\n        raise ValueError(\n            \"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler\"\n        )\n\n    # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502\n    # Ideally, read DDIM paper in-detail understanding\n\n    # Notation (<variable name> -> <name in paper>\n    # - pred_noise_t -> e_theta(x_t, t)\n    # - pred_original_sample -> f_theta(x_t, t) or x_0\n    # - std_dev_t -> sigma_t\n    # - eta -> η\n    # - pred_sample_direction -> \"direction pointing to x_t\"\n    # - pred_prev_sample -> \"x_t-1\"\n\n    # 1. get previous step value (=t-1)\n    prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps\n\n    # 2. compute alphas, betas\n    alpha_prod_t = self.alphas_cumprod[timestep]\n    alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod\n\n    beta_prod_t = 1 - alpha_prod_t\n\n    # 3. compute predicted original sample from predicted noise also called\n    # \"predicted x_0\" of formula (12) from https://huggingface.co/papers/2010.02502\n    pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n\n    # 4. Clip \"predicted x_0\"\n    scale = self.bit_scale\n    if self.config.clip_sample:\n        pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)\n\n    # 5. compute variance: \"sigma_t(η)\" -> see formula (16)\n    # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)\n    variance = self._get_variance(timestep, prev_timestep)\n    std_dev_t = eta * variance ** (0.5)\n\n    if use_clipped_model_output:\n        # the model_output is always re-derived from the clipped x_0 in Glide\n        model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)\n\n    # 6. compute \"direction pointing to x_t\" of formula (12) from https://huggingface.co/papers/2010.02502\n    pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output\n\n    # 7. compute x_t without \"random noise\" of formula (12) from https://huggingface.co/papers/2010.02502\n    prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction\n\n    if eta > 0:\n        # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072\n        device = model_output.device if torch.is_tensor(model_output) else \"cpu\"\n        noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)\n        variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise\n\n        prev_sample = prev_sample + variance\n\n    if not return_dict:\n        return (prev_sample,)\n\n    return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)\n\n\ndef ddpm_bit_scheduler_step(\n    self,\n    model_output: torch.Tensor,\n    timestep: int,\n    sample: torch.Tensor,\n    prediction_type=\"epsilon\",\n    generator=None,\n    return_dict: bool = True,\n) -> Union[DDPMSchedulerOutput, Tuple]:\n    \"\"\"\n    Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion\n    process from the learned model outputs (most often the predicted noise).\n    Args:\n        model_output (`torch.Tensor`): direct output from learned diffusion model.\n        timestep (`int`): current discrete timestep in the diffusion chain.\n        sample (`torch.Tensor`):\n            current instance of sample being created by diffusion process.\n        prediction_type (`str`, default `epsilon`):\n            indicates whether the model predicts the noise (epsilon), or the samples (`sample`).\n        generator: random number generator.\n        return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class\n    Returns:\n        [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:\n        [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When\n        returning a tuple, the first element is the sample tensor.\n    \"\"\"\n    t = timestep\n\n    if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [\"learned\", \"learned_range\"]:\n        model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)\n    else:\n        predicted_variance = None\n\n    # 1. compute alphas, betas\n    alpha_prod_t = self.alphas_cumprod[t]\n    alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one\n    beta_prod_t = 1 - alpha_prod_t\n    beta_prod_t_prev = 1 - alpha_prod_t_prev\n\n    # 2. compute predicted original sample from predicted noise also called\n    # \"predicted x_0\" of formula (15) from https://huggingface.co/papers/2006.11239\n    if prediction_type == \"epsilon\":\n        pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n    elif prediction_type == \"sample\":\n        pred_original_sample = model_output\n    else:\n        raise ValueError(f\"Unsupported prediction_type {prediction_type}.\")\n\n    # 3. Clip \"predicted x_0\"\n    scale = self.bit_scale\n    if self.config.clip_sample:\n        pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)\n\n    # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t\n    # See formula (7) from https://huggingface.co/papers/2006.11239\n    pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t\n    current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t\n\n    # 5. Compute predicted previous sample µ_t\n    # See formula (7) from https://huggingface.co/papers/2006.11239\n    pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample\n\n    # 6. Add noise\n    variance = 0\n    if t > 0:\n        noise = torch.randn(\n            model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator\n        ).to(model_output.device)\n        variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise\n\n    pred_prev_sample = pred_prev_sample + variance\n\n    if not return_dict:\n        return (pred_prev_sample,)\n\n    return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)\n\n\nclass BitDiffusion(DiffusionPipeline):\n    def __init__(\n        self,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, DDPMScheduler],\n        bit_scale: Optional[float] = 1.0,\n    ):\n        super().__init__()\n        self.bit_scale = bit_scale\n        self.scheduler.step = (\n            ddim_bit_scheduler_step if isinstance(scheduler, DDIMScheduler) else ddpm_bit_scheduler_step\n        )\n\n        self.register_modules(unet=unet, scheduler=scheduler)\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        height: Optional[int] = 256,\n        width: Optional[int] = 256,\n        num_inference_steps: Optional[int] = 50,\n        generator: torch.Generator | None = None,\n        batch_size: Optional[int] = 1,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        **kwargs,\n    ) -> Union[Tuple, ImagePipelineOutput]:\n        latents = torch.randn(\n            (batch_size, self.unet.config.in_channels, height, width),\n            generator=generator,\n        )\n        latents = decimal_to_bits(latents) * self.bit_scale\n        latents = latents.to(self.device)\n\n        self.scheduler.set_timesteps(num_inference_steps)\n\n        for t in self.progress_bar(self.scheduler.timesteps):\n            # predict the noise residual\n            noise_pred = self.unet(latents, t).sample\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n        image = bits_to_decimal(latents)\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image,)\n\n        return ImagePipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/checkpoint_merger.py",
    "content": "import glob\nimport os\nfrom typing import Dict, List, Union\n\nimport safetensors.torch\nimport torch\nfrom huggingface_hub import snapshot_download\nfrom huggingface_hub.utils import validate_hf_hub_args\n\nfrom diffusers import DiffusionPipeline, __version__\nfrom diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME\nfrom diffusers.utils import CONFIG_NAME, ONNX_WEIGHTS_NAME, WEIGHTS_NAME\n\n\nclass CheckpointMergerPipeline(DiffusionPipeline):\n    \"\"\"\n    A class that supports merging diffusion models based on the discussion here:\n    https://github.com/huggingface/diffusers/issues/877\n\n    Example usage:-\n\n    pipe = DiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", custom_pipeline=\"checkpoint_merger.py\")\n\n    merged_pipe = pipe.merge([\"CompVis/stable-diffusion-v1-4\",\"prompthero/openjourney\"], interp = 'inv_sigmoid', alpha = 0.8, force = True)\n\n    merged_pipe.to('cuda')\n\n    prompt = \"An astronaut riding a unicycle on Mars\"\n\n    results = merged_pipe(prompt)\n\n    ## For more details, see the docstring for the merge method.\n\n    \"\"\"\n\n    def __init__(self):\n        self.register_to_config()\n        super().__init__()\n\n    def _compare_model_configs(self, dict0, dict1):\n        if dict0 == dict1:\n            return True\n        else:\n            config0, meta_keys0 = self._remove_meta_keys(dict0)\n            config1, meta_keys1 = self._remove_meta_keys(dict1)\n            if config0 == config1:\n                print(f\"Warning !: Mismatch in keys {meta_keys0} and {meta_keys1}.\")\n                return True\n        return False\n\n    def _remove_meta_keys(self, config_dict: Dict):\n        meta_keys = []\n        temp_dict = config_dict.copy()\n        for key in config_dict.keys():\n            if key.startswith(\"_\"):\n                temp_dict.pop(key)\n                meta_keys.append(key)\n        return (temp_dict, meta_keys)\n\n    @torch.no_grad()\n    @validate_hf_hub_args\n    def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs):\n        \"\"\"\n        Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed\n        in the argument 'pretrained_model_name_or_path_list' as a list.\n\n        Parameters:\n        -----------\n            pretrained_model_name_or_path_list : A list of valid pretrained model names in the HuggingFace hub or paths to locally stored models in the HuggingFace format.\n\n            **kwargs:\n                Supports all the default DiffusionPipeline.get_config_dict kwargs viz..\n\n                cache_dir, force_download, proxies, local_files_only, token, revision, torch_dtype, device_map.\n\n                alpha - The interpolation parameter. Ranges from 0 to 1.  It affects the ratio in which the checkpoints are merged. A 0.8 alpha\n                    would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2\n\n                interp - The interpolation method to use for the merging. Supports \"sigmoid\", \"inv_sigmoid\", \"add_diff\" and None.\n                    Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only \"add_diff\" is supported.\n\n                force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.\n\n                variant - which variant of a pretrained model to load, e.g. \"fp16\" (None)\n\n        \"\"\"\n        # Default kwargs from DiffusionPipeline\n        cache_dir = kwargs.pop(\"cache_dir\", None)\n        force_download = kwargs.pop(\"force_download\", False)\n        proxies = kwargs.pop(\"proxies\", None)\n        local_files_only = kwargs.pop(\"local_files_only\", False)\n        token = kwargs.pop(\"token\", None)\n        variant = kwargs.pop(\"variant\", None)\n        revision = kwargs.pop(\"revision\", None)\n        torch_dtype = kwargs.pop(\"torch_dtype\", torch.float32)\n        device_map = kwargs.pop(\"device_map\", None)\n\n        if not isinstance(torch_dtype, torch.dtype):\n            torch_dtype = torch.float32\n            print(f\"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`.\")\n\n        alpha = kwargs.pop(\"alpha\", 0.5)\n        interp = kwargs.pop(\"interp\", None)\n\n        print(\"Received list\", pretrained_model_name_or_path_list)\n        print(f\"Combining with alpha={alpha}, interpolation mode={interp}\")\n\n        checkpoint_count = len(pretrained_model_name_or_path_list)\n        # Ignore result from model_index_json comparison of the two checkpoints\n        force = kwargs.pop(\"force\", False)\n\n        # If less than 2 checkpoints, nothing to merge. If more than 3, not supported for now.\n        if checkpoint_count > 3 or checkpoint_count < 2:\n            raise ValueError(\n                \"Received incorrect number of checkpoints to merge. Ensure that either 2 or 3 checkpoints are being\"\n                \" passed.\"\n            )\n\n        print(\"Received the right number of checkpoints\")\n        # chkpt0, chkpt1 = pretrained_model_name_or_path_list[0:2]\n        # chkpt2 = pretrained_model_name_or_path_list[2] if checkpoint_count == 3 else None\n\n        # Validate that the checkpoints can be merged\n        # Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_'\n        config_dicts = []\n        for pretrained_model_name_or_path in pretrained_model_name_or_path_list:\n            config_dict = DiffusionPipeline.load_config(\n                pretrained_model_name_or_path,\n                cache_dir=cache_dir,\n                force_download=force_download,\n                proxies=proxies,\n                local_files_only=local_files_only,\n                token=token,\n                revision=revision,\n            )\n            config_dicts.append(config_dict)\n\n        comparison_result = True\n        for idx in range(1, len(config_dicts)):\n            comparison_result &= self._compare_model_configs(config_dicts[idx - 1], config_dicts[idx])\n            if not force and comparison_result is False:\n                raise ValueError(\"Incompatible checkpoints. Please check model_index.json for the models.\")\n        print(\"Compatible model_index.json files found\")\n        # Step 2: Basic Validation has succeeded. Let's download the models and save them into our local files.\n        cached_folders = []\n        for pretrained_model_name_or_path, config_dict in zip(pretrained_model_name_or_path_list, config_dicts):\n            folder_names = [k for k in config_dict.keys() if not k.startswith(\"_\")]\n            allow_patterns = [os.path.join(k, \"*\") for k in folder_names]\n            allow_patterns += [\n                WEIGHTS_NAME,\n                SCHEDULER_CONFIG_NAME,\n                CONFIG_NAME,\n                ONNX_WEIGHTS_NAME,\n                DiffusionPipeline.config_name,\n            ]\n            requested_pipeline_class = config_dict.get(\"_class_name\")\n            user_agent = {\"diffusers\": __version__, \"pipeline_class\": requested_pipeline_class}\n\n            cached_folder = (\n                pretrained_model_name_or_path\n                if os.path.isdir(pretrained_model_name_or_path)\n                else snapshot_download(\n                    pretrained_model_name_or_path,\n                    cache_dir=cache_dir,\n                    proxies=proxies,\n                    local_files_only=local_files_only,\n                    token=token,\n                    revision=revision,\n                    allow_patterns=allow_patterns,\n                    user_agent=user_agent,\n                )\n            )\n            print(\"Cached Folder\", cached_folder)\n            cached_folders.append(cached_folder)\n\n        # Step 3:-\n        # Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place\n        final_pipe = DiffusionPipeline.from_pretrained(\n            cached_folders[0],\n            torch_dtype=torch_dtype,\n            device_map=device_map,\n            variant=variant,\n        )\n        final_pipe.to(self.device)\n\n        checkpoint_path_2 = None\n        if len(cached_folders) > 2:\n            checkpoint_path_2 = os.path.join(cached_folders[2])\n\n        if interp == \"sigmoid\":\n            theta_func = CheckpointMergerPipeline.sigmoid\n        elif interp == \"inv_sigmoid\":\n            theta_func = CheckpointMergerPipeline.inv_sigmoid\n        elif interp == \"add_diff\":\n            theta_func = CheckpointMergerPipeline.add_difference\n        else:\n            theta_func = CheckpointMergerPipeline.weighted_sum\n\n        # Find each module's state dict.\n        for attr in final_pipe.config.keys():\n            if not attr.startswith(\"_\"):\n                checkpoint_path_1 = os.path.join(cached_folders[1], attr)\n                if os.path.exists(checkpoint_path_1):\n                    files = [\n                        *glob.glob(os.path.join(checkpoint_path_1, \"*.safetensors\")),\n                        *glob.glob(os.path.join(checkpoint_path_1, \"*.bin\")),\n                    ]\n                    checkpoint_path_1 = files[0] if len(files) > 0 else None\n                if len(cached_folders) < 3:\n                    checkpoint_path_2 = None\n                else:\n                    checkpoint_path_2 = os.path.join(cached_folders[2], attr)\n                    if os.path.exists(checkpoint_path_2):\n                        files = [\n                            *glob.glob(os.path.join(checkpoint_path_2, \"*.safetensors\")),\n                            *glob.glob(os.path.join(checkpoint_path_2, \"*.bin\")),\n                        ]\n                        checkpoint_path_2 = files[0] if len(files) > 0 else None\n                # For an attr if both checkpoint_path_1 and 2 are None, ignore.\n                # If at least one is present, deal with it according to interp method, of course only if the state_dict keys match.\n                if checkpoint_path_1 is None and checkpoint_path_2 is None:\n                    print(f\"Skipping {attr}: not present in 2nd or 3d model\")\n                    continue\n                try:\n                    module = getattr(final_pipe, attr)\n                    if isinstance(module, bool):  # ignore requires_safety_checker boolean\n                        continue\n                    theta_0 = getattr(module, \"state_dict\")\n                    theta_0 = theta_0()\n\n                    update_theta_0 = getattr(module, \"load_state_dict\")\n                    theta_1 = (\n                        safetensors.torch.load_file(checkpoint_path_1)\n                        if (checkpoint_path_1.endswith(\".safetensors\"))\n                        else torch.load(checkpoint_path_1, map_location=\"cpu\")\n                    )\n                    theta_2 = None\n                    if checkpoint_path_2:\n                        theta_2 = (\n                            safetensors.torch.load_file(checkpoint_path_2)\n                            if (checkpoint_path_2.endswith(\".safetensors\"))\n                            else torch.load(checkpoint_path_2, map_location=\"cpu\")\n                        )\n\n                    if not theta_0.keys() == theta_1.keys():\n                        print(f\"Skipping {attr}: key mismatch\")\n                        continue\n                    if theta_2 and not theta_1.keys() == theta_2.keys():\n                        print(f\"Skipping {attr}:y mismatch\")\n                except Exception as e:\n                    print(f\"Skipping {attr} do to an unexpected error: {str(e)}\")\n                    continue\n                print(f\"MERGING {attr}\")\n\n                for key in theta_0.keys():\n                    if theta_2:\n                        theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], alpha)\n                    else:\n                        theta_0[key] = theta_func(theta_0[key], theta_1[key], None, alpha)\n\n                del theta_1\n                del theta_2\n                update_theta_0(theta_0)\n\n                del theta_0\n        return final_pipe\n\n    @staticmethod\n    def weighted_sum(theta0, theta1, theta2, alpha):\n        return ((1 - alpha) * theta0) + (alpha * theta1)\n\n    # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)\n    @staticmethod\n    def sigmoid(theta0, theta1, theta2, alpha):\n        alpha = alpha * alpha * (3 - (2 * alpha))\n        return theta0 + ((theta1 - theta0) * alpha)\n\n    # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)\n    @staticmethod\n    def inv_sigmoid(theta0, theta1, theta2, alpha):\n        import math\n\n        alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)\n        return theta0 + ((theta1 - theta0) * alpha)\n\n    @staticmethod\n    def add_difference(theta0, theta1, theta2, alpha):\n        return theta0 + (theta1 - theta2) * (1.0 - alpha)\n"
  },
  {
    "path": "examples/community/clip_guided_images_mixing_stable_diffusion.py",
    "content": "# -*- coding: utf-8 -*-\nimport inspect\nfrom typing import Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom torch.nn import functional as F\nfrom torchvision import transforms\nfrom transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDIMScheduler,\n    DPMSolverMultistepScheduler,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n    UNet2DConditionModel,\n)\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.utils import PIL_INTERPOLATION\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\ndef preprocess(image, w, h):\n    if isinstance(image, torch.Tensor):\n        return image\n    elif isinstance(image, PIL.Image.Image):\n        image = [image]\n\n    if isinstance(image[0], PIL.Image.Image):\n        image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION[\"lanczos\"]))[None, :] for i in image]\n        image = np.concatenate(image, axis=0)\n        image = np.array(image).astype(np.float32) / 255.0\n        image = image.transpose(0, 3, 1, 2)\n        image = 2.0 * image - 1.0\n        image = torch.from_numpy(image)\n    elif isinstance(image[0], torch.Tensor):\n        image = torch.cat(image, dim=0)\n    return image\n\n\ndef slerp(t, v0, v1, DOT_THRESHOLD=0.9995):\n    if not isinstance(v0, np.ndarray):\n        inputs_are_torch = True\n        input_device = v0.device\n        v0 = v0.cpu().numpy()\n        v1 = v1.cpu().numpy()\n\n    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))\n    if np.abs(dot) > DOT_THRESHOLD:\n        v2 = (1 - t) * v0 + t * v1\n    else:\n        theta_0 = np.arccos(dot)\n        sin_theta_0 = np.sin(theta_0)\n        theta_t = theta_0 * t\n        sin_theta_t = np.sin(theta_t)\n        s0 = np.sin(theta_0 - theta_t) / sin_theta_0\n        s1 = sin_theta_t / sin_theta_0\n        v2 = s0 * v0 + s1 * v1\n\n    if inputs_are_torch:\n        v2 = torch.from_numpy(v2).to(input_device)\n\n    return v2\n\n\ndef spherical_dist_loss(x, y):\n    x = F.normalize(x, dim=-1)\n    y = F.normalize(y, dim=-1)\n    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)\n\n\ndef set_requires_grad(model, value):\n    for param in model.parameters():\n        param.requires_grad = value\n\n\nclass CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline, StableDiffusionMixin):\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        clip_model: CLIPModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],\n        feature_extractor: CLIPImageProcessor,\n        coca_model=None,\n        coca_tokenizer=None,\n        coca_transform=None,\n    ):\n        super().__init__()\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            clip_model=clip_model,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n            coca_model=coca_model,\n            coca_tokenizer=coca_tokenizer,\n            coca_transform=coca_transform,\n        )\n        self.feature_extractor_size = (\n            feature_extractor.size\n            if isinstance(feature_extractor.size, int)\n            else feature_extractor.size[\"shortest_edge\"]\n        )\n        self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)\n        set_requires_grad(self.text_encoder, False)\n        set_requires_grad(self.clip_model, False)\n\n    def freeze_vae(self):\n        set_requires_grad(self.vae, False)\n\n    def unfreeze_vae(self):\n        set_requires_grad(self.vae, True)\n\n    def freeze_unet(self):\n        set_requires_grad(self.unet, False)\n\n    def unfreeze_unet(self):\n        set_requires_grad(self.unet, True)\n\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start:]\n\n        return timesteps, num_inference_steps - t_start\n\n    def prepare_latents(self, image, timestep, batch_size, dtype, device, generator=None):\n        if not isinstance(image, torch.Tensor):\n            raise ValueError(f\"`image` has to be of type `torch.Tensor` but is {type(image)}\")\n\n        image = image.to(device=device, dtype=dtype)\n\n        if isinstance(generator, list):\n            init_latents = [\n                self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)\n            ]\n            init_latents = torch.cat(init_latents, dim=0)\n        else:\n            init_latents = self.vae.encode(image).latent_dist.sample(generator)\n\n        # Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor\n        init_latents = 0.18215 * init_latents\n        init_latents = init_latents.repeat_interleave(batch_size, dim=0)\n\n        noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)\n\n        # get latents\n        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n        latents = init_latents\n\n        return latents\n\n    def get_image_description(self, image):\n        transformed_image = self.coca_transform(image).unsqueeze(0)\n        with torch.no_grad(), torch.cuda.amp.autocast():\n            generated = self.coca_model.generate(transformed_image.to(device=self.device, dtype=self.coca_model.dtype))\n        generated = self.coca_tokenizer.decode(generated[0].cpu().numpy())\n        return generated.split(\"<end_of_text>\")[0].replace(\"<start_of_text>\", \"\").rstrip(\" .,\")\n\n    def get_clip_image_embeddings(self, image, batch_size):\n        clip_image_input = self.feature_extractor.preprocess(image)\n        clip_image_features = torch.from_numpy(clip_image_input[\"pixel_values\"][0]).unsqueeze(0).to(self.device).half()\n        image_embeddings_clip = self.clip_model.get_image_features(clip_image_features)\n        image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)\n        image_embeddings_clip = image_embeddings_clip.repeat_interleave(batch_size, dim=0)\n        return image_embeddings_clip\n\n    @torch.enable_grad()\n    def cond_fn(\n        self,\n        latents,\n        timestep,\n        index,\n        text_embeddings,\n        noise_pred_original,\n        original_image_embeddings_clip,\n        clip_guidance_scale,\n    ):\n        latents = latents.detach().requires_grad_()\n\n        latent_model_input = self.scheduler.scale_model_input(latents, timestep)\n\n        # predict the noise residual\n        noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample\n\n        if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):\n            alpha_prod_t = self.scheduler.alphas_cumprod[timestep]\n            beta_prod_t = 1 - alpha_prod_t\n            # compute predicted original sample from predicted noise also called\n            # \"predicted x_0\" of formula (12) from https://huggingface.co/papers/2010.02502\n            pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)\n\n            fac = torch.sqrt(beta_prod_t)\n            sample = pred_original_sample * (fac) + latents * (1 - fac)\n        elif isinstance(self.scheduler, LMSDiscreteScheduler):\n            sigma = self.scheduler.sigmas[index]\n            sample = latents - sigma * noise_pred\n        else:\n            raise ValueError(f\"scheduler type {type(self.scheduler)} not supported\")\n\n        # Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor\n        sample = 1 / 0.18215 * sample\n        image = self.vae.decode(sample).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n\n        image = transforms.Resize(self.feature_extractor_size)(image)\n        image = self.normalize(image).to(latents.dtype)\n\n        image_embeddings_clip = self.clip_model.get_image_features(image)\n        image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)\n\n        loss = spherical_dist_loss(image_embeddings_clip, original_image_embeddings_clip).mean() * clip_guidance_scale\n\n        grads = -torch.autograd.grad(loss, latents)[0]\n\n        if isinstance(self.scheduler, LMSDiscreteScheduler):\n            latents = latents.detach() + grads * (sigma**2)\n            noise_pred = noise_pred_original\n        else:\n            noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads\n        return noise_pred, latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        style_image: Union[torch.Tensor, PIL.Image.Image],\n        content_image: Union[torch.Tensor, PIL.Image.Image],\n        style_prompt: str | None = None,\n        content_prompt: str | None = None,\n        height: Optional[int] = 512,\n        width: Optional[int] = 512,\n        noise_strength: float = 0.6,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        batch_size: Optional[int] = 1,\n        eta: float = 0.0,\n        clip_guidance_scale: Optional[float] = 100,\n        generator: torch.Generator | None = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        slerp_latent_style_strength: float = 0.8,\n        slerp_prompt_style_strength: float = 0.1,\n        slerp_clip_image_style_strength: float = 0.1,\n    ):\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(f\"You have passed {batch_size} batch_size, but only {len(generator)} generators.\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if isinstance(generator, torch.Generator) and batch_size > 1:\n            generator = [generator] + [None] * (batch_size - 1)\n\n        coca_is_none = [\n            (\"model\", self.coca_model is None),\n            (\"tokenizer\", self.coca_tokenizer is None),\n            (\"transform\", self.coca_transform is None),\n        ]\n        coca_is_none = [x[0] for x in coca_is_none if x[1]]\n        coca_is_none_str = \", \".join(coca_is_none)\n        # generate prompts with coca model if prompt is None\n        if content_prompt is None:\n            if len(coca_is_none):\n                raise ValueError(\n                    f\"Content prompt is None and CoCa [{coca_is_none_str}] is None.\"\n                    f\"Set prompt or pass Coca [{coca_is_none_str}] to DiffusionPipeline.\"\n                )\n            content_prompt = self.get_image_description(content_image)\n        if style_prompt is None:\n            if len(coca_is_none):\n                raise ValueError(\n                    f\"Style prompt is None and CoCa [{coca_is_none_str}] is None.\"\n                    f\" Set prompt or pass Coca [{coca_is_none_str}] to DiffusionPipeline.\"\n                )\n            style_prompt = self.get_image_description(style_image)\n\n        # get prompt text embeddings for content and style\n        content_text_input = self.tokenizer(\n            content_prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        content_text_embeddings = self.text_encoder(content_text_input.input_ids.to(self.device))[0]\n\n        style_text_input = self.tokenizer(\n            style_prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        style_text_embeddings = self.text_encoder(style_text_input.input_ids.to(self.device))[0]\n\n        text_embeddings = slerp(slerp_prompt_style_strength, content_text_embeddings, style_text_embeddings)\n\n        # duplicate text embeddings for each generation per prompt\n        text_embeddings = text_embeddings.repeat_interleave(batch_size, dim=0)\n\n        # set timesteps\n        accepts_offset = \"offset\" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())\n        extra_set_kwargs = {}\n        if accepts_offset:\n            extra_set_kwargs[\"offset\"] = 1\n\n        self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)\n        # Some schedulers like PNDM have timesteps as arrays\n        # It's more optimized to move all timesteps to correct device beforehand\n        self.scheduler.timesteps.to(self.device)\n\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, noise_strength, self.device)\n        latent_timestep = timesteps[:1].repeat(batch_size)\n\n        # Preprocess image\n        preprocessed_content_image = preprocess(content_image, width, height)\n        content_latents = self.prepare_latents(\n            preprocessed_content_image, latent_timestep, batch_size, text_embeddings.dtype, self.device, generator\n        )\n\n        preprocessed_style_image = preprocess(style_image, width, height)\n        style_latents = self.prepare_latents(\n            preprocessed_style_image, latent_timestep, batch_size, text_embeddings.dtype, self.device, generator\n        )\n\n        latents = slerp(slerp_latent_style_strength, content_latents, style_latents)\n\n        if clip_guidance_scale > 0:\n            content_clip_image_embedding = self.get_clip_image_embeddings(content_image, batch_size)\n            style_clip_image_embedding = self.get_clip_image_embeddings(style_image, batch_size)\n            clip_image_embeddings = slerp(\n                slerp_clip_image_style_strength, content_clip_image_embedding, style_clip_image_embedding\n            )\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            max_length = content_text_input.input_ids.shape[-1]\n            uncond_input = self.tokenizer([\"\"], padding=\"max_length\", max_length=max_length, return_tensors=\"pt\")\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n            # duplicate unconditional embeddings for each generation per prompt\n            uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size, dim=0)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # get the initial random noise unless the user supplied it\n\n        # Unlike in other pipelines, latents need to be generated in the target device\n        # for 1-to-1 results reproducibility with the CompVis implementation.\n        # However this currently doesn't work in `mps`.\n        latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)\n        latents_dtype = text_embeddings.dtype\n        if latents is None:\n            if self.device.type == \"mps\":\n                # randn does not work reproducibly on mps\n                latents = torch.randn(latents_shape, generator=generator, device=\"cpu\", dtype=latents_dtype).to(\n                    self.device\n                )\n            else:\n                latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)\n        else:\n            if latents.shape != latents_shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {latents_shape}\")\n            latents = latents.to(self.device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n\n                # perform classifier free guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # perform clip guidance\n                if clip_guidance_scale > 0:\n                    text_embeddings_for_guidance = (\n                        text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings\n                    )\n                    noise_pred, latents = self.cond_fn(\n                        latents,\n                        t,\n                        i,\n                        text_embeddings_for_guidance,\n                        noise_pred,\n                        clip_image_embeddings,\n                        clip_guidance_scale,\n                    )\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                progress_bar.update()\n        # Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents).sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).numpy()\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image, None)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)\n"
  },
  {
    "path": "examples/community/clip_guided_stable_diffusion.py",
    "content": "import inspect\nfrom typing import List, Optional, Union\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torchvision import transforms\nfrom transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDIMScheduler,\n    DPMSolverMultistepScheduler,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n    UNet2DConditionModel,\n)\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput\n\n\nclass MakeCutouts(nn.Module):\n    def __init__(self, cut_size, cut_power=1.0):\n        super().__init__()\n\n        self.cut_size = cut_size\n        self.cut_power = cut_power\n\n    def forward(self, pixel_values, num_cutouts):\n        sideY, sideX = pixel_values.shape[2:4]\n        max_size = min(sideX, sideY)\n        min_size = min(sideX, sideY, self.cut_size)\n        cutouts = []\n        for _ in range(num_cutouts):\n            size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)\n            offsetx = torch.randint(0, sideX - size + 1, ())\n            offsety = torch.randint(0, sideY - size + 1, ())\n            cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]\n            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))\n        return torch.cat(cutouts)\n\n\ndef spherical_dist_loss(x, y):\n    x = F.normalize(x, dim=-1)\n    y = F.normalize(y, dim=-1)\n    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)\n\n\ndef set_requires_grad(model, value):\n    for param in model.parameters():\n        param.requires_grad = value\n\n\nclass CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin):\n    \"\"\"CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000\n    - https://github.com/Jack000/glid-3-xl\n    - https://github.dev/crowsonkb/k-diffusion\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        clip_model: CLIPModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],\n        feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            clip_model=clip_model,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n        )\n\n        self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)\n        self.cut_out_size = (\n            feature_extractor.size\n            if isinstance(feature_extractor.size, int)\n            else feature_extractor.size[\"shortest_edge\"]\n        )\n        self.make_cutouts = MakeCutouts(self.cut_out_size)\n\n        set_requires_grad(self.text_encoder, False)\n        set_requires_grad(self.clip_model, False)\n\n    def freeze_vae(self):\n        set_requires_grad(self.vae, False)\n\n    def unfreeze_vae(self):\n        set_requires_grad(self.vae, True)\n\n    def freeze_unet(self):\n        set_requires_grad(self.unet, False)\n\n    def unfreeze_unet(self):\n        set_requires_grad(self.unet, True)\n\n    @torch.enable_grad()\n    def cond_fn(\n        self,\n        latents,\n        timestep,\n        index,\n        text_embeddings,\n        noise_pred_original,\n        text_embeddings_clip,\n        clip_guidance_scale,\n        num_cutouts,\n        use_cutouts=True,\n    ):\n        latents = latents.detach().requires_grad_()\n\n        latent_model_input = self.scheduler.scale_model_input(latents, timestep)\n\n        # predict the noise residual\n        noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample\n\n        if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):\n            alpha_prod_t = self.scheduler.alphas_cumprod[timestep]\n            beta_prod_t = 1 - alpha_prod_t\n            # compute predicted original sample from predicted noise also called\n            # \"predicted x_0\" of formula (12) from https://huggingface.co/papers/2010.02502\n            pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)\n\n            fac = torch.sqrt(beta_prod_t)\n            sample = pred_original_sample * (fac) + latents * (1 - fac)\n        elif isinstance(self.scheduler, LMSDiscreteScheduler):\n            sigma = self.scheduler.sigmas[index]\n            sample = latents - sigma * noise_pred\n        else:\n            raise ValueError(f\"scheduler type {type(self.scheduler)} not supported\")\n\n        sample = 1 / self.vae.config.scaling_factor * sample\n        image = self.vae.decode(sample).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n\n        if use_cutouts:\n            image = self.make_cutouts(image, num_cutouts)\n        else:\n            image = transforms.Resize(self.cut_out_size)(image)\n        image = self.normalize(image).to(latents.dtype)\n\n        image_embeddings_clip = self.clip_model.get_image_features(image)\n        image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)\n\n        if use_cutouts:\n            dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)\n            dists = dists.view([num_cutouts, sample.shape[0], -1])\n            loss = dists.sum(2).mean(0).sum() * clip_guidance_scale\n        else:\n            loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale\n\n        grads = -torch.autograd.grad(loss, latents)[0]\n\n        if isinstance(self.scheduler, LMSDiscreteScheduler):\n            latents = latents.detach() + grads * (sigma**2)\n            noise_pred = noise_pred_original\n        else:\n            noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads\n        return noise_pred, latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        height: Optional[int] = 512,\n        width: Optional[int] = 512,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        clip_guidance_scale: Optional[float] = 100,\n        clip_prompt: Optional[Union[str, List[str]]] = None,\n        num_cutouts: Optional[int] = 4,\n        use_cutouts: Optional[bool] = True,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n    ):\n        if isinstance(prompt, str):\n            batch_size = 1\n        elif isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        # get prompt text embeddings\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]\n        # duplicate text embeddings for each generation per prompt\n        text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)\n\n        if clip_guidance_scale > 0:\n            if clip_prompt is not None:\n                clip_text_input = self.tokenizer(\n                    clip_prompt,\n                    padding=\"max_length\",\n                    max_length=self.tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                ).input_ids.to(self.device)\n            else:\n                clip_text_input = text_input.input_ids.to(self.device)\n            text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)\n            text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)\n            # duplicate text embeddings clip for each generation per prompt\n            text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            max_length = text_input.input_ids.shape[-1]\n            uncond_input = self.tokenizer([\"\"], padding=\"max_length\", max_length=max_length, return_tensors=\"pt\")\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n            # duplicate unconditional embeddings for each generation per prompt\n            uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # get the initial random noise unless the user supplied it\n\n        # Unlike in other pipelines, latents need to be generated in the target device\n        # for 1-to-1 results reproducibility with the CompVis implementation.\n        # However this currently doesn't work in `mps`.\n        latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)\n        latents_dtype = text_embeddings.dtype\n        if latents is None:\n            if self.device.type == \"mps\":\n                # randn does not work reproducibly on mps\n                latents = torch.randn(latents_shape, generator=generator, device=\"cpu\", dtype=latents_dtype).to(\n                    self.device\n                )\n            else:\n                latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)\n        else:\n            if latents.shape != latents_shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {latents_shape}\")\n            latents = latents.to(self.device)\n\n        # set timesteps\n        accepts_offset = \"offset\" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())\n        extra_set_kwargs = {}\n        if accepts_offset:\n            extra_set_kwargs[\"offset\"] = 1\n\n        self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)\n\n        # Some schedulers like PNDM have timesteps as arrays\n        # It's more optimized to move all timesteps to correct device beforehand\n        timesteps_tensor = self.scheduler.timesteps.to(self.device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n\n        for i, t in enumerate(self.progress_bar(timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n\n            # perform classifier free guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # perform clip guidance\n            if clip_guidance_scale > 0:\n                text_embeddings_for_guidance = (\n                    text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings\n                )\n                noise_pred, latents = self.cond_fn(\n                    latents,\n                    t,\n                    i,\n                    text_embeddings_for_guidance,\n                    noise_pred,\n                    text_embeddings_clip,\n                    clip_guidance_scale,\n                    num_cutouts,\n                    use_cutouts,\n                )\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n        # scale and decode the image latents with vae\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents).sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).numpy()\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image, None)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)\n"
  },
  {
    "path": "examples/community/clip_guided_stable_diffusion_img2img.py",
    "content": "import inspect\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torchvision import transforms\nfrom transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDIMScheduler,\n    DPMSolverMultistepScheduler,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n    UNet2DConditionModel,\n)\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.utils import PIL_INTERPOLATION, deprecate\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        from io import BytesIO\n\n        import requests\n        import torch\n        from diffusers import DiffusionPipeline\n        from PIL import Image\n        from transformers import CLIPImageProcessor, CLIPModel\n\n        feature_extractor = CLIPImageProcessor.from_pretrained(\n            \"laion/CLIP-ViT-B-32-laion2B-s34B-b79K\"\n        )\n        clip_model = CLIPModel.from_pretrained(\n            \"laion/CLIP-ViT-B-32-laion2B-s34B-b79K\", torch_dtype=torch.float16\n        )\n\n\n        guided_pipeline = DiffusionPipeline.from_pretrained(\n            \"CompVis/stable-diffusion-v1-4\",\n            # custom_pipeline=\"clip_guided_stable_diffusion\",\n            custom_pipeline=\"/home/njindal/diffusers/examples/community/clip_guided_stable_diffusion.py\",\n            clip_model=clip_model,\n            feature_extractor=feature_extractor,\n            torch_dtype=torch.float16,\n        )\n        guided_pipeline.enable_attention_slicing()\n        guided_pipeline = guided_pipeline.to(\"cuda\")\n\n        prompt = \"fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece\"\n\n        url = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\n\n        response = requests.get(url)\n        init_image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n\n        image = guided_pipeline(\n            prompt=prompt,\n            num_inference_steps=30,\n            image=init_image,\n            strength=0.75,\n            guidance_scale=7.5,\n            clip_guidance_scale=100,\n            num_cutouts=4,\n            use_cutouts=False,\n        ).images[0]\n        display(image)\n        ```\n\"\"\"\n\n\ndef preprocess(image, w, h):\n    if isinstance(image, torch.Tensor):\n        return image\n    elif isinstance(image, PIL.Image.Image):\n        image = [image]\n\n    if isinstance(image[0], PIL.Image.Image):\n        image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION[\"lanczos\"]))[None, :] for i in image]\n        image = np.concatenate(image, axis=0)\n        image = np.array(image).astype(np.float32) / 255.0\n        image = image.transpose(0, 3, 1, 2)\n        image = 2.0 * image - 1.0\n        image = torch.from_numpy(image)\n    elif isinstance(image[0], torch.Tensor):\n        image = torch.cat(image, dim=0)\n    return image\n\n\nclass MakeCutouts(nn.Module):\n    def __init__(self, cut_size, cut_power=1.0):\n        super().__init__()\n\n        self.cut_size = cut_size\n        self.cut_power = cut_power\n\n    def forward(self, pixel_values, num_cutouts):\n        sideY, sideX = pixel_values.shape[2:4]\n        max_size = min(sideX, sideY)\n        min_size = min(sideX, sideY, self.cut_size)\n        cutouts = []\n        for _ in range(num_cutouts):\n            size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)\n            offsetx = torch.randint(0, sideX - size + 1, ())\n            offsety = torch.randint(0, sideY - size + 1, ())\n            cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]\n            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))\n        return torch.cat(cutouts)\n\n\ndef spherical_dist_loss(x, y):\n    x = F.normalize(x, dim=-1)\n    y = F.normalize(y, dim=-1)\n    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)\n\n\ndef set_requires_grad(model, value):\n    for param in model.parameters():\n        param.requires_grad = value\n\n\nclass CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin):\n    \"\"\"CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000\n    - https://github.com/Jack000/glid-3-xl\n    - https://github.dev/crowsonkb/k-diffusion\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        clip_model: CLIPModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],\n        feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            clip_model=clip_model,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n        )\n\n        self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)\n        self.cut_out_size = (\n            feature_extractor.size\n            if isinstance(feature_extractor.size, int)\n            else feature_extractor.size[\"shortest_edge\"]\n        )\n        self.make_cutouts = MakeCutouts(self.cut_out_size)\n\n        set_requires_grad(self.text_encoder, False)\n        set_requires_grad(self.clip_model, False)\n\n    def freeze_vae(self):\n        set_requires_grad(self.vae, False)\n\n    def unfreeze_vae(self):\n        set_requires_grad(self.vae, True)\n\n    def freeze_unet(self):\n        set_requires_grad(self.unet, False)\n\n    def unfreeze_unet(self):\n        set_requires_grad(self.unet, True)\n\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start:]\n\n        return timesteps, num_inference_steps - t_start\n\n    def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if isinstance(generator, list):\n            init_latents = [\n                self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)\n            ]\n            init_latents = torch.cat(init_latents, dim=0)\n        else:\n            init_latents = self.vae.encode(image).latent_dist.sample(generator)\n\n        init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            deprecation_message = (\n                f\"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial\"\n                \" images (`image`). Initial images are now duplicating to match the number of text prompts. Note\"\n                \" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update\"\n                \" your script to pass as many initial images as text prompts to suppress this warning.\"\n            )\n            deprecate(\"len(prompt) != len(image)\", \"1.0.0\", deprecation_message, standard_warn=False)\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        shape = init_latents.shape\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        # get latents\n        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n        latents = init_latents\n\n        return latents\n\n    @torch.enable_grad()\n    def cond_fn(\n        self,\n        latents,\n        timestep,\n        index,\n        text_embeddings,\n        noise_pred_original,\n        text_embeddings_clip,\n        clip_guidance_scale,\n        num_cutouts,\n        use_cutouts=True,\n    ):\n        latents = latents.detach().requires_grad_()\n\n        latent_model_input = self.scheduler.scale_model_input(latents, timestep)\n\n        # predict the noise residual\n        noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample\n\n        if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):\n            alpha_prod_t = self.scheduler.alphas_cumprod[timestep]\n            beta_prod_t = 1 - alpha_prod_t\n            # compute predicted original sample from predicted noise also called\n            # \"predicted x_0\" of formula (12) from https://huggingface.co/papers/2010.02502\n            pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)\n\n            fac = torch.sqrt(beta_prod_t)\n            sample = pred_original_sample * (fac) + latents * (1 - fac)\n        elif isinstance(self.scheduler, LMSDiscreteScheduler):\n            sigma = self.scheduler.sigmas[index]\n            sample = latents - sigma * noise_pred\n        else:\n            raise ValueError(f\"scheduler type {type(self.scheduler)} not supported\")\n\n        sample = 1 / self.vae.config.scaling_factor * sample\n        image = self.vae.decode(sample).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n\n        if use_cutouts:\n            image = self.make_cutouts(image, num_cutouts)\n        else:\n            image = transforms.Resize(self.cut_out_size)(image)\n        image = self.normalize(image).to(latents.dtype)\n\n        image_embeddings_clip = self.clip_model.get_image_features(image)\n        image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)\n\n        if use_cutouts:\n            dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)\n            dists = dists.view([num_cutouts, sample.shape[0], -1])\n            loss = dists.sum(2).mean(0).sum() * clip_guidance_scale\n        else:\n            loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale\n\n        grads = -torch.autograd.grad(loss, latents)[0]\n\n        if isinstance(self.scheduler, LMSDiscreteScheduler):\n            latents = latents.detach() + grads * (sigma**2)\n            noise_pred = noise_pred_original\n        else:\n            noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads\n        return noise_pred, latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        height: Optional[int] = 512,\n        width: Optional[int] = 512,\n        image: Union[torch.Tensor, PIL.Image.Image] = None,\n        strength: float = 0.8,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        clip_guidance_scale: Optional[float] = 100,\n        clip_prompt: Optional[Union[str, List[str]]] = None,\n        num_cutouts: Optional[int] = 4,\n        use_cutouts: Optional[bool] = True,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n    ):\n        if isinstance(prompt, str):\n            batch_size = 1\n        elif isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        # get prompt text embeddings\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]\n        # duplicate text embeddings for each generation per prompt\n        text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)\n\n        # set timesteps\n        accepts_offset = \"offset\" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())\n        extra_set_kwargs = {}\n        if accepts_offset:\n            extra_set_kwargs[\"offset\"] = 1\n\n        self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)\n        # Some schedulers like PNDM have timesteps as arrays\n        # It's more optimized to move all timesteps to correct device beforehand\n        self.scheduler.timesteps.to(self.device)\n\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # Preprocess image\n        image = preprocess(image, width, height)\n        if latents is None:\n            latents = self.prepare_latents(\n                image,\n                latent_timestep,\n                batch_size,\n                num_images_per_prompt,\n                text_embeddings.dtype,\n                self.device,\n                generator,\n            )\n\n        if clip_guidance_scale > 0:\n            if clip_prompt is not None:\n                clip_text_input = self.tokenizer(\n                    clip_prompt,\n                    padding=\"max_length\",\n                    max_length=self.tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                ).input_ids.to(self.device)\n            else:\n                clip_text_input = text_input.input_ids.to(self.device)\n            text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)\n            text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)\n            # duplicate text embeddings clip for each generation per prompt\n            text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            max_length = text_input.input_ids.shape[-1]\n            uncond_input = self.tokenizer([\"\"], padding=\"max_length\", max_length=max_length, return_tensors=\"pt\")\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n            # duplicate unconditional embeddings for each generation per prompt\n            uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # get the initial random noise unless the user supplied it\n\n        # Unlike in other pipelines, latents need to be generated in the target device\n        # for 1-to-1 results reproducibility with the CompVis implementation.\n        # However this currently doesn't work in `mps`.\n        latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)\n        latents_dtype = text_embeddings.dtype\n        if latents is None:\n            if self.device.type == \"mps\":\n                # randn does not work reproducibly on mps\n                latents = torch.randn(latents_shape, generator=generator, device=\"cpu\", dtype=latents_dtype).to(\n                    self.device\n                )\n            else:\n                latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)\n        else:\n            if latents.shape != latents_shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {latents_shape}\")\n            latents = latents.to(self.device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n\n        with self.progress_bar(total=num_inference_steps):\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n\n                # perform classifier free guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # perform clip guidance\n                if clip_guidance_scale > 0:\n                    text_embeddings_for_guidance = (\n                        text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings\n                    )\n                    noise_pred, latents = self.cond_fn(\n                        latents,\n                        t,\n                        i,\n                        text_embeddings_for_guidance,\n                        noise_pred,\n                        text_embeddings_clip,\n                        clip_guidance_scale,\n                        num_cutouts,\n                        use_cutouts,\n                    )\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n        # scale and decode the image latents with vae\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents).sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).numpy()\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image, None)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)\n"
  },
  {
    "path": "examples/community/cogvideox_ddim_inversion.py",
    "content": "\"\"\"\nThis script performs DDIM inversion for video frames using a pre-trained model and generates\na video reconstruction based on a provided prompt. It utilizes the CogVideoX pipeline to\nprocess video frames, apply the DDIM inverse scheduler, and produce an output video.\n\n**Please notice that this script is based on the CogVideoX 5B model, and would not generate\na good result for 2B variants.**\n\nUsage:\n    python cogvideox_ddim_inversion.py\n        --model-path /path/to/model\n        --prompt \"a prompt\"\n        --video-path /path/to/video.mp4\n        --output-path /path/to/output\n\nFor more details about the cli arguments, please run `python cogvideox_ddim_inversion.py --help`.\n\nAuthor:\n    LittleNyima <littlenyima[at]163[dot]com>\n\"\"\"\n\nimport argparse\nimport math\nimport os\nfrom typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast\n\nimport torch\nimport torch.nn.functional as F\nimport torchvision.transforms as T\nfrom transformers import T5EncoderModel, T5Tokenizer\n\nfrom diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0\nfrom diffusers.models.autoencoders import AutoencoderKLCogVideoX\nfrom diffusers.models.embeddings import apply_rotary_emb\nfrom diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel\nfrom diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps\nfrom diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler\nfrom diffusers.utils import export_to_video\n\n\n# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error.\n# Very few bug reports but it happens. Look in decord Github issues for more relevant information.\nimport decord  # isort: skip\n\n\nclass DDIMInversionArguments(TypedDict):\n    model_path: str\n    prompt: str\n    video_path: str\n    output_path: str\n    guidance_scale: float\n    num_inference_steps: int\n    skip_frames_start: int\n    skip_frames_end: int\n    frame_sample_step: Optional[int]\n    max_num_frames: int\n    width: int\n    height: int\n    fps: int\n    dtype: torch.dtype\n    seed: int\n    device: torch.device\n\n\ndef get_args() -> DDIMInversionArguments:\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\"--model_path\", type=str, required=True, help=\"Path of the pretrained model\")\n    parser.add_argument(\"--prompt\", type=str, required=True, help=\"Prompt for the direct sample procedure\")\n    parser.add_argument(\"--video_path\", type=str, required=True, help=\"Path of the video for inversion\")\n    parser.add_argument(\"--output_path\", type=str, default=\"output\", help=\"Path of the output videos\")\n    parser.add_argument(\"--guidance_scale\", type=float, default=6.0, help=\"Classifier-free guidance scale\")\n    parser.add_argument(\"--num_inference_steps\", type=int, default=50, help=\"Number of inference steps\")\n    parser.add_argument(\"--skip_frames_start\", type=int, default=0, help=\"Number of skipped frames from the start\")\n    parser.add_argument(\"--skip_frames_end\", type=int, default=0, help=\"Number of skipped frames from the end\")\n    parser.add_argument(\"--frame_sample_step\", type=int, default=None, help=\"Temporal stride of the sampled frames\")\n    parser.add_argument(\"--max_num_frames\", type=int, default=81, help=\"Max number of sampled frames\")\n    parser.add_argument(\"--width\", type=int, default=720, help=\"Resized width of the video frames\")\n    parser.add_argument(\"--height\", type=int, default=480, help=\"Resized height of the video frames\")\n    parser.add_argument(\"--fps\", type=int, default=8, help=\"Frame rate of the output videos\")\n    parser.add_argument(\"--dtype\", type=str, default=\"bf16\", choices=[\"bf16\", \"fp16\"], help=\"Dtype of the model\")\n    parser.add_argument(\"--seed\", type=int, default=42, help=\"Seed for the random number generator\")\n    parser.add_argument(\"--device\", type=str, default=\"cuda\", choices=[\"cuda\", \"cpu\"], help=\"Device for inference\")\n\n    args = parser.parse_args()\n    args.dtype = torch.bfloat16 if args.dtype == \"bf16\" else torch.float16\n    args.device = torch.device(args.device)\n\n    return DDIMInversionArguments(**vars(args))\n\n\nclass CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):\n    def __init__(self):\n        super().__init__()\n\n    def calculate_attention(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        attn: Attention,\n        batch_size: int,\n        image_seq_length: int,\n        text_seq_length: int,\n        attention_mask: Optional[torch.Tensor],\n        image_rotary_emb: Optional[torch.Tensor],\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        r\"\"\"\n        Core attention computation with inversion-guided RoPE integration.\n\n        Args:\n            query (`torch.Tensor`): `[batch_size, seq_len, dim]` query tensor\n            key (`torch.Tensor`): `[batch_size, seq_len, dim]` key tensor\n            value (`torch.Tensor`): `[batch_size, seq_len, dim]` value tensor\n            attn (`Attention`): Parent attention module with projection layers\n            batch_size (`int`): Effective batch size (after chunk splitting)\n            image_seq_length (`int`): Length of image feature sequence\n            text_seq_length (`int`): Length of text feature sequence\n            attention_mask (`Optional[torch.Tensor]`): Attention mask tensor\n            image_rotary_emb (`Optional[torch.Tensor]`): Rotary embeddings for image positions\n\n        Returns:\n            `Tuple[torch.Tensor, torch.Tensor]`:\n                (1) hidden_states: [batch_size, image_seq_length, dim] processed image features\n                (2) encoder_hidden_states: [batch_size, text_seq_length, dim] processed text features\n        \"\"\"\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        if attn.norm_q is not None:\n            query = attn.norm_q(query)\n        if attn.norm_k is not None:\n            key = attn.norm_k(key)\n\n        # Apply RoPE if needed\n        if image_rotary_emb is not None:\n            query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)\n            if not attn.is_cross_attention:\n                if key.size(2) == query.size(2):  # Attention for reference hidden states\n                    key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)\n                else:  # RoPE should be applied to each group of image tokens\n                    key[:, :, text_seq_length : text_seq_length + image_seq_length] = apply_rotary_emb(\n                        key[:, :, text_seq_length : text_seq_length + image_seq_length], image_rotary_emb\n                    )\n                    key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb(\n                        key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb\n                    )\n\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        encoder_hidden_states, hidden_states = hidden_states.split(\n            [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1\n        )\n        return hidden_states, encoder_hidden_states\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        image_rotary_emb: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        r\"\"\"\n        Process the dual-path attention for the inversion-guided denoising procedure.\n\n        Args:\n            attn (`Attention`): Parent attention module\n            hidden_states (`torch.Tensor`): `[batch_size, image_seq_len, dim]` Image tokens\n            encoder_hidden_states (`torch.Tensor`): `[batch_size, text_seq_len, dim]` Text tokens\n            attention_mask (`Optional[torch.Tensor]`): Optional attention mask\n            image_rotary_emb (`Optional[torch.Tensor]`): Rotary embeddings for image tokens\n\n        Returns:\n            `Tuple[torch.Tensor, torch.Tensor]`:\n                (1) Final hidden states: `[batch_size, image_seq_length, dim]` Resulting image tokens\n                (2) Final encoder states: `[batch_size, text_seq_length, dim]` Resulting text tokens\n        \"\"\"\n        image_seq_length = hidden_states.size(1)\n        text_seq_length = encoder_hidden_states.size(1)\n\n        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        query = attn.to_q(hidden_states)\n        key = attn.to_k(hidden_states)\n        value = attn.to_v(hidden_states)\n\n        query, query_reference = query.chunk(2)\n        key, key_reference = key.chunk(2)\n        value, value_reference = value.chunk(2)\n        batch_size = batch_size // 2\n\n        hidden_states, encoder_hidden_states = self.calculate_attention(\n            query=query,\n            key=torch.cat((key, key_reference), dim=1),\n            value=torch.cat((value, value_reference), dim=1),\n            attn=attn,\n            batch_size=batch_size,\n            image_seq_length=image_seq_length,\n            text_seq_length=text_seq_length,\n            attention_mask=attention_mask,\n            image_rotary_emb=image_rotary_emb,\n        )\n        hidden_states_reference, encoder_hidden_states_reference = self.calculate_attention(\n            query=query_reference,\n            key=key_reference,\n            value=value_reference,\n            attn=attn,\n            batch_size=batch_size,\n            image_seq_length=image_seq_length,\n            text_seq_length=text_seq_length,\n            attention_mask=attention_mask,\n            image_rotary_emb=image_rotary_emb,\n        )\n\n        return (\n            torch.cat((hidden_states, hidden_states_reference)),\n            torch.cat((encoder_hidden_states, encoder_hidden_states_reference)),\n        )\n\n\nclass OverrideAttnProcessors:\n    r\"\"\"\n    Context manager for temporarily overriding attention processors in CogVideo transformer blocks.\n\n    Designed for DDIM inversion process, replaces original attention processors with\n    `CogVideoXAttnProcessor2_0ForDDIMInversion` and restores them upon exit. Uses Python context manager\n    pattern to safely manage processor replacement.\n\n    Typical usage:\n    ```python\n    with OverrideAttnProcessors(transformer):\n        # Perform DDIM inversion operations\n    ```\n\n    Args:\n        transformer (`CogVideoXTransformer3DModel`):\n            The transformer model containing attention blocks to be modified. Should have\n            `transformer_blocks` attribute containing `CogVideoXBlock` instances.\n    \"\"\"\n\n    def __init__(self, transformer: CogVideoXTransformer3DModel):\n        self.transformer = transformer\n        self.original_processors = {}\n\n    def __enter__(self):\n        for block in self.transformer.transformer_blocks:\n            block = cast(CogVideoXBlock, block)\n            self.original_processors[id(block)] = block.attn1.get_processor()\n            block.attn1.set_processor(CogVideoXAttnProcessor2_0ForDDIMInversion())\n\n    def __exit__(self, _0, _1, _2):\n        for block in self.transformer.transformer_blocks:\n            block = cast(CogVideoXBlock, block)\n            block.attn1.set_processor(self.original_processors[id(block)])\n\n\ndef get_video_frames(\n    video_path: str,\n    width: int,\n    height: int,\n    skip_frames_start: int,\n    skip_frames_end: int,\n    max_num_frames: int,\n    frame_sample_step: Optional[int],\n) -> torch.FloatTensor:\n    \"\"\"\n    Extract and preprocess video frames from a video file for VAE processing.\n\n    Args:\n        video_path (`str`): Path to input video file\n        width (`int`): Target frame width for decoding\n        height (`int`): Target frame height for decoding\n        skip_frames_start (`int`): Number of frames to skip at video start\n        skip_frames_end (`int`): Number of frames to skip at video end\n        max_num_frames (`int`): Maximum allowed number of output frames\n        frame_sample_step (`Optional[int]`):\n            Frame sampling step size. If None, automatically calculated as:\n            (total_frames - skipped_frames) // max_num_frames\n\n    Returns:\n        `torch.FloatTensor`: Preprocessed frames in `[F, C, H, W]` format where:\n        - `F`: Number of frames (adjusted to 4k + 1 for VAE compatibility)\n        - `C`: Channels (3 for RGB)\n        - `H`: Frame height\n        - `W`: Frame width\n    \"\"\"\n    with decord.bridge.use_torch():\n        video_reader = decord.VideoReader(uri=video_path, width=width, height=height)\n        video_num_frames = len(video_reader)\n        start_frame = min(skip_frames_start, video_num_frames)\n        end_frame = max(0, video_num_frames - skip_frames_end)\n\n        if end_frame <= start_frame:\n            indices = [start_frame]\n        elif end_frame - start_frame <= max_num_frames:\n            indices = list(range(start_frame, end_frame))\n        else:\n            step = frame_sample_step or (end_frame - start_frame) // max_num_frames\n            indices = list(range(start_frame, end_frame, step))\n\n        frames = video_reader.get_batch(indices=indices)\n        frames = frames[:max_num_frames].float()  # ensure that we don't go over the limit\n\n        # Choose first (4k + 1) frames as this is how many is required by the VAE\n        selected_num_frames = frames.size(0)\n        remainder = (3 + selected_num_frames) % 4\n        if remainder != 0:\n            frames = frames[:-remainder]\n        assert frames.size(0) % 4 == 1\n\n        # Normalize the frames\n        transform = T.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)\n        frames = torch.stack(tuple(map(transform, frames)), dim=0)\n\n        return frames.permute(0, 3, 1, 2).contiguous()  # [F, C, H, W]\n\n\nclass CogVideoXDDIMInversionOutput:\n    inverse_latents: torch.FloatTensor\n    recon_latents: torch.FloatTensor\n\n    def __init__(self, inverse_latents: torch.FloatTensor, recon_latents: torch.FloatTensor):\n        self.inverse_latents = inverse_latents\n        self.recon_latents = recon_latents\n\n\nclass CogVideoXPipelineForDDIMInversion(CogVideoXPipeline):\n    def __init__(\n        self,\n        tokenizer: T5Tokenizer,\n        text_encoder: T5EncoderModel,\n        vae: AutoencoderKLCogVideoX,\n        transformer: CogVideoXTransformer3DModel,\n        scheduler: CogVideoXDDIMScheduler,\n    ):\n        super().__init__(\n            tokenizer=tokenizer,\n            text_encoder=text_encoder,\n            vae=vae,\n            transformer=transformer,\n            scheduler=scheduler,\n        )\n        self.inverse_scheduler = DDIMInverseScheduler(**scheduler.config)\n\n    def encode_video_frames(self, video_frames: torch.FloatTensor) -> torch.FloatTensor:\n        \"\"\"\n        Encode video frames into latent space using Variational Autoencoder.\n\n        Args:\n            video_frames (`torch.FloatTensor`):\n                Input frames tensor in `[F, C, H, W]` format from `get_video_frames()`\n\n        Returns:\n            `torch.FloatTensor`: Encoded latents in `[1, F, D, H_latent, W_latent]` format where:\n            - `F`: Number of frames (same as input)\n            - `D`: Latent channel dimension\n            - `H_latent`: Latent space height (H // 2^vae.downscale_factor)\n            - `W_latent`: Latent space width (W // 2^vae.downscale_factor)\n        \"\"\"\n        vae: AutoencoderKLCogVideoX = self.vae\n        video_frames = video_frames.to(device=vae.device, dtype=vae.dtype)\n        video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]\n        latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2)\n        return latent_dist * vae.config.scaling_factor\n\n    @torch.no_grad()\n    def export_latents_to_video(self, latents: torch.FloatTensor, video_path: str, fps: int):\n        r\"\"\"\n        Decode latent vectors into video and export as video file.\n\n        Args:\n            latents (`torch.FloatTensor`): Encoded latents in `[B, F, D, H_latent, W_latent]` format from\n                `encode_video_frames()`\n            video_path (`str`): Output path for video file\n            fps (`int`): Target frames per second for output video\n        \"\"\"\n        video = self.decode_latents(latents)\n        frames = self.video_processor.postprocess_video(video=video, output_type=\"pil\")\n        os.makedirs(os.path.dirname(video_path), exist_ok=True)\n        export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps)\n\n    # Modified from CogVideoXPipeline.__call__\n    @torch.no_grad()\n    def sample(\n        self,\n        latents: torch.FloatTensor,\n        scheduler: Union[DDIMInverseScheduler, CogVideoXDDIMScheduler],\n        prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 6,\n        use_dynamic_cfg: bool = False,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        attention_kwargs: Optional[Dict[str, Any]] = None,\n        reference_latents: torch.FloatTensor = None,\n    ) -> torch.FloatTensor:\n        r\"\"\"\n        Execute the core sampling loop for video generation/inversion using CogVideoX.\n\n        Implements the full denoising trajectory recording for both DDIM inversion and\n        generation processes. Supports dynamic classifier-free guidance and reference\n        latent conditioning.\n\n        Args:\n            latents (`torch.FloatTensor`):\n                Initial noise tensor of shape `[B, F, C, H, W]`.\n            scheduler (`Union[DDIMInverseScheduler, CogVideoXDDIMScheduler]`):\n                Scheduling strategy for diffusion process. Use:\n                (1) `DDIMInverseScheduler` for inversion\n                (2) `CogVideoXDDIMScheduler` for generation\n            prompt (`Optional[Union[str, List[str]]]`):\n                Text prompt(s) for conditional generation. Defaults to unconditional.\n            negative_prompt (`Optional[Union[str, List[str]]]`):\n                Negative prompt(s) for guidance. Requires `guidance_scale > 1`.\n            num_inference_steps (`int`):\n                Number of denoising steps. Affects quality/compute trade-off.\n            guidance_scale (`float`):\n                Classifier-free guidance weight. 1.0 = no guidance.\n            use_dynamic_cfg (`bool`):\n                Enable time-varying guidance scale (cosine schedule)\n            eta (`float`):\n                DDIM variance parameter (0 = deterministic process)\n            generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`):\n                Random number generator(s) for reproducibility\n            attention_kwargs (`Optional[Dict[str, Any]]`):\n                Custom parameters for attention modules\n            reference_latents (`torch.FloatTensor`):\n                Reference latent trajectory for conditional sampling. Shape should match\n                `[T, B, F, C, H, W]` where `T` is number of timesteps\n\n        Returns:\n            `torch.FloatTensor`:\n                Full denoising trajectory tensor of shape `[T, B, F, C, H, W]`.\n        \"\"\"\n        self._guidance_scale = guidance_scale\n        self._attention_kwargs = attention_kwargs\n        self._interrupt = False\n\n        device = self._execution_device\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            negative_prompt,\n            do_classifier_free_guidance,\n            device=device,\n        )\n        if do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n        if reference_latents is not None:\n            prompt_embeds = torch.cat([prompt_embeds] * 2, dim=0)\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device)\n        self._num_timesteps = len(timesteps)\n\n        # 5. Prepare latents.\n        latents = latents.to(device=device) * scheduler.init_noise_sigma\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n        if isinstance(scheduler, DDIMInverseScheduler):  # Inverse scheduler does not accept extra kwargs\n            extra_step_kwargs = {}\n\n        # 7. Create rotary embeds if required\n        image_rotary_emb = (\n            self._prepare_rotary_positional_embeddings(\n                height=latents.size(3) * self.vae_scale_factor_spatial,\n                width=latents.size(4) * self.vae_scale_factor_spatial,\n                num_frames=latents.size(1),\n                device=device,\n            )\n            if self.transformer.config.use_rotary_positional_embeddings\n            else None\n        )\n\n        # 8. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)\n\n        trajectory = torch.zeros_like(latents).unsqueeze(0).repeat(len(timesteps), 1, 1, 1, 1, 1)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                if reference_latents is not None:\n                    reference = reference_latents[i]\n                    reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference\n                    latent_model_input = torch.cat([latent_model_input, reference], dim=0)\n                latent_model_input = scheduler.scale_model_input(latent_model_input, t)\n\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latent_model_input.shape[0])\n\n                # predict noise model_output\n                noise_pred = self.transformer(\n                    hidden_states=latent_model_input,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep=timestep,\n                    image_rotary_emb=image_rotary_emb,\n                    attention_kwargs=attention_kwargs,\n                    return_dict=False,\n                )[0]\n                noise_pred = noise_pred.float()\n\n                if reference_latents is not None:  # Recover the original batch size\n                    noise_pred, _ = noise_pred.chunk(2)\n\n                # perform guidance\n                if use_dynamic_cfg:\n                    self._guidance_scale = 1 + guidance_scale * (\n                        (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2\n                    )\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the noisy sample x_t-1 -> x_t\n                latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n                latents = latents.to(prompt_embeds.dtype)\n                trajectory[i] = latents\n\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):\n                    progress_bar.update()\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        return trajectory\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: str,\n        video_path: str,\n        guidance_scale: float,\n        num_inference_steps: int,\n        skip_frames_start: int,\n        skip_frames_end: int,\n        frame_sample_step: Optional[int],\n        max_num_frames: int,\n        width: int,\n        height: int,\n        seed: int,\n    ):\n        \"\"\"\n        Performs DDIM inversion on a video to reconstruct it with a new prompt.\n\n        Args:\n            prompt (`str`): The text prompt to guide the reconstruction.\n            video_path (`str`): Path to the input video file.\n            guidance_scale (`float`): Scale for classifier-free guidance.\n            num_inference_steps (`int`): Number of denoising steps.\n            skip_frames_start (`int`): Number of frames to skip from the beginning of the video.\n            skip_frames_end (`int`): Number of frames to skip from the end of the video.\n            frame_sample_step (`Optional[int]`): Step size for sampling frames. If None, all frames are used.\n            max_num_frames (`int`): Maximum number of frames to process.\n            width (`int`): Width of the output video frames.\n            height (`int`): Height of the output video frames.\n            seed (`int`): Random seed for reproducibility.\n\n        Returns:\n            `CogVideoXDDIMInversionOutput`: Contains the inverse latents and reconstructed latents.\n        \"\"\"\n        if not self.transformer.config.use_rotary_positional_embeddings:\n            raise NotImplementedError(\"This script supports CogVideoX 5B model only.\")\n        video_frames = get_video_frames(\n            video_path=video_path,\n            width=width,\n            height=height,\n            skip_frames_start=skip_frames_start,\n            skip_frames_end=skip_frames_end,\n            max_num_frames=max_num_frames,\n            frame_sample_step=frame_sample_step,\n        ).to(device=self.device)\n        video_latents = self.encode_video_frames(video_frames=video_frames)\n        inverse_latents = self.sample(\n            latents=video_latents,\n            scheduler=self.inverse_scheduler,\n            prompt=\"\",\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            generator=torch.Generator(device=self.device).manual_seed(seed),\n        )\n        with OverrideAttnProcessors(transformer=self.transformer):\n            recon_latents = self.sample(\n                latents=torch.randn_like(video_latents),\n                scheduler=self.scheduler,\n                prompt=prompt,\n                num_inference_steps=num_inference_steps,\n                guidance_scale=guidance_scale,\n                generator=torch.Generator(device=self.device).manual_seed(seed),\n                reference_latents=reversed(inverse_latents),\n            )\n        return CogVideoXDDIMInversionOutput(\n            inverse_latents=inverse_latents,\n            recon_latents=recon_latents,\n        )\n\n\nif __name__ == \"__main__\":\n    arguments = get_args()\n    pipeline = CogVideoXPipelineForDDIMInversion.from_pretrained(\n        arguments.pop(\"model_path\"),\n        torch_dtype=arguments.pop(\"dtype\"),\n    ).to(device=arguments.pop(\"device\"))\n\n    output_path = arguments.pop(\"output_path\")\n    fps = arguments.pop(\"fps\")\n    inverse_video_path = os.path.join(output_path, f\"{arguments.get('video_path')}_inversion.mp4\")\n    recon_video_path = os.path.join(output_path, f\"{arguments.get('video_path')}_reconstruction.mp4\")\n\n    # Run DDIM inversion\n    output = pipeline(**arguments)\n    pipeline.export_latents_to_video(output.inverse_latents[-1], inverse_video_path, fps)\n    pipeline.export_latents_to_video(output.recon_latents[-1], recon_video_path, fps)\n"
  },
  {
    "path": "examples/community/composable_stable_diffusion.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Callable, List, Optional, Union\n\nimport torch\nfrom packaging import version\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import (\n    DDIMScheduler,\n    DPMSolverMultistepScheduler,\n    EulerAncestralDiscreteScheduler,\n    EulerDiscreteScheduler,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n)\nfrom diffusers.utils import deprecate, logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[\n            DDIMScheduler,\n            PNDMScheduler,\n            LMSDiscreteScheduler,\n            EulerDiscreteScheduler,\n            EulerAncestralDiscreteScheduler,\n            DPMSolverMultistepScheduler,\n        ],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `list(int)`):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n        \"\"\"\n        batch_size = len(prompt) if isinstance(prompt, list) else 1\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n            )\n\n        if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n            attention_mask = text_inputs.attention_mask.to(device)\n        else:\n            attention_mask = None\n\n        text_embeddings = self.text_encoder(\n            text_input_ids.to(device),\n            attention_mask=attention_mask,\n        )\n        text_embeddings = text_embeddings[0]\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = text_embeddings.shape\n        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)\n        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = text_input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            uncond_embeddings = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            uncond_embeddings = uncond_embeddings[0]\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = uncond_embeddings.shape[1]\n            uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)\n            uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        return text_embeddings\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        else:\n            has_nsfw_concept = None\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(self, prompt, height, width, callback_steps):\n        if not isinstance(prompt, str) and not isinstance(prompt, list):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if latents is None:\n            if device.type == \"mps\":\n                # randn does not work reproducibly on mps\n                latents = torch.randn(shape, generator=generator, device=\"cpu\", dtype=dtype).to(device)\n            else:\n                latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            if latents.shape != shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {shape}\")\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        weights: str | None = \"\",\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(prompt, height, width, callback_steps)\n\n        # 2. Define call parameters\n        batch_size = 1 if isinstance(prompt, str) else len(prompt)\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        if \"|\" in prompt:\n            prompt = [x.strip() for x in prompt.split(\"|\")]\n            print(f\"composing {prompt}...\")\n\n            if not weights:\n                # specify weights for prompts (excluding the unconditional score)\n                print(\"using equal positive weights (conjunction) for all prompts...\")\n                weights = torch.tensor([guidance_scale] * len(prompt), device=self.device).reshape(-1, 1, 1, 1)\n            else:\n                # set prompt weight for each\n                num_prompts = len(prompt) if isinstance(prompt, list) else 1\n                weights = [float(w.strip()) for w in weights.split(\"|\")]\n                # guidance scale as the default\n                if len(weights) < num_prompts:\n                    weights.append(guidance_scale)\n                else:\n                    weights = weights[:num_prompts]\n                assert len(weights) == len(prompt), \"weights specified are not equal to the number of prompts\"\n                weights = torch.tensor(weights, device=self.device).reshape(-1, 1, 1, 1)\n        else:\n            weights = guidance_scale\n\n        # 3. Encode input prompt\n        text_embeddings = self._encode_prompt(\n            prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt\n        )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            text_embeddings.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # composable diffusion\n        if isinstance(prompt, list) and batch_size == 1:\n            # remove extra unconditional embedding\n            # N = one unconditional embed + conditional embeds\n            text_embeddings = text_embeddings[len(prompt) - 1 :]\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = []\n                for j in range(text_embeddings.shape[0]):\n                    noise_pred.append(\n                        self.unet(latent_model_input[:1], t, encoder_hidden_states=text_embeddings[j : j + 1]).sample\n                    )\n                noise_pred = torch.cat(noise_pred, dim=0)\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred[:1], noise_pred[1:]\n                    noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(\n                        dim=0, keepdims=True\n                    )\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        # 8. Post-processing\n        image = self.decode_latents(latents)\n\n        # 9. Run safety checker\n        image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)\n\n        # 10. Convert to PIL\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/ddim_noise_comparative_analysis.py",
    "content": "# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional, Tuple, Union\n\nimport PIL.Image\nimport torch\nfrom torchvision import transforms\n\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput\nfrom diffusers.schedulers import DDIMScheduler\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\ntrans = transforms.Compose(\n    [\n        transforms.Resize((256, 256)),\n        transforms.ToTensor(),\n        transforms.Normalize([0.5], [0.5]),\n    ]\n)\n\n\ndef preprocess(image):\n    if isinstance(image, torch.Tensor):\n        return image\n    elif isinstance(image, PIL.Image.Image):\n        image = [image]\n\n    image = [trans(img.convert(\"RGB\")) for img in image]\n    image = torch.stack(image)\n    return image\n\n\nclass DDIMNoiseComparativeAnalysisPipeline(DiffusionPipeline):\n    r\"\"\"\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Parameters:\n        unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of\n            [`DDPMScheduler`], or [`DDIMScheduler`].\n    \"\"\"\n\n    def __init__(self, unet, scheduler):\n        super().__init__()\n\n        # make sure scheduler can always be converted to DDIM\n        scheduler = DDIMScheduler.from_config(scheduler.config)\n\n        self.register_modules(unet=unet, scheduler=scheduler)\n\n    def check_inputs(self, strength):\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start:]\n\n        return timesteps, num_inference_steps - t_start\n\n    def prepare_latents(self, image, timestep, batch_size, dtype, device, generator=None):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        init_latents = image.to(device=device, dtype=dtype)\n\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        shape = init_latents.shape\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        # get latents\n        print(\"add noise to latents at timestep\", timestep)\n        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n        latents = init_latents\n\n        return latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        image: Union[torch.Tensor, PIL.Image.Image] = None,\n        strength: float = 0.8,\n        batch_size: int = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        eta: float = 0.0,\n        num_inference_steps: int = 50,\n        use_clipped_model_output: Optional[bool] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n    ) -> Union[ImagePipelineOutput, Tuple]:\n        r\"\"\"\n        Args:\n            image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, that will be used as the starting point for the\n                process.\n            strength (`float`, *optional*, defaults to 0.8):\n                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`\n                will be used as a starting point, adding more noise to it the larger the `strength`. The number of\n                denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will\n                be maximum and the denoising process will run for the full number of iterations specified in\n                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.\n            batch_size (`int`, *optional*, defaults to 1):\n                The number of images to generate.\n            generator (`torch.Generator`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            eta (`float`, *optional*, defaults to 0.0):\n                The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            use_clipped_model_output (`bool`, *optional*, defaults to `None`):\n                if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed\n                downstream to the scheduler. So use `None` for schedulers which don't support this argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is\n            True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(strength)\n\n        # 2. Preprocess image\n        image = preprocess(image)\n\n        # 3. set timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=self.device)\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device)\n        latent_timestep = timesteps[:1].repeat(batch_size)\n\n        # 4. Prepare latent variables\n        latents = self.prepare_latents(image, latent_timestep, batch_size, self.unet.dtype, self.device, generator)\n        image = latents\n\n        # 5. Denoising loop\n        for t in self.progress_bar(timesteps):\n            # 1. predict noise model_output\n            model_output = self.unet(image, t).sample\n\n            # 2. predict previous mean of image x_t-1 and add variance depending on eta\n            # eta corresponds to η in paper and should be between [0, 1]\n            # do x_t -> x_t-1\n            image = self.scheduler.step(\n                model_output,\n                t,\n                image,\n                eta=eta,\n                use_clipped_model_output=use_clipped_model_output,\n                generator=generator,\n            ).prev_sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).numpy()\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image, latent_timestep.item())\n\n        return ImagePipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/dps_pipeline.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom math import pi\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom diffusers import DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DModel\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nclass DPSPipeline(DiffusionPipeline):\n    r\"\"\"\n    Pipeline for Diffusion Posterior Sampling.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    Parameters:\n        unet ([`UNet2DModel`]):\n            A `UNet2DModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of\n            [`DDPMScheduler`], or [`DDIMScheduler`].\n    \"\"\"\n\n    model_cpu_offload_seq = \"unet\"\n\n    def __init__(self, unet, scheduler):\n        super().__init__()\n        self.register_modules(unet=unet, scheduler=scheduler)\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        measurement: torch.Tensor,\n        operator: torch.nn.Module,\n        loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],\n        batch_size: int = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        num_inference_steps: int = 1000,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        zeta: float = 0.3,\n    ) -> Union[ImagePipelineOutput, Tuple]:\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            measurement (`torch.Tensor`, *required*):\n                A 'torch.Tensor', the corrupted image\n            operator (`torch.nn.Module`, *required*):\n                A 'torch.nn.Module', the operator generating the corrupted image\n            loss_fn (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *required*):\n                A 'Callable[[torch.Tensor, torch.Tensor], torch.Tensor]', the loss function used\n                between the measurements, for most of the cases using RMSE is fine.\n            batch_size (`int`, *optional*, defaults to 1):\n                The number of images to generate.\n            generator (`torch.Generator`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            num_inference_steps (`int`, *optional*, defaults to 1000):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.\n\n        Example:\n\n        ```py\n        >>> from diffusers import DDPMPipeline\n\n        >>> # load model and scheduler\n        >>> pipe = DDPMPipeline.from_pretrained(\"google/ddpm-cat-256\")\n\n        >>> # run pipeline in inference (sample random noise and denoise)\n        >>> image = pipe().images[0]\n\n        >>> # save image\n        >>> image.save(\"ddpm_generated_image.png\")\n        ```\n\n        Returns:\n            [`~pipelines.ImagePipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is\n                returned where the first element is a list with the generated images\n        \"\"\"\n        # Sample gaussian noise to begin loop\n        if isinstance(self.unet.config.sample_size, int):\n            image_shape = (\n                batch_size,\n                self.unet.config.in_channels,\n                self.unet.config.sample_size,\n                self.unet.config.sample_size,\n            )\n        else:\n            image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)\n\n        if self.device.type == \"mps\":\n            # randn does not work reproducibly on mps\n            image = randn_tensor(image_shape, generator=generator)\n            image = image.to(self.device)\n        else:\n            image = randn_tensor(image_shape, generator=generator, device=self.device)\n\n        # set step values\n        self.scheduler.set_timesteps(num_inference_steps)\n\n        for t in self.progress_bar(self.scheduler.timesteps):\n            with torch.enable_grad():\n                # 1. predict noise model_output\n                image = image.requires_grad_()\n                model_output = self.unet(image, t).sample\n\n                # 2. compute previous image x'_{t-1} and original prediction x0_{t}\n                scheduler_out = self.scheduler.step(model_output, t, image, generator=generator)\n                image_pred, origi_pred = scheduler_out.prev_sample, scheduler_out.pred_original_sample\n\n                # 3. compute y'_t = f(x0_{t})\n                measurement_pred = operator(origi_pred)\n\n                # 4. compute loss = d(y, y'_t-1)\n                loss = loss_fn(measurement, measurement_pred)\n                loss.backward()\n\n                print(\"distance: {0:.4f}\".format(loss.item()))\n\n                with torch.no_grad():\n                    image_pred = image_pred - zeta * image.grad\n                    image = image_pred.detach()\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).numpy()\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image,)\n\n        return ImagePipelineOutput(images=image)\n\n\nif __name__ == \"__main__\":\n    import scipy\n    from torch import nn\n    from torchvision.utils import save_image\n\n    # defining the operators f(.) of y = f(x)\n    # super-resolution operator\n    class SuperResolutionOperator(nn.Module):\n        def __init__(self, in_shape, scale_factor):\n            super().__init__()\n\n            # Resizer local class, do not use outiside the SR operator class\n            class Resizer(nn.Module):\n                def __init__(self, in_shape, scale_factor=None, output_shape=None, kernel=None, antialiasing=True):\n                    super(Resizer, self).__init__()\n\n                    # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa\n                    scale_factor, output_shape = self.fix_scale_and_size(in_shape, output_shape, scale_factor)\n\n                    # Choose interpolation method, each method has the matching kernel size\n                    def cubic(x):\n                        absx = np.abs(x)\n                        absx2 = absx**2\n                        absx3 = absx**3\n                        return (1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) + (\n                            -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2\n                        ) * ((1 < absx) & (absx <= 2))\n\n                    def lanczos2(x):\n                        return (\n                            (np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps)\n                            / ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps)\n                        ) * (abs(x) < 2)\n\n                    def box(x):\n                        return ((-0.5 <= x) & (x < 0.5)) * 1.0\n\n                    def lanczos3(x):\n                        return (\n                            (np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps)\n                            / ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps)\n                        ) * (abs(x) < 3)\n\n                    def linear(x):\n                        return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1))\n\n                    method, kernel_width = {\n                        \"cubic\": (cubic, 4.0),\n                        \"lanczos2\": (lanczos2, 4.0),\n                        \"lanczos3\": (lanczos3, 6.0),\n                        \"box\": (box, 1.0),\n                        \"linear\": (linear, 2.0),\n                        None: (cubic, 4.0),  # set default interpolation method as cubic\n                    }.get(kernel)\n\n                    # Antialiasing is only used when downscaling\n                    antialiasing *= np.any(np.array(scale_factor) < 1)\n\n                    # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient\n                    sorted_dims = np.argsort(np.array(scale_factor))\n                    self.sorted_dims = [int(dim) for dim in sorted_dims if scale_factor[dim] != 1]\n\n                    # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction\n                    field_of_view_list = []\n                    weights_list = []\n                    for dim in self.sorted_dims:\n                        # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the\n                        # weights that multiply the values there to get its result.\n                        weights, field_of_view = self.contributions(\n                            in_shape[dim], output_shape[dim], scale_factor[dim], method, kernel_width, antialiasing\n                        )\n\n                        # convert to torch tensor\n                        weights = torch.tensor(weights.T, dtype=torch.float32)\n\n                        # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for\n                        # tmp_im[field_of_view.T], (bsxfun style)\n                        weights_list.append(\n                            nn.Parameter(\n                                torch.reshape(weights, list(weights.shape) + (len(scale_factor) - 1) * [1]),\n                                requires_grad=False,\n                            )\n                        )\n                        field_of_view_list.append(\n                            nn.Parameter(\n                                torch.tensor(field_of_view.T.astype(np.int32), dtype=torch.long), requires_grad=False\n                            )\n                        )\n\n                    self.field_of_view = nn.ParameterList(field_of_view_list)\n                    self.weights = nn.ParameterList(weights_list)\n\n                def forward(self, in_tensor):\n                    x = in_tensor\n\n                    # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim\n                    for dim, fov, w in zip(self.sorted_dims, self.field_of_view, self.weights):\n                        # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize\n                        x = torch.transpose(x, dim, 0)\n\n                        # This is a bit of a complicated multiplication: x[field_of_view.T] is a tensor of order image_dims+1.\n                        # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim\n                        # only, this is why it only adds 1 dim to 5the shape). We then multiply, for each pixel, its set of positions with\n                        # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style:\n                        # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the\n                        # same number\n                        x = torch.sum(x[fov] * w, dim=0)\n\n                        # Finally we swap back the axes to the original order\n                        x = torch.transpose(x, dim, 0)\n\n                    return x\n\n                def fix_scale_and_size(self, input_shape, output_shape, scale_factor):\n                    # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the\n                    # same size as the number of input dimensions)\n                    if scale_factor is not None:\n                        # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it.\n                        if np.isscalar(scale_factor) and len(input_shape) > 1:\n                            scale_factor = [scale_factor, scale_factor]\n\n                        # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales\n                        scale_factor = list(scale_factor)\n                        scale_factor = [1] * (len(input_shape) - len(scale_factor)) + scale_factor\n\n                    # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size\n                    # to all the unspecified dimensions\n                    if output_shape is not None:\n                        output_shape = list(input_shape[len(output_shape) :]) + list(np.uint(np.array(output_shape)))\n\n                    # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is\n                    # sub-optimal, because there can be different scales to the same output-shape.\n                    if scale_factor is None:\n                        scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape)\n\n                    # Dealing with missing output-shape. calculating according to scale-factor\n                    if output_shape is None:\n                        output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor)))\n\n                    return scale_factor, output_shape\n\n                def contributions(self, in_length, out_length, scale, kernel, kernel_width, antialiasing):\n                    # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied\n                    # such that each position from the field_of_view will be multiplied with a matching filter from the\n                    # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers\n                    # around it. This is only done for one dimension of the image.\n\n                    # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of\n                    # 1/sf. this means filtering is more 'low-pass filter'.\n                    fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel\n                    kernel_width *= 1.0 / scale if antialiasing else 1.0\n\n                    # These are the coordinates of the output image\n                    out_coordinates = np.arange(1, out_length + 1)\n\n                    # since both scale-factor and output size can be provided simultaneously, preserving the center of the image requires shifting\n                    # the output coordinates. the deviation is because out_length doesn't necessary equal in_length*scale.\n                    # to keep the center we need to subtract half of this deviation so that we get equal margins for both sides and center is preserved.\n                    shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2\n\n                    # These are the matching positions of the output-coordinates on the input image coordinates.\n                    # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels:\n                    # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel.\n                    # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to\n                    # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big\n                    # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor).\n                    # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is\n                    # at d=1, and so on (d = p - 0.5).  we calculate (d_new = d_old / sf) which means:\n                    # (p_new-0.5 = (p_old-0.5) / sf)     ->          p_new = p_old/sf + 0.5 * (1-1/sf)\n                    match_coordinates = shifted_out_coordinates / scale + 0.5 * (1 - 1 / scale)\n\n                    # This is the left boundary to start multiplying the filter from, it depends on the size of the filter\n                    left_boundary = np.floor(match_coordinates - kernel_width / 2)\n\n                    # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers\n                    # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them)\n                    expanded_kernel_width = np.ceil(kernel_width) + 2\n\n                    # Determine a set of field_of_view for each each output position, these are the pixels in the input image\n                    # that the pixel in the output image 'sees'. We get a matrix whose horizontal dim is the output pixels (big) and the\n                    # vertical dim is the pixels it 'sees' (kernel_size + 2)\n                    field_of_view = np.squeeze(\n                        np.int16(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)\n                    )\n\n                    # Assign weight to each pixel in the field of view. A matrix whose horizontal dim is the output pixels and the\n                    # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in\n                    # 'field_of_view')\n                    weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1)\n\n                    # Normalize weights to sum up to 1. be careful from dividing by 0\n                    sum_weights = np.sum(weights, axis=1)\n                    sum_weights[sum_weights == 0] = 1.0\n                    weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1)\n\n                    # We use this mirror structure as a trick for reflection padding at the boundaries\n                    mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))))\n                    field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])]\n\n                    # Get rid of  weights and pixel positions that are of zero weight\n                    non_zero_out_pixels = np.nonzero(np.any(weights, axis=0))\n                    weights = np.squeeze(weights[:, non_zero_out_pixels])\n                    field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels])\n\n                    # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size\n                    return weights, field_of_view\n\n            self.down_sample = Resizer(in_shape, 1 / scale_factor)\n            for param in self.parameters():\n                param.requires_grad = False\n\n        def forward(self, data, **kwargs):\n            return self.down_sample(data)\n\n    # Gaussian blurring operator\n    class GaussialBlurOperator(nn.Module):\n        def __init__(self, kernel_size, intensity):\n            super().__init__()\n\n            class Blurkernel(nn.Module):\n                def __init__(self, blur_type=\"gaussian\", kernel_size=31, std=3.0):\n                    super().__init__()\n                    self.blur_type = blur_type\n                    self.kernel_size = kernel_size\n                    self.std = std\n                    self.seq = nn.Sequential(\n                        nn.ReflectionPad2d(self.kernel_size // 2),\n                        nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3),\n                    )\n                    self.weights_init()\n\n                def forward(self, x):\n                    return self.seq(x)\n\n                def weights_init(self):\n                    if self.blur_type == \"gaussian\":\n                        n = np.zeros((self.kernel_size, self.kernel_size))\n                        n[self.kernel_size // 2, self.kernel_size // 2] = 1\n                        k = scipy.ndimage.gaussian_filter(n, sigma=self.std)\n                        k = torch.from_numpy(k)\n                        self.k = k\n                        for name, f in self.named_parameters():\n                            f.data.copy_(k)\n\n                def update_weights(self, k):\n                    if not torch.is_tensor(k):\n                        k = torch.from_numpy(k)\n                    for name, f in self.named_parameters():\n                        f.data.copy_(k)\n\n                def get_kernel(self):\n                    return self.k\n\n            self.kernel_size = kernel_size\n            self.conv = Blurkernel(blur_type=\"gaussian\", kernel_size=kernel_size, std=intensity)\n            self.kernel = self.conv.get_kernel()\n            self.conv.update_weights(self.kernel.type(torch.float32))\n\n            for param in self.parameters():\n                param.requires_grad = False\n\n        def forward(self, data, **kwargs):\n            return self.conv(data)\n\n        def transpose(self, data, **kwargs):\n            return data\n\n        def get_kernel(self):\n            return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)\n\n    # assuming the forward process y = f(x) is polluted by Gaussian noise, use l2 norm\n    def RMSELoss(yhat, y):\n        return torch.sqrt(torch.sum((yhat - y) ** 2))\n\n    # set up source image\n    src = Image.open(\"sample.png\")\n    # read image into [1,3,H,W]\n    src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2, 0, 1)[None]\n    # normalize image to [-1,1]\n    src = (src / 127.5) - 1.0\n    src = src.to(\"cuda\")\n\n    # set up operator and measurement\n    # operator = SuperResolutionOperator(in_shape=src.shape, scale_factor=4).to(\"cuda\")\n    operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to(\"cuda\")\n    measurement = operator(src)\n\n    # set up scheduler\n    scheduler = DDPMScheduler.from_pretrained(\"google/ddpm-celebahq-256\")\n    scheduler.set_timesteps(1000)\n\n    # set up model\n    model = UNet2DModel.from_pretrained(\"google/ddpm-celebahq-256\").to(\"cuda\")\n\n    save_image((src + 1.0) / 2.0, \"dps_src.png\")\n    save_image((measurement + 1.0) / 2.0, \"dps_mea.png\")\n\n    # finally, the pipeline\n    dpspipe = DPSPipeline(model, scheduler)\n    image = dpspipe(\n        measurement=measurement,\n        operator=operator,\n        loss_fn=RMSELoss,\n        zeta=1.0,\n    ).images[0]\n\n    image.save(\"dps_generated_image.png\")\n"
  },
  {
    "path": "examples/community/edict_pipeline.py",
    "content": "import torch\nfrom PIL import Image\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, UNet2DConditionModel\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.utils import (\n    deprecate,\n)\n\n\nclass EDICTPipeline(DiffusionPipeline):\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: DDIMScheduler,\n        mixing_coeff: float = 0.93,\n        leapfrog_steps: bool = True,\n    ):\n        self.mixing_coeff = mixing_coeff\n        self.leapfrog_steps = leapfrog_steps\n\n        super().__init__()\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n        )\n\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n    def _encode_prompt(\n        self, prompt: str, negative_prompt: str | None = None, do_classifier_free_guidance: bool = False\n    ):\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        prompt_embeds = self.text_encoder(text_inputs.input_ids.to(self.device)).last_hidden_state\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=self.device)\n\n        if do_classifier_free_guidance:\n            uncond_tokens = \"\" if negative_prompt is None else negative_prompt\n\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device)).last_hidden_state\n\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    def denoise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor):\n        x = self.mixing_coeff * x + (1 - self.mixing_coeff) * y\n        y = self.mixing_coeff * y + (1 - self.mixing_coeff) * x\n\n        return [x, y]\n\n    def noise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor):\n        y = (y - (1 - self.mixing_coeff) * x) / self.mixing_coeff\n        x = (x - (1 - self.mixing_coeff) * y) / self.mixing_coeff\n\n        return [x, y]\n\n    def _get_alpha_and_beta(self, t: torch.Tensor):\n        # as self.alphas_cumprod is always in cpu\n        t = int(t)\n\n        alpha_prod = self.scheduler.alphas_cumprod[t] if t >= 0 else self.scheduler.final_alpha_cumprod\n\n        return alpha_prod, 1 - alpha_prod\n\n    def noise_step(\n        self,\n        base: torch.Tensor,\n        model_input: torch.Tensor,\n        model_output: torch.Tensor,\n        timestep: torch.Tensor,\n    ):\n        prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps\n\n        alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep)\n        alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep)\n\n        a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5\n        b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5\n\n        next_model_input = (base - b_t * model_output) / a_t\n\n        return model_input, next_model_input.to(base.dtype)\n\n    def denoise_step(\n        self,\n        base: torch.Tensor,\n        model_input: torch.Tensor,\n        model_output: torch.Tensor,\n        timestep: torch.Tensor,\n    ):\n        prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps\n\n        alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep)\n        alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep)\n\n        a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5\n        b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5\n        next_model_input = a_t * base + b_t * model_output\n\n        return model_input, next_model_input.to(base.dtype)\n\n    @torch.no_grad()\n    def decode_latents(self, latents: torch.Tensor):\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n        return image\n\n    @torch.no_grad()\n    def prepare_latents(\n        self,\n        image: Image.Image,\n        text_embeds: torch.Tensor,\n        timesteps: torch.Tensor,\n        guidance_scale: float,\n        generator: torch.Generator | None = None,\n    ):\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        image = image.to(device=self.device, dtype=text_embeds.dtype)\n        latent = self.vae.encode(image).latent_dist.sample(generator)\n\n        latent = self.vae.config.scaling_factor * latent\n\n        coupled_latents = [latent.clone(), latent.clone()]\n\n        for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):\n            coupled_latents = self.noise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1])\n\n            # j - model_input index, k - base index\n            for j in range(2):\n                k = j ^ 1\n\n                if self.leapfrog_steps:\n                    if i % 2 == 0:\n                        k, j = j, k\n\n                model_input = coupled_latents[j]\n                base = coupled_latents[k]\n\n                latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input\n\n                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds).sample\n\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                base, model_input = self.noise_step(\n                    base=base,\n                    model_input=model_input,\n                    model_output=noise_pred,\n                    timestep=t,\n                )\n\n                coupled_latents[k] = model_input\n\n        return coupled_latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        base_prompt: str,\n        target_prompt: str,\n        image: Image.Image,\n        guidance_scale: float = 3.0,\n        num_inference_steps: int = 50,\n        strength: float = 0.8,\n        negative_prompt: str | None = None,\n        generator: torch.Generator | None = None,\n        output_type: str | None = \"pil\",\n    ):\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        image = self.image_processor.preprocess(image)\n\n        base_embeds = self._encode_prompt(base_prompt, negative_prompt, do_classifier_free_guidance)\n        target_embeds = self._encode_prompt(target_prompt, negative_prompt, do_classifier_free_guidance)\n\n        self.scheduler.set_timesteps(num_inference_steps, self.device)\n\n        t_limit = num_inference_steps - int(num_inference_steps * strength)\n        fwd_timesteps = self.scheduler.timesteps[t_limit:]\n        bwd_timesteps = fwd_timesteps.flip(0)\n\n        coupled_latents = self.prepare_latents(image, base_embeds, bwd_timesteps, guidance_scale, generator)\n\n        for i, t in tqdm(enumerate(fwd_timesteps), total=len(fwd_timesteps)):\n            # j - model_input index, k - base index\n            for k in range(2):\n                j = k ^ 1\n\n                if self.leapfrog_steps:\n                    if i % 2 == 1:\n                        k, j = j, k\n\n                model_input = coupled_latents[j]\n                base = coupled_latents[k]\n\n                latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input\n\n                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=target_embeds).sample\n\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                base, model_input = self.denoise_step(\n                    base=base,\n                    model_input=model_input,\n                    model_output=noise_pred,\n                    timestep=t,\n                )\n\n                coupled_latents[k] = model_input\n\n            coupled_latents = self.denoise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1])\n\n        # either one is fine\n        final_latent = coupled_latents[0]\n\n        if output_type not in [\"latent\", \"pt\", \"np\", \"pil\"]:\n            deprecation_message = (\n                f\"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: \"\n                \"`pil`, `np`, `pt`, `latent`\"\n            )\n            deprecate(\"Unsupported output_type\", \"1.0.0\", deprecation_message, standard_warn=False)\n            output_type = \"np\"\n\n        if output_type == \"latent\":\n            image = final_latent\n        else:\n            image = self.decode_latents(final_latent)\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        return image\n"
  },
  {
    "path": "examples/community/fresco_v2v.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gc\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.model_zoo\nfrom einops import rearrange, repeat\nfrom gmflow.gmflow import GMFlow\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel\nfrom diffusers.models.attention_processor import AttnProcessor2_0\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef clear_cache():\n    gc.collect()\n    torch.cuda.empty_cache()\n\n\ndef coords_grid(b, h, w, homogeneous=False, device=None):\n    y, x = torch.meshgrid(torch.arange(h), torch.arange(w))  # [H, W]\n\n    stacks = [x, y]\n\n    if homogeneous:\n        ones = torch.ones_like(x)  # [H, W]\n        stacks.append(ones)\n\n    grid = torch.stack(stacks, dim=0).float()  # [2, H, W] or [3, H, W]\n\n    grid = grid[None].repeat(b, 1, 1, 1)  # [B, 2, H, W] or [B, 3, H, W]\n\n    if device is not None:\n        grid = grid.to(device)\n\n    return grid\n\n\ndef bilinear_sample(img, sample_coords, mode=\"bilinear\", padding_mode=\"zeros\", return_mask=False):\n    # img: [B, C, H, W]\n    # sample_coords: [B, 2, H, W] in image scale\n    if sample_coords.size(1) != 2:  # [B, H, W, 2]\n        sample_coords = sample_coords.permute(0, 3, 1, 2)\n\n    b, _, h, w = sample_coords.shape\n\n    # Normalize to [-1, 1]\n    x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1\n    y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1\n\n    grid = torch.stack([x_grid, y_grid], dim=-1)  # [B, H, W, 2]\n\n    img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)\n\n    if return_mask:\n        mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1)  # [B, H, W]\n\n        return img, mask\n\n    return img\n\n\nclass Dilate:\n    def __init__(self, kernel_size=7, channels=1, device=\"cpu\"):\n        self.kernel_size = kernel_size\n        self.channels = channels\n        gaussian_kernel = torch.ones(1, 1, self.kernel_size, self.kernel_size)\n        gaussian_kernel = gaussian_kernel.repeat(self.channels, 1, 1, 1)\n        self.mean = (self.kernel_size - 1) // 2\n        gaussian_kernel = gaussian_kernel.to(device)\n        self.gaussian_filter = gaussian_kernel\n\n    def __call__(self, x):\n        x = F.pad(x, (self.mean, self.mean, self.mean, self.mean), \"replicate\")\n        return torch.clamp(F.conv2d(x, self.gaussian_filter, bias=None), 0, 1)\n\n\ndef flow_warp(feature, flow, mask=False, mode=\"bilinear\", padding_mode=\"zeros\"):\n    b, c, h, w = feature.size()\n    assert flow.size(1) == 2\n\n    grid = coords_grid(b, h, w).to(flow.device) + flow  # [B, 2, H, W]\n    grid = grid.to(feature.dtype)\n    return bilinear_sample(feature, grid, mode=mode, padding_mode=padding_mode, return_mask=mask)\n\n\ndef forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5):\n    # fwd_flow, bwd_flow: [B, 2, H, W]\n    # alpha and beta values are following UnFlow\n    # (https://huggingface.co/papers/1711.07837)\n    assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4\n    assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2\n    flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1)  # [B, H, W]\n\n    warped_bwd_flow = flow_warp(bwd_flow, fwd_flow)  # [B, 2, H, W]\n    warped_fwd_flow = flow_warp(fwd_flow, bwd_flow)  # [B, 2, H, W]\n\n    diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1)  # [B, H, W]\n    diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)\n\n    threshold = alpha * flow_mag + beta\n\n    fwd_occ = (diff_fwd > threshold).float()  # [B, H, W]\n    bwd_occ = (diff_bwd > threshold).float()\n\n    return fwd_occ, bwd_occ\n\n\ndef numpy2tensor(img):\n    x0 = torch.from_numpy(img.copy()).float().cuda() / 255.0 * 2.0 - 1.0\n    x0 = torch.stack([x0], dim=0)\n    # einops.rearrange(x0, 'b h w c -> b c h w').clone()\n    return x0.permute(0, 3, 1, 2)\n\n\ndef calc_mean_std(feat, eps=1e-5, chunk=1):\n    size = feat.size()\n    assert len(size) == 4\n    if chunk == 2:\n        feat = torch.cat(feat.chunk(2), dim=3)\n    N, C = size[:2]\n    feat_var = feat.view(N // chunk, C, -1).var(dim=2) + eps\n    feat_std = feat_var.sqrt().view(N, C, 1, 1)\n    feat_mean = feat.view(N // chunk, C, -1).mean(dim=2).view(N // chunk, C, 1, 1)\n    return feat_mean.repeat(chunk, 1, 1, 1), feat_std.repeat(chunk, 1, 1, 1)\n\n\ndef adaptive_instance_normalization(content_feat, style_feat, chunk=1):\n    assert content_feat.size()[:2] == style_feat.size()[:2]\n    size = content_feat.size()\n    style_mean, style_std = calc_mean_std(style_feat, chunk)\n    content_mean, content_std = calc_mean_std(content_feat)\n\n    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)\n    return normalized_feat * style_std.expand(size) + style_mean.expand(size)\n\n\ndef optimize_feature(\n    sample, flows, occs, correlation_matrix=[], intra_weight=1e2, iters=20, unet_chunk_size=2, optimize_temporal=True\n):\n    \"\"\"\n    FRESO-guided latent feature optimization\n    * optimize spatial correspondence (match correlation_matrix)\n    * optimize temporal correspondence (match warped_image)\n    \"\"\"\n    if (flows is None or occs is None or (not optimize_temporal)) and (\n        intra_weight == 0 or len(correlation_matrix) == 0\n    ):\n        return sample\n    # flows=[fwd_flows, bwd_flows]: (N-1)*2*H1*W1\n    # occs=[fwd_occs, bwd_occs]: (N-1)*H1*W1\n    # sample: 2N*C*H*W\n    torch.cuda.empty_cache()\n    video_length = sample.shape[0] // unet_chunk_size\n    latent = rearrange(sample.to(torch.float32), \"(b f) c h w -> b f c h w\", f=video_length)\n\n    cs = torch.nn.Parameter((latent.detach().clone()))\n    optimizer = torch.optim.Adam([cs], lr=0.2)\n\n    # unify resolution\n    if flows is not None and occs is not None:\n        scale = sample.shape[2] * 1.0 / flows[0].shape[2]\n        kernel = int(1 / scale)\n        bwd_flow_ = F.interpolate(flows[1] * scale, scale_factor=scale, mode=\"bilinear\").repeat(\n            unet_chunk_size, 1, 1, 1\n        )\n        bwd_occ_ = F.max_pool2d(occs[1].unsqueeze(1), kernel_size=kernel).repeat(\n            unet_chunk_size, 1, 1, 1\n        )  # 2(N-1)*1*H1*W1\n        fwd_flow_ = F.interpolate(flows[0] * scale, scale_factor=scale, mode=\"bilinear\").repeat(\n            unet_chunk_size, 1, 1, 1\n        )\n        fwd_occ_ = F.max_pool2d(occs[0].unsqueeze(1), kernel_size=kernel).repeat(\n            unet_chunk_size, 1, 1, 1\n        )  # 2(N-1)*1*H1*W1\n        # match frame 0,1,2,3 and frame 1,2,3,0\n        reshuffle_list = list(range(1, video_length)) + [0]\n\n    # attention_probs is the GRAM matrix of the normalized feature\n    attention_probs = None\n    for tmp in correlation_matrix:\n        if sample.shape[2] * sample.shape[3] == tmp.shape[1]:\n            attention_probs = tmp  # 2N*HW*HW\n            break\n\n    n_iter = [0]\n    while n_iter[0] < iters:\n\n        def closure():\n            optimizer.zero_grad()\n\n            loss = 0\n\n            # temporal consistency loss\n            if optimize_temporal and flows is not None and occs is not None:\n                c1 = rearrange(cs[:, :], \"b f c h w -> (b f) c h w\")\n                c2 = rearrange(cs[:, reshuffle_list], \"b f c h w -> (b f) c h w\")\n                warped_image1 = flow_warp(c1, bwd_flow_)\n                warped_image2 = flow_warp(c2, fwd_flow_)\n                loss = (\n                    abs((c2 - warped_image1) * (1 - bwd_occ_)) + abs((c1 - warped_image2) * (1 - fwd_occ_))\n                ).mean() * 2\n\n            # spatial consistency loss\n            if attention_probs is not None and intra_weight > 0:\n                cs_vector = rearrange(cs, \"b f c h w -> (b f) (h w) c\")\n                # attention_scores = torch.bmm(cs_vector, cs_vector.transpose(-1, -2))\n                # cs_attention_probs = attention_scores.softmax(dim=-1)\n                cs_vector = cs_vector / ((cs_vector**2).sum(dim=2, keepdims=True) ** 0.5)\n                cs_attention_probs = torch.bmm(cs_vector, cs_vector.transpose(-1, -2))\n                tmp = F.l1_loss(cs_attention_probs, attention_probs) * intra_weight\n                loss = tmp + loss\n\n            loss.backward()\n            n_iter[0] += 1\n\n            return loss\n\n        optimizer.step(closure)\n\n    torch.cuda.empty_cache()\n    return adaptive_instance_normalization(rearrange(cs.data.to(sample.dtype), \"b f c h w -> (b f) c h w\"), sample)\n\n\n@torch.no_grad()\ndef warp_tensor(sample, flows, occs, saliency, unet_chunk_size):\n    \"\"\"\n    Warp images or features based on optical flow\n    Fuse the warped imges or features based on occusion masks and saliency map\n    \"\"\"\n    scale = sample.shape[2] * 1.0 / flows[0].shape[2]\n    kernel = int(1 / scale)\n    bwd_flow_ = F.interpolate(flows[1] * scale, scale_factor=scale, mode=\"bilinear\")\n    bwd_occ_ = F.max_pool2d(occs[1].unsqueeze(1), kernel_size=kernel)  # (N-1)*1*H1*W1\n    if scale == 1:\n        bwd_occ_ = Dilate(kernel_size=13, device=sample.device)(bwd_occ_)\n    fwd_flow_ = F.interpolate(flows[0] * scale, scale_factor=scale, mode=\"bilinear\")\n    fwd_occ_ = F.max_pool2d(occs[0].unsqueeze(1), kernel_size=kernel)  # (N-1)*1*H1*W1\n    if scale == 1:\n        fwd_occ_ = Dilate(kernel_size=13, device=sample.device)(fwd_occ_)\n    scale2 = sample.shape[2] * 1.0 / saliency.shape[2]\n    saliency = F.interpolate(saliency, scale_factor=scale2, mode=\"bilinear\")\n    latent = sample.to(torch.float32)\n    video_length = sample.shape[0] // unet_chunk_size\n    warp_saliency = flow_warp(saliency, bwd_flow_)\n    warp_saliency_ = flow_warp(saliency[0:1], fwd_flow_[video_length - 1 : video_length])\n\n    for j in range(unet_chunk_size):\n        for ii in range(video_length - 1):\n            i = video_length * j + ii\n            warped_image = flow_warp(latent[i : i + 1], bwd_flow_[ii : ii + 1])\n            mask = (1 - bwd_occ_[ii : ii + 1]) * saliency[ii + 1 : ii + 2] * warp_saliency[ii : ii + 1]\n            latent[i + 1 : i + 2] = latent[i + 1 : i + 2] * (1 - mask) + warped_image * mask\n        i = video_length * j\n        ii = video_length - 1\n        warped_image = flow_warp(latent[i : i + 1], fwd_flow_[ii : ii + 1])\n        mask = (1 - fwd_occ_[ii : ii + 1]) * saliency[ii : ii + 1] * warp_saliency_\n        latent[ii + i : ii + i + 1] = latent[ii + i : ii + i + 1] * (1 - mask) + warped_image * mask\n\n    return latent.to(sample.dtype)\n\n\ndef my_forward(\n    self,\n    steps=[],\n    layers=[0, 1, 2, 3],\n    flows=None,\n    occs=None,\n    correlation_matrix=[],\n    intra_weight=1e2,\n    iters=20,\n    optimize_temporal=True,\n    saliency=None,\n):\n    \"\"\"\n    Hacked pipe.unet.forward()\n    copied from https://github.com/huggingface/diffusers/blob/v0.19.3/src/diffusers/models/unet_2d_condition.py#L700\n    if you are using a new version of diffusers, please copy the source code and modify it accordingly (find [HACK] in the code)\n    * restore and return the decoder features\n    * optimize the decoder features\n    * perform background smoothing\n    \"\"\"\n\n    def forward(\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[UNet2DConditionOutput, Tuple]:\n        r\"\"\"\n        The [`UNet2DConditionModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor with the following shape `(batch, channel, height, width)`.\n            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.FloatTensor`):\n                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.\n            encoder_attention_mask (`torch.Tensor`):\n                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If\n                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,\n                which adds large negative values to the attention scores corresponding to \"discard\" tokens.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain\n                tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].\n            added_cond_kwargs: (`dict`, *optional*):\n                A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that\n                are passed along to the UNet blocks.\n\n        Returns:\n            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise\n                a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):\n            logger.info(\"Forward upsample size to force interpolation output size.\")\n            forward_upsample_size = True\n\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            is_npu = sample.device.type == \"npu\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if (is_mps or is_npu) else torch.float64\n            else:\n                dtype = torch.int32 if (is_mps or is_npu) else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n                # `Timesteps` does not contain any weights and will always return f32 tensors\n                # there might be better ways to encapsulate this.\n                class_labels = class_labels.to(dtype=sample.dtype)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)\n\n            if self.config.class_embeddings_concat:\n                emb = torch.cat([emb, class_emb], dim=-1)\n            else:\n                emb = emb + class_emb\n\n        if self.config.addition_embed_type == \"text\":\n            aug_emb = self.add_embedding(encoder_hidden_states)\n        elif self.config.addition_embed_type == \"text_image\":\n            # Kandinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            text_embs = added_cond_kwargs.get(\"text_embeds\", encoder_hidden_states)\n            aug_emb = self.add_embedding(text_embs, image_embs)\n        elif self.config.addition_embed_type == \"text_time\":\n            # SDXL - style\n            if \"text_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            text_embeds = added_cond_kwargs.get(\"text_embeds\")\n            if \"time_ids\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`\"\n                )\n            time_ids = added_cond_kwargs.get(\"time_ids\")\n            time_embeds = self.add_time_proj(time_ids.flatten())\n            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))\n\n            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)\n            add_embeds = add_embeds.to(emb.dtype)\n            aug_emb = self.add_embedding(add_embeds)\n        elif self.config.addition_embed_type == \"image\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            aug_emb = self.add_embedding(image_embs)\n        elif self.config.addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs or \"hint\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            hint = added_cond_kwargs.get(\"hint\")\n            aug_emb, hint = self.add_embedding(image_embs, hint)\n            sample = torch.cat([sample, hint], dim=1)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        if self.time_embed_act is not None:\n            emb = self.time_embed_act(emb)\n\n        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_proj\":\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_image_proj\":\n            # Kadinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(image_embeds)\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        # 3. down\n\n        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None\n        is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                # For t2i-adapter CrossAttnDownBlock2D\n                additional_residuals = {}\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    additional_residuals[\"additional_residuals\"] = down_block_additional_residuals.pop(0)\n\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                    **additional_residuals,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    sample += down_block_additional_residuals.pop(0)\n            down_block_res_samples += res_samples\n\n        if is_controlnet:\n            new_down_block_res_samples = ()\n\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n\n            down_block_res_samples = new_down_block_res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            sample = self.mid_block(\n                sample,\n                emb,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=attention_mask,\n                cross_attention_kwargs=cross_attention_kwargs,\n                encoder_attention_mask=encoder_attention_mask,\n            )\n\n        if is_controlnet:\n            sample = sample + mid_block_additional_residual\n\n        # 5. up\n        \"\"\"\n        [HACK] restore the decoder features in up_samples\n        \"\"\"\n        up_samples = ()\n        # down_samples = ()\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            \"\"\"\n            [HACK] restore the decoder features in up_samples\n            [HACK] optimize the decoder features\n            [HACK] perform background smoothing\n            \"\"\"\n            if i in layers:\n                up_samples += (sample,)\n            if timestep in steps and i in layers:\n                sample = optimize_feature(\n                    sample, flows, occs, correlation_matrix, intra_weight, iters, optimize_temporal=optimize_temporal\n                )\n                if saliency is not None:\n                    sample = warp_tensor(sample, flows, occs, saliency, 2)\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size\n                )\n\n        # 6. post-process\n        if self.conv_norm_out:\n            sample = self.conv_norm_out(sample)\n            sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        \"\"\"\n        [HACK] return the output feature as well as the decoder features\n        \"\"\"\n        if not return_dict:\n            return (sample,) + up_samples\n\n        return UNet2DConditionOutput(sample=sample)\n\n    return forward\n\n\n@torch.no_grad()\ndef get_single_mapping_ind(bwd_flow, bwd_occ, imgs, scale=1.0):\n    \"\"\"\n    FLATTEN: Optical fLow-guided attention (Temoporal-guided attention)\n    Find the correspondence between every pixels in a pair of frames\n\n    [input]\n    bwd_flow: 1*2*H*W\n    bwd_occ: 1*H*W      i.e., f2 = warp(f1, bwd_flow) * bwd_occ\n    imgs: 2*3*H*W       i.e., [f1,f2]\n\n    [output]\n    mapping_ind: pixel index correspondence\n    unlinkedmask: indicate whether a pixel has no correspondence\n    i.e., f2 = f1[mapping_ind] * unlinkedmask\n    \"\"\"\n    flows = F.interpolate(bwd_flow, scale_factor=1.0 / scale, mode=\"bilinear\")[0][[1, 0]] / scale  # 2*H*W\n    _, H, W = flows.shape\n    masks = torch.logical_not(F.interpolate(bwd_occ[None], scale_factor=1.0 / scale, mode=\"bilinear\") > 0.5)[\n        0\n    ]  # 1*H*W\n    frames = F.interpolate(imgs, scale_factor=1.0 / scale, mode=\"bilinear\").view(2, 3, -1)  # 2*3*HW\n    grid = torch.stack(torch.meshgrid([torch.arange(H), torch.arange(W)]), dim=0).to(flows.device)  # 2*H*W\n    warp_grid = torch.round(grid + flows)\n    mask = torch.logical_and(\n        torch.logical_and(\n            torch.logical_and(torch.logical_and(warp_grid[0] >= 0, warp_grid[0] < H), warp_grid[1] >= 0),\n            warp_grid[1] < W,\n        ),\n        masks[0],\n    ).view(-1)  # HW\n    warp_grid = warp_grid.view(2, -1)  # 2*HW\n    warp_ind = (warp_grid[0] * W + warp_grid[1]).to(torch.long)  # HW\n    mapping_ind = torch.zeros_like(warp_ind) - 1  # HW\n\n    for f0ind, f1ind in enumerate(warp_ind):\n        if mask[f0ind]:\n            if mapping_ind[f1ind] == -1:\n                mapping_ind[f1ind] = f0ind\n            else:\n                targetv = frames[0, :, f1ind]\n                pref0ind = mapping_ind[f1ind]\n                prev = frames[1, :, pref0ind]\n                v = frames[1, :, f0ind]\n                if ((prev - targetv) ** 2).mean() > ((v - targetv) ** 2).mean():\n                    mask[pref0ind] = False\n                    mapping_ind[f1ind] = f0ind\n                else:\n                    mask[f0ind] = False\n\n    unusedind = torch.arange(len(mask)).to(mask.device)[~mask]\n    unlinkedmask = mapping_ind == -1\n    mapping_ind[unlinkedmask] = unusedind\n    return mapping_ind, unlinkedmask\n\n\n@torch.no_grad()\ndef get_mapping_ind(bwd_flows, bwd_occs, imgs, scale=1.0):\n    \"\"\"\n    FLATTEN: Optical fLow-guided attention (Temoporal-guided attention)\n    Find pixel correspondence between every consecutive frames in a batch\n\n    [input]\n    bwd_flow: (N-1)*2*H*W\n    bwd_occ: (N-1)*H*W\n    imgs: N*3*H*W\n\n    [output]\n    fwd_mappings: N*1*HW\n    bwd_mappings: N*1*HW\n    flattn_mask: HW*1*N*N\n    i.e., imgs[i,:,fwd_mappings[i]] corresponds to imgs[0]\n    i.e., imgs[i,:,fwd_mappings[i]][:,bwd_mappings[i]] restore the original imgs[i]\n    \"\"\"\n    N, H, W = imgs.shape[0], int(imgs.shape[2] // scale), int(imgs.shape[3] // scale)\n    iterattn_mask = torch.ones(H * W, N, N, dtype=torch.bool).to(imgs.device)\n    for i in range(len(imgs) - 1):\n        one_mask = torch.ones(N, N, dtype=torch.bool).to(imgs.device)\n        one_mask[: i + 1, i + 1 :] = False\n        one_mask[i + 1 :, : i + 1] = False\n        mapping_ind, unlinkedmask = get_single_mapping_ind(\n            bwd_flows[i : i + 1], bwd_occs[i : i + 1], imgs[i : i + 2], scale\n        )\n        if i == 0:\n            fwd_mapping = [torch.arange(len(mapping_ind)).to(mapping_ind.device)]\n            bwd_mapping = [torch.arange(len(mapping_ind)).to(mapping_ind.device)]\n        iterattn_mask[unlinkedmask[fwd_mapping[-1]]] = torch.logical_and(\n            iterattn_mask[unlinkedmask[fwd_mapping[-1]]], one_mask\n        )\n        fwd_mapping += [mapping_ind[fwd_mapping[-1]]]\n        bwd_mapping += [torch.sort(fwd_mapping[-1])[1]]\n    fwd_mappings = torch.stack(fwd_mapping, dim=0).unsqueeze(1)\n    bwd_mappings = torch.stack(bwd_mapping, dim=0).unsqueeze(1)\n    return fwd_mappings, bwd_mappings, iterattn_mask.unsqueeze(1)\n\n\ndef apply_FRESCO_opt(\n    pipe,\n    steps=[],\n    layers=[0, 1, 2, 3],\n    flows=None,\n    occs=None,\n    correlation_matrix=[],\n    intra_weight=1e2,\n    iters=20,\n    optimize_temporal=True,\n    saliency=None,\n):\n    \"\"\"\n    Apply FRESCO-based optimization to a StableDiffusionPipeline\n    \"\"\"\n    pipe.unet.forward = my_forward(\n        pipe.unet, steps, layers, flows, occs, correlation_matrix, intra_weight, iters, optimize_temporal, saliency\n    )\n\n\n@torch.no_grad()\ndef get_intraframe_paras(pipe, imgs, frescoProc, prompt_embeds, do_classifier_free_guidance=True, generator=None):\n    \"\"\"\n    Get parameters for spatial-guided attention and optimization\n    * perform one step denoising\n    * collect attention feature, stored in frescoProc.controller.stored_attn['decoder_attn']\n    * compute the gram matrix of the normalized feature for spatial consistency loss\n    \"\"\"\n\n    noise_scheduler = pipe.scheduler\n    timestep = noise_scheduler.timesteps[-1]\n    device = pipe._execution_device\n    B, C, H, W = imgs.shape\n\n    frescoProc.controller.disable_controller()\n    apply_FRESCO_opt(pipe)\n    frescoProc.controller.clear_store()\n    frescoProc.controller.enable_store()\n\n    latents = pipe.prepare_latents(\n        imgs.to(pipe.unet.dtype), timestep, B, 1, prompt_embeds.dtype, device, generator=generator, repeat_noise=False\n    )\n\n    latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n    model_output = pipe.unet(\n        latent_model_input,\n        timestep,\n        encoder_hidden_states=prompt_embeds,\n        cross_attention_kwargs=None,\n        return_dict=False,\n    )\n\n    frescoProc.controller.disable_store()\n\n    # gram matrix of the normalized feature for spatial consistency loss\n    correlation_matrix = []\n    for tmp in model_output[1:]:\n        latent_vector = rearrange(tmp, \"b c h w -> b (h w) c\")\n        latent_vector = latent_vector / ((latent_vector**2).sum(dim=2, keepdims=True) ** 0.5)\n        attention_probs = torch.bmm(latent_vector, latent_vector.transpose(-1, -2))\n        correlation_matrix += [attention_probs.detach().clone().to(torch.float32)]\n        del attention_probs, latent_vector, tmp\n    del model_output\n\n    clear_cache()\n\n    return correlation_matrix\n\n\n@torch.no_grad()\ndef get_flow_and_interframe_paras(flow_model, imgs):\n    \"\"\"\n    Get parameters for temporal-guided attention and optimization\n    * predict optical flow and occlusion mask\n    * compute pixel index correspondence for FLATTEN\n    \"\"\"\n    images = torch.stack([torch.from_numpy(img).permute(2, 0, 1).float() for img in imgs], dim=0).cuda()\n    imgs_torch = torch.cat([numpy2tensor(img) for img in imgs], dim=0)\n\n    reshuffle_list = list(range(1, len(images))) + [0]\n\n    results_dict = flow_model(\n        images,\n        images[reshuffle_list],\n        attn_splits_list=[2],\n        corr_radius_list=[-1],\n        prop_radius_list=[-1],\n        pred_bidir_flow=True,\n    )\n    flow_pr = results_dict[\"flow_preds\"][-1]  # [2*B, 2, H, W]\n    fwd_flows, bwd_flows = flow_pr.chunk(2)  # [B, 2, H, W]\n    fwd_occs, bwd_occs = forward_backward_consistency_check(fwd_flows, bwd_flows)  # [B, H, W]\n\n    warped_image1 = flow_warp(images, bwd_flows)\n    bwd_occs = torch.clamp(\n        bwd_occs + (abs(images[reshuffle_list] - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0, 1\n    )\n\n    warped_image2 = flow_warp(images[reshuffle_list], fwd_flows)\n    fwd_occs = torch.clamp(fwd_occs + (abs(images - warped_image2).mean(dim=1) > 255 * 0.25).float(), 0, 1)\n\n    attn_mask = []\n    for scale in [8.0, 16.0, 32.0]:\n        bwd_occs_ = F.interpolate(bwd_occs[:-1].unsqueeze(1), scale_factor=1.0 / scale, mode=\"bilinear\")\n        attn_mask += [\n            torch.cat((bwd_occs_[0:1].reshape(1, -1) > -1, bwd_occs_.reshape(bwd_occs_.shape[0], -1) > 0.5), dim=0)\n        ]\n\n    fwd_mappings = []\n    bwd_mappings = []\n    interattn_masks = []\n    for scale in [8.0, 16.0]:\n        fwd_mapping, bwd_mapping, interattn_mask = get_mapping_ind(bwd_flows, bwd_occs, imgs_torch, scale=scale)\n        fwd_mappings += [fwd_mapping]\n        bwd_mappings += [bwd_mapping]\n        interattn_masks += [interattn_mask]\n\n    interattn_paras = {}\n    interattn_paras[\"fwd_mappings\"] = fwd_mappings\n    interattn_paras[\"bwd_mappings\"] = bwd_mappings\n    interattn_paras[\"interattn_masks\"] = interattn_masks\n\n    clear_cache()\n\n    return [fwd_flows, bwd_flows], [fwd_occs, bwd_occs], attn_mask, interattn_paras\n\n\nclass AttentionControl:\n    \"\"\"\n    Control FRESCO-based attention\n    * enable/disable spatial-guided attention\n    * enable/disable temporal-guided attention\n    * enable/disable cross-frame attention\n    * collect intermediate attention feature (for spatial-guided attention)\n    \"\"\"\n\n    def __init__(self):\n        self.stored_attn = self.get_empty_store()\n        self.store = False\n        self.index = 0\n        self.attn_mask = None\n        self.interattn_paras = None\n        self.use_interattn = False\n        self.use_cfattn = False\n        self.use_intraattn = False\n        self.intraattn_bias = 0\n        self.intraattn_scale_factor = 0.2\n        self.interattn_scale_factor = 0.2\n\n    @staticmethod\n    def get_empty_store():\n        return {\n            \"decoder_attn\": [],\n        }\n\n    def clear_store(self):\n        del self.stored_attn\n        torch.cuda.empty_cache()\n        gc.collect()\n        self.stored_attn = self.get_empty_store()\n        self.disable_intraattn()\n\n    # store attention feature of the input frame for spatial-guided attention\n    def enable_store(self):\n        self.store = True\n\n    def disable_store(self):\n        self.store = False\n\n    # spatial-guided attention\n    def enable_intraattn(self):\n        self.index = 0\n        self.use_intraattn = True\n        self.disable_store()\n        if len(self.stored_attn[\"decoder_attn\"]) == 0:\n            self.use_intraattn = False\n\n    def disable_intraattn(self):\n        self.index = 0\n        self.use_intraattn = False\n        self.disable_store()\n\n    def disable_cfattn(self):\n        self.use_cfattn = False\n\n    # cross frame attention\n    def enable_cfattn(self, attn_mask=None):\n        if attn_mask:\n            if self.attn_mask:\n                del self.attn_mask\n                torch.cuda.empty_cache()\n            self.attn_mask = attn_mask\n            self.use_cfattn = True\n        else:\n            if self.attn_mask:\n                self.use_cfattn = True\n            else:\n                print(\"Warning: no valid cross-frame attention parameters available!\")\n                self.disable_cfattn()\n\n    def disable_interattn(self):\n        self.use_interattn = False\n\n    # temporal-guided attention\n    def enable_interattn(self, interattn_paras=None):\n        if interattn_paras:\n            if self.interattn_paras:\n                del self.interattn_paras\n                torch.cuda.empty_cache()\n            self.interattn_paras = interattn_paras\n            self.use_interattn = True\n        else:\n            if self.interattn_paras:\n                self.use_interattn = True\n            else:\n                print(\"Warning: no valid temporal-guided attention parameters available!\")\n                self.disable_interattn()\n\n    def disable_controller(self):\n        self.disable_intraattn()\n        self.disable_interattn()\n        self.disable_cfattn()\n\n    def enable_controller(self, interattn_paras=None, attn_mask=None):\n        self.enable_intraattn()\n        self.enable_interattn(interattn_paras)\n        self.enable_cfattn(attn_mask)\n\n    def forward(self, context):\n        if self.store:\n            self.stored_attn[\"decoder_attn\"].append(context.detach())\n        if self.use_intraattn and len(self.stored_attn[\"decoder_attn\"]) > 0:\n            tmp = self.stored_attn[\"decoder_attn\"][self.index]\n            self.index = self.index + 1\n            if self.index >= len(self.stored_attn[\"decoder_attn\"]):\n                self.index = 0\n                self.disable_store()\n            return tmp\n        return context\n\n    def __call__(self, context):\n        context = self.forward(context)\n        return context\n\n\nclass FRESCOAttnProcessor2_0:\n    \"\"\"\n    Hack self attention to FRESCO-based attention\n    * adding spatial-guided attention\n    * adding temporal-guided attention\n    * adding cross-frame attention\n\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    Usage\n    frescoProc = FRESCOAttnProcessor2_0(2, attn_mask)\n    attnProc = AttnProcessor2_0()\n\n    attn_processor_dict = {}\n    for k in pipe.unet.attn_processors.keys():\n        if k.startswith(\"up_blocks.2\") or k.startswith(\"up_blocks.3\"):\n            attn_processor_dict[k] = frescoProc\n        else:\n            attn_processor_dict[k] = attnProc\n    pipe.unet.set_attn_processor(attn_processor_dict)\n    \"\"\"\n\n    def __init__(self, unet_chunk_size=2, controller=None):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n        self.unet_chunk_size = unet_chunk_size\n        self.controller = controller\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        crossattn = False\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n            if self.controller and self.controller.store:\n                self.controller(hidden_states.detach().clone())\n        else:\n            crossattn = True\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        # BC * HW * 8D\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query_raw, key_raw = None, None\n        if self.controller and self.controller.use_interattn and (not crossattn):\n            query_raw, key_raw = query.clone(), key.clone()\n\n        inner_dim = key.shape[-1]  # 8D\n        head_dim = inner_dim // attn.heads  # D\n\n        \"\"\"for efficient cross-frame attention\"\"\"\n        if self.controller and self.controller.use_cfattn and (not crossattn):\n            video_length = key.size()[0] // self.unet_chunk_size\n            former_frame_index = [0] * video_length\n            attn_mask = None\n            if self.controller.attn_mask is not None:\n                for m in self.controller.attn_mask:\n                    if m.shape[1] == key.shape[1]:\n                        attn_mask = m\n            # BC * HW * 8D --> B * C * HW * 8D\n            key = rearrange(key, \"(b f) d c -> b f d c\", f=video_length)\n            # B * C * HW * 8D --> B * C * HW * 8D\n            if attn_mask is None:\n                key = key[:, former_frame_index]\n            else:\n                key = repeat(key[:, attn_mask], \"b d c -> b f d c\", f=video_length)\n            # B * C * HW * 8D --> BC * HW * 8D\n            key = rearrange(key, \"b f d c -> (b f) d c\").detach()\n            value = rearrange(value, \"(b f) d c -> b f d c\", f=video_length)\n            if attn_mask is None:\n                value = value[:, former_frame_index]\n            else:\n                value = repeat(value[:, attn_mask], \"b d c -> b f d c\", f=video_length)\n            value = rearrange(value, \"b f d c -> (b f) d c\").detach()\n\n        # BC * HW * 8D --> BC * HW * 8 * D --> BC * 8 * HW * D\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        # BC * 8 * HW2 * D\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        # BC * 8 * HW2 * D2\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        \"\"\"for spatial-guided intra-frame attention\"\"\"\n        if self.controller and self.controller.use_intraattn and (not crossattn):\n            ref_hidden_states = self.controller(None)\n            assert ref_hidden_states.shape == encoder_hidden_states.shape\n            query_ = attn.to_q(ref_hidden_states)\n            key_ = attn.to_k(ref_hidden_states)\n\n            # BC * 8 * HW * D\n            query_ = query_.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n            key_ = key_.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n            query = F.scaled_dot_product_attention(\n                query_,\n                key_ * self.controller.intraattn_scale_factor,\n                query,\n                attn_mask=torch.eye(query_.size(-2), key_.size(-2), dtype=query.dtype, device=query.device)\n                * self.controller.intraattn_bias,\n            ).detach()\n\n            del query_, key_\n            torch.cuda.empty_cache()\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        # output: BC * 8 * HW * D2\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        \"\"\"for temporal-guided inter-frame attention (FLATTEN)\"\"\"\n        if self.controller and self.controller.use_interattn and (not crossattn):\n            del query, key, value\n            torch.cuda.empty_cache()\n            bwd_mapping = None\n            fwd_mapping = None\n            for i, f in enumerate(self.controller.interattn_paras[\"fwd_mappings\"]):\n                if f.shape[2] == hidden_states.shape[2]:\n                    fwd_mapping = f\n                    bwd_mapping = self.controller.interattn_paras[\"bwd_mappings\"][i]\n                    interattn_mask = self.controller.interattn_paras[\"interattn_masks\"][i]\n            video_length = key_raw.size()[0] // self.unet_chunk_size\n            # BC * HW * 8D --> C * 8BD * HW\n            key = rearrange(key_raw, \"(b f) d c -> f (b c) d\", f=video_length)\n            query = rearrange(query_raw, \"(b f) d c -> f (b c) d\", f=video_length)\n            # BC * 8 * HW * D --> C * 8BD * HW\n            # key = rearrange(hidden_states, \"(b f) h d c -> f (b h c) d\", f=video_length) ########\n            # query = rearrange(hidden_states, \"(b f) h d c -> f (b h c) d\", f=video_length) #######\n\n            value = rearrange(hidden_states, \"(b f) h d c -> f (b h c) d\", f=video_length)\n            key = torch.gather(key, 2, fwd_mapping.expand(-1, key.shape[1], -1))\n            query = torch.gather(query, 2, fwd_mapping.expand(-1, query.shape[1], -1))\n            value = torch.gather(value, 2, fwd_mapping.expand(-1, value.shape[1], -1))\n            # C * 8BD * HW --> BHW, C, 8D\n            key = rearrange(key, \"f (b c) d -> (b d) f c\", b=self.unet_chunk_size)\n            query = rearrange(query, \"f (b c) d -> (b d) f c\", b=self.unet_chunk_size)\n            value = rearrange(value, \"f (b c) d -> (b d) f c\", b=self.unet_chunk_size)\n            # BHW * C * 8D --> BHW * C * 8 * D--> BHW * 8 * C * D\n            query = query.view(-1, video_length, attn.heads, head_dim).transpose(1, 2).detach()\n            key = key.view(-1, video_length, attn.heads, head_dim).transpose(1, 2).detach()\n            value = value.view(-1, video_length, attn.heads, head_dim).transpose(1, 2).detach()\n            hidden_states_ = F.scaled_dot_product_attention(\n                query,\n                key * self.controller.interattn_scale_factor,\n                value,\n                # .to(query.dtype)-1.0) * 1e6 -\n                attn_mask=(interattn_mask.repeat(self.unet_chunk_size, 1, 1, 1)),\n                # torch.eye(interattn_mask.shape[2]).to(query.device).to(query.dtype) * 1e4,\n            )\n\n            # BHW * 8 * C * D --> C * 8BD * HW\n            hidden_states_ = rearrange(hidden_states_, \"(b d) h f c -> f (b h c) d\", b=self.unet_chunk_size)\n            hidden_states_ = torch.gather(\n                hidden_states_, 2, bwd_mapping.expand(-1, hidden_states_.shape[1], -1)\n            ).detach()\n            # C * 8BD * HW --> BC * 8 * HW * D\n            hidden_states = rearrange(\n                hidden_states_, \"f (b h c) d -> (b f) h d c\", b=self.unet_chunk_size, h=attn.heads\n            )\n\n        # BC * 8 * HW * D --> BC * HW * 8D\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\ndef apply_FRESCO_attn(pipe):\n    \"\"\"\n    Apply FRESCO-guided attention to a StableDiffusionPipeline\n    \"\"\"\n    frescoProc = FRESCOAttnProcessor2_0(2, AttentionControl())\n    attnProc = AttnProcessor2_0()\n    attn_processor_dict = {}\n    for k in pipe.unet.attn_processors.keys():\n        if k.startswith(\"up_blocks.2\") or k.startswith(\"up_blocks.3\"):\n            attn_processor_dict[k] = frescoProc\n        else:\n            attn_processor_dict[k] = attnProc\n    pipe.unet.set_attn_processor(attn_processor_dict)\n    return frescoProc\n\n\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\ndef prepare_image(image):\n    if isinstance(image, torch.Tensor):\n        # Batch single image\n        if image.ndim == 3:\n            image = image.unsqueeze(0)\n\n        image = image.to(dtype=torch.float32)\n    else:\n        # preprocess image\n        if isinstance(image, (PIL.Image.Image, np.ndarray)):\n            image = [image]\n\n        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n            image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n            image = np.concatenate(image, axis=0)\n        elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n            image = np.concatenate([i[None, :] for i in image], axis=0)\n\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n    return image\n\n\nclass FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):\n    r\"\"\"\n    Pipeline for video-to-video translation using Stable Diffusion with FRESCO Algorithm.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):\n            Provides additional conditioning to the `unet` during the denoising process. If you set multiple\n            ControlNets as a list, the outputs from each ControlNet are added together to create one combined\n            additional conditioning.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__(\n            vae,\n            text_encoder,\n            tokenizer,\n            unet,\n            controlnet,\n            scheduler,\n            safety_checker,\n            feature_extractor,\n            image_encoder,\n            requires_safety_checker,\n        )\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n        frescoProc = FRESCOAttnProcessor2_0(2, AttentionControl())\n        attnProc = AttnProcessor2_0()\n        attn_processor_dict = {}\n        for k in self.unet.attn_processors.keys():\n            if k.startswith(\"up_blocks.2\") or k.startswith(\"up_blocks.3\"):\n                attn_processor_dict[k] = frescoProc\n            else:\n                attn_processor_dict[k] = attnProc\n        self.unet.set_attn_processor(attn_processor_dict)\n        self.frescoProc = frescoProc\n\n        flow_model = GMFlow(\n            feature_channels=128,\n            num_scales=1,\n            upsample_factor=8,\n            num_head=1,\n            attention_type=\"swin\",\n            ffn_dim_expansion=4,\n            num_transformer_layers=6,\n        ).to(self.device)\n\n        checkpoint = torch.utils.model_zoo.load_url(\n            \"https://huggingface.co/Anonymous-sub/Rerender/resolve/main/models/gmflow_sintel-0c07dcb3.pth\",\n            map_location=lambda storage, loc: storage,\n        )\n        weights = checkpoint[\"model\"] if \"model\" in checkpoint else checkpoint\n        flow_model.load_state_dict(weights, strict=False)\n        flow_model.eval()\n        self.flow_model = flow_model\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ):\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            image_embeds = []\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)\n                single_negative_image_embeds = torch.stack(\n                    [single_negative_image_embeds] * num_images_per_prompt, dim=0\n                )\n\n                if do_classifier_free_guidance:\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                    single_image_embeds = single_image_embeds.to(device)\n\n                image_embeds.append(single_image_embeds)\n        else:\n            repeat_dims = [1]\n            image_embeds = []\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                    single_negative_image_embeds = single_negative_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))\n                    )\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                else:\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                image_embeds.append(single_image_embeds)\n\n        return image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents\n    def decode_latents(self, latents):\n        deprecation_message = \"The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead\"\n        deprecate(\"decode_latents\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        image,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        controlnet_conditioning_scale=1.0,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        # `prompt` needs more sophisticated handling when there are multiple\n        # conditionings.\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if isinstance(prompt, list):\n                logger.warning(\n                    f\"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}\"\n                    \" prompts. The conditionings will be fixed across the prompts.\"\n                )\n\n        # Check `image`\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            self.check_image(image, prompt, prompt_embeds)\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if not isinstance(image, list):\n                raise TypeError(\"For multiple controlnets: `image` must be type `list`\")\n\n            # When `image` is a nested list:\n            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])\n            elif any(isinstance(i, list) for i in image):\n                raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif len(image) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets.\"\n                )\n\n            for image_ in image:\n                self.check_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if len(control_guidance_start) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n        if ip_adapter_image_embeds is not None:\n            if not isinstance(ip_adapter_image_embeds, list):\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\n                )\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\n                )\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image\n    def prepare_control_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n        if hasattr(self.scheduler, \"set_begin_index\"):\n            self.scheduler.set_begin_index(t_start * self.scheduler.order)\n\n        return timesteps, num_inference_steps - t_start\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents\n    def prepare_latents(\n        self, image, timestep, batch_size, num_images_per_prompt, dtype, device, repeat_noise, generator=None\n    ):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            init_latents = image\n\n        else:\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            elif isinstance(generator, list):\n                init_latents = [\n                    retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                    for i in range(batch_size)\n                ]\n                init_latents = torch.cat(init_latents, dim=0)\n            else:\n                init_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n            init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            deprecation_message = (\n                f\"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial\"\n                \" images (`image`). Initial images are now duplicating to match the number of text prompts. Note\"\n                \" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update\"\n                \" your script to pass as many initial images as text prompts to suppress this warning.\"\n            )\n            deprecate(\"len(prompt) != len(image)\", \"1.0.0\", deprecation_message, standard_warn=False)\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        shape = init_latents.shape\n        if repeat_noise:\n            noise = randn_tensor((1, *shape[1:]), generator=generator, device=device, dtype=dtype)\n            one_tuple = (1,) * (len(shape) - 1)\n            noise = noise.repeat(batch_size, *one_tuple)\n        else:\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        # get latents\n        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n        latents = init_latents\n\n        return latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        frames: Union[List[np.ndarray], torch.FloatTensor] = None,\n        control_frames: Union[List[np.ndarray], torch.FloatTensor] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        strength: float = 0.8,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 0.8,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        end_opt_step=15,\n        num_intraattn_steps=1,\n        step_interattn_end=350,\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            frames (`List[np.ndarray]` or `torch.FloatTensor`): The input images to be used as the starting point for the image generation process.\n            control_frames (`List[np.ndarray]` or `torch.FloatTensor`): The ControlNet input images condition to provide guidance to the `unet` for generation.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            strength (`float`, *optional*, defaults to 0.8):\n                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1\n                essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set\n                the corresponding scale as a list.\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                The ControlNet encoder tries to recognize the content of the input image even if you remove all\n                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the ControlNet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the ControlNet stops applying.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            end_opt_step:\n                The feature optimization is activated from strength * num_inference_step to end_opt_step.\n            num_intraattn_steps:\n                Apply num_interattn_steps steps of spatial-guided attention.\n            step_interattn_end:\n                Apply temporal-guided attention in [step_interattn_end, 1000] steps\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            control_frames[0],\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        batch_size = len(frames)\n\n        device = self._execution_device\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        global_pool_conditions = (\n            controlnet.config.global_pool_conditions\n            if isinstance(controlnet, ControlNetModel)\n            else controlnet.nets[0].config.global_pool_conditions\n        )\n        guess_mode = guess_mode or global_pool_conditions\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n        prompt_embeds = prompt_embeds.repeat(batch_size, 1, 1)\n        negative_prompt_embeds = negative_prompt_embeds.repeat(batch_size, 1, 1)\n\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 4. Prepare image\n        imgs_np = []\n        for frame in frames:\n            if isinstance(frame, PIL.Image.Image):\n                imgs_np.append(np.asarray(frame))\n            else:\n                # np.ndarray\n                imgs_np.append(frame)\n        images_pt = self.image_processor.preprocess(frames).to(dtype=torch.float32)\n\n        # 5. Prepare controlnet_conditioning_image\n        if isinstance(controlnet, ControlNetModel):\n            control_image = self.prepare_control_image(\n                image=control_frames,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n        elif isinstance(controlnet, MultiControlNetModel):\n            control_images = []\n\n            for control_image_ in control_frames:\n                control_image_ = self.prepare_control_image(\n                    image=control_image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                control_images.append(control_image_)\n\n            control_image = control_images\n        else:\n            assert False\n\n        self.flow_model.to(device)\n\n        flows, occs, attn_mask, interattn_paras = get_flow_and_interframe_paras(self.flow_model, imgs_np)\n        correlation_matrix = get_intraframe_paras(self, images_pt, self.frescoProc, prompt_embeds, generator)\n\n        \"\"\"\n        Flexible settings for attention:\n        * Turn off FRESCO-guided attention: frescoProc.controller.disable_controller()\n        Then you can turn on one specific attention submodule\n        * Turn on Cross-frame attention: frescoProc.controller.enable_cfattn(attn_mask)\n        * Turn on Spatial-guided attention: frescoProc.controller.enable_intraattn()\n        * Turn on Temporal-guided attention: frescoProc.controller.enable_interattn(interattn_paras)\n\n        Flexible settings for optimization:\n        * Turn off Spatial-guided optimization: set optimize_temporal = False in apply_FRESCO_opt()\n        * Turn off Temporal-guided optimization: set correlation_matrix = [] in apply_FRESCO_opt()\n        * Turn off FRESCO-guided optimization: disable_FRESCO_opt(pipe)\n\n        Flexible settings for background smoothing:\n        * Turn off background smoothing: set saliency = None in apply_FRESCO_opt()\n        \"\"\"\n\n        self.frescoProc.controller.enable_controller(interattn_paras=interattn_paras, attn_mask=attn_mask)\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n        apply_FRESCO_opt(\n            self,\n            steps=timesteps[:end_opt_step],\n            flows=flows,\n            occs=occs,\n            correlation_matrix=correlation_matrix,\n            saliency=None,\n            optimize_temporal=True,\n        )\n\n        clear_cache()\n\n        # 5. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n        self._num_timesteps = len(timesteps)\n\n        # 6. Prepare latent variables\n        latents = self.prepare_latents(\n            images_pt,\n            latent_timestep,\n            batch_size,\n            num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator=generator,\n            repeat_noise=True,\n        )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = (\n            {\"image_embeds\": image_embeds}\n            if ip_adapter_image is not None or ip_adapter_image_embeds is not None\n            else None\n        )\n\n        # 7.2 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if i >= num_intraattn_steps:\n                    self.frescoProc.controller.disable_intraattn()\n                if t < step_interattn_end:\n                    self.frescoProc.controller.disable_interattn()\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # controlnet(s) inference\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                else:\n                    control_model_input = latent_model_input\n                    controlnet_prompt_embeds = prompt_embeds\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=controlnet_prompt_embeds,\n                    controlnet_cond=control_image,\n                    conditioning_scale=cond_scale,\n                    guess_mode=guess_mode,\n                    return_dict=False,\n                )\n\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Inferred ControlNet only for the conditional batch.\n                    # To apply the output of ControlNet to both the unconditional and conditional batches,\n                    # add 0 to the unconditional batch to keep it unchanged.\n                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\n                0\n            ]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/gluegen.py",
    "content": "import inspect\nfrom typing import Any, Dict, List, Optional, Union\n\nimport torch\nimport torch.nn as nn\nfrom transformers import AutoModel, AutoTokenizer, CLIPImageProcessor\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    logging,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass TranslatorBase(nn.Module):\n    def __init__(self, num_tok, dim, dim_out, mult=2):\n        super().__init__()\n\n        self.dim_in = dim\n        self.dim_out = dim_out\n\n        self.net_tok = nn.Sequential(\n            nn.Linear(num_tok, int(num_tok * mult)),\n            nn.LayerNorm(int(num_tok * mult)),\n            nn.GELU(),\n            nn.Linear(int(num_tok * mult), int(num_tok * mult)),\n            nn.LayerNorm(int(num_tok * mult)),\n            nn.GELU(),\n            nn.Linear(int(num_tok * mult), num_tok),\n            nn.LayerNorm(num_tok),\n        )\n\n        self.net_sen = nn.Sequential(\n            nn.Linear(dim, int(dim * mult)),\n            nn.LayerNorm(int(dim * mult)),\n            nn.GELU(),\n            nn.Linear(int(dim * mult), int(dim * mult)),\n            nn.LayerNorm(int(dim * mult)),\n            nn.GELU(),\n            nn.Linear(int(dim * mult), dim_out),\n            nn.LayerNorm(dim_out),\n        )\n\n    def forward(self, x):\n        if self.dim_in == self.dim_out:\n            indentity_0 = x\n            x = self.net_sen(x)\n            x += indentity_0\n            x = x.transpose(1, 2)\n\n            indentity_1 = x\n            x = self.net_tok(x)\n            x += indentity_1\n            x = x.transpose(1, 2)\n        else:\n            x = self.net_sen(x)\n            x = x.transpose(1, 2)\n\n            x = self.net_tok(x)\n            x = x.transpose(1, 2)\n        return x\n\n\nclass TranslatorBaseNoLN(nn.Module):\n    def __init__(self, num_tok, dim, dim_out, mult=2):\n        super().__init__()\n\n        self.dim_in = dim\n        self.dim_out = dim_out\n\n        self.net_tok = nn.Sequential(\n            nn.Linear(num_tok, int(num_tok * mult)),\n            nn.GELU(),\n            nn.Linear(int(num_tok * mult), int(num_tok * mult)),\n            nn.GELU(),\n            nn.Linear(int(num_tok * mult), num_tok),\n        )\n\n        self.net_sen = nn.Sequential(\n            nn.Linear(dim, int(dim * mult)),\n            nn.GELU(),\n            nn.Linear(int(dim * mult), int(dim * mult)),\n            nn.GELU(),\n            nn.Linear(int(dim * mult), dim_out),\n        )\n\n    def forward(self, x):\n        if self.dim_in == self.dim_out:\n            indentity_0 = x\n            x = self.net_sen(x)\n            x += indentity_0\n            x = x.transpose(1, 2)\n\n            indentity_1 = x\n            x = self.net_tok(x)\n            x += indentity_1\n            x = x.transpose(1, 2)\n        else:\n            x = self.net_sen(x)\n            x = x.transpose(1, 2)\n\n            x = self.net_tok(x)\n            x = x.transpose(1, 2)\n        return x\n\n\nclass TranslatorNoLN(nn.Module):\n    def __init__(self, num_tok, dim, dim_out, mult=2, depth=5):\n        super().__init__()\n\n        self.blocks = nn.ModuleList([TranslatorBase(num_tok, dim, dim, mult=2) for d in range(depth)])\n        self.gelu = nn.GELU()\n\n        self.tail = TranslatorBaseNoLN(num_tok, dim, dim_out, mult=2)\n\n    def forward(self, x):\n        for block in self.blocks:\n            x = block(x) + x\n            x = self.gelu(x)\n\n        x = self.tail(x)\n        return x\n\n\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass GlueGenStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionLoraLoaderMixin):\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: AutoModel,\n        tokenizer: AutoTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        language_adapter: TranslatorNoLN = None,\n        tensor_norm: torch.Tensor = None,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            language_adapter=language_adapter,\n            tensor_norm=tensor_norm,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def load_language_adapter(\n        self,\n        model_path: str,\n        num_token: int,\n        dim: int,\n        dim_out: int,\n        tensor_norm: torch.Tensor,\n        mult: int = 2,\n        depth: int = 5,\n    ):\n        device = self._execution_device\n        self.tensor_norm = tensor_norm.to(device)\n        self.language_adapter = TranslatorNoLN(num_tok=num_token, dim=dim, dim_out=dim_out, mult=mult, depth=depth).to(\n            device\n        )\n        self.language_adapter.load_state_dict(torch.load(model_path))\n\n    def _adapt_language(self, prompt_embeds: torch.Tensor):\n        prompt_embeds = prompt_embeds / 3\n        prompt_embeds = self.language_adapter(prompt_embeds) * (self.tensor_norm / 2)\n        return prompt_embeds\n\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            elif self.language_adapter is not None:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n            # Run prompt language adapter\n            if self.language_adapter is not None:\n                prompt_embeds = self._adapt_language(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n            # Run negative prompt language adapter\n            if self.language_adapter is not None:\n                negative_prompt_embeds = self._adapt_language(negative_prompt_embeds)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        clip_skip: Optional[int] = None,\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when\n                using zero terminal SNR.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n        # to deal with lora scaling and other possible forward hooks\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 6.2 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\n                0\n            ]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/hd_painter.py",
    "content": "import math\nimport numbers\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom diffusers.image_processor import PipelineImageInput\nfrom diffusers.models import AsymmetricAutoencoderKL, ImageProjection\nfrom diffusers.models.attention_processor import Attention, AttnProcessor\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import (\n    StableDiffusionInpaintPipeline,\n    retrieve_timesteps,\n)\nfrom diffusers.utils import deprecate\n\n\nclass RASGAttnProcessor:\n    def __init__(self, mask, token_idx, scale_factor):\n        self.attention_scores = None  # Stores the last output of the similarity matrix here. Each layer will get its own RASGAttnProcessor assigned\n        self.mask = mask\n        self.token_idx = token_idx\n        self.scale_factor = scale_factor\n        self.mask_resoltuion = mask.shape[-1] * mask.shape[-2]  # 64 x 64 if the image is 512x512\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        temb: Optional[torch.Tensor] = None,\n        scale: float = 1.0,\n    ) -> torch.Tensor:\n        # Same as the default AttnProcessor up until the part where similarity matrix gets saved\n        downscale_factor = self.mask_resoltuion // hidden_states.shape[1]\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        # Automatically recognize the resolution and save the attention similarity values\n        # We need to use the values before the softmax function, hence the rewritten get_attention_scores function.\n        if downscale_factor == self.scale_factor**2:\n            self.attention_scores = get_attention_scores(attn, query, key, attention_mask)\n            attention_probs = self.attention_scores.softmax(dim=-1)\n            attention_probs = attention_probs.to(query.dtype)\n        else:\n            attention_probs = attn.get_attention_scores(query, key, attention_mask)  # Original code\n\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass PAIntAAttnProcessor:\n    def __init__(self, transformer_block, mask, token_idx, do_classifier_free_guidance, scale_factors):\n        self.transformer_block = transformer_block  # Stores the parent transformer block.\n        self.mask = mask\n        self.scale_factors = scale_factors\n        self.do_classifier_free_guidance = do_classifier_free_guidance\n        self.token_idx = token_idx\n        self.shape = mask.shape[2:]\n        self.mask_resoltuion = mask.shape[-1] * mask.shape[-2]  # 64 x 64\n        self.default_processor = AttnProcessor()\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        temb: Optional[torch.Tensor] = None,\n        scale: float = 1.0,\n    ) -> torch.Tensor:\n        # Automatically recognize the resolution of the current attention layer and resize the masks accordingly\n        downscale_factor = self.mask_resoltuion // hidden_states.shape[1]\n\n        mask = None\n        for factor in self.scale_factors:\n            if downscale_factor == factor**2:\n                shape = (self.shape[0] // factor, self.shape[1] // factor)\n                mask = F.interpolate(self.mask, shape, mode=\"bicubic\")  # B, 1, H, W\n                break\n        if mask is None:\n            return self.default_processor(attn, hidden_states, encoder_hidden_states, attention_mask, temb, scale)\n\n        # STARTS HERE\n        residual = hidden_states\n        # Save the input hidden_states for later use\n        input_hidden_states = hidden_states\n\n        # ================================================== #\n        # =============== SELF ATTENTION 1 ================= #\n        # ================================================== #\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        # self_attention_probs = attn.get_attention_scores(query, key, attention_mask) # We can't use post-softmax attention scores in this case\n        self_attention_scores = get_attention_scores(\n            attn, query, key, attention_mask\n        )  # The custom function returns pre-softmax probabilities\n        self_attention_probs = self_attention_scores.softmax(\n            dim=-1\n        )  # Manually compute the probabilities here, the scores will be reused in the second part of PAIntA\n        self_attention_probs = self_attention_probs.to(query.dtype)\n\n        hidden_states = torch.bmm(self_attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        # x = x + self.attn1(self.norm1(x))\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:  # So many residuals everywhere\n            hidden_states = hidden_states + residual\n\n        self_attention_output_hidden_states = hidden_states / attn.rescale_output_factor\n\n        # ================================================== #\n        # ============ BasicTransformerBlock =============== #\n        # ================================================== #\n        # We use a hack by running the code from the BasicTransformerBlock that is between Self and Cross attentions here\n        # The other option would've been modifying the BasicTransformerBlock and adding this functionality here.\n        # I assumed that changing the BasicTransformerBlock would have been a bigger deal and decided to use this hack instead.\n\n        # The SelfAttention block receives the normalized latents from the BasicTransformerBlock,\n        # But the residual of the output is the non-normalized version.\n        # Therefore we unnormalize the input hidden state here\n        unnormalized_input_hidden_states = (\n            input_hidden_states + self.transformer_block.norm1.bias\n        ) * self.transformer_block.norm1.weight\n\n        # TODO: return if necessary\n        # if self.use_ada_layer_norm_zero:\n        #     attn_output = gate_msa.unsqueeze(1) * attn_output\n        # elif self.use_ada_layer_norm_single:\n        #     attn_output = gate_msa * attn_output\n\n        transformer_hidden_states = self_attention_output_hidden_states + unnormalized_input_hidden_states\n        if transformer_hidden_states.ndim == 4:\n            transformer_hidden_states = transformer_hidden_states.squeeze(1)\n\n        # TODO: return if necessary\n        # 2.5 GLIGEN Control\n        # if gligen_kwargs is not None:\n        #     transformer_hidden_states = self.fuser(transformer_hidden_states, gligen_kwargs[\"objs\"])\n        # NOTE: we experimented with using GLIGEN and HDPainter together, the results were not that great\n\n        # 3. Cross-Attention\n        if self.transformer_block.use_ada_layer_norm:\n            # transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states, timestep)\n            raise NotImplementedError()\n        elif self.transformer_block.use_ada_layer_norm_zero or self.transformer_block.use_layer_norm:\n            transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states)\n        elif self.transformer_block.use_ada_layer_norm_single:\n            # For PixArt norm2 isn't applied here:\n            # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103\n            transformer_norm_hidden_states = transformer_hidden_states\n        elif self.transformer_block.use_ada_layer_norm_continuous:\n            # transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states, added_cond_kwargs[\"pooled_text_emb\"])\n            raise NotImplementedError()\n        else:\n            raise ValueError(\"Incorrect norm\")\n\n        if self.transformer_block.pos_embed is not None and self.transformer_block.use_ada_layer_norm_single is False:\n            transformer_norm_hidden_states = self.transformer_block.pos_embed(transformer_norm_hidden_states)\n\n        # ================================================== #\n        # ================= CROSS ATTENTION ================ #\n        # ================================================== #\n\n        # We do an initial pass of the CrossAttention up to obtaining the similarity matrix here.\n        # The similarity matrix is used to obtain scaling coefficients for the attention matrix of the self attention\n        # We reuse the previously computed self-attention matrix, and only repeat the steps after the softmax\n\n        cross_attention_input_hidden_states = (\n            transformer_norm_hidden_states  # Renaming the variable for the sake of readability\n        )\n\n        # TODO: check if classifier_free_guidance is being used before splitting here\n        if self.do_classifier_free_guidance:\n            # Our scaling coefficients depend only on the conditional part, so we split the inputs\n            (\n                _cross_attention_input_hidden_states_unconditional,\n                cross_attention_input_hidden_states_conditional,\n            ) = cross_attention_input_hidden_states.chunk(2)\n\n            # Same split for the encoder_hidden_states i.e. the tokens\n            # Since the SelfAttention processors don't get the encoder states as input, we inject them into the processor in the beginning.\n            _encoder_hidden_states_unconditional, encoder_hidden_states_conditional = self.encoder_hidden_states.chunk(\n                2\n            )\n        else:\n            cross_attention_input_hidden_states_conditional = cross_attention_input_hidden_states\n            encoder_hidden_states_conditional = self.encoder_hidden_states.chunk(2)\n\n        # Rename the variables for the sake of readability\n        # The part below is the beginning of the __call__ function of the following CrossAttention layer\n        cross_attention_hidden_states = cross_attention_input_hidden_states_conditional\n        cross_attention_encoder_hidden_states = encoder_hidden_states_conditional\n\n        attn2 = self.transformer_block.attn2\n\n        if attn2.spatial_norm is not None:\n            cross_attention_hidden_states = attn2.spatial_norm(cross_attention_hidden_states, temb)\n\n        input_ndim = cross_attention_hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = cross_attention_hidden_states.shape\n            cross_attention_hidden_states = cross_attention_hidden_states.view(\n                batch_size, channel, height * width\n            ).transpose(1, 2)\n\n        (\n            batch_size,\n            sequence_length,\n            _,\n        ) = cross_attention_hidden_states.shape  # It is definitely a cross attention, so no need for an if block\n        # TODO: change the attention_mask here\n        attention_mask = attn2.prepare_attention_mask(\n            None, sequence_length, batch_size\n        )  # I assume the attention mask is the same...\n\n        if attn2.group_norm is not None:\n            cross_attention_hidden_states = attn2.group_norm(cross_attention_hidden_states.transpose(1, 2)).transpose(\n                1, 2\n            )\n\n        query2 = attn2.to_q(cross_attention_hidden_states)\n\n        if attn2.norm_cross:\n            cross_attention_encoder_hidden_states = attn2.norm_encoder_hidden_states(\n                cross_attention_encoder_hidden_states\n            )\n\n        key2 = attn2.to_k(cross_attention_encoder_hidden_states)\n        query2 = attn2.head_to_batch_dim(query2)\n        key2 = attn2.head_to_batch_dim(key2)\n\n        cross_attention_probs = attn2.get_attention_scores(query2, key2, attention_mask)\n\n        # CrossAttention ends here, the remaining part is not used\n\n        # ================================================== #\n        # ================ SELF ATTENTION 2 ================ #\n        # ================================================== #\n        # DEJA VU!\n\n        mask = (mask > 0.5).to(self_attention_output_hidden_states.dtype)\n        m = mask.to(self_attention_output_hidden_states.device)\n        # m = rearrange(m, 'b c h w -> b (h w) c').contiguous()\n        m = m.permute(0, 2, 3, 1).reshape((m.shape[0], -1, m.shape[1])).contiguous()  # B HW 1\n        m = torch.matmul(m, m.permute(0, 2, 1)) + (1 - m)\n\n        # # Compute scaling coefficients for the similarity matrix\n        # # Select the cross attention values for the correct tokens only!\n        # cross_attention_probs = cross_attention_probs.mean(dim = 0)\n        # cross_attention_probs = cross_attention_probs[:, self.token_idx].sum(dim=1)\n\n        # cross_attention_probs = cross_attention_probs.reshape(shape)\n        # gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).to(self_attention_output_hidden_states.device)\n        # cross_attention_probs = gaussian_smoothing(cross_attention_probs.unsqueeze(0))[0] # optional smoothing\n        # cross_attention_probs = cross_attention_probs.reshape(-1)\n        # cross_attention_probs = ((cross_attention_probs - torch.median(cross_attention_probs.ravel())) / torch.max(cross_attention_probs.ravel())).clip(0, 1)\n\n        # c = (1 - m) * cross_attention_probs.reshape(1, 1, -1) + m # PAIntA scaling coefficients\n\n        # Compute scaling coefficients for the similarity matrix\n        # Select the cross attention values for the correct tokens only!\n\n        batch_size, dims, channels = cross_attention_probs.shape\n        batch_size = batch_size // attn.heads\n        cross_attention_probs = cross_attention_probs.reshape((batch_size, attn.heads, dims, channels))  # B, D, HW, T\n\n        cross_attention_probs = cross_attention_probs.mean(dim=1)  # B, HW, T\n        cross_attention_probs = cross_attention_probs[..., self.token_idx].sum(dim=-1)  # B, HW\n        cross_attention_probs = cross_attention_probs.reshape((batch_size,) + shape)  # , B, H, W\n\n        gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).to(\n            self_attention_output_hidden_states.device\n        )\n        cross_attention_probs = gaussian_smoothing(cross_attention_probs[:, None])[:, 0]  # optional smoothing B, H, W\n\n        # Median normalization\n        cross_attention_probs = cross_attention_probs.reshape(batch_size, -1)  # B, HW\n        cross_attention_probs = (\n            cross_attention_probs - cross_attention_probs.median(dim=-1, keepdim=True).values\n        ) / cross_attention_probs.max(dim=-1, keepdim=True).values\n        cross_attention_probs = cross_attention_probs.clip(0, 1)\n\n        c = (1 - m) * cross_attention_probs.reshape(batch_size, 1, -1) + m\n        c = c.repeat_interleave(attn.heads, 0)  # BD, HW\n        if self.do_classifier_free_guidance:\n            c = torch.cat([c, c])  # 2BD, HW\n\n        # Rescaling the original self-attention matrix\n        self_attention_scores_rescaled = self_attention_scores * c\n        self_attention_probs_rescaled = self_attention_scores_rescaled.softmax(dim=-1)\n\n        # Continuing the self attention normally using the new matrix\n        hidden_states = torch.bmm(self_attention_probs_rescaled, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + input_hidden_states\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass StableDiffusionHDPainterPipeline(StableDiffusionInpaintPipeline):\n    def get_tokenized_prompt(self, prompt):\n        out = self.tokenizer(prompt)\n        return [self.tokenizer.decode(x) for x in out[\"input_ids\"]]\n\n    def init_attn_processors(\n        self,\n        mask,\n        token_idx,\n        use_painta=True,\n        use_rasg=True,\n        painta_scale_factors=[2, 4],  # 64x64 -> [16x16, 32x32]\n        rasg_scale_factor=4,  # 64x64 -> 16x16\n        self_attention_layer_name=\"attn1\",\n        cross_attention_layer_name=\"attn2\",\n        list_of_painta_layer_names=None,\n        list_of_rasg_layer_names=None,\n    ):\n        default_processor = AttnProcessor()\n        width, height = mask.shape[-2:]\n        width, height = width // self.vae_scale_factor, height // self.vae_scale_factor\n\n        painta_scale_factors = [x * self.vae_scale_factor for x in painta_scale_factors]\n        rasg_scale_factor = self.vae_scale_factor * rasg_scale_factor\n\n        attn_processors = {}\n        for x in self.unet.attn_processors:\n            if (list_of_painta_layer_names is None and self_attention_layer_name in x) or (\n                list_of_painta_layer_names is not None and x in list_of_painta_layer_names\n            ):\n                if use_painta:\n                    transformer_block = self.unet.get_submodule(x.replace(\".attn1.processor\", \"\"))\n                    attn_processors[x] = PAIntAAttnProcessor(\n                        transformer_block, mask, token_idx, self.do_classifier_free_guidance, painta_scale_factors\n                    )\n                else:\n                    attn_processors[x] = default_processor\n            elif (list_of_rasg_layer_names is None and cross_attention_layer_name in x) or (\n                list_of_rasg_layer_names is not None and x in list_of_rasg_layer_names\n            ):\n                if use_rasg:\n                    attn_processors[x] = RASGAttnProcessor(mask, token_idx, rasg_scale_factor)\n                else:\n                    attn_processors[x] = default_processor\n\n        self.unet.set_attn_processor(attn_processors)\n        # import json\n        # with open('/home/hayk.manukyan/repos/diffusers/debug.txt', 'a')  as f:\n        #     json.dump({x:str(y) for x,y in self.unet.attn_processors.items()}, f, indent=4)\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        mask_image: PipelineImageInput = None,\n        masked_image_latents: torch.Tensor = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        padding_mask_crop: Optional[int] = None,\n        strength: float = 1.0,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        guidance_scale: float = 7.5,\n        positive_prompt: str | None = \"\",\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.01,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        clip_skip: int = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        use_painta=True,\n        use_rasg=True,\n        self_attention_layer_name=\".attn1\",\n        cross_attention_layer_name=\".attn2\",\n        painta_scale_factors=[2, 4],  # 16 x 16 and 32 x 32\n        rasg_scale_factor=4,  # 16x16 by default\n        list_of_painta_layer_names=None,\n        list_of_rasg_layer_names=None,\n        **kwargs,\n    ):\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        #\n        prompt_no_positives = prompt\n        if isinstance(prompt, list):\n            prompt = [x + positive_prompt for x in prompt]\n        else:\n            prompt = prompt + positive_prompt\n\n        # 1. Check inputs\n        self.check_inputs(\n            prompt,\n            image,\n            mask_image,\n            height,\n            width,\n            strength,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            callback_on_step_end_tensor_inputs,\n            padding_mask_crop,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # assert batch_size == 1, \"Does not work with batch size > 1 currently\"\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        if ip_adapter_image is not None:\n            output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True\n            image_embeds, negative_image_embeds = self.encode_image(\n                ip_adapter_image, device, num_images_per_prompt, output_hidden_state\n            )\n            if self.do_classifier_free_guidance:\n                image_embeds = torch.cat([negative_image_embeds, image_embeds])\n\n        # 4. set timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n        timesteps, num_inference_steps = self.get_timesteps(\n            num_inference_steps=num_inference_steps, strength=strength, device=device\n        )\n        # check that number of inference steps is not < 1 - as this doesn't make sense\n        if num_inference_steps < 1:\n            raise ValueError(\n                f\"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline\"\n                f\"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline.\"\n            )\n        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise\n        is_strength_max = strength == 1.0\n\n        # 5. Preprocess mask and image\n\n        if padding_mask_crop is not None:\n            crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)\n            resize_mode = \"fill\"\n        else:\n            crops_coords = None\n            resize_mode = \"default\"\n\n        original_image = image\n        init_image = self.image_processor.preprocess(\n            image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode\n        )\n        init_image = init_image.to(dtype=torch.float32)\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.vae.config.latent_channels\n        num_channels_unet = self.unet.config.in_channels\n        return_image_latents = num_channels_unet == 4\n\n        latents_outputs = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n            image=init_image,\n            timestep=latent_timestep,\n            is_strength_max=is_strength_max,\n            return_noise=True,\n            return_image_latents=return_image_latents,\n        )\n\n        if return_image_latents:\n            latents, noise, image_latents = latents_outputs\n        else:\n            latents, noise = latents_outputs\n\n        # 7. Prepare mask latent variables\n        mask_condition = self.mask_processor.preprocess(\n            mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords\n        )\n\n        if masked_image_latents is None:\n            masked_image = init_image * (mask_condition < 0.5)\n        else:\n            masked_image = masked_image_latents\n\n        mask, masked_image_latents = self.prepare_mask_latents(\n            mask_condition,\n            masked_image,\n            batch_size * num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            self.do_classifier_free_guidance,\n        )\n\n        # 7.5 Setting up HD-Painter\n\n        # Get the indices of the tokens to be modified by both RASG and PAIntA\n        token_idx = list(range(1, self.get_tokenized_prompt(prompt_no_positives).index(\"<|endoftext|>\"))) + [\n            self.get_tokenized_prompt(prompt).index(\"<|endoftext|>\")\n        ]\n\n        # Setting up the attention processors\n        self.init_attn_processors(\n            mask_condition,\n            token_idx,\n            use_painta,\n            use_rasg,\n            painta_scale_factors=painta_scale_factors,\n            rasg_scale_factor=rasg_scale_factor,\n            self_attention_layer_name=self_attention_layer_name,\n            cross_attention_layer_name=cross_attention_layer_name,\n            list_of_painta_layer_names=list_of_painta_layer_names,\n            list_of_rasg_layer_names=list_of_rasg_layer_names,\n        )\n\n        # 8. Check that sizes of mask, masked image and latents match\n        if num_channels_unet == 9:\n            # default case for stable-diffusion-v1-5/stable-diffusion-inpainting\n            num_channels_mask = mask.shape[1]\n            num_channels_masked_image = masked_image_latents.shape[1]\n            if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:\n                raise ValueError(\n                    f\"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects\"\n                    f\" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +\"\n                    f\" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}\"\n                    f\" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of\"\n                    \" `pipeline.unet` or your `mask_image` or `image` input.\"\n                )\n        elif num_channels_unet != 4:\n            raise ValueError(\n                f\"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}.\"\n            )\n\n        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        if use_rasg:\n            extra_step_kwargs[\"generator\"] = None\n\n        # 9.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = {\"image_embeds\": image_embeds} if ip_adapter_image is not None else None\n\n        # 9.2 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 10. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n        painta_active = True\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                if t < 500 and painta_active:\n                    self.init_attn_processors(\n                        mask_condition,\n                        token_idx,\n                        False,\n                        use_rasg,\n                        painta_scale_factors=painta_scale_factors,\n                        rasg_scale_factor=rasg_scale_factor,\n                        self_attention_layer_name=self_attention_layer_name,\n                        cross_attention_layer_name=cross_attention_layer_name,\n                        list_of_painta_layer_names=list_of_painta_layer_names,\n                        list_of_rasg_layer_names=list_of_rasg_layer_names,\n                    )\n                    painta_active = False\n\n                with torch.enable_grad():\n                    self.unet.zero_grad()\n                    latents = latents.detach()\n                    latents.requires_grad = True\n\n                    # expand the latents if we are doing classifier free guidance\n                    latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n\n                    # concat latents, mask, masked_image_latents in the channel dimension\n                    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                    if num_channels_unet == 9:\n                        latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)\n\n                    self.scheduler.latents = latents\n                    self.encoder_hidden_states = prompt_embeds\n                    for attn_processor in self.unet.attn_processors.values():\n                        attn_processor.encoder_hidden_states = prompt_embeds\n\n                    # predict the noise residual\n                    noise_pred = self.unet(\n                        latent_model_input,\n                        t,\n                        encoder_hidden_states=prompt_embeds,\n                        timestep_cond=timestep_cond,\n                        cross_attention_kwargs=self.cross_attention_kwargs,\n                        added_cond_kwargs=added_cond_kwargs,\n                        return_dict=False,\n                    )[0]\n\n                    # perform guidance\n                    if self.do_classifier_free_guidance:\n                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                        noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                    if use_rasg:\n                        # Perform RASG\n                        _, _, height, width = mask_condition.shape  # 512 x 512\n                        scale_factor = self.vae_scale_factor * rasg_scale_factor  # 8 * 4 = 32\n\n                        # TODO: Fix for > 1 batch_size\n                        rasg_mask = F.interpolate(\n                            mask_condition, (height // scale_factor, width // scale_factor), mode=\"bicubic\"\n                        )[0, 0]  # mode is nearest by default, B, H, W\n\n                        # Aggregate the saved attention maps\n                        attn_map = []\n                        for processor in self.unet.attn_processors.values():\n                            if hasattr(processor, \"attention_scores\") and processor.attention_scores is not None:\n                                if self.do_classifier_free_guidance:\n                                    attn_map.append(processor.attention_scores.chunk(2)[1])  # (B/2) x H, 256, 77\n                                else:\n                                    attn_map.append(processor.attention_scores)  # B x H, 256, 77 ?\n\n                        attn_map = (\n                            torch.cat(attn_map)\n                            .mean(0)\n                            .permute(1, 0)\n                            .reshape((-1, height // scale_factor, width // scale_factor))\n                        )  # 77, 16, 16\n\n                        # Compute the attention score\n                        attn_score = -sum(\n                            [\n                                F.binary_cross_entropy_with_logits(x - 1.0, rasg_mask.to(device))\n                                for x in attn_map[token_idx]\n                            ]\n                        )\n\n                        # Backward the score and compute the gradients\n                        attn_score.backward()\n\n                        # Normalzie the gradients and compute the noise component\n                        variance_noise = latents.grad.detach()\n                        # print(\"VARIANCE SHAPE\", variance_noise.shape)\n                        variance_noise -= torch.mean(variance_noise, [1, 2, 3], keepdim=True)\n                        variance_noise /= torch.std(variance_noise, [1, 2, 3], keepdim=True)\n                    else:\n                        variance_noise = None\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(\n                    noise_pred, t, latents, **extra_step_kwargs, return_dict=False, variance_noise=variance_noise\n                )[0]\n\n                if num_channels_unet == 4:\n                    init_latents_proper = image_latents\n                    if self.do_classifier_free_guidance:\n                        init_mask, _ = mask.chunk(2)\n                    else:\n                        init_mask = mask\n\n                    if i < len(timesteps) - 1:\n                        noise_timestep = timesteps[i + 1]\n                        init_latents_proper = self.scheduler.add_noise(\n                            init_latents_proper, noise, torch.tensor([noise_timestep])\n                        )\n\n                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    mask = callback_outputs.pop(\"mask\", mask)\n                    masked_image_latents = callback_outputs.pop(\"masked_image_latents\", masked_image_latents)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            condition_kwargs = {}\n            if isinstance(self.vae, AsymmetricAutoencoderKL):\n                init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)\n                init_image_condition = init_image.clone()\n                init_image = self._encode_vae_image(init_image, generator=generator)\n                mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)\n                condition_kwargs = {\"image\": init_image_condition, \"mask\": mask_condition}\n            image = self.vae.decode(\n                latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs\n            )[0]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        if padding_mask_crop is not None:\n            image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n\n# ============= Utility Functions ============== #\n\n\nclass GaussianSmoothing(nn.Module):\n    \"\"\"\n    Apply gaussian smoothing on a\n    1d, 2d or 3d tensor. Filtering is performed separately for each channel\n    in the input using a depthwise convolution.\n\n    Args:\n        channels (`int` or `sequence`):\n            Number of channels of the input tensors. The output will have this number of channels as well.\n        kernel_size (`int` or `sequence`):\n            Size of the Gaussian kernel.\n        sigma (`float` or `sequence`):\n            Standard deviation of the Gaussian kernel.\n        dim (`int`, *optional*, defaults to `2`):\n            The number of dimensions of the data. Default is 2 (spatial dimensions).\n    \"\"\"\n\n    def __init__(self, channels, kernel_size, sigma, dim=2):\n        super(GaussianSmoothing, self).__init__()\n        if isinstance(kernel_size, numbers.Number):\n            kernel_size = [kernel_size] * dim\n        if isinstance(sigma, numbers.Number):\n            sigma = [sigma] * dim\n\n        # The gaussian kernel is the product of the\n        # gaussian function of each dimension.\n        kernel = 1\n        meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])\n        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):\n            mean = (size - 1) / 2\n            kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))\n\n        # Make sure sum of values in gaussian kernel equals 1.\n        kernel = kernel / torch.sum(kernel)\n\n        # Reshape to depthwise convolutional weight\n        kernel = kernel.view(1, 1, *kernel.size())\n        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))\n\n        self.register_buffer(\"weight\", kernel)\n        self.groups = channels\n\n        if dim == 1:\n            self.conv = F.conv1d\n        elif dim == 2:\n            self.conv = F.conv2d\n        elif dim == 3:\n            self.conv = F.conv3d\n        else:\n            raise RuntimeError(\"Only 1, 2 and 3 dimensions are supported. Received {}.\".format(dim))\n\n    def forward(self, input):\n        \"\"\"\n        Apply gaussian filter to input.\n\n        Args:\n            input (`torch.Tensor` of shape `(N, C, H, W)`):\n                Input to apply Gaussian filter on.\n\n        Returns:\n            `torch.Tensor`:\n                The filtered output tensor with the same shape as the input.\n        \"\"\"\n        return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding=\"same\")\n\n\ndef get_attention_scores(\n    self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None\n) -> torch.Tensor:\n    r\"\"\"\n    Compute the attention scores.\n\n    Args:\n        query (`torch.Tensor`): The query tensor.\n        key (`torch.Tensor`): The key tensor.\n        attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.\n\n    Returns:\n        `torch.Tensor`: The attention probabilities/scores.\n    \"\"\"\n    if self.upcast_attention:\n        query = query.float()\n        key = key.float()\n\n    if attention_mask is None:\n        baddbmm_input = torch.empty(\n            query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device\n        )\n        beta = 0\n    else:\n        baddbmm_input = attention_mask\n        beta = 1\n\n    attention_scores = torch.baddbmm(\n        baddbmm_input,\n        query,\n        key.transpose(-1, -2),\n        beta=beta,\n        alpha=self.scale,\n    )\n    del baddbmm_input\n\n    if self.upcast_softmax:\n        attention_scores = attention_scores.float()\n\n    return attention_scores\n"
  },
  {
    "path": "examples/community/iadb.py",
    "content": "from typing import List, Optional, Tuple, Union\n\nimport torch\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.configuration_utils import ConfigMixin\nfrom diffusers.pipelines.pipeline_utils import ImagePipelineOutput\nfrom diffusers.schedulers.scheduling_utils import SchedulerMixin\n\n\nclass IADBScheduler(SchedulerMixin, ConfigMixin):\n    \"\"\"\n    IADBScheduler is a scheduler for the Iterative α-(de)Blending denoising method. It is simple and minimalist.\n\n    For more details, see the original paper: https://huggingface.co/papers/2305.03486 and the blog post: https://ggx-research.github.io/publication/2023/05/10/publication-iadb.html\n    \"\"\"\n\n    def step(\n        self,\n        model_output: torch.Tensor,\n        timestep: int,\n        x_alpha: torch.Tensor,\n    ) -> torch.Tensor:\n        \"\"\"\n        Predict the sample at the previous timestep by reversing the ODE. Core function to propagate the diffusion\n        process from the learned model outputs (most often the predicted noise).\n\n        Args:\n            model_output (`torch.Tensor`): direct output from learned diffusion model. It is the direction from x0 to x1.\n            timestep (`float`): current timestep in the diffusion chain.\n            x_alpha (`torch.Tensor`): x_alpha sample for the current timestep\n\n        Returns:\n            `torch.Tensor`: the sample at the previous timestep\n\n        \"\"\"\n        if self.num_inference_steps is None:\n            raise ValueError(\n                \"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler\"\n            )\n\n        alpha = timestep / self.num_inference_steps\n        alpha_next = (timestep + 1) / self.num_inference_steps\n\n        d = model_output\n\n        x_alpha = x_alpha + (alpha_next - alpha) * d\n\n        return x_alpha\n\n    def set_timesteps(self, num_inference_steps: int):\n        self.num_inference_steps = num_inference_steps\n\n    def add_noise(\n        self,\n        original_samples: torch.Tensor,\n        noise: torch.Tensor,\n        alpha: torch.Tensor,\n    ) -> torch.Tensor:\n        return original_samples * alpha + noise * (1 - alpha)\n\n    def __len__(self):\n        return self.config.num_train_timesteps\n\n\nclass IADBPipeline(DiffusionPipeline):\n    r\"\"\"\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Parameters:\n        unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of\n            [`DDPMScheduler`], or [`DDIMScheduler`].\n    \"\"\"\n\n    def __init__(self, unet, scheduler):\n        super().__init__()\n\n        self.register_modules(unet=unet, scheduler=scheduler)\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        batch_size: int = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        num_inference_steps: int = 50,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n    ) -> Union[ImagePipelineOutput, Tuple]:\n        r\"\"\"\n        Args:\n            batch_size (`int`, *optional*, defaults to 1):\n                The number of images to generate.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is\n            True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        # Sample gaussian noise to begin loop\n        if isinstance(self.unet.config.sample_size, int):\n            image_shape = (\n                batch_size,\n                self.unet.config.in_channels,\n                self.unet.config.sample_size,\n                self.unet.config.sample_size,\n            )\n        else:\n            image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)\n\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        image = torch.randn(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)\n\n        # set step values\n        self.scheduler.set_timesteps(num_inference_steps)\n        x_alpha = image.clone()\n        for t in self.progress_bar(range(num_inference_steps)):\n            alpha = t / num_inference_steps\n\n            # 1. predict noise model_output\n            model_output = self.unet(x_alpha, torch.tensor(alpha, device=x_alpha.device)).sample\n\n            # 2. step\n            x_alpha = self.scheduler.step(model_output, t, x_alpha)\n\n        image = (x_alpha * 0.5 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).numpy()\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image,)\n\n        return ImagePipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/imagic_stable_diffusion.py",
    "content": "\"\"\"\nmodeled after the textual_inversion.py / train_dreambooth.py and the work\nof justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb\n\"\"\"\n\nimport inspect\nimport warnings\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nfrom accelerate import Accelerator\n\n# TODO: remove and import from diffusers.utils when the new version of diffusers is released\nfrom packaging import version\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler\nfrom diffusers.utils import logging\n\n\nif version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.Resampling.BILINEAR,\n        \"bilinear\": PIL.Image.Resampling.BILINEAR,\n        \"bicubic\": PIL.Image.Resampling.BICUBIC,\n        \"lanczos\": PIL.Image.Resampling.LANCZOS,\n        \"nearest\": PIL.Image.Resampling.NEAREST,\n    }\nelse:\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.LINEAR,\n        \"bilinear\": PIL.Image.BILINEAR,\n        \"bicubic\": PIL.Image.BICUBIC,\n        \"lanczos\": PIL.Image.LANCZOS,\n        \"nearest\": PIL.Image.NEAREST,\n    }\n# ------------------------------------------------------------------------------\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef preprocess(image):\n    w, h = image.size\n    w, h = (x - x % 32 for x in (w, h))  # resize to integer multiple of 32\n    image = image.resize((w, h), resample=PIL_INTERPOLATION[\"lanczos\"])\n    image = np.array(image).astype(np.float32) / 255.0\n    image = image[None].transpose(0, 3, 1, 2)\n    image = torch.from_numpy(image)\n    return 2.0 * image - 1.0\n\n\nclass ImagicStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):\n    r\"\"\"\n    Pipeline for imagic image editing.\n    See paper here: https://huggingface.co/papers/2210.09276\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offsensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n    def train(\n        self,\n        prompt: Union[str, List[str]],\n        image: Union[torch.Tensor, PIL.Image.Image],\n        height: Optional[int] = 512,\n        width: Optional[int] = 512,\n        generator: torch.Generator | None = None,\n        embedding_learning_rate: float = 0.001,\n        diffusion_model_learning_rate: float = 2e-6,\n        text_embedding_optimization_steps: int = 500,\n        model_fine_tuning_optimization_steps: int = 1000,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        accelerator = Accelerator(\n            gradient_accumulation_steps=1,\n            mixed_precision=\"fp16\",\n        )\n\n        if \"torch_device\" in kwargs:\n            device = kwargs.pop(\"torch_device\")\n            warnings.warn(\n                \"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0.\"\n                \" Consider using `pipe.to(torch_device)` instead.\"\n            )\n\n            if device is None:\n                device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n            self.to(device)\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        # Freeze vae and unet\n        self.vae.requires_grad_(False)\n        self.unet.requires_grad_(False)\n        self.text_encoder.requires_grad_(False)\n        self.unet.eval()\n        self.vae.eval()\n        self.text_encoder.eval()\n\n        if accelerator.is_main_process:\n            accelerator.init_trackers(\n                \"imagic\",\n                config={\n                    \"embedding_learning_rate\": embedding_learning_rate,\n                    \"text_embedding_optimization_steps\": text_embedding_optimization_steps,\n                },\n            )\n\n        # get text embeddings for prompt\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_embeddings = torch.nn.Parameter(\n            self.text_encoder(text_input.input_ids.to(self.device))[0], requires_grad=True\n        )\n        text_embeddings = text_embeddings.detach()\n        text_embeddings.requires_grad_()\n        text_embeddings_orig = text_embeddings.clone()\n\n        # Initialize the optimizer\n        optimizer = torch.optim.Adam(\n            [text_embeddings],  # only optimize the embeddings\n            lr=embedding_learning_rate,\n        )\n\n        if isinstance(image, PIL.Image.Image):\n            image = preprocess(image)\n\n        latents_dtype = text_embeddings.dtype\n        image = image.to(device=self.device, dtype=latents_dtype)\n        init_latent_image_dist = self.vae.encode(image).latent_dist\n        image_latents = init_latent_image_dist.sample(generator=generator)\n        image_latents = 0.18215 * image_latents\n\n        progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process)\n        progress_bar.set_description(\"Steps\")\n\n        global_step = 0\n\n        logger.info(\"First optimizing the text embedding to better reconstruct the init image\")\n        for _ in range(text_embedding_optimization_steps):\n            with accelerator.accumulate(text_embeddings):\n                # Sample noise that we'll add to the latents\n                noise = torch.randn(image_latents.shape).to(image_latents.device)\n                timesteps = torch.randint(1000, (1,), device=image_latents.device)\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)\n\n                # Predict the noise residual\n                noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample\n\n                loss = F.mse_loss(noise_pred, noise, reduction=\"none\").mean([1, 2, 3]).mean()\n                accelerator.backward(loss)\n\n                optimizer.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n            logs = {\"loss\": loss.detach().item()}  # , \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n        accelerator.wait_for_everyone()\n\n        text_embeddings.requires_grad_(False)\n\n        # Now we fine tune the unet to better reconstruct the image\n        self.unet.requires_grad_(True)\n        self.unet.train()\n        optimizer = torch.optim.Adam(\n            self.unet.parameters(),  # only optimize unet\n            lr=diffusion_model_learning_rate,\n        )\n        progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process)\n\n        logger.info(\"Next fine tuning the entire model to better reconstruct the init image\")\n        for _ in range(model_fine_tuning_optimization_steps):\n            with accelerator.accumulate(self.unet.parameters()):\n                # Sample noise that we'll add to the latents\n                noise = torch.randn(image_latents.shape).to(image_latents.device)\n                timesteps = torch.randint(1000, (1,), device=image_latents.device)\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)\n\n                # Predict the noise residual\n                noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample\n\n                loss = F.mse_loss(noise_pred, noise, reduction=\"none\").mean([1, 2, 3]).mean()\n                accelerator.backward(loss)\n\n                optimizer.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n            logs = {\"loss\": loss.detach().item()}  # , \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n        accelerator.wait_for_everyone()\n        self.text_embeddings_orig = text_embeddings_orig\n        self.text_embeddings = text_embeddings\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        alpha: float = 1.2,\n        height: Optional[int] = 512,\n        width: Optional[int] = 512,\n        num_inference_steps: Optional[int] = 50,\n        generator: torch.Generator | None = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        guidance_scale: float = 7.5,\n        eta: float = 0.0,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n        Args:\n            alpha (`float`, *optional*, defaults to 1.2):\n                The interpolation factor between the original and optimized text embeddings. A value closer to 0\n                will resemble the original input image.\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n        if self.text_embeddings is None:\n            raise ValueError(\"Please run the pipe.train() before trying to generate an image.\")\n        if self.text_embeddings_orig is None:\n            raise ValueError(\"Please run the pipe.train() before trying to generate an image.\")\n\n        text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            uncond_tokens = [\"\"]\n            max_length = self.tokenizer.model_max_length\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = uncond_embeddings.shape[1]\n            uncond_embeddings = uncond_embeddings.view(1, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # get the initial random noise unless the user supplied it\n\n        # Unlike in other pipelines, latents need to be generated in the target device\n        # for 1-to-1 results reproducibility with the CompVis implementation.\n        # However this currently doesn't work in `mps`.\n        latents_shape = (1, self.unet.config.in_channels, height // 8, width // 8)\n        latents_dtype = text_embeddings.dtype\n        if self.device.type == \"mps\":\n            # randn does not exist on mps\n            latents = torch.randn(latents_shape, generator=generator, device=\"cpu\", dtype=latents_dtype).to(\n                self.device\n            )\n        else:\n            latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)\n\n        # set timesteps\n        self.scheduler.set_timesteps(num_inference_steps)\n\n        # Some schedulers like PNDM have timesteps as arrays\n        # It's more optimized to move all timesteps to correct device beforehand\n        timesteps_tensor = self.scheduler.timesteps.to(self.device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        for i, t in enumerate(self.progress_bar(timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents).sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(\n                self.device\n            )\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)\n            )\n        else:\n            has_nsfw_concept = None\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/img2img_inpainting.py",
    "content": "import inspect\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler\nfrom diffusers.utils import deprecate, logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef prepare_mask_and_masked_image(image, mask):\n    image = np.array(image.convert(\"RGB\"))\n    image = image[None].transpose(0, 3, 1, 2)\n    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n    mask = np.array(mask.convert(\"L\"))\n    mask = mask.astype(np.float32) / 255.0\n    mask = mask[None, None]\n    mask[mask < 0.5] = 0\n    mask[mask >= 0.5] = 1\n    mask = torch.from_numpy(mask)\n\n    masked_image = image * (mask < 0.5)\n\n    return mask, masked_image\n\n\ndef check_size(image, height, width):\n    if isinstance(image, PIL.Image.Image):\n        w, h = image.size\n    elif isinstance(image, torch.Tensor):\n        *_, h, w = image.shape\n\n    if h != height or w != width:\n        raise ValueError(f\"Image size should be {height}x{width}, but got {h}x{w}\")\n\n\ndef overlay_inner_image(image, inner_image, paste_offset: Tuple[int, ...] = (0, 0)):\n    inner_image = inner_image.convert(\"RGBA\")\n    image = image.convert(\"RGB\")\n\n    image.paste(inner_image, paste_offset, inner_image)\n    image = image.convert(\"RGB\")\n\n    return image\n\n\nclass ImageToImageInpaintingPipeline(DiffusionPipeline):\n    r\"\"\"\n    Pipeline for text-guided image-to-image inpainting using Stable Diffusion. *This is an experimental feature*.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        image: Union[torch.Tensor, PIL.Image.Image],\n        inner_image: Union[torch.Tensor, PIL.Image.Image],\n        mask_image: Union[torch.Tensor, PIL.Image.Image],\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will\n                be masked out with `mask_image` and repainted according to `prompt`.\n            inner_image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be overlaid onto `image`. Non-transparent\n                regions of `inner_image` must fit inside white pixels in `mask_image`. Expects four channels, with\n                the last channel representing the alpha channel, which will be used to blend `inner_image` with\n                `image`. If not provided, it will be forcibly cast to RGBA.\n            mask_image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted\n                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)\n                instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n\n        if isinstance(prompt, str):\n            batch_size = 1\n        elif isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        # check if input sizes are correct\n        check_size(image, height, width)\n        check_size(inner_image, height, width)\n        check_size(mask_image, height, width)\n\n        # get prompt text embeddings\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n\n        if text_input_ids.shape[-1] > self.tokenizer.model_max_length:\n            removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n            )\n            text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]\n        text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = text_embeddings.shape\n        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)\n        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"]\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = text_input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = uncond_embeddings.shape[1]\n            uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)\n            uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # get the initial random noise unless the user supplied it\n        # Unlike in other pipelines, latents need to be generated in the target device\n        # for 1-to-1 results reproducibility with the CompVis implementation.\n        # However this currently doesn't work in `mps`.\n        num_channels_latents = self.vae.config.latent_channels\n        latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)\n        latents_dtype = text_embeddings.dtype\n        if latents is None:\n            if self.device.type == \"mps\":\n                # randn does not exist on mps\n                latents = torch.randn(latents_shape, generator=generator, device=\"cpu\", dtype=latents_dtype).to(\n                    self.device\n                )\n            else:\n                latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)\n        else:\n            if latents.shape != latents_shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {latents_shape}\")\n            latents = latents.to(self.device)\n\n        # overlay the inner image\n        image = overlay_inner_image(image, inner_image)\n\n        # prepare mask and masked_image\n        mask, masked_image = prepare_mask_and_masked_image(image, mask_image)\n        mask = mask.to(device=self.device, dtype=text_embeddings.dtype)\n        masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)\n\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))\n\n        # encode the mask image into latents space so we can concatenate it to the latents\n        masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)\n        masked_image_latents = 0.18215 * masked_image_latents\n\n        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\n        mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1)\n        masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1)\n\n        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask\n        masked_image_latents = (\n            torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents\n        )\n\n        num_channels_mask = mask.shape[1]\n        num_channels_masked_image = masked_image_latents.shape[1]\n\n        if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:\n            raise ValueError(\n                f\"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects\"\n                f\" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +\"\n                f\" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}\"\n                f\" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of\"\n                \" `pipeline.unet` or your `mask_image` or `image` input.\"\n            )\n\n        # set timesteps\n        self.scheduler.set_timesteps(num_inference_steps)\n\n        # Some schedulers like PNDM have timesteps as arrays\n        # It's more optimized to move all timesteps to correct device beforehand\n        timesteps_tensor = self.scheduler.timesteps.to(self.device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        for i, t in enumerate(self.progress_bar(timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n\n            # concat latents, mask, masked_image_latents in the channel dimension\n            latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)\n\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n            # call the callback, if provided\n            if callback is not None and i % callback_steps == 0:\n                step_idx = i // getattr(self.scheduler, \"order\", 1)\n                callback(step_idx, t, latents)\n\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents).sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(\n                self.device\n            )\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)\n            )\n        else:\n            has_nsfw_concept = None\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/instaflow_one_step.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport inspect\r\nfrom typing import Any, Callable, Dict, List, Optional, Union\r\n\r\nimport torch\r\nfrom packaging import version\r\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\r\n\r\nfrom diffusers.configuration_utils import FrozenDict\r\nfrom diffusers.image_processor import VaeImageProcessor\r\nfrom diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin\r\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\r\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\r\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\r\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\r\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\r\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\r\nfrom diffusers.utils import (\r\n    deprecate,\r\n    logging,\r\n)\r\nfrom diffusers.utils.torch_utils import randn_tensor\r\n\r\n\r\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\r\n\r\n\r\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\r\n    \"\"\"\r\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\r\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\r\n    \"\"\"\r\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\r\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\r\n    # rescale the results from guidance (fixes overexposure)\r\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\r\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\r\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\r\n    return noise_cfg\r\n\r\n\r\nclass InstaFlowPipeline(\r\n    DiffusionPipeline,\r\n    StableDiffusionMixin,\r\n    TextualInversionLoaderMixin,\r\n    StableDiffusionLoraLoaderMixin,\r\n    FromSingleFileMixin,\r\n):\r\n    r\"\"\"\r\n    Pipeline for text-to-image generation using Rectified Flow and Euler discretization.\r\n    This customized pipeline is based on StableDiffusionPipeline from the official Diffusers library (0.21.4)\r\n\r\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\r\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\r\n\r\n    The pipeline also inherits the following loading methods:\r\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\r\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\r\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\r\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\r\n\r\n    Args:\r\n        vae ([`AutoencoderKL`]):\r\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\r\n        text_encoder ([`~transformers.CLIPTextModel`]):\r\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\r\n        tokenizer ([`~transformers.CLIPTokenizer`]):\r\n            A `CLIPTokenizer` to tokenize text.\r\n        unet ([`UNet2DConditionModel`]):\r\n            A `UNet2DConditionModel` to denoise the encoded image latents.\r\n        scheduler ([`SchedulerMixin`]):\r\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\r\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\r\n        safety_checker ([`StableDiffusionSafetyChecker`]):\r\n            Classification module that estimates whether generated images could be considered offensive or harmful.\r\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\r\n            about a model's potential harms.\r\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\r\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\r\n    \"\"\"\r\n\r\n    model_cpu_offload_seq = \"text_encoder->unet->vae\"\r\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\r\n    _exclude_from_cpu_offload = [\"safety_checker\"]\r\n\r\n    def __init__(\r\n        self,\r\n        vae: AutoencoderKL,\r\n        text_encoder: CLIPTextModel,\r\n        tokenizer: CLIPTokenizer,\r\n        unet: UNet2DConditionModel,\r\n        scheduler: KarrasDiffusionSchedulers,\r\n        safety_checker: StableDiffusionSafetyChecker,\r\n        feature_extractor: CLIPImageProcessor,\r\n        requires_safety_checker: bool = True,\r\n    ):\r\n        super().__init__()\r\n\r\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\r\n            deprecation_message = (\r\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\r\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\r\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\r\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\r\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\r\n                \" file\"\r\n            )\r\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\r\n            new_config = dict(scheduler.config)\r\n            new_config[\"steps_offset\"] = 1\r\n            scheduler._internal_dict = FrozenDict(new_config)\r\n\r\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\r\n            deprecation_message = (\r\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\r\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\r\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\r\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\r\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\r\n            )\r\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\r\n            new_config = dict(scheduler.config)\r\n            new_config[\"clip_sample\"] = False\r\n            scheduler._internal_dict = FrozenDict(new_config)\r\n\r\n        if safety_checker is None and requires_safety_checker:\r\n            logger.warning(\r\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\r\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\r\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\r\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\r\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\r\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\r\n            )\r\n\r\n        if safety_checker is not None and feature_extractor is None:\r\n            raise ValueError(\r\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\r\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\r\n            )\r\n\r\n        is_unet_version_less_0_9_0 = (\r\n            unet is not None\r\n            and hasattr(unet.config, \"_diffusers_version\")\r\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\r\n        )\r\n        is_unet_sample_size_less_64 = (\r\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\r\n        )\r\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\r\n            deprecation_message = (\r\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\r\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\r\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\r\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\r\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\r\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\r\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\r\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\r\n                \" the `unet/config.json` file\"\r\n            )\r\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\r\n            new_config = dict(unet.config)\r\n            new_config[\"sample_size\"] = 64\r\n            unet._internal_dict = FrozenDict(new_config)\r\n\r\n        self.register_modules(\r\n            vae=vae,\r\n            text_encoder=text_encoder,\r\n            tokenizer=tokenizer,\r\n            unet=unet,\r\n            scheduler=scheduler,\r\n            safety_checker=safety_checker,\r\n            feature_extractor=feature_extractor,\r\n        )\r\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\r\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\r\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\r\n\r\n    def _encode_prompt(\r\n        self,\r\n        prompt,\r\n        device,\r\n        num_images_per_prompt,\r\n        do_classifier_free_guidance,\r\n        negative_prompt=None,\r\n        prompt_embeds: Optional[torch.Tensor] = None,\r\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\r\n        lora_scale: Optional[float] = None,\r\n    ):\r\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\r\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\r\n\r\n        prompt_embeds_tuple = self.encode_prompt(\r\n            prompt=prompt,\r\n            device=device,\r\n            num_images_per_prompt=num_images_per_prompt,\r\n            do_classifier_free_guidance=do_classifier_free_guidance,\r\n            negative_prompt=negative_prompt,\r\n            prompt_embeds=prompt_embeds,\r\n            negative_prompt_embeds=negative_prompt_embeds,\r\n            lora_scale=lora_scale,\r\n        )\r\n\r\n        # concatenate for backwards comp\r\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\r\n\r\n        return prompt_embeds\r\n\r\n    def encode_prompt(\r\n        self,\r\n        prompt,\r\n        device,\r\n        num_images_per_prompt,\r\n        do_classifier_free_guidance,\r\n        negative_prompt=None,\r\n        prompt_embeds: Optional[torch.Tensor] = None,\r\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\r\n        lora_scale: Optional[float] = None,\r\n    ):\r\n        r\"\"\"\r\n        Encodes the prompt into text encoder hidden states.\r\n\r\n        Args:\r\n            prompt (`str` or `List[str]`, *optional*):\r\n                prompt to be encoded\r\n            device: (`torch.device`):\r\n                torch device\r\n            num_images_per_prompt (`int`):\r\n                number of images that should be generated per prompt\r\n            do_classifier_free_guidance (`bool`):\r\n                whether to use classifier free guidance or not\r\n            negative_prompt (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\r\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\r\n                less than `1`).\r\n            prompt_embeds (`torch.Tensor`, *optional*):\r\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\r\n                provided, text embeddings will be generated from `prompt` input argument.\r\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\r\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\r\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\r\n                argument.\r\n            lora_scale (`float`, *optional*):\r\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\r\n        \"\"\"\r\n        # set lora scale so that monkey patched LoRA\r\n        # function of text encoder can correctly access it\r\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\r\n            self._lora_scale = lora_scale\r\n\r\n            # dynamically adjust the LoRA scale\r\n            adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\r\n\r\n        if prompt is not None and isinstance(prompt, str):\r\n            batch_size = 1\r\n        elif prompt is not None and isinstance(prompt, list):\r\n            batch_size = len(prompt)\r\n        else:\r\n            batch_size = prompt_embeds.shape[0]\r\n\r\n        if prompt_embeds is None:\r\n            # textual inversion: procecss multi-vector tokens if necessary\r\n            if isinstance(self, TextualInversionLoaderMixin):\r\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\r\n\r\n            text_inputs = self.tokenizer(\r\n                prompt,\r\n                padding=\"max_length\",\r\n                max_length=self.tokenizer.model_max_length,\r\n                truncation=True,\r\n                return_tensors=\"pt\",\r\n            )\r\n            text_input_ids = text_inputs.input_ids\r\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\r\n\r\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\r\n                text_input_ids, untruncated_ids\r\n            ):\r\n                removed_text = self.tokenizer.batch_decode(\r\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\r\n                )\r\n                logger.warning(\r\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\r\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\r\n                )\r\n\r\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\r\n                attention_mask = text_inputs.attention_mask.to(device)\r\n            else:\r\n                attention_mask = None\r\n\r\n            prompt_embeds = self.text_encoder(\r\n                text_input_ids.to(device),\r\n                attention_mask=attention_mask,\r\n            )\r\n            prompt_embeds = prompt_embeds[0]\r\n\r\n        if self.text_encoder is not None:\r\n            prompt_embeds_dtype = self.text_encoder.dtype\r\n        elif self.unet is not None:\r\n            prompt_embeds_dtype = self.unet.dtype\r\n        else:\r\n            prompt_embeds_dtype = prompt_embeds.dtype\r\n\r\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\r\n\r\n        bs_embed, seq_len, _ = prompt_embeds.shape\r\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\r\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\r\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\r\n\r\n        # get unconditional embeddings for classifier free guidance\r\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\r\n            uncond_tokens: List[str]\r\n            if negative_prompt is None:\r\n                uncond_tokens = [\"\"] * batch_size\r\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\r\n                raise TypeError(\r\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\r\n                    f\" {type(prompt)}.\"\r\n                )\r\n            elif isinstance(negative_prompt, str):\r\n                uncond_tokens = [negative_prompt]\r\n            elif batch_size != len(negative_prompt):\r\n                raise ValueError(\r\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\r\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\r\n                    \" the batch size of `prompt`.\"\r\n                )\r\n            else:\r\n                uncond_tokens = negative_prompt\r\n\r\n            # textual inversion: procecss multi-vector tokens if necessary\r\n            if isinstance(self, TextualInversionLoaderMixin):\r\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\r\n\r\n            max_length = prompt_embeds.shape[1]\r\n            uncond_input = self.tokenizer(\r\n                uncond_tokens,\r\n                padding=\"max_length\",\r\n                max_length=max_length,\r\n                truncation=True,\r\n                return_tensors=\"pt\",\r\n            )\r\n\r\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\r\n                attention_mask = uncond_input.attention_mask.to(device)\r\n            else:\r\n                attention_mask = None\r\n\r\n            negative_prompt_embeds = self.text_encoder(\r\n                uncond_input.input_ids.to(device),\r\n                attention_mask=attention_mask,\r\n            )\r\n            negative_prompt_embeds = negative_prompt_embeds[0]\r\n\r\n        if do_classifier_free_guidance:\r\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\r\n            seq_len = negative_prompt_embeds.shape[1]\r\n\r\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\r\n\r\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\r\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\r\n\r\n        return prompt_embeds, negative_prompt_embeds\r\n\r\n    def run_safety_checker(self, image, device, dtype):\r\n        if self.safety_checker is None:\r\n            has_nsfw_concept = None\r\n        else:\r\n            if torch.is_tensor(image):\r\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\r\n            else:\r\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\r\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\r\n            image, has_nsfw_concept = self.safety_checker(\r\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\r\n            )\r\n        return image, has_nsfw_concept\r\n\r\n    def decode_latents(self, latents):\r\n        deprecation_message = \"The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead\"\r\n        deprecate(\"decode_latents\", \"1.0.0\", deprecation_message, standard_warn=False)\r\n\r\n        latents = 1 / self.vae.config.scaling_factor * latents\r\n        image = self.vae.decode(latents, return_dict=False)[0]\r\n        image = (image / 2 + 0.5).clamp(0, 1)\r\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\r\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\r\n        return image\r\n\r\n    def merge_dW_to_unet(pipe, dW_dict, alpha=1.0):\r\n        _tmp_sd = pipe.unet.state_dict()\r\n        for key in dW_dict.keys():\r\n            _tmp_sd[key] += dW_dict[key] * alpha\r\n        pipe.unet.load_state_dict(_tmp_sd, strict=False)\r\n        return pipe\r\n\r\n    def prepare_extra_step_kwargs(self, generator, eta):\r\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\r\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\r\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\r\n        # and should be between [0, 1]\r\n\r\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\r\n        extra_step_kwargs = {}\r\n        if accepts_eta:\r\n            extra_step_kwargs[\"eta\"] = eta\r\n\r\n        # check if the scheduler accepts generator\r\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\r\n        if accepts_generator:\r\n            extra_step_kwargs[\"generator\"] = generator\r\n        return extra_step_kwargs\r\n\r\n    def check_inputs(\r\n        self,\r\n        prompt,\r\n        height,\r\n        width,\r\n        callback_steps,\r\n        negative_prompt=None,\r\n        prompt_embeds=None,\r\n        negative_prompt_embeds=None,\r\n    ):\r\n        if height % 8 != 0 or width % 8 != 0:\r\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\r\n\r\n        if (callback_steps is None) or (\r\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\r\n        ):\r\n            raise ValueError(\r\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\r\n                f\" {type(callback_steps)}.\"\r\n            )\r\n\r\n        if prompt is not None and prompt_embeds is not None:\r\n            raise ValueError(\r\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\r\n                \" only forward one of the two.\"\r\n            )\r\n        elif prompt is None and prompt_embeds is None:\r\n            raise ValueError(\r\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\r\n            )\r\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\r\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\r\n\r\n        if negative_prompt is not None and negative_prompt_embeds is not None:\r\n            raise ValueError(\r\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\r\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\r\n            )\r\n\r\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\r\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\r\n                raise ValueError(\r\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\r\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\r\n                    f\" {negative_prompt_embeds.shape}.\"\r\n                )\r\n\r\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\r\n        shape = (\r\n            batch_size,\r\n            num_channels_latents,\r\n            int(height) // self.vae_scale_factor,\r\n            int(width) // self.vae_scale_factor,\r\n        )\r\n        if isinstance(generator, list) and len(generator) != batch_size:\r\n            raise ValueError(\r\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\r\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\r\n            )\r\n\r\n        if latents is None:\r\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\r\n        else:\r\n            latents = latents.to(device)\r\n\r\n        # scale the initial noise by the standard deviation required by the scheduler\r\n        latents = latents * self.scheduler.init_noise_sigma\r\n        return latents\r\n\r\n    @torch.no_grad()\r\n    def __call__(\r\n        self,\r\n        prompt: Union[str, List[str]] = None,\r\n        height: Optional[int] = None,\r\n        width: Optional[int] = None,\r\n        num_inference_steps: int = 50,\r\n        guidance_scale: float = 7.5,\r\n        negative_prompt: Optional[Union[str, List[str]]] = None,\r\n        num_images_per_prompt: Optional[int] = 1,\r\n        eta: float = 0.0,\r\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\r\n        latents: Optional[torch.Tensor] = None,\r\n        prompt_embeds: Optional[torch.Tensor] = None,\r\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\r\n        output_type: str | None = \"pil\",\r\n        return_dict: bool = True,\r\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\r\n        callback_steps: int = 1,\r\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\r\n        guidance_rescale: float = 0.0,\r\n    ):\r\n        r\"\"\"\r\n        The call function to the pipeline for generation.\r\n\r\n        Args:\r\n            prompt (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\r\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\r\n                The height in pixels of the generated image.\r\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\r\n                The width in pixels of the generated image.\r\n            num_inference_steps (`int`, *optional*, defaults to 50):\r\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\r\n                expense of slower inference.\r\n            guidance_scale (`float`, *optional*, defaults to 7.5):\r\n                A higher guidance scale value encourages the model to generate images closely linked to the text\r\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\r\n            negative_prompt (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\r\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\r\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\r\n                The number of images to generate per prompt.\r\n            eta (`float`, *optional*, defaults to 0.0):\r\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\r\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\r\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\r\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\r\n                generation deterministic.\r\n            latents (`torch.Tensor`, *optional*):\r\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\r\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\r\n                tensor is generated by sampling using the supplied random `generator`.\r\n            prompt_embeds (`torch.Tensor`, *optional*):\r\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\r\n                provided, text embeddings are generated from the `prompt` input argument.\r\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\r\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\r\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\r\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\r\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\r\n            return_dict (`bool`, *optional*, defaults to `True`):\r\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\r\n                plain tuple.\r\n            callback (`Callable`, *optional*):\r\n                A function that calls every `callback_steps` steps during inference. The function is called with the\r\n                following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\r\n            callback_steps (`int`, *optional*, defaults to 1):\r\n                The frequency at which the `callback` function is called. If not specified, the callback is called at\r\n                every step.\r\n            cross_attention_kwargs (`dict`, *optional*):\r\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\r\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\r\n            guidance_rescale (`float`, *optional*, defaults to 0.7):\r\n                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are\r\n                Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when\r\n                using zero terminal SNR.\r\n\r\n        Examples:\r\n\r\n        Returns:\r\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\r\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\r\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\r\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\r\n                \"not-safe-for-work\" (nsfw) content.\r\n        \"\"\"\r\n        # 0. Default height and width to unet\r\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\r\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\r\n\r\n        # 1. Check inputs. Raise error if not correct\r\n        self.check_inputs(\r\n            prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds\r\n        )\r\n\r\n        # 2. Define call parameters\r\n        if prompt is not None and isinstance(prompt, str):\r\n            batch_size = 1\r\n        elif prompt is not None and isinstance(prompt, list):\r\n            batch_size = len(prompt)\r\n        else:\r\n            batch_size = prompt_embeds.shape[0]\r\n\r\n        device = self._execution_device\r\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\r\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\r\n        # corresponds to doing no classifier free guidance.\r\n        do_classifier_free_guidance = guidance_scale > 1.0\r\n\r\n        # 3. Encode input prompt\r\n        text_encoder_lora_scale = (\r\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\r\n        )\r\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\r\n            prompt,\r\n            device,\r\n            num_images_per_prompt,\r\n            do_classifier_free_guidance,\r\n            negative_prompt,\r\n            prompt_embeds=prompt_embeds,\r\n            negative_prompt_embeds=negative_prompt_embeds,\r\n            lora_scale=text_encoder_lora_scale,\r\n        )\r\n        # For classifier free guidance, we need to do two forward passes.\r\n        # Here we concatenate the unconditional and text embeddings into a single batch\r\n        # to avoid doing two forward passes\r\n        if do_classifier_free_guidance:\r\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\r\n\r\n        # 4. Prepare timesteps\r\n        timesteps = [(1.0 - i / num_inference_steps) * 1000.0 for i in range(num_inference_steps)]\r\n\r\n        # 5. Prepare latent variables\r\n        num_channels_latents = self.unet.config.in_channels\r\n        latents = self.prepare_latents(\r\n            batch_size * num_images_per_prompt,\r\n            num_channels_latents,\r\n            height,\r\n            width,\r\n            prompt_embeds.dtype,\r\n            device,\r\n            generator,\r\n            latents,\r\n        )\r\n\r\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\r\n        dt = 1.0 / num_inference_steps\r\n\r\n        # 7. Denoising loop of Euler discretization from t = 0 to t = 1\r\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\r\n            for i, t in enumerate(timesteps):\r\n                # expand the latents if we are doing classifier free guidance\r\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\r\n\r\n                vec_t = torch.ones((latent_model_input.shape[0],), device=latents.device) * t\r\n\r\n                v_pred = self.unet(latent_model_input, vec_t, encoder_hidden_states=prompt_embeds).sample\r\n\r\n                # perform guidance\r\n                if do_classifier_free_guidance:\r\n                    v_pred_neg, v_pred_text = v_pred.chunk(2)\r\n                    v_pred = v_pred_neg + guidance_scale * (v_pred_text - v_pred_neg)\r\n\r\n                latents = latents + dt * v_pred\r\n\r\n                # call the callback, if provided\r\n                if i == len(timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):\r\n                    progress_bar.update()\r\n                    if callback is not None and i % callback_steps == 0:\r\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\r\n                        callback(step_idx, t, latents)\r\n\r\n        if not output_type == \"latent\":\r\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\r\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\r\n        else:\r\n            image = latents\r\n            has_nsfw_concept = None\r\n\r\n        if has_nsfw_concept is None:\r\n            do_denormalize = [True] * image.shape[0]\r\n        else:\r\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\r\n\r\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\r\n\r\n        # Offload all models\r\n        self.maybe_free_model_hooks()\r\n\r\n        if not return_dict:\r\n            return (image, has_nsfw_concept)\r\n\r\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\r\n"
  },
  {
    "path": "examples/community/interpolate_stable_diffusion.py",
    "content": "import inspect\nimport time\nfrom pathlib import Path\nfrom typing import Callable, List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler\nfrom diffusers.utils import deprecate, logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef slerp(t, v0, v1, DOT_THRESHOLD=0.9995):\n    \"\"\"helper function to spherically interpolate two arrays v1 v2\"\"\"\n\n    if not isinstance(v0, np.ndarray):\n        inputs_are_torch = True\n        input_device = v0.device\n        v0 = v0.cpu().numpy()\n        v1 = v1.cpu().numpy()\n\n    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))\n    if np.abs(dot) > DOT_THRESHOLD:\n        v2 = (1 - t) * v0 + t * v1\n    else:\n        theta_0 = np.arccos(dot)\n        sin_theta_0 = np.sin(theta_0)\n        theta_t = theta_0 * t\n        sin_theta_t = np.sin(theta_t)\n        s0 = np.sin(theta_0 - theta_t) / sin_theta_0\n        s1 = sin_theta_t / sin_theta_0\n        v2 = s0 * v0 + s1 * v1\n\n    if inputs_are_torch:\n        v2 = torch.from_numpy(v2).to(input_device)\n\n    return v2\n\n\nclass StableDiffusionWalkPipeline(DiffusionPipeline, StableDiffusionMixin):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Optional[Union[str, List[str]]] = None,\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        text_embeddings: Optional[torch.Tensor] = None,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*, defaults to `None`):\n                The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            text_embeddings (`torch.Tensor`, *optional*, defaults to `None`):\n                Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of\n                `prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from\n                the supplied `prompt`.\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if text_embeddings is None:\n            if isinstance(prompt, str):\n                batch_size = 1\n            elif isinstance(prompt, list):\n                batch_size = len(prompt)\n            else:\n                raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n            # get prompt text embeddings\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n\n            if text_input_ids.shape[-1] > self.tokenizer.model_max_length:\n                removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])\n                print(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n                text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]\n            text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]\n        else:\n            batch_size = text_embeddings.shape[0]\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = text_embeddings.shape\n        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)\n        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = self.tokenizer.model_max_length\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = uncond_embeddings.shape[1]\n            uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)\n            uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # get the initial random noise unless the user supplied it\n\n        # Unlike in other pipelines, latents need to be generated in the target device\n        # for 1-to-1 results reproducibility with the CompVis implementation.\n        # However this currently doesn't work in `mps`.\n        latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)\n        latents_dtype = text_embeddings.dtype\n        if latents is None:\n            if self.device.type == \"mps\":\n                # randn does not work reproducibly on mps\n                latents = torch.randn(latents_shape, generator=generator, device=\"cpu\", dtype=latents_dtype).to(\n                    self.device\n                )\n            else:\n                latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)\n        else:\n            if latents.shape != latents_shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {latents_shape}\")\n            latents = latents.to(self.device)\n\n        # set timesteps\n        self.scheduler.set_timesteps(num_inference_steps)\n\n        # Some schedulers like PNDM have timesteps as arrays\n        # It's more optimized to move all timesteps to correct device beforehand\n        timesteps_tensor = self.scheduler.timesteps.to(self.device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        for i, t in enumerate(self.progress_bar(timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n            # call the callback, if provided\n            if callback is not None and i % callback_steps == 0:\n                step_idx = i // getattr(self.scheduler, \"order\", 1)\n                callback(step_idx, t, latents)\n\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents).sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(\n                self.device\n            )\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)\n            )\n        else:\n            has_nsfw_concept = None\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n    def embed_text(self, text):\n        \"\"\"takes in text and turns it into text embeddings\"\"\"\n        text_input = self.tokenizer(\n            text,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        with torch.no_grad():\n            embed = self.text_encoder(text_input.input_ids.to(self.device))[0]\n        return embed\n\n    def get_noise(self, seed, dtype=torch.float32, height=512, width=512):\n        \"\"\"Takes in random seed and returns corresponding noise vector\"\"\"\n        return torch.randn(\n            (1, self.unet.config.in_channels, height // 8, width // 8),\n            generator=torch.Generator(device=self.device).manual_seed(seed),\n            device=self.device,\n            dtype=dtype,\n        )\n\n    def walk(\n        self,\n        prompts: List[str],\n        seeds: List[int],\n        num_interpolation_steps: Optional[int] = 6,\n        output_dir: str | None = \"./dreams\",\n        name: str | None = None,\n        batch_size: Optional[int] = 1,\n        height: Optional[int] = 512,\n        width: Optional[int] = 512,\n        guidance_scale: Optional[float] = 7.5,\n        num_inference_steps: Optional[int] = 50,\n        eta: Optional[float] = 0.0,\n    ) -> List[str]:\n        \"\"\"\n        Walks through a series of prompts and seeds, interpolating between them and saving the results to disk.\n\n        Args:\n            prompts (`List[str]`):\n                List of prompts to generate images for.\n            seeds (`List[int]`):\n                List of seeds corresponding to provided prompts. Must be the same length as prompts.\n            num_interpolation_steps (`int`, *optional*, defaults to 6):\n                Number of interpolation steps to take between prompts.\n            output_dir (`str`, *optional*, defaults to `./dreams`):\n                Directory to save the generated images to.\n            name (`str`, *optional*, defaults to `None`):\n                Subdirectory of `output_dir` to save the generated images to. If `None`, the name will\n                be the current time.\n            batch_size (`int`, *optional*, defaults to 1):\n                Number of images to generate at once.\n            height (`int`, *optional*, defaults to 512):\n                Height of the generated images.\n            width (`int`, *optional*, defaults to 512):\n                Width of the generated images.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n\n        Returns:\n            `List[str]`: List of paths to the generated images.\n        \"\"\"\n        if not len(prompts) == len(seeds):\n            raise ValueError(\n                f\"Number of prompts and seeds must be equalGot {len(prompts)} prompts and {len(seeds)} seeds\"\n            )\n\n        name = name or time.strftime(\"%Y%m%d-%H%M%S\")\n        save_path = Path(output_dir) / name\n        save_path.mkdir(exist_ok=True, parents=True)\n\n        frame_idx = 0\n        frame_filepaths = []\n        for prompt_a, prompt_b, seed_a, seed_b in zip(prompts, prompts[1:], seeds, seeds[1:]):\n            # Embed Text\n            embed_a = self.embed_text(prompt_a)\n            embed_b = self.embed_text(prompt_b)\n\n            # Get Noise\n            noise_dtype = embed_a.dtype\n            noise_a = self.get_noise(seed_a, noise_dtype, height, width)\n            noise_b = self.get_noise(seed_b, noise_dtype, height, width)\n\n            noise_batch, embeds_batch = None, None\n            T = np.linspace(0.0, 1.0, num_interpolation_steps)\n            for i, t in enumerate(T):\n                noise = slerp(float(t), noise_a, noise_b)\n                embed = torch.lerp(embed_a, embed_b, t)\n\n                noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise], dim=0)\n                embeds_batch = embed if embeds_batch is None else torch.cat([embeds_batch, embed], dim=0)\n\n                batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == T.shape[0]\n                if batch_is_ready:\n                    outputs = self(\n                        latents=noise_batch,\n                        text_embeddings=embeds_batch,\n                        height=height,\n                        width=width,\n                        guidance_scale=guidance_scale,\n                        eta=eta,\n                        num_inference_steps=num_inference_steps,\n                    )\n                    noise_batch, embeds_batch = None, None\n\n                    for image in outputs[\"images\"]:\n                        frame_filepath = str(save_path / f\"frame_{frame_idx:06d}.png\")\n                        image.save(frame_filepath)\n                        frame_filepaths.append(frame_filepath)\n                        frame_idx += 1\n        return frame_filepaths\n"
  },
  {
    "path": "examples/community/ip_adapter_face_id.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom packaging import version\nfrom safetensors import safe_open\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.models.attention_processor import (\n    AttnProcessor,\n    AttnProcessor2_0,\n    IPAdapterAttnProcessor,\n    IPAdapterAttnProcessor2_0,\n)\nfrom diffusers.models.embeddings import MultiIPAdapterImageProjection\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    _get_model_file,\n    deprecate,\n    logging,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass IPAdapterFullImageProjection(nn.Module):\n    def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):\n        super().__init__()\n        from diffusers.models.attention import FeedForward\n\n        self.num_tokens = num_tokens\n        self.cross_attention_dim = cross_attention_dim\n        self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn=\"gelu\")\n        self.norm = nn.LayerNorm(cross_attention_dim)\n\n    def forward(self, image_embeds: torch.Tensor):\n        x = self.ff(image_embeds)\n        x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)\n        return self.norm(x)\n\n\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass IPAdapterFaceIDStableDiffusionPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    StableDiffusionLoraLoaderMixin,\n    IPAdapterMixin,\n    FromSingleFileMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_name, **kwargs):\n        cache_dir = kwargs.pop(\"cache_dir\", None)\n        force_download = kwargs.pop(\"force_download\", False)\n        proxies = kwargs.pop(\"proxies\", None)\n        local_files_only = kwargs.pop(\"local_files_only\", None)\n        token = kwargs.pop(\"token\", None)\n        revision = kwargs.pop(\"revision\", None)\n        subfolder = kwargs.pop(\"subfolder\", None)\n\n        user_agent = {\"file_type\": \"attn_procs_weights\", \"framework\": \"pytorch\"}\n        model_file = _get_model_file(\n            pretrained_model_name_or_path_or_dict,\n            weights_name=weight_name,\n            cache_dir=cache_dir,\n            force_download=force_download,\n            proxies=proxies,\n            local_files_only=local_files_only,\n            token=token,\n            revision=revision,\n            subfolder=subfolder,\n            user_agent=user_agent,\n        )\n        if weight_name.endswith(\".safetensors\"):\n            state_dict = {\"image_proj\": {}, \"ip_adapter\": {}}\n            with safe_open(model_file, framework=\"pt\", device=\"cpu\") as f:\n                for key in f.keys():\n                    if key.startswith(\"image_proj.\"):\n                        state_dict[\"image_proj\"][key.replace(\"image_proj.\", \"\")] = f.get_tensor(key)\n                    elif key.startswith(\"ip_adapter.\"):\n                        state_dict[\"ip_adapter\"][key.replace(\"ip_adapter.\", \"\")] = f.get_tensor(key)\n        else:\n            state_dict = torch.load(model_file, map_location=\"cpu\")\n        self._load_ip_adapter_weights(state_dict)\n\n    def convert_ip_adapter_image_proj_to_diffusers(self, state_dict):\n        updated_state_dict = {}\n        clip_embeddings_dim_in = state_dict[\"proj.0.weight\"].shape[1]\n        clip_embeddings_dim_out = state_dict[\"proj.0.weight\"].shape[0]\n        multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in\n        norm_layer = \"norm.weight\"\n        cross_attention_dim = state_dict[norm_layer].shape[0]\n        num_tokens = state_dict[\"proj.2.weight\"].shape[0] // cross_attention_dim\n\n        image_projection = IPAdapterFullImageProjection(\n            cross_attention_dim=cross_attention_dim,\n            image_embed_dim=clip_embeddings_dim_in,\n            mult=multiplier,\n            num_tokens=num_tokens,\n        )\n\n        for key, value in state_dict.items():\n            diffusers_name = key.replace(\"proj.0\", \"ff.net.0.proj\")\n            diffusers_name = diffusers_name.replace(\"proj.2\", \"ff.net.2\")\n            updated_state_dict[diffusers_name] = value\n\n        image_projection.load_state_dict(updated_state_dict)\n        return image_projection\n\n    def _load_ip_adapter_weights(self, state_dict):\n        num_image_text_embeds = 4\n\n        self.unet.encoder_hid_proj = None\n\n        # set ip-adapter cross-attention processors & load state_dict\n        attn_procs = {}\n        lora_dict = {}\n        key_id = 0\n        for name in self.unet.attn_processors.keys():\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else self.unet.config.cross_attention_dim\n            if name.startswith(\"mid_block\"):\n                hidden_size = self.unet.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = self.unet.config.block_out_channels[block_id]\n            if cross_attention_dim is None or \"motion_modules\" in name:\n                attn_processor_class = (\n                    AttnProcessor2_0 if hasattr(F, \"scaled_dot_product_attention\") else AttnProcessor\n                )\n                attn_procs[name] = attn_processor_class()\n\n                lora_dict.update(\n                    {f\"unet.{name}.to_k_lora.down.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_k_lora.down.weight\"]}\n                )\n                lora_dict.update(\n                    {f\"unet.{name}.to_q_lora.down.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_q_lora.down.weight\"]}\n                )\n                lora_dict.update(\n                    {f\"unet.{name}.to_v_lora.down.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_v_lora.down.weight\"]}\n                )\n                lora_dict.update(\n                    {\n                        f\"unet.{name}.to_out_lora.down.weight\": state_dict[\"ip_adapter\"][\n                            f\"{key_id}.to_out_lora.down.weight\"\n                        ]\n                    }\n                )\n                lora_dict.update(\n                    {f\"unet.{name}.to_k_lora.up.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_k_lora.up.weight\"]}\n                )\n                lora_dict.update(\n                    {f\"unet.{name}.to_q_lora.up.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_q_lora.up.weight\"]}\n                )\n                lora_dict.update(\n                    {f\"unet.{name}.to_v_lora.up.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_v_lora.up.weight\"]}\n                )\n                lora_dict.update(\n                    {f\"unet.{name}.to_out_lora.up.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_out_lora.up.weight\"]}\n                )\n                key_id += 1\n            else:\n                attn_processor_class = (\n                    IPAdapterAttnProcessor2_0 if hasattr(F, \"scaled_dot_product_attention\") else IPAdapterAttnProcessor\n                )\n                attn_procs[name] = attn_processor_class(\n                    hidden_size=hidden_size,\n                    cross_attention_dim=cross_attention_dim,\n                    scale=1.0,\n                    num_tokens=num_image_text_embeds,\n                ).to(dtype=self.dtype, device=self.device)\n\n                lora_dict.update(\n                    {f\"unet.{name}.to_k_lora.down.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_k_lora.down.weight\"]}\n                )\n                lora_dict.update(\n                    {f\"unet.{name}.to_q_lora.down.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_q_lora.down.weight\"]}\n                )\n                lora_dict.update(\n                    {f\"unet.{name}.to_v_lora.down.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_v_lora.down.weight\"]}\n                )\n                lora_dict.update(\n                    {\n                        f\"unet.{name}.to_out_lora.down.weight\": state_dict[\"ip_adapter\"][\n                            f\"{key_id}.to_out_lora.down.weight\"\n                        ]\n                    }\n                )\n                lora_dict.update(\n                    {f\"unet.{name}.to_k_lora.up.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_k_lora.up.weight\"]}\n                )\n                lora_dict.update(\n                    {f\"unet.{name}.to_q_lora.up.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_q_lora.up.weight\"]}\n                )\n                lora_dict.update(\n                    {f\"unet.{name}.to_v_lora.up.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_v_lora.up.weight\"]}\n                )\n                lora_dict.update(\n                    {f\"unet.{name}.to_out_lora.up.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_out_lora.up.weight\"]}\n                )\n\n                value_dict = {}\n                value_dict.update({\"to_k_ip.0.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_k_ip.weight\"]})\n                value_dict.update({\"to_v_ip.0.weight\": state_dict[\"ip_adapter\"][f\"{key_id}.to_v_ip.weight\"]})\n                attn_procs[name].load_state_dict(value_dict)\n                key_id += 1\n\n        self.unet.set_attn_processor(attn_procs)\n\n        self.load_lora_weights(lora_dict, adapter_name=\"faceid\")\n        self.set_adapters([\"faceid\"], adapter_weights=[1.0])\n\n        # convert IP-Adapter Image Projection layers to diffusers\n        image_projection = self.convert_ip_adapter_image_proj_to_diffusers(state_dict[\"image_proj\"])\n        image_projection_layers = [image_projection.to(device=self.device, dtype=self.dtype)]\n\n        self.unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)\n        self.unet.config.encoder_hid_dim_type = \"ip_image_proj\"\n\n    def set_ip_adapter_scale(self, scale):\n        unet = getattr(self, self.unet_name) if not hasattr(self, \"unet\") else self.unet\n        for attn_processor in unet.attn_processors.values():\n            if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):\n                attn_processor.scale = [scale]\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ):\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        deprecation_message = \"The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead\"\n        deprecate(\"decode_latents\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        image_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            image_embeds (`torch.Tensor`, *optional*):\n                Pre-generated image embeddings.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when\n                using zero terminal SNR.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n        # to deal with lora scaling and other possible forward hooks\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        if image_embeds is not None:\n            image_embeds = torch.stack([image_embeds] * num_images_per_prompt, dim=0).to(\n                device=device, dtype=prompt_embeds.dtype\n            )\n            negative_image_embeds = torch.zeros_like(image_embeds)\n            if self.do_classifier_free_guidance:\n                image_embeds = torch.cat([negative_image_embeds, image_embeds])\n        image_embeds = [image_embeds]\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 6.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = {\"image_embeds\": image_embeds} if image_embeds is not None else {}\n\n        # 6.2 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\n                0\n            ]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/kohya_hires_fix.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers.configuration_utils import register_to_config\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.models.autoencoders import AutoencoderKL\nfrom diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass UNet2DConditionModelHighResFix(UNet2DConditionModel):\n    r\"\"\"\n    A conditional 2D UNet model that applies Kohya fix proposed for high resolution image generation.\n\n    This model inherits from [`UNet2DConditionModel`]. Check the superclass documentation for learning about all the parameters.\n\n    Parameters:\n        high_res_fix (`List[Dict]`, *optional*, defaults to `[{'timestep': 600, 'scale_factor': 0.5, 'block_num': 1}]`):\n            Enables Kohya fix for high resolution generation. The activation maps are scaled based on the scale_factor up to the timestep at specified block_num.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(self, high_res_fix: List[Dict] = [{\"timestep\": 600, \"scale_factor\": 0.5, \"block_num\": 1}], **kwargs):\n        super().__init__(**kwargs)\n        if high_res_fix:\n            self.config.high_res_fix = sorted(high_res_fix, key=lambda x: x[\"timestep\"], reverse=True)\n\n    @classmethod\n    def _resize(cls, sample, target=None, scale_factor=1, mode=\"bicubic\"):\n        dtype = sample.dtype\n        if dtype == torch.bfloat16:\n            sample = sample.to(torch.float32)\n\n        if target is not None:\n            if sample.shape[-2:] != target.shape[-2:]:\n                sample = nn.functional.interpolate(sample, size=target.shape[-2:], mode=mode, align_corners=False)\n        elif scale_factor != 1:\n            sample = nn.functional.interpolate(sample, scale_factor=scale_factor, mode=mode, align_corners=False)\n\n        return sample.to(dtype)\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[UNet2DConditionOutput, Tuple]:\n        r\"\"\"\n        The [`UNet2DConditionModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor with the following shape `(batch, channel, height, width)`.\n            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.FloatTensor`):\n                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.\n            class_labels (`torch.Tensor`, *optional*, defaults to `None`):\n                Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.\n            timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):\n                Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed\n                through the `self.time_embedding` layer to obtain the timestep embeddings.\n            attention_mask (`torch.Tensor`, *optional*, defaults to `None`):\n                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask\n                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large\n                negative values to the attention scores corresponding to \"discard\" tokens.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            added_cond_kwargs: (`dict`, *optional*):\n                A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that\n                are passed along to the UNet blocks.\n            down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):\n                A tuple of tensors that if specified are added to the residuals of down unet blocks.\n            mid_block_additional_residual: (`torch.Tensor`, *optional*):\n                A tensor that if specified is added to the residual of the middle unet block.\n            down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):\n                additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)\n            encoder_attention_mask (`torch.Tensor`):\n                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If\n                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,\n                which adds large negative values to the attention scores corresponding to \"discard\" tokens.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain\n                tuple.\n\n        Returns:\n            [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        for dim in sample.shape[-2:]:\n            if dim % default_overall_up_factor != 0:\n                # Forward upsample size to force interpolation output size.\n                forward_upsample_size = True\n                break\n\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        t_emb = self.get_time_embed(sample=sample, timestep=timestep)\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)\n        if class_emb is not None:\n            if self.config.class_embeddings_concat:\n                emb = torch.cat([emb, class_emb], dim=-1)\n            else:\n                emb = emb + class_emb\n\n        aug_emb = self.get_aug_embed(\n            emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n        )\n        if self.config.addition_embed_type == \"image_hint\":\n            aug_emb, hint = aug_emb\n            sample = torch.cat([sample, hint], dim=1)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        if self.time_embed_act is not None:\n            emb = self.time_embed_act(emb)\n\n        encoder_hidden_states = self.process_encoder_hidden_states(\n            encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n        )\n\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        # 2.5 GLIGEN position net\n        if cross_attention_kwargs is not None and cross_attention_kwargs.get(\"gligen\", None) is not None:\n            cross_attention_kwargs = cross_attention_kwargs.copy()\n            gligen_args = cross_attention_kwargs.pop(\"gligen\")\n            cross_attention_kwargs[\"gligen\"] = {\"objs\": self.position_net(**gligen_args)}\n\n        # 3. down\n        # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated\n        # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.\n        if cross_attention_kwargs is not None:\n            cross_attention_kwargs = cross_attention_kwargs.copy()\n            lora_scale = cross_attention_kwargs.pop(\"scale\", 1.0)\n        else:\n            lora_scale = 1.0\n\n        if USE_PEFT_BACKEND:\n            # weight the lora layers by setting `lora_scale` for each PEFT layer\n            scale_lora_layers(self, lora_scale)\n\n        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None\n        # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets\n        is_adapter = down_intrablock_additional_residuals is not None\n        # maintain backward compatibility for legacy usage, where\n        #       T2I-Adapter and ControlNet both use down_block_additional_residuals arg\n        #       but can only use one or the other\n        if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:\n            deprecate(\n                \"T2I should not use down_block_additional_residuals\",\n                \"1.3.0\",\n                \"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \\\n                       and will be removed in diffusers 1.3.0.  `down_block_additional_residuals` should only be used \\\n                       for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. \",\n                standard_warn=False,\n            )\n            down_intrablock_additional_residuals = down_block_additional_residuals\n            is_adapter = True\n\n        down_block_res_samples = (sample,)\n        for down_i, downsample_block in enumerate(self.down_blocks):\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                # For t2i-adapter CrossAttnDownBlock2D\n                additional_residuals = {}\n                if is_adapter and len(down_intrablock_additional_residuals) > 0:\n                    additional_residuals[\"additional_residuals\"] = down_intrablock_additional_residuals.pop(0)\n\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                    **additional_residuals,\n                )\n\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n                if is_adapter and len(down_intrablock_additional_residuals) > 0:\n                    sample += down_intrablock_additional_residuals.pop(0)\n\n            down_block_res_samples += res_samples\n\n            # kohya high res fix\n            if self.config.high_res_fix:\n                for high_res_fix in self.config.high_res_fix:\n                    if timestep > high_res_fix[\"timestep\"] and down_i == high_res_fix[\"block_num\"]:\n                        sample = self.__class__._resize(sample, scale_factor=high_res_fix[\"scale_factor\"])\n                        break\n\n        if is_controlnet:\n            new_down_block_res_samples = ()\n\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n\n            down_block_res_samples = new_down_block_res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            if hasattr(self.mid_block, \"has_cross_attention\") and self.mid_block.has_cross_attention:\n                sample = self.mid_block(\n                    sample,\n                    emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                sample = self.mid_block(sample, emb)\n\n            # To support T2I-Adapter-XL\n            if (\n                is_adapter\n                and len(down_intrablock_additional_residuals) > 0\n                and sample.shape == down_intrablock_additional_residuals[0].shape\n            ):\n                sample += down_intrablock_additional_residuals.pop(0)\n\n        if is_controlnet:\n            sample = sample + mid_block_additional_residual\n\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # up scaling of kohya high res fix\n            if self.config.high_res_fix is not None:\n                if res_samples[0].shape[-2:] != sample.shape[-2:]:\n                    sample = self.__class__._resize(sample, target=res_samples[0])\n                    res_samples_up_sampled = (res_samples[0],)\n                    for res_sample in res_samples[1:]:\n                        res_samples_up_sampled += (self.__class__._resize(res_sample, target=res_samples[0]),)\n                    res_samples = res_samples_up_sampled\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    upsample_size=upsample_size,\n                )\n\n        # 6. post-process\n        if self.conv_norm_out:\n            sample = self.conv_norm_out(sample)\n            sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if USE_PEFT_BACKEND:\n            # remove `lora_scale` from each PEFT layer\n            unscale_lora_layers(self, lora_scale)\n\n        if not return_dict:\n            return (sample,)\n\n        return UNet2DConditionOutput(sample=sample)\n\n    @classmethod\n    def from_unet(cls, unet: UNet2DConditionModel, high_res_fix: list):\n        config = dict((unet.config))\n        config[\"high_res_fix\"] = high_res_fix\n        unet_high_res = cls(**config)\n        unet_high_res.load_state_dict(unet.state_dict())\n        unet_high_res.to(unet.dtype)\n        return unet_high_res\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import DiffusionPipeline\n\n        >>> pipe = DiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\",\n                                         custom_pipeline=\"kohya_hires_fix\",\n                                         torch_dtype=torch.float16,\n                                         high_res_fix=[{'timestep': 600,\n                                                        'scale_factor': 0.5,\n                                                        'block_num': 1}])\n        >>> pipe = pipe.to(\"cuda\")\n\n        >>> prompt = \"a photo of an astronaut riding a horse on mars\"\n        >>> image = pipe(prompt, height=1000, width=1600).images[0]\n        ```\n\"\"\"\n\n\nclass StableDiffusionHighResFixPipeline(StableDiffusionPipeline):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion with Kohya fix for high resolution generation.\n\n    This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods.\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n        high_res_fix (`List[Dict]`, *optional*, defaults to `[{'timestep': 600, 'scale_factor': 0.5, 'block_num': 1}]`):\n            Enables Kohya fix for high resolution generation. The activation maps are scaled based on the scale_factor up to the timestep at specified block_num.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n        high_res_fix: List[Dict] = [{\"timestep\": 600, \"scale_factor\": 0.5, \"block_num\": 1}],\n    ):\n        super().__init__(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n            requires_safety_checker=requires_safety_checker,\n        )\n\n        unet = UNet2DConditionModelHighResFix.from_unet(unet=unet, high_res_fix=high_res_fix)\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n"
  },
  {
    "path": "examples/community/latent_consistency_img2img.py",
    "content": "# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion\n# and https://github.com/hojonathanho/diffusion\n\nimport math\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, ConfigMixin, DiffusionPipeline, SchedulerMixin, UNet2DConditionModel, logging\nfrom diffusers.configuration_utils import register_to_config\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.utils import BaseOutput\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass LatentConsistencyModelImg2ImgPipeline(DiffusionPipeline):\n    _optional_components = [\"scheduler\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: \"LCMSchedulerWithTimestamp\",\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        scheduler = (\n            scheduler\n            if scheduler is not None\n            else LCMSchedulerWithTimestamp(\n                beta_start=0.00085, beta_end=0.0120, beta_schedule=\"scaled_linear\", prediction_type=\"epsilon\"\n            )\n        )\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        prompt_embeds: None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n        \"\"\"\n\n        if prompt is not None and isinstance(prompt, str):\n            pass\n        elif prompt is not None and isinstance(prompt, list):\n            len(prompt)\n        else:\n            prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            prompt_embeds = self.text_encoder(\n                text_input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # Don't need to get uncond prompt embedding because of LCM Guided Distillation\n        return prompt_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    def prepare_latents(\n        self,\n        image,\n        timestep,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        latents=None,\n        generator=None,\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        image = image.to(device=device, dtype=dtype)\n\n        # batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            init_latents = image\n\n        else:\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            elif isinstance(generator, list):\n                init_latents = [\n                    self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)\n                ]\n                init_latents = torch.cat(init_latents, dim=0)\n            else:\n                init_latents = self.vae.encode(image).latent_dist.sample(generator)\n\n            init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            (\n                f\"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial\"\n                \" images (`image`). Initial images are now duplicating to match the number of text prompts. Note\"\n                \" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update\"\n                \" your script to pass as many initial images as text prompts to suppress this warning.\"\n            )\n            # deprecate(\"len(prompt) != len(image)\", \"1.0.0\", deprecation_message, standard_warn=False)\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        shape = init_latents.shape\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        # get latents\n        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n        latents = init_latents\n\n        return latents\n\n    def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n        Args:\n        timesteps: torch.Tensor: generate embedding vectors at these timesteps\n        embedding_dim: int: dimension of the embeddings to generate\n        dtype: data type of the generated embeddings\n        Returns:\n        embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        return timesteps, num_inference_steps - t_start\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        strength: float = 0.8,\n        height: Optional[int] = 768,\n        width: Optional[int] = 768,\n        guidance_scale: float = 7.5,\n        num_images_per_prompt: Optional[int] = 1,\n        latents: Optional[torch.Tensor] = None,\n        num_inference_steps: int = 4,\n        lcm_origin_steps: int = 50,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # do_classifier_free_guidance = guidance_scale > 0.0  # In LCM Implementation:  cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            prompt_embeds=prompt_embeds,\n        )\n\n        # 3.5 encode image\n        image = self.image_processor.preprocess(image)\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(strength, num_inference_steps, lcm_origin_steps)\n        # timesteps = self.scheduler.timesteps\n        # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)\n        timesteps = self.scheduler.timesteps\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        print(\"timesteps: \", timesteps)\n\n        # 5. Prepare latent variable\n        num_channels_latents = self.unet.config.in_channels\n        if latents is None:\n            latents = self.prepare_latents(\n                image,\n                latent_timestep,\n                batch_size * num_images_per_prompt,\n                num_channels_latents,\n                height,\n                width,\n                prompt_embeds.dtype,\n                device,\n                latents,\n            )\n        bs = batch_size * num_images_per_prompt\n\n        # 6. Get Guidance Scale Embedding\n        w = torch.tensor(guidance_scale).repeat(bs)\n        w_embedding = self.get_w_embedding(w, embedding_dim=256).to(device=device, dtype=latents.dtype)\n\n        # 7. LCM MultiStep Sampling Loop:\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                ts = torch.full((bs,), t, device=device, dtype=torch.long)\n                latents = latents.to(prompt_embeds.dtype)\n\n                # model prediction (v-prediction, eps, x)\n                model_pred = self.unet(\n                    latents,\n                    ts,\n                    timestep_cond=w_embedding,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents, denoised = self.scheduler.step(model_pred, i, t, latents, return_dict=False)\n\n                # # call the callback, if provided\n                # if i == len(timesteps) - 1:\n                progress_bar.update()\n\n        denoised = denoised.to(prompt_embeds.dtype)\n        if not output_type == \"latent\":\n            image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = denoised\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n\n@dataclass\n# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM\nclass LCMSchedulerOutput(BaseOutput):\n    \"\"\"\n    Output class for the scheduler's `step` function output.\n    Args:\n        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):\n            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the\n            denoising loop.\n        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):\n            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.\n            `pred_original_sample` can be used to preview progress or for guidance.\n    \"\"\"\n\n    prev_sample: torch.Tensor\n    denoised: Optional[torch.Tensor] = None\n\n\n# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar\ndef betas_for_alpha_bar(\n    num_diffusion_timesteps,\n    max_beta=0.999,\n    alpha_transform_type=\"cosine\",\n):\n    \"\"\"\n    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of\n    (1-beta) over time from t = [0,1].\n    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up\n    to that part of the diffusion process.\n    Args:\n        num_diffusion_timesteps (`int`): the number of betas to produce.\n        max_beta (`float`): the maximum beta to use; use values lower than 1 to\n                     prevent singularities.\n        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.\n                     Choose from `cosine` or `exp`\n    Returns:\n        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs\n    \"\"\"\n    if alpha_transform_type == \"cosine\":\n\n        def alpha_bar_fn(t):\n            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2\n\n    elif alpha_transform_type == \"exp\":\n\n        def alpha_bar_fn(t):\n            return math.exp(t * -12.0)\n\n    else:\n        raise ValueError(f\"Unsupported alpha_transform_type: {alpha_transform_type}\")\n\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))\n    return torch.tensor(betas, dtype=torch.float32)\n\n\ndef rescale_zero_terminal_snr(betas):\n    \"\"\"\n    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)\n    Args:\n        betas (`torch.Tensor`):\n            the betas that the scheduler is being initialized with.\n    Returns:\n        `torch.Tensor`: rescaled betas with zero terminal SNR\n    \"\"\"\n    # Convert betas to alphas_bar_sqrt\n    alphas = 1.0 - betas\n    alphas_cumprod = torch.cumprod(alphas, dim=0)\n    alphas_bar_sqrt = alphas_cumprod.sqrt()\n\n    # Store old values.\n    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()\n    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()\n\n    # Shift so the last timestep is zero.\n    alphas_bar_sqrt -= alphas_bar_sqrt_T\n\n    # Scale so the first timestep is back to the old value.\n    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)\n\n    # Convert alphas_bar_sqrt to betas\n    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt\n    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod\n    alphas = torch.cat([alphas_bar[0:1], alphas])\n    betas = 1 - alphas\n\n    return betas\n\n\nclass LCMSchedulerWithTimestamp(SchedulerMixin, ConfigMixin):\n    \"\"\"\n    This class modifies LCMScheduler to add a timestamp argument to set_timesteps\n\n\n    `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with\n    non-Markovian guidance.\n    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic\n    methods the library implements for all schedulers such as loading and saving.\n    Args:\n        num_train_timesteps (`int`, defaults to 1000):\n            The number of diffusion steps to train the model.\n        beta_start (`float`, defaults to 0.0001):\n            The starting `beta` value of inference.\n        beta_end (`float`, defaults to 0.02):\n            The final `beta` value.\n        beta_schedule (`str`, defaults to `\"linear\"`):\n            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from\n            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.\n        trained_betas (`np.ndarray`, *optional*):\n            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.\n        clip_sample (`bool`, defaults to `True`):\n            Clip the predicted sample for numerical stability.\n        clip_sample_range (`float`, defaults to 1.0):\n            The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.\n        set_alpha_to_one (`bool`, defaults to `True`):\n            Each diffusion step uses the alphas product value at that step and at the previous one. For the final step\n            there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,\n            otherwise it uses the alpha value at step 0.\n        steps_offset (`int`, defaults to 0):\n            An offset added to the inference steps, as required by some model families.\n        prediction_type (`str`, defaults to `epsilon`, *optional*):\n            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),\n            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen\n            Video](https://imagen.research.google/video/paper.pdf) paper).\n        thresholding (`bool`, defaults to `False`):\n            Whether to use the \"dynamic thresholding\" method. This is unsuitable for latent-space diffusion models such\n            as Stable Diffusion.\n        dynamic_thresholding_ratio (`float`, defaults to 0.995):\n            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.\n        sample_max_value (`float`, defaults to 1.0):\n            The threshold value for dynamic thresholding. Valid only when `thresholding=True`.\n        timestep_spacing (`str`, defaults to `\"leading\"`):\n            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and\n            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.\n        rescale_betas_zero_snr (`bool`, defaults to `False`):\n            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and\n            dark samples instead of limiting it to samples with medium brightness. Loosely related to\n            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).\n    \"\"\"\n\n    # _compatibles = [e.name for e in KarrasDiffusionSchedulers]\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        beta_start: float = 0.0001,\n        beta_end: float = 0.02,\n        beta_schedule: str = \"linear\",\n        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,\n        clip_sample: bool = True,\n        set_alpha_to_one: bool = True,\n        steps_offset: int = 0,\n        prediction_type: str = \"epsilon\",\n        thresholding: bool = False,\n        dynamic_thresholding_ratio: float = 0.995,\n        clip_sample_range: float = 1.0,\n        sample_max_value: float = 1.0,\n        timestep_spacing: str = \"leading\",\n        rescale_betas_zero_snr: bool = False,\n    ):\n        if trained_betas is not None:\n            self.betas = torch.tensor(trained_betas, dtype=torch.float32)\n        elif beta_schedule == \"linear\":\n            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)\n        elif beta_schedule == \"scaled_linear\":\n            # this schedule is very specific to the latent diffusion model.\n            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2\n        elif beta_schedule == \"squaredcos_cap_v2\":\n            # Glide cosine schedule\n            self.betas = betas_for_alpha_bar(num_train_timesteps)\n        else:\n            raise NotImplementedError(f\"{beta_schedule} is not implemented for {self.__class__}\")\n\n        # Rescale for zero SNR\n        if rescale_betas_zero_snr:\n            self.betas = rescale_zero_terminal_snr(self.betas)\n\n        self.alphas = 1.0 - self.betas\n        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)\n\n        # At every step in ddim, we are looking into the previous alphas_cumprod\n        # For the final step, there is no previous alphas_cumprod because we are already at 0\n        # `set_alpha_to_one` decides whether we set this parameter simply to one or\n        # whether we use the final alpha of the \"non-previous\" one.\n        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]\n\n        # standard deviation of the initial noise distribution\n        self.init_noise_sigma = 1.0\n\n        # setable values\n        self.num_inference_steps = None\n        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))\n\n    def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:\n        \"\"\"\n        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the\n        current timestep.\n        Args:\n            sample (`torch.Tensor`):\n                The input sample.\n            timestep (`int`, *optional*):\n                The current timestep in the diffusion chain.\n        Returns:\n            `torch.Tensor`:\n                A scaled input sample.\n        \"\"\"\n        return sample\n\n    def _get_variance(self, timestep, prev_timestep):\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod\n        beta_prod_t = 1 - alpha_prod_t\n        beta_prod_t_prev = 1 - alpha_prod_t_prev\n\n        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)\n\n        return variance\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample\n    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        \"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the\n        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by\n        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing\n        pixels from saturation at each step. We find that dynamic thresholding results in significantly better\n        photorealism as well as better image-text alignment, especially when using very large guidance weights.\"\n        https://huggingface.co/papers/2205.11487\n        \"\"\"\n        dtype = sample.dtype\n        batch_size, channels, height, width = sample.shape\n\n        if dtype not in (torch.float32, torch.float64):\n            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half\n\n        # Flatten sample for doing quantile calculation along each image\n        sample = sample.reshape(batch_size, channels * height * width)\n\n        abs_sample = sample.abs()  # \"a certain percentile absolute pixel value\"\n\n        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)\n        s = torch.clamp(\n            s, min=1, max=self.config.sample_max_value\n        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]\n\n        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0\n        sample = torch.clamp(sample, -s, s) / s  # \"we threshold xt0 to the range [-s, s] and then divide by s\"\n\n        sample = sample.reshape(batch_size, channels, height, width)\n        sample = sample.to(dtype)\n\n        return sample\n\n    def set_timesteps(\n        self, strength, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None\n    ):\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain (to be run before inference).\n        Args:\n            num_inference_steps (`int`):\n                The number of diffusion steps used when generating samples with a pre-trained model.\n        \"\"\"\n\n        if num_inference_steps > self.config.num_train_timesteps:\n            raise ValueError(\n                f\"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:\"\n                f\" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle\"\n                f\" maximal {self.config.num_train_timesteps} timesteps.\"\n            )\n\n        self.num_inference_steps = num_inference_steps\n\n        # LCM Timesteps Setting:  # Linear Spacing\n        c = self.config.num_train_timesteps // lcm_origin_steps\n        lcm_origin_timesteps = (\n            np.asarray(list(range(1, int(lcm_origin_steps * strength) + 1))) * c - 1\n        )  # LCM Training  Steps Schedule\n        skipping_step = len(lcm_origin_timesteps) // num_inference_steps\n        timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]  # LCM Inference Steps Schedule\n\n        self.timesteps = torch.from_numpy(timesteps.copy()).to(device)\n\n    def get_scalings_for_boundary_condition_discrete(self, t):\n        self.sigma_data = 0.5  # Default: 0.5\n\n        # By dividing 0.1: This is almost a delta function at t=0.\n        c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)\n        c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5\n        return c_skip, c_out\n\n    def step(\n        self,\n        model_output: torch.Tensor,\n        timeindex: int,\n        timestep: int,\n        sample: torch.Tensor,\n        eta: float = 0.0,\n        use_clipped_model_output: bool = False,\n        generator=None,\n        variance_noise: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[LCMSchedulerOutput, Tuple]:\n        \"\"\"\n        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion\n        process from the learned model outputs (most often the predicted noise).\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from learned diffusion model.\n            timestep (`float`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n            eta (`float`):\n                The weight of noise for added noise in diffusion step.\n            use_clipped_model_output (`bool`, defaults to `False`):\n                If `True`, computes \"corrected\" `model_output` from the clipped predicted original sample. Necessary\n                because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no\n                clipping has happened, \"corrected\" `model_output` would coincide with the one provided as input and\n                `use_clipped_model_output` has no effect.\n            generator (`torch.Generator`, *optional*):\n                A random number generator.\n            variance_noise (`torch.Tensor`):\n                Alternative to generating noise with `generator` by directly providing the noise for the variance\n                itself. Useful for methods such as [`CycleDiffusion`].\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.\n        Returns:\n            [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:\n                If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a\n                tuple is returned where the first element is the sample tensor.\n        \"\"\"\n        if self.num_inference_steps is None:\n            raise ValueError(\n                \"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler\"\n            )\n\n        # 1. get previous step value\n        prev_timeindex = timeindex + 1\n        if prev_timeindex < len(self.timesteps):\n            prev_timestep = self.timesteps[prev_timeindex]\n        else:\n            prev_timestep = timestep\n\n        # 2. compute alphas, betas\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod\n\n        beta_prod_t = 1 - alpha_prod_t\n        beta_prod_t_prev = 1 - alpha_prod_t_prev\n\n        # 3. Get scalings for boundary conditions\n        c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)\n\n        # 4. Different Parameterization:\n        parameterization = self.config.prediction_type\n\n        if parameterization == \"epsilon\":  # noise-prediction\n            pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()\n\n        elif parameterization == \"sample\":  # x-prediction\n            pred_x0 = model_output\n\n        elif parameterization == \"v_prediction\":  # v-prediction\n            pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output\n\n        # 4. Denoise model output using boundary conditions\n        denoised = c_out * pred_x0 + c_skip * sample\n\n        # 5. Sample z ~ N(0, I), For MultiStep Inference\n        # Noise is not used for one-step sampling.\n        if len(self.timesteps) > 1:\n            noise = torch.randn(model_output.shape).to(model_output.device)\n            prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise\n        else:\n            prev_sample = denoised\n\n        if not return_dict:\n            return (prev_sample, denoised)\n\n        return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise\n    def add_noise(\n        self,\n        original_samples: torch.Tensor,\n        noise: torch.Tensor,\n        timesteps: torch.IntTensor,\n    ) -> torch.Tensor:\n        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples\n        alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)\n        timesteps = timesteps.to(original_samples.device)\n\n        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5\n        sqrt_alpha_prod = sqrt_alpha_prod.flatten()\n        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):\n            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)\n\n        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5\n        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()\n        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):\n            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)\n\n        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise\n        return noisy_samples\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity\n    def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:\n        # Make sure alphas_cumprod and timestep have same device and dtype as sample\n        alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)\n        timesteps = timesteps.to(sample.device)\n\n        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5\n        sqrt_alpha_prod = sqrt_alpha_prod.flatten()\n        while len(sqrt_alpha_prod.shape) < len(sample.shape):\n            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)\n\n        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5\n        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()\n        while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):\n            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)\n\n        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample\n        return velocity\n\n    def __len__(self):\n        return self.config.num_train_timesteps\n"
  },
  {
    "path": "examples/community/latent_consistency_interpolate.py",
    "content": "import inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker\nfrom diffusers.schedulers import LCMScheduler\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> import numpy as np\n\n        >>> from diffusers import DiffusionPipeline\n\n        >>> pipe = DiffusionPipeline.from_pretrained(\"SimianLuo/LCM_Dreamshaper_v7\", custom_pipeline=\"latent_consistency_interpolate\")\n        >>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality.\n        >>> pipe.to(torch_device=\"cuda\", torch_dtype=torch.float32)\n\n        >>> prompts = [\"A cat\", \"A dog\", \"A horse\"]\n        >>> num_inference_steps = 4\n        >>> num_interpolation_steps = 24\n        >>> seed = 1337\n\n        >>> torch.manual_seed(seed)\n        >>> np.random.seed(seed)\n\n        >>> images = pipe(\n                prompt=prompts,\n                height=512,\n                width=512,\n                num_inference_steps=num_inference_steps,\n                num_interpolation_steps=num_interpolation_steps,\n                guidance_scale=8.0,\n                embedding_interpolation_type=\"lerp\",\n                latent_interpolation_type=\"slerp\",\n                process_batch_size=4, # Make it higher or lower based on your GPU memory\n                generator=torch.Generator(seed),\n            )\n\n        >>> # Save the images as a video\n        >>> import imageio\n        >>> from PIL import Image\n\n        >>> def pil_to_video(images: List[Image.Image], filename: str, fps: int = 60) -> None:\n                frames = [np.array(image) for image in images]\n                with imageio.get_writer(filename, fps=fps) as video_writer:\n                    for frame in frames:\n                        video_writer.append_data(frame)\n\n        >>> pil_to_video(images, \"lcm_interpolate.mp4\", fps=24)\n        ```\n\"\"\"\n\n\ndef lerp(\n    v0: Union[torch.Tensor, np.ndarray],\n    v1: Union[torch.Tensor, np.ndarray],\n    t: Union[float, torch.Tensor, np.ndarray],\n) -> Union[torch.Tensor, np.ndarray]:\n    \"\"\"\n    Linearly interpolate between two vectors/tensors.\n\n    Args:\n        v0 (`torch.Tensor` or `np.ndarray`): First vector/tensor.\n        v1 (`torch.Tensor` or `np.ndarray`): Second vector/tensor.\n        t: (`float`, `torch.Tensor`, or `np.ndarray`):\n            Interpolation factor. If float, must be between 0 and 1. If np.ndarray or\n            torch.Tensor, must be one dimensional with values between 0 and 1.\n\n    Returns:\n        Union[torch.Tensor, np.ndarray]\n            Interpolated vector/tensor between v0 and v1.\n    \"\"\"\n    inputs_are_torch = False\n    t_is_float = False\n\n    if isinstance(v0, torch.Tensor):\n        inputs_are_torch = True\n        input_device = v0.device\n        v0 = v0.cpu().numpy()\n        v1 = v1.cpu().numpy()\n\n    if isinstance(t, torch.Tensor):\n        inputs_are_torch = True\n        input_device = t.device\n        t = t.cpu().numpy()\n    elif isinstance(t, float):\n        t_is_float = True\n        t = np.array([t])\n\n    t = t[..., None]\n    v0 = v0[None, ...]\n    v1 = v1[None, ...]\n    v2 = (1 - t) * v0 + t * v1\n\n    if t_is_float and v0.ndim > 1:\n        assert v2.shape[0] == 1\n        v2 = np.squeeze(v2, axis=0)\n    if inputs_are_torch:\n        v2 = torch.from_numpy(v2).to(input_device)\n\n    return v2\n\n\ndef slerp(\n    v0: Union[torch.Tensor, np.ndarray],\n    v1: Union[torch.Tensor, np.ndarray],\n    t: Union[float, torch.Tensor, np.ndarray],\n    DOT_THRESHOLD=0.9995,\n) -> Union[torch.Tensor, np.ndarray]:\n    \"\"\"\n    Spherical linear interpolation between two vectors/tensors.\n\n    Args:\n        v0 (`torch.Tensor` or `np.ndarray`): First vector/tensor.\n        v1 (`torch.Tensor` or `np.ndarray`): Second vector/tensor.\n        t: (`float`, `torch.Tensor`, or `np.ndarray`):\n            Interpolation factor. If float, must be between 0 and 1. If np.ndarray or\n            torch.Tensor, must be one dimensional with values between 0 and 1.\n        DOT_THRESHOLD (`float`, *optional*, default=0.9995):\n            Threshold for when to use linear interpolation instead of spherical interpolation.\n\n    Returns:\n        `torch.Tensor` or `np.ndarray`:\n            Interpolated vector/tensor between v0 and v1.\n    \"\"\"\n    inputs_are_torch = False\n    t_is_float = False\n\n    if isinstance(v0, torch.Tensor):\n        inputs_are_torch = True\n        input_device = v0.device\n        v0 = v0.cpu().numpy()\n        v1 = v1.cpu().numpy()\n\n    if isinstance(t, torch.Tensor):\n        inputs_are_torch = True\n        input_device = t.device\n        t = t.cpu().numpy()\n    elif isinstance(t, float):\n        t_is_float = True\n        t = np.array([t], dtype=v0.dtype)\n\n    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))\n    if np.abs(dot) > DOT_THRESHOLD:\n        # v1 and v2 are close to parallel\n        # Use linear interpolation instead\n        v2 = lerp(v0, v1, t)\n    else:\n        theta_0 = np.arccos(dot)\n        sin_theta_0 = np.sin(theta_0)\n        theta_t = theta_0 * t\n        sin_theta_t = np.sin(theta_t)\n        s0 = np.sin(theta_0 - theta_t) / sin_theta_0\n        s1 = sin_theta_t / sin_theta_0\n        s0 = s0[..., None]\n        s1 = s1[..., None]\n        v0 = v0[None, ...]\n        v1 = v1[None, ...]\n        v2 = s0 * v0 + s1 * v1\n\n    if t_is_float and v0.ndim > 1:\n        assert v2.shape[0] == 1\n        v2 = np.squeeze(v2, axis=0)\n    if inputs_are_torch:\n        v2 = torch.from_numpy(v2).to(input_device)\n\n    return v2\n\n\nclass LatentConsistencyModelWalkPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    StableDiffusionLoraLoaderMixin,\n    FromSingleFileMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using a latent consistency model.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only\n            supports [`LCMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n        requires_safety_checker (`bool`, *optional*, defaults to `True`):\n            Whether the pipeline requires a safety checker component.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"denoised\", \"prompt_embeds\", \"w_embedding\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: LCMScheduler,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    # Currently StableDiffusionPipeline.check_inputs with negative prompt stuff removed\n    def check_inputs(\n        self,\n        prompt: Union[str, List[str]],\n        height: int,\n        width: int,\n        callback_steps: int,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n    @torch.no_grad()\n    def interpolate_embedding(\n        self,\n        start_embedding: torch.Tensor,\n        end_embedding: torch.Tensor,\n        num_interpolation_steps: Union[int, List[int]],\n        interpolation_type: str,\n    ) -> torch.Tensor:\n        if interpolation_type == \"lerp\":\n            interpolation_fn = lerp\n        elif interpolation_type == \"slerp\":\n            interpolation_fn = slerp\n        else:\n            raise ValueError(\n                f\"embedding_interpolation_type must be one of ['lerp', 'slerp'], got {interpolation_type}.\"\n            )\n\n        embedding = torch.cat([start_embedding, end_embedding])\n        steps = torch.linspace(0, 1, num_interpolation_steps, dtype=embedding.dtype).cpu().numpy()\n        steps = np.expand_dims(steps, axis=tuple(range(1, embedding.ndim)))\n        interpolations = []\n\n        # Interpolate between text embeddings\n        # TODO(aryan): Think of a better way of doing this\n        # See if it can be done parallelly instead\n        for i in range(embedding.shape[0] - 1):\n            interpolations.append(interpolation_fn(embedding[i], embedding[i + 1], steps).squeeze(dim=1))\n\n        interpolations = torch.cat(interpolations)\n        return interpolations\n\n    @torch.no_grad()\n    def interpolate_latent(\n        self,\n        start_latent: torch.Tensor,\n        end_latent: torch.Tensor,\n        num_interpolation_steps: Union[int, List[int]],\n        interpolation_type: str,\n    ) -> torch.Tensor:\n        if interpolation_type == \"lerp\":\n            interpolation_fn = lerp\n        elif interpolation_type == \"slerp\":\n            interpolation_fn = slerp\n\n        latent = torch.cat([start_latent, end_latent])\n        steps = torch.linspace(0, 1, num_interpolation_steps, dtype=latent.dtype).cpu().numpy()\n        steps = np.expand_dims(steps, axis=tuple(range(1, latent.ndim)))\n        interpolations = []\n\n        # Interpolate between latents\n        # TODO: Think of a better way of doing this\n        # See if it can be done parallelly instead\n        for i in range(latent.shape[0] - 1):\n            interpolations.append(interpolation_fn(latent[i], latent[i + 1], steps).squeeze(dim=1))\n\n        return torch.cat(interpolations)\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 4,\n        num_interpolation_steps: int = 8,\n        original_inference_steps: int = None,\n        guidance_scale: float = 8.5,\n        num_images_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        embedding_interpolation_type: str = \"lerp\",\n        latent_interpolation_type: str = \"slerp\",\n        process_batch_size: int = 4,\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            original_inference_steps (`int`, *optional*):\n                The original number of inference steps use to generate a linearly-spaced timestep schedule, from which\n                we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule,\n                following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the\n                scheduler's `original_inference_steps` attribute.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n                Note that the original latent consistency models paper uses a different CFG formulation where the\n                guidance scales are decreased by 1 (so in the paper formulation CFG is enabled when `guidance_scale >\n                0`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            embedding_interpolation_type (`str`, *optional*, defaults to `\"lerp\"`):\n                The type of interpolation to use for interpolating between text embeddings. Choose between `\"lerp\"` and `\"slerp\"`.\n            latent_interpolation_type (`str`, *optional*, defaults to `\"slerp\"`):\n                The type of interpolation to use for interpolating between latents. Choose between `\"lerp\"` and `\"slerp\"`.\n            process_batch_size (`int`, *optional*, defaults to 4):\n                The batch size to use for processing the images. This is useful when generating a large number of images\n                and you want to avoid running out of memory.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(prompt, height, width, callback_steps, prompt_embeds, callback_on_step_end_tensor_inputs)\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n        if batch_size < 2:\n            raise ValueError(f\"`prompt` must have length of at least 2 but found {batch_size}\")\n        if num_images_per_prompt != 1:\n            raise ValueError(\"`num_images_per_prompt` must be `1` as no other value is supported yet\")\n        if prompt_embeds is not None:\n            raise ValueError(\"`prompt_embeds` must be None since it is not supported yet\")\n        if latents is not None:\n            raise ValueError(\"`latents` must be None since it is not supported yet\")\n\n        device = self._execution_device\n        # do_classifier_free_guidance = guidance_scale > 1.0\n\n        lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        self.scheduler.set_timesteps(num_inference_steps, device, original_inference_steps=original_inference_steps)\n        timesteps = self.scheduler.timesteps\n        num_channels_latents = self.unet.config.in_channels\n        # bs = batch_size * num_images_per_prompt\n\n        # 3. Encode initial input prompt\n        prompt_embeds_1, _ = self.encode_prompt(\n            prompt[:1],\n            device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=False,\n            negative_prompt=None,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=None,\n            lora_scale=lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # 4. Prepare initial latent variables\n        latents_1 = self.prepare_latents(\n            1,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds_1.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n        images = []\n\n        # 5. Iterate over prompts and perform latent walk. Note that we do this two prompts at a time\n        #    otherwise the memory usage ends up being too high.\n        with self.progress_bar(total=batch_size - 1) as prompt_progress_bar:\n            for i in range(1, batch_size):\n                # 6. Encode current prompt\n                prompt_embeds_2, _ = self.encode_prompt(\n                    prompt[i : i + 1],\n                    device,\n                    num_images_per_prompt=num_images_per_prompt,\n                    do_classifier_free_guidance=False,\n                    negative_prompt=None,\n                    prompt_embeds=prompt_embeds,\n                    negative_prompt_embeds=None,\n                    lora_scale=lora_scale,\n                    clip_skip=self.clip_skip,\n                )\n\n                # 7. Prepare current latent variables\n                latents_2 = self.prepare_latents(\n                    1,\n                    num_channels_latents,\n                    height,\n                    width,\n                    prompt_embeds_2.dtype,\n                    device,\n                    generator,\n                    latents,\n                )\n\n                # 8. Interpolate between previous and current prompt embeddings and latents\n                inference_embeddings = self.interpolate_embedding(\n                    start_embedding=prompt_embeds_1,\n                    end_embedding=prompt_embeds_2,\n                    num_interpolation_steps=num_interpolation_steps,\n                    interpolation_type=embedding_interpolation_type,\n                )\n                inference_latents = self.interpolate_latent(\n                    start_latent=latents_1,\n                    end_latent=latents_2,\n                    num_interpolation_steps=num_interpolation_steps,\n                    interpolation_type=latent_interpolation_type,\n                )\n                next_prompt_embeds = inference_embeddings[-1:].detach().clone()\n                next_latents = inference_latents[-1:].detach().clone()\n                bs = num_interpolation_steps\n\n                # 9. Perform inference in batches. Note the use of `process_batch_size` to control the batch size\n                #    of the inference. This is useful for reducing memory usage and can be configured based on the\n                #    available GPU memory.\n                with self.progress_bar(\n                    total=(bs + process_batch_size - 1) // process_batch_size\n                ) as batch_progress_bar:\n                    for batch_index in range(0, bs, process_batch_size):\n                        batch_inference_latents = inference_latents[batch_index : batch_index + process_batch_size]\n                        batch_inference_embeddings = inference_embeddings[\n                            batch_index : batch_index + process_batch_size\n                        ]\n\n                        self.scheduler.set_timesteps(\n                            num_inference_steps, device, original_inference_steps=original_inference_steps\n                        )\n                        timesteps = self.scheduler.timesteps\n\n                        current_bs = batch_inference_embeddings.shape[0]\n                        w = torch.tensor(self.guidance_scale - 1).repeat(current_bs)\n                        w_embedding = self.get_guidance_scale_embedding(\n                            w, embedding_dim=self.unet.config.time_cond_proj_dim\n                        ).to(device=device, dtype=latents_1.dtype)\n\n                        # 10. Perform inference for current batch\n                        with self.progress_bar(total=num_inference_steps) as progress_bar:\n                            for index, t in enumerate(timesteps):\n                                batch_inference_latents = batch_inference_latents.to(batch_inference_embeddings.dtype)\n\n                                # model prediction (v-prediction, eps, x)\n                                model_pred = self.unet(\n                                    batch_inference_latents,\n                                    t,\n                                    timestep_cond=w_embedding,\n                                    encoder_hidden_states=batch_inference_embeddings,\n                                    cross_attention_kwargs=self.cross_attention_kwargs,\n                                    return_dict=False,\n                                )[0]\n\n                                # compute the previous noisy sample x_t -> x_t-1\n                                batch_inference_latents, denoised = self.scheduler.step(\n                                    model_pred, t, batch_inference_latents, **extra_step_kwargs, return_dict=False\n                                )\n                                if callback_on_step_end is not None:\n                                    callback_kwargs = {}\n                                    for k in callback_on_step_end_tensor_inputs:\n                                        callback_kwargs[k] = locals()[k]\n                                    callback_outputs = callback_on_step_end(self, index, t, callback_kwargs)\n\n                                    batch_inference_latents = callback_outputs.pop(\"latents\", batch_inference_latents)\n                                    batch_inference_embeddings = callback_outputs.pop(\n                                        \"prompt_embeds\", batch_inference_embeddings\n                                    )\n                                    w_embedding = callback_outputs.pop(\"w_embedding\", w_embedding)\n                                    denoised = callback_outputs.pop(\"denoised\", denoised)\n\n                                # call the callback, if provided\n                                if index == len(timesteps) - 1 or (\n                                    (index + 1) > num_warmup_steps and (index + 1) % self.scheduler.order == 0\n                                ):\n                                    progress_bar.update()\n                                    if callback is not None and index % callback_steps == 0:\n                                        step_idx = index // getattr(self.scheduler, \"order\", 1)\n                                        callback(step_idx, t, batch_inference_latents)\n\n                        denoised = denoised.to(batch_inference_embeddings.dtype)\n\n                        # Note: This is not supported because you would get black images in your latent walk if\n                        #       NSFW concept is detected\n                        # if not output_type == \"latent\":\n                        #     image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]\n                        #     image, has_nsfw_concept = self.run_safety_checker(image, device, inference_embeddings.dtype)\n                        # else:\n                        #     image = denoised\n                        #     has_nsfw_concept = None\n\n                        # if has_nsfw_concept is None:\n                        #     do_denormalize = [True] * image.shape[0]\n                        # else:\n                        #     do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n                        image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]\n                        do_denormalize = [True] * image.shape[0]\n                        has_nsfw_concept = None\n\n                        image = self.image_processor.postprocess(\n                            image, output_type=output_type, do_denormalize=do_denormalize\n                        )\n                        images.append(image)\n\n                        batch_progress_bar.update()\n\n                prompt_embeds_1 = next_prompt_embeds\n                latents_1 = next_latents\n\n                prompt_progress_bar.update()\n\n        # 11. Determine what should be returned\n        if output_type == \"pil\":\n            images = [image for image_list in images for image in image_list]\n        elif output_type == \"np\":\n            images = np.concatenate(images)\n        elif output_type == \"pt\":\n            images = torch.cat(images)\n        else:\n            raise ValueError(\"`output_type` must be one of 'pil', 'np' or 'pt'.\")\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (images, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/latent_consistency_txt2img.py",
    "content": "# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion\n# and https://github.com/hojonathanho/diffusion\n\nimport math\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, ConfigMixin, DiffusionPipeline, SchedulerMixin, UNet2DConditionModel, logging\nfrom diffusers.configuration_utils import register_to_config\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.utils import BaseOutput\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass LatentConsistencyModelPipeline(DiffusionPipeline):\n    _optional_components = [\"scheduler\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: \"LCMScheduler\",\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        scheduler = (\n            scheduler\n            if scheduler is not None\n            else LCMScheduler(\n                beta_start=0.00085, beta_end=0.0120, beta_schedule=\"scaled_linear\", prediction_type=\"epsilon\"\n            )\n        )\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        prompt_embeds: None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n        \"\"\"\n\n        if prompt is not None and isinstance(prompt, str):\n            pass\n        elif prompt is not None and isinstance(prompt, list):\n            len(prompt)\n        else:\n            prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            prompt_embeds = self.text_encoder(\n                text_input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # Don't need to get uncond prompt embedding because of LCM Guided Distillation\n        return prompt_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if latents is None:\n            latents = torch.randn(shape, dtype=dtype).to(device)\n        else:\n            latents = latents.to(device)\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n        Args:\n        timesteps: torch.Tensor: generate embedding vectors at these timesteps\n        embedding_dim: int: dimension of the embeddings to generate\n        dtype: data type of the generated embeddings\n        Returns:\n        embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        height: Optional[int] = 768,\n        width: Optional[int] = 768,\n        guidance_scale: float = 7.5,\n        num_images_per_prompt: Optional[int] = 1,\n        latents: Optional[torch.Tensor] = None,\n        num_inference_steps: int = 4,\n        lcm_origin_steps: int = 50,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # do_classifier_free_guidance = guidance_scale > 0.0  # In LCM Implementation:  cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            prompt_embeds=prompt_embeds,\n        )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, lcm_origin_steps)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variable\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            latents,\n        )\n        bs = batch_size * num_images_per_prompt\n\n        # 6. Get Guidance Scale Embedding\n        w = torch.tensor(guidance_scale).repeat(bs)\n        w_embedding = self.get_w_embedding(w, embedding_dim=256).to(device=device, dtype=latents.dtype)\n\n        # 7. LCM MultiStep Sampling Loop:\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                ts = torch.full((bs,), t, device=device, dtype=torch.long)\n                latents = latents.to(prompt_embeds.dtype)\n\n                # model prediction (v-prediction, eps, x)\n                model_pred = self.unet(\n                    latents,\n                    ts,\n                    timestep_cond=w_embedding,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents, denoised = self.scheduler.step(model_pred, i, t, latents, return_dict=False)\n\n                # # call the callback, if provided\n                # if i == len(timesteps) - 1:\n                progress_bar.update()\n\n        denoised = denoised.to(prompt_embeds.dtype)\n        if not output_type == \"latent\":\n            image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = denoised\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n\n@dataclass\n# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM\nclass LCMSchedulerOutput(BaseOutput):\n    \"\"\"\n    Output class for the scheduler's `step` function output.\n    Args:\n        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):\n            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the\n            denoising loop.\n        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):\n            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.\n            `pred_original_sample` can be used to preview progress or for guidance.\n    \"\"\"\n\n    prev_sample: torch.Tensor\n    denoised: Optional[torch.Tensor] = None\n\n\n# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar\ndef betas_for_alpha_bar(\n    num_diffusion_timesteps,\n    max_beta=0.999,\n    alpha_transform_type=\"cosine\",\n):\n    \"\"\"\n    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of\n    (1-beta) over time from t = [0,1].\n    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up\n    to that part of the diffusion process.\n    Args:\n        num_diffusion_timesteps (`int`): the number of betas to produce.\n        max_beta (`float`): the maximum beta to use; use values lower than 1 to\n                     prevent singularities.\n        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.\n                     Choose from `cosine` or `exp`\n    Returns:\n        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs\n    \"\"\"\n    if alpha_transform_type == \"cosine\":\n\n        def alpha_bar_fn(t):\n            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2\n\n    elif alpha_transform_type == \"exp\":\n\n        def alpha_bar_fn(t):\n            return math.exp(t * -12.0)\n\n    else:\n        raise ValueError(f\"Unsupported alpha_transform_type: {alpha_transform_type}\")\n\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))\n    return torch.tensor(betas, dtype=torch.float32)\n\n\ndef rescale_zero_terminal_snr(betas):\n    \"\"\"\n    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)\n    Args:\n        betas (`torch.Tensor`):\n            the betas that the scheduler is being initialized with.\n    Returns:\n        `torch.Tensor`: rescaled betas with zero terminal SNR\n    \"\"\"\n    # Convert betas to alphas_bar_sqrt\n    alphas = 1.0 - betas\n    alphas_cumprod = torch.cumprod(alphas, dim=0)\n    alphas_bar_sqrt = alphas_cumprod.sqrt()\n\n    # Store old values.\n    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()\n    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()\n\n    # Shift so the last timestep is zero.\n    alphas_bar_sqrt -= alphas_bar_sqrt_T\n\n    # Scale so the first timestep is back to the old value.\n    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)\n\n    # Convert alphas_bar_sqrt to betas\n    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt\n    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod\n    alphas = torch.cat([alphas_bar[0:1], alphas])\n    betas = 1 - alphas\n\n    return betas\n\n\nclass LCMScheduler(SchedulerMixin, ConfigMixin):\n    \"\"\"\n    `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with\n    non-Markovian guidance.\n    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic\n    methods the library implements for all schedulers such as loading and saving.\n    Args:\n        num_train_timesteps (`int`, defaults to 1000):\n            The number of diffusion steps to train the model.\n        beta_start (`float`, defaults to 0.0001):\n            The starting `beta` value of inference.\n        beta_end (`float`, defaults to 0.02):\n            The final `beta` value.\n        beta_schedule (`str`, defaults to `\"linear\"`):\n            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from\n            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.\n        trained_betas (`np.ndarray`, *optional*):\n            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.\n        clip_sample (`bool`, defaults to `True`):\n            Clip the predicted sample for numerical stability.\n        clip_sample_range (`float`, defaults to 1.0):\n            The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.\n        set_alpha_to_one (`bool`, defaults to `True`):\n            Each diffusion step uses the alphas product value at that step and at the previous one. For the final step\n            there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,\n            otherwise it uses the alpha value at step 0.\n        steps_offset (`int`, defaults to 0):\n            An offset added to the inference steps, as required by some model families.\n        prediction_type (`str`, defaults to `epsilon`, *optional*):\n            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),\n            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen\n            Video](https://imagen.research.google/video/paper.pdf) paper).\n        thresholding (`bool`, defaults to `False`):\n            Whether to use the \"dynamic thresholding\" method. This is unsuitable for latent-space diffusion models such\n            as Stable Diffusion.\n        dynamic_thresholding_ratio (`float`, defaults to 0.995):\n            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.\n        sample_max_value (`float`, defaults to 1.0):\n            The threshold value for dynamic thresholding. Valid only when `thresholding=True`.\n        timestep_spacing (`str`, defaults to `\"leading\"`):\n            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and\n            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.\n        rescale_betas_zero_snr (`bool`, defaults to `False`):\n            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and\n            dark samples instead of limiting it to samples with medium brightness. Loosely related to\n            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).\n    \"\"\"\n\n    # _compatibles = [e.name for e in KarrasDiffusionSchedulers]\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        beta_start: float = 0.0001,\n        beta_end: float = 0.02,\n        beta_schedule: str = \"linear\",\n        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,\n        clip_sample: bool = True,\n        set_alpha_to_one: bool = True,\n        steps_offset: int = 0,\n        prediction_type: str = \"epsilon\",\n        thresholding: bool = False,\n        dynamic_thresholding_ratio: float = 0.995,\n        clip_sample_range: float = 1.0,\n        sample_max_value: float = 1.0,\n        timestep_spacing: str = \"leading\",\n        rescale_betas_zero_snr: bool = False,\n    ):\n        if trained_betas is not None:\n            self.betas = torch.tensor(trained_betas, dtype=torch.float32)\n        elif beta_schedule == \"linear\":\n            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)\n        elif beta_schedule == \"scaled_linear\":\n            # this schedule is very specific to the latent diffusion model.\n            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2\n        elif beta_schedule == \"squaredcos_cap_v2\":\n            # Glide cosine schedule\n            self.betas = betas_for_alpha_bar(num_train_timesteps)\n        else:\n            raise NotImplementedError(f\"{beta_schedule} is not implemented for {self.__class__}\")\n\n        # Rescale for zero SNR\n        if rescale_betas_zero_snr:\n            self.betas = rescale_zero_terminal_snr(self.betas)\n\n        self.alphas = 1.0 - self.betas\n        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)\n\n        # At every step in ddim, we are looking into the previous alphas_cumprod\n        # For the final step, there is no previous alphas_cumprod because we are already at 0\n        # `set_alpha_to_one` decides whether we set this parameter simply to one or\n        # whether we use the final alpha of the \"non-previous\" one.\n        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]\n\n        # standard deviation of the initial noise distribution\n        self.init_noise_sigma = 1.0\n\n        # setable values\n        self.num_inference_steps = None\n        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))\n\n    def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:\n        \"\"\"\n        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the\n        current timestep.\n        Args:\n            sample (`torch.Tensor`):\n                The input sample.\n            timestep (`int`, *optional*):\n                The current timestep in the diffusion chain.\n        Returns:\n            `torch.Tensor`:\n                A scaled input sample.\n        \"\"\"\n        return sample\n\n    def _get_variance(self, timestep, prev_timestep):\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod\n        beta_prod_t = 1 - alpha_prod_t\n        beta_prod_t_prev = 1 - alpha_prod_t_prev\n\n        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)\n\n        return variance\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample\n    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        \"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the\n        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by\n        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing\n        pixels from saturation at each step. We find that dynamic thresholding results in significantly better\n        photorealism as well as better image-text alignment, especially when using very large guidance weights.\"\n        https://huggingface.co/papers/2205.11487\n        \"\"\"\n        dtype = sample.dtype\n        batch_size, channels, height, width = sample.shape\n\n        if dtype not in (torch.float32, torch.float64):\n            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half\n\n        # Flatten sample for doing quantile calculation along each image\n        sample = sample.reshape(batch_size, channels * height * width)\n\n        abs_sample = sample.abs()  # \"a certain percentile absolute pixel value\"\n\n        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)\n        s = torch.clamp(\n            s, min=1, max=self.config.sample_max_value\n        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]\n\n        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0\n        sample = torch.clamp(sample, -s, s) / s  # \"we threshold xt0 to the range [-s, s] and then divide by s\"\n\n        sample = sample.reshape(batch_size, channels, height, width)\n        sample = sample.to(dtype)\n\n        return sample\n\n    def set_timesteps(self, num_inference_steps: int, lcm_origin_steps: int, device: Union[str, torch.device] = None):\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain (to be run before inference).\n        Args:\n            num_inference_steps (`int`):\n                The number of diffusion steps used when generating samples with a pre-trained model.\n        \"\"\"\n\n        if num_inference_steps > self.config.num_train_timesteps:\n            raise ValueError(\n                f\"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:\"\n                f\" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle\"\n                f\" maximal {self.config.num_train_timesteps} timesteps.\"\n            )\n\n        self.num_inference_steps = num_inference_steps\n\n        # LCM Timesteps Setting:  # Linear Spacing\n        c = self.config.num_train_timesteps // lcm_origin_steps\n        lcm_origin_timesteps = np.asarray(list(range(1, lcm_origin_steps + 1))) * c - 1  # LCM Training  Steps Schedule\n        skipping_step = len(lcm_origin_timesteps) // num_inference_steps\n        timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]  # LCM Inference Steps Schedule\n\n        self.timesteps = torch.from_numpy(timesteps.copy()).to(device)\n\n    def get_scalings_for_boundary_condition_discrete(self, t):\n        self.sigma_data = 0.5  # Default: 0.5\n\n        # By dividing 0.1: This is almost a delta function at t=0.\n        c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)\n        c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5\n        return c_skip, c_out\n\n    def step(\n        self,\n        model_output: torch.Tensor,\n        timeindex: int,\n        timestep: int,\n        sample: torch.Tensor,\n        eta: float = 0.0,\n        use_clipped_model_output: bool = False,\n        generator=None,\n        variance_noise: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[LCMSchedulerOutput, Tuple]:\n        \"\"\"\n        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion\n        process from the learned model outputs (most often the predicted noise).\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from learned diffusion model.\n            timestep (`float`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n            eta (`float`):\n                The weight of noise for added noise in diffusion step.\n            use_clipped_model_output (`bool`, defaults to `False`):\n                If `True`, computes \"corrected\" `model_output` from the clipped predicted original sample. Necessary\n                because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no\n                clipping has happened, \"corrected\" `model_output` would coincide with the one provided as input and\n                `use_clipped_model_output` has no effect.\n            generator (`torch.Generator`, *optional*):\n                A random number generator.\n            variance_noise (`torch.Tensor`):\n                Alternative to generating noise with `generator` by directly providing the noise for the variance\n                itself. Useful for methods such as [`CycleDiffusion`].\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.\n        Returns:\n            [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:\n                If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a\n                tuple is returned where the first element is the sample tensor.\n        \"\"\"\n        if self.num_inference_steps is None:\n            raise ValueError(\n                \"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler\"\n            )\n\n        # 1. get previous step value\n        prev_timeindex = timeindex + 1\n        if prev_timeindex < len(self.timesteps):\n            prev_timestep = self.timesteps[prev_timeindex]\n        else:\n            prev_timestep = timestep\n\n        # 2. compute alphas, betas\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod\n\n        beta_prod_t = 1 - alpha_prod_t\n        beta_prod_t_prev = 1 - alpha_prod_t_prev\n\n        # 3. Get scalings for boundary conditions\n        c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)\n\n        # 4. Different Parameterization:\n        parameterization = self.config.prediction_type\n\n        if parameterization == \"epsilon\":  # noise-prediction\n            pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()\n\n        elif parameterization == \"sample\":  # x-prediction\n            pred_x0 = model_output\n\n        elif parameterization == \"v_prediction\":  # v-prediction\n            pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output\n\n        # 4. Denoise model output using boundary conditions\n        denoised = c_out * pred_x0 + c_skip * sample\n\n        # 5. Sample z ~ N(0, I), For MultiStep Inference\n        # Noise is not used for one-step sampling.\n        if len(self.timesteps) > 1:\n            noise = torch.randn(model_output.shape).to(model_output.device)\n            prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise\n        else:\n            prev_sample = denoised\n\n        if not return_dict:\n            return (prev_sample, denoised)\n\n        return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise\n    def add_noise(\n        self,\n        original_samples: torch.Tensor,\n        noise: torch.Tensor,\n        timesteps: torch.IntTensor,\n    ) -> torch.Tensor:\n        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples\n        alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)\n        timesteps = timesteps.to(original_samples.device)\n\n        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5\n        sqrt_alpha_prod = sqrt_alpha_prod.flatten()\n        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):\n            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)\n\n        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5\n        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()\n        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):\n            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)\n\n        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise\n        return noisy_samples\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity\n    def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:\n        # Make sure alphas_cumprod and timestep have same device and dtype as sample\n        alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)\n        timesteps = timesteps.to(sample.device)\n\n        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5\n        sqrt_alpha_prod = sqrt_alpha_prod.flatten()\n        while len(sqrt_alpha_prod.shape) < len(sample.shape):\n            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)\n\n        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5\n        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()\n        while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):\n            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)\n\n        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample\n        return velocity\n\n    def __len__(self):\n        return self.config.num_train_timesteps\n"
  },
  {
    "path": "examples/community/llm_grounded_diffusion.py",
    "content": "# Copyright 2025 Long Lian, the GLIGEN Authors, and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# This is a single file implementation of LMD+. See README.md for examples.\n\nimport ast\nimport gc\nimport inspect\nimport math\nimport warnings\nfrom collections.abc import Iterable\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom packaging import version\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.models.attention import Attention, GatedSelfAttentionDense\nfrom diffusers.models.attention_processor import AttnProcessor2_0\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines import DiffusionPipeline\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import DiffusionPipeline\n\n        >>> pipe = DiffusionPipeline.from_pretrained(\n        ...     \"longlian/lmd_plus\",\n        ...     custom_pipeline=\"llm_grounded_diffusion\",\n        ...     custom_revision=\"main\",\n        ...     variant=\"fp16\", torch_dtype=torch.float16\n        ... )\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> # Generate an image described by the prompt and\n        >>> # insert objects described by text at the region defined by bounding boxes\n        >>> prompt = \"a waterfall and a modern high speed train in a beautiful forest with fall foliage\"\n        >>> boxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]]\n        >>> phrases = [\"a waterfall\", \"a modern high speed train\"]\n\n        >>> images = pipe(\n        ...     prompt=prompt,\n        ...     phrases=phrases,\n        ...     boxes=boxes,\n        ...     gligen_scheduled_sampling_beta=0.4,\n        ...     output_type=\"pil\",\n        ...     num_inference_steps=50,\n        ...     lmd_guidance_kwargs={}\n        ... ).images\n\n        >>> images[0].save(\"./lmd_plus_generation.jpg\")\n\n        >>> # Generate directly from a text prompt and an LLM response\n        >>> prompt = \"a waterfall and a modern high speed train in a beautiful forest with fall foliage\"\n        >>> phrases, boxes, bg_prompt, neg_prompt = pipe.parse_llm_response(\\\"\"\"\n        [('a waterfall', [71, 105, 148, 258]), ('a modern high speed train', [255, 223, 181, 149])]\n        Background prompt: A beautiful forest with fall foliage\n        Negative prompt:\n        \\\"\"\")\n\n        >> images = pipe(\n        ...     prompt=prompt,\n        ...     negative_prompt=neg_prompt,\n        ...     phrases=phrases,\n        ...     boxes=boxes,\n        ...     gligen_scheduled_sampling_beta=0.4,\n        ...     output_type=\"pil\",\n        ...     num_inference_steps=50,\n        ...     lmd_guidance_kwargs={}\n        ... ).images\n\n        >>> images[0].save(\"./lmd_plus_generation.jpg\")\n\nimages[0]\n\n        ```\n\"\"\"\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n# All keys in Stable Diffusion models: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]\n# Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.\nDEFAULT_GUIDANCE_ATTN_KEYS = [\n    (\"mid\", 0, 0, 0),\n    (\"up\", 1, 0, 0),\n    (\"up\", 1, 1, 0),\n    (\"up\", 1, 2, 0),\n]\n\n\ndef convert_attn_keys(key):\n    \"\"\"Convert the attention key from tuple format to the torch state format\"\"\"\n\n    if key[0] == \"mid\":\n        assert key[1] == 0, f\"mid block only has one block but the index is {key[1]}\"\n        return f\"{key[0]}_block.attentions.{key[2]}.transformer_blocks.{key[3]}.attn2.processor\"\n\n    return f\"{key[0]}_blocks.{key[1]}.attentions.{key[2]}.transformer_blocks.{key[3]}.attn2.processor\"\n\n\nDEFAULT_GUIDANCE_ATTN_KEYS = [convert_attn_keys(key) for key in DEFAULT_GUIDANCE_ATTN_KEYS]\n\n\ndef scale_proportion(obj_box, H, W):\n    # Separately rounding box_w and box_h to allow shift invariant box sizes. Otherwise box sizes may change when both coordinates being rounded end with \".5\".\n    x_min, y_min = round(obj_box[0] * W), round(obj_box[1] * H)\n    box_w, box_h = round((obj_box[2] - obj_box[0]) * W), round((obj_box[3] - obj_box[1]) * H)\n    x_max, y_max = x_min + box_w, y_min + box_h\n\n    x_min, y_min = max(x_min, 0), max(y_min, 0)\n    x_max, y_max = min(x_max, W), min(y_max, H)\n\n    return x_min, y_min, x_max, y_max\n\n\n# Adapted from the parent class `AttnProcessor2_0`\nclass AttnProcessorWithHook(AttnProcessor2_0):\n    def __init__(\n        self,\n        attn_processor_key,\n        hidden_size,\n        cross_attention_dim,\n        hook=None,\n        fast_attn=True,\n        enabled=True,\n    ):\n        super().__init__()\n        self.attn_processor_key = attn_processor_key\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.hook = hook\n        self.fast_attn = fast_attn\n        self.enabled = enabled\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n        scale: float = 1.0,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        args = () if USE_PEFT_BACKEND else (scale,)\n        query = attn.to_q(hidden_states, *args)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states, *args)\n        value = attn.to_v(encoder_hidden_states, *args)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        if (self.hook is not None and self.enabled) or not self.fast_attn:\n            query_batch_dim = attn.head_to_batch_dim(query)\n            key_batch_dim = attn.head_to_batch_dim(key)\n            value_batch_dim = attn.head_to_batch_dim(value)\n            attention_probs = attn.get_attention_scores(query_batch_dim, key_batch_dim, attention_mask)\n\n        if self.hook is not None and self.enabled:\n            # Call the hook with query, key, value, and attention maps\n            self.hook(\n                self.attn_processor_key,\n                query_batch_dim,\n                key_batch_dim,\n                value_batch_dim,\n                attention_probs,\n            )\n\n        if self.fast_attn:\n            query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n            key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n            value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n            if attention_mask is not None:\n                # scaled_dot_product_attention expects attention_mask shape to be\n                # (batch, heads, source_length, target_length)\n                attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n            # the output of sdp = (batch, num_heads, seq_len, head_dim)\n            # TODO: add support for attn.scale when we move to Torch 2.1\n            hidden_states = F.scaled_dot_product_attention(\n                query,\n                key,\n                value,\n                attn_mask=attention_mask,\n                dropout_p=0.0,\n                is_causal=False,\n            )\n            hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n            hidden_states = hidden_states.to(query.dtype)\n        else:\n            hidden_states = torch.bmm(attention_probs, value)\n            hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states, *args)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass LLMGroundedDiffusionPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    StableDiffusionLoraLoaderMixin,\n    IPAdapterMixin,\n    FromSingleFileMixin,\n):\n    r\"\"\"\n    Pipeline for layout-grounded text-to-image generation using LLM-grounded Diffusion (LMD+): https://huggingface.co/papers/2305.13655.\n\n    This model inherits from [`StableDiffusionPipeline`] and aims at implementing the pipeline with minimal modifications. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    This is a simplified implementation that does not perform latent or attention transfer from single object generation to overall generation. The final image is generated directly with attention and adapters control.\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n        requires_safety_checker (bool):\n            Whether a safety checker is needed for this pipeline.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    objects_text = \"Objects: \"\n    bg_prompt_text = \"Background prompt: \"\n    bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip()\n    neg_prompt_text = \"Negative prompt: \"\n    neg_prompt_text_no_trailing_space = neg_prompt_text.rstrip()\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n    ):\n        # This is copied from StableDiffusionPipeline, with hook initizations for LMD+.\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n        # Initialize the attention hooks for LLM-grounded Diffusion\n        self.register_attn_hooks(unet)\n        self._saved_attn = None\n\n    def attn_hook(self, name, query, key, value, attention_probs):\n        if name in DEFAULT_GUIDANCE_ATTN_KEYS:\n            self._saved_attn[name] = attention_probs\n\n    @classmethod\n    def convert_box(cls, box, height, width):\n        # box: x, y, w, h (in 512 format) -> x_min, y_min, x_max, y_max\n        x_min, y_min = box[0] / width, box[1] / height\n        w_box, h_box = box[2] / width, box[3] / height\n\n        x_max, y_max = x_min + w_box, y_min + h_box\n\n        return x_min, y_min, x_max, y_max\n\n    @classmethod\n    def _parse_response_with_negative(cls, text):\n        if not text:\n            raise ValueError(\"LLM response is empty\")\n\n        if cls.objects_text in text:\n            text = text.split(cls.objects_text)[1]\n\n        text_split = text.split(cls.bg_prompt_text_no_trailing_space)\n        if len(text_split) == 2:\n            gen_boxes, text_rem = text_split\n        else:\n            raise ValueError(f\"LLM response is incomplete: {text}\")\n\n        text_split = text_rem.split(cls.neg_prompt_text_no_trailing_space)\n\n        if len(text_split) == 2:\n            bg_prompt, neg_prompt = text_split\n        else:\n            raise ValueError(f\"LLM response is incomplete: {text}\")\n\n        try:\n            gen_boxes = ast.literal_eval(gen_boxes)\n        except SyntaxError as e:\n            # Sometimes the response is in plain text\n            if \"No objects\" in gen_boxes or gen_boxes.strip() == \"\":\n                gen_boxes = []\n            else:\n                raise e\n        bg_prompt = bg_prompt.strip()\n        neg_prompt = neg_prompt.strip()\n\n        # LLM may return \"None\" to mean no negative prompt provided.\n        if neg_prompt == \"None\":\n            neg_prompt = \"\"\n\n        return gen_boxes, bg_prompt, neg_prompt\n\n    @classmethod\n    def parse_llm_response(cls, response, canvas_height=512, canvas_width=512):\n        # Infer from spec\n        gen_boxes, bg_prompt, neg_prompt = cls._parse_response_with_negative(text=response)\n\n        gen_boxes = sorted(gen_boxes, key=lambda gen_box: gen_box[0])\n\n        phrases = [name for name, _ in gen_boxes]\n        boxes = [cls.convert_box(box, height=canvas_height, width=canvas_width) for _, box in gen_boxes]\n\n        return phrases, boxes, bg_prompt, neg_prompt\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_steps,\n        phrases,\n        boxes,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        phrase_indices=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt is None and phrase_indices is None:\n            raise ValueError(\"If the prompt is None, the phrase_indices cannot be None\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if len(phrases) != len(boxes):\n            raise ValueError(\n                \"length of `phrases` and `boxes` has to be same, but\"\n                f\" got: `phrases` {len(phrases)} != `boxes` {len(boxes)}\"\n            )\n\n    def register_attn_hooks(self, unet):\n        \"\"\"Registering hooks to obtain the attention maps for guidance\"\"\"\n\n        attn_procs = {}\n\n        for name in unet.attn_processors.keys():\n            # Only obtain the queries and keys from cross-attention\n            if name.endswith(\"attn1.processor\") or name.endswith(\"fuser.attn.processor\"):\n                # Keep the same attn_processors for self-attention (no hooks for self-attention)\n                attn_procs[name] = unet.attn_processors[name]\n                continue\n\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n\n            if name.startswith(\"mid_block\"):\n                hidden_size = unet.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = unet.config.block_out_channels[block_id]\n\n            attn_procs[name] = AttnProcessorWithHook(\n                attn_processor_key=name,\n                hidden_size=hidden_size,\n                cross_attention_dim=cross_attention_dim,\n                hook=self.attn_hook,\n                fast_attn=True,\n                # Not enabled by default\n                enabled=False,\n            )\n\n        unet.set_attn_processor(attn_procs)\n\n    def enable_fuser(self, enabled=True):\n        for module in self.unet.modules():\n            if isinstance(module, GatedSelfAttentionDense):\n                module.enabled = enabled\n\n    def enable_attn_hook(self, enabled=True):\n        for module in self.unet.attn_processors.values():\n            if isinstance(module, AttnProcessorWithHook):\n                module.enabled = enabled\n\n    def get_token_map(self, prompt, padding=\"do_not_pad\", verbose=False):\n        \"\"\"Get a list of mapping: prompt index to str (prompt in a list of token str)\"\"\"\n        fg_prompt_tokens = self.tokenizer([prompt], padding=padding, max_length=77, return_tensors=\"np\")\n        input_ids = fg_prompt_tokens[\"input_ids\"][0]\n\n        token_map = []\n        for ind, item in enumerate(input_ids.tolist()):\n            token = self.tokenizer._convert_id_to_token(item)\n\n            if verbose:\n                logger.info(f\"{ind}, {token} ({item})\")\n\n            token_map.append(token)\n\n        return token_map\n\n    def get_phrase_indices(\n        self,\n        prompt,\n        phrases,\n        token_map=None,\n        add_suffix_if_not_found=False,\n        verbose=False,\n    ):\n        for obj in phrases:\n            # Suffix the prompt with object name for attention guidance if object is not in the prompt, using \"|\" to separate the prompt and the suffix\n            if obj not in prompt:\n                prompt += \"| \" + obj\n\n        if token_map is None:\n            # We allow using a pre-computed token map.\n            token_map = self.get_token_map(prompt=prompt, padding=\"do_not_pad\", verbose=verbose)\n        token_map_str = \" \".join(token_map)\n\n        phrase_indices = []\n\n        for obj in phrases:\n            phrase_token_map = self.get_token_map(prompt=obj, padding=\"do_not_pad\", verbose=verbose)\n            # Remove <bos> and <eos> in substr\n            phrase_token_map = phrase_token_map[1:-1]\n            phrase_token_map_len = len(phrase_token_map)\n            phrase_token_map_str = \" \".join(phrase_token_map)\n\n            if verbose:\n                logger.info(\n                    \"Full str:\",\n                    token_map_str,\n                    \"Substr:\",\n                    phrase_token_map_str,\n                    \"Phrase:\",\n                    phrases,\n                )\n\n            # Count the number of token before substr\n            # The substring comes with a trailing space that needs to be removed by minus one in the index.\n            obj_first_index = len(token_map_str[: token_map_str.index(phrase_token_map_str) - 1].split(\" \"))\n\n            obj_position = list(range(obj_first_index, obj_first_index + phrase_token_map_len))\n            phrase_indices.append(obj_position)\n\n        if add_suffix_if_not_found:\n            return phrase_indices, prompt\n\n        return phrase_indices\n\n    def add_ca_loss_per_attn_map_to_loss(\n        self,\n        loss,\n        attn_map,\n        object_number,\n        bboxes,\n        phrase_indices,\n        fg_top_p=0.2,\n        bg_top_p=0.2,\n        fg_weight=1.0,\n        bg_weight=1.0,\n    ):\n        # b is the number of heads, not batch\n        b, i, j = attn_map.shape\n        H = W = int(math.sqrt(i))\n        for obj_idx in range(object_number):\n            obj_loss = 0\n            mask = torch.zeros(size=(H, W), device=\"cuda\")\n            obj_boxes = bboxes[obj_idx]\n\n            # We support two level (one box per phrase) and three level (multiple boxes per phrase)\n            if not isinstance(obj_boxes[0], Iterable):\n                obj_boxes = [obj_boxes]\n\n            for obj_box in obj_boxes:\n                # x_min, y_min, x_max, y_max = int(obj_box[0] * W), int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)\n                x_min, y_min, x_max, y_max = scale_proportion(obj_box, H=H, W=W)\n                mask[y_min:y_max, x_min:x_max] = 1\n\n            for obj_position in phrase_indices[obj_idx]:\n                # Could potentially optimize to compute this for loop in batch.\n                # Could crop the ref cross attention before saving to save memory.\n\n                ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)\n\n                # shape: (b, H * W)\n                ca_map_obj = attn_map[:, :, obj_position]  # .reshape(b, H, W)\n                k_fg = (mask.sum() * fg_top_p).long().clamp_(min=1)\n                k_bg = ((1 - mask).sum() * bg_top_p).long().clamp_(min=1)\n\n                mask_1d = mask.view(1, -1)\n\n                # Max-based loss function\n\n                # Take the topk over spatial dimension, and then take the sum over heads dim\n                # The mean is over k_fg and k_bg dimension, so we don't need to sum and divide on our own.\n                obj_loss += (1 - (ca_map_obj * mask_1d).topk(k=k_fg).values.mean(dim=1)).sum(dim=0) * fg_weight\n                obj_loss += ((ca_map_obj * (1 - mask_1d)).topk(k=k_bg).values.mean(dim=1)).sum(dim=0) * bg_weight\n\n            loss += obj_loss / len(phrase_indices[obj_idx])\n\n        return loss\n\n    def compute_ca_loss(\n        self,\n        saved_attn,\n        bboxes,\n        phrase_indices,\n        guidance_attn_keys,\n        verbose=False,\n        **kwargs,\n    ):\n        \"\"\"\n        The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.\n        `AttnProcessor` will put attention maps into the `save_attn_to_dict`.\n\n        `index` is the timestep.\n        `ref_ca_word_token_only`: This has precedence over `ref_ca_last_token_only` (i.e., if both are enabled, we take the token from word rather than the last token).\n        `ref_ca_last_token_only`: `ref_ca_saved_attn` comes from the attention map of the last token of the phrase in single object generation, so we apply it only to the last token of the phrase in overall generation if this is set to True. If set to False, `ref_ca_saved_attn` will be applied to all the text tokens.\n        \"\"\"\n        loss = torch.tensor(0).float().cuda()\n        object_number = len(bboxes)\n        if object_number == 0:\n            return loss\n\n        for attn_key in guidance_attn_keys:\n            # We only have 1 cross attention for mid.\n\n            attn_map_integrated = saved_attn[attn_key]\n            if not attn_map_integrated.is_cuda:\n                attn_map_integrated = attn_map_integrated.cuda()\n            # Example dimension: [20, 64, 77]\n            attn_map = attn_map_integrated.squeeze(dim=0)\n\n            loss = self.add_ca_loss_per_attn_map_to_loss(\n                loss, attn_map, object_number, bboxes, phrase_indices, **kwargs\n            )\n\n        num_attn = len(guidance_attn_keys)\n\n        if num_attn > 0:\n            loss = loss / (object_number * num_attn)\n\n        return loss\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        gligen_scheduled_sampling_beta: float = 0.3,\n        phrases: List[str] = None,\n        boxes: List[List[float]] = None,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        clip_skip: Optional[int] = None,\n        lmd_guidance_kwargs: Optional[Dict[str, Any]] = {},\n        phrase_indices: Optional[List[int]] = None,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            phrases (`List[str]`):\n                The phrases to guide what to include in each of the regions defined by the corresponding\n                `boxes`. There should only be one phrase per bounding box.\n            boxes (`List[List[float]]`):\n                The bounding boxes that identify rectangular regions of the image that are going to be filled with the\n                content described by the corresponding `phrases`. Each rectangular box is defined as a\n                `List[float]` of 4 elements `[xmin, ymin, xmax, ymax]` where each value is between [0,1].\n            gligen_scheduled_sampling_beta (`float`, defaults to 0.3):\n                Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image\n                Generation](https://huggingface.co/papers/2301.07093). Scheduled Sampling factor is only varied for\n                scheduled sampling during inference for improved quality and controllability.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that calls every `callback_steps` steps during inference. The function is called with the\n                following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function is called. If not specified, the callback is called at\n                every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when\n                using zero terminal SNR.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            lmd_guidance_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to `latent_lmd_guidance` function. Useful keys include `loss_scale` (the guidance strength), `loss_threshold` (when loss is lower than this value, the guidance is not applied anymore), `max_iter` (the number of iterations of guidance for each step), and `guidance_timesteps` (the number of diffusion timesteps to apply guidance on). See `latent_lmd_guidance` for implementation details.\n            phrase_indices (`list` of `list`, *optional*): The indices of the tokens of each phrase in the overall prompt. If omitted, the pipeline will match the first token subsequence. The pipeline will append the missing phrases to the end of the prompt by default.\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            callback_steps,\n            phrases,\n            boxes,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            phrase_indices,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n            if phrase_indices is None:\n                phrase_indices, prompt = self.get_phrase_indices(prompt, phrases, add_suffix_if_not_found=True)\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n            if phrase_indices is None:\n                phrase_indices = []\n                prompt_parsed = []\n                for prompt_item in prompt:\n                    (\n                        phrase_indices_parsed_item,\n                        prompt_parsed_item,\n                    ) = self.get_phrase_indices(prompt_item, add_suffix_if_not_found=True)\n                    phrase_indices.append(phrase_indices_parsed_item)\n                    prompt_parsed.append(prompt_parsed_item)\n                prompt = prompt_parsed\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            clip_skip=clip_skip,\n        )\n\n        cond_prompt_embeds = prompt_embeds\n\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        if ip_adapter_image is not None:\n            image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)\n            if self.do_classifier_free_guidance:\n                image_embeds = torch.cat([negative_image_embeds, image_embeds])\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 5.1 Prepare GLIGEN variables\n        max_objs = 30\n        if len(boxes) > max_objs:\n            warnings.warn(\n                f\"More that {max_objs} objects found. Only first {max_objs} objects will be processed.\",\n                FutureWarning,\n            )\n            phrases = phrases[:max_objs]\n            boxes = boxes[:max_objs]\n\n        n_objs = len(boxes)\n        if n_objs:\n            # prepare batched input to the PositionNet (boxes, phrases, mask)\n            # Get tokens for phrases from pre-trained CLIPTokenizer\n            tokenizer_inputs = self.tokenizer(phrases, padding=True, return_tensors=\"pt\").to(device)\n            # For the token, we use the same pre-trained text encoder\n            # to obtain its text feature\n            _text_embeddings = self.text_encoder(**tokenizer_inputs).pooler_output\n\n        # For each entity, described in phrases, is denoted with a bounding box,\n        # we represent the location information as (xmin,ymin,xmax,ymax)\n        cond_boxes = torch.zeros(max_objs, 4, device=device, dtype=self.text_encoder.dtype)\n        if n_objs:\n            cond_boxes[:n_objs] = torch.tensor(boxes)\n        text_embeddings = torch.zeros(\n            max_objs,\n            self.unet.config.cross_attention_dim,\n            device=device,\n            dtype=self.text_encoder.dtype,\n        )\n        if n_objs:\n            text_embeddings[:n_objs] = _text_embeddings\n        # Generate a mask for each object that is entity described by phrases\n        masks = torch.zeros(max_objs, device=device, dtype=self.text_encoder.dtype)\n        masks[:n_objs] = 1\n\n        repeat_batch = batch_size * num_images_per_prompt\n        cond_boxes = cond_boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()\n        text_embeddings = text_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()\n        masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone()\n        if do_classifier_free_guidance:\n            repeat_batch = repeat_batch * 2\n            cond_boxes = torch.cat([cond_boxes] * 2)\n            text_embeddings = torch.cat([text_embeddings] * 2)\n            masks = torch.cat([masks] * 2)\n            masks[: repeat_batch // 2] = 0\n        if cross_attention_kwargs is None:\n            cross_attention_kwargs = {}\n        cross_attention_kwargs[\"gligen\"] = {\n            \"boxes\": cond_boxes,\n            \"positive_embeddings\": text_embeddings,\n            \"masks\": masks,\n        }\n\n        num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps))\n        self.enable_fuser(True)\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 6.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = {\"image_embeds\": image_embeds} if ip_adapter_image is not None else None\n\n        loss_attn = torch.tensor(10000.0)\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # Scheduled sampling\n                if i == num_grounding_steps:\n                    self.enable_fuser(False)\n\n                if latents.shape[1] != 4:\n                    latents = torch.randn_like(latents[:, :4])\n\n                # 7.1 Perform LMD guidance\n                if boxes:\n                    latents, loss_attn = self.latent_lmd_guidance(\n                        cond_prompt_embeds,\n                        index=i,\n                        boxes=boxes,\n                        phrase_indices=phrase_indices,\n                        t=t,\n                        latents=latents,\n                        loss=loss_attn,\n                        **lmd_guidance_kwargs,\n                    )\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                ).sample\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n    @torch.set_grad_enabled(True)\n    def latent_lmd_guidance(\n        self,\n        cond_embeddings,\n        index,\n        boxes,\n        phrase_indices,\n        t,\n        latents,\n        loss,\n        *,\n        loss_scale=20,\n        loss_threshold=5.0,\n        max_iter=[3] * 5 + [2] * 5 + [1] * 5,\n        guidance_timesteps=15,\n        cross_attention_kwargs=None,\n        guidance_attn_keys=DEFAULT_GUIDANCE_ATTN_KEYS,\n        verbose=False,\n        clear_cache=False,\n        unet_additional_kwargs={},\n        guidance_callback=None,\n        **kwargs,\n    ):\n        scheduler, unet = self.scheduler, self.unet\n\n        iteration = 0\n\n        if index < guidance_timesteps:\n            if isinstance(max_iter, list):\n                max_iter = max_iter[index]\n\n            if verbose:\n                logger.info(\n                    f\"time index {index}, loss: {loss.item() / loss_scale:.3f} (de-scaled with scale {loss_scale:.1f}), loss threshold: {loss_threshold:.3f}\"\n                )\n\n            try:\n                self.enable_attn_hook(enabled=True)\n\n                while (\n                    loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < guidance_timesteps\n                ):\n                    self._saved_attn = {}\n\n                    latents.requires_grad_(True)\n                    latent_model_input = latents\n                    latent_model_input = scheduler.scale_model_input(latent_model_input, t)\n\n                    unet(\n                        latent_model_input,\n                        t,\n                        encoder_hidden_states=cond_embeddings,\n                        cross_attention_kwargs=cross_attention_kwargs,\n                        **unet_additional_kwargs,\n                    )\n\n                    # update latents with guidance\n                    loss = (\n                        self.compute_ca_loss(\n                            saved_attn=self._saved_attn,\n                            bboxes=boxes,\n                            phrase_indices=phrase_indices,\n                            guidance_attn_keys=guidance_attn_keys,\n                            verbose=verbose,\n                            **kwargs,\n                        )\n                        * loss_scale\n                    )\n\n                    if torch.isnan(loss):\n                        raise RuntimeError(\"**Loss is NaN**\")\n\n                    # This callback allows visualizations.\n                    if guidance_callback is not None:\n                        guidance_callback(self, latents, loss, iteration, index)\n\n                    self._saved_attn = None\n\n                    grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]\n\n                    latents.requires_grad_(False)\n\n                    # Scaling with classifier guidance\n                    alpha_prod_t = scheduler.alphas_cumprod[t]\n                    # Classifier guidance: https://huggingface.co/papers/2105.05233\n                    # DDIM: https://huggingface.co/papers/2010.02502\n                    scale = (1 - alpha_prod_t) ** (0.5)\n                    latents = latents - scale * grad_cond\n\n                    iteration += 1\n\n                    if clear_cache:\n                        gc.collect()\n                        torch.cuda.empty_cache()\n\n                    if verbose:\n                        logger.info(\n                            f\"time index {index}, loss: {loss.item() / loss_scale:.3f}, loss threshold: {loss_threshold:.3f}, iteration: {iteration}\"\n                        )\n\n            finally:\n                self.enable_attn_hook(enabled=False)\n\n        return latents, loss\n\n    # Below are methods copied from StableDiffusionPipeline\n    # The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ):\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device),\n                    attention_mask=attention_mask,\n                    output_hidden_states=True,\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        image_embeds = self.image_encoder(image).image_embeds\n        image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n\n        uncond_image_embeds = torch.zeros_like(image_embeds)\n        return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents\n    def decode_latents(self, latents):\n        deprecation_message = \"The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead\"\n        deprecate(\"decode_latents\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_rescale\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps\n    def num_timesteps(self):\n        return self._num_timesteps\n"
  },
  {
    "path": "examples/community/lpw_stable_diffusion.py",
    "content": "import inspect\nimport re\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom packaging import version\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    PIL_INTERPOLATION,\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\n# ------------------------------------------------------------------------------\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nre_attention = re.compile(\n    r\"\"\"\n\\\\\\(|\n\\\\\\)|\n\\\\\\[|\n\\\\]|\n\\\\\\\\|\n\\\\|\n\\(|\n\\[|\n:([+-]?[.\\d]+)\\)|\n\\)|\n]|\n[^\\\\()\\[\\]:]+|\n:\n\"\"\",\n    re.X,\n)\n\n\ndef parse_prompt_attention(text):\n    \"\"\"\n    Parses a string with attention tokens and returns a list of pairs: text and its associated weight.\n    Accepted tokens are:\n      (abc) - increases attention to abc by a multiplier of 1.1\n      (abc:3.12) - increases attention to abc by a multiplier of 3.12\n      [abc] - decreases attention to abc by a multiplier of 1.1\n      \\\\( - literal character '('\n      \\\\[ - literal character '['\n      \\\\) - literal character ')'\n      \\\\] - literal character ']'\n      \\\\ - literal character '\\'\n      anything else - just text\n    >>> parse_prompt_attention('normal text')\n    [['normal text', 1.0]]\n    >>> parse_prompt_attention('an (important) word')\n    [['an ', 1.0], ['important', 1.1], [' word', 1.0]]\n    >>> parse_prompt_attention('(unbalanced')\n    [['unbalanced', 1.1]]\n    >>> parse_prompt_attention('\\\\(literal\\\\]')\n    [['(literal]', 1.0]]\n    >>> parse_prompt_attention('(unnecessary)(parens)')\n    [['unnecessaryparens', 1.1]]\n    >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')\n    [['a ', 1.0],\n     ['house', 1.5730000000000004],\n     [' ', 1.1],\n     ['on', 1.0],\n     [' a ', 1.1],\n     ['hill', 0.55],\n     [', sun, ', 1.1],\n     ['sky', 1.4641000000000006],\n     ['.', 1.1]]\n    \"\"\"\n\n    res = []\n    round_brackets = []\n    square_brackets = []\n\n    round_bracket_multiplier = 1.1\n    square_bracket_multiplier = 1 / 1.1\n\n    def multiply_range(start_position, multiplier):\n        for p in range(start_position, len(res)):\n            res[p][1] *= multiplier\n\n    for m in re_attention.finditer(text):\n        text = m.group(0)\n        weight = m.group(1)\n\n        if text.startswith(\"\\\\\"):\n            res.append([text[1:], 1.0])\n        elif text == \"(\":\n            round_brackets.append(len(res))\n        elif text == \"[\":\n            square_brackets.append(len(res))\n        elif weight is not None and len(round_brackets) > 0:\n            multiply_range(round_brackets.pop(), float(weight))\n        elif text == \")\" and len(round_brackets) > 0:\n            multiply_range(round_brackets.pop(), round_bracket_multiplier)\n        elif text == \"]\" and len(square_brackets) > 0:\n            multiply_range(square_brackets.pop(), square_bracket_multiplier)\n        else:\n            res.append([text, 1.0])\n\n    for pos in round_brackets:\n        multiply_range(pos, round_bracket_multiplier)\n\n    for pos in square_brackets:\n        multiply_range(pos, square_bracket_multiplier)\n\n    if len(res) == 0:\n        res = [[\"\", 1.0]]\n\n    # merge runs of identical weights\n    i = 0\n    while i + 1 < len(res):\n        if res[i][1] == res[i + 1][1]:\n            res[i][0] += res[i + 1][0]\n            res.pop(i + 1)\n        else:\n            i += 1\n\n    return res\n\n\ndef get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):\n    r\"\"\"\n    Tokenize a list of prompts and return its tokens with weights of each token.\n\n    No padding, starting or ending token is included.\n    \"\"\"\n    tokens = []\n    weights = []\n    truncated = False\n    for text in prompt:\n        texts_and_weights = parse_prompt_attention(text)\n        text_token = []\n        text_weight = []\n        for word, weight in texts_and_weights:\n            # tokenize and discard the starting and the ending token\n            token = pipe.tokenizer(word).input_ids[1:-1]\n            text_token += token\n            # copy the weight by length of token\n            text_weight += [weight] * len(token)\n            # stop if the text is too long (longer than truncation limit)\n            if len(text_token) > max_length:\n                truncated = True\n                break\n        # truncate\n        if len(text_token) > max_length:\n            truncated = True\n            text_token = text_token[:max_length]\n            text_weight = text_weight[:max_length]\n        tokens.append(text_token)\n        weights.append(text_weight)\n    if truncated:\n        logger.warning(\"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples\")\n    return tokens, weights\n\n\ndef pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):\n    r\"\"\"\n    Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.\n    \"\"\"\n    max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)\n    weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length\n    for i in range(len(tokens)):\n        tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]\n        if no_boseos_middle:\n            weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))\n        else:\n            w = []\n            if len(weights[i]) == 0:\n                w = [1.0] * weights_length\n            else:\n                for j in range(max_embeddings_multiples):\n                    w.append(1.0)  # weight for starting token in this chunk\n                    w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]\n                    w.append(1.0)  # weight for ending token in this chunk\n                w += [1.0] * (weights_length - len(w))\n            weights[i] = w[:]\n\n    return tokens, weights\n\n\ndef get_unweighted_text_embeddings(\n    pipe: DiffusionPipeline,\n    text_input: torch.Tensor,\n    chunk_length: int,\n    no_boseos_middle: Optional[bool] = True,\n    clip_skip: Optional[int] = None,\n):\n    \"\"\"\n    When the length of tokens is a multiple of the capacity of the text encoder,\n    it should be split into chunks and sent to the text encoder individually.\n    \"\"\"\n    max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)\n    if max_embeddings_multiples > 1:\n        text_embeddings = []\n        for i in range(max_embeddings_multiples):\n            # extract the i-th chunk\n            text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()\n\n            # cover the head and the tail by the starting and the ending tokens\n            text_input_chunk[:, 0] = text_input[0, 0]\n            text_input_chunk[:, -1] = text_input[0, -1]\n            if clip_skip is None:\n                prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device))\n                text_embedding = prompt_embeds[0]\n            else:\n                prompt_embeds = pipe.text_encoder(text_input_chunk.to(pipe.device), output_hidden_states=True)\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                text_embedding = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n            if no_boseos_middle:\n                if i == 0:\n                    # discard the ending token\n                    text_embedding = text_embedding[:, :-1]\n                elif i == max_embeddings_multiples - 1:\n                    # discard the starting token\n                    text_embedding = text_embedding[:, 1:]\n                else:\n                    # discard both starting and ending tokens\n                    text_embedding = text_embedding[:, 1:-1]\n\n            text_embeddings.append(text_embedding)\n        text_embeddings = torch.concat(text_embeddings, axis=1)\n    else:\n        if clip_skip is None:\n            clip_skip = 0\n        prompt_embeds = pipe.text_encoder(text_input, output_hidden_states=True)[-1][-(clip_skip + 1)]\n        text_embeddings = pipe.text_encoder.text_model.final_layer_norm(prompt_embeds)\n    return text_embeddings\n\n\ndef get_weighted_text_embeddings(\n    pipe: DiffusionPipeline,\n    prompt: Union[str, List[str]],\n    uncond_prompt: Optional[Union[str, List[str]]] = None,\n    max_embeddings_multiples: Optional[int] = 3,\n    no_boseos_middle: Optional[bool] = False,\n    skip_parsing: Optional[bool] = False,\n    skip_weighting: Optional[bool] = False,\n    clip_skip=None,\n    lora_scale=None,\n):\n    r\"\"\"\n    Prompts can be assigned with local weights using brackets. For example,\n    prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',\n    and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.\n\n    Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.\n\n    Args:\n        pipe (`DiffusionPipeline`):\n            Pipe to provide access to the tokenizer and the text encoder.\n        prompt (`str` or `List[str]`):\n            The prompt or prompts to guide the image generation.\n        uncond_prompt (`str` or `List[str]`):\n            The unconditional prompt or prompts for guide the image generation. If unconditional prompt\n            is provided, the embeddings of prompt and uncond_prompt are concatenated.\n        max_embeddings_multiples (`int`, *optional*, defaults to `3`):\n            The max multiple length of prompt embeddings compared to the max output length of text encoder.\n        no_boseos_middle (`bool`, *optional*, defaults to `False`):\n            If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and\n            ending token in each of the chunk in the middle.\n        skip_parsing (`bool`, *optional*, defaults to `False`):\n            Skip the parsing of brackets.\n        skip_weighting (`bool`, *optional*, defaults to `False`):\n            Skip the weighting. When the parsing is skipped, it is forced True.\n    \"\"\"\n    # set lora scale so that monkey patched LoRA\n    # function of text encoder can correctly access it\n    if lora_scale is not None and isinstance(pipe, StableDiffusionLoraLoaderMixin):\n        pipe._lora_scale = lora_scale\n\n        # dynamically adjust the LoRA scale\n        if not USE_PEFT_BACKEND:\n            adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)\n        else:\n            scale_lora_layers(pipe.text_encoder, lora_scale)\n    max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2\n    if isinstance(prompt, str):\n        prompt = [prompt]\n\n    if not skip_parsing:\n        prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)\n        if uncond_prompt is not None:\n            if isinstance(uncond_prompt, str):\n                uncond_prompt = [uncond_prompt]\n            uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)\n    else:\n        prompt_tokens = [\n            token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids\n        ]\n        prompt_weights = [[1.0] * len(token) for token in prompt_tokens]\n        if uncond_prompt is not None:\n            if isinstance(uncond_prompt, str):\n                uncond_prompt = [uncond_prompt]\n            uncond_tokens = [\n                token[1:-1]\n                for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids\n            ]\n            uncond_weights = [[1.0] * len(token) for token in uncond_tokens]\n\n    # round up the longest length of tokens to a multiple of (model_max_length - 2)\n    max_length = max([len(token) for token in prompt_tokens])\n    if uncond_prompt is not None:\n        max_length = max(max_length, max([len(token) for token in uncond_tokens]))\n\n    max_embeddings_multiples = min(\n        max_embeddings_multiples,\n        (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,\n    )\n    max_embeddings_multiples = max(1, max_embeddings_multiples)\n    max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2\n\n    # pad the length of tokens and weights\n    bos = pipe.tokenizer.bos_token_id\n    eos = pipe.tokenizer.eos_token_id\n    pad = getattr(pipe.tokenizer, \"pad_token_id\", eos)\n    prompt_tokens, prompt_weights = pad_tokens_and_weights(\n        prompt_tokens,\n        prompt_weights,\n        max_length,\n        bos,\n        eos,\n        pad,\n        no_boseos_middle=no_boseos_middle,\n        chunk_length=pipe.tokenizer.model_max_length,\n    )\n    prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)\n    if uncond_prompt is not None:\n        uncond_tokens, uncond_weights = pad_tokens_and_weights(\n            uncond_tokens,\n            uncond_weights,\n            max_length,\n            bos,\n            eos,\n            pad,\n            no_boseos_middle=no_boseos_middle,\n            chunk_length=pipe.tokenizer.model_max_length,\n        )\n        uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)\n\n    # get the embeddings\n    text_embeddings = get_unweighted_text_embeddings(\n        pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, clip_skip=clip_skip\n    )\n    prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)\n    if uncond_prompt is not None:\n        uncond_embeddings = get_unweighted_text_embeddings(\n            pipe,\n            uncond_tokens,\n            pipe.tokenizer.model_max_length,\n            no_boseos_middle=no_boseos_middle,\n            clip_skip=clip_skip,\n        )\n        uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)\n\n    # assign weights to the prompts and normalize in the sense of mean\n    # TODO: should we normalize by chunk or in a whole (current implementation)?\n    if (not skip_parsing) and (not skip_weighting):\n        previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)\n        text_embeddings *= prompt_weights.unsqueeze(-1)\n        current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)\n        text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)\n        if uncond_prompt is not None:\n            previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)\n            uncond_embeddings *= uncond_weights.unsqueeze(-1)\n            current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)\n            uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)\n\n    if pipe.text_encoder is not None:\n        if isinstance(pipe, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(pipe.text_encoder, lora_scale)\n\n    if uncond_prompt is not None:\n        return text_embeddings, uncond_embeddings\n    return text_embeddings, None\n\n\ndef preprocess_image(image, batch_size):\n    w, h = image.size\n    w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8\n    image = image.resize((w, h), resample=PIL_INTERPOLATION[\"lanczos\"])\n    image = np.array(image).astype(np.float32) / 255.0\n    image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)\n    image = torch.from_numpy(image)\n    return 2.0 * image - 1.0\n\n\ndef preprocess_mask(mask, batch_size, scale_factor=8):\n    if not isinstance(mask, torch.Tensor):\n        mask = mask.convert(\"L\")\n        w, h = mask.size\n        w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8\n        mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION[\"nearest\"])\n        mask = np.array(mask).astype(np.float32) / 255.0\n        mask = np.tile(mask, (4, 1, 1))\n        mask = np.vstack([mask[None]] * batch_size)\n        mask = 1 - mask  # repaint white, keep black\n        mask = torch.from_numpy(mask)\n        return mask\n\n    else:\n        valid_mask_channel_sizes = [1, 3]\n        # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)\n        if mask.shape[3] in valid_mask_channel_sizes:\n            mask = mask.permute(0, 3, 1, 2)\n        elif mask.shape[1] not in valid_mask_channel_sizes:\n            raise ValueError(\n                f\"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,\"\n                f\" but received mask of shape {tuple(mask.shape)}\"\n            )\n        # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape\n        mask = mask.mean(dim=1, keepdim=True)\n        h, w = mask.shape[-2:]\n        h, w = (x - x % 8 for x in (h, w))  # resize to integer multiple of 8\n        mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))\n        return mask\n\n\nclass StableDiffusionLongPromptWeightingPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    StableDiffusionLoraLoaderMixin,\n    FromSingleFileMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing\n    weighting in prompt.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder-->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(\n            requires_safety_checker=requires_safety_checker,\n        )\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        max_embeddings_multiples=3,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        clip_skip: Optional[int] = None,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `list(int)`):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            max_embeddings_multiples (`int`, *optional*, defaults to `3`):\n                The max multiple length of prompt embeddings compared to the max output length of text encoder.\n        \"\"\"\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if negative_prompt_embeds is None:\n            if negative_prompt is None:\n                negative_prompt = [\"\"] * batch_size\n            elif isinstance(negative_prompt, str):\n                negative_prompt = [negative_prompt] * batch_size\n            if batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n        if prompt_embeds is None or negative_prompt_embeds is None:\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n                if do_classifier_free_guidance and negative_prompt_embeds is None:\n                    negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)\n\n            prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(\n                pipe=self,\n                prompt=prompt,\n                uncond_prompt=negative_prompt if do_classifier_free_guidance else None,\n                max_embeddings_multiples=max_embeddings_multiples,\n                clip_skip=clip_skip,\n                lora_scale=lora_scale,\n            )\n            if prompt_embeds is None:\n                prompt_embeds = prompt_embeds1\n            if negative_prompt_embeds is None:\n                negative_prompt_embeds = negative_prompt_embeds1\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            bs_embed, seq_len, _ = negative_prompt_embeds.shape\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        strength,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n    def get_timesteps(self, num_inference_steps, strength, device, is_text2img):\n        if is_text2img:\n            return self.scheduler.timesteps.to(device), num_inference_steps\n        else:\n            # get the original timestep using init_timestep\n            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n            t_start = max(num_inference_steps - init_timestep, 0)\n            timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n            return timesteps, num_inference_steps - t_start\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        else:\n            has_nsfw_concept = None\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def prepare_latents(\n        self,\n        image,\n        timestep,\n        num_images_per_prompt,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        if image is None:\n            batch_size = batch_size * num_images_per_prompt\n            shape = (\n                batch_size,\n                num_channels_latents,\n                int(height) // self.vae_scale_factor,\n                int(width) // self.vae_scale_factor,\n            )\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            if latents is None:\n                latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            else:\n                latents = latents.to(device)\n\n            # scale the initial noise by the standard deviation required by the scheduler\n            latents = latents * self.scheduler.init_noise_sigma\n            return latents, None, None\n        else:\n            image = image.to(device=self.device, dtype=dtype)\n            init_latent_dist = self.vae.encode(image).latent_dist\n            init_latents = init_latent_dist.sample(generator=generator)\n            init_latents = self.vae.config.scaling_factor * init_latents\n\n            # Expand init_latents for batch_size and num_images_per_prompt\n            init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)\n            init_latents_orig = init_latents\n\n            # add noise to latents using the timesteps\n            noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)\n            init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n            latents = init_latents\n            return latents, init_latents_orig, noise\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        image: Union[torch.Tensor, PIL.Image.Image] = None,\n        mask_image: Union[torch.Tensor, PIL.Image.Image] = None,\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        strength: float = 0.8,\n        num_images_per_prompt: Optional[int] = 1,\n        add_predicted_noise: Optional[bool] = False,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        max_embeddings_multiples: Optional[int] = 3,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        is_cancelled_callback: Optional[Callable[[], bool]] = None,\n        clip_skip: Optional[int] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, that will be used as the starting point for the\n                process.\n            mask_image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a\n                PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should\n                contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            strength (`float`, *optional*, defaults to 0.8):\n                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.\n                `image` will be used as a starting point, adding more noise to it the larger the `strength`. The\n                number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added\n                noise will be maximum and the denoising process will run for the full number of iterations specified in\n                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            add_predicted_noise (`bool`, *optional*, defaults to True):\n                Use predicted noise instead of random noise when constructing noisy versions of the original image in\n                the reverse diffusion process\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            max_embeddings_multiples (`int`, *optional*, defaults to `3`):\n                The max multiple length of prompt embeddings compared to the max output length of text encoder.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            is_cancelled_callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. If the function returns\n                `True`, the inference will be cancelled.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n\n        Returns:\n            `None` if cancelled by `is_cancelled_callback`,\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        lora_scale = cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            max_embeddings_multiples,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            clip_skip=clip_skip,\n            lora_scale=lora_scale,\n        )\n        dtype = prompt_embeds.dtype\n\n        # 4. Preprocess image and mask\n        if isinstance(image, PIL.Image.Image):\n            image = preprocess_image(image, batch_size)\n        if image is not None:\n            image = image.to(device=self.device, dtype=dtype)\n        if isinstance(mask_image, PIL.Image.Image):\n            mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)\n        if mask_image is not None:\n            mask = mask_image.to(device=self.device, dtype=dtype)\n            mask = torch.cat([mask] * num_images_per_prompt)\n        else:\n            mask = None\n\n        # 5. set timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 6. Prepare latent variables\n        latents, init_latents_orig, noise = self.prepare_latents(\n            image,\n            latent_timestep,\n            num_images_per_prompt,\n            batch_size,\n            self.unet.config.in_channels,\n            height,\n            width,\n            dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                ).sample\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                if mask is not None:\n                    # masking\n                    if add_predicted_noise:\n                        init_latents_proper = self.scheduler.add_noise(\n                            init_latents_orig, noise_pred_uncond, torch.tensor([t])\n                        )\n                    else:\n                        init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))\n                    latents = (init_latents_proper * mask) + (latents * (1 - mask))\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if i % callback_steps == 0:\n                        if callback is not None:\n                            step_idx = i // getattr(self.scheduler, \"order\", 1)\n                            callback(step_idx, t, latents)\n                        if is_cancelled_callback is not None and is_cancelled_callback():\n                            return None\n\n        if output_type == \"latent\":\n            image = latents\n            has_nsfw_concept = None\n        elif output_type == \"pil\":\n            # 9. Post-processing\n            image = self.decode_latents(latents)\n\n            # 10. Run safety checker\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n\n            # 11. Convert to PIL\n            image = self.numpy_to_pil(image)\n        else:\n            # 9. Post-processing\n            image = self.decode_latents(latents)\n\n            # 10. Run safety checker\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return image, has_nsfw_concept\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n    def text2img(\n        self,\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        max_embeddings_multiples: Optional[int] = 3,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        is_cancelled_callback: Optional[Callable[[], bool]] = None,\n        clip_skip=None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        r\"\"\"\n        Function for text-to-image generation.\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            max_embeddings_multiples (`int`, *optional*, defaults to `3`):\n                The max multiple length of prompt embeddings compared to the max output length of text encoder.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            is_cancelled_callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. If the function returns\n                `True`, the inference will be cancelled.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n\n        Returns:\n            `None` if cancelled by `is_cancelled_callback`,\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        return self.__call__(\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            max_embeddings_multiples=max_embeddings_multiples,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            is_cancelled_callback=is_cancelled_callback,\n            clip_skip=clip_skip,\n            callback_steps=callback_steps,\n            cross_attention_kwargs=cross_attention_kwargs,\n        )\n\n    def img2img(\n        self,\n        image: Union[torch.Tensor, PIL.Image.Image],\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        strength: float = 0.8,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: Optional[float] = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        max_embeddings_multiples: Optional[int] = 3,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        is_cancelled_callback: Optional[Callable[[], bool]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        r\"\"\"\n        Function for image-to-image generation.\n        Args:\n            image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, that will be used as the starting point for the\n                process.\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            strength (`float`, *optional*, defaults to 0.8):\n                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.\n                `image` will be used as a starting point, adding more noise to it the larger the `strength`. The\n                number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added\n                noise will be maximum and the denoising process will run for the full number of iterations specified in\n                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference. This parameter will be modulated by `strength`.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            max_embeddings_multiples (`int`, *optional*, defaults to `3`):\n                The max multiple length of prompt embeddings compared to the max output length of text encoder.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            is_cancelled_callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. If the function returns\n                `True`, the inference will be cancelled.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n\n        Returns:\n            `None` if cancelled by `is_cancelled_callback`,\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        return self.__call__(\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            image=image,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            strength=strength,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            max_embeddings_multiples=max_embeddings_multiples,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            is_cancelled_callback=is_cancelled_callback,\n            callback_steps=callback_steps,\n            cross_attention_kwargs=cross_attention_kwargs,\n        )\n\n    def inpaint(\n        self,\n        image: Union[torch.Tensor, PIL.Image.Image],\n        mask_image: Union[torch.Tensor, PIL.Image.Image],\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        strength: float = 0.8,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        num_images_per_prompt: Optional[int] = 1,\n        add_predicted_noise: Optional[bool] = False,\n        eta: Optional[float] = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        max_embeddings_multiples: Optional[int] = 3,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        is_cancelled_callback: Optional[Callable[[], bool]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        r\"\"\"\n        Function for inpaint.\n        Args:\n            image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, that will be used as the starting point for the\n                process. This is the image whose masked region will be inpainted.\n            mask_image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a\n                PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should\n                contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            strength (`float`, *optional*, defaults to 0.8):\n                Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`\n                is 1, the denoising process will be run on the masked area for the full number of iterations specified\n                in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more\n                noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The reference number of denoising steps. More denoising steps usually lead to a higher quality image at\n                the expense of slower inference. This parameter will be modulated by `strength`, as explained above.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            add_predicted_noise (`bool`, *optional*, defaults to True):\n                Use predicted noise instead of random noise when constructing noisy versions of the original image in\n                the reverse diffusion process\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            max_embeddings_multiples (`int`, *optional*, defaults to `3`):\n                The max multiple length of prompt embeddings compared to the max output length of text encoder.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            is_cancelled_callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. If the function returns\n                `True`, the inference will be cancelled.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n\n        Returns:\n            `None` if cancelled by `is_cancelled_callback`,\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        return self.__call__(\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            image=image,\n            mask_image=mask_image,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            strength=strength,\n            num_images_per_prompt=num_images_per_prompt,\n            add_predicted_noise=add_predicted_noise,\n            eta=eta,\n            generator=generator,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            max_embeddings_multiples=max_embeddings_multiples,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            is_cancelled_callback=is_cancelled_callback,\n            callback_steps=callback_steps,\n            cross_attention_kwargs=cross_attention_kwargs,\n        )\n"
  },
  {
    "path": "examples/community/lpw_stable_diffusion_onnx.py",
    "content": "import inspect\nimport re\nfrom typing import Callable, List, Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom packaging import version\nfrom transformers import CLIPImageProcessor, CLIPTokenizer\n\nimport diffusers\nfrom diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, SchedulerMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.utils import logging\n\n\ntry:\n    from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE\nexcept ImportError:\n    ORT_TO_NP_TYPE = {\n        \"tensor(bool)\": np.bool_,\n        \"tensor(int8)\": np.int8,\n        \"tensor(uint8)\": np.uint8,\n        \"tensor(int16)\": np.int16,\n        \"tensor(uint16)\": np.uint16,\n        \"tensor(int32)\": np.int32,\n        \"tensor(uint32)\": np.uint32,\n        \"tensor(int64)\": np.int64,\n        \"tensor(uint64)\": np.uint64,\n        \"tensor(float16)\": np.float16,\n        \"tensor(float)\": np.float32,\n        \"tensor(double)\": np.float64,\n    }\n\ntry:\n    from diffusers.utils import PIL_INTERPOLATION\nexcept ImportError:\n    if version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n        PIL_INTERPOLATION = {\n            \"linear\": PIL.Image.Resampling.BILINEAR,\n            \"bilinear\": PIL.Image.Resampling.BILINEAR,\n            \"bicubic\": PIL.Image.Resampling.BICUBIC,\n            \"lanczos\": PIL.Image.Resampling.LANCZOS,\n            \"nearest\": PIL.Image.Resampling.NEAREST,\n        }\n    else:\n        PIL_INTERPOLATION = {\n            \"linear\": PIL.Image.LINEAR,\n            \"bilinear\": PIL.Image.BILINEAR,\n            \"bicubic\": PIL.Image.BICUBIC,\n            \"lanczos\": PIL.Image.LANCZOS,\n            \"nearest\": PIL.Image.NEAREST,\n        }\n# ------------------------------------------------------------------------------\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nre_attention = re.compile(\n    r\"\"\"\n\\\\\\(|\n\\\\\\)|\n\\\\\\[|\n\\\\]|\n\\\\\\\\|\n\\\\|\n\\(|\n\\[|\n:([+-]?[.\\d]+)\\)|\n\\)|\n]|\n[^\\\\()\\[\\]:]+|\n:\n\"\"\",\n    re.X,\n)\n\n\ndef parse_prompt_attention(text):\n    \"\"\"\n    Parses a string with attention tokens and returns a list of pairs: text and its associated weight.\n    Accepted tokens are:\n      (abc) - increases attention to abc by a multiplier of 1.1\n      (abc:3.12) - increases attention to abc by a multiplier of 3.12\n      [abc] - decreases attention to abc by a multiplier of 1.1\n      \\\\( - literal character '('\n      \\\\[ - literal character '['\n      \\\\) - literal character ')'\n      \\\\] - literal character ']'\n      \\\\ - literal character '\\'\n      anything else - just text\n    >>> parse_prompt_attention('normal text')\n    [['normal text', 1.0]]\n    >>> parse_prompt_attention('an (important) word')\n    [['an ', 1.0], ['important', 1.1], [' word', 1.0]]\n    >>> parse_prompt_attention('(unbalanced')\n    [['unbalanced', 1.1]]\n    >>> parse_prompt_attention('\\\\(literal\\\\]')\n    [['(literal]', 1.0]]\n    >>> parse_prompt_attention('(unnecessary)(parens)')\n    [['unnecessaryparens', 1.1]]\n    >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')\n    [['a ', 1.0],\n     ['house', 1.5730000000000004],\n     [' ', 1.1],\n     ['on', 1.0],\n     [' a ', 1.1],\n     ['hill', 0.55],\n     [', sun, ', 1.1],\n     ['sky', 1.4641000000000006],\n     ['.', 1.1]]\n    \"\"\"\n\n    res = []\n    round_brackets = []\n    square_brackets = []\n\n    round_bracket_multiplier = 1.1\n    square_bracket_multiplier = 1 / 1.1\n\n    def multiply_range(start_position, multiplier):\n        for p in range(start_position, len(res)):\n            res[p][1] *= multiplier\n\n    for m in re_attention.finditer(text):\n        text = m.group(0)\n        weight = m.group(1)\n\n        if text.startswith(\"\\\\\"):\n            res.append([text[1:], 1.0])\n        elif text == \"(\":\n            round_brackets.append(len(res))\n        elif text == \"[\":\n            square_brackets.append(len(res))\n        elif weight is not None and len(round_brackets) > 0:\n            multiply_range(round_brackets.pop(), float(weight))\n        elif text == \")\" and len(round_brackets) > 0:\n            multiply_range(round_brackets.pop(), round_bracket_multiplier)\n        elif text == \"]\" and len(square_brackets) > 0:\n            multiply_range(square_brackets.pop(), square_bracket_multiplier)\n        else:\n            res.append([text, 1.0])\n\n    for pos in round_brackets:\n        multiply_range(pos, round_bracket_multiplier)\n\n    for pos in square_brackets:\n        multiply_range(pos, square_bracket_multiplier)\n\n    if len(res) == 0:\n        res = [[\"\", 1.0]]\n\n    # merge runs of identical weights\n    i = 0\n    while i + 1 < len(res):\n        if res[i][1] == res[i + 1][1]:\n            res[i][0] += res[i + 1][0]\n            res.pop(i + 1)\n        else:\n            i += 1\n\n    return res\n\n\ndef get_prompts_with_weights(pipe, prompt: List[str], max_length: int):\n    r\"\"\"\n    Tokenize a list of prompts and return its tokens with weights of each token.\n\n    No padding, starting or ending token is included.\n    \"\"\"\n    tokens = []\n    weights = []\n    truncated = False\n    for text in prompt:\n        texts_and_weights = parse_prompt_attention(text)\n        text_token = []\n        text_weight = []\n        for word, weight in texts_and_weights:\n            # tokenize and discard the starting and the ending token\n            token = pipe.tokenizer(word, return_tensors=\"np\").input_ids[0, 1:-1]\n            text_token += list(token)\n            # copy the weight by length of token\n            text_weight += [weight] * len(token)\n            # stop if the text is too long (longer than truncation limit)\n            if len(text_token) > max_length:\n                truncated = True\n                break\n        # truncate\n        if len(text_token) > max_length:\n            truncated = True\n            text_token = text_token[:max_length]\n            text_weight = text_weight[:max_length]\n        tokens.append(text_token)\n        weights.append(text_weight)\n    if truncated:\n        logger.warning(\"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples\")\n    return tokens, weights\n\n\ndef pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):\n    r\"\"\"\n    Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.\n    \"\"\"\n    max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)\n    weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length\n    for i in range(len(tokens)):\n        tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]\n        if no_boseos_middle:\n            weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))\n        else:\n            w = []\n            if len(weights[i]) == 0:\n                w = [1.0] * weights_length\n            else:\n                for j in range(max_embeddings_multiples):\n                    w.append(1.0)  # weight for starting token in this chunk\n                    w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]\n                    w.append(1.0)  # weight for ending token in this chunk\n                w += [1.0] * (weights_length - len(w))\n            weights[i] = w[:]\n\n    return tokens, weights\n\n\ndef get_unweighted_text_embeddings(\n    pipe,\n    text_input: np.array,\n    chunk_length: int,\n    no_boseos_middle: Optional[bool] = True,\n):\n    \"\"\"\n    When the length of tokens is a multiple of the capacity of the text encoder,\n    it should be split into chunks and sent to the text encoder individually.\n    \"\"\"\n    max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)\n    if max_embeddings_multiples > 1:\n        text_embeddings = []\n        for i in range(max_embeddings_multiples):\n            # extract the i-th chunk\n            text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].copy()\n\n            # cover the head and the tail by the starting and the ending tokens\n            text_input_chunk[:, 0] = text_input[0, 0]\n            text_input_chunk[:, -1] = text_input[0, -1]\n\n            text_embedding = pipe.text_encoder(input_ids=text_input_chunk)[0]\n\n            if no_boseos_middle:\n                if i == 0:\n                    # discard the ending token\n                    text_embedding = text_embedding[:, :-1]\n                elif i == max_embeddings_multiples - 1:\n                    # discard the starting token\n                    text_embedding = text_embedding[:, 1:]\n                else:\n                    # discard both starting and ending tokens\n                    text_embedding = text_embedding[:, 1:-1]\n\n            text_embeddings.append(text_embedding)\n        text_embeddings = np.concatenate(text_embeddings, axis=1)\n    else:\n        text_embeddings = pipe.text_encoder(input_ids=text_input)[0]\n    return text_embeddings\n\n\ndef get_weighted_text_embeddings(\n    pipe,\n    prompt: Union[str, List[str]],\n    uncond_prompt: Optional[Union[str, List[str]]] = None,\n    max_embeddings_multiples: Optional[int] = 4,\n    no_boseos_middle: Optional[bool] = False,\n    skip_parsing: Optional[bool] = False,\n    skip_weighting: Optional[bool] = False,\n    **kwargs,\n):\n    r\"\"\"\n    Prompts can be assigned with local weights using brackets. For example,\n    prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',\n    and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.\n\n    Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.\n\n    Args:\n        pipe (`OnnxStableDiffusionPipeline`):\n            Pipe to provide access to the tokenizer and the text encoder.\n        prompt (`str` or `List[str]`):\n            The prompt or prompts to guide the image generation.\n        uncond_prompt (`str` or `List[str]`):\n            The unconditional prompt or prompts for guide the image generation. If unconditional prompt\n            is provided, the embeddings of prompt and uncond_prompt are concatenated.\n        max_embeddings_multiples (`int`, *optional*, defaults to `1`):\n            The max multiple length of prompt embeddings compared to the max output length of text encoder.\n        no_boseos_middle (`bool`, *optional*, defaults to `False`):\n            If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and\n            ending token in each of the chunk in the middle.\n        skip_parsing (`bool`, *optional*, defaults to `False`):\n            Skip the parsing of brackets.\n        skip_weighting (`bool`, *optional*, defaults to `False`):\n            Skip the weighting. When the parsing is skipped, it is forced True.\n    \"\"\"\n    max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2\n    if isinstance(prompt, str):\n        prompt = [prompt]\n\n    if not skip_parsing:\n        prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)\n        if uncond_prompt is not None:\n            if isinstance(uncond_prompt, str):\n                uncond_prompt = [uncond_prompt]\n            uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)\n    else:\n        prompt_tokens = [\n            token[1:-1]\n            for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True, return_tensors=\"np\").input_ids\n        ]\n        prompt_weights = [[1.0] * len(token) for token in prompt_tokens]\n        if uncond_prompt is not None:\n            if isinstance(uncond_prompt, str):\n                uncond_prompt = [uncond_prompt]\n            uncond_tokens = [\n                token[1:-1]\n                for token in pipe.tokenizer(\n                    uncond_prompt,\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"np\",\n                ).input_ids\n            ]\n            uncond_weights = [[1.0] * len(token) for token in uncond_tokens]\n\n    # round up the longest length of tokens to a multiple of (model_max_length - 2)\n    max_length = max([len(token) for token in prompt_tokens])\n    if uncond_prompt is not None:\n        max_length = max(max_length, max([len(token) for token in uncond_tokens]))\n\n    max_embeddings_multiples = min(\n        max_embeddings_multiples,\n        (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,\n    )\n    max_embeddings_multiples = max(1, max_embeddings_multiples)\n    max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2\n\n    # pad the length of tokens and weights\n    bos = pipe.tokenizer.bos_token_id\n    eos = pipe.tokenizer.eos_token_id\n    pad = getattr(pipe.tokenizer, \"pad_token_id\", eos)\n    prompt_tokens, prompt_weights = pad_tokens_and_weights(\n        prompt_tokens,\n        prompt_weights,\n        max_length,\n        bos,\n        eos,\n        pad,\n        no_boseos_middle=no_boseos_middle,\n        chunk_length=pipe.tokenizer.model_max_length,\n    )\n    prompt_tokens = np.array(prompt_tokens, dtype=np.int32)\n    if uncond_prompt is not None:\n        uncond_tokens, uncond_weights = pad_tokens_and_weights(\n            uncond_tokens,\n            uncond_weights,\n            max_length,\n            bos,\n            eos,\n            pad,\n            no_boseos_middle=no_boseos_middle,\n            chunk_length=pipe.tokenizer.model_max_length,\n        )\n        uncond_tokens = np.array(uncond_tokens, dtype=np.int32)\n\n    # get the embeddings\n    text_embeddings = get_unweighted_text_embeddings(\n        pipe,\n        prompt_tokens,\n        pipe.tokenizer.model_max_length,\n        no_boseos_middle=no_boseos_middle,\n    )\n    prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)\n    if uncond_prompt is not None:\n        uncond_embeddings = get_unweighted_text_embeddings(\n            pipe,\n            uncond_tokens,\n            pipe.tokenizer.model_max_length,\n            no_boseos_middle=no_boseos_middle,\n        )\n        uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)\n\n    # assign weights to the prompts and normalize in the sense of mean\n    # TODO: should we normalize by chunk or in a whole (current implementation)?\n    if (not skip_parsing) and (not skip_weighting):\n        previous_mean = text_embeddings.mean(axis=(-2, -1))\n        text_embeddings *= prompt_weights[:, :, None]\n        text_embeddings *= (previous_mean / text_embeddings.mean(axis=(-2, -1)))[:, None, None]\n        if uncond_prompt is not None:\n            previous_mean = uncond_embeddings.mean(axis=(-2, -1))\n            uncond_embeddings *= uncond_weights[:, :, None]\n            uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=(-2, -1)))[:, None, None]\n\n    # For classifier free guidance, we need to do two forward passes.\n    # Here we concatenate the unconditional and text embeddings into a single batch\n    # to avoid doing two forward passes\n    if uncond_prompt is not None:\n        return text_embeddings, uncond_embeddings\n\n    return text_embeddings\n\n\ndef preprocess_image(image):\n    w, h = image.size\n    w, h = (x - x % 32 for x in (w, h))  # resize to integer multiple of 32\n    image = image.resize((w, h), resample=PIL_INTERPOLATION[\"lanczos\"])\n    image = np.array(image).astype(np.float32) / 255.0\n    image = image[None].transpose(0, 3, 1, 2)\n    return 2.0 * image - 1.0\n\n\ndef preprocess_mask(mask, scale_factor=8):\n    mask = mask.convert(\"L\")\n    w, h = mask.size\n    w, h = (x - x % 32 for x in (w, h))  # resize to integer multiple of 32\n    mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION[\"nearest\"])\n    mask = np.array(mask).astype(np.float32) / 255.0\n    mask = np.tile(mask, (4, 1, 1))\n    mask = mask[None].transpose(0, 1, 2, 3)  # what does this step do?\n    mask = 1 - mask  # repaint white, keep black\n    return mask\n\n\nclass OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing\n    weighting in prompt.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n    \"\"\"\n\n    if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse(\"0.9.0\"):\n\n        def __init__(\n            self,\n            vae_encoder: OnnxRuntimeModel,\n            vae_decoder: OnnxRuntimeModel,\n            text_encoder: OnnxRuntimeModel,\n            tokenizer: CLIPTokenizer,\n            unet: OnnxRuntimeModel,\n            scheduler: SchedulerMixin,\n            safety_checker: OnnxRuntimeModel,\n            feature_extractor: CLIPImageProcessor,\n            requires_safety_checker: bool = True,\n        ):\n            super().__init__(\n                vae_encoder=vae_encoder,\n                vae_decoder=vae_decoder,\n                text_encoder=text_encoder,\n                tokenizer=tokenizer,\n                unet=unet,\n                scheduler=scheduler,\n                safety_checker=safety_checker,\n                feature_extractor=feature_extractor,\n                requires_safety_checker=requires_safety_checker,\n            )\n            self.__init__additional__()\n\n    else:\n\n        def __init__(\n            self,\n            vae_encoder: OnnxRuntimeModel,\n            vae_decoder: OnnxRuntimeModel,\n            text_encoder: OnnxRuntimeModel,\n            tokenizer: CLIPTokenizer,\n            unet: OnnxRuntimeModel,\n            scheduler: SchedulerMixin,\n            safety_checker: OnnxRuntimeModel,\n            feature_extractor: CLIPImageProcessor,\n        ):\n            super().__init__(\n                vae_encoder=vae_encoder,\n                vae_decoder=vae_decoder,\n                text_encoder=text_encoder,\n                tokenizer=tokenizer,\n                unet=unet,\n                scheduler=scheduler,\n                safety_checker=safety_checker,\n                feature_extractor=feature_extractor,\n            )\n            self.__init__additional__()\n\n    def __init__additional__(self):\n        self.unet.config.in_channels = 4\n        self.vae_scale_factor = 8\n\n    def _encode_prompt(\n        self,\n        prompt,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt,\n        max_embeddings_multiples,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `list(int)`):\n                prompt to be encoded\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            max_embeddings_multiples (`int`, *optional*, defaults to `3`):\n                The max multiple length of prompt embeddings compared to the max output length of text encoder.\n        \"\"\"\n        batch_size = len(prompt) if isinstance(prompt, list) else 1\n\n        if negative_prompt is None:\n            negative_prompt = [\"\"] * batch_size\n        elif isinstance(negative_prompt, str):\n            negative_prompt = [negative_prompt] * batch_size\n        if batch_size != len(negative_prompt):\n            raise ValueError(\n                f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                \" the batch size of `prompt`.\"\n            )\n\n        text_embeddings, uncond_embeddings = get_weighted_text_embeddings(\n            pipe=self,\n            prompt=prompt,\n            uncond_prompt=negative_prompt if do_classifier_free_guidance else None,\n            max_embeddings_multiples=max_embeddings_multiples,\n        )\n\n        text_embeddings = text_embeddings.repeat(num_images_per_prompt, 0)\n        if do_classifier_free_guidance:\n            uncond_embeddings = uncond_embeddings.repeat(num_images_per_prompt, 0)\n            text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])\n\n        return text_embeddings\n\n    def check_inputs(self, prompt, height, width, strength, callback_steps):\n        if not isinstance(prompt, str) and not isinstance(prompt, list):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n    def get_timesteps(self, num_inference_steps, strength, is_text2img):\n        if is_text2img:\n            return self.scheduler.timesteps, num_inference_steps\n        else:\n            # get the original timestep using init_timestep\n            offset = self.scheduler.config.get(\"steps_offset\", 0)\n            init_timestep = int(num_inference_steps * strength) + offset\n            init_timestep = min(init_timestep, num_inference_steps)\n\n            t_start = max(num_inference_steps - init_timestep + offset, 0)\n            timesteps = self.scheduler.timesteps[t_start:]\n            return timesteps, num_inference_steps - t_start\n\n    def run_safety_checker(self, image):\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(\n                self.numpy_to_pil(image), return_tensors=\"np\"\n            ).pixel_values.astype(image.dtype)\n            # There will throw an error if use safety_checker directly and batchsize>1\n            images, has_nsfw_concept = [], []\n            for i in range(image.shape[0]):\n                image_i, has_nsfw_concept_i = self.safety_checker(\n                    clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]\n                )\n                images.append(image_i)\n                has_nsfw_concept.append(has_nsfw_concept_i[0])\n            image = np.concatenate(images)\n        else:\n            has_nsfw_concept = None\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        latents = 1 / 0.18215 * latents\n        # image = self.vae_decoder(latent_sample=latents)[0]\n        # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1\n        image = np.concatenate(\n            [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]\n        )\n        image = np.clip(image / 2 + 0.5, 0, 1)\n        image = image.transpose((0, 2, 3, 1))\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def prepare_latents(self, image, timestep, batch_size, height, width, dtype, generator, latents=None):\n        if image is None:\n            shape = (\n                batch_size,\n                self.unet.config.in_channels,\n                height // self.vae_scale_factor,\n                width // self.vae_scale_factor,\n            )\n\n            if latents is None:\n                latents = torch.randn(shape, generator=generator, device=\"cpu\").numpy().astype(dtype)\n            else:\n                if latents.shape != shape:\n                    raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {shape}\")\n\n            # scale the initial noise by the standard deviation required by the scheduler\n            latents = (torch.from_numpy(latents) * self.scheduler.init_noise_sigma).numpy()\n            return latents, None, None\n        else:\n            init_latents = self.vae_encoder(sample=image)[0]\n            init_latents = 0.18215 * init_latents\n            init_latents = np.concatenate([init_latents] * batch_size, axis=0)\n            init_latents_orig = init_latents\n            shape = init_latents.shape\n\n            # add noise to latents using the timesteps\n            noise = torch.randn(shape, generator=generator, device=\"cpu\").numpy().astype(dtype)\n            latents = self.scheduler.add_noise(\n                torch.from_numpy(init_latents), torch.from_numpy(noise), timestep\n            ).numpy()\n            return latents, init_latents_orig, noise\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        image: Union[np.ndarray, PIL.Image.Image] = None,\n        mask_image: Union[np.ndarray, PIL.Image.Image] = None,\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        strength: float = 0.8,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[np.ndarray] = None,\n        max_embeddings_multiples: Optional[int] = 3,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, np.ndarray], None]] = None,\n        is_cancelled_callback: Optional[Callable[[], bool]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            image (`np.ndarray` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, that will be used as the starting point for the\n                process.\n            mask_image (`np.ndarray` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a\n                PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should\n                contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            strength (`float`, *optional*, defaults to 0.8):\n                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.\n                `image` will be used as a starting point, adding more noise to it the larger the `strength`. The\n                number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added\n                noise will be maximum and the denoising process will run for the full number of iterations specified in\n                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`np.ndarray`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            max_embeddings_multiples (`int`, *optional*, defaults to `3`):\n                The max multiple length of prompt embeddings compared to the max output length of text encoder.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.\n            is_cancelled_callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. If the function returns\n                `True`, the inference will be cancelled.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n\n        Returns:\n            `None` if cancelled by `is_cancelled_callback`,\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(prompt, height, width, strength, callback_steps)\n\n        # 2. Define call parameters\n        batch_size = 1 if isinstance(prompt, str) else len(prompt)\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        text_embeddings = self._encode_prompt(\n            prompt,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            max_embeddings_multiples,\n        )\n        dtype = text_embeddings.dtype\n\n        # 4. Preprocess image and mask\n        if isinstance(image, PIL.Image.Image):\n            image = preprocess_image(image)\n        if image is not None:\n            image = image.astype(dtype)\n        if isinstance(mask_image, PIL.Image.Image):\n            mask_image = preprocess_mask(mask_image, self.vae_scale_factor)\n        if mask_image is not None:\n            mask = mask_image.astype(dtype)\n            mask = np.concatenate([mask] * batch_size * num_images_per_prompt)\n        else:\n            mask = None\n\n        # 5. set timesteps\n        self.scheduler.set_timesteps(num_inference_steps)\n        timestep_dtype = next(\n            (input.type for input in self.unet.model.get_inputs() if input.name == \"timestep\"), \"tensor(float)\"\n        )\n        timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, image is None)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 6. Prepare latent variables\n        latents, init_latents_orig, noise = self.prepare_latents(\n            image,\n            latent_timestep,\n            batch_size * num_images_per_prompt,\n            height,\n            width,\n            dtype,\n            generator,\n            latents,\n        )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 8. Denoising loop\n        for i, t in enumerate(self.progress_bar(timesteps)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents\n            latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)\n            latent_model_input = latent_model_input.numpy()\n\n            # predict the noise residual\n            noise_pred = self.unet(\n                sample=latent_model_input,\n                timestep=np.array([t], dtype=timestep_dtype),\n                encoder_hidden_states=text_embeddings,\n            )\n            noise_pred = noise_pred[0]\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            scheduler_output = self.scheduler.step(\n                torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs\n            )\n            latents = scheduler_output.prev_sample.numpy()\n\n            if mask is not None:\n                # masking\n                init_latents_proper = self.scheduler.add_noise(\n                    torch.from_numpy(init_latents_orig),\n                    torch.from_numpy(noise),\n                    t,\n                ).numpy()\n                latents = (init_latents_proper * mask) + (latents * (1 - mask))\n\n            # call the callback, if provided\n            if i % callback_steps == 0:\n                if callback is not None:\n                    step_idx = i // getattr(self.scheduler, \"order\", 1)\n                    callback(step_idx, t, latents)\n                if is_cancelled_callback is not None and is_cancelled_callback():\n                    return None\n\n        # 9. Post-processing\n        image = self.decode_latents(latents)\n\n        # 10. Run safety checker\n        image, has_nsfw_concept = self.run_safety_checker(image)\n\n        # 11. Convert to PIL\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return image, has_nsfw_concept\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n    def text2img(\n        self,\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[np.ndarray] = None,\n        max_embeddings_multiples: Optional[int] = 3,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, np.ndarray], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function for text-to-image generation.\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`np.ndarray`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            max_embeddings_multiples (`int`, *optional*, defaults to `3`):\n                The max multiple length of prompt embeddings compared to the max output length of text encoder.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        return self.__call__(\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            max_embeddings_multiples=max_embeddings_multiples,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n            **kwargs,\n        )\n\n    def img2img(\n        self,\n        image: Union[np.ndarray, PIL.Image.Image],\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        strength: float = 0.8,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: Optional[float] = 0.0,\n        generator: torch.Generator | None = None,\n        max_embeddings_multiples: Optional[int] = 3,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, np.ndarray], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function for image-to-image generation.\n        Args:\n            image (`np.ndarray` or `PIL.Image.Image`):\n                `Image`, or ndarray representing an image batch, that will be used as the starting point for the\n                process.\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            strength (`float`, *optional*, defaults to 0.8):\n                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.\n                `image` will be used as a starting point, adding more noise to it the larger the `strength`. The\n                number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added\n                noise will be maximum and the denoising process will run for the full number of iterations specified in\n                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference. This parameter will be modulated by `strength`.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            max_embeddings_multiples (`int`, *optional*, defaults to `3`):\n                The max multiple length of prompt embeddings compared to the max output length of text encoder.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        return self.__call__(\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            image=image,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            strength=strength,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            max_embeddings_multiples=max_embeddings_multiples,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n            **kwargs,\n        )\n\n    def inpaint(\n        self,\n        image: Union[np.ndarray, PIL.Image.Image],\n        mask_image: Union[np.ndarray, PIL.Image.Image],\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        strength: float = 0.8,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: Optional[float] = 0.0,\n        generator: torch.Generator | None = None,\n        max_embeddings_multiples: Optional[int] = 3,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, np.ndarray], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function for inpaint.\n        Args:\n            image (`np.ndarray` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, that will be used as the starting point for the\n                process. This is the image whose masked region will be inpainted.\n            mask_image (`np.ndarray` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a\n                PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should\n                contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            strength (`float`, *optional*, defaults to 0.8):\n                Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`\n                is 1, the denoising process will be run on the masked area for the full number of iterations specified\n                in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more\n                noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The reference number of denoising steps. More denoising steps usually lead to a higher quality image at\n                the expense of slower inference. This parameter will be modulated by `strength`, as explained above.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            max_embeddings_multiples (`int`, *optional*, defaults to `3`):\n                The max multiple length of prompt embeddings compared to the max output length of text encoder.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        return self.__call__(\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            image=image,\n            mask_image=mask_image,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            strength=strength,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            max_embeddings_multiples=max_embeddings_multiples,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n            **kwargs,\n        )\n"
  },
  {
    "path": "examples/community/lpw_stable_diffusion_xl.py",
    "content": "## ----------------------------------------------------------\n# A SDXL pipeline can take unlimited weighted prompt\n#\n# Author: Andrew Zhu\n# GitHub: https://github.com/xhinker\n# Medium: https://medium.com/@xhinker\n## -----------------------------------------------------------\n\nimport inspect\nimport os\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom PIL import Image\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTextModel,\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers import DiffusionPipeline, StableDiffusionXLPipeline\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    is_accelerate_available,\n    is_accelerate_version,\n    is_invisible_watermark_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_invisible_watermark_available():\n    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker\n\n\ndef parse_prompt_attention(text):\n    \"\"\"\n    Parses a string with attention tokens and returns a list of pairs: text and its associated weight.\n    Accepted tokens are:\n      (abc) - increases attention to abc by a multiplier of 1.1\n      (abc:3.12) - increases attention to abc by a multiplier of 3.12\n      [abc] - decreases attention to abc by a multiplier of 1.1\n      \\\\( - literal character '('\n      \\\\[ - literal character '['\n      \\\\) - literal character ')'\n      \\\\] - literal character ']'\n      \\\\ - literal character '\\'\n      anything else - just text\n\n    >>> parse_prompt_attention('normal text')\n    [['normal text', 1.0]]\n    >>> parse_prompt_attention('an (important) word')\n    [['an ', 1.0], ['important', 1.1], [' word', 1.0]]\n    >>> parse_prompt_attention('(unbalanced')\n    [['unbalanced', 1.1]]\n    >>> parse_prompt_attention('\\\\(literal\\\\]')\n    [['(literal]', 1.0]]\n    >>> parse_prompt_attention('(unnecessary)(parens)')\n    [['unnecessaryparens', 1.1]]\n    >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')\n    [['a ', 1.0],\n     ['house', 1.5730000000000004],\n     [' ', 1.1],\n     ['on', 1.0],\n     [' a ', 1.1],\n     ['hill', 0.55],\n     [', sun, ', 1.1],\n     ['sky', 1.4641000000000006],\n     ['.', 1.1]]\n    \"\"\"\n    import re\n\n    re_attention = re.compile(\n        r\"\"\"\n            \\\\\\(|\\\\\\)|\\\\\\[|\\\\]|\\\\\\\\|\\\\|\\(|\\[|:([+-]?[.\\d]+)\\)|\n            \\)|]|[^\\\\()\\[\\]:]+|:\n        \"\"\",\n        re.X,\n    )\n\n    re_break = re.compile(r\"\\s*\\bBREAK\\b\\s*\", re.S)\n\n    res = []\n    round_brackets = []\n    square_brackets = []\n\n    round_bracket_multiplier = 1.1\n    square_bracket_multiplier = 1 / 1.1\n\n    def multiply_range(start_position, multiplier):\n        for p in range(start_position, len(res)):\n            res[p][1] *= multiplier\n\n    for m in re_attention.finditer(text):\n        text = m.group(0)\n        weight = m.group(1)\n\n        if text.startswith(\"\\\\\"):\n            res.append([text[1:], 1.0])\n        elif text == \"(\":\n            round_brackets.append(len(res))\n        elif text == \"[\":\n            square_brackets.append(len(res))\n        elif weight is not None and len(round_brackets) > 0:\n            multiply_range(round_brackets.pop(), float(weight))\n        elif text == \")\" and len(round_brackets) > 0:\n            multiply_range(round_brackets.pop(), round_bracket_multiplier)\n        elif text == \"]\" and len(square_brackets) > 0:\n            multiply_range(square_brackets.pop(), square_bracket_multiplier)\n        else:\n            parts = re.split(re_break, text)\n            for i, part in enumerate(parts):\n                if i > 0:\n                    res.append([\"BREAK\", -1])\n                res.append([part, 1.0])\n\n    for pos in round_brackets:\n        multiply_range(pos, round_bracket_multiplier)\n\n    for pos in square_brackets:\n        multiply_range(pos, square_bracket_multiplier)\n\n    if len(res) == 0:\n        res = [[\"\", 1.0]]\n\n    # merge runs of identical weights\n    i = 0\n    while i + 1 < len(res):\n        if res[i][1] == res[i + 1][1]:\n            res[i][0] += res[i + 1][0]\n            res.pop(i + 1)\n        else:\n            i += 1\n\n    return res\n\n\ndef get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str):\n    \"\"\"\n    Get prompt token ids and weights, this function works for both prompt and negative prompt\n\n    Args:\n        pipe (CLIPTokenizer)\n            A CLIPTokenizer\n        prompt (str)\n            A prompt string with weights\n\n    Returns:\n        text_tokens (list)\n            A list contains token ids\n        text_weight (list)\n            A list contains the correspondent weight of token ids\n\n    Example:\n        import torch\n        from transformers import CLIPTokenizer\n\n        clip_tokenizer = CLIPTokenizer.from_pretrained(\n            \"stablediffusionapi/deliberate-v2\"\n            , subfolder = \"tokenizer\"\n            , dtype = torch.float16\n        )\n\n        token_id_list, token_weight_list = get_prompts_tokens_with_weights(\n            clip_tokenizer = clip_tokenizer\n            ,prompt = \"a (red:1.5) cat\"*70\n        )\n    \"\"\"\n    texts_and_weights = parse_prompt_attention(prompt)\n    text_tokens, text_weights = [], []\n    for word, weight in texts_and_weights:\n        # tokenize and discard the starting and the ending token\n        token = clip_tokenizer(word, truncation=False).input_ids[1:-1]  # so that tokenize whatever length prompt\n        # the returned token is a 1d list: [320, 1125, 539, 320]\n\n        # merge the new tokens to the all tokens holder: text_tokens\n        text_tokens = [*text_tokens, *token]\n\n        # each token chunk will come with one weight, like ['red cat', 2.0]\n        # need to expand weight for each token.\n        chunk_weights = [weight] * len(token)\n\n        # append the weight back to the weight holder: text_weights\n        text_weights = [*text_weights, *chunk_weights]\n    return text_tokens, text_weights\n\n\ndef group_tokens_and_weights(token_ids: list, weights: list, pad_last_block=False):\n    \"\"\"\n    Produce tokens and weights in groups and pad the missing tokens\n\n    Args:\n        token_ids (list)\n            The token ids from tokenizer\n        weights (list)\n            The weights list from function get_prompts_tokens_with_weights\n        pad_last_block (bool)\n            Control if fill the last token list to 75 tokens with eos\n    Returns:\n        new_token_ids (2d list)\n        new_weights (2d list)\n\n    Example:\n        token_groups,weight_groups = group_tokens_and_weights(\n            token_ids = token_id_list\n            , weights = token_weight_list\n        )\n    \"\"\"\n    bos, eos = 49406, 49407\n\n    # this will be a 2d list\n    new_token_ids = []\n    new_weights = []\n    while len(token_ids) >= 75:\n        # get the first 75 tokens\n        head_75_tokens = [token_ids.pop(0) for _ in range(75)]\n        head_75_weights = [weights.pop(0) for _ in range(75)]\n\n        # extract token ids and weights\n        temp_77_token_ids = [bos] + head_75_tokens + [eos]\n        temp_77_weights = [1.0] + head_75_weights + [1.0]\n\n        # add 77 token and weights chunk to the holder list\n        new_token_ids.append(temp_77_token_ids)\n        new_weights.append(temp_77_weights)\n\n    # padding the left\n    if len(token_ids) > 0:\n        padding_len = 75 - len(token_ids) if pad_last_block else 0\n\n        temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]\n        new_token_ids.append(temp_77_token_ids)\n\n        temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]\n        new_weights.append(temp_77_weights)\n\n    return new_token_ids, new_weights\n\n\ndef get_weighted_text_embeddings_sdxl(\n    pipe: StableDiffusionXLPipeline,\n    prompt: str = \"\",\n    prompt_2: str = None,\n    neg_prompt: str = \"\",\n    neg_prompt_2: str = None,\n    num_images_per_prompt: int = 1,\n    device: Optional[torch.device] = None,\n    clip_skip: Optional[int] = None,\n    lora_scale: Optional[int] = None,\n):\n    \"\"\"\n    This function can process long prompt with weights, no length limitation\n    for Stable Diffusion XL\n\n    Args:\n        pipe (StableDiffusionPipeline)\n        prompt (str)\n        prompt_2 (str)\n        neg_prompt (str)\n        neg_prompt_2 (str)\n        num_images_per_prompt (int)\n        device (torch.device)\n        clip_skip (int)\n    Returns:\n        prompt_embeds (torch.Tensor)\n        neg_prompt_embeds (torch.Tensor)\n    \"\"\"\n    device = device or pipe._execution_device\n\n    # set lora scale so that monkey patched LoRA\n    # function of text encoder can correctly access it\n    if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin):\n        pipe._lora_scale = lora_scale\n\n        # dynamically adjust the LoRA scale\n        if pipe.text_encoder is not None:\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(pipe.text_encoder, lora_scale)\n\n        if pipe.text_encoder_2 is not None:\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale)\n            else:\n                scale_lora_layers(pipe.text_encoder_2, lora_scale)\n\n    if prompt_2:\n        prompt = f\"{prompt} {prompt_2}\"\n\n    if neg_prompt_2:\n        neg_prompt = f\"{neg_prompt} {neg_prompt_2}\"\n\n    prompt_t1 = prompt_t2 = prompt\n    neg_prompt_t1 = neg_prompt_t2 = neg_prompt\n\n    if isinstance(pipe, TextualInversionLoaderMixin):\n        prompt_t1 = pipe.maybe_convert_prompt(prompt_t1, pipe.tokenizer)\n        neg_prompt_t1 = pipe.maybe_convert_prompt(neg_prompt_t1, pipe.tokenizer)\n        prompt_t2 = pipe.maybe_convert_prompt(prompt_t2, pipe.tokenizer_2)\n        neg_prompt_t2 = pipe.maybe_convert_prompt(neg_prompt_t2, pipe.tokenizer_2)\n\n    eos = pipe.tokenizer.eos_token_id\n\n    # tokenizer 1\n    prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, prompt_t1)\n    neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt_t1)\n\n    # tokenizer 2\n    prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt_t2)\n    neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt_t2)\n\n    # padding the shorter one for prompt set 1\n    prompt_token_len = len(prompt_tokens)\n    neg_prompt_token_len = len(neg_prompt_tokens)\n\n    if prompt_token_len > neg_prompt_token_len:\n        # padding the neg_prompt with eos token\n        neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)\n        neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)\n    else:\n        # padding the prompt\n        prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)\n        prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)\n\n    # padding the shorter one for token set 2\n    prompt_token_len_2 = len(prompt_tokens_2)\n    neg_prompt_token_len_2 = len(neg_prompt_tokens_2)\n\n    if prompt_token_len_2 > neg_prompt_token_len_2:\n        # padding the neg_prompt with eos token\n        neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)\n        neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)\n    else:\n        # padding the prompt\n        prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)\n        prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)\n\n    embeds = []\n    neg_embeds = []\n\n    prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy())\n\n    neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(\n        neg_prompt_tokens.copy(), neg_prompt_weights.copy()\n    )\n\n    prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights(\n        prompt_tokens_2.copy(), prompt_weights_2.copy()\n    )\n\n    neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights(\n        neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy()\n    )\n\n    # get prompt embeddings one by one is not working.\n    for i in range(len(prompt_token_groups)):\n        # get positive prompt embeddings with weights\n        token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device)\n        weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device)\n\n        token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device)\n\n        # use first text encoder\n        prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)\n\n        # use second text encoder\n        prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)\n        pooled_prompt_embeds = prompt_embeds_2[0]\n\n        if clip_skip is None:\n            prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]\n            prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]\n        else:\n            # \"2\" because SDXL always indexes from the penultimate layer.\n            prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-(clip_skip + 2)]\n            prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-(clip_skip + 2)]\n\n        prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]\n        token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)\n\n        for j in range(len(weight_tensor)):\n            if weight_tensor[j] != 1.0:\n                token_embedding[j] = (\n                    token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]\n                )\n\n        token_embedding = token_embedding.unsqueeze(0)\n        embeds.append(token_embedding)\n\n        # get negative prompt embeddings with weights\n        neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device)\n        neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device)\n        neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device)\n\n        # use first text encoder\n        neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True)\n        neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]\n\n        # use second text encoder\n        neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True)\n        neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]\n        negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]\n\n        neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]\n        neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)\n\n        for z in range(len(neg_weight_tensor)):\n            if neg_weight_tensor[z] != 1.0:\n                neg_token_embedding[z] = (\n                    neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]\n                )\n\n        neg_token_embedding = neg_token_embedding.unsqueeze(0)\n        neg_embeds.append(neg_token_embedding)\n\n    prompt_embeds = torch.cat(embeds, dim=1)\n    negative_prompt_embeds = torch.cat(neg_embeds, dim=1)\n\n    bs_embed, seq_len, _ = prompt_embeds.shape\n    # duplicate text embeddings for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n    seq_len = negative_prompt_embeds.shape[1]\n    negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n    pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(\n        bs_embed * num_images_per_prompt, -1\n    )\n    negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(\n        bs_embed * num_images_per_prompt, -1\n    )\n\n    if pipe.text_encoder is not None:\n        if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(pipe.text_encoder, lora_scale)\n\n    if pipe.text_encoder_2 is not None:\n        if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(pipe.text_encoder_2, lora_scale)\n\n    return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n\n# -------------------------------------------------------------------------------------------------------------------------------\n# reuse the backbone code from StableDiffusionXLPipeline\n# -------------------------------------------------------------------------------------------------------------------------------\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        from diffusers import DiffusionPipeline\n        import torch\n\n        pipe = DiffusionPipeline.from_pretrained(\n            \"stabilityai/stable-diffusion-xl-base-1.0\"\n            , torch_dtype       = torch.float16\n            , use_safetensors   = True\n            , variant           = \"fp16\"\n            , custom_pipeline   = \"lpw_stable_diffusion_xl\",\n        )\n\n        prompt = \"a white cat running on the grass\"*20\n        prompt2 = \"play a football\"*20\n        prompt = f\"{prompt},{prompt2}\"\n        neg_prompt = \"blur, low quality\"\n\n        pipe.to(\"cuda\")\n        images = pipe(\n            prompt                  = prompt\n            , negative_prompt       = neg_prompt\n        ).images[0]\n\n        pipe.to(\"cpu\")\n        torch.cuda.empty_cache()\n        images\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass SDXLLongPromptWeightingPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion XL.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion XL uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([` CLIPTextModelWithProjection`]):\n            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\n            specifically the\n            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)\n            variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`CLIPTokenizer`):\n            Second Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]):\n            Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->image_encoder->unet->vae\"\n    _optional_components = [\n        \"tokenizer\",\n        \"tokenizer_2\",\n        \"text_encoder\",\n        \"text_encoder_2\",\n        \"image_encoder\",\n        \"feature_extractor\",\n    ]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"add_text_embeds\",\n        \"add_time_ids\",\n        \"negative_pooled_prompt_embeds\",\n        \"negative_add_time_ids\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        text_encoder_2: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        tokenizer_2: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        feature_extractor: Optional[CLIPImageProcessor] = None,\n        image_encoder: Optional[CLIPVisionModelWithProjection] = None,\n        force_zeros_for_empty_prompt: bool = True,\n        add_watermarker: Optional[bool] = None,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            unet=unet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.mask_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True\n        )\n        self.default_sample_size = (\n            self.unet.config.sample_size\n            if hasattr(self, \"unet\") and self.unet is not None and hasattr(self.unet.config, \"sample_size\")\n            else 128\n        )\n\n        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()\n\n        if add_watermarker:\n            self.watermark = StableDiffusionXLWatermarker()\n        else:\n            self.watermark = None\n\n    def enable_model_cpu_offload(self, gpu_id=0):\n        r\"\"\"\n        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared\n        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`\n        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with\n        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.\n        \"\"\"\n        if is_accelerate_available() and is_accelerate_version(\">=\", \"0.17.0.dev0\"):\n            from accelerate import cpu_offload_with_hook\n        else:\n            raise ImportError(\"`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.\")\n\n        device = torch.device(f\"cuda:{gpu_id}\")\n\n        if self.device.type != \"cpu\":\n            self.to(\"cpu\", silence_dtype_warnings=True)\n            torch.cuda.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)\n\n        model_sequence = (\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\n        )\n        model_sequence.extend([self.unet, self.vae])\n\n        hook = None\n        for cpu_offloaded_model in model_sequence:\n            _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)\n\n        # We'll offload the last model manually.\n        self.final_offload_hook = hook\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: str,\n        prompt_2: str | None = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]\n        text_encoders = (\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\n        )\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            # textual inversion: process multi-vector tokens if necessary\n            prompt_embeds_list = []\n            prompts = [prompt, prompt_2]\n            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                text_input_ids = text_inputs.input_ids\n                untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                    text_input_ids, untruncated_ids\n                ):\n                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                    logger.warning(\n                        \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                        f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                    )\n\n                prompt_embeds = text_encoder(\n                    text_input_ids.to(device),\n                    output_hidden_states=True,\n                )\n\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:\n                    pooled_prompt_embeds = prompt_embeds[0]\n\n                prompt_embeds = prompt_embeds.hidden_states[-2]\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\n\n            uncond_tokens: List[str]\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt, negative_prompt_2]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = [negative_prompt, negative_prompt_2]\n\n            negative_prompt_embeds_list = []\n            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    negative_prompt,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                negative_prompt_embeds = text_encoder(\n                    uncond_input.input_ids.to(device),\n                    output_hidden_states=True,\n                )\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:\n                    negative_pooled_prompt_embeds = negative_prompt_embeds[0]\n                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        height,\n        width,\n        strength,\n        callback_steps,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):\n        # get the original timestep using init_timestep\n        if denoising_start is None:\n            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n            t_start = max(num_inference_steps - init_timestep, 0)\n        else:\n            t_start = 0\n\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        # Strength is irrelevant if we directly request a timestep to start at;\n        # that is, strength is determined by the denoising_start instead.\n        if denoising_start is not None:\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_start * self.scheduler.config.num_train_timesteps)\n                )\n            )\n\n            num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()\n            if self.scheduler.order == 2 and num_inference_steps % 2 == 0:\n                # if the scheduler is a 2nd order scheduler we might have to do +1\n                # because `num_inference_steps` might be even given that every timestep\n                # (except the highest one) is duplicated. If `num_inference_steps` is even it would\n                # mean that we cut the timesteps in the middle of the denoising step\n                # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1\n                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler\n                num_inference_steps = num_inference_steps + 1\n\n            # because t_n+1 >= t_n, we slice the timesteps starting from the end\n            timesteps = timesteps[-num_inference_steps:]\n            return timesteps, num_inference_steps\n\n        return timesteps, num_inference_steps - t_start\n\n    def prepare_latents(\n        self,\n        image,\n        mask,\n        width,\n        height,\n        num_channels_latents,\n        timestep,\n        batch_size,\n        num_images_per_prompt,\n        dtype,\n        device,\n        generator=None,\n        add_noise=True,\n        latents=None,\n        is_strength_max=True,\n        return_noise=False,\n        return_image_latents=False,\n    ):\n        batch_size *= num_images_per_prompt\n\n        if image is None:\n            shape = (\n                batch_size,\n                num_channels_latents,\n                int(height) // self.vae_scale_factor,\n                int(width) // self.vae_scale_factor,\n            )\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            if latents is None:\n                latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            else:\n                latents = latents.to(device)\n\n            # scale the initial noise by the standard deviation required by the scheduler\n            latents = latents * self.scheduler.init_noise_sigma\n            return latents\n\n        elif mask is None:\n            if not isinstance(image, (torch.Tensor, Image.Image, list)):\n                raise ValueError(\n                    f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n                )\n\n            # Offload text encoder if `enable_model_cpu_offload` was enabled\n            if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n                self.text_encoder_2.to(\"cpu\")\n                torch.cuda.empty_cache()\n\n            image = image.to(device=device, dtype=dtype)\n\n            if image.shape[1] == 4:\n                init_latents = image\n\n            else:\n                # make sure the VAE is in float32 mode, as it overflows in float16\n                if self.vae.config.force_upcast:\n                    image = image.float()\n                    self.vae.to(dtype=torch.float32)\n\n                if isinstance(generator, list) and len(generator) != batch_size:\n                    raise ValueError(\n                        f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                        f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                    )\n\n                elif isinstance(generator, list):\n                    init_latents = [\n                        retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                        for i in range(batch_size)\n                    ]\n                    init_latents = torch.cat(init_latents, dim=0)\n                else:\n                    init_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n                if self.vae.config.force_upcast:\n                    self.vae.to(dtype)\n\n                init_latents = init_latents.to(dtype)\n                init_latents = self.vae.config.scaling_factor * init_latents\n\n            if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n                # expand init_latents for batch_size\n                additional_image_per_prompt = batch_size // init_latents.shape[0]\n                init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n            elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n                raise ValueError(\n                    f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n                )\n            else:\n                init_latents = torch.cat([init_latents], dim=0)\n\n            if add_noise:\n                shape = init_latents.shape\n                noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n                # get latents\n                init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n\n            latents = init_latents\n            return latents\n\n        else:\n            shape = (\n                batch_size,\n                num_channels_latents,\n                int(height) // self.vae_scale_factor,\n                int(width) // self.vae_scale_factor,\n            )\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            if (image is None or timestep is None) and not is_strength_max:\n                raise ValueError(\n                    \"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise.\"\n                    \"However, either the image or the noise timestep has not been provided.\"\n                )\n\n            if image.shape[1] == 4:\n                image_latents = image.to(device=device, dtype=dtype)\n                image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)\n            elif return_image_latents or (latents is None and not is_strength_max):\n                image = image.to(device=device, dtype=dtype)\n                image_latents = self._encode_vae_image(image=image, generator=generator)\n                image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)\n\n            if latents is None and add_noise:\n                noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n                # if strength is 1. then initialise the latents to noise, else initial to image + noise\n                latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)\n                # if pure noise then scale the initial latents by the  Scheduler's init sigma\n                latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents\n            elif add_noise:\n                noise = latents.to(device)\n                latents = noise * self.scheduler.init_noise_sigma\n            else:\n                noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n                latents = image_latents.to(device)\n\n            outputs = (latents,)\n\n            if return_noise:\n                outputs += (noise,)\n\n            if return_image_latents:\n                outputs += (image_latents,)\n\n            return outputs\n\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\n        dtype = image.dtype\n        if self.vae.config.force_upcast:\n            image = image.float()\n            self.vae.to(dtype=torch.float32)\n\n        if isinstance(generator, list):\n            image_latents = [\n                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                for i in range(image.shape[0])\n            ]\n            image_latents = torch.cat(image_latents, dim=0)\n        else:\n            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n        if self.vae.config.force_upcast:\n            self.vae.to(dtype)\n\n        image_latents = image_latents.to(dtype)\n        image_latents = self.vae.config.scaling_factor * image_latents\n\n        return image_latents\n\n    def prepare_mask_latents(\n        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance\n    ):\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\n        # and half precision\n        mask = torch.nn.functional.interpolate(\n            mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)\n        )\n        mask = mask.to(device=device, dtype=dtype)\n\n        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\n        if mask.shape[0] < batch_size:\n            if not batch_size % mask.shape[0] == 0:\n                raise ValueError(\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\n                    f\" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number\"\n                    \" of masks that you pass is divisible by the total requested batch size.\"\n                )\n            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)\n\n        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask\n\n        if masked_image is not None and masked_image.shape[1] == 4:\n            masked_image_latents = masked_image\n        else:\n            masked_image_latents = None\n\n        if masked_image is not None:\n            if masked_image_latents is None:\n                masked_image = masked_image.to(device=device, dtype=dtype)\n                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)\n\n            if masked_image_latents.shape[0] < batch_size:\n                if not batch_size % masked_image_latents.shape[0] == 0:\n                    raise ValueError(\n                        \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                        f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\n                        \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                    )\n                masked_image_latents = masked_image_latents.repeat(\n                    batch_size // masked_image_latents.shape[0], 1, 1, 1\n                )\n\n            masked_image_latents = (\n                torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents\n            )\n\n            # aligning device to prevent device errors when concating it with the latent model input\n            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\n\n        return mask, masked_image_latents\n\n    def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n\n        passed_add_embed_dim = (\n            self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim\n        )\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        return add_time_ids\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def denoising_end(self):\n        return self._denoising_end\n\n    @property\n    def denoising_start(self):\n        return self._denoising_start\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: str = None,\n        prompt_2: str | None = None,\n        image: Optional[PipelineImageInput] = None,\n        mask_image: Optional[PipelineImageInput] = None,\n        masked_image_latents: Optional[torch.Tensor] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        strength: float = 0.8,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        denoising_start: Optional[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str`):\n                The prompt  to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str`):\n                The prompt to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            image (`PipelineImageInput`, *optional*):\n                `Image`, or tensor representing an image batch, that will be used as the starting point for the\n                process.\n            mask_image (`PipelineImageInput`, *optional*):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a\n                PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should\n                contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            strength (`float`, *optional*, defaults to 0.8):\n                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.\n                `image` will be used as a starting point, adding more noise to it the larger the `strength`. The\n                number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added\n                noise will be maximum and the denoising process will run for the full number of iterations specified in\n                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            denoising_start (`float`, *optional*):\n                When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be\n                bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and\n                it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,\n                strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline\n                is integrated into a \"Mixture of Denoisers\" multi-pipeline setup, as detailed in [**Refine Image\n                Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be\n                denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the\n                final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline\n                forms a part of a \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refine Image\n                Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str`):\n                The prompt not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str`):\n                The prompt not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            ip_adapter_image: (`PipelineImageInput`, *optional*):\n                Optional image input to work with IP Adapters.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead\n                of a plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of\n                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).\n                Guidance rescale factor should fix overexposure when using zero terminal SNR.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple`. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        # 0. Default height and width to unet\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            height,\n            width,\n            strength,\n            callback_steps,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n        self._denoising_start = denoising_start\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        if ip_adapter_image is not None:\n            output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True\n            image_embeds, negative_image_embeds = self.encode_image(\n                ip_adapter_image, device, num_images_per_prompt, output_hidden_state\n            )\n            if self.do_classifier_free_guidance:\n                image_embeds = torch.cat([negative_image_embeds, image_embeds])\n\n        # 3. Encode input prompt\n        lora_scale = (\n            self._cross_attention_kwargs.get(\"scale\", None) if self._cross_attention_kwargs is not None else None\n        )\n\n        negative_prompt = negative_prompt if negative_prompt is not None else \"\"\n\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = get_weighted_text_embeddings_sdxl(\n            pipe=self,\n            prompt=prompt,\n            neg_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            clip_skip=clip_skip,\n            lora_scale=lora_scale,\n        )\n        dtype = prompt_embeds.dtype\n\n        if isinstance(image, Image.Image):\n            image = self.image_processor.preprocess(image, height=height, width=width)\n        if image is not None:\n            image = image.to(device=self.device, dtype=dtype)\n\n        if isinstance(mask_image, Image.Image):\n            mask = self.mask_processor.preprocess(mask_image, height=height, width=width)\n        else:\n            mask = mask_image\n        if mask_image is not None:\n            mask = mask.to(device=self.device, dtype=dtype)\n\n            if masked_image_latents is not None:\n                masked_image = masked_image_latents\n            elif image.shape[1] == 4:\n                # if image is in latent space, we can't mask it\n                masked_image = None\n            else:\n                masked_image = image * (mask < 0.5)\n        else:\n            mask = None\n\n        # 4. Prepare timesteps\n        def denoising_value_valid(dnv):\n            return isinstance(dnv, float) and 0 < dnv < 1\n\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n        if image is not None:\n            timesteps, num_inference_steps = self.get_timesteps(\n                num_inference_steps,\n                strength,\n                device,\n                denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,\n            )\n\n            # check that number of inference steps is not < 1 - as this doesn't make sense\n            if num_inference_steps < 1:\n                raise ValueError(\n                    f\"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline\"\n                    f\"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline.\"\n                )\n\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n        is_strength_max = strength == 1.0\n        add_noise = True if self.denoising_start is None else False\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.vae.config.latent_channels\n        num_channels_unet = self.unet.config.in_channels\n        return_image_latents = num_channels_unet == 4\n\n        latents = self.prepare_latents(\n            image=image,\n            mask=mask,\n            width=width,\n            height=height,\n            num_channels_latents=num_channels_unet,\n            timestep=latent_timestep,\n            batch_size=batch_size,\n            num_images_per_prompt=num_images_per_prompt,\n            dtype=prompt_embeds.dtype,\n            device=device,\n            generator=generator,\n            add_noise=add_noise,\n            latents=latents,\n            is_strength_max=is_strength_max,\n            return_noise=True,\n            return_image_latents=return_image_latents,\n        )\n\n        if mask is not None:\n            if return_image_latents:\n                latents, noise, image_latents = latents\n            else:\n                latents, noise = latents\n\n        # 5.1 Prepare mask latent variables\n        if mask is not None:\n            mask, masked_image_latents = self.prepare_mask_latents(\n                mask=mask,\n                masked_image=masked_image,\n                batch_size=batch_size * num_images_per_prompt,\n                height=height,\n                width=width,\n                dtype=prompt_embeds.dtype,\n                device=device,\n                generator=generator,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n            )\n\n            # Check that sizes of mask, masked image and latents match\n            if num_channels_unet == 9:\n                # default case for stable-diffusion-v1-5/stable-diffusion-inpainting\n                num_channels_mask = mask.shape[1]\n                num_channels_masked_image = masked_image_latents.shape[1]\n                if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:\n                    raise ValueError(\n                        f\"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects\"\n                        f\" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +\"\n                        f\" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}\"\n                        f\" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of\"\n                        \" `pipeline.unet` or your `mask_image` or `image` input.\"\n                    )\n            elif num_channels_unet != 4:\n                raise ValueError(\n                    f\"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}.\"\n                )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 6.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = {\"image_embeds\": image_embeds} if ip_adapter_image is not None else {}\n\n        height, width = latents.shape[-2:]\n        height = height * self.vae_scale_factor\n        width = width * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 7. Prepare added time ids & embeddings\n        add_text_embeds = pooled_prompt_embeds\n        add_time_ids = self._get_add_time_ids(\n            original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype\n        )\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # 7.1 Apply denoising_end\n        if (\n            self.denoising_end is not None\n            and self.denoising_start is not None\n            and denoising_value_valid(self.denoising_end)\n            and denoising_value_valid(self.denoising_start)\n            and self.denoising_start >= self.denoising_end\n        ):\n            raise ValueError(\n                f\"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: \"\n                + f\" {self.denoising_end} when using type float.\"\n            )\n        elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        # 8. Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        self._num_timesteps = len(timesteps)\n\n        # 9. Denoising loop\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                if mask is not None and num_channels_unet == 9:\n                    latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)\n\n                # predict the noise residual\n                added_cond_kwargs.update({\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids})\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if mask is not None and num_channels_unet == 4:\n                    init_latents_proper = image_latents\n\n                    if self.do_classifier_free_guidance:\n                        init_mask, _ = mask.chunk(2)\n                    else:\n                        init_mask = mask\n\n                    if i < len(timesteps) - 1:\n                        noise_timestep = timesteps[i + 1]\n                        init_latents_proper = self.scheduler.add_noise(\n                            init_latents_proper, noise, torch.tensor([noise_timestep])\n                        )\n\n                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                    negative_pooled_prompt_embeds = callback_outputs.pop(\n                        \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                    )\n                    add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            # unscale/denormalize the latents\n            # denormalize with the mean and std if available and not None\n            has_latents_mean = hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None\n            has_latents_std = hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None\n            if has_latents_mean and has_latents_std:\n                latents_mean = (\n                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents_std = (\n                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean\n            else:\n                latents = latents / self.vae.config.scaling_factor\n\n            image = self.vae.decode(latents, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n            return StableDiffusionXLPipelineOutput(images=image)\n\n        # apply watermark if available\n        if self.watermark is not None:\n            image = self.watermark.apply_watermark(image)\n\n        image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n\n    def text2img(\n        self,\n        prompt: str = None,\n        prompt_2: str | None = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        denoising_start: Optional[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling pipeline for text-to-image.\n\n        Refer to the documentation of the `__call__` method for parameter descriptions.\n        \"\"\"\n        return self.__call__(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            timesteps=timesteps,\n            denoising_start=denoising_start,\n            denoising_end=denoising_end,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            ip_adapter_image=ip_adapter_image,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            output_type=output_type,\n            return_dict=return_dict,\n            cross_attention_kwargs=cross_attention_kwargs,\n            guidance_rescale=guidance_rescale,\n            original_size=original_size,\n            crops_coords_top_left=crops_coords_top_left,\n            target_size=target_size,\n            clip_skip=clip_skip,\n            callback_on_step_end=callback_on_step_end,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            **kwargs,\n        )\n\n    def img2img(\n        self,\n        prompt: str = None,\n        prompt_2: str | None = None,\n        image: Optional[PipelineImageInput] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        strength: float = 0.8,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        denoising_start: Optional[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling pipeline for image-to-image.\n\n        Refer to the documentation of the `__call__` method for parameter descriptions.\n        \"\"\"\n        return self.__call__(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            image=image,\n            height=height,\n            width=width,\n            strength=strength,\n            num_inference_steps=num_inference_steps,\n            timesteps=timesteps,\n            denoising_start=denoising_start,\n            denoising_end=denoising_end,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            ip_adapter_image=ip_adapter_image,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            output_type=output_type,\n            return_dict=return_dict,\n            cross_attention_kwargs=cross_attention_kwargs,\n            guidance_rescale=guidance_rescale,\n            original_size=original_size,\n            crops_coords_top_left=crops_coords_top_left,\n            target_size=target_size,\n            clip_skip=clip_skip,\n            callback_on_step_end=callback_on_step_end,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            **kwargs,\n        )\n\n    def inpaint(\n        self,\n        prompt: str = None,\n        prompt_2: str | None = None,\n        image: Optional[PipelineImageInput] = None,\n        mask_image: Optional[PipelineImageInput] = None,\n        masked_image_latents: Optional[torch.Tensor] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        strength: float = 0.8,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        denoising_start: Optional[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling pipeline for inpainting.\n\n        Refer to the documentation of the `__call__` method for parameter descriptions.\n        \"\"\"\n        return self.__call__(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            image=image,\n            mask_image=mask_image,\n            masked_image_latents=masked_image_latents,\n            height=height,\n            width=width,\n            strength=strength,\n            num_inference_steps=num_inference_steps,\n            timesteps=timesteps,\n            denoising_start=denoising_start,\n            denoising_end=denoising_end,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            ip_adapter_image=ip_adapter_image,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            output_type=output_type,\n            return_dict=return_dict,\n            cross_attention_kwargs=cross_attention_kwargs,\n            guidance_rescale=guidance_rescale,\n            original_size=original_size,\n            crops_coords_top_left=crops_coords_top_left,\n            target_size=target_size,\n            clip_skip=clip_skip,\n            callback_on_step_end=callback_on_step_end,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            **kwargs,\n        )\n\n    # Override to properly handle the loading and unloading of the additional text encoder.\n    def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):\n        # We could have accessed the unet config from `lora_state_dict()` too. We pass\n        # it here explicitly to be able to tell that it's coming from an SDXL\n        # pipeline.\n        state_dict, network_alphas = self.lora_state_dict(\n            pretrained_model_name_or_path_or_dict,\n            unet_config=self.unet.config,\n            **kwargs,\n        )\n        self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)\n\n        text_encoder_state_dict = {k: v for k, v in state_dict.items() if \"text_encoder.\" in k}\n        if len(text_encoder_state_dict) > 0:\n            self.load_lora_into_text_encoder(\n                text_encoder_state_dict,\n                network_alphas=network_alphas,\n                text_encoder=self.text_encoder,\n                prefix=\"text_encoder\",\n                lora_scale=self.lora_scale,\n            )\n\n        text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if \"text_encoder_2.\" in k}\n        if len(text_encoder_2_state_dict) > 0:\n            self.load_lora_into_text_encoder(\n                text_encoder_2_state_dict,\n                network_alphas=network_alphas,\n                text_encoder=self.text_encoder_2,\n                prefix=\"text_encoder_2\",\n                lora_scale=self.lora_scale,\n            )\n\n    @classmethod\n    def save_lora_weights(\n        cls,\n        save_directory: Union[str, os.PathLike],\n        unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,\n        text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,\n        text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,\n        is_main_process: bool = True,\n        weight_name: str = None,\n        save_function: Callable = None,\n        safe_serialization: bool = False,\n    ):\n        state_dict = {}\n\n        def pack_weights(layers, prefix):\n            layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers\n            layers_state_dict = {f\"{prefix}.{module_name}\": param for module_name, param in layers_weights.items()}\n            return layers_state_dict\n\n        state_dict.update(pack_weights(unet_lora_layers, \"unet\"))\n\n        if text_encoder_lora_layers and text_encoder_2_lora_layers:\n            state_dict.update(pack_weights(text_encoder_lora_layers, \"text_encoder\"))\n            state_dict.update(pack_weights(text_encoder_2_lora_layers, \"text_encoder_2\"))\n\n        cls.write_lora_layers(\n            state_dict=state_dict,\n            save_directory=save_directory,\n            is_main_process=is_main_process,\n            weight_name=weight_name,\n            save_function=save_function,\n            safe_serialization=safe_serialization,\n        )\n\n    def _remove_text_encoder_monkey_patch(self):\n        self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)\n        self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)\n"
  },
  {
    "path": "examples/community/magic_mix.py",
    "content": "from typing import Union\n\nimport torch\nfrom PIL import Image\nfrom torchvision import transforms as tfms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDIMScheduler,\n    DiffusionPipeline,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n    UNet2DConditionModel,\n)\n\n\nclass MagicMixPipeline(DiffusionPipeline):\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],\n    ):\n        super().__init__()\n\n        self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)\n\n    # convert PIL image to latents\n    def encode(self, img):\n        with torch.no_grad():\n            latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1)\n            latent = 0.18215 * latent.latent_dist.sample()\n        return latent\n\n    # convert latents to PIL image\n    def decode(self, latent):\n        latent = (1 / 0.18215) * latent\n        with torch.no_grad():\n            img = self.vae.decode(latent).sample\n        img = (img / 2 + 0.5).clamp(0, 1)\n        img = img.detach().cpu().permute(0, 2, 3, 1).numpy()\n        img = (img * 255).round().astype(\"uint8\")\n        return Image.fromarray(img[0])\n\n    # convert prompt into text embeddings, also unconditional embeddings\n    def prep_text(self, prompt):\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0]\n\n        uncond_input = self.tokenizer(\n            \"\",\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n        return torch.cat([uncond_embedding, text_embedding])\n\n    def __call__(\n        self,\n        img: Image.Image,\n        prompt: str,\n        kmin: float = 0.3,\n        kmax: float = 0.6,\n        mix_factor: float = 0.5,\n        seed: int = 42,\n        steps: int = 50,\n        guidance_scale: float = 7.5,\n    ) -> Image.Image:\n        tmin = steps - int(kmin * steps)\n        tmax = steps - int(kmax * steps)\n\n        text_embeddings = self.prep_text(prompt)\n\n        self.scheduler.set_timesteps(steps)\n\n        width, height = img.size\n        encoded = self.encode(img)\n\n        torch.manual_seed(seed)\n        noise = torch.randn(\n            (1, self.unet.config.in_channels, height // 8, width // 8),\n        ).to(self.device)\n\n        latents = self.scheduler.add_noise(\n            encoded,\n            noise,\n            timesteps=self.scheduler.timesteps[tmax],\n        )\n\n        input = torch.cat([latents] * 2)\n\n        input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax])\n\n        with torch.no_grad():\n            pred = self.unet(\n                input,\n                self.scheduler.timesteps[tmax],\n                encoder_hidden_states=text_embeddings,\n            ).sample\n\n        pred_uncond, pred_text = pred.chunk(2)\n        pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)\n\n        latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample\n\n        for i, t in enumerate(tqdm(self.scheduler.timesteps)):\n            if i > tmax:\n                if i < tmin:  # layout generation phase\n                    orig_latents = self.scheduler.add_noise(\n                        encoded,\n                        noise,\n                        timesteps=t,\n                    )\n\n                    input = (\n                        (mix_factor * latents) + (1 - mix_factor) * orig_latents\n                    )  # interpolating between layout noise and conditionally generated noise to preserve layout semantics\n                    input = torch.cat([input] * 2)\n\n                else:  # content generation phase\n                    input = torch.cat([latents] * 2)\n\n                input = self.scheduler.scale_model_input(input, t)\n\n                with torch.no_grad():\n                    pred = self.unet(\n                        input,\n                        t,\n                        encoder_hidden_states=text_embeddings,\n                    ).sample\n\n                pred_uncond, pred_text = pred.chunk(2)\n                pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)\n\n                latents = self.scheduler.step(pred, t, latents).prev_sample\n\n        return self.decode(latents)\n"
  },
  {
    "path": "examples/community/marigold_depth_estimation.py",
    "content": "# Copyright 2025 Bingxin Ke, ETH Zurich and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# --------------------------------------------------------------------------\n# If you find this code useful, we kindly ask you to cite our paper in your work.\n# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation\n# More information about the method can be found at https://marigoldmonodepth.github.io\n# --------------------------------------------------------------------------\n\n\nimport logging\nimport math\nfrom typing import Dict, Union\n\nimport matplotlib\nimport numpy as np\nimport torch\nfrom PIL import Image\nfrom PIL.Image import Resampling\nfrom scipy.optimize import minimize\nfrom torch.utils.data import DataLoader, TensorDataset\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDIMScheduler,\n    DiffusionPipeline,\n    LCMScheduler,\n    UNet2DConditionModel,\n)\nfrom diffusers.utils import BaseOutput, check_min_version\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\n\nclass MarigoldDepthOutput(BaseOutput):\n    \"\"\"\n    Output class for Marigold monocular depth prediction pipeline.\n\n    Args:\n        depth_np (`np.ndarray`):\n            Predicted depth map, with depth values in the range of [0, 1].\n        depth_colored (`None` or `PIL.Image.Image`):\n            Colorized depth map, with the shape of [3, H, W] and values in [0, 1].\n        uncertainty (`None` or `np.ndarray`):\n            Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.\n    \"\"\"\n\n    depth_np: np.ndarray\n    depth_colored: Union[None, Image.Image]\n    uncertainty: Union[None, np.ndarray]\n\n\ndef get_pil_resample_method(method_str: str) -> Resampling:\n    resample_method_dic = {\n        \"bilinear\": Resampling.BILINEAR,\n        \"bicubic\": Resampling.BICUBIC,\n        \"nearest\": Resampling.NEAREST,\n    }\n    resample_method = resample_method_dic.get(method_str, None)\n    if resample_method is None:\n        raise ValueError(f\"Unknown resampling method: {resample_method}\")\n    else:\n        return resample_method\n\n\nclass MarigoldPipeline(DiffusionPipeline):\n    \"\"\"\n    Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        unet (`UNet2DConditionModel`):\n            Conditional U-Net to denoise the depth latent, conditioned on image latent.\n        vae (`AutoencoderKL`):\n            Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps\n            to and from latent representations.\n        scheduler (`DDIMScheduler`):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents.\n        text_encoder (`CLIPTextModel`):\n            Text-encoder, for empty text embedding.\n        tokenizer (`CLIPTokenizer`):\n            CLIP tokenizer.\n    \"\"\"\n\n    rgb_latent_scale_factor = 0.18215\n    depth_latent_scale_factor = 0.18215\n\n    def __init__(\n        self,\n        unet: UNet2DConditionModel,\n        vae: AutoencoderKL,\n        scheduler: DDIMScheduler,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            unet=unet,\n            vae=vae,\n            scheduler=scheduler,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n        )\n\n        self.empty_text_embed = None\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        input_image: Image,\n        denoising_steps: int = 10,\n        ensemble_size: int = 10,\n        processing_res: int = 768,\n        match_input_res: bool = True,\n        resample_method: str = \"bilinear\",\n        batch_size: int = 0,\n        seed: Union[int, None] = None,\n        color_map: str = \"Spectral\",\n        show_progress_bar: bool = True,\n        ensemble_kwargs: Dict = None,\n    ) -> MarigoldDepthOutput:\n        \"\"\"\n        Function invoked when calling the pipeline.\n\n        Args:\n            input_image (`Image`):\n                Input RGB (or gray-scale) image.\n            processing_res (`int`, *optional*, defaults to `768`):\n                Maximum resolution of processing.\n                If set to 0: will not resize at all.\n            match_input_res (`bool`, *optional*, defaults to `True`):\n                Resize depth prediction to match input resolution.\n                Only valid if `processing_res` > 0.\n            resample_method: (`str`, *optional*, defaults to `bilinear`):\n                Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.\n            denoising_steps (`int`, *optional*, defaults to `10`):\n                Number of diffusion denoising steps (DDIM) during inference.\n            ensemble_size (`int`, *optional*, defaults to `10`):\n                Number of predictions to be ensembled.\n            batch_size (`int`, *optional*, defaults to `0`):\n                Inference batch size, no bigger than `num_ensemble`.\n                If set to 0, the script will automatically decide the proper batch size.\n            seed (`int`, *optional*, defaults to `None`)\n                Reproducibility seed.\n            show_progress_bar (`bool`, *optional*, defaults to `True`):\n                Display a progress bar of diffusion denoising.\n            color_map (`str`, *optional*, defaults to `\"Spectral\"`, pass `None` to skip colorized depth map generation):\n                Colormap used to colorize the depth map.\n            ensemble_kwargs (`dict`, *optional*, defaults to `None`):\n                Arguments for detailed ensembling settings.\n        Returns:\n            `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:\n            - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]\n            - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`\n            - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)\n                    coming from ensembling. None if `ensemble_size = 1`\n        \"\"\"\n\n        device = self.device\n        input_size = input_image.size\n\n        if not match_input_res:\n            assert processing_res is not None, \"Value error: `resize_output_back` is only valid with \"\n        assert processing_res >= 0\n        assert ensemble_size >= 1\n\n        # Check if denoising step is reasonable\n        self._check_inference_step(denoising_steps)\n\n        resample_method: Resampling = get_pil_resample_method(resample_method)\n\n        # ----------------- Image Preprocess -----------------\n        # Resize image\n        if processing_res > 0:\n            input_image = self.resize_max_res(\n                input_image,\n                max_edge_resolution=processing_res,\n                resample_method=resample_method,\n            )\n        # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel\n        input_image = input_image.convert(\"RGB\")\n        image = np.asarray(input_image)\n\n        # Normalize rgb values\n        rgb = np.transpose(image, (2, 0, 1))  # [H, W, rgb] -> [rgb, H, W]\n        rgb_norm = rgb / 255.0 * 2.0 - 1.0  #  [0, 255] -> [-1, 1]\n        rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)\n        rgb_norm = rgb_norm.to(device)\n        assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0\n\n        # ----------------- Predicting depth -----------------\n        # Batch repeated input image\n        duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)\n        single_rgb_dataset = TensorDataset(duplicated_rgb)\n        if batch_size > 0:\n            _bs = batch_size\n        else:\n            _bs = self._find_batch_size(\n                ensemble_size=ensemble_size,\n                input_res=max(rgb_norm.shape[1:]),\n                dtype=self.dtype,\n            )\n\n        single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False)\n\n        # Predict depth maps (batched)\n        depth_pred_ls = []\n        if show_progress_bar:\n            iterable = tqdm(single_rgb_loader, desc=\" \" * 2 + \"Inference batches\", leave=False)\n        else:\n            iterable = single_rgb_loader\n        for batch in iterable:\n            (batched_img,) = batch\n            depth_pred_raw = self.single_infer(\n                rgb_in=batched_img,\n                num_inference_steps=denoising_steps,\n                show_pbar=show_progress_bar,\n                seed=seed,\n            )\n            depth_pred_ls.append(depth_pred_raw.detach())\n        depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()\n        torch.cuda.empty_cache()  # clear vram cache for ensembling\n\n        # ----------------- Test-time ensembling -----------------\n        if ensemble_size > 1:\n            depth_pred, pred_uncert = self.ensemble_depths(depth_preds, **(ensemble_kwargs or {}))\n        else:\n            depth_pred = depth_preds\n            pred_uncert = None\n\n        # ----------------- Post processing -----------------\n        # Scale prediction to [0, 1]\n        min_d = torch.min(depth_pred)\n        max_d = torch.max(depth_pred)\n        depth_pred = (depth_pred - min_d) / (max_d - min_d)\n\n        # Convert to numpy\n        depth_pred = depth_pred.cpu().numpy().astype(np.float32)\n\n        # Resize back to original resolution\n        if match_input_res:\n            pred_img = Image.fromarray(depth_pred)\n            pred_img = pred_img.resize(input_size, resample=resample_method)\n            depth_pred = np.asarray(pred_img)\n\n        # Clip output range\n        depth_pred = depth_pred.clip(0, 1)\n\n        # Colorize\n        if color_map is not None:\n            depth_colored = self.colorize_depth_maps(\n                depth_pred, 0, 1, cmap=color_map\n            ).squeeze()  # [3, H, W], value in (0, 1)\n            depth_colored = (depth_colored * 255).astype(np.uint8)\n            depth_colored_hwc = self.chw2hwc(depth_colored)\n            depth_colored_img = Image.fromarray(depth_colored_hwc)\n        else:\n            depth_colored_img = None\n\n        return MarigoldDepthOutput(\n            depth_np=depth_pred,\n            depth_colored=depth_colored_img,\n            uncertainty=pred_uncert,\n        )\n\n    def _check_inference_step(self, n_step: int):\n        \"\"\"\n        Check if denoising step is reasonable\n        Args:\n            n_step (`int`): denoising steps\n        \"\"\"\n        assert n_step >= 1\n\n        if isinstance(self.scheduler, DDIMScheduler):\n            if n_step < 10:\n                logging.warning(\n                    f\"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference.\"\n                )\n        elif isinstance(self.scheduler, LCMScheduler):\n            if not 1 <= n_step <= 4:\n                logging.warning(f\"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps.\")\n        else:\n            raise RuntimeError(f\"Unsupported scheduler type: {type(self.scheduler)}\")\n\n    def _encode_empty_text(self):\n        \"\"\"\n        Encode text embedding for empty prompt.\n        \"\"\"\n        prompt = \"\"\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"do_not_pad\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)\n        self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)\n\n    @torch.no_grad()\n    def single_infer(\n        self,\n        rgb_in: torch.Tensor,\n        num_inference_steps: int,\n        seed: Union[int, None],\n        show_pbar: bool,\n    ) -> torch.Tensor:\n        \"\"\"\n        Perform an individual depth prediction without ensembling.\n\n        Args:\n            rgb_in (`torch.Tensor`):\n                Input RGB image.\n            num_inference_steps (`int`):\n                Number of diffusion denoisign steps (DDIM) during inference.\n            show_pbar (`bool`):\n                Display a progress bar of diffusion denoising.\n        Returns:\n            `torch.Tensor`: Predicted depth map.\n        \"\"\"\n        device = rgb_in.device\n\n        # Set timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps  # [T]\n\n        # Encode image\n        rgb_latent = self.encode_rgb(rgb_in)\n\n        # Initial depth map (noise)\n        if seed is None:\n            rand_num_generator = None\n        else:\n            rand_num_generator = torch.Generator(device=device)\n            rand_num_generator.manual_seed(seed)\n        depth_latent = torch.randn(\n            rgb_latent.shape,\n            device=device,\n            dtype=self.dtype,\n            generator=rand_num_generator,\n        )  # [B, 4, h, w]\n\n        # Batched empty text embedding\n        if self.empty_text_embed is None:\n            self._encode_empty_text()\n        batch_empty_text_embed = self.empty_text_embed.repeat((rgb_latent.shape[0], 1, 1))  # [B, 2, 1024]\n\n        # Denoising loop\n        if show_pbar:\n            iterable = tqdm(\n                enumerate(timesteps),\n                total=len(timesteps),\n                leave=False,\n                desc=\" \" * 4 + \"Diffusion denoising\",\n            )\n        else:\n            iterable = enumerate(timesteps)\n\n        for i, t in iterable:\n            unet_input = torch.cat([rgb_latent, depth_latent], dim=1)  # this order is important\n\n            # predict the noise residual\n            noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample  # [B, 4, h, w]\n\n            # compute the previous noisy sample x_t -> x_t-1\n            depth_latent = self.scheduler.step(noise_pred, t, depth_latent, generator=rand_num_generator).prev_sample\n\n        depth = self.decode_depth(depth_latent)\n\n        # clip prediction\n        depth = torch.clip(depth, -1.0, 1.0)\n        # shift to [0, 1]\n        depth = (depth + 1.0) / 2.0\n\n        return depth\n\n    def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Encode RGB image into latent.\n\n        Args:\n            rgb_in (`torch.Tensor`):\n                Input RGB image to be encoded.\n\n        Returns:\n            `torch.Tensor`: Image latent.\n        \"\"\"\n        # encode\n        h = self.vae.encoder(rgb_in)\n        moments = self.vae.quant_conv(h)\n        mean, logvar = torch.chunk(moments, 2, dim=1)\n        # scale latent\n        rgb_latent = mean * self.rgb_latent_scale_factor\n        return rgb_latent\n\n    def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Decode depth latent into depth map.\n\n        Args:\n            depth_latent (`torch.Tensor`):\n                Depth latent to be decoded.\n\n        Returns:\n            `torch.Tensor`: Decoded depth map.\n        \"\"\"\n        # scale latent\n        depth_latent = depth_latent / self.depth_latent_scale_factor\n        # decode\n        z = self.vae.post_quant_conv(depth_latent)\n        stacked = self.vae.decoder(z)\n        # mean of output channels\n        depth_mean = stacked.mean(dim=1, keepdim=True)\n        return depth_mean\n\n    @staticmethod\n    def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image:\n        \"\"\"\n        Resize image to limit maximum edge length while keeping aspect ratio.\n\n        Args:\n            img (`Image.Image`):\n                Image to be resized.\n            max_edge_resolution (`int`):\n                Maximum edge length (pixel).\n            resample_method (`PIL.Image.Resampling`):\n                Resampling method used to resize images.\n\n        Returns:\n            `Image.Image`: Resized image.\n        \"\"\"\n        original_width, original_height = img.size\n        downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height)\n\n        new_width = int(original_width * downscale_factor)\n        new_height = int(original_height * downscale_factor)\n\n        resized_img = img.resize((new_width, new_height), resample=resample_method)\n        return resized_img\n\n    @staticmethod\n    def colorize_depth_maps(depth_map, min_depth, max_depth, cmap=\"Spectral\", valid_mask=None):\n        \"\"\"\n        Colorize depth maps.\n        \"\"\"\n        assert len(depth_map.shape) >= 2, \"Invalid dimension\"\n\n        if isinstance(depth_map, torch.Tensor):\n            depth = depth_map.detach().clone().squeeze().numpy()\n        elif isinstance(depth_map, np.ndarray):\n            depth = depth_map.copy().squeeze()\n        # reshape to [ (B,) H, W ]\n        if depth.ndim < 3:\n            depth = depth[np.newaxis, :, :]\n\n        # colorize\n        cm = matplotlib.colormaps[cmap]\n        depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)\n        img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3]  # value from 0 to 1\n        img_colored_np = np.rollaxis(img_colored_np, 3, 1)\n\n        if valid_mask is not None:\n            if isinstance(depth_map, torch.Tensor):\n                valid_mask = valid_mask.detach().numpy()\n            valid_mask = valid_mask.squeeze()  # [H, W] or [B, H, W]\n            if valid_mask.ndim < 3:\n                valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]\n            else:\n                valid_mask = valid_mask[:, np.newaxis, :, :]\n            valid_mask = np.repeat(valid_mask, 3, axis=1)\n            img_colored_np[~valid_mask] = 0\n\n        if isinstance(depth_map, torch.Tensor):\n            img_colored = torch.from_numpy(img_colored_np).float()\n        elif isinstance(depth_map, np.ndarray):\n            img_colored = img_colored_np\n\n        return img_colored\n\n    @staticmethod\n    def chw2hwc(chw):\n        assert 3 == len(chw.shape)\n        if isinstance(chw, torch.Tensor):\n            hwc = torch.permute(chw, (1, 2, 0))\n        elif isinstance(chw, np.ndarray):\n            hwc = np.moveaxis(chw, 0, -1)\n        return hwc\n\n    @staticmethod\n    def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:\n        \"\"\"\n        Automatically search for suitable operating batch size.\n\n        Args:\n            ensemble_size (`int`):\n                Number of predictions to be ensembled.\n            input_res (`int`):\n                Operating resolution of the input image.\n\n        Returns:\n            `int`: Operating batch size.\n        \"\"\"\n        # Search table for suggested max. inference batch size\n        bs_search_table = [\n            # tested on A100-PCIE-80GB\n            {\"res\": 768, \"total_vram\": 79, \"bs\": 35, \"dtype\": torch.float32},\n            {\"res\": 1024, \"total_vram\": 79, \"bs\": 20, \"dtype\": torch.float32},\n            # tested on A100-PCIE-40GB\n            {\"res\": 768, \"total_vram\": 39, \"bs\": 15, \"dtype\": torch.float32},\n            {\"res\": 1024, \"total_vram\": 39, \"bs\": 8, \"dtype\": torch.float32},\n            {\"res\": 768, \"total_vram\": 39, \"bs\": 30, \"dtype\": torch.float16},\n            {\"res\": 1024, \"total_vram\": 39, \"bs\": 15, \"dtype\": torch.float16},\n            # tested on RTX3090, RTX4090\n            {\"res\": 512, \"total_vram\": 23, \"bs\": 20, \"dtype\": torch.float32},\n            {\"res\": 768, \"total_vram\": 23, \"bs\": 7, \"dtype\": torch.float32},\n            {\"res\": 1024, \"total_vram\": 23, \"bs\": 3, \"dtype\": torch.float32},\n            {\"res\": 512, \"total_vram\": 23, \"bs\": 40, \"dtype\": torch.float16},\n            {\"res\": 768, \"total_vram\": 23, \"bs\": 18, \"dtype\": torch.float16},\n            {\"res\": 1024, \"total_vram\": 23, \"bs\": 10, \"dtype\": torch.float16},\n            # tested on GTX1080Ti\n            {\"res\": 512, \"total_vram\": 10, \"bs\": 5, \"dtype\": torch.float32},\n            {\"res\": 768, \"total_vram\": 10, \"bs\": 2, \"dtype\": torch.float32},\n            {\"res\": 512, \"total_vram\": 10, \"bs\": 10, \"dtype\": torch.float16},\n            {\"res\": 768, \"total_vram\": 10, \"bs\": 5, \"dtype\": torch.float16},\n            {\"res\": 1024, \"total_vram\": 10, \"bs\": 3, \"dtype\": torch.float16},\n        ]\n\n        if not torch.cuda.is_available():\n            return 1\n\n        total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3\n        filtered_bs_search_table = [s for s in bs_search_table if s[\"dtype\"] == dtype]\n        for settings in sorted(\n            filtered_bs_search_table,\n            key=lambda k: (k[\"res\"], -k[\"total_vram\"]),\n        ):\n            if input_res <= settings[\"res\"] and total_vram >= settings[\"total_vram\"]:\n                bs = settings[\"bs\"]\n                if bs > ensemble_size:\n                    bs = ensemble_size\n                elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:\n                    bs = math.ceil(ensemble_size / 2)\n                return bs\n\n        return 1\n\n    @staticmethod\n    def ensemble_depths(\n        input_images: torch.Tensor,\n        regularizer_strength: float = 0.02,\n        max_iter: int = 2,\n        tol: float = 1e-3,\n        reduction: str = \"median\",\n        max_res: int = None,\n    ):\n        \"\"\"\n        To ensemble multiple affine-invariant depth images (up to scale and shift),\n            by aligning estimating the scale and shift\n        \"\"\"\n\n        def inter_distances(tensors: torch.Tensor):\n            \"\"\"\n            To calculate the distance between each two depth maps.\n            \"\"\"\n            distances = []\n            for i, j in torch.combinations(torch.arange(tensors.shape[0])):\n                arr1 = tensors[i : i + 1]\n                arr2 = tensors[j : j + 1]\n                distances.append(arr1 - arr2)\n            dist = torch.concatenate(distances, dim=0)\n            return dist\n\n        device = input_images.device\n        dtype = input_images.dtype\n        np_dtype = np.float32\n\n        original_input = input_images.clone()\n        n_img = input_images.shape[0]\n        ori_shape = input_images.shape\n\n        if max_res is not None:\n            scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))\n            if scale_factor < 1:\n                downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode=\"nearest\")\n                input_images = downscaler(torch.from_numpy(input_images)).numpy()\n\n        # init guess\n        _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)\n        _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)\n        s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))\n        t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))\n        x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)\n\n        input_images = input_images.to(device)\n\n        # objective function\n        def closure(x):\n            l = len(x)\n            s = x[: int(l / 2)]\n            t = x[int(l / 2) :]\n            s = torch.from_numpy(s).to(dtype=dtype).to(device)\n            t = torch.from_numpy(t).to(dtype=dtype).to(device)\n\n            transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))\n            dists = inter_distances(transformed_arrays)\n            sqrt_dist = torch.sqrt(torch.mean(dists**2))\n\n            if \"mean\" == reduction:\n                pred = torch.mean(transformed_arrays, dim=0)\n            elif \"median\" == reduction:\n                pred = torch.median(transformed_arrays, dim=0).values\n            else:\n                raise ValueError\n\n            near_err = torch.sqrt((0 - torch.min(pred)) ** 2)\n            far_err = torch.sqrt((1 - torch.max(pred)) ** 2)\n\n            err = sqrt_dist + (near_err + far_err) * regularizer_strength\n            err = err.detach().cpu().numpy().astype(np_dtype)\n            return err\n\n        res = minimize(\n            closure,\n            x,\n            method=\"BFGS\",\n            tol=tol,\n            options={\"maxiter\": max_iter, \"disp\": False},\n        )\n        x = res.x\n        l = len(x)\n        s = x[: int(l / 2)]\n        t = x[int(l / 2) :]\n\n        # Prediction\n        s = torch.from_numpy(s).to(dtype=dtype).to(device)\n        t = torch.from_numpy(t).to(dtype=dtype).to(device)\n        transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)\n        if \"mean\" == reduction:\n            aligned_images = torch.mean(transformed_arrays, dim=0)\n            std = torch.std(transformed_arrays, dim=0)\n            uncertainty = std\n        elif \"median\" == reduction:\n            aligned_images = torch.median(transformed_arrays, dim=0).values\n            # MAD (median absolute deviation) as uncertainty indicator\n            abs_dev = torch.abs(transformed_arrays - aligned_images)\n            mad = torch.median(abs_dev, dim=0).values\n            uncertainty = mad\n        else:\n            raise ValueError(f\"Unknown reduction method: {reduction}\")\n\n        # Scale and shift to [0, 1]\n        _min = torch.min(aligned_images)\n        _max = torch.max(aligned_images)\n        aligned_images = (aligned_images - _min) / (_max - _min)\n        uncertainty /= _max - _min\n\n        return aligned_images, uncertainty\n"
  },
  {
    "path": "examples/community/masked_stable_diffusion_img2img.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\n\nfrom diffusers import StableDiffusionImg2ImgPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\n\n\nclass MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):\n    debug_save = False\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: Union[\n            torch.Tensor,\n            PIL.Image.Image,\n            np.ndarray,\n            List[torch.Tensor],\n            List[PIL.Image.Image],\n            List[np.ndarray],\n        ] = None,\n        strength: float = 0.8,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: Optional[float] = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        mask: Union[\n            torch.Tensor,\n            PIL.Image.Image,\n            np.ndarray,\n            List[torch.Tensor],\n            List[PIL.Image.Image],\n            List[np.ndarray],\n        ] = None,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):\n                `Image` or tensor representing an image batch to be used as the starting point. Can also accept image\n                latents as `image`, but if passing latents directly it is not encoded again.\n            strength (`float`, *optional*, defaults to 0.8):\n                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1\n                essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference. This parameter is modulated by `strength`.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that calls every `callback_steps` steps during inference. The function is called with the\n                following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function is called. If not specified, the callback is called at\n                every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            mask (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`, *optional*):\n                A mask with non-zero elements for the area to be inpainted. If not specified, no mask is applied.\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n        # code adapted from parent class StableDiffusionImg2ImgPipeline\n\n        # 0. Check inputs. Raise error if not correct\n        self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)\n\n        # 1. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 2. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 3. Preprocess image\n        image = self.image_processor.preprocess(image)\n\n        # 4. set timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 5. Prepare latent variables\n        # it is sampled from the latent distribution of the VAE\n        latents = self.prepare_latents(\n            image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator\n        )\n\n        # mean of the latent distribution\n        init_latents = [\n            self.vae.encode(image.to(device=device, dtype=prompt_embeds.dtype)[i : i + 1]).latent_dist.mean\n            for i in range(batch_size)\n        ]\n        init_latents = torch.cat(init_latents, dim=0)\n\n        # 6. create latent mask\n        latent_mask = self._make_latent_mask(latents, mask)\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if latent_mask is not None:\n                    latents = torch.lerp(init_latents * self.vae.config.scaling_factor, latents, latent_mask)\n                    noise_pred = torch.lerp(torch.zeros_like(noise_pred), noise_pred, latent_mask)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            scaled = latents / self.vae.config.scaling_factor\n            if latent_mask is not None:\n                # scaled = latents / self.vae.config.scaling_factor * latent_mask + init_latents * (1 - latent_mask)\n                scaled = torch.lerp(init_latents, scaled, latent_mask)\n            image = self.vae.decode(scaled, return_dict=False)[0]\n            if self.debug_save:\n                image_gen = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n                image_gen = self.image_processor.postprocess(image_gen, output_type=output_type, do_denormalize=[True])\n                image_gen[0].save(\"from_latent.png\")\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n    def _make_latent_mask(self, latents, mask):\n        if mask is not None:\n            latent_mask = []\n            if not isinstance(mask, list):\n                tmp_mask = [mask]\n            else:\n                tmp_mask = mask\n            _, l_channels, l_height, l_width = latents.shape\n            for m in tmp_mask:\n                if not isinstance(m, PIL.Image.Image):\n                    if len(m.shape) == 2:\n                        m = m[..., np.newaxis]\n                    if m.max() > 1:\n                        m = m / 255.0\n                    m = self.image_processor.numpy_to_pil(m)[0]\n                if m.mode != \"L\":\n                    m = m.convert(\"L\")\n                resized = self.image_processor.resize(m, l_height, l_width)\n                if self.debug_save:\n                    resized.save(\"latent_mask.png\")\n                latent_mask.append(np.repeat(np.array(resized)[np.newaxis, :, :], l_channels, axis=0))\n            latent_mask = torch.as_tensor(np.stack(latent_mask)).to(latents)\n            latent_mask = latent_mask / latent_mask.max()\n        return latent_mask\n"
  },
  {
    "path": "examples/community/masked_stable_diffusion_xl_img2img.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom PIL import Image, ImageFilter\n\nfrom diffusers.image_processor import PipelineImageInput\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import (\n    StableDiffusionXLImg2ImgPipeline,\n    rescale_noise_cfg,\n    retrieve_latents,\n    retrieve_timesteps,\n)\nfrom diffusers.utils import (\n    deprecate,\n    is_torch_xla_available,\n    logging,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass MaskedStableDiffusionXLImg2ImgPipeline(StableDiffusionXLImg2ImgPipeline):\n    debug_save = 0\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        image: PipelineImageInput = None,\n        original_image: PipelineImageInput = None,\n        strength: float = 0.3,\n        num_inference_steps: Optional[int] = 50,\n        timesteps: List[int] = None,\n        denoising_start: Optional[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: Optional[float] = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: Optional[float] = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Tuple[int, int] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Tuple[int, int] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        aesthetic_score: float = 6.0,\n        negative_aesthetic_score: float = 2.5,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        mask: Union[\n            torch.FloatTensor,\n            Image.Image,\n            np.ndarray,\n            List[torch.FloatTensor],\n            List[Image.Image],\n            List[np.ndarray],\n        ] = None,\n        blur=24,\n        blur_compose=4,\n        sample_mode=\"sample\",\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            image (`PipelineImageInput`):\n                `Image` or tensor representing an image batch to be used as the starting point. This image might have mask painted on it.\n            original_image (`PipelineImageInput`, *optional*):\n                `Image` or tensor representing an image batch to be used for blending with the result.\n            strength (`float`, *optional*, defaults to 0.8):\n                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1\n                essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference. This parameter is modulated by `strength`.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                ,`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that calls every `callback_steps` steps during inference. The function is called with the\n                following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function is called. If not specified, the callback is called at\n                every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            blur (`int`, *optional*):\n                blur to apply to mask\n            blur_compose (`int`, *optional*):\n                blur to apply for composition of original a\n            mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`, *optional*):\n                A mask with non-zero elements for the area to be inpainted. If not specified, no mask is applied.\n            sample_mode (`str`, *optional*):\n                control latents initialisation for the inpaint area, can be one of sample, argmax, random\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n        # code adapted from parent class StableDiffusionXLImg2ImgPipeline\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n\n        # 0. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            strength,\n            num_inference_steps,\n            callback_steps,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n        self._denoising_start = denoising_start\n        self._interrupt = False\n\n        # 1. Define call parameters\n        # mask is computed from difference between image and original_image\n        if image is not None:\n            neq = np.any(np.array(original_image) != np.array(image), axis=-1)\n            mask = neq.astype(np.uint8) * 255\n        else:\n            assert mask is not None\n\n        if not isinstance(mask, Image.Image):\n            pil_mask = Image.fromarray(mask)\n            if pil_mask.mode != \"L\":\n                pil_mask = pil_mask.convert(\"L\")\n        mask_blur = self.blur_mask(pil_mask, blur)\n        mask_compose = self.blur_mask(pil_mask, blur_compose)\n        if original_image is None:\n            original_image = image\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 2. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # 3. Preprocess image\n        input_image = image if image is not None else original_image\n        image = self.image_processor.preprocess(input_image)\n        original_image = self.image_processor.preprocess(original_image)\n\n        # 4. set timesteps\n        def denoising_value_valid(dnv):\n            return isinstance(dnv, float) and 0 < dnv < 1\n\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n        timesteps, num_inference_steps = self.get_timesteps(\n            num_inference_steps,\n            strength,\n            device,\n            denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,\n        )\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        add_noise = True if self.denoising_start is None else False\n\n        # 5. Prepare latent variables\n        # It is sampled from the latent distribution of the VAE\n        # that's what we repaint\n        latents = self.prepare_latents(\n            image,\n            latent_timestep,\n            batch_size,\n            num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            add_noise,\n            sample_mode=sample_mode,\n        )\n\n        # mean of the latent distribution\n        # it is multiplied by self.vae.config.scaling_factor\n        non_paint_latents = self.prepare_latents(\n            original_image,\n            latent_timestep,\n            batch_size,\n            num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            add_noise=False,\n            sample_mode=\"argmax\",\n        )\n\n        if self.debug_save:\n            init_img_from_latents = self.latents_to_img(non_paint_latents)\n            init_img_from_latents[0].save(\"non_paint_latents.png\")\n        # 6. create latent mask\n        latent_mask = self._make_latent_mask(latents, mask)\n\n        # 7. Prepare extra step kwargs.\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        height, width = latents.shape[-2:]\n        height = height * self.vae_scale_factor\n        width = width * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 8. Prepare added time ids & embeddings\n        if negative_original_size is None:\n            negative_original_size = original_size\n        if negative_target_size is None:\n            negative_target_size = target_size\n\n        add_text_embeds = pooled_prompt_embeds\n        if self.text_encoder_2 is None:\n            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        else:\n            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n        add_time_ids, add_neg_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            aesthetic_score,\n            negative_aesthetic_score,\n            negative_original_size,\n            negative_crops_coords_top_left,\n            negative_target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n            add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device)\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 10. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # 10.1 Apply denoising_end\n        if (\n            self.denoising_end is not None\n            and self.denoising_start is not None\n            and denoising_value_valid(self.denoising_end)\n            and denoising_value_valid(self.denoising_start)\n            and self.denoising_start >= self.denoising_end\n        ):\n            raise ValueError(\n                f\"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: \"\n                + f\" {self.denoising_end} when using type float.\"\n            )\n        elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        # 10.2 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                shape = non_paint_latents.shape\n                noise = randn_tensor(shape, generator=generator, device=device, dtype=latents.dtype)\n                # noisy latent code of input image at current step\n                orig_latents_t = non_paint_latents\n                orig_latents_t = self.scheduler.add_noise(non_paint_latents, noise, t.unsqueeze(0))\n\n                # orig_latents_t (1 - latent_mask) + latents * latent_mask\n                latents = torch.lerp(orig_latents_t, latents, latent_mask)\n\n                if self.debug_save:\n                    img1 = self.latents_to_img(latents)\n                    t_str = str(t.int().item())\n                    for i in range(3 - len(t_str)):\n                        t_str = \"0\" + t_str\n                    img1[0].save(f\"step{t_str}.png\")\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n                    added_cond_kwargs[\"image_embeds\"] = image_embeds\n\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                    negative_pooled_prompt_embeds = callback_outputs.pop(\n                        \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                    )\n                    add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n                    add_neg_time_ids = callback_outputs.pop(\"add_neg_time_ids\", add_neg_time_ids)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n            elif latents.dtype != self.vae.dtype:\n                if torch.backends.mps.is_available():\n                    # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                    self.vae = self.vae.to(latents.dtype)\n\n            if self.debug_save:\n                image_gen = self.latents_to_img(latents)\n                image_gen[0].save(\"from_latent.png\")\n\n            if latent_mask is not None:\n                # interpolate with latent mask\n                latents = torch.lerp(non_paint_latents, latents, latent_mask)\n\n            latents = self.denormalize(latents)\n            image = self.vae.decode(latents, return_dict=False)[0]\n            m = mask_compose.permute(2, 0, 1).unsqueeze(0).to(image)\n            img_compose = m * image + (1 - m) * original_image.to(image)\n            image = img_compose\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n\n        # apply watermark if available\n        if self.watermark is not None:\n            image = self.watermark.apply_watermark(image)\n\n        image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n\n    def _make_latent_mask(self, latents, mask):\n        if mask is not None:\n            latent_mask = []\n            if not isinstance(mask, list):\n                tmp_mask = [mask]\n            else:\n                tmp_mask = mask\n            _, l_channels, l_height, l_width = latents.shape\n            for m in tmp_mask:\n                if not isinstance(m, Image.Image):\n                    if len(m.shape) == 2:\n                        m = m[..., np.newaxis]\n                    if m.max() > 1:\n                        m = m / 255.0\n                    m = self.image_processor.numpy_to_pil(m)[0]\n                if m.mode != \"L\":\n                    m = m.convert(\"L\")\n                resized = self.image_processor.resize(m, l_height, l_width)\n                if self.debug_save:\n                    resized.save(\"latent_mask.png\")\n                latent_mask.append(np.repeat(np.array(resized)[np.newaxis, :, :], l_channels, axis=0))\n            latent_mask = torch.as_tensor(np.stack(latent_mask)).to(latents)\n            latent_mask = latent_mask / max(latent_mask.max(), 1)\n        return latent_mask\n\n    def prepare_latents(\n        self,\n        image,\n        timestep,\n        batch_size,\n        num_images_per_prompt,\n        dtype,\n        device,\n        generator=None,\n        add_noise=True,\n        sample_mode: str = \"sample\",\n    ):\n        if not isinstance(image, (torch.Tensor, Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        # Offload text encoder if `enable_model_cpu_offload` was enabled\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.text_encoder_2.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            init_latents = image\n        elif sample_mode == \"random\":\n            height, width = image.shape[-2:]\n            num_channels_latents = self.unet.config.in_channels\n            latents = self.random_latents(\n                batch_size,\n                num_channels_latents,\n                height,\n                width,\n                dtype,\n                device,\n                generator,\n            )\n            return self.vae.config.scaling_factor * latents\n        else:\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            if self.vae.config.force_upcast:\n                image = image.float()\n                self.vae.to(dtype=torch.float32)\n\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            elif isinstance(generator, list):\n                init_latents = [\n                    retrieve_latents(\n                        self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode\n                    )\n                    for i in range(batch_size)\n                ]\n                init_latents = torch.cat(init_latents, dim=0)\n            else:\n                init_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode=sample_mode)\n\n            if self.vae.config.force_upcast:\n                self.vae.to(dtype)\n\n            init_latents = init_latents.to(dtype)\n            init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        if add_noise:\n            shape = init_latents.shape\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            # get latents\n            init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n\n        latents = init_latents\n\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def random_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def denormalize(self, latents):\n        # unscale/denormalize the latents\n        # denormalize with the mean and std if available and not None\n        has_latents_mean = hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None\n        has_latents_std = hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None\n        if has_latents_mean and has_latents_std:\n            latents_mean = (\n                torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n            )\n            latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n            latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean\n        else:\n            latents = latents / self.vae.config.scaling_factor\n\n        return latents\n\n    def latents_to_img(self, latents):\n        l1 = self.denormalize(latents)\n        img1 = self.vae.decode(l1, return_dict=False)[0]\n        img1 = self.image_processor.postprocess(img1, output_type=\"pil\", do_denormalize=[True])\n        return img1\n\n    def blur_mask(self, pil_mask, blur):\n        mask_blur = pil_mask.filter(ImageFilter.GaussianBlur(radius=blur))\n        mask_blur = np.array(mask_blur)\n        return torch.from_numpy(np.tile(mask_blur / mask_blur.max(), (3, 1, 1)).transpose(1, 2, 0))\n"
  },
  {
    "path": "examples/community/matryoshka.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# Based on [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111).\n# Authors: Jiatao Gu, Shuangfei Zhai, Yizhe Zhang, Josh Susskind, Navdeep Jaitly\n# Code: https://github.com/apple/ml-mdm with MIT license\n#\n# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).\n\n\nimport gc\nimport inspect\nimport math\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom packaging import version\nfrom PIL import Image\nfrom torch import nn\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.configuration_utils import ConfigMixin, FrozenDict, LegacyConfigMixin, register_to_config\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    PeftAdapterMixin,\n    StableDiffusionLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n    UNet2DConditionLoadersMixin,\n)\nfrom diffusers.loaders.single_file_model import FromOriginalModelMixin\nfrom diffusers.models.activations import GELU, get_activation\nfrom diffusers.models.attention_processor import (\n    ADDED_KV_ATTENTION_PROCESSORS,\n    CROSS_ATTENTION_PROCESSORS,\n    Attention,\n    AttentionProcessor,\n    AttnAddedKVProcessor,\n    AttnProcessor,\n    FusedAttnProcessor2_0,\n)\nfrom diffusers.models.downsampling import Downsample2D\nfrom diffusers.models.embeddings import (\n    GaussianFourierProjection,\n    GLIGENTextBoundingboxProjection,\n    ImageHintTimeEmbedding,\n    ImageProjection,\n    ImageTimeEmbedding,\n    TextImageProjection,\n    TextImageTimeEmbedding,\n    TextTimeEmbedding,\n    TimestepEmbedding,\n    Timesteps,\n)\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.models.modeling_utils import LegacyModelMixin, ModelMixin\nfrom diffusers.models.resnet import ResnetBlock2D\nfrom diffusers.models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D\nfrom diffusers.models.upsampling import Upsample2D\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.schedulers.scheduling_utils import SchedulerMixin\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    BaseOutput,\n    deprecate,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import apply_freeu, randn_tensor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm  # type: ignore\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> from diffusers import DiffusionPipeline\n        >>> from diffusers.utils import make_image_grid\n\n        >>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64\n        >>> pipe = DiffusionPipeline.from_pretrained(\"tolgacangoz/matryoshka-diffusion-models\",\n        ...                                         nesting_level=0,\n        ...                                         trust_remote_code=False,  # One needs to give permission for this code to run\n        ...                                         ).to(\"cuda\")\n\n        >>> prompt0 = \"a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree\"\n        >>> prompt = f\"breathtaking {prompt0}. award-winning, professional, highly detailed\"\n        >>> image = pipe(prompt, num_inference_steps=50).images\n        >>> make_image_grid(image, rows=1, cols=len(image))\n\n        >>> # pipe.change_nesting_level(<int>)  # 0, 1, or 2\n        >>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\n# Copied from diffusers.models.attention._chunked_feed_forward\ndef _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):\n    # \"feed_forward_chunk_size\" can be used to save memory\n    if hidden_states.shape[chunk_dim] % chunk_size != 0:\n        raise ValueError(\n            f\"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`.\"\n        )\n\n    num_chunks = hidden_states.shape[chunk_dim] // chunk_size\n    ff_output = torch.cat(\n        [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],\n        dim=chunk_dim,\n    )\n    return ff_output\n\n\n@dataclass\nclass MatryoshkaDDIMSchedulerOutput(BaseOutput):\n    \"\"\"\n    Output class for the scheduler's `step` function output.\n\n    Args:\n        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):\n            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the\n            denoising loop.\n        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):\n            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.\n            `pred_original_sample` can be used to preview progress or for guidance.\n    \"\"\"\n\n    prev_sample: Union[torch.Tensor, List[torch.Tensor]]\n    pred_original_sample: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None\n\n\n# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar\ndef betas_for_alpha_bar(\n    num_diffusion_timesteps,\n    max_beta=0.999,\n    alpha_transform_type=\"cosine\",\n):\n    \"\"\"\n    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of\n    (1-beta) over time from t = [0,1].\n\n    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up\n    to that part of the diffusion process.\n\n\n    Args:\n        num_diffusion_timesteps (`int`): the number of betas to produce.\n        max_beta (`float`): the maximum beta to use; use values lower than 1 to\n                     prevent singularities.\n        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.\n                     Choose from `cosine` or `exp`\n\n    Returns:\n        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs\n    \"\"\"\n    if alpha_transform_type == \"cosine\":\n\n        def alpha_bar_fn(t):\n            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2\n\n    elif alpha_transform_type == \"exp\":\n\n        def alpha_bar_fn(t):\n            return math.exp(t * -12.0)\n\n    else:\n        raise ValueError(f\"Unsupported alpha_transform_type: {alpha_transform_type}\")\n\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))\n    return torch.tensor(betas, dtype=torch.float32)\n\n\n# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr\ndef rescale_zero_terminal_snr(betas):\n    \"\"\"\n    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)\n\n\n    Args:\n        betas (`torch.Tensor`):\n            the betas that the scheduler is being initialized with.\n\n    Returns:\n        `torch.Tensor`: rescaled betas with zero terminal SNR\n    \"\"\"\n    # Convert betas to alphas_bar_sqrt\n    alphas = 1.0 - betas\n    alphas_cumprod = torch.cumprod(alphas, dim=0)\n    alphas_bar_sqrt = alphas_cumprod.sqrt()\n\n    # Store old values.\n    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()\n    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()\n\n    # Shift so the last timestep is zero.\n    alphas_bar_sqrt -= alphas_bar_sqrt_T\n\n    # Scale so the first timestep is back to the old value.\n    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)\n\n    # Convert alphas_bar_sqrt to betas\n    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt\n    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod\n    alphas = torch.cat([alphas_bar[0:1], alphas])\n    betas = 1 - alphas\n\n    return betas\n\n\nclass MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):\n    \"\"\"\n    `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with\n    non-Markovian guidance.\n\n    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic\n    methods the library implements for all schedulers such as loading and saving.\n\n    Args:\n        num_train_timesteps (`int`, defaults to 1000):\n            The number of diffusion steps to train the model.\n        beta_start (`float`, defaults to 0.0001):\n            The starting `beta` value of inference.\n        beta_end (`float`, defaults to 0.02):\n            The final `beta` value.\n        beta_schedule (`str`, defaults to `\"linear\"`):\n            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from\n            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.\n        trained_betas (`np.ndarray`, *optional*):\n            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.\n        clip_sample (`bool`, defaults to `True`):\n            Clip the predicted sample for numerical stability.\n        clip_sample_range (`float`, defaults to 1.0):\n            The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.\n        set_alpha_to_one (`bool`, defaults to `True`):\n            Each diffusion step uses the alphas product value at that step and at the previous one. For the final step\n            there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,\n            otherwise it uses the alpha value at step 0.\n        steps_offset (`int`, defaults to 0):\n            An offset added to the inference steps, as required by some model families.\n        prediction_type (`str`, defaults to `epsilon`, *optional*):\n            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),\n            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen\n            Video](https://imagen.research.google/video/paper.pdf) paper).\n        thresholding (`bool`, defaults to `False`):\n            Whether to use the \"dynamic thresholding\" method. This is unsuitable for latent-space diffusion models such\n            as Stable Diffusion.\n        dynamic_thresholding_ratio (`float`, defaults to 0.995):\n            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.\n        sample_max_value (`float`, defaults to 1.0):\n            The threshold value for dynamic thresholding. Valid only when `thresholding=True`.\n        timestep_spacing (`str`, defaults to `\"leading\"`):\n            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and\n            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.\n        rescale_betas_zero_snr (`bool`, defaults to `False`):\n            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and\n            dark samples instead of limiting it to samples with medium brightness. Loosely related to\n            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).\n    \"\"\"\n\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        beta_start: float = 0.0001,\n        beta_end: float = 0.02,\n        beta_schedule: str = \"linear\",\n        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,\n        clip_sample: bool = True,\n        set_alpha_to_one: bool = True,\n        steps_offset: int = 0,\n        prediction_type: str = \"epsilon\",\n        thresholding: bool = False,\n        dynamic_thresholding_ratio: float = 0.995,\n        clip_sample_range: float = 1.0,\n        sample_max_value: float = 1.0,\n        timestep_spacing: str = \"leading\",\n        rescale_betas_zero_snr: bool = False,\n    ):\n        if trained_betas is not None:\n            self.betas = torch.tensor(trained_betas, dtype=torch.float32)\n        elif beta_schedule == \"linear\":\n            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)\n        elif beta_schedule == \"scaled_linear\":\n            # this schedule is very specific to the latent diffusion model.\n            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2\n        elif beta_schedule == \"squaredcos_cap_v2\":\n            if self.config.timestep_spacing == \"matryoshka_style\":\n                self.betas = torch.cat((torch.tensor([0]), betas_for_alpha_bar(num_train_timesteps)))\n            else:\n                # Glide cosine schedule\n                self.betas = betas_for_alpha_bar(num_train_timesteps)\n        else:\n            raise NotImplementedError(f\"{beta_schedule} is not implemented for {self.__class__}\")\n\n        # Rescale for zero SNR\n        if rescale_betas_zero_snr:\n            self.betas = rescale_zero_terminal_snr(self.betas)\n\n        self.alphas = 1.0 - self.betas\n        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)\n\n        # At every step in ddim, we are looking into the previous alphas_cumprod\n        # For the final step, there is no previous alphas_cumprod because we are already at 0\n        # `set_alpha_to_one` decides whether we set this parameter simply to one or\n        # whether we use the final alpha of the \"non-previous\" one.\n        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]\n\n        # standard deviation of the initial noise distribution\n        self.init_noise_sigma = 1.0\n\n        # setable values\n        self.num_inference_steps = None\n        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))\n\n        self.scales = None\n        self.schedule_shifted_power = 1.0\n\n    def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:\n        \"\"\"\n        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the\n        current timestep.\n\n        Args:\n            sample (`torch.Tensor`):\n                The input sample.\n            timestep (`int`, *optional*):\n                The current timestep in the diffusion chain.\n\n        Returns:\n            `torch.Tensor`:\n                A scaled input sample.\n        \"\"\"\n        return sample\n\n    def _get_variance(self, timestep, prev_timestep):\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod\n        beta_prod_t = 1 - alpha_prod_t\n        beta_prod_t_prev = 1 - alpha_prod_t_prev\n\n        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)\n\n        return variance\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample\n    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        \"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the\n        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by\n        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing\n        pixels from saturation at each step. We find that dynamic thresholding results in significantly better\n        photorealism as well as better image-text alignment, especially when using very large guidance weights.\"\n\n        https://huggingface.co/papers/2205.11487\n        \"\"\"\n        dtype = sample.dtype\n        batch_size, channels, *remaining_dims = sample.shape\n\n        if dtype not in (torch.float32, torch.float64):\n            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half\n\n        # Flatten sample for doing quantile calculation along each image\n        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))\n\n        abs_sample = sample.abs()  # \"a certain percentile absolute pixel value\"\n\n        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)\n        s = torch.clamp(\n            s, min=1, max=self.config.sample_max_value\n        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]\n        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0\n        sample = torch.clamp(sample, -s, s) / s  # \"we threshold xt0 to the range [-s, s] and then divide by s\"\n\n        sample = sample.reshape(batch_size, channels, *remaining_dims)\n        sample = sample.to(dtype)\n\n        return sample\n\n    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain (to be run before inference).\n\n        Args:\n            num_inference_steps (`int`):\n                The number of diffusion steps used when generating samples with a pre-trained model.\n        \"\"\"\n\n        if num_inference_steps > self.config.num_train_timesteps:\n            raise ValueError(\n                f\"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:\"\n                f\" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle\"\n                f\" maximal {self.config.num_train_timesteps} timesteps.\"\n            )\n\n        self.num_inference_steps = num_inference_steps\n\n        # \"linspace\", \"leading\", \"trailing\" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891\n        if self.config.timestep_spacing == \"linspace\":\n            timesteps = (\n                np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)\n                .round()[::-1]\n                .copy()\n                .astype(np.int64)\n            )\n        elif self.config.timestep_spacing == \"leading\":\n            step_ratio = self.config.num_train_timesteps // self.num_inference_steps\n            # creates integer timesteps by multiplying by ratio\n            # casting to int to avoid issues when num_inference_step is power of 3\n            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)\n            timesteps += self.config.steps_offset\n        elif self.config.timestep_spacing == \"trailing\":\n            step_ratio = self.config.num_train_timesteps / self.num_inference_steps\n            # creates integer timesteps by multiplying by ratio\n            # casting to int to avoid issues when num_inference_step is power of 3\n            timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)\n            timesteps -= 1\n        elif self.config.timestep_spacing == \"matryoshka_style\":\n            step_ratio = (self.config.num_train_timesteps + 1) / (num_inference_steps + 1)\n            timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1].copy().astype(np.int64)\n        else:\n            raise ValueError(\n                f\"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'.\"\n            )\n\n        self.timesteps = torch.from_numpy(timesteps).to(device)\n\n    def get_schedule_shifted(self, alpha_prod, scale_factor=None):\n        if (scale_factor is not None) and (scale_factor > 1):  # rescale noise schedule\n            scale_factor = scale_factor**self.schedule_shifted_power\n            snr = alpha_prod / (1 - alpha_prod)\n            scaled_snr = snr / scale_factor\n            alpha_prod = 1 / (1 + 1 / scaled_snr)\n        return alpha_prod\n\n    def step(\n        self,\n        model_output: torch.Tensor,\n        timestep: int,\n        sample: torch.Tensor,\n        eta: float = 0.0,\n        use_clipped_model_output: bool = False,\n        generator=None,\n        variance_noise: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[MatryoshkaDDIMSchedulerOutput, Tuple]:\n        \"\"\"\n        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion\n        process from the learned model outputs (most often the predicted noise).\n\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from learned diffusion model.\n            timestep (`float`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n            eta (`float`):\n                The weight of noise for added noise in diffusion step.\n            use_clipped_model_output (`bool`, defaults to `False`):\n                If `True`, computes \"corrected\" `model_output` from the clipped predicted original sample. Necessary\n                because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no\n                clipping has happened, \"corrected\" `model_output` would coincide with the one provided as input and\n                `use_clipped_model_output` has no effect.\n            generator (`torch.Generator`, *optional*):\n                A random number generator.\n            variance_noise (`torch.Tensor`):\n                Alternative to generating noise with `generator` by directly providing the noise for the variance\n                itself. Useful for methods such as [`CycleDiffusion`].\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.\n\n        Returns:\n            [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:\n                If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a\n                tuple is returned where the first element is the sample tensor.\n\n        \"\"\"\n        if self.num_inference_steps is None:\n            raise ValueError(\n                \"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler\"\n            )\n\n        # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502\n        # Ideally, read DDIM paper in-detail understanding\n\n        # Notation (<variable name> -> <name in paper>\n        # - pred_noise_t -> e_theta(x_t, t)\n        # - pred_original_sample -> f_theta(x_t, t) or x_0\n        # - std_dev_t -> sigma_t\n        # - eta -> η\n        # - pred_sample_direction -> \"direction pointing to x_t\"\n        # - pred_prev_sample -> \"x_t-1\"\n\n        # 1. get previous step value (=t-1)\n        if self.config.timestep_spacing != \"matryoshka_style\":\n            prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps\n        else:\n            prev_timestep = self.timesteps[torch.nonzero(self.timesteps == timestep).item() + 1]\n\n        # 2. compute alphas, betas\n        alpha_prod_t = self.alphas_cumprod[timestep]\n        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod\n\n        if self.config.timestep_spacing == \"matryoshka_style\" and len(model_output) > 1:\n            alpha_prod_t = torch.tensor([self.get_schedule_shifted(alpha_prod_t, s) for s in self.scales])\n            alpha_prod_t_prev = torch.tensor([self.get_schedule_shifted(alpha_prod_t_prev, s) for s in self.scales])\n\n        beta_prod_t = 1 - alpha_prod_t\n\n        # 3. compute predicted original sample from predicted noise also called\n        # \"predicted x_0\" of formula (12) from https://huggingface.co/papers/2010.02502\n        if self.config.prediction_type == \"epsilon\":\n            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n            pred_epsilon = model_output\n        elif self.config.prediction_type == \"sample\":\n            pred_original_sample = model_output\n            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)\n        elif self.config.prediction_type == \"v_prediction\":\n            if len(model_output) > 1:\n                pred_original_sample = []\n                pred_epsilon = []\n                for m_o, s, a_p_t, b_p_t in zip(model_output, sample, alpha_prod_t, beta_prod_t):\n                    pred_original_sample.append((a_p_t**0.5) * s - (b_p_t**0.5) * m_o)\n                    pred_epsilon.append((a_p_t**0.5) * m_o + (b_p_t**0.5) * s)\n            else:\n                pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output\n                pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample\n        else:\n            raise ValueError(\n                f\"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or\"\n                \" `v_prediction`\"\n            )\n\n        # 4. Clip or threshold \"predicted x_0\"\n        if self.config.thresholding:\n            if len(model_output) > 1:\n                pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample]\n            else:\n                pred_original_sample = self._threshold_sample(pred_original_sample)\n        elif self.config.clip_sample:\n            if len(model_output) > 1:\n                pred_original_sample = [\n                    p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)\n                    for p_o_s in pred_original_sample\n                ]\n            else:\n                pred_original_sample = pred_original_sample.clamp(\n                    -self.config.clip_sample_range, self.config.clip_sample_range\n                )\n\n        # 5. compute variance: \"sigma_t(η)\" -> see formula (16)\n        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)\n        variance = self._get_variance(timestep, prev_timestep)\n        std_dev_t = eta * variance ** (0.5)\n\n        if use_clipped_model_output:\n            # the pred_epsilon is always re-derived from the clipped x_0 in Glide\n            if len(model_output) > 1:\n                pred_epsilon = []\n                for s, a_p_t, p_o_s, b_p_t in zip(sample, alpha_prod_t, pred_original_sample, beta_prod_t):\n                    pred_epsilon.append((s - a_p_t ** (0.5) * p_o_s) / b_p_t ** (0.5))\n            else:\n                pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)\n\n        # 6. compute \"direction pointing to x_t\" of formula (12) from https://huggingface.co/papers/2010.02502\n        if len(model_output) > 1:\n            pred_sample_direction = []\n            for p_e, a_p_t_p in zip(pred_epsilon, alpha_prod_t_prev):\n                pred_sample_direction.append((1 - a_p_t_p - std_dev_t**2) ** (0.5) * p_e)\n        else:\n            pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon\n\n        # 7. compute x_t without \"random noise\" of formula (12) from https://huggingface.co/papers/2010.02502\n        if len(model_output) > 1:\n            prev_sample = []\n            for p_o_s, p_s_d, a_p_t_p in zip(pred_original_sample, pred_sample_direction, alpha_prod_t_prev):\n                prev_sample.append(a_p_t_p ** (0.5) * p_o_s + p_s_d)\n        else:\n            prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction\n\n        if eta > 0:\n            if variance_noise is not None and generator is not None:\n                raise ValueError(\n                    \"Cannot pass both generator and variance_noise. Please make sure that either `generator` or\"\n                    \" `variance_noise` stays `None`.\"\n                )\n\n            if variance_noise is None:\n                if len(model_output) > 1:\n                    variance_noise = []\n                    for m_o in model_output:\n                        variance_noise.append(\n                            randn_tensor(m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype)\n                        )\n                else:\n                    variance_noise = randn_tensor(\n                        model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype\n                    )\n            if len(model_output) > 1:\n                prev_sample = [p_s + std_dev_t * v_n for v_n, p_s in zip(variance_noise, prev_sample)]\n            else:\n                variance = std_dev_t * variance_noise\n\n                prev_sample = prev_sample + variance\n\n        if not return_dict:\n            return (prev_sample,)\n\n        return MatryoshkaDDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise\n    def add_noise(\n        self,\n        original_samples: torch.Tensor,\n        noise: torch.Tensor,\n        timesteps: torch.IntTensor,\n    ) -> torch.Tensor:\n        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples\n        # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement\n        # for the subsequent add_noise calls\n        self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)\n        alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)\n        timesteps = timesteps.to(original_samples.device)\n\n        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5\n        sqrt_alpha_prod = sqrt_alpha_prod.flatten()\n        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):\n            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)\n\n        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5\n        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()\n        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):\n            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)\n\n        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise\n        return noisy_samples\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity\n    def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:\n        # Make sure alphas_cumprod and timestep have same device and dtype as sample\n        self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)\n        alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)\n        timesteps = timesteps.to(sample.device)\n\n        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5\n        sqrt_alpha_prod = sqrt_alpha_prod.flatten()\n        while len(sqrt_alpha_prod.shape) < len(sample.shape):\n            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)\n\n        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5\n        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()\n        while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):\n            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)\n\n        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample\n        return velocity\n\n    def __len__(self):\n        return self.config.num_train_timesteps\n\n\nclass CrossAttnDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        norm_type: str = \"layer_norm\",\n        num_attention_heads: int = 1,\n        cross_attention_dim: int = 1280,\n        cross_attention_norm: str | None = None,\n        output_scale_factor: float = 1.0,\n        downsample_padding: int = 1,\n        add_downsample: bool = True,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        only_cross_attention: bool = False,\n        upcast_attention: bool = False,\n        attention_type: str = \"default\",\n        attention_pre_only: bool = False,\n        attention_bias: bool = False,\n        use_attention_ffn: bool = True,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.has_cross_attention = True\n        self.num_attention_heads = num_attention_heads\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * num_layers\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            attentions.append(\n                MatryoshkaTransformer2DModel(\n                    num_attention_heads,\n                    out_channels // num_attention_heads,\n                    in_channels=out_channels,\n                    num_layers=transformer_layers_per_block[i],\n                    cross_attention_dim=cross_attention_dim,\n                    upcast_attention=upcast_attention,\n                    use_attention_ffn=use_attention_ffn,\n                )\n            )\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        temb: Optional[torch.Tensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        additional_residuals: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:\n        if cross_attention_kwargs is not None:\n            if cross_attention_kwargs.get(\"scale\", None) is not None:\n                logger.warning(\"Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.\")\n\n        output_states = ()\n\n        blocks = list(zip(self.resnets, self.attentions))\n\n        for i, (resnet, attn) in enumerate(blocks):\n            if torch.is_grad_enabled() and self.gradient_checkpointing:\n                hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n            else:\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n\n            # apply additional residuals to the output of the last pair of resnet and attention blocks\n            if i == len(blocks) - 1 and additional_residuals is not None:\n                hidden_states = hidden_states + additional_residuals\n\n            output_states = output_states + (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n            output_states = output_states + (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass UNetMidBlock2DCrossAttn(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        temb_channels: int,\n        out_channels: Optional[int] = None,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_groups_out: Optional[int] = None,\n        resnet_pre_norm: bool = True,\n        norm_type: str = \"layer_norm\",\n        num_attention_heads: int = 1,\n        output_scale_factor: float = 1.0,\n        cross_attention_dim: int = 1280,\n        cross_attention_norm: str | None = None,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        upcast_attention: bool = False,\n        attention_type: str = \"default\",\n        attention_pre_only: bool = False,\n        attention_bias: bool = False,\n        use_attention_ffn: bool = True,\n    ):\n        super().__init__()\n\n        out_channels = out_channels or in_channels\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        self.has_cross_attention = True\n        self.num_attention_heads = num_attention_heads\n        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)\n\n        # support for variable transformer layers per block\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * num_layers\n\n        resnet_groups_out = resnet_groups_out or resnet_groups\n\n        # there is always at least one resnet\n        resnets = [\n            ResnetBlock2D(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=resnet_groups,\n                groups_out=resnet_groups_out,\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n            )\n        ]\n        attentions = []\n\n        for i in range(num_layers):\n            attentions.append(\n                MatryoshkaTransformer2DModel(\n                    num_attention_heads,\n                    out_channels // num_attention_heads,\n                    in_channels=out_channels,\n                    num_layers=transformer_layers_per_block[i],\n                    cross_attention_dim=cross_attention_dim,\n                    upcast_attention=upcast_attention,\n                    use_attention_ffn=use_attention_ffn,\n                )\n            )\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=out_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups_out,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        temb: Optional[torch.Tensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if cross_attention_kwargs is not None:\n            if cross_attention_kwargs.get(\"scale\", None) is not None:\n                logger.warning(\"Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.\")\n\n        hidden_states = self.resnets[0](hidden_states, temb)\n        for attn, resnet in zip(self.attentions, self.resnets[1:]):\n            if torch.is_grad_enabled() and self.gradient_checkpointing:\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)\n            else:\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                hidden_states = resnet(hidden_states, temb)\n\n        return hidden_states\n\n\nclass CrossAttnUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        prev_output_channel: int,\n        temb_channels: int,\n        resolution_idx: Optional[int] = None,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        norm_type: str = \"layer_norm\",\n        num_attention_heads: int = 1,\n        cross_attention_dim: int = 1280,\n        cross_attention_norm: str | None = None,\n        output_scale_factor: float = 1.0,\n        add_upsample: bool = True,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        only_cross_attention: bool = False,\n        upcast_attention: bool = False,\n        attention_type: str = \"default\",\n        attention_pre_only: bool = False,\n        attention_bias: bool = False,\n        use_attention_ffn: bool = True,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.has_cross_attention = True\n        self.num_attention_heads = num_attention_heads\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * num_layers\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            attentions.append(\n                MatryoshkaTransformer2DModel(\n                    num_attention_heads,\n                    out_channels // num_attention_heads,\n                    in_channels=out_channels,\n                    num_layers=transformer_layers_per_block[i],\n                    cross_attention_dim=cross_attention_dim,\n                    upcast_attention=upcast_attention,\n                    use_attention_ffn=use_attention_ffn,\n                )\n            )\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n        self.resolution_idx = resolution_idx\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        res_hidden_states_tuple: Tuple[torch.Tensor, ...],\n        temb: Optional[torch.Tensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        upsample_size: Optional[int] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if cross_attention_kwargs is not None:\n            if cross_attention_kwargs.get(\"scale\", None) is not None:\n                logger.warning(\"Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.\")\n\n        is_freeu_enabled = (\n            getattr(self, \"s1\", None)\n            and getattr(self, \"s2\", None)\n            and getattr(self, \"b1\", None)\n            and getattr(self, \"b2\", None)\n        )\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n\n            # FreeU: Only operate on the first two stages\n            if is_freeu_enabled:\n                hidden_states, res_hidden_states = apply_freeu(\n                    self.resolution_idx,\n                    hidden_states,\n                    res_hidden_states,\n                    s1=self.s1,\n                    s2=self.s2,\n                    b1=self.b1,\n                    b2=self.b2,\n                )\n\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            if torch.is_grad_enabled() and self.gradient_checkpointing:\n                hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n            else:\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, upsample_size)\n\n        return hidden_states\n\n\n@dataclass\nclass MatryoshkaTransformer2DModelOutput(BaseOutput):\n    \"\"\"\n    The output of [`MatryoshkaTransformer2DModel`].\n\n    Args:\n        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`MatryoshkaTransformer2DModel`] is discrete):\n            The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability\n            distributions for the unnoised latent pixels.\n    \"\"\"\n\n    sample: \"torch.Tensor\"  # noqa: F821\n\n\nclass MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):\n    _supports_gradient_checkpointing = True\n    _no_split_modules = [\"MatryoshkaTransformerBlock\"]\n\n    @register_to_config\n    def __init__(\n        self,\n        num_attention_heads: int = 16,\n        attention_head_dim: int = 88,\n        in_channels: Optional[int] = None,\n        num_layers: int = 1,\n        cross_attention_dim: Optional[int] = None,\n        upcast_attention: bool = False,\n        use_attention_ffn: bool = True,\n    ):\n        super().__init__()\n        self.in_channels = self.config.num_attention_heads * self.config.attention_head_dim\n        self.gradient_checkpointing = False\n\n        self.transformer_blocks = nn.ModuleList(\n            [\n                MatryoshkaTransformerBlock(\n                    self.in_channels,\n                    self.config.num_attention_heads,\n                    self.config.attention_head_dim,\n                    cross_attention_dim=self.config.cross_attention_dim,\n                    upcast_attention=self.config.upcast_attention,\n                    use_attention_ffn=self.config.use_attention_ffn,\n                )\n                for _ in range(self.config.num_layers)\n            ]\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        timestep: Optional[torch.LongTensor] = None,\n        added_cond_kwargs: Dict[str, torch.Tensor] = None,\n        class_labels: Optional[torch.LongTensor] = None,\n        cross_attention_kwargs: Dict[str, Any] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ):\n        \"\"\"\n        The [`MatryoshkaTransformer2DModel`] forward method.\n\n        Args:\n            hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous):\n                Input `hidden_states`.\n            encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):\n                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to\n                self-attention.\n            timestep ( `torch.LongTensor`, *optional*):\n                Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.\n            class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):\n                Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in\n                `AdaLayerZeroNorm`.\n            cross_attention_kwargs ( `Dict[str, Any]`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            attention_mask ( `torch.Tensor`, *optional*):\n                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask\n                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large\n                negative values to the attention scores corresponding to \"discard\" tokens.\n            encoder_attention_mask ( `torch.Tensor`, *optional*):\n                Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:\n\n                    * Mask `(batch, sequence_length)` True = keep, False = discard.\n                    * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.\n\n                If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format\n                above. This bias will be added to the cross-attention scores.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain\n                tuple.\n\n        Returns:\n            If `return_dict` is True, an [`~MatryoshkaTransformer2DModelOutput`] is returned,\n            otherwise a `tuple` where the first element is the sample tensor.\n        \"\"\"\n        if cross_attention_kwargs is not None:\n            if cross_attention_kwargs.get(\"scale\", None) is not None:\n                logger.warning(\"Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.\")\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.\n        #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.\n        #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None and attention_mask.ndim == 2:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # Blocks\n        for block in self.transformer_blocks:\n            if torch.is_grad_enabled() and self.gradient_checkpointing:\n                hidden_states = self._gradient_checkpointing_func(\n                    block,\n                    hidden_states,\n                    attention_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    timestep,\n                    cross_attention_kwargs,\n                    class_labels,\n                )\n            else:\n                hidden_states = block(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    encoder_hidden_states=encoder_hidden_states,\n                    encoder_attention_mask=encoder_attention_mask,\n                    timestep=timestep,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    class_labels=class_labels,\n                )\n\n        # Output\n        output = hidden_states\n\n        if not return_dict:\n            return (output,)\n\n        return MatryoshkaTransformer2DModelOutput(sample=output)\n\n\nclass MatryoshkaTransformerBlock(nn.Module):\n    r\"\"\"\n    Matryoshka Transformer block.\n\n    Parameters:\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        cross_attention_dim: Optional[int] = None,\n        upcast_attention: bool = False,\n        use_attention_ffn: bool = True,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.num_attention_heads = num_attention_heads\n        self.attention_head_dim = attention_head_dim\n        self.cross_attention_dim = cross_attention_dim\n\n        # Define 3 blocks.\n        # 1. Self-Attn\n        self.attn1 = Attention(\n            query_dim=dim,\n            cross_attention_dim=None,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            norm_num_groups=32,\n            bias=True,\n            upcast_attention=upcast_attention,\n            pre_only=True,\n            processor=MatryoshkaFusedAttnProcessor2_0(),\n        )\n        self.attn1.fuse_projections()\n        del self.attn1.to_q\n        del self.attn1.to_k\n        del self.attn1.to_v\n\n        # 2. Cross-Attn\n        if cross_attention_dim is not None and cross_attention_dim > 0:\n            self.attn2 = Attention(\n                query_dim=dim,\n                cross_attention_dim=cross_attention_dim,\n                cross_attention_norm=\"layer_norm\",\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n                bias=True,\n                upcast_attention=upcast_attention,\n                pre_only=True,\n                processor=MatryoshkaFusedAttnProcessor2_0(),\n            )\n            self.attn2.fuse_projections()\n            del self.attn2.to_q\n            del self.attn2.to_k\n            del self.attn2.to_v\n\n        self.proj_out = nn.Linear(dim, dim)\n\n        if use_attention_ffn:\n            # 3. Feed-forward\n            self.ff = MatryoshkaFeedForward(dim)\n        else:\n            self.ff = None\n\n        # let chunk size default to None\n        self._chunk_size = None\n        self._chunk_dim = 0\n\n    # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward\n    def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):\n        # Sets chunk feed-forward\n        self._chunk_size = chunk_size\n        self._chunk_dim = dim\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        timestep: Optional[torch.LongTensor] = None,\n        cross_attention_kwargs: Dict[str, Any] = None,\n        class_labels: Optional[torch.LongTensor] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n    ) -> torch.Tensor:\n        if cross_attention_kwargs is not None:\n            if cross_attention_kwargs.get(\"scale\", None) is not None:\n                logger.warning(\"Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.\")\n\n        # 1. Self-Attention\n        batch_size, channels, *spatial_dims = hidden_states.shape\n\n        attn_output, query = self.attn1(\n            hidden_states,\n            # **cross_attention_kwargs,\n        )\n\n        # 2. Cross-Attention\n        if self.cross_attention_dim is not None and self.cross_attention_dim > 0:\n            attn_output_cond = self.attn2(\n                hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=encoder_attention_mask,\n                self_attention_output=attn_output,\n                self_attention_query=query,\n                # **cross_attention_kwargs,\n            )\n\n        attn_output_cond = self.proj_out(attn_output_cond)\n        attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims)\n        hidden_states = hidden_states + attn_output_cond\n\n        if self.ff is not None:\n            # 3. Feed-forward\n            if self._chunk_size is not None:\n                # \"feed_forward_chunk_size\" can be used to save memory\n                ff_output = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size)\n            else:\n                ff_output = self.ff(hidden_states)\n\n            hidden_states = ff_output + hidden_states\n\n        return hidden_states\n\n\nclass MatryoshkaFusedAttnProcessor2_0:\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses\n    fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.\n    For cross-attention modules, key and value projection matrices are fused.\n\n    > [!WARNING]\n    > This API is currently 🧪 experimental in nature and can change in future.\n    \"\"\"\n\n    def __init__(self):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\n                \"MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x.\"\n            )\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        temb: Optional[torch.Tensor] = None,\n        self_attention_query: Optional[torch.Tensor] = None,\n        self_attention_output: Optional[torch.Tensor] = None,\n        *args,\n        **kwargs,\n    ) -> torch.Tensor:\n        if len(args) > 0 or kwargs.get(\"scale\", None) is not None:\n            deprecation_message = \"The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.\"\n            deprecate(\"scale\", \"1.0.0\", deprecation_message)\n\n        residual = hidden_states\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states)\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2).contiguous()\n\n        if encoder_hidden_states is None:\n            qkv = attn.to_qkv(hidden_states)\n            split_size = qkv.shape[-1] // 3\n            query, key, value = torch.split(qkv, split_size, dim=-1)\n        else:\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n            if self_attention_query is not None:\n                query = self_attention_query\n            else:\n                query = attn.to_q(hidden_states)\n\n            kv = attn.to_kv(encoder_hidden_states)\n            split_size = kv.shape[-1] // 2\n            key, value = torch.split(kv, split_size, dim=-1)\n\n        if attn.norm_q is not None:\n            query = attn.norm_q(query)\n        if attn.norm_k is not None:\n            key = attn.norm_k(key)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        if self_attention_output is None:\n            query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        if attn.norm_q is not None:\n            query = attn.norm_q(query)\n        if attn.norm_k is not None:\n            key = attn.norm_k(key)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.to(query.dtype)\n\n        if self_attention_output is not None:\n            hidden_states = hidden_states + self_attention_output\n            hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states if self_attention_output is not None else (hidden_states, query)\n\n\nclass MatryoshkaFeedForward(nn.Module):\n    r\"\"\"\n    A feed-forward layer for the Matryoshka models.\n\n    Parameters:\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n    ):\n        super().__init__()\n\n        self.group_norm = nn.GroupNorm(32, dim)\n        self.linear_gelu = GELU(dim, dim * 4)\n        self.linear_out = nn.Linear(dim * 4, dim)\n\n    def forward(self, x):\n        batch_size, channels, *spatial_dims = x.shape\n        x = self.group_norm(x)\n        x = x.view(batch_size, channels, -1).permute(0, 2, 1)\n        x = self.linear_out(self.linear_gelu(x))\n        x = x.permute(0, 2, 1).view(batch_size, channels, *spatial_dims)\n        return x\n\n\ndef get_down_block(\n    down_block_type: str,\n    num_layers: int,\n    in_channels: int,\n    out_channels: int,\n    temb_channels: int,\n    add_downsample: bool,\n    resnet_eps: float,\n    resnet_act_fn: str,\n    norm_type: str = \"layer_norm\",\n    transformer_layers_per_block: int = 1,\n    num_attention_heads: Optional[int] = None,\n    resnet_groups: Optional[int] = None,\n    cross_attention_dim: Optional[int] = None,\n    downsample_padding: Optional[int] = None,\n    dual_cross_attention: bool = False,\n    use_linear_projection: bool = False,\n    only_cross_attention: bool = False,\n    upcast_attention: bool = False,\n    resnet_time_scale_shift: str = \"default\",\n    attention_type: str = \"default\",\n    attention_pre_only: bool = False,\n    resnet_skip_time_act: bool = False,\n    resnet_out_scale_factor: float = 1.0,\n    cross_attention_norm: str | None = None,\n    attention_head_dim: Optional[int] = None,\n    use_attention_ffn: bool = True,\n    downsample_type: str | None = None,\n    dropout: float = 0.0,\n):\n    # If attn head dim is not defined, we default it to the number of heads\n    if attention_head_dim is None:\n        logger.warning(\n            f\"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}.\"\n        )\n        attention_head_dim = num_attention_heads\n\n    down_block_type = down_block_type[7:] if down_block_type.startswith(\"UNetRes\") else down_block_type\n    if down_block_type == \"DownBlock2D\":\n        return DownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            dropout=dropout,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"CrossAttnDownBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for CrossAttnDownBlock2D\")\n        return CrossAttnDownBlock2D(\n            num_layers=num_layers,\n            transformer_layers_per_block=transformer_layers_per_block,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            dropout=dropout,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            norm_type=norm_type,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            cross_attention_dim=cross_attention_dim,\n            cross_attention_norm=cross_attention_norm,\n            num_attention_heads=num_attention_heads,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            only_cross_attention=only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            attention_type=attention_type,\n            attention_pre_only=attention_pre_only,\n            use_attention_ffn=use_attention_ffn,\n        )\n\n\ndef get_mid_block(\n    mid_block_type: str,\n    temb_channels: int,\n    in_channels: int,\n    resnet_eps: float,\n    resnet_act_fn: str,\n    resnet_groups: int,\n    norm_type: str = \"layer_norm\",\n    output_scale_factor: float = 1.0,\n    transformer_layers_per_block: int = 1,\n    num_attention_heads: Optional[int] = None,\n    cross_attention_dim: Optional[int] = None,\n    dual_cross_attention: bool = False,\n    use_linear_projection: bool = False,\n    mid_block_only_cross_attention: bool = False,\n    upcast_attention: bool = False,\n    resnet_time_scale_shift: str = \"default\",\n    attention_type: str = \"default\",\n    attention_pre_only: bool = False,\n    resnet_skip_time_act: bool = False,\n    cross_attention_norm: str | None = None,\n    attention_head_dim: Optional[int] = 1,\n    dropout: float = 0.0,\n):\n    if mid_block_type == \"UNetMidBlock2DCrossAttn\":\n        return UNetMidBlock2DCrossAttn(\n            transformer_layers_per_block=transformer_layers_per_block,\n            in_channels=in_channels,\n            temb_channels=temb_channels,\n            dropout=dropout,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            norm_type=norm_type,\n            output_scale_factor=output_scale_factor,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            cross_attention_dim=cross_attention_dim,\n            cross_attention_norm=cross_attention_norm,\n            num_attention_heads=num_attention_heads,\n            resnet_groups=resnet_groups,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            upcast_attention=upcast_attention,\n            attention_type=attention_type,\n            attention_pre_only=attention_pre_only,\n        )\n\n\ndef get_up_block(\n    up_block_type: str,\n    num_layers: int,\n    in_channels: int,\n    out_channels: int,\n    prev_output_channel: int,\n    temb_channels: int,\n    add_upsample: bool,\n    resnet_eps: float,\n    resnet_act_fn: str,\n    norm_type: str = \"layer_norm\",\n    resolution_idx: Optional[int] = None,\n    transformer_layers_per_block: int = 1,\n    num_attention_heads: Optional[int] = None,\n    resnet_groups: Optional[int] = None,\n    cross_attention_dim: Optional[int] = None,\n    dual_cross_attention: bool = False,\n    use_linear_projection: bool = False,\n    only_cross_attention: bool = False,\n    upcast_attention: bool = False,\n    resnet_time_scale_shift: str = \"default\",\n    attention_type: str = \"default\",\n    attention_pre_only: bool = False,\n    resnet_skip_time_act: bool = False,\n    resnet_out_scale_factor: float = 1.0,\n    cross_attention_norm: str | None = None,\n    attention_head_dim: Optional[int] = None,\n    use_attention_ffn: bool = True,\n    upsample_type: str | None = None,\n    dropout: float = 0.0,\n) -> nn.Module:\n    # If attn head dim is not defined, we default it to the number of heads\n    if attention_head_dim is None:\n        logger.warning(\n            f\"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}.\"\n        )\n        attention_head_dim = num_attention_heads\n\n    up_block_type = up_block_type[7:] if up_block_type.startswith(\"UNetRes\") else up_block_type\n    if up_block_type == \"UpBlock2D\":\n        return UpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif up_block_type == \"CrossAttnUpBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for CrossAttnUpBlock2D\")\n        return CrossAttnUpBlock2D(\n            num_layers=num_layers,\n            transformer_layers_per_block=transformer_layers_per_block,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            resolution_idx=resolution_idx,\n            dropout=dropout,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            norm_type=norm_type,\n            resnet_groups=resnet_groups,\n            cross_attention_dim=cross_attention_dim,\n            cross_attention_norm=cross_attention_norm,\n            num_attention_heads=num_attention_heads,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            only_cross_attention=only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            attention_type=attention_type,\n            attention_pre_only=attention_pre_only,\n            use_attention_ffn=use_attention_ffn,\n        )\n\n\nclass MatryoshkaCombinedTimestepTextEmbedding(nn.Module):\n    def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim, type):\n        super().__init__()\n        if type == \"unet\":\n            self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False)\n        elif type == \"nested_unet\":\n            self.cond_emb = None\n        self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0)\n        self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim)\n\n    def forward(self, emb, encoder_hidden_states, added_cond_kwargs):\n        conditioning_mask = added_cond_kwargs.get(\"conditioning_mask\", None)\n        masked_cross_attention = added_cond_kwargs.get(\"masked_cross_attention\", False)\n        if self.cond_emb is not None and not added_cond_kwargs.get(\"from_nested\", False):\n            if conditioning_mask is None:\n                y = encoder_hidden_states.mean(dim=1)\n            else:\n                y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum(\n                    dim=1, keepdim=True\n                )\n            cond_emb = self.cond_emb(y)\n        else:\n            cond_emb = None\n\n        if not masked_cross_attention:\n            conditioning_mask = None\n\n        micro = added_cond_kwargs.get(\"micro_conditioning_scale\", None)\n        if micro is not None:\n            temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype))\n            temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype))\n            # if self.cond_emb is not None and not added_cond_kwargs.get(\"from_nested\", False):\n            return temb_micro_conditioning, conditioning_mask, cond_emb\n\n        return None, conditioning_mask, cond_emb\n\n\n@dataclass\nclass MatryoshkaUNet2DConditionOutput(BaseOutput):\n    \"\"\"\n    The output of [`MatryoshkaUNet2DConditionOutput`].\n\n    Args:\n        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.\n    \"\"\"\n\n    sample: torch.Tensor = None\n    sample_inner: torch.Tensor = None\n\n\nclass MatryoshkaUNet2DConditionModel(\n    ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin\n):\n    r\"\"\"\n    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample\n    shaped output.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):\n            Height and width of input/output sample.\n        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.\n        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.\n        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.\n        flip_sin_to_cos (`bool`, *optional*, defaults to `True`):\n            Whether to flip the sin to cos in the time embedding.\n        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"DownBlock2D\")`):\n            The tuple of downsample blocks to use.\n        mid_block_type (`str`, *optional*, defaults to `\"UNetMidBlock2DCrossAttn\"`):\n            Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or\n            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\")`):\n            The tuple of upsample blocks to use.\n        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):\n            Whether to include self-attention in the basic transformer blocks, see\n            [`~models.attention.BasicTransformerBlock`].\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.\n        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.\n        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`): The activation function to use.\n        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.\n            If `None`, normalization and activation layers is skipped in post-processing.\n        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.\n        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n        reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling\n            blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for\n            [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n        encoder_hid_dim (`int`, *optional*, defaults to None):\n            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`\n            dimension to `cross_attention_dim`.\n        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):\n            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text\n            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.\n        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.\n        num_attention_heads (`int`, *optional*):\n            The number of attention heads. If not defined, defaults to `attention_head_dim`\n        resnet_time_scale_shift (`str`, *optional*, defaults to `\"default\"`): Time scale shift config\n            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.\n        class_embed_type (`str`, *optional*, defaults to `None`):\n            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,\n            `\"timestep\"`, `\"identity\"`, `\"projection\"`, or `\"simple_projection\"`.\n        addition_embed_type (`str`, *optional*, defaults to `None`):\n            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or\n            \"text\". \"text\" will use the `TextTimeEmbedding` layer.\n        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):\n            Dimension for the timestep embeddings.\n        num_class_embeds (`int`, *optional*, defaults to `None`):\n            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing\n            class conditioning with `class_embed_type` equal to `None`.\n        time_embedding_type (`str`, *optional*, defaults to `positional`):\n            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.\n        time_embedding_dim (`int`, *optional*, defaults to `None`):\n            An optional override for the dimension of the projected time embedding.\n        time_embedding_act_fn (`str`, *optional*, defaults to `None`):\n            Optional activation function to use only once on the time embeddings before they are passed to the rest of\n            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.\n        timestep_post_act (`str`, *optional*, defaults to `None`):\n            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.\n        time_cond_proj_dim (`int`, *optional*, defaults to `None`):\n            The dimension of `cond_proj` layer in the timestep embedding.\n        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.\n        conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.\n        projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when\n            `class_embed_type=\"projection\"`. Required when `class_embed_type=\"projection\"`.\n        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time\n            embeddings with the class embeddings.\n        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):\n            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If\n            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the\n            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`\n            otherwise.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n    _no_split_modules = [\"MatryoshkaTransformerBlock\", \"ResnetBlock2D\", \"CrossAttnUpBlock2D\"]\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 3,\n        out_channels: int = 3,\n        center_input_sample: bool = False,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str, ...] = (\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"DownBlock2D\",\n        ),\n        mid_block_type: str | None = \"UNetMidBlock2DCrossAttn\",\n        up_block_types: Tuple[str, ...] = (\n            \"UpBlock2D\",\n            \"CrossAttnUpBlock2D\",\n            \"CrossAttnUpBlock2D\",\n            \"CrossAttnUpBlock2D\",\n        ),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),\n        layers_per_block: Union[int, Tuple[int]] = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        dropout: float = 0.0,\n        act_fn: str = \"silu\",\n        norm_type: str = \"layer_norm\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: Union[int, Tuple[int]] = 1280,\n        transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,\n        reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,\n        encoder_hid_dim: Optional[int] = None,\n        encoder_hid_dim_type: str | None = None,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,\n        dual_cross_attention: bool = False,\n        use_attention_ffn: bool = True,\n        use_linear_projection: bool = False,\n        class_embed_type: str | None = None,\n        addition_embed_type: str | None = None,\n        addition_time_embed_dim: Optional[int] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_skip_time_act: bool = False,\n        resnet_out_scale_factor: float = 1.0,\n        time_embedding_type: str = \"positional\",\n        time_embedding_dim: Optional[int] = None,\n        time_embedding_act_fn: str | None = None,\n        timestep_post_act: str | None = None,\n        time_cond_proj_dim: Optional[int] = None,\n        conv_in_kernel: int = 3,\n        conv_out_kernel: int = 3,\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        attention_type: str = \"default\",\n        attention_pre_only: bool = False,\n        masked_cross_attention: bool = False,\n        micro_conditioning_scale: int = None,\n        class_embeddings_concat: bool = False,\n        mid_block_only_cross_attention: Optional[bool] = None,\n        cross_attention_norm: str | None = None,\n        addition_embed_type_num_heads: int = 64,\n        temporal_mode: bool = False,\n        temporal_spatial_ds: bool = False,\n        skip_cond_emb: bool = False,\n        nesting: Optional[int] = False,\n    ):\n        super().__init__()\n\n        self.sample_size = sample_size\n\n        if num_attention_heads is not None:\n            raise ValueError(\n                \"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19.\"\n            )\n\n        # If `num_attention_heads` is not defined (which is the case for most models)\n        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.\n        # The reason for this behavior is to correct for incorrectly named variables that were introduced\n        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131\n        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking\n        # which is why we correct for the naming here.\n        num_attention_heads = num_attention_heads or attention_head_dim\n\n        # Check inputs\n        self._check_config(\n            down_block_types=down_block_types,\n            up_block_types=up_block_types,\n            only_cross_attention=only_cross_attention,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            cross_attention_dim=cross_attention_dim,\n            transformer_layers_per_block=transformer_layers_per_block,\n            reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,\n            attention_head_dim=attention_head_dim,\n            num_attention_heads=num_attention_heads,\n        )\n\n        # input\n        conv_in_padding = (conv_in_kernel - 1) // 2\n        self.conv_in = nn.Conv2d(\n            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding\n        )\n\n        # time\n        time_embed_dim, timestep_input_dim = self._set_time_proj(\n            time_embedding_type,\n            block_out_channels=block_out_channels,\n            flip_sin_to_cos=flip_sin_to_cos,\n            freq_shift=freq_shift,\n            time_embedding_dim=time_embedding_dim,\n        )\n\n        self.time_embedding = TimestepEmbedding(\n            time_embedding_dim // 4 if time_embedding_dim is not None else timestep_input_dim,\n            time_embed_dim,\n            act_fn=act_fn,\n            post_act_fn=timestep_post_act,\n            cond_proj_dim=time_cond_proj_dim,\n        )\n\n        self._set_encoder_hid_proj(\n            encoder_hid_dim_type,\n            cross_attention_dim=cross_attention_dim,\n            encoder_hid_dim=encoder_hid_dim,\n        )\n\n        # class embedding\n        self._set_class_embedding(\n            class_embed_type,\n            act_fn=act_fn,\n            num_class_embeds=num_class_embeds,\n            projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,\n            time_embed_dim=time_embed_dim,\n            timestep_input_dim=timestep_input_dim,\n        )\n\n        self._set_add_embedding(\n            addition_embed_type,\n            addition_embed_type_num_heads=addition_embed_type_num_heads,\n            addition_time_embed_dim=timestep_input_dim,\n            cross_attention_dim=cross_attention_dim,\n            encoder_hid_dim=encoder_hid_dim,\n            flip_sin_to_cos=flip_sin_to_cos,\n            freq_shift=freq_shift,\n            projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,\n            time_embed_dim=time_embed_dim,\n        )\n\n        if time_embedding_act_fn is None:\n            self.time_embed_act = None\n        else:\n            self.time_embed_act = get_activation(time_embedding_act_fn)\n\n        self.down_blocks = nn.ModuleList([])\n        self.up_blocks = nn.ModuleList([])\n\n        if isinstance(only_cross_attention, bool):\n            if mid_block_only_cross_attention is None:\n                mid_block_only_cross_attention = only_cross_attention\n\n            only_cross_attention = [only_cross_attention] * len(down_block_types)\n\n        if mid_block_only_cross_attention is None:\n            mid_block_only_cross_attention = False\n\n        if isinstance(num_attention_heads, int):\n            num_attention_heads = (num_attention_heads,) * len(down_block_types)\n\n        if isinstance(attention_head_dim, int):\n            attention_head_dim = (attention_head_dim,) * len(down_block_types)\n\n        if isinstance(cross_attention_dim, int):\n            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)\n\n        if isinstance(layers_per_block, int):\n            layers_per_block = [layers_per_block] * len(down_block_types)\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)\n\n        if class_embeddings_concat:\n            # The time embeddings are concatenated with the class embeddings. The dimension of the\n            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the\n            # regular time embeddings\n            blocks_time_embed_dim = time_embed_dim * 2\n        else:\n            blocks_time_embed_dim = time_embed_dim\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block[i],\n                transformer_layers_per_block=transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                norm_type=norm_type,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim[i],\n                num_attention_heads=num_attention_heads[i],\n                downsample_padding=downsample_padding,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                attention_type=attention_type,\n                attention_pre_only=attention_pre_only,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                use_attention_ffn=use_attention_ffn,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                dropout=dropout,\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        self.mid_block = get_mid_block(\n            mid_block_type,\n            temb_channels=blocks_time_embed_dim,\n            in_channels=block_out_channels[-1],\n            resnet_eps=norm_eps,\n            resnet_act_fn=act_fn,\n            norm_type=norm_type,\n            resnet_groups=norm_num_groups,\n            output_scale_factor=mid_block_scale_factor,\n            transformer_layers_per_block=1,\n            num_attention_heads=num_attention_heads[-1],\n            cross_attention_dim=cross_attention_dim[-1],\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            mid_block_only_cross_attention=mid_block_only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            attention_type=attention_type,\n            attention_pre_only=attention_pre_only,\n            resnet_skip_time_act=resnet_skip_time_act,\n            cross_attention_norm=cross_attention_norm,\n            attention_head_dim=attention_head_dim[-1],\n            dropout=dropout,\n        )\n\n        # count how many layers upsample the images\n        self.num_upsamplers = 0\n\n        # up\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        reversed_num_attention_heads = list(reversed(num_attention_heads))\n        reversed_layers_per_block = list(reversed(layers_per_block))\n        reversed_cross_attention_dim = list(reversed(cross_attention_dim))\n        reversed_transformer_layers_per_block = (\n            list(reversed(transformer_layers_per_block))\n            if reverse_transformer_layers_per_block is None\n            else reverse_transformer_layers_per_block\n        )\n        only_cross_attention = list(reversed(only_cross_attention))\n\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            is_final_block = i == len(block_out_channels) - 1\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]\n\n            # add upsample block for all BUT final layer\n            if not is_final_block:\n                add_upsample = True\n                self.num_upsamplers += 1\n            else:\n                add_upsample = False\n\n            up_block = get_up_block(\n                up_block_type,\n                num_layers=reversed_layers_per_block[i] + 1,\n                transformer_layers_per_block=reversed_transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_upsample=add_upsample,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                norm_type=norm_type,\n                resolution_idx=i,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=reversed_cross_attention_dim[i],\n                num_attention_heads=reversed_num_attention_heads[i],\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                attention_type=attention_type,\n                attention_pre_only=attention_pre_only,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                use_attention_ffn=use_attention_ffn,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                dropout=dropout,\n            )\n            self.up_blocks.append(up_block)\n\n        # out\n        if norm_num_groups is not None:\n            self.conv_norm_out = nn.GroupNorm(\n                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps\n            )\n\n            self.conv_act = get_activation(act_fn)\n\n        else:\n            self.conv_norm_out = None\n            self.conv_act = None\n\n        conv_out_padding = (conv_out_kernel - 1) // 2\n        self.conv_out = nn.Conv2d(\n            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding\n        )\n\n        self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)\n\n        self.is_temporal = []\n\n    def _check_config(\n        self,\n        down_block_types: Tuple[str, ...],\n        up_block_types: Tuple[str, ...],\n        only_cross_attention: Union[bool, Tuple[bool]],\n        block_out_channels: Tuple[int, ...],\n        layers_per_block: Union[int, Tuple[int]],\n        cross_attention_dim: Union[int, Tuple[int]],\n        transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],\n        reverse_transformer_layers_per_block: bool,\n        attention_head_dim: int,\n        num_attention_heads: Optional[Union[int, Tuple[int]]],\n    ):\n        if len(down_block_types) != len(up_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.\"\n            )\n\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.\"\n            )\n        if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:\n            for layer_number_per_block in transformer_layers_per_block:\n                if isinstance(layer_number_per_block, list):\n                    raise ValueError(\"Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.\")\n\n    def _set_time_proj(\n        self,\n        time_embedding_type: str,\n        block_out_channels: int,\n        flip_sin_to_cos: bool,\n        freq_shift: float,\n        time_embedding_dim: int,\n    ) -> Tuple[int, int]:\n        if time_embedding_type == \"fourier\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2\n            if time_embed_dim % 2 != 0:\n                raise ValueError(f\"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.\")\n            self.time_proj = GaussianFourierProjection(\n                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos\n            )\n            timestep_input_dim = time_embed_dim\n        elif time_embedding_type == \"positional\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4\n\n            if self.model_type == \"unet\":\n                self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)\n            elif self.model_type == \"nested_unet\" and self.config.micro_conditioning_scale == 256:\n                self.time_proj = Timesteps(block_out_channels[0] * 4, flip_sin_to_cos, freq_shift)\n            elif self.model_type == \"nested_unet\" and self.config.micro_conditioning_scale == 1024:\n                self.time_proj = Timesteps(block_out_channels[0] * 4 * 2, flip_sin_to_cos, freq_shift)\n            timestep_input_dim = block_out_channels[0]\n        else:\n            raise ValueError(\n                f\"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.\"\n            )\n\n        return time_embed_dim, timestep_input_dim\n\n    def _set_encoder_hid_proj(\n        self,\n        encoder_hid_dim_type: str | None,\n        cross_attention_dim: Union[int, Tuple[int]],\n        encoder_hid_dim: Optional[int],\n    ):\n        if encoder_hid_dim_type is None and encoder_hid_dim is not None:\n            encoder_hid_dim_type = \"text_proj\"\n            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)\n            logger.info(\"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.\")\n\n        if encoder_hid_dim is None and encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.\"\n            )\n\n        if encoder_hid_dim_type == \"text_proj\":\n            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)\n        elif encoder_hid_dim_type == \"text_image_proj\":\n            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image_proj\"` (Kandinsky 2.1)`\n            self.encoder_hid_proj = TextImageProjection(\n                text_embed_dim=encoder_hid_dim,\n                image_embed_dim=cross_attention_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2\n            self.encoder_hid_proj = ImageProjection(\n                image_embed_dim=encoder_hid_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'.\"\n            )\n        else:\n            self.encoder_hid_proj = None\n\n    def _set_class_embedding(\n        self,\n        class_embed_type: str | None,\n        act_fn: str,\n        num_class_embeds: Optional[int],\n        projection_class_embeddings_input_dim: Optional[int],\n        time_embed_dim: int,\n        timestep_input_dim: int,\n    ):\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        elif class_embed_type == \"projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except\n            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings\n            # 2. it projects from an arbitrary input dimension.\n            #\n            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.\n            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.\n            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.\n            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif class_embed_type == \"simple_projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n    def _set_add_embedding(\n        self,\n        addition_embed_type: str,\n        addition_embed_type_num_heads: int,\n        addition_time_embed_dim: Optional[int],\n        flip_sin_to_cos: bool,\n        freq_shift: float,\n        cross_attention_dim: Optional[int],\n        encoder_hid_dim: Optional[int],\n        projection_class_embeddings_input_dim: Optional[int],\n        time_embed_dim: int,\n    ):\n        if addition_embed_type == \"text\":\n            if encoder_hid_dim is not None:\n                text_time_embedding_from_dim = encoder_hid_dim\n            else:\n                text_time_embedding_from_dim = cross_attention_dim\n\n            self.add_embedding = TextTimeEmbedding(\n                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads\n            )\n        elif addition_embed_type == \"matryoshka\":\n            self.add_embedding = MatryoshkaCombinedTimestepTextEmbedding(\n                self.config.time_embedding_dim // 4\n                if self.config.time_embedding_dim is not None\n                else addition_time_embed_dim,\n                cross_attention_dim,\n                time_embed_dim,\n                self.model_type,  # if not self.config.nesting else \"inner_\" + self.model_type,\n            )\n        elif addition_embed_type == \"text_image\":\n            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image\"` (Kandinsky 2.1)`\n            self.add_embedding = TextImageTimeEmbedding(\n                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim\n            )\n        elif addition_embed_type == \"text_time\":\n            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)\n            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif addition_embed_type == \"image\":\n            # Kandinsky 2.2\n            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 ControlNet\n            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type is not None:\n            raise ValueError(\n                f\"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'.\"\n            )\n\n    def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):\n        if attention_type in [\"gated\", \"gated-text-image\"]:\n            positive_len = 768\n            if isinstance(cross_attention_dim, int):\n                positive_len = cross_attention_dim\n            elif isinstance(cross_attention_dim, (list, tuple)):\n                positive_len = cross_attention_dim[0]\n\n            feature_type = \"text-only\" if attention_type == \"gated\" else \"text-image\"\n            self.position_net = GLIGENTextBoundingboxProjection(\n                positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type\n            )\n\n    @property\n    def attn_processors(self) -> dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor()\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):\n            processor = AttnAddedKVProcessor()\n        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):\n            processor = AttnProcessor()\n        else:\n            raise ValueError(\n                f\"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}\"\n            )\n\n        self.set_attn_processor(processor)\n\n    def set_attention_slice(self, slice_size: Union[str, int, List[int]] = \"auto\"):\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module splits the input tensor in slices to compute attention in\n        several steps. This is useful for saving some memory in exchange for a small decrease in speed.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, input to the attention heads is halved, so attention is computed in two steps. If\n                `\"max\"`, maximum amount of memory is saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_sliceable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_sliceable_dims(module)\n\n        num_sliceable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_sliceable_layers * [1]\n\n        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):\n        r\"\"\"Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.\n\n        The suffixes after the scaling factors represent the stage blocks where they are being applied.\n\n        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that\n        are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.\n\n        Args:\n            s1 (`float`):\n                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to\n                mitigate the \"oversmoothing effect\" in the enhanced denoising process.\n            s2 (`float`):\n                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to\n                mitigate the \"oversmoothing effect\" in the enhanced denoising process.\n            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.\n            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.\n        \"\"\"\n        for i, upsample_block in enumerate(self.up_blocks):\n            setattr(upsample_block, \"s1\", s1)\n            setattr(upsample_block, \"s2\", s2)\n            setattr(upsample_block, \"b1\", b1)\n            setattr(upsample_block, \"b2\", b2)\n\n    def disable_freeu(self):\n        \"\"\"Disables the FreeU mechanism.\"\"\"\n        freeu_keys = {\"s1\", \"s2\", \"b1\", \"b2\"}\n        for i, upsample_block in enumerate(self.up_blocks):\n            for k in freeu_keys:\n                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:\n                    setattr(upsample_block, k, None)\n\n    def fuse_qkv_projections(self):\n        \"\"\"\n        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)\n        are fused. For cross-attention modules, key and value projection matrices are fused.\n\n        > [!WARNING]\n        > This API is 🧪 experimental.\n        \"\"\"\n        self.original_attn_processors = None\n\n        for _, attn_processor in self.attn_processors.items():\n            if \"Added\" in str(attn_processor.__class__.__name__):\n                raise ValueError(\"`fuse_qkv_projections()` is not supported for models having added KV projections.\")\n\n        self.original_attn_processors = self.attn_processors\n\n        for module in self.modules():\n            if isinstance(module, Attention):\n                module.fuse_projections(fuse=True)\n\n        self.set_attn_processor(FusedAttnProcessor2_0())\n\n    def unfuse_qkv_projections(self):\n        \"\"\"Disables the fused QKV projection if enabled.\n\n        > [!WARNING]\n        > This API is 🧪 experimental.\n\n        \"\"\"\n        if self.original_attn_processors is not None:\n            self.set_attn_processor(self.original_attn_processors)\n\n    def get_time_embed(\n        self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]\n    ) -> Optional[torch.Tensor]:\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            is_npu = sample.device.type == \"npu\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if (is_mps or is_npu) else torch.float64\n            else:\n                dtype = torch.int32 if (is_mps or is_npu) else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n        return t_emb\n\n    def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:\n        class_emb = None\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n                # `Timesteps` does not contain any weights and will always return f32 tensors\n                # there might be better ways to encapsulate this.\n                class_labels = class_labels.to(dtype=sample.dtype)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)\n        return class_emb\n\n    def get_aug_embed(\n        self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]\n    ) -> Optional[torch.Tensor]:\n        aug_emb = None\n        if self.config.addition_embed_type == \"text\":\n            aug_emb = self.add_embedding(encoder_hidden_states)\n        elif self.config.addition_embed_type == \"matryoshka\":\n            aug_emb = self.add_embedding(emb, encoder_hidden_states, added_cond_kwargs)\n        elif self.config.addition_embed_type == \"text_image\":\n            # Kandinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            text_embs = added_cond_kwargs.get(\"text_embeds\", encoder_hidden_states)\n            aug_emb = self.add_embedding(text_embs, image_embs)\n        elif self.config.addition_embed_type == \"text_time\":\n            # SDXL - style\n            if \"text_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            text_embeds = added_cond_kwargs.get(\"text_embeds\")\n            if \"time_ids\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`\"\n                )\n            time_ids = added_cond_kwargs.get(\"time_ids\")\n            time_embeds = self.add_time_proj(time_ids.flatten())\n            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))\n            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)\n            add_embeds = add_embeds.to(emb.dtype)\n            aug_emb = self.add_embedding(add_embeds)\n        elif self.config.addition_embed_type == \"image\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            aug_emb = self.add_embedding(image_embs)\n        elif self.config.addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 ControlNet - style\n            if \"image_embeds\" not in added_cond_kwargs or \"hint\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            hint = added_cond_kwargs.get(\"hint\")\n            aug_emb = self.add_embedding(image_embs, hint)\n        return aug_emb\n\n    def process_encoder_hidden_states(\n        self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]\n    ) -> torch.Tensor:\n        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_proj\":\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_image_proj\":\n            # Kandinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(image_embeds)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"ip_image_proj\":\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n\n            if hasattr(self, \"text_encoder_hid_proj\") and self.text_encoder_hid_proj is not None:\n                encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)\n\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            image_embeds = self.encoder_hid_proj(image_embeds)\n            encoder_hidden_states = (encoder_hidden_states, image_embeds)\n        return encoder_hidden_states\n\n    @property\n    def model_type(self) -> str:\n        return \"unet\"\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        cond_emb: Optional[torch.Tensor] = None,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n        from_nested: bool = False,\n    ) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]:\n        r\"\"\"\n        The [`NestedUNet2DConditionModel`] forward method.\n\n        Args:\n            sample (`torch.Tensor`):\n                The noisy input tensor with the following shape `(batch, channel, height, width)`.\n            timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.Tensor`):\n                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.\n            class_labels (`torch.Tensor`, *optional*, defaults to `None`):\n                Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.\n            timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):\n                Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed\n                through the `self.time_embedding` layer to obtain the timestep embeddings.\n            attention_mask (`torch.Tensor`, *optional*, defaults to `None`):\n                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask\n                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large\n                negative values to the attention scores corresponding to \"discard\" tokens.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            added_cond_kwargs: (`dict`, *optional*):\n                A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that\n                are passed along to the UNet blocks.\n            down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):\n                A tuple of tensors that if specified are added to the residuals of down unet blocks.\n            mid_block_additional_residual: (`torch.Tensor`, *optional*):\n                A tensor that if specified is added to the residual of the middle unet block.\n            down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):\n                additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)\n            encoder_attention_mask (`torch.Tensor`):\n                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If\n                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,\n                which adds large negative values to the attention scores corresponding to \"discard\" tokens.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain\n                tuple.\n\n        Returns:\n            [`~NestedUNet2DConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~NestedUNet2DConditionOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if self.config.nesting:\n            sample, sample_feat = sample\n        if isinstance(sample, list) and len(sample) == 1:\n            sample = sample[0]\n\n        for dim in sample.shape[-2:]:\n            if dim % default_overall_up_factor != 0:\n                # Forward upsample size to force interpolation output size.\n                forward_upsample_size = True\n                break\n\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        t_emb = self.get_time_embed(sample=sample, timestep=timestep)\n        emb = self.time_embedding(t_emb, timestep_cond)\n\n        class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)\n        if class_emb is not None:\n            if self.config.class_embeddings_concat:\n                emb = torch.cat([emb, class_emb], dim=-1)\n            else:\n                emb = emb + class_emb\n\n        added_cond_kwargs = added_cond_kwargs or {}\n        added_cond_kwargs[\"masked_cross_attention\"] = self.config.masked_cross_attention\n        added_cond_kwargs[\"micro_conditioning_scale\"] = self.config.micro_conditioning_scale\n        added_cond_kwargs[\"from_nested\"] = from_nested\n        added_cond_kwargs[\"conditioning_mask\"] = encoder_attention_mask\n\n        if not from_nested:\n            encoder_hidden_states = self.process_encoder_hidden_states(\n                encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n            )\n\n            aug_emb, encoder_attention_mask, cond_emb = self.get_aug_embed(\n                emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n            )\n        else:\n            aug_emb, encoder_attention_mask, _ = self.get_aug_embed(\n                emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n            )\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        if self.config.addition_embed_type == \"image_hint\":\n            aug_emb, hint = aug_emb\n            sample = torch.cat([sample, hint], dim=1)\n\n        emb = emb + aug_emb + cond_emb if aug_emb is not None else emb\n\n        if self.time_embed_act is not None:\n            emb = self.time_embed_act(emb)\n\n        # 2. pre-process\n        sample = self.conv_in(sample)\n        if self.config.nesting:\n            sample = sample + sample_feat\n\n        # 2.5 GLIGEN position net\n        if cross_attention_kwargs is not None and cross_attention_kwargs.get(\"gligen\", None) is not None:\n            cross_attention_kwargs = cross_attention_kwargs.copy()\n            gligen_args = cross_attention_kwargs.pop(\"gligen\")\n            cross_attention_kwargs[\"gligen\"] = {\"objs\": self.position_net(**gligen_args)}\n\n        # 3. down\n        # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated\n        # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.\n        if cross_attention_kwargs is not None:\n            cross_attention_kwargs = cross_attention_kwargs.copy()\n            lora_scale = cross_attention_kwargs.pop(\"scale\", 1.0)\n        else:\n            lora_scale = 1.0\n\n        if USE_PEFT_BACKEND:\n            # weight the lora layers by setting `lora_scale` for each PEFT layer\n            scale_lora_layers(self, lora_scale)\n\n        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None\n        # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets\n        is_adapter = down_intrablock_additional_residuals is not None\n        # maintain backward compatibility for legacy usage, where\n        #       T2I-Adapter and ControlNet both use down_block_additional_residuals arg\n        #       but can only use one or the other\n        if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:\n            deprecate(\n                \"T2I should not use down_block_additional_residuals\",\n                \"1.3.0\",\n                \"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \\\n                       and will be removed in diffusers 1.3.0.  `down_block_additional_residuals` should only be used \\\n                       for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. \",\n                standard_warn=False,\n            )\n            down_intrablock_additional_residuals = down_block_additional_residuals\n            is_adapter = True\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                # For t2i-adapter CrossAttnDownBlock2D\n                additional_residuals = {}\n                if is_adapter and len(down_intrablock_additional_residuals) > 0:\n                    additional_residuals[\"additional_residuals\"] = down_intrablock_additional_residuals.pop(0)\n\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                    **additional_residuals,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n                if is_adapter and len(down_intrablock_additional_residuals) > 0:\n                    sample += down_intrablock_additional_residuals.pop(0)\n\n            down_block_res_samples += res_samples\n\n        if is_controlnet:\n            new_down_block_res_samples = ()\n\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n\n            down_block_res_samples = new_down_block_res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            if hasattr(self.mid_block, \"has_cross_attention\") and self.mid_block.has_cross_attention:\n                sample = self.mid_block(\n                    sample,\n                    emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                sample = self.mid_block(sample, emb)\n\n            # To support T2I-Adapter-XL\n            if (\n                is_adapter\n                and len(down_intrablock_additional_residuals) > 0\n                and sample.shape == down_intrablock_additional_residuals[0].shape\n            ):\n                sample += down_intrablock_additional_residuals.pop(0)\n\n        if is_controlnet:\n            sample = sample + mid_block_additional_residual\n\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    upsample_size=upsample_size,\n                )\n\n        sample_inner = sample\n\n        # 6. post-process\n        if self.conv_norm_out:\n            sample = self.conv_norm_out(sample_inner)\n            sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if USE_PEFT_BACKEND:\n            # remove `lora_scale` from each PEFT layer\n            unscale_lora_layers(self, lora_scale)\n\n        if not return_dict:\n            return (sample,)\n\n        if self.config.nesting:\n            return MatryoshkaUNet2DConditionOutput(sample=sample, sample_inner=sample_inner)\n\n        return MatryoshkaUNet2DConditionOutput(sample=sample)\n\n\nclass NestedUNet2DConditionOutput(BaseOutput):\n    \"\"\"\n    Output type for the [`NestedUNet2DConditionModel`] model.\n    \"\"\"\n\n    sample: list = None\n    sample_inner: torch.Tensor = None\n\n\nclass NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):\n    \"\"\"\n    Nested UNet model with condition for image denoising.\n    \"\"\"\n\n    @register_to_config\n    def __init__(\n        self,\n        in_channels=3,\n        out_channels=3,\n        block_out_channels=(64, 128, 256),\n        cross_attention_dim=2048,\n        resnet_time_scale_shift=\"scale_shift\",\n        down_block_types=(\"DownBlock2D\", \"DownBlock2D\", \"DownBlock2D\"),\n        up_block_types=(\"UpBlock2D\", \"UpBlock2D\", \"UpBlock2D\"),\n        mid_block_type=None,\n        nesting=False,\n        flip_sin_to_cos=False,\n        transformer_layers_per_block=[0, 0, 0],\n        layers_per_block=[2, 2, 1],\n        masked_cross_attention=True,\n        micro_conditioning_scale=256,\n        addition_embed_type=\"matryoshka\",\n        skip_normalization=True,\n        time_embedding_dim=1024,\n        skip_inner_unet_input=False,\n        temporal_mode=False,\n        temporal_spatial_ds=False,\n        initialize_inner_with_pretrained=None,\n        use_attention_ffn=False,\n        act_fn=\"silu\",\n        addition_embed_type_num_heads=64,\n        addition_time_embed_dim=None,\n        attention_head_dim=8,\n        attention_pre_only=False,\n        attention_type=\"default\",\n        center_input_sample=False,\n        class_embed_type=None,\n        class_embeddings_concat=False,\n        conv_in_kernel=3,\n        conv_out_kernel=3,\n        cross_attention_norm=None,\n        downsample_padding=1,\n        dropout=0.0,\n        dual_cross_attention=False,\n        encoder_hid_dim=None,\n        encoder_hid_dim_type=None,\n        freq_shift=0,\n        mid_block_only_cross_attention=None,\n        mid_block_scale_factor=1,\n        norm_eps=1e-05,\n        norm_num_groups=32,\n        norm_type=\"layer_norm\",\n        num_attention_heads=None,\n        num_class_embeds=None,\n        only_cross_attention=False,\n        projection_class_embeddings_input_dim=None,\n        resnet_out_scale_factor=1.0,\n        resnet_skip_time_act=False,\n        reverse_transformer_layers_per_block=None,\n        sample_size=None,\n        skip_cond_emb=False,\n        time_cond_proj_dim=None,\n        time_embedding_act_fn=None,\n        time_embedding_type=\"positional\",\n        timestep_post_act=None,\n        upcast_attention=False,\n        use_linear_projection=False,\n        is_temporal=None,\n        inner_config={},\n    ):\n        super().__init__(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            block_out_channels=block_out_channels,\n            cross_attention_dim=cross_attention_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            down_block_types=down_block_types,\n            up_block_types=up_block_types,\n            mid_block_type=mid_block_type,\n            nesting=nesting,\n            flip_sin_to_cos=flip_sin_to_cos,\n            transformer_layers_per_block=transformer_layers_per_block,\n            layers_per_block=layers_per_block,\n            masked_cross_attention=masked_cross_attention,\n            micro_conditioning_scale=micro_conditioning_scale,\n            addition_embed_type=addition_embed_type,\n            time_embedding_dim=time_embedding_dim,\n            temporal_mode=temporal_mode,\n            temporal_spatial_ds=temporal_spatial_ds,\n            use_attention_ffn=use_attention_ffn,\n            sample_size=sample_size,\n        )\n        # self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim\n\n        if \"inner_config\" not in self.config.inner_config:\n            self.inner_unet = MatryoshkaUNet2DConditionModel(**self.config.inner_config)\n        else:\n            self.inner_unet = NestedUNet2DConditionModel(**self.config.inner_config)\n\n        if not self.config.skip_inner_unet_input:\n            self.in_adapter = nn.Conv2d(\n                self.config.block_out_channels[-1],\n                self.config.inner_config[\"block_out_channels\"][0],\n                kernel_size=3,\n                padding=1,\n            )\n        else:\n            self.in_adapter = None\n        self.out_adapter = nn.Conv2d(\n            self.config.inner_config[\"block_out_channels\"][0],\n            self.config.block_out_channels[-1],\n            kernel_size=3,\n            padding=1,\n        )\n\n        self.is_temporal = [self.config.temporal_mode and (not self.config.temporal_spatial_ds)]\n        if hasattr(self.inner_unet, \"is_temporal\"):\n            self.is_temporal = self.is_temporal + self.inner_unet.is_temporal\n\n        nest_ratio = int(2 ** (len(self.config.block_out_channels) - 1))\n        if self.is_temporal[0]:\n            nest_ratio = int(np.sqrt(nest_ratio))\n        if self.inner_unet.config.nesting and self.inner_unet.model_type == \"nested_unet\":\n            self.nest_ratio = [nest_ratio * self.inner_unet.nest_ratio[0]] + self.inner_unet.nest_ratio\n        else:\n            self.nest_ratio = [nest_ratio]\n\n        # self.register_modules(inner_unet=self.inner_unet)\n\n    @property\n    def model_type(self):\n        return \"nested_unet\"\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        cond_emb: Optional[torch.Tensor] = None,\n        from_nested: bool = False,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]:\n        r\"\"\"\n        The [`NestedUNet2DConditionModel`] forward method.\n\n        Args:\n            sample (`torch.Tensor`):\n                The noisy input tensor with the following shape `(batch, channel, height, width)`.\n            timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.Tensor`):\n                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.\n            class_labels (`torch.Tensor`, *optional*, defaults to `None`):\n                Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.\n            timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):\n                Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed\n                through the `self.time_embedding` layer to obtain the timestep embeddings.\n            attention_mask (`torch.Tensor`, *optional*, defaults to `None`):\n                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask\n                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large\n                negative values to the attention scores corresponding to \"discard\" tokens.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            added_cond_kwargs: (`dict`, *optional*):\n                A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that\n                are passed along to the UNet blocks.\n            down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):\n                A tuple of tensors that if specified are added to the residuals of down unet blocks.\n            mid_block_additional_residual: (`torch.Tensor`, *optional*):\n                A tensor that if specified is added to the residual of the middle unet block.\n            down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):\n                additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)\n            encoder_attention_mask (`torch.Tensor`):\n                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If\n                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,\n                which adds large negative values to the attention scores corresponding to \"discard\" tokens.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain\n                tuple.\n\n        Returns:\n            [`~NestedUNet2DConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~NestedUNet2DConditionOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if self.config.nesting:\n            sample, sample_feat = sample\n        if isinstance(sample, list) and len(sample) == 1:\n            sample = sample[0]\n\n        # 2. input layer (normalize the input)\n        bsz = [x.size(0) for x in sample]\n        bh, bl = bsz[0], bsz[1]\n        x_t_low, sample = sample[1:], sample[0]\n\n        for dim in sample.shape[-2:]:\n            if dim % default_overall_up_factor != 0:\n                # Forward upsample size to force interpolation output size.\n                forward_upsample_size = True\n                break\n\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        t_emb = self.get_time_embed(sample=sample, timestep=timestep)\n        emb = self.time_embedding(t_emb, timestep_cond)\n\n        class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)\n        if class_emb is not None:\n            if self.config.class_embeddings_concat:\n                emb = torch.cat([emb, class_emb], dim=-1)\n            else:\n                emb = emb + class_emb\n\n        if self.inner_unet.model_type == \"unet\":\n            added_cond_kwargs = added_cond_kwargs or {}\n            added_cond_kwargs[\"masked_cross_attention\"] = self.inner_unet.config.masked_cross_attention\n            added_cond_kwargs[\"micro_conditioning_scale\"] = self.config.micro_conditioning_scale\n            added_cond_kwargs[\"conditioning_mask\"] = encoder_attention_mask\n\n            if not self.config.nesting:\n                encoder_hidden_states = self.inner_unet.process_encoder_hidden_states(\n                    encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n                )\n\n                aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.get_aug_embed(\n                    emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n                )\n                added_cond_kwargs[\"masked_cross_attention\"] = self.config.masked_cross_attention\n                aug_emb, __, _ = self.get_aug_embed(\n                    emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n                )\n            else:\n                aug_emb, cond_mask, _ = self.get_aug_embed(\n                    emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n                )\n\n        elif self.inner_unet.model_type == \"nested_unet\":\n            added_cond_kwargs = added_cond_kwargs or {}\n            added_cond_kwargs[\"masked_cross_attention\"] = self.inner_unet.inner_unet.config.masked_cross_attention\n            added_cond_kwargs[\"micro_conditioning_scale\"] = self.config.micro_conditioning_scale\n            added_cond_kwargs[\"conditioning_mask\"] = encoder_attention_mask\n\n            encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states(\n                encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n            )\n\n            aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.inner_unet.get_aug_embed(\n                emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n            )\n\n            aug_emb, __, _ = self.get_aug_embed(\n                emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n            )\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        if self.config.addition_embed_type == \"image_hint\":\n            aug_emb, hint = aug_emb\n            sample = torch.cat([sample, hint], dim=1)\n\n        emb = emb + aug_emb + cond_emb if aug_emb is not None else emb\n\n        if self.time_embed_act is not None:\n            emb = self.time_embed_act(emb)\n\n        if not self.config.skip_normalization:\n            sample = sample / sample.std((1, 2, 3), keepdims=True)\n        if isinstance(sample, list) and len(sample) == 1:\n            sample = sample[0]\n        sample = self.conv_in(sample)\n        if self.config.nesting:\n            sample = sample + sample_feat\n\n        # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated\n        # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.\n        if cross_attention_kwargs is not None:\n            cross_attention_kwargs = cross_attention_kwargs.copy()\n            lora_scale = cross_attention_kwargs.pop(\"scale\", 1.0)\n        else:\n            lora_scale = 1.0\n\n        if USE_PEFT_BACKEND:\n            # weight the lora layers by setting `lora_scale` for each PEFT layer\n            scale_lora_layers(self, lora_scale)\n\n        # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets\n        is_adapter = down_intrablock_additional_residuals is not None\n        # maintain backward compatibility for legacy usage, where\n        #       T2I-Adapter and ControlNet both use down_block_additional_residuals arg\n        #       but can only use one or the other\n        if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:\n            deprecate(\n                \"T2I should not use down_block_additional_residuals\",\n                \"1.3.0\",\n                \"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \\\n                       and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \\\n                       for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. \",\n                standard_warn=False,\n            )\n            down_intrablock_additional_residuals = down_block_additional_residuals\n            is_adapter = True\n\n        # 3. downsample blocks in the outer layers\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                # For t2i-adapter CrossAttnDownBlock2D\n                additional_residuals = {}\n                if is_adapter and len(down_intrablock_additional_residuals) > 0:\n                    additional_residuals[\"additional_residuals\"] = down_intrablock_additional_residuals.pop(0)\n\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb[:bh],\n                    encoder_hidden_states=encoder_hidden_states[:bh],\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask,\n                    **additional_residuals,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n                if is_adapter and len(down_intrablock_additional_residuals) > 0:\n                    sample += down_intrablock_additional_residuals.pop(0)\n\n            down_block_res_samples += res_samples\n\n        # 4. run inner unet\n        x_inner = self.in_adapter(sample) if self.in_adapter is not None else None\n        x_inner = (\n            torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner\n        )  # pad zeros for low-resolutions\n        inner_unet_output = self.inner_unet(\n            (x_t_low, x_inner),\n            timestep,\n            cond_emb=cond_emb,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=cond_mask,\n            from_nested=True,\n        )\n        x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner\n        x_inner = self.out_adapter(x_inner)\n        sample = sample + x_inner[:bh] if bh < bl else sample + x_inner\n\n        # 5. upsample blocks in the outer layers\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb[:bh],\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states[:bh],\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    upsample_size=upsample_size,\n                )\n\n        # 6. post-process\n        if self.conv_norm_out:\n            sample_out = self.conv_norm_out(sample)\n            sample_out = self.conv_act(sample_out)\n        sample_out = self.conv_out(sample_out)\n\n        if USE_PEFT_BACKEND:\n            # remove `lora_scale` from each PEFT layer\n            unscale_lora_layers(self, lora_scale)\n\n        # 7. output both low and high-res output\n        if isinstance(x_low, list):\n            out = [sample_out] + x_low\n        else:\n            out = [sample_out, x_low]\n        if self.config.nesting:\n            return NestedUNet2DConditionOutput(sample=out, sample_inner=sample)\n        if not return_dict:\n            return (out,)\n        else:\n            return NestedUNet2DConditionOutput(sample=out)\n\n\n@dataclass\nclass MatryoshkaPipelineOutput(BaseOutput):\n    \"\"\"\n    Output class for Matryoshka pipelines.\n\n    Args:\n        images (`List[PIL.Image.Image]` or `np.ndarray`)\n            List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,\n            num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.\n    \"\"\"\n\n    images: Union[List[Image.Image], List[List[Image.Image]], np.ndarray, List[np.ndarray]]\n\n\nclass MatryoshkaPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    StableDiffusionLoraLoaderMixin,\n    IPAdapterMixin,\n    FromSingleFileMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Matryoshka Diffusion Models.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        text_encoder ([`~transformers.T5EncoderModel`]):\n            Frozen text-encoder ([flan-t5-xl](https://huggingface.co/google/flan-t5-xl)).\n        tokenizer ([`~transformers.T5Tokenizer`]):\n            A `T5Tokenizer` to tokenize text.\n        unet ([`MatryoshkaUNet2DConditionModel`]):\n            A `MatryoshkaUNet2DConditionModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`MatryoshkaDDIMScheduler`] and other schedulers with proper modifications, see an example usage in README.md.\n        feature_extractor ([`~transformers.<AnImageProcessor>`]):\n            A `AnImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet\"\n    _optional_components = [\"unet\", \"feature_extractor\", \"image_encoder\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        text_encoder: T5EncoderModel,\n        tokenizer: T5TokenizerFast,\n        scheduler: MatryoshkaDDIMScheduler,\n        unet: MatryoshkaUNet2DConditionModel = None,\n        feature_extractor: CLIPImageProcessor = None,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        trust_remote_code: bool = False,\n        nesting_level: int = 0,\n    ):\n        super().__init__()\n\n        if nesting_level == 0:\n            unet = MatryoshkaUNet2DConditionModel.from_pretrained(\n                \"tolgacangoz/matryoshka-diffusion-models\", subfolder=\"unet/nesting_level_0\"\n            )\n        elif nesting_level == 1:\n            unet = NestedUNet2DConditionModel.from_pretrained(\n                \"tolgacangoz/matryoshka-diffusion-models\", subfolder=\"unet/nesting_level_1\"\n            )\n        elif nesting_level == 2:\n            unet = NestedUNet2DConditionModel.from_pretrained(\n                \"tolgacangoz/matryoshka-diffusion-models\", subfolder=\"unet/nesting_level_2\"\n            )\n        else:\n            raise ValueError(\"Currently, nesting levels 0, 1, and 2 are supported.\")\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        # if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n        #     deprecation_message = (\n        #         f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n        #         \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n        #         \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n        #         \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n        #         \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n        #     )\n        #     deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n        #     new_config = dict(scheduler.config)\n        #     new_config[\"clip_sample\"] = False\n        #     scheduler._internal_dict = FrozenDict(new_config)\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        if hasattr(unet, \"nest_ratio\"):\n            scheduler.scales = unet.nest_ratio + [1]\n            if nesting_level == 2:\n                scheduler.schedule_shifted_power = 2.0\n\n        self.register_modules(\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.register_to_config(nesting_level=nesting_level)\n        self.image_processor = VaeImageProcessor(do_resize=False)\n\n    def change_nesting_level(self, nesting_level: int):\n        if nesting_level == 0:\n            if hasattr(self.unet, \"nest_ratio\"):\n                self.scheduler.scales = None\n            self.unet = MatryoshkaUNet2DConditionModel.from_pretrained(\n                \"tolgacangoz/matryoshka-diffusion-models\", subfolder=\"unet/nesting_level_0\"\n            ).to(self.device)\n            self.config.nesting_level = 0\n        elif nesting_level == 1:\n            self.unet = NestedUNet2DConditionModel.from_pretrained(\n                \"tolgacangoz/matryoshka-diffusion-models\", subfolder=\"unet/nesting_level_1\"\n            ).to(self.device)\n            self.config.nesting_level = 1\n            self.scheduler.scales = self.unet.nest_ratio + [1]\n            self.scheduler.schedule_shifted_power = 1.0\n        elif nesting_level == 2:\n            self.unet = NestedUNet2DConditionModel.from_pretrained(\n                \"tolgacangoz/matryoshka-diffusion-models\", subfolder=\"unet/nesting_level_2\"\n            ).to(self.device)\n            self.config.nesting_level = 2\n            self.scheduler.scales = self.unet.nest_ratio + [1]\n            self.scheduler.schedule_shifted_power = 2.0\n        else:\n            raise ValueError(\"Currently, nesting levels 0, 1, and 2 are supported.\")\n\n        gc.collect()\n        torch.cuda.empty_cache()\n\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because FLAN-T5-XL for this pipeline can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                prompt_attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                prompt_attention_mask = None\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                return_tensors=\"pt\",\n            )\n            uncond_input_ids = uncond_input.input_ids\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                negative_prompt_attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                negative_prompt_attention_mask = None\n\n        if not do_classifier_free_guidance:\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n        else:\n            max_len = max(len(text_input_ids[0]), len(uncond_input_ids[0]))\n            if len(text_input_ids[0]) < max_len:\n                text_input_ids = torch.cat(\n                    [text_input_ids, torch.zeros(batch_size, max_len - len(text_input_ids[0]), dtype=torch.long)],\n                    dim=1,\n                )\n                prompt_attention_mask = torch.cat(\n                    [\n                        prompt_attention_mask,\n                        torch.zeros(\n                            batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long, device=device\n                        ),\n                    ],\n                    dim=1,\n                )\n            elif len(uncond_input_ids[0]) < max_len:\n                uncond_input_ids = torch.cat(\n                    [uncond_input_ids, torch.zeros(batch_size, max_len - len(uncond_input_ids[0]), dtype=torch.long)],\n                    dim=1,\n                )\n                negative_prompt_attention_mask = torch.cat(\n                    [\n                        negative_prompt_attention_mask,\n                        torch.zeros(\n                            batch_size,\n                            max_len - len(negative_prompt_attention_mask[0]),\n                            dtype=torch.long,\n                            device=device,\n                        ),\n                    ],\n                    dim=1,\n                )\n            cfg_input_ids = torch.cat([uncond_input_ids, text_input_ids], dim=0)\n            cfg_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)\n            prompt_embeds = self.text_encoder(\n                cfg_input_ids.to(device),\n                attention_mask=cfg_attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        if self.text_encoder is not None:\n            if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if not do_classifier_free_guidance:\n            return prompt_embeds, None, prompt_attention_mask, None\n        return prompt_embeds[1], prompt_embeds[0], prompt_attention_mask, negative_prompt_attention_mask\n\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        image_embeds = []\n        if do_classifier_free_guidance:\n            negative_image_embeds = []\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n\n                image_embeds.append(single_image_embeds[None, :])\n                if do_classifier_free_guidance:\n                    negative_image_embeds.append(single_negative_image_embeds[None, :])\n        else:\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    negative_image_embeds.append(single_negative_image_embeds)\n                image_embeds.append(single_image_embeds)\n\n        ip_adapter_image_embeds = []\n        for i, single_image_embeds in enumerate(image_embeds):\n            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)\n            if do_classifier_free_guidance:\n                single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)\n                single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)\n\n            single_image_embeds = single_image_embeds.to(device=device)\n            ip_adapter_image_embeds.append(single_image_embeds)\n\n        return ip_adapter_image_embeds\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n        if ip_adapter_image_embeds is not None:\n            if not isinstance(ip_adapter_image_embeds, list):\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\n                )\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\n                )\n\n    def prepare_latents(\n        self, batch_size, num_channels_latents, height, width, dtype, device, generator, scales, latents=None\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height),\n            int(width),\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            if scales is not None:\n                out = [latents]\n                for s in scales[1:]:\n                    ratio = scales[0] // s\n                    sample_low = F.avg_pool2d(latents, ratio) * ratio\n                    sample_low = sample_low.normal_(generator=generator)\n                    out += [sample_low]\n                latents = out\n        else:\n            if scales is not None:\n                latents = [latent.to(device=device) for latent in latents]\n            else:\n                latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        if scales is not None:\n            latents = [latent * self.scheduler.init_noise_sigma for latent in latents]\n        else:\n            latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(\n        self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32\n    ) -> torch.Tensor:\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            w (`torch.Tensor`):\n                Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.\n            embedding_dim (`int`, *optional*, defaults to 512):\n                Dimension of the embeddings to generate.\n            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):\n                Data type of the generated embeddings.\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        sigmas: List[float] = None,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when\n                using zero terminal SNR.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~MatryoshkaPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~MatryoshkaPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size\n        width = width or self.unet.config.sample_size\n        # to deal with lora scaling and other possible forward hooks\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_attention_mask,\n        ) = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds.unsqueeze(0), prompt_embeds.unsqueeze(0)])\n            attention_masks = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])\n        else:\n            attention_masks = prompt_attention_mask\n\n        prompt_embeds = prompt_embeds * attention_masks.unsqueeze(-1)\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler, num_inference_steps, device, timesteps, sigmas\n        )\n        timesteps = timesteps[:-1]\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            self.scheduler.scales,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n        extra_step_kwargs |= {\"use_clipped_model_output\": True}\n\n        # 6.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = (\n            {\"image_embeds\": image_embeds}\n            if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)\n            else None\n        )\n\n        # 6.2 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # expand the latents if we are doing classifier free guidance\n                if self.do_classifier_free_guidance and isinstance(latents, list):\n                    latent_model_input = [latent.repeat(2, 1, 1, 1) for latent in latents]\n                elif self.do_classifier_free_guidance:\n                    latent_model_input = latents.repeat(2, 1, 1, 1)\n                else:\n                    latent_model_input = latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t - 1,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    encoder_attention_mask=attention_masks,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if isinstance(noise_pred, list) and self.do_classifier_free_guidance:\n                    for i, (noise_pred_uncond, noise_pred_text) in enumerate(noise_pred):\n                        noise_pred[i] = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n                elif self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        image = latents\n\n        if self.scheduler.scales is not None:\n            for i, img in enumerate(image):\n                image[i] = self.image_processor.postprocess(img, output_type=output_type)[0]\n        else:\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return MatryoshkaPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/mixture_canvas.py",
    "content": "import re\nfrom copy import deepcopy\nfrom dataclasses import asdict, dataclass\nfrom enum import Enum\nfrom typing import List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom numpy import exp, pi, sqrt\nfrom torchvision.transforms.functional import resize\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler\n\n\ndef preprocess_image(image):\n    from PIL import Image\n\n    \"\"\"Preprocess an input image\n\n    Same as\n    https://github.com/huggingface/diffusers/blob/1138d63b519e37f0ce04e027b9f4a3261d27c628/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L44\n    \"\"\"\n    w, h = image.size\n    w, h = (x - x % 32 for x in (w, h))  # resize to integer multiple of 32\n    image = image.resize((w, h), resample=Image.LANCZOS)\n    image = np.array(image).astype(np.float32) / 255.0\n    image = image[None].transpose(0, 3, 1, 2)\n    image = torch.from_numpy(image)\n    return 2.0 * image - 1.0\n\n\n@dataclass\nclass CanvasRegion:\n    \"\"\"Class defining a rectangular region in the canvas\"\"\"\n\n    row_init: int  # Region starting row in pixel space (included)\n    row_end: int  # Region end row in pixel space (not included)\n    col_init: int  # Region starting column in pixel space (included)\n    col_end: int  # Region end column in pixel space (not included)\n    region_seed: int = None  # Seed for random operations in this region\n    noise_eps: float = 0.0  # Deviation of a zero-mean gaussian noise to be applied over the latents in this region. Useful for slightly \"rerolling\" latents\n\n    def __post_init__(self):\n        # Initialize arguments if not specified\n        if self.region_seed is None:\n            self.region_seed = np.random.randint(9999999999)\n        # Check coordinates are non-negative\n        for coord in [self.row_init, self.row_end, self.col_init, self.col_end]:\n            if coord < 0:\n                raise ValueError(\n                    f\"A CanvasRegion must be defined with non-negative indices, found ({self.row_init}, {self.row_end}, {self.col_init}, {self.col_end})\"\n                )\n        # Check coordinates are divisible by 8, else we end up with nasty rounding error when mapping to latent space\n        for coord in [self.row_init, self.row_end, self.col_init, self.col_end]:\n            if coord // 8 != coord / 8:\n                raise ValueError(\n                    f\"A CanvasRegion must be defined with locations divisible by 8, found ({self.row_init}-{self.row_end}, {self.col_init}-{self.col_end})\"\n                )\n        # Check noise eps is non-negative\n        if self.noise_eps < 0:\n            raise ValueError(f\"A CanvasRegion must be defined noises eps non-negative, found {self.noise_eps}\")\n        # Compute coordinates for this region in latent space\n        self.latent_row_init = self.row_init // 8\n        self.latent_row_end = self.row_end // 8\n        self.latent_col_init = self.col_init // 8\n        self.latent_col_end = self.col_end // 8\n\n    @property\n    def width(self):\n        return self.col_end - self.col_init\n\n    @property\n    def height(self):\n        return self.row_end - self.row_init\n\n    def get_region_generator(self, device=\"cpu\"):\n        \"\"\"Creates a torch.Generator based on the random seed of this region\"\"\"\n        # Initialize region generator\n        return torch.Generator(device).manual_seed(self.region_seed)\n\n    @property\n    def __dict__(self):\n        return asdict(self)\n\n\nclass MaskModes(Enum):\n    \"\"\"Modes in which the influence of diffuser is masked\"\"\"\n\n    CONSTANT = \"constant\"\n    GAUSSIAN = \"gaussian\"\n    QUARTIC = \"quartic\"  # See https://en.wikipedia.org/wiki/Kernel_(statistics)\n\n\n@dataclass\nclass DiffusionRegion(CanvasRegion):\n    \"\"\"Abstract class defining a region where some class of diffusion process is acting\"\"\"\n\n    pass\n\n\n@dataclass\nclass Text2ImageRegion(DiffusionRegion):\n    \"\"\"Class defining a region where a text guided diffusion process is acting\"\"\"\n\n    prompt: str = \"\"  # Text prompt guiding the diffuser in this region\n    guidance_scale: float = 7.5  # Guidance scale of the diffuser in this region. If None, randomize\n    mask_type: MaskModes = MaskModes.GAUSSIAN.value  # Kind of weight mask applied to this region\n    mask_weight: float = 1.0  # Global weights multiplier of the mask\n    tokenized_prompt = None  # Tokenized prompt\n    encoded_prompt = None  # Encoded prompt\n\n    def __post_init__(self):\n        super().__post_init__()\n        # Mask weight cannot be negative\n        if self.mask_weight < 0:\n            raise ValueError(\n                f\"A Text2ImageRegion must be defined with non-negative mask weight, found {self.mask_weight}\"\n            )\n        # Mask type must be an actual known mask\n        if self.mask_type not in [e.value for e in MaskModes]:\n            raise ValueError(\n                f\"A Text2ImageRegion was defined with mask {self.mask_type}, which is not an accepted mask ({[e.value for e in MaskModes]})\"\n            )\n        # Randomize arguments if given as None\n        if self.guidance_scale is None:\n            self.guidance_scale = np.random.randint(5, 30)\n        # Clean prompt\n        self.prompt = re.sub(\" +\", \" \", self.prompt).replace(\"\\n\", \" \")\n\n    def tokenize_prompt(self, tokenizer):\n        \"\"\"Tokenizes the prompt for this diffusion region using a given tokenizer\"\"\"\n        self.tokenized_prompt = tokenizer(\n            self.prompt,\n            padding=\"max_length\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n    def encode_prompt(self, text_encoder, device):\n        \"\"\"Encodes the previously tokenized prompt for this diffusion region using a given encoder\"\"\"\n        assert self.tokenized_prompt is not None, ValueError(\n            \"Prompt in diffusion region must be tokenized before encoding\"\n        )\n        self.encoded_prompt = text_encoder(self.tokenized_prompt.input_ids.to(device))[0]\n\n\n@dataclass\nclass Image2ImageRegion(DiffusionRegion):\n    \"\"\"Class defining a region where an image guided diffusion process is acting\"\"\"\n\n    reference_image: torch.Tensor = None\n    strength: float = 0.8  # Strength of the image\n\n    def __post_init__(self):\n        super().__post_init__()\n        if self.reference_image is None:\n            raise ValueError(\"Must provide a reference image when creating an Image2ImageRegion\")\n        if self.strength < 0 or self.strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {self.strength}\")\n        # Rescale image to region shape\n        self.reference_image = resize(self.reference_image, size=[self.height, self.width])\n\n    def encode_reference_image(self, encoder, device, generator, cpu_vae=False):\n        \"\"\"Encodes the reference image for this Image2Image region into the latent space\"\"\"\n        # Place encoder in CPU or not following the parameter cpu_vae\n        if cpu_vae:\n            # Note here we use mean instead of sample, to avoid moving also generator to CPU, which is troublesome\n            self.reference_latents = encoder.cpu().encode(self.reference_image).latent_dist.mean.to(device)\n        else:\n            self.reference_latents = encoder.encode(self.reference_image.to(device)).latent_dist.sample(\n                generator=generator\n            )\n        self.reference_latents = 0.18215 * self.reference_latents\n\n    @property\n    def __dict__(self):\n        # This class requires special casting to dict because of the reference_image tensor. Otherwise it cannot be casted to JSON\n\n        # Get all basic fields from parent class\n        super_fields = {key: getattr(self, key) for key in DiffusionRegion.__dataclass_fields__.keys()}\n        # Pack other fields\n        return {**super_fields, \"reference_image\": self.reference_image.cpu().tolist(), \"strength\": self.strength}\n\n\nclass RerollModes(Enum):\n    \"\"\"Modes in which the reroll regions operate\"\"\"\n\n    RESET = \"reset\"  # Completely reset the random noise in the region\n    EPSILON = \"epsilon\"  # Alter slightly the latents in the region\n\n\n@dataclass\nclass RerollRegion(CanvasRegion):\n    \"\"\"Class defining a rectangular canvas region in which initial latent noise will be rerolled\"\"\"\n\n    reroll_mode: RerollModes = RerollModes.RESET.value\n\n\n@dataclass\nclass MaskWeightsBuilder:\n    \"\"\"Auxiliary class to compute a tensor of weights for a given diffusion region\"\"\"\n\n    latent_space_dim: int  # Size of the U-net latent space\n    nbatch: int = 1  # Batch size in the U-net\n\n    def compute_mask_weights(self, region: DiffusionRegion) -> torch.tensor:\n        \"\"\"Computes a tensor of weights for a given diffusion region\"\"\"\n        MASK_BUILDERS = {\n            MaskModes.CONSTANT.value: self._constant_weights,\n            MaskModes.GAUSSIAN.value: self._gaussian_weights,\n            MaskModes.QUARTIC.value: self._quartic_weights,\n        }\n        return MASK_BUILDERS[region.mask_type](region)\n\n    def _constant_weights(self, region: DiffusionRegion) -> torch.tensor:\n        \"\"\"Computes a tensor of constant for a given diffusion region\"\"\"\n        latent_width = region.latent_col_end - region.latent_col_init\n        latent_height = region.latent_row_end - region.latent_row_init\n        return torch.ones(self.nbatch, self.latent_space_dim, latent_height, latent_width) * region.mask_weight\n\n    def _gaussian_weights(self, region: DiffusionRegion) -> torch.tensor:\n        \"\"\"Generates a gaussian mask of weights for tile contributions\"\"\"\n        latent_width = region.latent_col_end - region.latent_col_init\n        latent_height = region.latent_row_end - region.latent_row_init\n\n        var = 0.01\n        midpoint = (latent_width - 1) / 2  # -1 because index goes from 0 to latent_width - 1\n        x_probs = [\n            exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)\n            for x in range(latent_width)\n        ]\n        midpoint = (latent_height - 1) / 2\n        y_probs = [\n            exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)\n            for y in range(latent_height)\n        ]\n\n        weights = np.outer(y_probs, x_probs) * region.mask_weight\n        return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1))\n\n    def _quartic_weights(self, region: DiffusionRegion) -> torch.tensor:\n        \"\"\"Generates a quartic mask of weights for tile contributions\n\n        The quartic kernel has bounded support over the diffusion region, and a smooth decay to the region limits.\n        \"\"\"\n        quartic_constant = 15.0 / 16.0\n\n        support = (np.array(range(region.latent_col_init, region.latent_col_end)) - region.latent_col_init) / (\n            region.latent_col_end - region.latent_col_init - 1\n        ) * 1.99 - (1.99 / 2.0)\n        x_probs = quartic_constant * np.square(1 - np.square(support))\n        support = (np.array(range(region.latent_row_init, region.latent_row_end)) - region.latent_row_init) / (\n            region.latent_row_end - region.latent_row_init - 1\n        ) * 1.99 - (1.99 / 2.0)\n        y_probs = quartic_constant * np.square(1 - np.square(support))\n\n        weights = np.outer(y_probs, x_probs) * region.mask_weight\n        return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1))\n\n\nclass StableDiffusionCanvasPipeline(DiffusionPipeline, StableDiffusionMixin):\n    \"\"\"Stable Diffusion pipeline that mixes several diffusers in the same canvas\"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n    def decode_latents(self, latents, cpu_vae=False):\n        \"\"\"Decodes a given array of latents into pixel space\"\"\"\n        # scale and decode the image latents with vae\n        if cpu_vae:\n            lat = deepcopy(latents).cpu()\n            vae = deepcopy(self.vae).cpu()\n        else:\n            lat = latents\n            vae = self.vae\n\n        lat = 1 / 0.18215 * lat\n        image = vae.decode(lat).sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).numpy()\n\n        return self.numpy_to_pil(image)\n\n    def get_latest_timestep_img2img(self, num_inference_steps, strength):\n        \"\"\"Finds the latest timesteps where an img2img strength does not impose latents anymore\"\"\"\n        # get the original timestep using init_timestep\n        offset = self.scheduler.config.get(\"steps_offset\", 0)\n        init_timestep = int(num_inference_steps * (1 - strength)) + offset\n        init_timestep = min(init_timestep, num_inference_steps)\n\n        t_start = min(max(num_inference_steps - init_timestep + offset, 0), num_inference_steps - 1)\n        latest_timestep = self.scheduler.timesteps[t_start]\n\n        return latest_timestep\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        canvas_height: int,\n        canvas_width: int,\n        regions: List[DiffusionRegion],\n        num_inference_steps: Optional[int] = 50,\n        seed: Optional[int] = 12345,\n        reroll_regions: Optional[List[RerollRegion]] = None,\n        cpu_vae: Optional[bool] = False,\n        decode_steps: Optional[bool] = False,\n    ):\n        if reroll_regions is None:\n            reroll_regions = []\n        batch_size = 1\n\n        if decode_steps:\n            steps_images = []\n\n        # Prepare scheduler\n        self.scheduler.set_timesteps(num_inference_steps, device=self.device)\n\n        # Split diffusion regions by their kind\n        text2image_regions = [region for region in regions if isinstance(region, Text2ImageRegion)]\n        image2image_regions = [region for region in regions if isinstance(region, Image2ImageRegion)]\n\n        # Prepare text embeddings\n        for region in text2image_regions:\n            region.tokenize_prompt(self.tokenizer)\n            region.encode_prompt(self.text_encoder, self.device)\n\n        # Create original noisy latents using the timesteps\n        latents_shape = (batch_size, self.unet.config.in_channels, canvas_height // 8, canvas_width // 8)\n        generator = torch.Generator(self.device).manual_seed(seed)\n        init_noise = torch.randn(latents_shape, generator=generator, device=self.device)\n\n        # Reset latents in seed reroll regions, if requested\n        for region in reroll_regions:\n            if region.reroll_mode == RerollModes.RESET.value:\n                region_shape = (\n                    latents_shape[0],\n                    latents_shape[1],\n                    region.latent_row_end - region.latent_row_init,\n                    region.latent_col_end - region.latent_col_init,\n                )\n                init_noise[\n                    :,\n                    :,\n                    region.latent_row_init : region.latent_row_end,\n                    region.latent_col_init : region.latent_col_end,\n                ] = torch.randn(region_shape, generator=region.get_region_generator(self.device), device=self.device)\n\n        # Apply epsilon noise to regions: first diffusion regions, then reroll regions\n        all_eps_rerolls = regions + [r for r in reroll_regions if r.reroll_mode == RerollModes.EPSILON.value]\n        for region in all_eps_rerolls:\n            if region.noise_eps > 0:\n                region_noise = init_noise[\n                    :,\n                    :,\n                    region.latent_row_init : region.latent_row_end,\n                    region.latent_col_init : region.latent_col_end,\n                ]\n                eps_noise = (\n                    torch.randn(\n                        region_noise.shape, generator=region.get_region_generator(self.device), device=self.device\n                    )\n                    * region.noise_eps\n                )\n                init_noise[\n                    :,\n                    :,\n                    region.latent_row_init : region.latent_row_end,\n                    region.latent_col_init : region.latent_col_end,\n                ] += eps_noise\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = init_noise * self.scheduler.init_noise_sigma\n\n        # Get unconditional embeddings for classifier free guidance in text2image regions\n        for region in text2image_regions:\n            max_length = region.tokenized_prompt.input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                [\"\"] * batch_size, padding=\"max_length\", max_length=max_length, return_tensors=\"pt\"\n            )\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            region.encoded_prompt = torch.cat([uncond_embeddings, region.encoded_prompt])\n\n        # Prepare image latents\n        for region in image2image_regions:\n            region.encode_reference_image(self.vae, device=self.device, generator=generator)\n\n        # Prepare mask of weights for each region\n        mask_builder = MaskWeightsBuilder(latent_space_dim=self.unet.config.in_channels, nbatch=batch_size)\n        mask_weights = [mask_builder.compute_mask_weights(region).to(self.device) for region in text2image_regions]\n\n        # Diffusion timesteps\n        for i, t in tqdm(enumerate(self.scheduler.timesteps)):\n            # Diffuse each region\n            noise_preds_regions = []\n\n            # text2image regions\n            for region in text2image_regions:\n                region_latents = latents[\n                    :,\n                    :,\n                    region.latent_row_init : region.latent_row_end,\n                    region.latent_col_init : region.latent_col_end,\n                ]\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([region_latents] * 2)\n                # scale model input following scheduler rules\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n                # predict the noise residual\n                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=region.encoded_prompt)[\"sample\"]\n                # perform guidance\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred_region = noise_pred_uncond + region.guidance_scale * (noise_pred_text - noise_pred_uncond)\n                noise_preds_regions.append(noise_pred_region)\n\n            # Merge noise predictions for all tiles\n            noise_pred = torch.zeros(latents.shape, device=self.device)\n            contributors = torch.zeros(latents.shape, device=self.device)\n            # Add each tile contribution to overall latents\n            for region, noise_pred_region, mask_weights_region in zip(\n                text2image_regions, noise_preds_regions, mask_weights\n            ):\n                noise_pred[\n                    :,\n                    :,\n                    region.latent_row_init : region.latent_row_end,\n                    region.latent_col_init : region.latent_col_end,\n                ] += noise_pred_region * mask_weights_region\n                contributors[\n                    :,\n                    :,\n                    region.latent_row_init : region.latent_row_end,\n                    region.latent_col_init : region.latent_col_end,\n                ] += mask_weights_region\n            # Average overlapping areas with more than 1 contributor\n            noise_pred /= contributors\n            noise_pred = torch.nan_to_num(\n                noise_pred\n            )  # Replace NaNs by zeros: NaN can appear if a position is not covered by any DiffusionRegion\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n            # Image2Image regions: override latents generated by the scheduler\n            for region in image2image_regions:\n                influence_step = self.get_latest_timestep_img2img(num_inference_steps, region.strength)\n                # Only override in the timesteps before the last influence step of the image (given by its strength)\n                if t > influence_step:\n                    timestep = t.repeat(batch_size)\n                    region_init_noise = init_noise[\n                        :,\n                        :,\n                        region.latent_row_init : region.latent_row_end,\n                        region.latent_col_init : region.latent_col_end,\n                    ]\n                    region_latents = self.scheduler.add_noise(region.reference_latents, region_init_noise, timestep)\n                    latents[\n                        :,\n                        :,\n                        region.latent_row_init : region.latent_row_end,\n                        region.latent_col_init : region.latent_col_end,\n                    ] = region_latents\n\n            if decode_steps:\n                steps_images.append(self.decode_latents(latents, cpu_vae))\n\n        # scale and decode the image latents with vae\n        image = self.decode_latents(latents, cpu_vae)\n\n        output = {\"images\": image}\n        if decode_steps:\n            output = {**output, \"steps_images\": steps_images}\n        return output\n"
  },
  {
    "path": "examples/community/mixture_tiling.py",
    "content": "import inspect\nfrom copy import deepcopy\nfrom enum import Enum\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom tqdm.auto import tqdm\n\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler\nfrom diffusers.utils import logging\n\n\ntry:\n    from ligo.segments import segment\n    from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\nexcept ImportError:\n    raise ImportError(\"Please install transformers and ligo-segments to use the mixture pipeline\")\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> from diffusers import LMSDiscreteScheduler, DiffusionPipeline\n\n        >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000)\n        >>> pipeline = DiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", scheduler=scheduler, custom_pipeline=\"mixture_tiling\")\n        >>> pipeline.to(\"cuda\")\n\n        >>> image = pipeline(\n        >>>     prompt=[[\n        >>>         \"A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece\",\n        >>>         \"A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece\",\n        >>>         \"An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece\"\n        >>>     ]],\n        >>>     tile_height=640,\n        >>>     tile_width=640,\n        >>>     tile_row_overlap=0,\n        >>>     tile_col_overlap=256,\n        >>>     guidance_scale=8,\n        >>>     seed=7178915308,\n        >>>     num_inference_steps=50,\n    >>> )[\"images\"][0]\n        ```\n\"\"\"\n\n\ndef _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):\n    \"\"\"Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image\n\n    Returns a tuple with:\n        - Starting coordinates of rows in pixel space\n        - Ending coordinates of rows in pixel space\n        - Starting coordinates of columns in pixel space\n        - Ending coordinates of columns in pixel space\n    \"\"\"\n    px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap)\n    px_row_end = px_row_init + tile_height\n    px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap)\n    px_col_end = px_col_init + tile_width\n    return px_row_init, px_row_end, px_col_init, px_col_end\n\n\ndef _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end):\n    \"\"\"Translates coordinates in pixel space to coordinates in latent space\"\"\"\n    return px_row_init // 8, px_row_end // 8, px_col_init // 8, px_col_end // 8\n\n\ndef _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):\n    \"\"\"Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image\n\n    Returns a tuple with:\n        - Starting coordinates of rows in latent space\n        - Ending coordinates of rows in latent space\n        - Starting coordinates of columns in latent space\n        - Ending coordinates of columns in latent space\n    \"\"\"\n    px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(\n        tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap\n    )\n    return _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end)\n\n\ndef _tile2latent_exclusive_indices(\n    tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, rows, columns\n):\n    \"\"\"Given a tile row and column numbers returns the range of latents affected only by that tile in the overall image\n\n    Returns a tuple with:\n        - Starting coordinates of rows in latent space\n        - Ending coordinates of rows in latent space\n        - Starting coordinates of columns in latent space\n        - Ending coordinates of columns in latent space\n    \"\"\"\n    row_init, row_end, col_init, col_end = _tile2latent_indices(\n        tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap\n    )\n    row_segment = segment(row_init, row_end)\n    col_segment = segment(col_init, col_end)\n    # Iterate over the rest of tiles, clipping the region for the current tile\n    for row in range(rows):\n        for column in range(columns):\n            if row != tile_row and column != tile_col:\n                clip_row_init, clip_row_end, clip_col_init, clip_col_end = _tile2latent_indices(\n                    row, column, tile_width, tile_height, tile_row_overlap, tile_col_overlap\n                )\n                row_segment = row_segment - segment(clip_row_init, clip_row_end)\n                col_segment = col_segment - segment(clip_col_init, clip_col_end)\n    # return row_init, row_end, col_init, col_end\n    return row_segment[0], row_segment[1], col_segment[0], col_segment[1]\n\n\nclass StableDiffusionExtrasMixin:\n    \"\"\"Mixin providing additional convenience method to Stable Diffusion pipelines\"\"\"\n\n    def decode_latents(self, latents, cpu_vae=False):\n        \"\"\"Decodes a given array of latents into pixel space\"\"\"\n        # scale and decode the image latents with vae\n        if cpu_vae:\n            lat = deepcopy(latents).cpu()\n            vae = deepcopy(self.vae).cpu()\n        else:\n            lat = latents\n            vae = self.vae\n\n        lat = 1 / 0.18215 * lat\n        image = vae.decode(lat).sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).numpy()\n\n        return self.numpy_to_pil(image)\n\n\nclass StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixin):\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n    class SeedTilesMode(Enum):\n        \"\"\"Modes in which the latents of a particular tile can be re-seeded\"\"\"\n\n        FULL = \"full\"\n        EXCLUSIVE = \"exclusive\"\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[List[str]]],\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        eta: Optional[float] = 0.0,\n        seed: Optional[int] = None,\n        tile_height: Optional[int] = 512,\n        tile_width: Optional[int] = 512,\n        tile_row_overlap: Optional[int] = 256,\n        tile_col_overlap: Optional[int] = 256,\n        guidance_scale_tiles: Optional[List[List[float]]] = None,\n        seed_tiles: Optional[List[List[int]]] = None,\n        seed_tiles_mode: Optional[Union[str, List[List[str]]]] = \"full\",\n        seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,\n        cpu_vae: Optional[bool] = False,\n    ):\n        r\"\"\"\n        Function to run the diffusion pipeline with tiling support.\n\n        Args:\n            prompt: either a single string (no tiling) or a list of lists with all the prompts to use (one list for each row of tiles). This will also define the tiling structure.\n            num_inference_steps: number of diffusions steps.\n            guidance_scale: classifier-free guidance.\n            seed: general random seed to initialize latents.\n            tile_height: height in pixels of each grid tile.\n            tile_width: width in pixels of each grid tile.\n            tile_row_overlap: number of overlap pixels between tiles in consecutive rows.\n            tile_col_overlap: number of overlap pixels between tiles in consecutive columns.\n            guidance_scale_tiles: specific weights for classifier-free guidance in each tile.\n            guidance_scale_tiles: specific weights for classifier-free guidance in each tile. If None, the value provided in guidance_scale will be used.\n            seed_tiles: specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard seed parameter.\n            seed_tiles_mode: either \"full\" \"exclusive\". If \"full\", all the latents affected by the tile be overridden. If \"exclusive\", only the latents that are affected exclusively by this tile (and no other tiles) will be overridden.\n            seed_reroll_regions: a list of tuples in the form (start row, end row, start column, end column, seed) defining regions in pixel space for which the latents will be overridden using the given seed. Takes priority over seed_tiles.\n            cpu_vae: the decoder from latent space to pixel space can require too much GPU RAM for large images. If you find out of memory errors at the end of the generation process, try setting this parameter to True to run the decoder in CPU. Slower, but should run without memory issues.\n\n        Examples:\n\n        Returns:\n            A PIL image with the generated image.\n\n        \"\"\"\n        if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):\n            raise ValueError(f\"`prompt` has to be a list of lists but is {type(prompt)}\")\n        grid_rows = len(prompt)\n        grid_cols = len(prompt[0])\n        if not all(len(row) == grid_cols for row in prompt):\n            raise ValueError(\"All prompt rows must have the same number of prompt columns\")\n        if not isinstance(seed_tiles_mode, str) and (\n            not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)\n        ):\n            raise ValueError(f\"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}\")\n        if isinstance(seed_tiles_mode, str):\n            seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]\n\n        modes = [mode.value for mode in self.SeedTilesMode]\n        if any(mode not in modes for row in seed_tiles_mode for mode in row):\n            raise ValueError(f\"Seed tiles mode must be one of {modes}\")\n        if seed_reroll_regions is None:\n            seed_reroll_regions = []\n        batch_size = 1\n\n        # create original noisy latents using the timesteps\n        height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)\n        width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)\n        latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)\n        generator = torch.Generator(\"cuda\").manual_seed(seed)\n        latents = torch.randn(latents_shape, generator=generator, device=self.device)\n\n        # overwrite latents for specific tiles if provided\n        if seed_tiles is not None:\n            for row in range(grid_rows):\n                for col in range(grid_cols):\n                    if (seed_tile := seed_tiles[row][col]) is not None:\n                        mode = seed_tiles_mode[row][col]\n                        if mode == self.SeedTilesMode.FULL.value:\n                            row_init, row_end, col_init, col_end = _tile2latent_indices(\n                                row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap\n                            )\n                        else:\n                            row_init, row_end, col_init, col_end = _tile2latent_exclusive_indices(\n                                row,\n                                col,\n                                tile_width,\n                                tile_height,\n                                tile_row_overlap,\n                                tile_col_overlap,\n                                grid_rows,\n                                grid_cols,\n                            )\n                        tile_generator = torch.Generator(\"cuda\").manual_seed(seed_tile)\n                        tile_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)\n                        latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(\n                            tile_shape, generator=tile_generator, device=self.device\n                        )\n\n        # overwrite again for seed reroll regions\n        for row_init, row_end, col_init, col_end, seed_reroll in seed_reroll_regions:\n            row_init, row_end, col_init, col_end = _pixel2latent_indices(\n                row_init, row_end, col_init, col_end\n            )  # to latent space coordinates\n            reroll_generator = torch.Generator(\"cuda\").manual_seed(seed_reroll)\n            region_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)\n            latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(\n                region_shape, generator=reroll_generator, device=self.device\n            )\n\n        # Prepare scheduler\n        accepts_offset = \"offset\" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())\n        extra_set_kwargs = {}\n        if accepts_offset:\n            extra_set_kwargs[\"offset\"] = 1\n        self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)\n        # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas\n        if isinstance(self.scheduler, LMSDiscreteScheduler):\n            latents = latents * self.scheduler.sigmas[0]\n\n        # get prompts text embeddings\n        text_input = [\n            [\n                self.tokenizer(\n                    col,\n                    padding=\"max_length\",\n                    max_length=self.tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n                for col in row\n            ]\n            for row in prompt\n        ]\n        text_embeddings = [[self.text_encoder(col.input_ids.to(self.device))[0] for col in row] for row in text_input]\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0  # TODO: also active if any tile has guidance scale\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            for i in range(grid_rows):\n                for j in range(grid_cols):\n                    max_length = text_input[i][j].input_ids.shape[-1]\n                    uncond_input = self.tokenizer(\n                        [\"\"] * batch_size, padding=\"max_length\", max_length=max_length, return_tensors=\"pt\"\n                    )\n                    uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n                    # For classifier free guidance, we need to do two forward passes.\n                    # Here we concatenate the unconditional and text embeddings into a single batch\n                    # to avoid doing two forward passes\n                    text_embeddings[i][j] = torch.cat([uncond_embeddings, text_embeddings[i][j]])\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # Mask for tile weights strength\n        tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size)\n\n        # Diffusion timesteps\n        for i, t in tqdm(enumerate(self.scheduler.timesteps)):\n            # Diffuse each tile\n            noise_preds = []\n            for row in range(grid_rows):\n                noise_preds_row = []\n                for col in range(grid_cols):\n                    px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(\n                        row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap\n                    )\n                    tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]\n                    # expand the latents if we are doing classifier free guidance\n                    latent_model_input = torch.cat([tile_latents] * 2) if do_classifier_free_guidance else tile_latents\n                    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n                    # predict the noise residual\n                    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings[row][col])[\n                        \"sample\"\n                    ]\n                    # perform guidance\n                    if do_classifier_free_guidance:\n                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                        guidance = (\n                            guidance_scale\n                            if guidance_scale_tiles is None or guidance_scale_tiles[row][col] is None\n                            else guidance_scale_tiles[row][col]\n                        )\n                        noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)\n                        noise_preds_row.append(noise_pred_tile)\n                noise_preds.append(noise_preds_row)\n            # Stitch noise predictions for all tiles\n            noise_pred = torch.zeros(latents.shape, device=self.device)\n            contributors = torch.zeros(latents.shape, device=self.device)\n            # Add each tile contribution to overall latents\n            for row in range(grid_rows):\n                for col in range(grid_cols):\n                    px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(\n                        row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap\n                    )\n                    noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += (\n                        noise_preds[row][col] * tile_weights\n                    )\n                    contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights\n            # Average overlapping areas with more than 1 contributor\n            noise_pred /= contributors\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n        # scale and decode the image latents with vae\n        image = self.decode_latents(latents, cpu_vae)\n\n        return {\"images\": image}\n\n    def _gaussian_weights(self, tile_width, tile_height, nbatches):\n        \"\"\"Generates a gaussian mask of weights for tile contributions\"\"\"\n        import numpy as np\n        from numpy import exp, pi, sqrt\n\n        latent_width = tile_width // 8\n        latent_height = tile_height // 8\n\n        var = 0.01\n        midpoint = (latent_width - 1) / 2  # -1 because index goes from 0 to latent_width - 1\n        x_probs = [\n            exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)\n            for x in range(latent_width)\n        ]\n        midpoint = latent_height / 2\n        y_probs = [\n            exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)\n            for y in range(latent_height)\n        ]\n\n        weights = np.outer(y_probs, x_probs)\n        return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))\n"
  },
  {
    "path": "examples/community/mixture_tiling_sdxl.py",
    "content": "# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport inspect\r\nfrom enum import Enum\r\nfrom typing import Any, Dict, List, Optional, Tuple, Union\r\n\r\nimport torch\r\nfrom transformers import (\r\n    CLIPTextModel,\r\n    CLIPTextModelWithProjection,\r\n    CLIPTokenizer,\r\n)\r\n\r\nfrom diffusers.image_processor import VaeImageProcessor\r\nfrom diffusers.loaders import (\r\n    FromSingleFileMixin,\r\n    StableDiffusionXLLoraLoaderMixin,\r\n    TextualInversionLoaderMixin,\r\n)\r\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\r\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\r\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\r\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\r\nfrom diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler\r\nfrom diffusers.utils import (\r\n    USE_PEFT_BACKEND,\r\n    deprecate,\r\n    is_invisible_watermark_available,\r\n    is_torch_xla_available,\r\n    logging,\r\n    replace_example_docstring,\r\n    scale_lora_layers,\r\n    unscale_lora_layers,\r\n)\r\nfrom diffusers.utils.torch_utils import randn_tensor\r\n\r\n\r\ntry:\r\n    from ligo.segments import segment\r\nexcept ImportError:\r\n    raise ImportError(\"Please install transformers and ligo-segments to use the mixture pipeline\")\r\n\r\nif is_invisible_watermark_available():\r\n    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker\r\n\r\nif is_torch_xla_available():\r\n    import torch_xla.core.xla_model as xm\r\n\r\n    XLA_AVAILABLE = True\r\nelse:\r\n    XLA_AVAILABLE = False\r\n\r\n\r\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\r\n\r\nEXAMPLE_DOC_STRING = \"\"\"\r\n    Examples:\r\n        ```py\r\n        >>> import torch\r\n        >>> from diffusers import StableDiffusionXLPipeline\r\n\r\n        >>> pipe = StableDiffusionXLPipeline.from_pretrained(\r\n        ...     \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\r\n        ... )\r\n        >>> pipe = pipe.to(\"cuda\")\r\n\r\n        >>> prompt = \"a photo of an astronaut riding a horse on mars\"\r\n        >>> image = pipe(prompt).images[0]\r\n        ```\r\n\"\"\"\r\n\r\n\r\ndef _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):\r\n    \"\"\"Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image\r\n\r\n    Returns a tuple with:\r\n        - Starting coordinates of rows in pixel space\r\n        - Ending coordinates of rows in pixel space\r\n        - Starting coordinates of columns in pixel space\r\n        - Ending coordinates of columns in pixel space\r\n    \"\"\"\r\n    px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap)\r\n    px_row_end = px_row_init + tile_height\r\n    px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap)\r\n    px_col_end = px_col_init + tile_width\r\n    return px_row_init, px_row_end, px_col_init, px_col_end\r\n\r\n\r\ndef _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end):\r\n    \"\"\"Translates coordinates in pixel space to coordinates in latent space\"\"\"\r\n    return px_row_init // 8, px_row_end // 8, px_col_init // 8, px_col_end // 8\r\n\r\n\r\ndef _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):\r\n    \"\"\"Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image\r\n\r\n    Returns a tuple with:\r\n        - Starting coordinates of rows in latent space\r\n        - Ending coordinates of rows in latent space\r\n        - Starting coordinates of columns in latent space\r\n        - Ending coordinates of columns in latent space\r\n    \"\"\"\r\n    px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(\r\n        tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap\r\n    )\r\n    return _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end)\r\n\r\n\r\ndef _tile2latent_exclusive_indices(\r\n    tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, rows, columns\r\n):\r\n    \"\"\"Given a tile row and column numbers returns the range of latents affected only by that tile in the overall image\r\n\r\n    Returns a tuple with:\r\n        - Starting coordinates of rows in latent space\r\n        - Ending coordinates of rows in latent space\r\n        - Starting coordinates of columns in latent space\r\n        - Ending coordinates of columns in latent space\r\n    \"\"\"\r\n    row_init, row_end, col_init, col_end = _tile2latent_indices(\r\n        tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap\r\n    )\r\n    row_segment = segment(row_init, row_end)\r\n    col_segment = segment(col_init, col_end)\r\n    # Iterate over the rest of tiles, clipping the region for the current tile\r\n    for row in range(rows):\r\n        for column in range(columns):\r\n            if row != tile_row and column != tile_col:\r\n                clip_row_init, clip_row_end, clip_col_init, clip_col_end = _tile2latent_indices(\r\n                    row, column, tile_width, tile_height, tile_row_overlap, tile_col_overlap\r\n                )\r\n                row_segment = row_segment - segment(clip_row_init, clip_row_end)\r\n                col_segment = col_segment - segment(clip_col_init, clip_col_end)\r\n    # return row_init, row_end, col_init, col_end\r\n    return row_segment[0], row_segment[1], col_segment[0], col_segment[1]\r\n\r\n\r\ndef _get_crops_coords_list(num_rows, num_cols, output_width):\r\n    \"\"\"\r\n    Generates a list of lists of `crops_coords_top_left` tuples for focusing on\r\n    different horizontal parts of an image, and repeats this list for the specified\r\n    number of rows in the output structure.\r\n\r\n    This function calculates `crops_coords_top_left` tuples to create horizontal\r\n    focus variations (like left, center, right focus) based on `output_width`\r\n    and `num_cols` (which represents the number of horizontal focus points/columns).\r\n    It then repeats the *list* of these horizontal focus tuples `num_rows` times to\r\n    create the final list of lists output structure.\r\n\r\n    Args:\r\n        num_rows (int): The desired number of rows in the output list of lists.\r\n                          This determines how many times the list of horizontal\r\n                          focus variations will be repeated.\r\n        num_cols (int): The number of horizontal focus points (columns) to generate.\r\n                          This determines how many horizontal focus variations are\r\n                          created based on dividing the `output_width`.\r\n        output_width (int): The desired width of the output image.\r\n\r\n    Returns:\r\n        list[list[tuple[int, int]]]: A list of lists of tuples. Each inner list\r\n                                     contains `num_cols` tuples of `(ctop, cleft)`,\r\n                                     representing horizontal focus points. The outer list\r\n                                     contains `num_rows` such inner lists.\r\n    \"\"\"\r\n    crops_coords_list = []\r\n    if num_cols <= 0:\r\n        crops_coords_list = []\r\n    elif num_cols == 1:\r\n        crops_coords_list = [(0, 0)]\r\n    else:\r\n        section_width = output_width / num_cols\r\n        for i in range(num_cols):\r\n            cleft = int(round(i * section_width))\r\n            crops_coords_list.append((0, cleft))\r\n\r\n    result_list = []\r\n    for _ in range(num_rows):\r\n        result_list.append(list(crops_coords_list))\r\n\r\n    return result_list\r\n\r\n\r\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\r\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\r\n    r\"\"\"\r\n    Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on\r\n    Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are\r\n    Flawed](https://huggingface.co/papers/2305.08891).\r\n\r\n    Args:\r\n        noise_cfg (`torch.Tensor`):\r\n            The predicted noise tensor for the guided diffusion process.\r\n        noise_pred_text (`torch.Tensor`):\r\n            The predicted noise tensor for the text-guided diffusion process.\r\n        guidance_rescale (`float`, *optional*, defaults to 0.0):\r\n            A rescale factor applied to the noise predictions.\r\n\r\n    Returns:\r\n        noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.\r\n    \"\"\"\r\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\r\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\r\n    # rescale the results from guidance (fixes overexposure)\r\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\r\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\r\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\r\n    return noise_cfg\r\n\r\n\r\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\r\ndef retrieve_timesteps(\r\n    scheduler,\r\n    num_inference_steps: Optional[int] = None,\r\n    device: Optional[Union[str, torch.device]] = None,\r\n    timesteps: Optional[List[int]] = None,\r\n    sigmas: Optional[List[float]] = None,\r\n    **kwargs,\r\n):\r\n    r\"\"\"\r\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\r\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\r\n\r\n    Args:\r\n        scheduler (`SchedulerMixin`):\r\n            The scheduler to get timesteps from.\r\n        num_inference_steps (`int`):\r\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\r\n            must be `None`.\r\n        device (`str` or `torch.device`, *optional*):\r\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\r\n        timesteps (`List[int]`, *optional*):\r\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\r\n            `num_inference_steps` and `sigmas` must be `None`.\r\n        sigmas (`List[float]`, *optional*):\r\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\r\n            `num_inference_steps` and `timesteps` must be `None`.\r\n\r\n    Returns:\r\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\r\n        second element is the number of inference steps.\r\n    \"\"\"\r\n\r\n    if timesteps is not None and sigmas is not None:\r\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\r\n    if timesteps is not None:\r\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\r\n        if not accepts_timesteps:\r\n            raise ValueError(\r\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\r\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\r\n            )\r\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\r\n        timesteps = scheduler.timesteps\r\n        num_inference_steps = len(timesteps)\r\n    elif sigmas is not None:\r\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\r\n        if not accept_sigmas:\r\n            raise ValueError(\r\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\r\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\r\n            )\r\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\r\n        timesteps = scheduler.timesteps\r\n        num_inference_steps = len(timesteps)\r\n    else:\r\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\r\n        timesteps = scheduler.timesteps\r\n    return timesteps, num_inference_steps\r\n\r\n\r\nclass StableDiffusionXLTilingPipeline(\r\n    DiffusionPipeline,\r\n    StableDiffusionMixin,\r\n    FromSingleFileMixin,\r\n    StableDiffusionXLLoraLoaderMixin,\r\n    TextualInversionLoaderMixin,\r\n):\r\n    r\"\"\"\r\n    Pipeline for text-to-image generation using Stable Diffusion XL.\r\n\r\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\r\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\r\n\r\n    The pipeline also inherits the following loading methods:\r\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\r\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\r\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\r\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\r\n\r\n    Args:\r\n        vae ([`AutoencoderKL`]):\r\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\r\n        text_encoder ([`CLIPTextModel`]):\r\n            Frozen text-encoder. Stable Diffusion XL uses the text portion of\r\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\r\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\r\n        text_encoder_2 ([` CLIPTextModelWithProjection`]):\r\n            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of\r\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\r\n            specifically the\r\n            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)\r\n            variant.\r\n        tokenizer (`CLIPTokenizer`):\r\n            Tokenizer of class\r\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\r\n        tokenizer_2 (`CLIPTokenizer`):\r\n            Second Tokenizer of class\r\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\r\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\r\n        scheduler ([`SchedulerMixin`]):\r\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\r\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\r\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\r\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\r\n            `stabilityai/stable-diffusion-xl-base-1-0`.\r\n        add_watermarker (`bool`, *optional*):\r\n            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to\r\n            watermark output images. If not defined, it will default to True if the package is installed, otherwise no\r\n            watermarker will be used.\r\n    \"\"\"\r\n\r\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->image_encoder->unet->vae\"\r\n    _optional_components = [\r\n        \"tokenizer\",\r\n        \"tokenizer_2\",\r\n        \"text_encoder\",\r\n        \"text_encoder_2\",\r\n    ]\r\n\r\n    def __init__(\r\n        self,\r\n        vae: AutoencoderKL,\r\n        text_encoder: CLIPTextModel,\r\n        text_encoder_2: CLIPTextModelWithProjection,\r\n        tokenizer: CLIPTokenizer,\r\n        tokenizer_2: CLIPTokenizer,\r\n        unet: UNet2DConditionModel,\r\n        scheduler: KarrasDiffusionSchedulers,\r\n        force_zeros_for_empty_prompt: bool = True,\r\n        add_watermarker: Optional[bool] = None,\r\n    ):\r\n        super().__init__()\r\n\r\n        self.register_modules(\r\n            vae=vae,\r\n            text_encoder=text_encoder,\r\n            text_encoder_2=text_encoder_2,\r\n            tokenizer=tokenizer,\r\n            tokenizer_2=tokenizer_2,\r\n            unet=unet,\r\n            scheduler=scheduler,\r\n        )\r\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\r\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\r\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\r\n\r\n        self.default_sample_size = (\r\n            self.unet.config.sample_size\r\n            if hasattr(self, \"unet\") and self.unet is not None and hasattr(self.unet.config, \"sample_size\")\r\n            else 128\r\n        )\r\n\r\n        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()\r\n\r\n        if add_watermarker:\r\n            self.watermark = StableDiffusionXLWatermarker()\r\n        else:\r\n            self.watermark = None\r\n\r\n    class SeedTilesMode(Enum):\r\n        \"\"\"Modes in which the latents of a particular tile can be re-seeded\"\"\"\r\n\r\n        FULL = \"full\"\r\n        EXCLUSIVE = \"exclusive\"\r\n\r\n    def encode_prompt(\r\n        self,\r\n        prompt: str,\r\n        prompt_2: str | None = None,\r\n        device: Optional[torch.device] = None,\r\n        num_images_per_prompt: int = 1,\r\n        do_classifier_free_guidance: bool = True,\r\n        negative_prompt: str | None = None,\r\n        negative_prompt_2: str | None = None,\r\n        prompt_embeds: Optional[torch.Tensor] = None,\r\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\r\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\r\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\r\n        lora_scale: Optional[float] = None,\r\n        clip_skip: Optional[int] = None,\r\n    ):\r\n        r\"\"\"\r\n        Encodes the prompt into text encoder hidden states.\r\n\r\n        Args:\r\n            prompt (`str` or `List[str]`, *optional*):\r\n                prompt to be encoded\r\n            prompt_2 (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\r\n                used in both text-encoders\r\n            device: (`torch.device`):\r\n                torch device\r\n            num_images_per_prompt (`int`):\r\n                number of images that should be generated per prompt\r\n            do_classifier_free_guidance (`bool`):\r\n                whether to use classifier free guidance or not\r\n            negative_prompt (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\r\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\r\n                less than `1`).\r\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\r\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\r\n            prompt_embeds (`torch.Tensor`, *optional*):\r\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\r\n                provided, text embeddings will be generated from `prompt` input argument.\r\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\r\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\r\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\r\n                argument.\r\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\r\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\r\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\r\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\r\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\r\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\r\n                input argument.\r\n            lora_scale (`float`, *optional*):\r\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\r\n            clip_skip (`int`, *optional*):\r\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\r\n                the output of the pre-final layer will be used for computing the prompt embeddings.\r\n        \"\"\"\r\n        device = device or self._execution_device\r\n\r\n        # set lora scale so that monkey patched LoRA\r\n        # function of text encoder can correctly access it\r\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\r\n            self._lora_scale = lora_scale\r\n\r\n            # dynamically adjust the LoRA scale\r\n            if self.text_encoder is not None:\r\n                if not USE_PEFT_BACKEND:\r\n                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\r\n                else:\r\n                    scale_lora_layers(self.text_encoder, lora_scale)\r\n\r\n            if self.text_encoder_2 is not None:\r\n                if not USE_PEFT_BACKEND:\r\n                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)\r\n                else:\r\n                    scale_lora_layers(self.text_encoder_2, lora_scale)\r\n\r\n        prompt = [prompt] if isinstance(prompt, str) else prompt\r\n\r\n        if prompt is not None:\r\n            batch_size = len(prompt)\r\n        else:\r\n            batch_size = prompt_embeds.shape[0]\r\n\r\n        # Define tokenizers and text encoders\r\n        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]\r\n        text_encoders = (\r\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\r\n        )\r\n\r\n        if prompt_embeds is None:\r\n            prompt_2 = prompt_2 or prompt\r\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\r\n\r\n            # textual inversion: process multi-vector tokens if necessary\r\n            prompt_embeds_list = []\r\n            prompts = [prompt, prompt_2]\r\n            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):\r\n                if isinstance(self, TextualInversionLoaderMixin):\r\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\r\n\r\n                text_inputs = tokenizer(\r\n                    prompt,\r\n                    padding=\"max_length\",\r\n                    max_length=tokenizer.model_max_length,\r\n                    truncation=True,\r\n                    return_tensors=\"pt\",\r\n                )\r\n\r\n                text_input_ids = text_inputs.input_ids\r\n                untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\r\n\r\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\r\n                    text_input_ids, untruncated_ids\r\n                ):\r\n                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\r\n                    logger.warning(\r\n                        \"The following part of your input was truncated because CLIP can only handle sequences up to\"\r\n                        f\" {tokenizer.model_max_length} tokens: {removed_text}\"\r\n                    )\r\n\r\n                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\r\n\r\n                # We are only ALWAYS interested in the pooled output of the final text encoder\r\n                if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:\r\n                    pooled_prompt_embeds = prompt_embeds[0]\r\n\r\n                if clip_skip is None:\r\n                    prompt_embeds = prompt_embeds.hidden_states[-2]\r\n                else:\r\n                    # \"2\" because SDXL always indexes from the penultimate layer.\r\n                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]\r\n\r\n                prompt_embeds_list.append(prompt_embeds)\r\n\r\n            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\r\n\r\n        # get unconditional embeddings for classifier free guidance\r\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\r\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\r\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\r\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\r\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\r\n            negative_prompt = negative_prompt or \"\"\r\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\r\n\r\n            # normalize str to list\r\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\r\n            negative_prompt_2 = (\r\n                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2\r\n            )\r\n\r\n            uncond_tokens: List[str]\r\n            if prompt is not None and type(prompt) is not type(negative_prompt):\r\n                raise TypeError(\r\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\r\n                    f\" {type(prompt)}.\"\r\n                )\r\n            elif batch_size != len(negative_prompt):\r\n                raise ValueError(\r\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\r\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\r\n                    \" the batch size of `prompt`.\"\r\n                )\r\n            else:\r\n                uncond_tokens = [negative_prompt, negative_prompt_2]\r\n\r\n            negative_prompt_embeds_list = []\r\n            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):\r\n                if isinstance(self, TextualInversionLoaderMixin):\r\n                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)\r\n\r\n                max_length = prompt_embeds.shape[1]\r\n                uncond_input = tokenizer(\r\n                    negative_prompt,\r\n                    padding=\"max_length\",\r\n                    max_length=max_length,\r\n                    truncation=True,\r\n                    return_tensors=\"pt\",\r\n                )\r\n\r\n                negative_prompt_embeds = text_encoder(\r\n                    uncond_input.input_ids.to(device),\r\n                    output_hidden_states=True,\r\n                )\r\n\r\n                # We are only ALWAYS interested in the pooled output of the final text encoder\r\n                if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:\r\n                    negative_pooled_prompt_embeds = negative_prompt_embeds[0]\r\n                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]\r\n\r\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\r\n\r\n            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\r\n\r\n        if self.text_encoder_2 is not None:\r\n            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\r\n        else:\r\n            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)\r\n\r\n        bs_embed, seq_len, _ = prompt_embeds.shape\r\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\r\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\r\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\r\n\r\n        if do_classifier_free_guidance:\r\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\r\n            seq_len = negative_prompt_embeds.shape[1]\r\n\r\n            if self.text_encoder_2 is not None:\r\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\r\n            else:\r\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)\r\n\r\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\r\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\r\n\r\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\r\n            bs_embed * num_images_per_prompt, -1\r\n        )\r\n        if do_classifier_free_guidance:\r\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\r\n                bs_embed * num_images_per_prompt, -1\r\n            )\r\n\r\n        if self.text_encoder is not None:\r\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\r\n                # Retrieve the original scale by scaling back the LoRA layers\r\n                unscale_lora_layers(self.text_encoder, lora_scale)\r\n\r\n        if self.text_encoder_2 is not None:\r\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\r\n                # Retrieve the original scale by scaling back the LoRA layers\r\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\r\n\r\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\r\n\r\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\r\n    def prepare_extra_step_kwargs(self, generator, eta):\r\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\r\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\r\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\r\n        # and should be between [0, 1]\r\n\r\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\r\n        extra_step_kwargs = {}\r\n        if accepts_eta:\r\n            extra_step_kwargs[\"eta\"] = eta\r\n\r\n        # check if the scheduler accepts generator\r\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\r\n        if accepts_generator:\r\n            extra_step_kwargs[\"generator\"] = generator\r\n        return extra_step_kwargs\r\n\r\n    def check_inputs(self, prompt, height, width, grid_cols, seed_tiles_mode, tiles_mode):\r\n        if height % 8 != 0 or width % 8 != 0:\r\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\r\n\r\n        if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\r\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\r\n\r\n        if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):\r\n            raise ValueError(f\"`prompt` has to be a list of lists but is {type(prompt)}\")\r\n\r\n        if not all(len(row) == grid_cols for row in prompt):\r\n            raise ValueError(\"All prompt rows must have the same number of prompt columns\")\r\n\r\n        if not isinstance(seed_tiles_mode, str) and (\r\n            not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)\r\n        ):\r\n            raise ValueError(f\"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}\")\r\n\r\n        if any(mode not in tiles_mode for row in seed_tiles_mode for mode in row):\r\n            raise ValueError(f\"Seed tiles mode must be one of {tiles_mode}\")\r\n\r\n    def _get_add_time_ids(\r\n        self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None\r\n    ):\r\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\r\n\r\n        passed_add_embed_dim = (\r\n            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim\r\n        )\r\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\r\n\r\n        if expected_add_embed_dim != passed_add_embed_dim:\r\n            raise ValueError(\r\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\r\n            )\r\n\r\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\r\n        return add_time_ids\r\n\r\n    def _gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype):\r\n        \"\"\"Generates a gaussian mask of weights for tile contributions\"\"\"\r\n        import numpy as np\r\n        from numpy import exp, pi, sqrt\r\n\r\n        latent_width = tile_width // 8\r\n        latent_height = tile_height // 8\r\n\r\n        var = 0.01\r\n        midpoint = (latent_width - 1) / 2  # -1 because index goes from 0 to latent_width - 1\r\n        x_probs = [\r\n            exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)\r\n            for x in range(latent_width)\r\n        ]\r\n        midpoint = latent_height / 2\r\n        y_probs = [\r\n            exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)\r\n            for y in range(latent_height)\r\n        ]\r\n\r\n        weights_np = np.outer(y_probs, x_probs)\r\n        weights_torch = torch.tensor(weights_np, device=device)\r\n        weights_torch = weights_torch.to(dtype)\r\n        return torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))\r\n\r\n    def upcast_vae(self):\r\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\r\n        self.vae.to(dtype=torch.float32)\r\n\r\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\r\n    def get_guidance_scale_embedding(\r\n        self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32\r\n    ) -> torch.Tensor:\r\n        \"\"\"\r\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\r\n\r\n        Args:\r\n            w (`torch.Tensor`):\r\n                Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.\r\n            embedding_dim (`int`, *optional*, defaults to 512):\r\n                Dimension of the embeddings to generate.\r\n            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):\r\n                Data type of the generated embeddings.\r\n\r\n        Returns:\r\n            `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.\r\n        \"\"\"\r\n        assert len(w.shape) == 1\r\n        w = w * 1000.0\r\n\r\n        half_dim = embedding_dim // 2\r\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\r\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\r\n        emb = w.to(dtype)[:, None] * emb[None, :]\r\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\r\n        if embedding_dim % 2 == 1:  # zero pad\r\n            emb = torch.nn.functional.pad(emb, (0, 1))\r\n        assert emb.shape == (w.shape[0], embedding_dim)\r\n        return emb\r\n\r\n    @property\r\n    def guidance_scale(self):\r\n        return self._guidance_scale\r\n\r\n    @property\r\n    def clip_skip(self):\r\n        return self._clip_skip\r\n\r\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\r\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\r\n    # corresponds to doing no classifier free guidance.\r\n    @property\r\n    def do_classifier_free_guidance(self):\r\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\r\n\r\n    @property\r\n    def cross_attention_kwargs(self):\r\n        return self._cross_attention_kwargs\r\n\r\n    @property\r\n    def num_timesteps(self):\r\n        return self._num_timesteps\r\n\r\n    @property\r\n    def interrupt(self):\r\n        return self._interrupt\r\n\r\n    @torch.no_grad()\r\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\r\n    def __call__(\r\n        self,\r\n        prompt: Union[str, List[str]] = None,\r\n        height: Optional[int] = None,\r\n        width: Optional[int] = None,\r\n        num_inference_steps: int = 50,\r\n        guidance_scale: float = 5.0,\r\n        negative_prompt: Optional[Union[str, List[str]]] = None,\r\n        num_images_per_prompt: Optional[int] = 1,\r\n        eta: float = 0.0,\r\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\r\n        output_type: str | None = \"pil\",\r\n        return_dict: bool = True,\r\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\r\n        original_size: Optional[Tuple[int, int]] = None,\r\n        crops_coords_top_left: Optional[List[List[Tuple[int, int]]]] = None,\r\n        target_size: Optional[Tuple[int, int]] = None,\r\n        negative_original_size: Optional[Tuple[int, int]] = None,\r\n        negative_crops_coords_top_left: Optional[List[List[Tuple[int, int]]]] = None,\r\n        negative_target_size: Optional[Tuple[int, int]] = None,\r\n        clip_skip: Optional[int] = None,\r\n        tile_height: Optional[int] = 1024,\r\n        tile_width: Optional[int] = 1024,\r\n        tile_row_overlap: Optional[int] = 128,\r\n        tile_col_overlap: Optional[int] = 128,\r\n        guidance_scale_tiles: Optional[List[List[float]]] = None,\r\n        seed_tiles: Optional[List[List[int]]] = None,\r\n        seed_tiles_mode: Optional[Union[str, List[List[str]]]] = \"full\",\r\n        seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,\r\n        **kwargs,\r\n    ):\r\n        r\"\"\"\r\n        Function invoked when calling the pipeline for generation.\r\n\r\n        Args:\r\n            prompt (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\r\n                instead.\r\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\r\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\r\n                Anything below 512 pixels won't work well for\r\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\r\n                and checkpoints that are not specifically fine-tuned on low resolutions.\r\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\r\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\r\n                Anything below 512 pixels won't work well for\r\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\r\n                and checkpoints that are not specifically fine-tuned on low resolutions.\r\n            num_inference_steps (`int`, *optional*, defaults to 50):\r\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\r\n                expense of slower inference.\r\n            guidance_scale (`float`, *optional*, defaults to 5.0):\r\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\r\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\r\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\r\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\r\n                usually at the expense of lower image quality.\r\n            negative_prompt (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\r\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\r\n                less than `1`).\r\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\r\n                The number of images to generate per prompt.\r\n            eta (`float`, *optional*, defaults to 0.0):\r\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\r\n                [`schedulers.DDIMScheduler`], will be ignored for others.\r\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\r\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\r\n                to make generation deterministic.\r\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\r\n                The output format of the generate image. Choose between\r\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\r\n            return_dict (`bool`, *optional*, defaults to `True`):\r\n                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead\r\n                of a plain tuple.\r\n            cross_attention_kwargs (`dict`, *optional*):\r\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\r\n                `self.processor` in\r\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\r\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\r\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\r\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\r\n                explained in section 2.2 of\r\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\r\n            crops_coords_top_left (`List[List[Tuple[int, int]]]`, *optional*, defaults to (0, 0)):\r\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\r\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\r\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\r\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\r\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\r\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\r\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\r\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\r\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\r\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\r\n                micro-conditioning as explained in section 2.2 of\r\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\r\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\r\n            negative_crops_coords_top_left (`List[List[Tuple[int, int]]]`, *optional*, defaults to (0, 0)):\r\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\r\n                micro-conditioning as explained in section 2.2 of\r\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\r\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\r\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\r\n                To negatively condition the generation process based on a target image resolution. It should be as same\r\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\r\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\r\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\r\n            tile_height (`int`, *optional*, defaults to 1024):\r\n                Height of each grid tile in pixels.\r\n            tile_width (`int`, *optional*, defaults to 1024):\r\n                Width of each grid tile in pixels.\r\n            tile_row_overlap (`int`, *optional*, defaults to 128):\r\n                Number of overlapping pixels between tiles in consecutive rows.\r\n            tile_col_overlap (`int`, *optional*, defaults to 128):\r\n                Number of overlapping pixels between tiles in consecutive columns.\r\n            guidance_scale_tiles (`List[List[float]]`, *optional*):\r\n                Specific weights for classifier-free guidance in each tile. If `None`, the value provided in `guidance_scale` will be used.\r\n            seed_tiles (`List[List[int]]`, *optional*):\r\n                Specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard `generator` parameter.\r\n            seed_tiles_mode (`Union[str, List[List[str]]]`, *optional*, defaults to `\"full\"`):\r\n                Mode for seeding tiles, can be `\"full\"` or `\"exclusive\"`. If `\"full\"`, all the latents affected by the tile will be overridden. If `\"exclusive\"`, only the latents that are exclusively affected by this tile (and no other tiles) will be overridden.\r\n            seed_reroll_regions (`List[Tuple[int, int, int, int, int]]`, *optional*):\r\n                A list of tuples in the form of `(start_row, end_row, start_column, end_column, seed)` defining regions in pixel space for which the latents will be overridden using the given seed. Takes priority over `seed_tiles`.\r\n            **kwargs (`Dict[str, Any]`, *optional*):\r\n                 Additional optional keyword arguments to be passed to the `unet.__call__` and `scheduler.step` functions.\r\n\r\n        Examples:\r\n\r\n        Returns:\r\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLTilingPipelineOutput`] or `tuple`:\r\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLTilingPipelineOutput`] if `return_dict` is True, otherwise a\r\n            `tuple`. When returning a tuple, the first element is a list with the generated images.\r\n        \"\"\"\r\n\r\n        # 0. Default height and width to unet\r\n        height = height or self.default_sample_size * self.vae_scale_factor\r\n        width = width or self.default_sample_size * self.vae_scale_factor\r\n\r\n        original_size = original_size or (height, width)\r\n        target_size = target_size or (height, width)\r\n        negative_original_size = negative_original_size or (height, width)\r\n        negative_target_size = negative_target_size or (height, width)\r\n\r\n        self._guidance_scale = guidance_scale\r\n        self._clip_skip = clip_skip\r\n        self._cross_attention_kwargs = cross_attention_kwargs\r\n        self._interrupt = False\r\n\r\n        grid_rows = len(prompt)\r\n        grid_cols = len(prompt[0])\r\n\r\n        tiles_mode = [mode.value for mode in self.SeedTilesMode]\r\n\r\n        if isinstance(seed_tiles_mode, str):\r\n            seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]\r\n\r\n        # 1. Check inputs. Raise error if not correct\r\n        self.check_inputs(\r\n            prompt,\r\n            height,\r\n            width,\r\n            grid_cols,\r\n            seed_tiles_mode,\r\n            tiles_mode,\r\n        )\r\n\r\n        if seed_reroll_regions is None:\r\n            seed_reroll_regions = []\r\n\r\n        batch_size = 1\r\n\r\n        device = self._execution_device\r\n\r\n        # update crops coords list\r\n        crops_coords_top_left = _get_crops_coords_list(grid_rows, grid_cols, tile_width)\r\n        if negative_original_size is not None and negative_target_size is not None:\r\n            negative_crops_coords_top_left = _get_crops_coords_list(grid_rows, grid_cols, tile_width)\r\n\r\n        # update height and width tile size and tile overlap size\r\n        height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)\r\n        width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)\r\n\r\n        # 3. Encode input prompt\r\n        lora_scale = (\r\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\r\n        )\r\n        text_embeddings = [\r\n            [\r\n                self.encode_prompt(\r\n                    prompt=col,\r\n                    device=device,\r\n                    num_images_per_prompt=num_images_per_prompt,\r\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\r\n                    negative_prompt=negative_prompt,\r\n                    prompt_embeds=None,\r\n                    negative_prompt_embeds=None,\r\n                    pooled_prompt_embeds=None,\r\n                    negative_pooled_prompt_embeds=None,\r\n                    lora_scale=lora_scale,\r\n                    clip_skip=self.clip_skip,\r\n                )\r\n                for col in row\r\n            ]\r\n            for row in prompt\r\n        ]\r\n\r\n        # 3. Prepare latents\r\n        latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)\r\n        dtype = text_embeddings[0][0][0].dtype\r\n        latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=dtype)\r\n\r\n        # 3.1 overwrite latents for specific tiles if provided\r\n        if seed_tiles is not None:\r\n            for row in range(grid_rows):\r\n                for col in range(grid_cols):\r\n                    if (seed_tile := seed_tiles[row][col]) is not None:\r\n                        mode = seed_tiles_mode[row][col]\r\n                        if mode == self.SeedTilesMode.FULL.value:\r\n                            row_init, row_end, col_init, col_end = _tile2latent_indices(\r\n                                row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap\r\n                            )\r\n                        else:\r\n                            row_init, row_end, col_init, col_end = _tile2latent_exclusive_indices(\r\n                                row,\r\n                                col,\r\n                                tile_width,\r\n                                tile_height,\r\n                                tile_row_overlap,\r\n                                tile_col_overlap,\r\n                                grid_rows,\r\n                                grid_cols,\r\n                            )\r\n                        tile_generator = torch.Generator(device).manual_seed(seed_tile)\r\n                        tile_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)\r\n                        latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(\r\n                            tile_shape, generator=tile_generator, device=device\r\n                        )\r\n\r\n        # 3.2 overwrite again for seed reroll regions\r\n        for row_init, row_end, col_init, col_end, seed_reroll in seed_reroll_regions:\r\n            row_init, row_end, col_init, col_end = _pixel2latent_indices(\r\n                row_init, row_end, col_init, col_end\r\n            )  # to latent space coordinates\r\n            reroll_generator = torch.Generator(device).manual_seed(seed_reroll)\r\n            region_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)\r\n            latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(\r\n                region_shape, generator=reroll_generator, device=device\r\n            )\r\n\r\n        # 4. Prepare timesteps\r\n        accepts_offset = \"offset\" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())\r\n        extra_set_kwargs = {}\r\n        if accepts_offset:\r\n            extra_set_kwargs[\"offset\"] = 1\r\n        timesteps, num_inference_steps = retrieve_timesteps(\r\n            self.scheduler, num_inference_steps, device, None, None, **extra_set_kwargs\r\n        )\r\n\r\n        # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas\r\n        if isinstance(self.scheduler, LMSDiscreteScheduler):\r\n            latents = latents * self.scheduler.sigmas[0]\r\n\r\n        # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\r\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\r\n\r\n        # 6. Prepare added time ids & embeddings\r\n        # text_embeddings order: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\r\n        embeddings_and_added_time = []\r\n        for row in range(grid_rows):\r\n            addition_embed_type_row = []\r\n            for col in range(grid_cols):\r\n                # extract generated values\r\n                prompt_embeds = text_embeddings[row][col][0]\r\n                negative_prompt_embeds = text_embeddings[row][col][1]\r\n                pooled_prompt_embeds = text_embeddings[row][col][2]\r\n                negative_pooled_prompt_embeds = text_embeddings[row][col][3]\r\n\r\n                add_text_embeds = pooled_prompt_embeds\r\n                if self.text_encoder_2 is None:\r\n                    text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\r\n                else:\r\n                    text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\r\n                add_time_ids = self._get_add_time_ids(\r\n                    original_size,\r\n                    crops_coords_top_left[row][col],\r\n                    target_size,\r\n                    dtype=prompt_embeds.dtype,\r\n                    text_encoder_projection_dim=text_encoder_projection_dim,\r\n                )\r\n                if negative_original_size is not None and negative_target_size is not None:\r\n                    negative_add_time_ids = self._get_add_time_ids(\r\n                        negative_original_size,\r\n                        negative_crops_coords_top_left[row][col],\r\n                        negative_target_size,\r\n                        dtype=prompt_embeds.dtype,\r\n                        text_encoder_projection_dim=text_encoder_projection_dim,\r\n                    )\r\n                else:\r\n                    negative_add_time_ids = add_time_ids\r\n\r\n                if self.do_classifier_free_guidance:\r\n                    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\r\n                    add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\r\n                    add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)\r\n\r\n                prompt_embeds = prompt_embeds.to(device)\r\n                add_text_embeds = add_text_embeds.to(device)\r\n                add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\r\n                addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))\r\n            embeddings_and_added_time.append(addition_embed_type_row)\r\n\r\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\r\n\r\n        # 7. Mask for tile weights strength\r\n        tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size, device, torch.float32)\r\n\r\n        # 8. Denoising loop\r\n        self._num_timesteps = len(timesteps)\r\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\r\n            for i, t in enumerate(timesteps):\r\n                # Diffuse each tile\r\n                noise_preds = []\r\n                for row in range(grid_rows):\r\n                    noise_preds_row = []\r\n                    for col in range(grid_cols):\r\n                        if self.interrupt:\r\n                            continue\r\n                        px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(\r\n                            row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap\r\n                        )\r\n                        tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]\r\n                        # expand the latents if we are doing classifier free guidance\r\n                        latent_model_input = (\r\n                            torch.cat([tile_latents] * 2) if self.do_classifier_free_guidance else tile_latents\r\n                        )\r\n                        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\r\n\r\n                        # predict the noise residual\r\n                        added_cond_kwargs = {\r\n                            \"text_embeds\": embeddings_and_added_time[row][col][1],\r\n                            \"time_ids\": embeddings_and_added_time[row][col][2],\r\n                        }\r\n                        with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype != self.unet.dtype):\r\n                            noise_pred = self.unet(\r\n                                latent_model_input,\r\n                                t,\r\n                                encoder_hidden_states=embeddings_and_added_time[row][col][0],\r\n                                cross_attention_kwargs=self.cross_attention_kwargs,\r\n                                added_cond_kwargs=added_cond_kwargs,\r\n                                return_dict=False,\r\n                            )[0]\r\n\r\n                        # perform guidance\r\n                        if self.do_classifier_free_guidance:\r\n                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\r\n                            guidance = (\r\n                                guidance_scale\r\n                                if guidance_scale_tiles is None or guidance_scale_tiles[row][col] is None\r\n                                else guidance_scale_tiles[row][col]\r\n                            )\r\n                            noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)\r\n                            noise_preds_row.append(noise_pred_tile)\r\n                    noise_preds.append(noise_preds_row)\r\n\r\n                # Stitch noise predictions for all tiles\r\n                noise_pred = torch.zeros(latents.shape, device=device)\r\n                contributors = torch.zeros(latents.shape, device=device)\r\n\r\n                # Add each tile contribution to overall latents\r\n                for row in range(grid_rows):\r\n                    for col in range(grid_cols):\r\n                        px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(\r\n                            row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap\r\n                        )\r\n                        noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += (\r\n                            noise_preds[row][col] * tile_weights\r\n                        )\r\n                        contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights\r\n\r\n                # Average overlapping areas with more than 1 contributor\r\n                noise_pred /= contributors\r\n                noise_pred = noise_pred.to(dtype)\r\n\r\n                # compute the previous noisy sample x_t -> x_t-1\r\n                latents_dtype = latents.dtype\r\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\r\n                if latents.dtype != latents_dtype:\r\n                    if torch.backends.mps.is_available():\r\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\r\n                        latents = latents.to(latents_dtype)\r\n\r\n                # update progress bar\r\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\r\n                    progress_bar.update()\r\n\r\n                if XLA_AVAILABLE:\r\n                    xm.mark_step()\r\n\r\n        if not output_type == \"latent\":\r\n            # make sure the VAE is in float32 mode, as it overflows in float16\r\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\r\n\r\n            if needs_upcasting:\r\n                self.upcast_vae()\r\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\r\n            elif latents.dtype != self.vae.dtype:\r\n                if torch.backends.mps.is_available():\r\n                    # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\r\n                    self.vae = self.vae.to(latents.dtype)\r\n\r\n            # unscale/denormalize the latents\r\n            # denormalize with the mean and std if available and not None\r\n            has_latents_mean = hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None\r\n            has_latents_std = hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None\r\n            if has_latents_mean and has_latents_std:\r\n                latents_mean = (\r\n                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)\r\n                )\r\n                latents_std = (\r\n                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)\r\n                )\r\n                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean\r\n            else:\r\n                latents = latents / self.vae.config.scaling_factor\r\n\r\n            image = self.vae.decode(latents, return_dict=False)[0]\r\n\r\n            # cast back to fp16 if  needed\r\n            if needs_upcasting:\r\n                self.vae.to(dtype=torch.float16)\r\n        else:\r\n            image = latents\r\n\r\n        if not output_type == \"latent\":\r\n            # apply watermark if available\r\n            if self.watermark is not None:\r\n                image = self.watermark.apply_watermark(image)\r\n\r\n            image = self.image_processor.postprocess(image, output_type=output_type)\r\n\r\n        # Offload all models\r\n        self.maybe_free_model_hooks()\r\n\r\n        if not return_dict:\r\n            return (image,)\r\n\r\n        return StableDiffusionXLPipelineOutput(images=image)\r\n"
  },
  {
    "path": "examples/community/mod_controlnet_tile_sr_sdxl.py",
    "content": "# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom enum import Enum\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom PIL import Image\nfrom transformers import (\n    CLIPTextModel,\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n)\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import (\n    AutoencoderKL,\n    ControlNetModel,\n    ControlNetUnionModel,\n    MultiControlNetModel,\n    UNet2DConditionModel,\n)\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.import_utils import is_invisible_watermark_available\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\n\n\nif is_invisible_watermark_available():\n    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker\n\nfrom diffusers.utils import is_torch_xla_available\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        import torch\n        from diffusers import DiffusionPipeline, ControlNetUnionModel, AutoencoderKL, UniPCMultistepScheduler\n        from diffusers.utils import load_image\n        from PIL import Image\n\n        device = \"cuda\"\n\n        # Initialize the models and pipeline\n        controlnet = ControlNetUnionModel.from_pretrained(\n            \"brad-twinkl/controlnet-union-sdxl-1.0-promax\", torch_dtype=torch.float16\n        ).to(device=device)\n        vae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16).to(device=device)\n\n        model_id = \"SG161222/RealVisXL_V5.0\"\n        pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(\n            model_id, controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant=\"fp16\"\n        ).to(device)\n\n        pipe.enable_model_cpu_offload()  # << Enable this if you have limited VRAM\n        pipe.enable_vae_tiling() # << Enable this if you have limited VRAM\n        pipe.enable_vae_slicing() # << Enable this if you have limited VRAM\n\n        # Set selected scheduler\n        pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n\n        # Load image\n        control_image = load_image(\"https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/1.jpg\")\n        original_height = control_image.height\n        original_width = control_image.width\n        print(f\"Current resolution: H:{original_height} x W:{original_width}\")\n\n        # Pre-upscale image for tiling\n        resolution = 4096\n        tile_gaussian_sigma = 0.3\n        max_tile_size = 1024 # or 1280\n\n        current_size = max(control_image.size)\n        scale_factor = max(2, resolution / current_size)\n        new_size = (int(control_image.width * scale_factor), int(control_image.height * scale_factor))\n        image = control_image.resize(new_size, Image.LANCZOS)\n\n        # Update target height and width\n        target_height = image.height\n        target_width = image.width\n        print(f\"Target resolution: H:{target_height} x W:{target_width}\")\n\n        # Calculate overlap size\n        normal_tile_overlap, border_tile_overlap = calculate_overlap(target_width, target_height)\n\n        # Set other params\n        tile_weighting_method = TileWeightingMethod.COSINE.value\n        guidance_scale = 4\n        num_inference_steps = 35\n        denoising_strenght = 0.65\n        controlnet_strength = 1.0\n        prompt = \"high-quality, noise-free edges, high quality, 4k, hd, 8k\"\n        negative_prompt = \"blurry, pixelated, noisy, low resolution, artifacts, poor details\"\n\n        # Image generation\n        control_image = pipe(\n            image=image,\n            control_image=control_image,\n            control_mode=[6],\n            controlnet_conditioning_scale=float(controlnet_strength),\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            normal_tile_overlap=normal_tile_overlap,\n            border_tile_overlap=border_tile_overlap,\n            height=target_height,\n            width=target_width,\n            original_size=(original_width, original_height),\n            target_size=(target_width, target_height),\n            guidance_scale=guidance_scale,\n            strength=float(denoising_strenght),\n            tile_weighting_method=tile_weighting_method,\n            max_tile_size=max_tile_size,\n            tile_gaussian_sigma=float(tile_gaussian_sigma),\n            num_inference_steps=num_inference_steps,\n        )[\"images\"][0]\n        ```\n\"\"\"\n\n\n# This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.\ndef _adaptive_tile_size(image_size, base_tile_size=512, max_tile_size=1280):\n    \"\"\"\n    Calculate the adaptive tile size based on the image dimensions, ensuring the tile\n    respects the aspect ratio and stays within the specified size limits.\n    \"\"\"\n    width, height = image_size\n    aspect_ratio = width / height\n\n    if aspect_ratio > 1:\n        # Landscape orientation\n        tile_width = min(width, max_tile_size)\n        tile_height = min(int(tile_width / aspect_ratio), max_tile_size)\n    else:\n        # Portrait or square orientation\n        tile_height = min(height, max_tile_size)\n        tile_width = min(int(tile_height * aspect_ratio), max_tile_size)\n\n    # Ensure the tile size is not smaller than the base_tile_size\n    tile_width = max(tile_width, base_tile_size)\n    tile_height = max(tile_height, base_tile_size)\n\n    return tile_width, tile_height\n\n\n# Copied and adapted from https://github.com/huggingface/diffusers/blob/main/examples/community/mixture_tiling.py\ndef _tile2pixel_indices(\n    tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, image_width, image_height\n):\n    \"\"\"Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image\n\n    Returns a tuple with:\n        - Starting coordinates of rows in pixel space\n        - Ending coordinates of rows in pixel space\n        - Starting coordinates of columns in pixel space\n        - Ending coordinates of columns in pixel space\n    \"\"\"\n    # Calculate initial indices\n    px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap)\n    px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap)\n\n    # Calculate end indices\n    px_row_end = px_row_init + tile_height\n    px_col_end = px_col_init + tile_width\n\n    # Ensure the last tile does not exceed the image dimensions\n    px_row_end = min(px_row_end, image_height)\n    px_col_end = min(px_col_end, image_width)\n\n    return px_row_init, px_row_end, px_col_init, px_col_end\n\n\n# Copied and adapted from https://github.com/huggingface/diffusers/blob/main/examples/community/mixture_tiling.py\ndef _tile2latent_indices(\n    tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, image_width, image_height\n):\n    \"\"\"Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image\n\n    Returns a tuple with:\n        - Starting coordinates of rows in latent space\n        - Ending coordinates of rows in latent space\n        - Starting coordinates of columns in latent space\n        - Ending coordinates of columns in latent space\n    \"\"\"\n    # Get pixel indices\n    px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(\n        tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, image_width, image_height\n    )\n\n    # Convert to latent space\n    latent_row_init = px_row_init // 8\n    latent_row_end = px_row_end // 8\n    latent_col_init = px_col_init // 8\n    latent_col_end = px_col_end // 8\n    latent_height = image_height // 8\n    latent_width = image_width // 8\n\n    # Ensure the last tile does not exceed the latent dimensions\n    latent_row_end = min(latent_row_end, latent_height)\n    latent_col_end = min(latent_col_end, latent_width)\n\n    return latent_row_init, latent_row_end, latent_col_init, latent_col_end\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\nclass StableDiffusionXLControlNetTileSRPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    FromSingleFileMixin,\n):\n    r\"\"\"\n    Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([` CLIPTextModelWithProjection`]):\n            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\n            specifically the\n            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)\n            variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`CLIPTokenizer`):\n            Second Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        controlnet ([`ControlNetUnionModel`]):\n            Provides additional conditioning to the unet during the denoising process.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        requires_aesthetics_score (`bool`, *optional*, defaults to `\"False\"`):\n            Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the\n            config of `stabilityai/stable-diffusion-xl-refiner-1-0`.\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\n            `stabilityai/stable-diffusion-xl-base-1-0`.\n        add_watermarker (`bool`, *optional*):\n            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to\n            watermark output images. If not defined, it will default to True if the package is installed, otherwise no\n            watermarker will be used.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->unet->vae\"\n    _optional_components = [\n        \"tokenizer\",\n        \"tokenizer_2\",\n        \"text_encoder\",\n        \"text_encoder_2\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        text_encoder_2: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        tokenizer_2: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: ControlNetUnionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        requires_aesthetics_score: bool = False,\n        force_zeros_for_empty_prompt: bool = True,\n        add_watermarker: Optional[bool] = None,\n    ):\n        super().__init__()\n\n        if not isinstance(controlnet, ControlNetUnionModel):\n            raise ValueError(\"Expected `controlnet` to be of type `ControlNetUnionModel`.\")\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n        self.mask_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True\n        )\n        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()\n\n        if add_watermarker:\n            self.watermark = StableDiffusionXLWatermarker()\n        else:\n            self.watermark = None\n\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)\n\n    def calculate_overlap(self, width, height, base_overlap=128):\n        \"\"\"\n        Calculates dynamic overlap based on the image's aspect ratio.\n\n        Args:\n            width (int): Width of the image in pixels.\n            height (int): Height of the image in pixels.\n            base_overlap (int, optional): Base overlap value in pixels. Defaults to 128.\n\n        Returns:\n            tuple: A tuple containing:\n                - row_overlap (int): Overlap between tiles in consecutive rows.\n                - col_overlap (int): Overlap between tiles in consecutive columns.\n        \"\"\"\n        ratio = height / width\n        if ratio < 1:  # Image is wider than tall\n            return base_overlap // 2, base_overlap\n        else:  # Image is taller than wide\n            return base_overlap, base_overlap * 2\n\n    class TileWeightingMethod(Enum):\n        \"\"\"Mode in which the tile weights will be generated\"\"\"\n\n        COSINE = \"Cosine\"\n        GAUSSIAN = \"Gaussian\"\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: str,\n        prompt_2: str | None = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder, lora_scale)\n\n            if self.text_encoder_2 is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]\n        text_encoders = (\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\n        )\n        dtype = text_encoders[0].dtype\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # textual inversion: process multi-vector tokens if necessary\n            prompt_embeds_list = []\n            prompts = [prompt, prompt_2]\n            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                text_input_ids = text_inputs.input_ids\n                untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                    text_input_ids, untruncated_ids\n                ):\n                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                    logger.warning(\n                        \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                        f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                    )\n                text_encoder.to(dtype)\n                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:\n                    pooled_prompt_embeds = prompt_embeds[0]\n\n                if clip_skip is None:\n                    prompt_embeds = prompt_embeds.hidden_states[-2]\n                else:\n                    # \"2\" because SDXL always indexes from the penultimate layer.\n                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\n\n            # normalize str to list\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n            negative_prompt_2 = (\n                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2\n            )\n\n            uncond_tokens: List[str]\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = [negative_prompt, negative_prompt_2]\n\n            negative_prompt_embeds_list = []\n            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    negative_prompt,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                negative_prompt_embeds = text_encoder(\n                    uncond_input.input_ids.to(device),\n                    output_hidden_states=True,\n                )\n\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:\n                    negative_pooled_prompt_embeds = negative_prompt_embeds[0]\n                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n\n        if self.text_encoder_2 is not None:\n            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n        else:\n            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            if self.text_encoder_2 is not None:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n            else:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        image,\n        strength,\n        num_inference_steps,\n        normal_tile_overlap,\n        border_tile_overlap,\n        max_tile_size,\n        tile_gaussian_sigma,\n        tile_weighting_method,\n        controlnet_conditioning_scale=1.0,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n        if num_inference_steps is None:\n            raise ValueError(\"`num_inference_steps` cannot be None.\")\n        elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:\n            raise ValueError(\n                f\"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type\"\n                f\" {type(num_inference_steps)}.\"\n            )\n        if normal_tile_overlap is None:\n            raise ValueError(\"`normal_tile_overlap` cannot be None.\")\n        elif not isinstance(normal_tile_overlap, int) or normal_tile_overlap < 64:\n            raise ValueError(\n                f\"`normal_tile_overlap` has to be greater than 64 but is {normal_tile_overlap} of type\"\n                f\" {type(normal_tile_overlap)}.\"\n            )\n        if border_tile_overlap is None:\n            raise ValueError(\"`border_tile_overlap` cannot be None.\")\n        elif not isinstance(border_tile_overlap, int) or border_tile_overlap < 128:\n            raise ValueError(\n                f\"`border_tile_overlap` has to be greater than 128 but is {border_tile_overlap} of type\"\n                f\" {type(border_tile_overlap)}.\"\n            )\n        if max_tile_size is None:\n            raise ValueError(\"`max_tile_size` cannot be None.\")\n        elif not isinstance(max_tile_size, int) or max_tile_size not in (1024, 1280):\n            raise ValueError(\n                f\"`max_tile_size` has to be in 1024 or 1280 but is {max_tile_size} of type {type(max_tile_size)}.\"\n            )\n        if tile_gaussian_sigma is None:\n            raise ValueError(\"`tile_gaussian_sigma` cannot be None.\")\n        elif not isinstance(tile_gaussian_sigma, float) or tile_gaussian_sigma <= 0:\n            raise ValueError(\n                f\"`tile_gaussian_sigma` has to be a positive float but is {tile_gaussian_sigma} of type\"\n                f\" {type(tile_gaussian_sigma)}.\"\n            )\n        if tile_weighting_method is None:\n            raise ValueError(\"`tile_weighting_method` cannot be None.\")\n        elif not isinstance(tile_weighting_method, str) or tile_weighting_method not in [\n            t.value for t in self.TileWeightingMethod\n        ]:\n            raise ValueError(\n                f\"`tile_weighting_method` has to be a string in ({[t.value for t in self.TileWeightingMethod]}) but is {tile_weighting_method} of type\"\n                f\" {type(tile_weighting_method)}.\"\n            )\n\n        # Check `image`\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            self.check_image(image, prompt)\n        elif (\n            isinstance(self.controlnet, ControlNetUnionModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)\n        ):\n            self.check_image(image, prompt)\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetUnionModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)\n        ) or (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if not isinstance(control_guidance_start, (tuple, list)):\n            control_guidance_start = [control_guidance_start]\n\n        if not isinstance(control_guidance_end, (tuple, list)):\n            control_guidance_end = [control_guidance_end]\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image\n    def check_image(self, image, prompt):\n        image_is_pil = isinstance(image, Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image\n    def prepare_control_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n        if hasattr(self.scheduler, \"set_begin_index\"):\n            self.scheduler.set_begin_index(t_start * self.scheduler.order)\n\n        return timesteps, num_inference_steps - t_start\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents\n    def prepare_latents(\n        self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True\n    ):\n        if not isinstance(image, (torch.Tensor, Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        latents_mean = latents_std = None\n        if hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None:\n            latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)\n        if hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None:\n            latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)\n\n        # Offload text encoder if `enable_model_cpu_offload` was enabled\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.text_encoder_2.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            init_latents = image\n\n        else:\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            if self.vae.config.force_upcast:\n                image = image.float()\n                self.vae.to(dtype=torch.float32)\n\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            elif isinstance(generator, list):\n                if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:\n                    image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)\n                elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:\n                    raise ValueError(\n                        f\"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} \"\n                    )\n\n                init_latents = [\n                    retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                    for i in range(batch_size)\n                ]\n                init_latents = torch.cat(init_latents, dim=0)\n            else:\n                init_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n            if self.vae.config.force_upcast:\n                self.vae.to(dtype)\n\n            init_latents = init_latents.to(dtype)\n            if latents_mean is not None and latents_std is not None:\n                latents_mean = latents_mean.to(device=device, dtype=dtype)\n                latents_std = latents_std.to(device=device, dtype=dtype)\n                init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std\n            else:\n                init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        if add_noise:\n            shape = init_latents.shape\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            # get latents\n            init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n\n        latents = init_latents\n\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids\n    def _get_add_time_ids(\n        self,\n        original_size,\n        crops_coords_top_left,\n        target_size,\n        aesthetic_score,\n        negative_aesthetic_score,\n        negative_original_size,\n        negative_crops_coords_top_left,\n        negative_target_size,\n        dtype,\n        text_encoder_projection_dim=None,\n    ):\n        if self.config.requires_aesthetics_score:\n            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))\n            add_neg_time_ids = list(\n                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)\n            )\n        else:\n            add_time_ids = list(original_size + crops_coords_top_left + target_size)\n            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)\n\n        passed_add_embed_dim = (\n            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim\n        )\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if (\n            expected_add_embed_dim > passed_add_embed_dim\n            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model.\"\n            )\n        elif (\n            expected_add_embed_dim < passed_add_embed_dim\n            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model.\"\n            )\n        elif expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)\n\n        return add_time_ids, add_neg_time_ids\n\n    def _generate_cosine_weights(self, tile_width, tile_height, nbatches, device, dtype):\n        \"\"\"\n        Generates cosine weights as a PyTorch tensor for blending tiles.\n\n        Args:\n            tile_width (int): Width of the tile in pixels.\n            tile_height (int): Height of the tile in pixels.\n            nbatches (int): Number of batches.\n            device (torch.device): Device where the tensor will be allocated (e.g., 'cuda' or 'cpu').\n            dtype (torch.dtype): Data type of the tensor (e.g., torch.float32).\n\n        Returns:\n            torch.Tensor: A tensor containing cosine weights for blending tiles, expanded to match batch and channel dimensions.\n        \"\"\"\n        # Convert tile dimensions to latent space\n        latent_width = tile_width // 8\n        latent_height = tile_height // 8\n\n        # Generate x and y coordinates in latent space\n        x = np.arange(0, latent_width)\n        y = np.arange(0, latent_height)\n\n        # Calculate midpoints\n        midpoint_x = (latent_width - 1) / 2\n        midpoint_y = (latent_height - 1) / 2\n\n        # Compute cosine probabilities for x and y\n        x_probs = np.cos(np.pi * (x - midpoint_x) / latent_width)\n        y_probs = np.cos(np.pi * (y - midpoint_y) / latent_height)\n\n        # Create a 2D weight matrix using the outer product\n        weights_np = np.outer(y_probs, x_probs)\n\n        # Convert to a PyTorch tensor with the correct device and dtype\n        weights_torch = torch.tensor(weights_np, device=device, dtype=dtype)\n\n        # Expand for batch and channel dimensions\n        tile_weights_expanded = torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))\n\n        return tile_weights_expanded\n\n    def _generate_gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype, sigma=0.05):\n        \"\"\"\n        Generates Gaussian weights as a PyTorch tensor for blending tiles in latent space.\n\n        Args:\n            tile_width (int): Width of the tile in pixels.\n            tile_height (int): Height of the tile in pixels.\n            nbatches (int): Number of batches.\n            device (torch.device): Device where the tensor will be allocated (e.g., 'cuda' or 'cpu').\n            dtype (torch.dtype): Data type of the tensor (e.g., torch.float32).\n            sigma (float, optional): Standard deviation of the Gaussian distribution. Controls the smoothness of the weights. Defaults to 0.05.\n\n        Returns:\n            torch.Tensor: A tensor containing Gaussian weights for blending tiles, expanded to match batch and channel dimensions.\n        \"\"\"\n        # Convert tile dimensions to latent space\n        latent_width = tile_width // 8\n        latent_height = tile_height // 8\n\n        # Generate Gaussian weights in latent space\n        x = np.linspace(-1, 1, latent_width)\n        y = np.linspace(-1, 1, latent_height)\n        xx, yy = np.meshgrid(x, y)\n        gaussian_weight = np.exp(-(xx**2 + yy**2) / (2 * sigma**2))\n\n        # Convert to a PyTorch tensor with the correct device and dtype\n        weights_torch = torch.tensor(gaussian_weight, device=device, dtype=dtype)\n\n        # Expand for batch and channel dimensions\n        weights_expanded = weights_torch.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions\n        weights_expanded = weights_expanded.expand(nbatches, -1, -1, -1)  # Expand to the number of batches\n\n        return weights_expanded\n\n    def _get_num_tiles(self, height, width, tile_height, tile_width, normal_tile_overlap, border_tile_overlap):\n        \"\"\"\n        Calculates the number of tiles needed to cover an image, choosing the appropriate formula based on the\n        ratio between the image size and the tile size.\n\n        This function automatically selects between two formulas:\n        1. A universal formula for typical cases (image-to-tile ratio <= 6:1).\n        2. A specialized formula with border tile overlap for larger or atypical cases (image-to-tile ratio > 6:1).\n\n        Args:\n            height (int): Height of the image in pixels.\n            width (int): Width of the image in pixels.\n            tile_height (int): Height of each tile in pixels.\n            tile_width (int): Width of each tile in pixels.\n            normal_tile_overlap (int): Overlap between tiles in pixels for normal (non-border) tiles.\n            border_tile_overlap (int): Overlap between tiles in pixels for border tiles.\n\n        Returns:\n            tuple: A tuple containing:\n                - grid_rows (int): Number of rows in the tile grid.\n                - grid_cols (int): Number of columns in the tile grid.\n\n        Notes:\n            - The function uses the universal formula (without border_tile_overlap) for typical cases where the\n            image-to-tile ratio is 6:1 or smaller.\n            - For larger or atypical cases (image-to-tile ratio > 6:1), it uses a specialized formula that includes\n            border_tile_overlap to ensure complete coverage of the image, especially at the edges.\n        \"\"\"\n        # Calculate the ratio between the image size and the tile size\n        height_ratio = height / tile_height\n        width_ratio = width / tile_width\n\n        # If the ratio is greater than 6:1, use the formula with border_tile_overlap\n        if height_ratio > 6 or width_ratio > 6:\n            grid_rows = int(np.ceil((height - border_tile_overlap) / (tile_height - normal_tile_overlap))) + 1\n            grid_cols = int(np.ceil((width - border_tile_overlap) / (tile_width - normal_tile_overlap))) + 1\n        else:\n            # Otherwise, use the universal formula\n            grid_rows = int(np.ceil((height - normal_tile_overlap) / (tile_height - normal_tile_overlap)))\n            grid_cols = int(np.ceil((width - normal_tile_overlap) / (tile_width - normal_tile_overlap)))\n\n        return grid_rows, grid_cols\n\n    def prepare_tiles(\n        self,\n        grid_rows,\n        grid_cols,\n        tile_weighting_method,\n        tile_width,\n        tile_height,\n        normal_tile_overlap,\n        border_tile_overlap,\n        width,\n        height,\n        tile_sigma,\n        batch_size,\n        device,\n        dtype,\n    ):\n        \"\"\"\n        Processes image tiles by dynamically adjusting overlap and calculating Gaussian or cosine weights.\n\n        Args:\n            grid_rows (int): Number of rows in the tile grid.\n            grid_cols (int): Number of columns in the tile grid.\n            tile_weighting_method (str): Method for weighting tiles. Options: \"Gaussian\" or \"Cosine\".\n            tile_width (int): Width of each tile in pixels.\n            tile_height (int): Height of each tile in pixels.\n            normal_tile_overlap (int): Overlap between tiles in pixels for normal tiles.\n            border_tile_overlap (int): Overlap between tiles in pixels for border tiles.\n            width (int): Width of the image in pixels.\n            height (int): Height of the image in pixels.\n            tile_sigma (float): Sigma parameter for Gaussian weighting.\n            batch_size (int): Batch size for weight tiles.\n            device (torch.device): Device where tensors will be allocated (e.g., 'cuda' or 'cpu').\n            dtype (torch.dtype): Data type of the tensors (e.g., torch.float32).\n\n        Returns:\n            tuple: A tuple containing:\n                - tile_weights (np.ndarray): Array of weights for each tile.\n                - tile_row_overlaps (np.ndarray): Array of row overlaps for each tile.\n                - tile_col_overlaps (np.ndarray): Array of column overlaps for each tile.\n        \"\"\"\n\n        # Create arrays to store dynamic overlaps and weights\n        tile_row_overlaps = np.full((grid_rows, grid_cols), normal_tile_overlap)\n        tile_col_overlaps = np.full((grid_rows, grid_cols), normal_tile_overlap)\n        tile_weights = np.empty((grid_rows, grid_cols), dtype=object)  # Stores Gaussian or cosine weights\n\n        # Iterate over tiles to adjust overlap and calculate weights\n        for row in range(grid_rows):\n            for col in range(grid_cols):\n                # Calculate the size of the current tile\n                px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(\n                    row, col, tile_width, tile_height, normal_tile_overlap, normal_tile_overlap, width, height\n                )\n                current_tile_width = px_col_end - px_col_init\n                current_tile_height = px_row_end - px_row_init\n                sigma = tile_sigma\n\n                # Adjust overlap for smaller tiles\n                if current_tile_width < tile_width:\n                    px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(\n                        row, col, tile_width, tile_height, border_tile_overlap, border_tile_overlap, width, height\n                    )\n                    current_tile_width = px_col_end - px_col_init\n                    tile_col_overlaps[row, col] = border_tile_overlap\n                    sigma = tile_sigma * 1.2\n                if current_tile_height < tile_height:\n                    px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(\n                        row, col, tile_width, tile_height, border_tile_overlap, border_tile_overlap, width, height\n                    )\n                    current_tile_height = px_row_end - px_row_init\n                    tile_row_overlaps[row, col] = border_tile_overlap\n                    sigma = tile_sigma * 1.2\n\n                # Calculate weights for the current tile\n                if tile_weighting_method == self.TileWeightingMethod.COSINE.value:\n                    tile_weights[row, col] = self._generate_cosine_weights(\n                        tile_width=current_tile_width,\n                        tile_height=current_tile_height,\n                        nbatches=batch_size,\n                        device=device,\n                        dtype=torch.float32,\n                    )\n                else:\n                    tile_weights[row, col] = self._generate_gaussian_weights(\n                        tile_width=current_tile_width,\n                        tile_height=current_tile_height,\n                        nbatches=batch_size,\n                        device=device,\n                        dtype=dtype,\n                        sigma=sigma,\n                    )\n\n        return tile_weights, tile_row_overlaps, tile_col_overlaps\n\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        control_image: PipelineImageInput = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        strength: float = 0.9999,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        control_mode: Optional[Union[int, List[int]]] = None,\n        original_size: Tuple[int, int] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Tuple[int, int] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        aesthetic_score: float = 6.0,\n        negative_aesthetic_score: float = 2.5,\n        clip_skip: Optional[int] = None,\n        normal_tile_overlap: int = 64,\n        border_tile_overlap: int = 128,\n        max_tile_size: int = 1024,\n        tile_gaussian_sigma: float = 0.05,\n        tile_weighting_method: str = \"Cosine\",\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, *optional*):\n                The initial image to be used as the starting point for the image generation process. Can also accept\n                image latents as `image`, if passing latents directly, they will not be encoded again.\n            control_image (`PipelineImageInput`, *optional*):\n                The ControlNet input condition. ControlNet uses this input condition to generate guidance for Unet.\n                If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also\n                be accepted as an image. The dimensions of the output image default to `image`'s dimensions. If height\n                and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in\n                init, images must be passed as a list such that each element of the list can be correctly batched for\n                input to a single ControlNet.\n            height (`int`, *optional*):\n                The height in pixels of the generated image. If not provided, defaults to the height of `control_image`.\n            width (`int`, *optional*):\n                The width in pixels of the generated image. If not provided, defaults to the width of `control_image`.\n            strength (`float`, *optional*, defaults to 0.9999):\n                Indicates the extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point, and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum, and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487).\n                Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages generating\n                images closely linked to the text `prompt`, usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/):\n                `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original UNet. If multiple ControlNets are specified in init, you can set the\n                corresponding scale as a list.\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                In this mode, the ControlNet encoder will try to recognize the content of the input image even if\n                you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the ControlNet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the ControlNet stops applying.\n            control_mode (`int` or `List[int]`, *optional*):\n                The mode of ControlNet guidance. Can be used to specify different behaviors for multiple ControlNets.\n            original_size (`Tuple[int, int]`, *optional*):\n                If `original_size` is not the same as `target_size`, the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning.\n            crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning.\n            target_size (`Tuple[int, int]`, *optional*):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified, it will default to `(height, width)`. Part of SDXL's micro-conditioning.\n            negative_original_size (`Tuple[int, int]`, *optional*):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning.\n            negative_crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning.\n            negative_target_size (`Tuple[int, int]`, *optional*):\n                To negatively condition the generation process based on a target image resolution. It should be the same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning.\n            aesthetic_score (`float`, *optional*, defaults to 6.0):\n                Used to simulate an aesthetic score of the generated image by influencing the positive text condition.\n                Part of SDXL's micro-conditioning.\n            negative_aesthetic_score (`float`, *optional*, defaults to 2.5):\n                Used to simulate an aesthetic score of the generated image by influencing the negative text condition.\n                Part of SDXL's micro-conditioning.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            normal_tile_overlap (`int`, *optional*, defaults to 64):\n                Number of overlapping pixels between tiles in consecutive rows.\n            border_tile_overlap (`int`, *optional*, defaults to 128):\n                Number of overlapping pixels between tiles at the borders.\n            max_tile_size (`int`, *optional*, defaults to 1024):\n                Maximum size of a tile in pixels.\n            tile_gaussian_sigma (`float`, *optional*, defaults to 0.3):\n                Sigma parameter for Gaussian weighting of tiles.\n            tile_weighting_method (`str`, *optional*, defaults to \"Cosine\"):\n                Method for weighting tiles. Options: \"Cosine\" or \"Gaussian\".\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`\n            containing the output images.\n        \"\"\"\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n\n        if not isinstance(control_image, list):\n            control_image = [control_image]\n        else:\n            control_image = control_image.copy()\n\n        if control_mode is None or isinstance(control_mode, list) and len(control_mode) == 0:\n            raise ValueError(\"The value for `control_mode` is expected!\")\n\n        if not isinstance(control_mode, list):\n            control_mode = [control_mode]\n\n        if len(control_image) != len(control_mode):\n            raise ValueError(\"Expected len(control_image) == len(control_mode)\")\n\n        num_control_type = controlnet.config.num_control_type\n\n        # 0. Set internal use parameters\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n        negative_original_size = negative_original_size or original_size\n        negative_target_size = negative_target_size or target_size\n        control_type = [0 for _ in range(num_control_type)]\n        control_type = torch.Tensor(control_type)\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._interrupt = False\n        batch_size = 1\n        device = self._execution_device\n        global_pool_conditions = controlnet.config.global_pool_conditions\n        guess_mode = guess_mode or global_pool_conditions\n\n        # 1. Check inputs\n        for _image, control_idx in zip(control_image, control_mode):\n            control_type[control_idx] = 1\n            self.check_inputs(\n                prompt,\n                height,\n                width,\n                _image,\n                strength,\n                num_inference_steps,\n                normal_tile_overlap,\n                border_tile_overlap,\n                max_tile_size,\n                tile_gaussian_sigma,\n                tile_weighting_method,\n                controlnet_conditioning_scale,\n                control_guidance_start,\n                control_guidance_end,\n            )\n\n        # 2 Get tile width and tile height size\n        tile_width, tile_height = _adaptive_tile_size((width, height), max_tile_size=max_tile_size)\n\n        # 2.1 Calculate the number of tiles needed\n        grid_rows, grid_cols = self._get_num_tiles(\n            height, width, tile_height, tile_width, normal_tile_overlap, border_tile_overlap\n        )\n\n        # 2.2 Expand prompt to number of tiles\n        if not isinstance(prompt, list):\n            prompt = [[prompt] * grid_cols] * grid_rows\n\n        # 2.3 Update height and width tile size by tile size and tile overlap size\n        width = (grid_cols - 1) * (tile_width - normal_tile_overlap) + min(\n            tile_width, width - (grid_cols - 1) * (tile_width - normal_tile_overlap)\n        )\n        height = (grid_rows - 1) * (tile_height - normal_tile_overlap) + min(\n            tile_height, height - (grid_rows - 1) * (tile_height - normal_tile_overlap)\n        )\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        text_embeddings = [\n            [\n                self.encode_prompt(\n                    prompt=col,\n                    device=device,\n                    num_images_per_prompt=num_images_per_prompt,\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\n                    negative_prompt=negative_prompt,\n                    prompt_embeds=None,\n                    negative_prompt_embeds=None,\n                    pooled_prompt_embeds=None,\n                    negative_pooled_prompt_embeds=None,\n                    lora_scale=text_encoder_lora_scale,\n                    clip_skip=self.clip_skip,\n                )\n                for col in row\n            ]\n            for row in prompt\n        ]\n\n        # 4. Prepare latent image\n        image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n\n        # 4.1 Prepare controlnet_conditioning_image\n        control_image = self.prepare_control_image(\n            image=image,\n            width=width,\n            height=height,\n            batch_size=batch_size * num_images_per_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            device=device,\n            dtype=controlnet.dtype,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            guess_mode=guess_mode,\n        )\n        control_type = (\n            control_type.reshape(1, -1)\n            .to(device, dtype=controlnet.dtype)\n            .repeat(batch_size * num_images_per_prompt * 2, 1)\n        )\n\n        # 5. Prepare timesteps\n        accepts_offset = \"offset\" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())\n        extra_set_kwargs = {}\n        if accepts_offset:\n            extra_set_kwargs[\"offset\"] = 1\n        self.scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n        self._num_timesteps = len(timesteps)\n\n        # 6. Prepare latent variables\n        dtype = text_embeddings[0][0][0].dtype\n        if latents is None:\n            latents = self.prepare_latents(\n                image_tensor,\n                latent_timestep,\n                batch_size,\n                num_images_per_prompt,\n                dtype,\n                device,\n                generator,\n                True,\n            )\n\n        # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas\n        if isinstance(self.scheduler, LMSDiscreteScheduler):\n            latents = latents * self.scheduler.sigmas[0]\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 8. Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            controlnet_keep.append(\n                1.0\n                - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)\n            )\n\n        # 8.1 Prepare added time ids & embeddings\n        # text_embeddings order: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n        embeddings_and_added_time = []\n        crops_coords_top_left = negative_crops_coords_top_left = (tile_width, tile_height)\n        for row in range(grid_rows):\n            addition_embed_type_row = []\n            for col in range(grid_cols):\n                # extract generated values\n                prompt_embeds = text_embeddings[row][col][0]\n                negative_prompt_embeds = text_embeddings[row][col][1]\n                pooled_prompt_embeds = text_embeddings[row][col][2]\n                negative_pooled_prompt_embeds = text_embeddings[row][col][3]\n\n                if negative_original_size is None:\n                    negative_original_size = original_size\n                if negative_target_size is None:\n                    negative_target_size = target_size\n                add_text_embeds = pooled_prompt_embeds\n\n                if self.text_encoder_2 is None:\n                    text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n                else:\n                    text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n                add_time_ids, add_neg_time_ids = self._get_add_time_ids(\n                    original_size,\n                    crops_coords_top_left,\n                    target_size,\n                    aesthetic_score,\n                    negative_aesthetic_score,\n                    negative_original_size,\n                    negative_crops_coords_top_left,\n                    negative_target_size,\n                    dtype=prompt_embeds.dtype,\n                    text_encoder_projection_dim=text_encoder_projection_dim,\n                )\n                add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n\n                if self.do_classifier_free_guidance:\n                    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n                    add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n                    add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n                    add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)\n\n                prompt_embeds = prompt_embeds.to(device)\n                add_text_embeds = add_text_embeds.to(device)\n                add_time_ids = add_time_ids.to(device)\n                addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids))\n\n            embeddings_and_added_time.append(addition_embed_type_row)\n\n        # 9. Prepare tiles weights and latent overlaps size to denoising process\n        tile_weights, tile_row_overlaps, tile_col_overlaps = self.prepare_tiles(\n            grid_rows,\n            grid_cols,\n            tile_weighting_method,\n            tile_width,\n            tile_height,\n            normal_tile_overlap,\n            border_tile_overlap,\n            width,\n            height,\n            tile_gaussian_sigma,\n            batch_size,\n            device,\n            dtype,\n        )\n\n        # 10. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # Diffuse each tile\n                noise_preds = []\n                for row in range(grid_rows):\n                    noise_preds_row = []\n                    for col in range(grid_cols):\n                        if self.interrupt:\n                            continue\n                        tile_row_overlap = tile_row_overlaps[row, col]\n                        tile_col_overlap = tile_col_overlaps[row, col]\n\n                        px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(\n                            row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, width, height\n                        )\n\n                        tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]\n\n                        # expand the latents if we are doing classifier free guidance\n                        latent_model_input = (\n                            torch.cat([tile_latents] * 2)\n                            if self.do_classifier_free_guidance\n                            else tile_latents  # 1, 4, ...\n                        )\n                        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                        # predict the noise residual\n                        added_cond_kwargs = {\n                            \"text_embeds\": embeddings_and_added_time[row][col][1],\n                            \"time_ids\": embeddings_and_added_time[row][col][2],\n                        }\n\n                        # controlnet(s) inference\n                        if guess_mode and self.do_classifier_free_guidance:\n                            # Infer ControlNet only for the conditional batch.\n                            control_model_input = tile_latents\n                            control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                            controlnet_prompt_embeds = embeddings_and_added_time[row][col][0].chunk(2)[1]\n                            controlnet_added_cond_kwargs = {\n                                \"text_embeds\": embeddings_and_added_time[row][col][1].chunk(2)[1],\n                                \"time_ids\": embeddings_and_added_time[row][col][2].chunk(2)[1],\n                            }\n                        else:\n                            control_model_input = latent_model_input\n                            controlnet_prompt_embeds = embeddings_and_added_time[row][col][0]\n                            controlnet_added_cond_kwargs = added_cond_kwargs\n\n                        if isinstance(controlnet_keep[i], list):\n                            cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                        else:\n                            controlnet_cond_scale = controlnet_conditioning_scale\n                            if isinstance(controlnet_cond_scale, list):\n                                controlnet_cond_scale = controlnet_cond_scale[0]\n                            cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                        px_row_init_pixel, px_row_end_pixel, px_col_init_pixel, px_col_end_pixel = _tile2pixel_indices(\n                            row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, width, height\n                        )\n\n                        tile_control_image = control_image[\n                            :, :, px_row_init_pixel:px_row_end_pixel, px_col_init_pixel:px_col_end_pixel\n                        ]\n\n                        down_block_res_samples, mid_block_res_sample = self.controlnet(\n                            control_model_input,\n                            t,\n                            encoder_hidden_states=controlnet_prompt_embeds,\n                            controlnet_cond=[tile_control_image],\n                            control_type=control_type,\n                            control_type_idx=control_mode,\n                            conditioning_scale=cond_scale,\n                            guess_mode=guess_mode,\n                            added_cond_kwargs=controlnet_added_cond_kwargs,\n                            return_dict=False,\n                        )\n\n                        if guess_mode and self.do_classifier_free_guidance:\n                            # Inferred ControlNet only for the conditional batch.\n                            # To apply the output of ControlNet to both the unconditional and conditional batches,\n                            # add 0 to the unconditional batch to keep it unchanged.\n                            down_block_res_samples = [\n                                torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples\n                            ]\n                            mid_block_res_sample = torch.cat(\n                                [torch.zeros_like(mid_block_res_sample), mid_block_res_sample]\n                            )\n\n                        # predict the noise residual\n                        with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype != self.unet.dtype):\n                            noise_pred = self.unet(\n                                latent_model_input,\n                                t,\n                                encoder_hidden_states=embeddings_and_added_time[row][col][0],\n                                cross_attention_kwargs=self.cross_attention_kwargs,\n                                down_block_additional_residuals=down_block_res_samples,\n                                mid_block_additional_residual=mid_block_res_sample,\n                                added_cond_kwargs=added_cond_kwargs,\n                                return_dict=False,\n                            )[0]\n\n                        # perform guidance\n                        if self.do_classifier_free_guidance:\n                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                            noise_pred_tile = noise_pred_uncond + guidance_scale * (\n                                noise_pred_text - noise_pred_uncond\n                            )\n                            noise_preds_row.append(noise_pred_tile)\n                    noise_preds.append(noise_preds_row)\n\n                # Stitch noise predictions for all tiles\n                noise_pred = torch.zeros(latents.shape, device=device)\n                contributors = torch.zeros(latents.shape, device=device)\n\n                # Add each tile contribution to overall latents\n                for row in range(grid_rows):\n                    for col in range(grid_cols):\n                        tile_row_overlap = tile_row_overlaps[row, col]\n                        tile_col_overlap = tile_col_overlaps[row, col]\n                        px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(\n                            row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, width, height\n                        )\n                        tile_weights_resized = tile_weights[row, col]\n\n                        noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += (\n                            noise_preds[row][col] * tile_weights_resized\n                        )\n                        contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights_resized\n\n                # Average overlapping areas with more than 1 contributor\n                noise_pred /= contributors\n                noise_pred = noise_pred.to(dtype)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                # update progress bar\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            # unscale/denormalize the latents\n            # denormalize with the mean and std if available and not None\n            has_latents_mean = hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None\n            has_latents_std = hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None\n            if has_latents_mean and has_latents_std:\n                latents_mean = (\n                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents_std = (\n                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean\n            else:\n                latents = latents / self.vae.config.scaling_factor\n\n            image = self.vae.decode(latents, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n\n            # apply watermark if available\n            if self.watermark is not None:\n                image = self.watermark.apply_watermark(image)\n\n            image = self.image_processor.postprocess(image, output_type=output_type)\n        else:\n            image = latents\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        result = StableDiffusionXLPipelineOutput(images=image)\n        if not return_dict:\n            return (image,)\n\n        return result\n"
  },
  {
    "path": "examples/community/multilingual_stable_diffusion.py",
    "content": "import inspect\nfrom typing import Callable, List, Optional, Union\n\nimport torch\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTextModel,\n    CLIPTokenizer,\n    MBart50TokenizerFast,\n    MBartForConditionalGeneration,\n    pipeline,\n)\n\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler\nfrom diffusers.utils import deprecate, logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef detect_language(pipe, prompt, batch_size):\n    \"\"\"helper function to detect language(s) of prompt\"\"\"\n\n    if batch_size == 1:\n        preds = pipe(prompt, top_k=1, truncation=True, max_length=128)\n        return preds[0][\"label\"]\n    else:\n        detected_languages = []\n        for p in prompt:\n            preds = pipe(p, top_k=1, truncation=True, max_length=128)\n            detected_languages.append(preds[0][\"label\"])\n\n        return detected_languages\n\n\ndef translate_prompt(prompt, translation_tokenizer, translation_model, device):\n    \"\"\"helper function to translate prompt to English\"\"\"\n\n    encoded_prompt = translation_tokenizer(prompt, return_tensors=\"pt\").to(device)\n    generated_tokens = translation_model.generate(**encoded_prompt, max_new_tokens=1000)\n    en_trans = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n\n    return en_trans[0]\n\n\nclass MultilingualStableDiffusion(DiffusionPipeline, StableDiffusionMixin):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion in different languages.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        detection_pipeline ([`pipeline`]):\n            Transformers pipeline to detect prompt's language.\n        translation_model ([`MBartForConditionalGeneration`]):\n            Model to translate prompt to English, if necessary. Please refer to the\n            [model card](https://huggingface.co/docs/transformers/model_doc/mbart) for details.\n        translation_tokenizer ([`MBart50TokenizerFast`]):\n            Tokenizer of the translation model.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    def __init__(\n        self,\n        detection_pipeline: pipeline,\n        translation_model: MBartForConditionalGeneration,\n        translation_tokenizer: MBart50TokenizerFast,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        self.register_modules(\n            detection_pipeline=detection_pipeline,\n            translation_model=translation_model,\n            translation_tokenizer=translation_tokenizer,\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation. Can be in different languages.\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        if isinstance(prompt, str):\n            batch_size = 1\n        elif isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        # detect language and translate if necessary\n        prompt_language = detect_language(self.detection_pipeline, prompt, batch_size)\n        if batch_size == 1 and prompt_language != \"en\":\n            prompt = translate_prompt(prompt, self.translation_tokenizer, self.translation_model, self.device)\n\n        if isinstance(prompt, list):\n            for index in range(batch_size):\n                if prompt_language[index] != \"en\":\n                    p = translate_prompt(\n                        prompt[index], self.translation_tokenizer, self.translation_model, self.device\n                    )\n                    prompt[index] = p\n\n        # get prompt text embeddings\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n\n        if text_input_ids.shape[-1] > self.tokenizer.model_max_length:\n            removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n            )\n            text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]\n        text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = text_embeddings.shape\n        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)\n        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                # detect language and translate it if necessary\n                negative_prompt_language = detect_language(self.detection_pipeline, negative_prompt, batch_size)\n                if negative_prompt_language != \"en\":\n                    negative_prompt = translate_prompt(\n                        negative_prompt, self.translation_tokenizer, self.translation_model, self.device\n                    )\n                if isinstance(negative_prompt, str):\n                    uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                # detect language and translate it if necessary\n                if isinstance(negative_prompt, list):\n                    negative_prompt_languages = detect_language(self.detection_pipeline, negative_prompt, batch_size)\n                    for index in range(batch_size):\n                        if negative_prompt_languages[index] != \"en\":\n                            p = translate_prompt(\n                                negative_prompt[index], self.translation_tokenizer, self.translation_model, self.device\n                            )\n                            negative_prompt[index] = p\n                uncond_tokens = negative_prompt\n\n            max_length = text_input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = uncond_embeddings.shape[1]\n            uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)\n            uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # get the initial random noise unless the user supplied it\n\n        # Unlike in other pipelines, latents need to be generated in the target device\n        # for 1-to-1 results reproducibility with the CompVis implementation.\n        # However this currently doesn't work in `mps`.\n        latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)\n        latents_dtype = text_embeddings.dtype\n        if latents is None:\n            if self.device.type == \"mps\":\n                # randn does not work reproducibly on mps\n                latents = torch.randn(latents_shape, generator=generator, device=\"cpu\", dtype=latents_dtype).to(\n                    self.device\n                )\n            else:\n                latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)\n        else:\n            if latents.shape != latents_shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {latents_shape}\")\n            latents = latents.to(self.device)\n\n        # set timesteps\n        self.scheduler.set_timesteps(num_inference_steps)\n\n        # Some schedulers like PNDM have timesteps as arrays\n        # It's more optimized to move all timesteps to correct device beforehand\n        timesteps_tensor = self.scheduler.timesteps.to(self.device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        for i, t in enumerate(self.progress_bar(timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n            # call the callback, if provided\n            if callback is not None and i % callback_steps == 0:\n                step_idx = i // getattr(self.scheduler, \"order\", 1)\n                callback(step_idx, t, latents)\n\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents).sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(\n                self.device\n            )\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)\n            )\n        else:\n            has_nsfw_concept = None\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/one_step_unet.py",
    "content": "#!/usr/bin/env python3\nimport torch\n\nfrom diffusers import DiffusionPipeline\n\n\nclass UnetSchedulerOneForwardPipeline(DiffusionPipeline):\n    def __init__(self, unet, scheduler):\n        super().__init__()\n\n        self.register_modules(unet=unet, scheduler=scheduler)\n\n    def __call__(self):\n        image = torch.randn(\n            (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),\n        )\n        timestep = 1\n\n        model_output = self.unet(image, timestep).sample\n        scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample\n\n        result = scheduler_output - scheduler_output + torch.ones_like(scheduler_output)\n\n        return result\n"
  },
  {
    "path": "examples/community/pipeline_animatediff_controlnet.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom PIL import Image\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.models.unets.unet_motion_model import MotionAdapter\nfrom diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.schedulers import (\n    DDIMScheduler,\n    DPMSolverMultistepScheduler,\n    EulerAncestralDiscreteScheduler,\n    EulerDiscreteScheduler,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n)\nfrom diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter\n        >>> from diffusers.pipelines import DiffusionPipeline\n        >>> from diffusers.schedulers import DPMSolverMultistepScheduler\n        >>> from PIL import Image\n\n        >>> motion_id = \"guoyww/animatediff-motion-adapter-v1-5-2\"\n        >>> adapter = MotionAdapter.from_pretrained(motion_id)\n        >>> controlnet = ControlNetModel.from_pretrained(\"lllyasviel/control_v11p_sd15_openpose\", torch_dtype=torch.float16)\n        >>> vae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16)\n\n        >>> model_id = \"SG161222/Realistic_Vision_V5.1_noVAE\"\n        >>> pipe = DiffusionPipeline.from_pretrained(\n        ...     model_id,\n        ...     motion_adapter=adapter,\n        ...     controlnet=controlnet,\n        ...     vae=vae,\n        ...     custom_pipeline=\"pipeline_animatediff_controlnet\",\n        ... ).to(device=\"cuda\", dtype=torch.float16)\n        >>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(\n        ...     model_id, subfolder=\"scheduler\", clip_sample=False, timestep_spacing=\"linspace\", steps_offset=1, beta_schedule=\"linear\",\n        ... )\n        >>> pipe.enable_vae_slicing()\n\n        >>> conditioning_frames = []\n        >>> for i in range(1, 16 + 1):\n        ...     conditioning_frames.append(Image.open(f\"frame_{i}.png\"))\n\n        >>> prompt = \"astronaut in space, dancing\"\n        >>> negative_prompt = \"bad quality, worst quality, jpeg artifacts, ugly\"\n        >>> result = pipe(\n        ...     prompt=prompt,\n        ...     negative_prompt=negative_prompt,\n        ...     width=512,\n        ...     height=768,\n        ...     conditioning_frames=conditioning_frames,\n        ...     num_inference_steps=12,\n        ... )\n\n        >>> from diffusers.utils import export_to_gif\n        >>> export_to_gif(result.frames[0], \"result.gif\")\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid\ndef tensor2vid(video: torch.Tensor, processor, output_type=\"np\"):\n    batch_size, channels, num_frames, height, width = video.shape\n    outputs = []\n    for batch_idx in range(batch_size):\n        batch_vid = video[batch_idx].permute(1, 0, 2, 3)\n        batch_output = processor.postprocess(batch_vid, output_type)\n\n        outputs.append(batch_output)\n\n    if output_type == \"np\":\n        outputs = np.stack(outputs)\n\n    elif output_type == \"pt\":\n        outputs = torch.stack(outputs)\n\n    elif not output_type == \"pil\":\n        raise ValueError(f\"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']\")\n\n    return outputs\n\n\nclass AnimateDiffControlNetPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    IPAdapterMixin,\n    StableDiffusionLoraLoaderMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-video generation.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer (`CLIPTokenizer`):\n            A [`~transformers.CLIPTokenizer`] to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.\n        motion_adapter ([`MotionAdapter`]):\n            A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->unet->vae\"\n    _optional_components = [\"feature_extractor\", \"image_encoder\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        motion_adapter: MotionAdapter,\n        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],\n        scheduler: Union[\n            DDIMScheduler,\n            PNDMScheduler,\n            LMSDiscreteScheduler,\n            EulerDiscreteScheduler,\n            EulerAncestralDiscreteScheduler,\n            DPMSolverMultistepScheduler,\n        ],\n        feature_extractor: Optional[CLIPImageProcessor] = None,\n        image_encoder: Optional[CLIPVisionModelWithProjection] = None,\n    ):\n        super().__init__()\n        unet = UNetMotionModel.from_unet2d(unet, motion_adapter)\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            motion_adapter=motion_adapter,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt\n    ):\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            image_embeds = []\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)\n                single_negative_image_embeds = torch.stack(\n                    [single_negative_image_embeds] * num_images_per_prompt, dim=0\n                )\n\n                if self.do_classifier_free_guidance:\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                    single_image_embeds = single_image_embeds.to(device)\n\n                image_embeds.append(single_image_embeds)\n        else:\n            image_embeds = ip_adapter_image_embeds\n        return image_embeds\n\n    # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n\n        batch_size, channels, num_frames, height, width = latents.shape\n        latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)\n\n        image = self.vae.decode(latents).sample\n        video = (\n            image[None, :]\n            .reshape(\n                (\n                    batch_size,\n                    num_frames,\n                    -1,\n                )\n                + image.shape[2:]\n            )\n            .permute(0, 2, 1, 3, 4)\n        )\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        video = video.float()\n        return video\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        num_frames,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        image=None,\n        controlnet_conditioning_scale=1.0,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        # `prompt` needs more sophisticated handling when there are multiple\n        # conditionings.\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if isinstance(prompt, list):\n                logger.warning(\n                    f\"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}\"\n                    \" prompts. The conditionings will be fixed across the prompts.\"\n                )\n\n        # Check `image`\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(image, list):\n                raise TypeError(f\"For single controlnet, `image` must be of type `list` but got {type(image)}\")\n            if len(image) != num_frames:\n                raise ValueError(f\"Excepted image to have length {num_frames} but got {len(image)=}\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if not isinstance(image, list) or not isinstance(image[0], list):\n                raise TypeError(f\"For multiple controlnets: `image` must be type list of lists but got {type(image)=}\")\n            if len(image[0]) != num_frames:\n                raise ValueError(f\"Expected length of image sublist as {num_frames} but got {len(image[0])=}\")\n            if any(len(img) != len(image[0]) for img in image):\n                raise ValueError(\"All conditioning frame batches for multicontrolnet must be same size\")\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if not isinstance(control_guidance_start, (tuple, list)):\n            control_guidance_start = [control_guidance_start]\n\n        if not isinstance(control_guidance_end, (tuple, list)):\n            control_guidance_end = [control_guidance_end]\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if len(control_guidance_start) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents\n    def prepare_latents(\n        self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            num_frames,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image\n    def prepare_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_frames: Optional[int] = 16,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_videos_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[PipelineImageInput] = None,\n        conditioning_frames: Optional[List[PipelineImageInput]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated video.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated video.\n            num_frames (`int`, *optional*, defaults to 16):\n                The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds\n                amounts to 2 seconds of video.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality videos at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`. Latents should be of shape\n                `(batch_size, num_channel, num_frames, height, width)`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image (`PipelineImageInput`, *optional*):\n                Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.\n                Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding\n                if `do_classifier_free_guidance` is set to `True`.\n                If not provided, embeddings are computed from the `ip_adapter_image` input argument.\n            conditioning_frames (`List[PipelineImageInput]`, *optional*):\n                The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets\n                are specified, images must be passed as a list such that each element of the list can be correctly\n                batched for input to a single ControlNet.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or\n                `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead\n                of a plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set\n                the corresponding scale as a list.\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                The ControlNet encoder tries to recognize the content of the input image even if you remove all\n                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the ControlNet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the ControlNet stops applying.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is\n                returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        num_videos_per_prompt = 1\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt=prompt,\n            height=height,\n            width=width,\n            num_frames=num_frames,\n            callback_steps=callback_steps,\n            negative_prompt=negative_prompt,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            image=conditioning_frames,\n            controlnet_conditioning_scale=controlnet_conditioning_scale,\n            control_guidance_start=control_guidance_start,\n            control_guidance_end=control_guidance_end,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        global_pool_conditions = (\n            controlnet.config.global_pool_conditions\n            if isinstance(controlnet, ControlNetModel)\n            else controlnet.nets[0].config.global_pool_conditions\n        )\n        guess_mode = guess_mode or global_pool_conditions\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_videos_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        if ip_adapter_image is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt\n            )\n\n        if isinstance(controlnet, ControlNetModel):\n            conditioning_frames = self.prepare_image(\n                image=conditioning_frames,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_videos_per_prompt * num_frames,\n                num_images_per_prompt=num_videos_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n        elif isinstance(controlnet, MultiControlNetModel):\n            cond_prepared_frames = []\n            for frame_ in conditioning_frames:\n                prepared_frame = self.prepare_image(\n                    image=frame_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_videos_per_prompt * num_frames,\n                    num_images_per_prompt=num_videos_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n                cond_prepared_frames.append(prepared_frame)\n            conditioning_frames = cond_prepared_frames\n        else:\n            assert False\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n        self._num_timesteps = len(timesteps)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_videos_per_prompt,\n            num_channels_latents,\n            num_frames,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Add image embeds for IP-Adapter\n        added_cond_kwargs = (\n            {\"image_embeds\": image_embeds}\n            if ip_adapter_image is not None or ip_adapter_image_embeds is not None\n            else None\n        )\n\n        # 7.1 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                else:\n                    control_model_input = latent_model_input\n                    controlnet_prompt_embeds = prompt_embeds\n                controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                control_model_input = torch.transpose(control_model_input, 1, 2)\n                control_model_input = control_model_input.reshape(\n                    (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])\n                )\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=controlnet_prompt_embeds,\n                    controlnet_cond=conditioning_frames,\n                    conditioning_scale=cond_scale,\n                    guess_mode=guess_mode,\n                    return_dict=False,\n                )\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                ).sample\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        callback(i, t, latents)\n\n        # 9. Post processing\n        if output_type == \"latent\":\n            video = latents\n        else:\n            video_tensor = self.decode_latents(latents)\n            video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)\n\n        # 10. Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (video,)\n\n        return AnimateDiffPipelineOutput(frames=video)\n"
  },
  {
    "path": "examples/community/pipeline_animatediff_img2video.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# Note:\n# This pipeline relies on a \"hack\" discovered by the community that allows\n# the generation of videos given an input image with AnimateDiff. It works\n# by creating a copy of the image `num_frames` times and progressively adding\n# more noise to the image based on the strength and latent interpolation method.\n\nimport inspect\nfrom types import FunctionType\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.models.unet_motion_model import MotionAdapter\nfrom diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.schedulers import (\n    DDIMScheduler,\n    DPMSolverMultistepScheduler,\n    EulerAncestralDiscreteScheduler,\n    EulerDiscreteScheduler,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n)\nfrom diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler\n        >>> from diffusers.utils import export_to_gif, load_image\n\n        >>> model_id = \"SG161222/Realistic_Vision_V5.1_noVAE\"\n        >>> adapter = MotionAdapter.from_pretrained(\"guoyww/animatediff-motion-adapter-v1-5-2\")\n        >>> pipe = DiffusionPipeline.from_pretrained(\"SG161222/Realistic_Vision_V5.1_noVAE\", motion_adapter=adapter, custom_pipeline=\"pipeline_animatediff_img2video\").to(\"cuda\")\n        >>> pipe.scheduler = pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder=\"scheduler\", clip_sample=False, timestep_spacing=\"linspace\", beta_schedule=\"linear\", steps_offset=1)\n\n        >>> image = load_image(\"snail.png\")\n        >>> output = pipe(image=image, prompt=\"A snail moving on the ground\", strength=0.8, latent_interpolation_method=\"slerp\")\n        >>> frames = output.frames[0]\n        >>> export_to_gif(frames, \"animation.gif\")\n        ```\n\"\"\"\n\n\ndef lerp(\n    v0: torch.Tensor,\n    v1: torch.Tensor,\n    t: Union[float, torch.Tensor],\n) -> torch.Tensor:\n    r\"\"\"\n    Linear Interpolation between two tensors.\n\n    Args:\n        v0 (`torch.Tensor`): First tensor.\n        v1 (`torch.Tensor`): Second tensor.\n        t: (`float` or `torch.Tensor`): Interpolation factor.\n    \"\"\"\n    t_is_float = False\n    input_device = v0.device\n    v0 = v0.cpu().numpy()\n    v1 = v1.cpu().numpy()\n\n    if isinstance(t, torch.Tensor):\n        t = t.cpu().numpy()\n    else:\n        t_is_float = True\n        t = np.array([t], dtype=v0.dtype)\n\n    t = t[..., None]\n    v0 = v0[None, ...]\n    v1 = v1[None, ...]\n    v2 = (1 - t) * v0 + t * v1\n\n    if t_is_float and v0.ndim > 1:\n        assert v2.shape[0] == 1\n        v2 = np.squeeze(v2, axis=0)\n\n    v2 = torch.from_numpy(v2).to(input_device)\n    return v2\n\n\ndef slerp(\n    v0: torch.Tensor,\n    v1: torch.Tensor,\n    t: Union[float, torch.Tensor],\n    DOT_THRESHOLD: float = 0.9995,\n) -> torch.Tensor:\n    r\"\"\"\n    Spherical Linear Interpolation between two tensors.\n\n    Args:\n        v0 (`torch.Tensor`): First tensor.\n        v1 (`torch.Tensor`): Second tensor.\n        t: (`float` or `torch.Tensor`): Interpolation factor.\n        DOT_THRESHOLD (`float`):\n            Dot product threshold exceeding which linear interpolation will be used\n            because input tensors are close to parallel.\n    \"\"\"\n    t_is_float = False\n    input_device = v0.device\n    v0 = v0.cpu().numpy()\n    v1 = v1.cpu().numpy()\n\n    if isinstance(t, torch.Tensor):\n        t = t.cpu().numpy()\n    else:\n        t_is_float = True\n        t = np.array([t], dtype=v0.dtype)\n\n    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))\n\n    if np.abs(dot) > DOT_THRESHOLD:\n        # v0 and v1 are close to parallel, so use linear interpolation instead\n        v2 = lerp(v0, v1, t)\n    else:\n        theta_0 = np.arccos(dot)\n        sin_theta_0 = np.sin(theta_0)\n        theta_t = theta_0 * t\n        sin_theta_t = np.sin(theta_t)\n        s0 = np.sin(theta_0 - theta_t) / sin_theta_0\n        s1 = sin_theta_t / sin_theta_0\n        s0 = s0[..., None]\n        s1 = s1[..., None]\n        v0 = v0[None, ...]\n        v1 = v1[None, ...]\n        v2 = s0 * v0 + s1 * v1\n\n    if t_is_float and v0.ndim > 1:\n        assert v2.shape[0] == 1\n        v2 = np.squeeze(v2, axis=0)\n\n    v2 = torch.from_numpy(v2).to(input_device)\n    return v2\n\n\n# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid\ndef tensor2vid(video: torch.Tensor, processor, output_type=\"np\"):\n    batch_size, channels, num_frames, height, width = video.shape\n    outputs = []\n    for batch_idx in range(batch_size):\n        batch_vid = video[batch_idx].permute(1, 0, 2, 3)\n        batch_output = processor.postprocess(batch_vid, output_type)\n\n        outputs.append(batch_output)\n\n    if output_type == \"np\":\n        outputs = np.stack(outputs)\n\n    elif output_type == \"pt\":\n        outputs = torch.stack(outputs)\n\n    elif not output_type == \"pil\":\n        raise ValueError(f\"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']\")\n\n    return outputs\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass AnimateDiffImgToVideoPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    IPAdapterMixin,\n    StableDiffusionLoraLoaderMixin,\n):\n    r\"\"\"\n    Pipeline for image-to-video generation.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer (`CLIPTokenizer`):\n            A [`~transformers.CLIPTokenizer`] to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.\n        motion_adapter ([`MotionAdapter`]):\n            A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n    _optional_components = [\"feature_extractor\", \"image_encoder\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        motion_adapter: MotionAdapter,\n        scheduler: Union[\n            DDIMScheduler,\n            PNDMScheduler,\n            LMSDiscreteScheduler,\n            EulerDiscreteScheduler,\n            EulerAncestralDiscreteScheduler,\n            DPMSolverMultistepScheduler,\n        ],\n        feature_extractor: CLIPImageProcessor = None,\n        image_encoder: CLIPVisionModelWithProjection = None,\n    ):\n        super().__init__()\n        unet = UNetMotionModel.from_unet2d(unet, motion_adapter)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            motion_adapter=motion_adapter,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt\n    ):\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            image_embeds = []\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)\n                single_negative_image_embeds = torch.stack(\n                    [single_negative_image_embeds] * num_images_per_prompt, dim=0\n                )\n\n                if self.do_classifier_free_guidance:\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                    single_image_embeds = single_image_embeds.to(device)\n\n                image_embeds.append(single_image_embeds)\n        else:\n            image_embeds = ip_adapter_image_embeds\n        return image_embeds\n\n    # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n\n        batch_size, channels, num_frames, height, width = latents.shape\n        latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)\n\n        image = self.vae.decode(latents).sample\n        video = (\n            image[None, :]\n            .reshape(\n                (\n                    batch_size,\n                    num_frames,\n                    -1,\n                )\n                + image.shape[2:]\n            )\n            .permute(0, 2, 1, 3, 4)\n        )\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        video = video.float()\n        return video\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        latent_interpolation_method=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if latent_interpolation_method is not None:\n            if latent_interpolation_method not in [\"lerp\", \"slerp\"] and not isinstance(\n                latent_interpolation_method, FunctionType\n            ):\n                raise ValueError(\n                    \"`latent_interpolation_method` must be one of `lerp`, `slerp` or a Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]\"\n                )\n\n    def prepare_latents(\n        self,\n        image,\n        strength,\n        batch_size,\n        num_channels_latents,\n        num_frames,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n        latent_interpolation_method=\"slerp\",\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            num_frames,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n\n        if latents is None:\n            image = image.to(device=device, dtype=dtype)\n\n            if image.shape[1] == 4:\n                latents = image\n            else:\n                # make sure the VAE is in float32 mode, as it overflows in float16\n                if self.vae.config.force_upcast:\n                    image = image.float()\n                    self.vae.to(dtype=torch.float32)\n\n                if isinstance(generator, list):\n                    if len(generator) != batch_size:\n                        raise ValueError(\n                            f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                            f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                        )\n\n                    init_latents = [\n                        retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                        for i in range(batch_size)\n                    ]\n                    init_latents = torch.cat(init_latents, dim=0)\n                else:\n                    init_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n                if self.vae.config.force_upcast:\n                    self.vae.to(dtype)\n\n                init_latents = init_latents.to(dtype)\n                init_latents = self.vae.config.scaling_factor * init_latents\n                latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n                latents = latents * self.scheduler.init_noise_sigma\n\n                if latent_interpolation_method == \"lerp\":\n\n                    def latent_cls(v0, v1, index):\n                        return lerp(v0, v1, index / num_frames * (1 - strength))\n                elif latent_interpolation_method == \"slerp\":\n\n                    def latent_cls(v0, v1, index):\n                        return slerp(v0, v1, index / num_frames * (1 - strength))\n                else:\n                    latent_cls = latent_interpolation_method\n\n                for i in range(num_frames):\n                    latents[:, :, i, :, :] = latent_cls(latents[:, :, i, :, :], init_latents, i)\n        else:\n            if shape != latents.shape:\n                # [B, C, F, H, W]\n                raise ValueError(f\"`latents` expected to have {shape=}, but found {latents.shape=}\")\n            latents = latents.to(device, dtype=dtype)\n\n        return latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        image: PipelineImageInput,\n        prompt: Optional[Union[str, List[str]]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_frames: int = 16,\n        num_inference_steps: int = 50,\n        timesteps: Optional[List[int]] = None,\n        guidance_scale: float = 7.5,\n        strength: float = 0.8,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_videos_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[PipelineImageInput] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: Optional[int] = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        clip_skip: Optional[int] = None,\n        latent_interpolation_method: Union[str, Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]] = \"slerp\",\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            image (`PipelineImageInput`):\n                The input image to condition the generation on.\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated video.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated video.\n            num_frames (`int`, *optional*, defaults to 16):\n                The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds\n                amounts to 2 seconds of video.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality videos at the\n                expense of slower inference.\n            strength (`float`, *optional*, defaults to 0.8):\n                Higher strength leads to more differences between original image and generated video.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`. Latents should be of shape\n                `(batch_size, num_channel, num_frames, height, width)`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*):\n                Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.\n                Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding\n                if `do_classifier_free_guidance` is set to `True`.\n                If not provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or\n                `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`AnimateDiffImgToVideoPipelineOutput`] instead\n                of a plain tuple.\n            callback (`Callable`, *optional*):\n                A function that calls every `callback_steps` steps during inference. The function is called with the\n                following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function is called. If not specified, the callback is called at\n                every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            latent_interpolation_method (`str` or `Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]]`, *optional*):\n                Must be one of \"lerp\", \"slerp\" or a callable that takes in a random noisy latent, image latent and a frame index\n                as input and returns an initial latent for sampling.\n        Examples:\n\n        Returns:\n            [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is\n                returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.\n        \"\"\"\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        num_videos_per_prompt = 1\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt=prompt,\n            height=height,\n            width=width,\n            callback_steps=callback_steps,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            latent_interpolation_method=latent_interpolation_method,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_videos_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=clip_skip,\n        )\n\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        if ip_adapter_image is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt\n            )\n\n        # 4. Preprocess image\n        image = self.image_processor.preprocess(image, height=height, width=width)\n\n        # 5. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            image=image,\n            strength=strength,\n            batch_size=batch_size * num_videos_per_prompt,\n            num_channels_latents=num_channels_latents,\n            num_frames=num_frames,\n            height=height,\n            width=width,\n            dtype=prompt_embeds.dtype,\n            device=device,\n            generator=generator,\n            latents=latents,\n            latent_interpolation_method=latent_interpolation_method,\n        )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 8. Add image embeds for IP-Adapter\n        added_cond_kwargs = (\n            {\"image_embeds\": image_embeds}\n            if ip_adapter_image is not None or ip_adapter_image_embeds is not None\n            else None\n        )\n\n        # 9. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                ).sample\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        callback(i, t, latents)\n\n        if output_type == \"latent\":\n            return AnimateDiffPipelineOutput(frames=latents)\n\n        # 10. Post-processing\n        if output_type == \"latent\":\n            video = latents\n        else:\n            video_tensor = self.decode_latents(latents)\n            video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)\n\n        # 11. Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (video,)\n\n        return AnimateDiffPipelineOutput(frames=video)\n"
  },
  {
    "path": "examples/community/pipeline_animatediff_ipex.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport intel_extension_for_pytorch as ipex\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers.image_processor import PipelineImageInput\nfrom diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.models.unets.unet_motion_model import MotionAdapter\nfrom diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput\nfrom diffusers.pipelines.free_init_utils import FreeInitMixin\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.schedulers import (\n    DDIMScheduler,\n    DPMSolverMultistepScheduler,\n    EulerAncestralDiscreteScheduler,\n    EulerDiscreteScheduler,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n)\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.video_processor import VideoProcessor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import MotionAdapter, AnimateDiffPipelineIpex, EulerDiscreteScheduler\n        >>> from diffusers.utils import export_to_gif\n        >>> from safetensors.torch import load_file\n\n        >>> device = \"cpu\"\n        >>> dtype = torch.float32\n\n        >>> # ByteDance/AnimateDiff-Lightning, a distilled version of AnimateDiff SD1.5 v2,\n        >>> # a lightning-fast text-to-video generation model which can generate videos\n        >>> # more than ten times faster than the original AnimateDiff.\n        >>> step = 8  # Options: [1,2,4,8]\n        >>> repo = \"ByteDance/AnimateDiff-Lightning\"\n        >>> ckpt = f\"animatediff_lightning_{step}step_diffusers.safetensors\"\n        >>> base = \"emilianJR/epiCRealism\"  # Choose to your favorite base model.\n\n        >>> adapter = MotionAdapter().to(device, dtype)\n        >>> adapter.load_state_dict(load_file(hf_hub_download(repo ,ckpt), device=device))\n\n        >>> pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)\n        >>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing=\"trailing\", beta_schedule=\"linear\")\n\n        >>> # For Float32\n        >>> pipe.prepare_for_ipex(torch.float32, prompt = \"A girl smiling\")\n        >>> # For BFloat16\n        >>> pipe.prepare_for_ipex(torch.bfloat16, prompt = \"A girl smiling\")\n\n        >>> # For Float32\n        >>> output = pipe(prompt = \"A girl smiling\", guidance_scale=1.0, num_inference_steps = step)\n        >>> # For BFloat16\n        >>> with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):\n        >>>     output = pipe(prompt = \"A girl smiling\", guidance_scale=1.0, num_inference_steps = step)\n\n        >>> frames = output.frames[0]\n        >>> export_to_gif(frames, \"animation.gif\")\n        ```\n\"\"\"\n\n\nclass AnimateDiffPipelineIpex(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    IPAdapterMixin,\n    LoraLoaderMixin,\n    FreeInitMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-video generation.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer (`CLIPTokenizer`):\n            A [`~transformers.CLIPTokenizer`] to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.\n        motion_adapter ([`MotionAdapter`]):\n            A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n    _optional_components = [\"feature_extractor\", \"image_encoder\", \"motion_adapter\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: Union[UNet2DConditionModel, UNetMotionModel],\n        motion_adapter: MotionAdapter,\n        scheduler: Union[\n            DDIMScheduler,\n            PNDMScheduler,\n            LMSDiscreteScheduler,\n            EulerDiscreteScheduler,\n            EulerAncestralDiscreteScheduler,\n            DPMSolverMultistepScheduler,\n        ],\n        feature_extractor: CLIPImageProcessor = None,\n        image_encoder: CLIPVisionModelWithProjection = None,\n    ):\n        super().__init__()\n        if isinstance(unet, UNet2DConditionModel):\n            unet = UNetMotionModel.from_unet2d(unet, motion_adapter)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            motion_adapter=motion_adapter,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, LoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if self.text_encoder is not None:\n            if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            image_embeds = []\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)\n                single_negative_image_embeds = torch.stack(\n                    [single_negative_image_embeds] * num_images_per_prompt, dim=0\n                )\n\n                if do_classifier_free_guidance:\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                    single_image_embeds = single_image_embeds.to(device)\n\n                image_embeds.append(single_image_embeds)\n        else:\n            repeat_dims = [1]\n            image_embeds = []\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                    single_negative_image_embeds = single_negative_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))\n                    )\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                else:\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                image_embeds.append(single_image_embeds)\n\n        return image_embeds\n\n    # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n\n        batch_size, channels, num_frames, height, width = latents.shape\n        latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)\n\n        image = self.vae.decode(latents).sample\n        video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        video = video.float()\n        return video\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n        if ip_adapter_image_embeds is not None:\n            if not isinstance(ip_adapter_image_embeds, list):\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\n                )\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\n                )\n\n    # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents\n    def prepare_latents(\n        self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            num_frames,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_frames: Optional[int] = 16,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_videos_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated video.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated video.\n            num_frames (`int`, *optional*, defaults to 16):\n                The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds\n                amounts to 2 seconds of video.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality videos at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`. Latents should be of shape\n                `(batch_size, num_channel, num_frames, height, width)`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*):\n                Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated video. Choose between `torch.Tensor`, `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead\n                of a plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is\n                returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.\n        \"\"\"\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        num_videos_per_prompt = 1\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_videos_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_videos_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_videos_per_prompt,\n            num_channels_latents,\n            num_frames,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Add image embeds for IP-Adapter\n        added_cond_kwargs = (\n            {\"image_embeds\": image_embeds}\n            if ip_adapter_image is not None or ip_adapter_image_embeds is not None\n            else None\n        )\n\n        num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1\n        for free_init_iter in range(num_free_init_iters):\n            if self.free_init_enabled:\n                latents, timesteps = self._apply_free_init(\n                    latents, free_init_iter, num_inference_steps, device, latents.dtype, generator\n                )\n\n            self._num_timesteps = len(timesteps)\n            num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n\n            # 8. Denoising loop\n            with self.progress_bar(total=self._num_timesteps) as progress_bar:\n                for i, t in enumerate(timesteps):\n                    # expand the latents if we are doing classifier free guidance\n                    latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                    # predict the noise residual\n                    noise_pred = self.unet(\n                        latent_model_input,\n                        t,\n                        encoder_hidden_states=prompt_embeds,\n                        # cross_attention_kwargs=cross_attention_kwargs,\n                        # added_cond_kwargs=added_cond_kwargs,\n                        # ).sample\n                    )[\"sample\"]\n\n                    # perform guidance\n                    if self.do_classifier_free_guidance:\n                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                    # compute the previous noisy sample x_t -> x_t-1\n                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                    if callback_on_step_end is not None:\n                        callback_kwargs = {}\n                        for k in callback_on_step_end_tensor_inputs:\n                            callback_kwargs[k] = locals()[k]\n                        callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                        latents = callback_outputs.pop(\"latents\", latents)\n                        prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                        negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                    # call the callback, if provided\n                    if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                        progress_bar.update()\n\n        # 9. Post processing\n        if output_type == \"latent\":\n            video = latents\n        else:\n            video_tensor = self.decode_latents(latents)\n            video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)\n\n        # 10. Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (video,)\n\n        return AnimateDiffPipelineOutput(frames=video)\n\n    @torch.no_grad()\n    def prepare_for_ipex(\n        self,\n        dtype=torch.float32,\n        prompt: Union[str, List[str]] = None,\n        num_frames: Optional[int] = 16,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_videos_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n    ):\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        num_videos_per_prompt = 1\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_videos_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_videos_per_prompt,\n            num_channels_latents,\n            num_frames,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1\n        for free_init_iter in range(num_free_init_iters):\n            if self.free_init_enabled:\n                latents, timesteps = self._apply_free_init(\n                    latents, free_init_iter, num_inference_steps, device, latents.dtype, generator\n                )\n\n        self._num_timesteps = len(timesteps)\n\n        dummy = timesteps[0]\n        latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n        latent_model_input = self.scheduler.scale_model_input(latent_model_input, dummy)\n\n        self.unet = self.unet.to(memory_format=torch.channels_last)\n        self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last)\n        self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last)\n\n        unet_input_example = {\n            \"sample\": latent_model_input,\n            \"timestep\": dummy,\n            \"encoder_hidden_states\": prompt_embeds,\n        }\n\n        fake_latents = 1 / self.vae.config.scaling_factor * latents\n        batch_size, channels, num_frames, height, width = fake_latents.shape\n        fake_latents = fake_latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)\n        vae_decoder_input_example = fake_latents\n\n        # optimize with ipex\n        if dtype == torch.bfloat16:\n            self.unet = ipex.optimize(self.unet.eval(), dtype=torch.bfloat16, inplace=True)\n            self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True)\n            self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)\n        elif dtype == torch.float32:\n            self.unet = ipex.optimize(\n                self.unet.eval(),\n                dtype=torch.float32,\n                inplace=True,\n                # sample_input=unet_input_example,\n                level=\"O1\",\n                weights_prepack=True,\n                auto_kernel_selection=False,\n            )\n            self.vae.decoder = ipex.optimize(\n                self.vae.decoder.eval(),\n                dtype=torch.float32,\n                inplace=True,\n                level=\"O1\",\n                weights_prepack=True,\n                auto_kernel_selection=False,\n            )\n            self.text_encoder = ipex.optimize(\n                self.text_encoder.eval(),\n                dtype=torch.float32,\n                inplace=True,\n                level=\"O1\",\n                weights_prepack=True,\n                auto_kernel_selection=False,\n            )\n        else:\n            raise ValueError(\" The value of 'dtype' should be 'torch.bfloat16' or 'torch.float32' !\")\n\n        # trace unet model to get better performance on IPEX\n        with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():\n            unet_trace_model = torch.jit.trace(\n                self.unet, example_kwarg_inputs=unet_input_example, check_trace=False, strict=False\n            )\n            unet_trace_model = torch.jit.freeze(unet_trace_model)\n            self.unet.forward = unet_trace_model.forward\n\n        # trace vae.decoder model to get better performance on IPEX\n        with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():\n            vae_decoder_trace_model = torch.jit.trace(\n                self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False\n            )\n            vae_decoder_trace_model = torch.jit.freeze(vae_decoder_trace_model)\n            self.vae.decoder.forward = vae_decoder_trace_model.forward\n"
  },
  {
    "path": "examples/community/pipeline_controlnet_xl_kolors.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import (\n    AutoencoderKL,\n    ControlNetModel,\n    ImageProjection,\n    MultiControlNetModel,\n    UNet2DConditionModel,\n)\nfrom diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    deprecate,\n    is_invisible_watermark_available,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\n\n\nif is_invisible_watermark_available():\n    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import KolorsControlNetPipeline, ControlNetModel\n        >>> from diffusers.utils import load_image\n        >>> import numpy as np\n        >>> import cv2\n        >>> from PIL import Image\n\n        >>> prompt = \"aerial view, a futuristic research complex in a bright foggy jungle, hard lighting\"\n        >>> negative_prompt = \"low quality, bad quality, sketches\"\n\n        >>> # download an image\n        >>> image = load_image(\n        ...     \"https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png\"\n        ... )\n\n        >>> # initialize the models and pipeline\n        >>> controlnet_conditioning_scale = 0.5  # recommended for good generalization\n        >>> controlnet = ControlNetModel.from_pretrained(\n        ...     \"Kwai-Kolors/Kolors-ControlNet-Canny\",\n        ...     use_safetensors=True,\n        ...     torch_dtype=torch.float16\n        ... )\n\n        >>> pipe = KolorsControlNetPipeline.from_pretrained(\n        ...     \"Kwai-Kolors/Kolors-diffusers\",\n        ...     controlnet=controlnet,\n        ...     variant=\"fp16\",\n        ...     use_safetensors=True,\n        ...     torch_dtype=torch.float16\n        ... )\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> # get canny image\n        >>> image = np.array(image)\n        >>> image = cv2.Canny(image, 100, 200)\n        >>> image = image[:, :, None]\n        >>> image = np.concatenate([image, image, image], axis=2)\n        >>> canny_image = Image.fromarray(image)\n\n        >>> # generate image\n        >>> image = pipe(\n        ...     prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image\n        ... ).images[0]\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\nclass KolorsControlNetPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    FromSingleFileMixin,\n    IPAdapterMixin,\n):\n    r\"\"\"\n    Pipeline for image-to-image generation using Kolors with ControlNet guidance.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.safetensors` files\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`ChatGLMModel`]):\n            Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b).\n        tokenizer (`ChatGLMTokenizer`):\n            Tokenizer of class\n            [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):\n            Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets\n            as a list, the outputs from each ControlNet are added together to create one combined additional\n            conditioning.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        requires_aesthetics_score (`bool`, *optional*, defaults to `\"False\"`):\n            Whether the `unet` requires an `aesthetic_score` condition to be passed during inference.\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\n            `Kwai-Kolors/Kolors-diffusers`.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n\n    _optional_components = [\n        \"tokenizer\",\n        \"text_encoder\",\n        \"feature_extractor\",\n        \"image_encoder\",\n    ]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"add_text_embeds\",\n        \"add_time_ids\",\n        \"negative_pooled_prompt_embeds\",\n        \"negative_add_time_ids\",\n        \"image\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: ChatGLMModel,\n        tokenizer: ChatGLMTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        requires_aesthetics_score: bool = False,\n        force_zeros_for_empty_prompt: bool = True,\n        feature_extractor: CLIPImageProcessor = None,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        add_watermarker: Optional[bool] = None,\n    ):\n        super().__init__()\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n\n        if add_watermarker:\n            self.watermark = StableDiffusionXLWatermarker()\n        else:\n            self.watermark = None\n\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)\n\n    def encode_prompt(\n        self,\n        prompt,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer]\n        text_encoders = [self.text_encoder]\n\n        if prompt_embeds is None:\n            # textual inversion: procecss multi-vector tokens if necessary\n            prompt_embeds_list = []\n            for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=256,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                ).to(self._execution_device)\n                output = text_encoder(\n                    input_ids=text_inputs[\"input_ids\"],\n                    attention_mask=text_inputs[\"attention_mask\"],\n                    position_ids=text_inputs[\"position_ids\"],\n                    output_hidden_states=True,\n                )\n                prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()\n                pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()  # [batch_size, 4096]\n                bs_embed, seq_len, _ = prompt_embeds.shape\n                prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n                prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = prompt_embeds_list[0]\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            negative_prompt_embeds_list = []\n            for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n                # textual inversion: procecss multi-vector tokens if necessary\n                if isinstance(self, TextualInversionLoaderMixin):\n                    uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    uncond_tokens,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                ).to(self._execution_device)\n                output = text_encoder(\n                    input_ids=uncond_input[\"input_ids\"],\n                    attention_mask=uncond_input[\"attention_mask\"],\n                    position_ids=uncond_input[\"position_ids\"],\n                    output_hidden_states=True,\n                )\n                negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()\n                negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()  # [batch_size, 4096]\n\n                if do_classifier_free_guidance:\n                    # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n                    seq_len = negative_prompt_embeds.shape[1]\n\n                    negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)\n\n                    negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n                    negative_prompt_embeds = negative_prompt_embeds.view(\n                        batch_size * num_images_per_prompt, seq_len, -1\n                    )\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            negative_prompt_embeds = negative_prompt_embeds_list[0]\n\n        bs_embed = pooled_prompt_embeds.shape[0]\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        image_embeds = []\n        if do_classifier_free_guidance:\n            negative_image_embeds = []\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. \"\n                    f\"Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n\n                image_embeds.append(single_image_embeds[None, :])\n                if do_classifier_free_guidance:\n                    negative_image_embeds.append(single_negative_image_embeds[None, :])\n        else:\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        image,\n        num_inference_steps,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        controlnet_conditioning_scale=1.0,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if num_inference_steps is None:\n            raise ValueError(\"`num_inference_steps` cannot be None.\")\n        elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:\n            raise ValueError(\n                f\"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type\"\n                f\" {type(num_inference_steps)}.\"\n            )\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n        # `prompt` needs more sophisticated handling when there are multiple\n        # conditionings.\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if isinstance(prompt, list):\n                logger.warning(\n                    f\"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}\"\n                    \" prompts. The conditionings will be fixed across the prompts.\"\n                )\n\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if not isinstance(control_guidance_start, (tuple, list)):\n            control_guidance_start = [control_guidance_start]\n\n        if not isinstance(control_guidance_end, (tuple, list)):\n            control_guidance_end = [control_guidance_end]\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if len(control_guidance_start) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n        if ip_adapter_image_embeds is not None:\n            if not isinstance(ip_adapter_image_embeds, list):\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\n                )\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\n                )\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image\n    def prepare_control_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents_t2i(\n        self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None\n    ):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids\n    def _get_add_time_ids(\n        self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None\n    ):\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n\n        passed_add_embed_dim = (\n            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim\n        )\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        return add_time_ids\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        guess_mode: bool = False,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 0.8,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        original_size: Tuple[int, int] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Tuple[int, int] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If\n                the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also\n                be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height\n                and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in\n                init, images must be passed as a list such that each element of the list can be correctly batched for\n                input to a single controlnet.\n            height (`int`, *optional*, defaults to the size of image):\n                The height in pixels of the generated image. Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to the size of image):\n                The width in pixels of the generated image. Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            strength (`float`, *optional*, defaults to 0.8):\n                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1\n                essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original unet. If multiple ControlNets are specified in init, you can set the\n                corresponding scale as a list.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the controlnet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the controlnet stops applying.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`\n            containing the output images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # from IPython import embed; embed()\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            image,\n            num_inference_steps,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        # 3.1. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 3.2 Encode ip_adapter_image\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        if isinstance(controlnet, ControlNetModel):\n            image = self.prepare_control_image(\n                image=image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n            height, width = image.shape[-2:]\n        elif isinstance(controlnet, MultiControlNetModel):\n            control_images = []\n\n            for control_image_ in image:\n                control_image_ = self.prepare_control_image(\n                    image=control_image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                control_images.append(control_image_)\n\n            image = control_images\n            height, width = image[0].shape[-2:]\n        else:\n            assert False\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6.5 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7.1 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n        # 7.2 Prepare added time ids & embeddings\n        if isinstance(image, list):\n            original_size = original_size or image[0].shape[-2:]\n        else:\n            original_size = original_size or image.shape[-2:]\n        target_size = target_size or (height, width)\n\n        # 7. Prepare added time ids & embeddings\n        text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n\n        add_text_embeds = pooled_prompt_embeds\n        add_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n\n        if negative_original_size is not None and negative_target_size is not None:\n            negative_add_time_ids = self._get_add_time_ids(\n                negative_original_size,\n                negative_crops_coords_top_left,\n                negative_target_size,\n                dtype=prompt_embeds.dtype,\n                text_encoder_projection_dim=text_encoder_projection_dim,\n            )\n        else:\n            negative_add_time_ids = add_time_ids\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        # patch diffusers controlnet instance forward, undo\n        # after denoising loop\n\n        patched_cn_models = []\n        if isinstance(self.controlnet, MultiControlNetModel):\n            cn_models_to_patch = self.controlnet.nets\n        else:\n            cn_models_to_patch = [self.controlnet]\n\n        for cn_model in cn_models_to_patch:\n            cn_og_forward = cn_model.forward\n\n            def _cn_patch_forward(*args, **kwargs):\n                encoder_hidden_states = kwargs[\"encoder_hidden_states\"]\n                if cn_model.encoder_hid_proj is not None and cn_model.config.encoder_hid_dim_type == \"text_proj\":\n                    # Ensure encoder_hidden_states is on the same device as the projection layer\n                    encoder_hidden_states = encoder_hidden_states.to(cn_model.encoder_hid_proj.weight.device)\n                    encoder_hidden_states = cn_model.encoder_hid_proj(encoder_hidden_states)\n                kwargs.pop(\"encoder_hidden_states\")\n                return cn_og_forward(*args, encoder_hidden_states=encoder_hidden_states, **kwargs)\n\n            cn_model.forward = _cn_patch_forward\n            patched_cn_models.append((cn_model, cn_og_forward))\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        try:\n            with self.progress_bar(total=num_inference_steps) as progress_bar:\n                for i, t in enumerate(timesteps):\n                    # expand the latents if we are doing classifier free guidance\n                    latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                    added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n                    # controlnet(s) inference\n                    if guess_mode and self.do_classifier_free_guidance:\n                        # Infer ControlNet only for the conditional batch.\n                        control_model_input = latents\n                        control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                        controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                        controlnet_added_cond_kwargs = {\n                            \"text_embeds\": add_text_embeds.chunk(2)[1],\n                            \"time_ids\": add_time_ids.chunk(2)[1],\n                        }\n                    else:\n                        control_model_input = latent_model_input\n                        controlnet_prompt_embeds = prompt_embeds\n                        controlnet_added_cond_kwargs = added_cond_kwargs\n\n                    if isinstance(controlnet_keep[i], list):\n                        cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                    else:\n                        controlnet_cond_scale = controlnet_conditioning_scale\n                        if isinstance(controlnet_cond_scale, list):\n                            controlnet_cond_scale = controlnet_cond_scale[0]\n                        cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                    down_block_res_samples, mid_block_res_sample = self.controlnet(\n                        control_model_input,\n                        t,\n                        encoder_hidden_states=controlnet_prompt_embeds,\n                        controlnet_cond=image,\n                        conditioning_scale=cond_scale,\n                        guess_mode=guess_mode,\n                        added_cond_kwargs=controlnet_added_cond_kwargs,\n                        return_dict=False,\n                    )\n\n                    if guess_mode and self.do_classifier_free_guidance:\n                        # Inferred ControlNet only for the conditional batch.\n                        # To apply the output of ControlNet to both the unconditional and conditional batches,\n                        # add 0 to the unconditional batch to keep it unchanged.\n                        down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                        mid_block_res_sample = torch.cat(\n                            [torch.zeros_like(mid_block_res_sample), mid_block_res_sample]\n                        )\n\n                    if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n                        added_cond_kwargs[\"image_embeds\"] = image_embeds\n\n                    # predict the noise residual\n                    noise_pred = self.unet(\n                        latent_model_input,\n                        t,\n                        encoder_hidden_states=prompt_embeds,\n                        timestep_cond=timestep_cond,\n                        cross_attention_kwargs=self.cross_attention_kwargs,\n                        down_block_additional_residuals=down_block_res_samples,\n                        mid_block_additional_residual=mid_block_res_sample,\n                        added_cond_kwargs=added_cond_kwargs,\n                        return_dict=False,\n                    )[0]\n\n                    # perform guidance\n                    if self.do_classifier_free_guidance:\n                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                    # compute the previous noisy sample x_t -> x_t-1\n                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                    if callback_on_step_end is not None:\n                        callback_kwargs = {}\n                        for k in callback_on_step_end_tensor_inputs:\n                            callback_kwargs[k] = locals()[k]\n                        callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                        latents = callback_outputs.pop(\"latents\", latents)\n                        prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                        negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                        add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                        negative_pooled_prompt_embeds = callback_outputs.pop(\n                            \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                        )\n                        add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n                        negative_add_time_ids = callback_outputs.pop(\"negative_add_time_ids\", negative_add_time_ids)\n                        image = callback_outputs.pop(\"image\", image)\n\n                    # call the callback, if provided\n                    if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                        progress_bar.update()\n                        if callback is not None and i % callback_steps == 0:\n                            step_idx = i // getattr(self.scheduler, \"order\", 1)\n                            callback(step_idx, t, latents)\n\n        finally:\n            for cn_and_og in patched_cn_models:\n                cn_and_og[0].forward = cn_and_og[1]\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n            torch.cuda.ipc_collect()\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            latents = latents / self.vae.config.scaling_factor\n            image = self.vae.decode(latents, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n            return StableDiffusionXLPipelineOutput(images=image)\n\n        image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_controlnet_xl_kolors_img2img.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import (\n    AutoencoderKL,\n    ControlNetModel,\n    ImageProjection,\n    MultiControlNetModel,\n    UNet2DConditionModel,\n)\nfrom diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    deprecate,\n    is_invisible_watermark_available,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\n\n\nif is_invisible_watermark_available():\n    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> import numpy as np\n        >>> from PIL import Image\n\n        >>> from transformers import DPTImageProcessor, DPTForDepthEstimation\n        >>> from diffusers import ControlNetModel, KolorsControlNetImg2ImgPipeline\n        >>> from diffusers.utils import load_image\n\n        >>> depth_estimator = DPTForDepthEstimation.from_pretrained(\"Intel/dpt-hybrid-midas\").to(\"cuda\")\n        >>> feature_extractor = DPTImageProcessor.from_pretrained(\"Intel/dpt-hybrid-midas\")\n        >>> controlnet = ControlNetModel.from_pretrained(\n        ...     \"Kwai-Kolors/Kolors-ControlNet-Depth\",\n        ...     use_safetensors=True,\n        ...     torch_dtype=torch.float16\n        ... )\n        >>> pipe = KolorsControlNetImg2ImgPipeline.from_pretrained(\n        ...     \"Kwai-Kolors/Kolors-diffusers\",\n        ...     controlnet=controlnet,\n        ...     variant=\"fp16\",\n        ...     use_safetensors=True,\n        ...     torch_dtype=torch.float16\n        ... )\n        >>> pipe.enable_model_cpu_offload()\n\n\n        >>> def get_depth_map(image):\n        ...     image = feature_extractor(images=image, return_tensors=\"pt\").pixel_values.to(\"cuda\")\n        ...\n        ...     with torch.no_grad(), torch.autocast(\"cuda\"):\n        ...         depth_map = depth_estimator(image).predicted_depth\n        ...\n        ...     depth_map = torch.nn.functional.interpolate(\n        ...         depth_map.unsqueeze(1),\n        ...         size=(1024, 1024),\n        ...         mode=\"bicubic\",\n        ...         align_corners=False,\n        ...     )\n        ...     depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)\n        ...     depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)\n        ...     depth_map = (depth_map - depth_min) / (depth_max - depth_min)\n        ...     image = torch.cat([depth_map] * 3, dim=1)\n        ...     image = image.permute(0, 2, 3, 1).cpu().numpy()[0]\n        ...     image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))\n        ...     return image\n\n\n        >>> prompt = \"A robot, 4k photo\"\n        >>> image = load_image(\n        ...     \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main\"\n        ...     \"/kandinsky/cat.png\"\n        ... ).resize((1024, 1024))\n        >>> controlnet_conditioning_scale = 0.5  # recommended for good generalization\n        >>> depth_image = get_depth_map(image)\n\n        >>> images = pipe(\n        ...     prompt,\n        ...     image=image,\n        ...     control_image=depth_image,\n        ...     strength=0.80,\n        ...     num_inference_steps=50,\n        ...     controlnet_conditioning_scale=controlnet_conditioning_scale,\n        ... ).images\n        >>> images[0].save(f\"robot_cat.png\")\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\nclass KolorsControlNetImg2ImgPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    FromSingleFileMixin,\n    IPAdapterMixin,\n):\n    r\"\"\"\n    Pipeline for image-to-image generation using Kolors with ControlNet guidance.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.safetensors` files\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`ChatGLMModel`]):\n            Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b).\n        tokenizer (`ChatGLMTokenizer`):\n            Tokenizer of class\n            [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):\n            Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets\n            as a list, the outputs from each ControlNet are added together to create one combined additional\n            conditioning.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        requires_aesthetics_score (`bool`, *optional*, defaults to `\"False\"`):\n            Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the\n            config of `stabilityai/stable-diffusion-xl-refiner-1-0`.\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\n            `Kwai-Kolors/Kolors-diffusers`.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n\n    _optional_components = [\n        \"tokenizer\",\n        \"text_encoder\",\n        \"feature_extractor\",\n        \"image_encoder\",\n    ]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"add_text_embeds\",\n        \"add_time_ids\",\n        \"negative_pooled_prompt_embeds\",\n        \"add_neg_time_ids\",\n        \"control_image\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: ChatGLMModel,\n        tokenizer: ChatGLMTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        requires_aesthetics_score: bool = False,\n        force_zeros_for_empty_prompt: bool = True,\n        feature_extractor: CLIPImageProcessor = None,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        add_watermarker: Optional[bool] = None,\n    ):\n        super().__init__()\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n\n        if add_watermarker:\n            self.watermark = StableDiffusionXLWatermarker()\n        else:\n            self.watermark = None\n\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)\n\n    def encode_prompt(\n        self,\n        prompt,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        # from IPython import embed; embed(); exit()\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer]\n        text_encoders = [self.text_encoder]\n\n        if prompt_embeds is None:\n            # textual inversion: procecss multi-vector tokens if necessary\n            prompt_embeds_list = []\n            for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=256,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                ).to(self._execution_device)\n                output = text_encoder(\n                    input_ids=text_inputs[\"input_ids\"],\n                    attention_mask=text_inputs[\"attention_mask\"],\n                    position_ids=text_inputs[\"position_ids\"],\n                    output_hidden_states=True,\n                )\n                prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()\n                pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()  # [batch_size, 4096]\n                bs_embed, seq_len, _ = prompt_embeds.shape\n                prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n                prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            # prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n            prompt_embeds = prompt_embeds_list[0]\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            # negative_prompt = negative_prompt or \"\"\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            negative_prompt_embeds_list = []\n            for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n                # textual inversion: procecss multi-vector tokens if necessary\n                if isinstance(self, TextualInversionLoaderMixin):\n                    uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    uncond_tokens,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                ).to(self._execution_device)\n                output = text_encoder(\n                    input_ids=uncond_input[\"input_ids\"],\n                    attention_mask=uncond_input[\"attention_mask\"],\n                    position_ids=uncond_input[\"position_ids\"],\n                    output_hidden_states=True,\n                )\n                negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()\n                negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()  # [batch_size, 4096]\n\n                if do_classifier_free_guidance:\n                    # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n                    seq_len = negative_prompt_embeds.shape[1]\n\n                    negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)\n\n                    negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n                    negative_prompt_embeds = negative_prompt_embeds.view(\n                        batch_size * num_images_per_prompt, seq_len, -1\n                    )\n\n                    # For classifier free guidance, we need to do two forward passes.\n                    # Here we concatenate the unconditional and text embeddings into a single batch\n                    # to avoid doing two forward passes\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            # negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n            negative_prompt_embeds = negative_prompt_embeds_list[0]\n\n        bs_embed = pooled_prompt_embeds.shape[0]\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        image_embeds = []\n        if do_classifier_free_guidance:\n            negative_image_embeds = []\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n\n                image_embeds.append(single_image_embeds[None, :])\n                if do_classifier_free_guidance:\n                    negative_image_embeds.append(single_negative_image_embeds[None, :])\n        else:\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    negative_image_embeds.append(single_negative_image_embeds)\n                image_embeds.append(single_image_embeds)\n\n        ip_adapter_image_embeds = []\n        for i, single_image_embeds in enumerate(image_embeds):\n            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)\n            if do_classifier_free_guidance:\n                single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)\n                single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)\n\n            single_image_embeds = single_image_embeds.to(device=device)\n            ip_adapter_image_embeds.append(single_image_embeds)\n\n        return ip_adapter_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for others.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        image,\n        strength,\n        num_inference_steps,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        controlnet_conditioning_scale=1.0,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n        if num_inference_steps is None:\n            raise ValueError(\"`num_inference_steps` cannot be None.\")\n        elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:\n            raise ValueError(\n                f\"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type\"\n                f\" {type(num_inference_steps)}.\"\n            )\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n        # `prompt` needs more sophisticated handling when there are multiple\n        # conditionings.\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if isinstance(prompt, list):\n                logger.warning(\n                    f\"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}\"\n                    \" prompts. The conditionings will be fixed across the prompts.\"\n                )\n\n        # Check `image`\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            self.check_image(image, prompt, prompt_embeds)\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if not isinstance(image, list):\n                raise TypeError(\"For multiple controlnets: `image` must be type `list`\")\n\n            # When `image` is a nested list:\n            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])\n            elif any(isinstance(i, list) for i in image):\n                raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif len(image) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets.\"\n                )\n\n            for image_ in image:\n                self.check_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if not isinstance(control_guidance_start, (tuple, list)):\n            control_guidance_start = [control_guidance_start]\n\n        if not isinstance(control_guidance_end, (tuple, list)):\n            control_guidance_end = [control_guidance_end]\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if len(control_guidance_start) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n        if ip_adapter_image_embeds is not None:\n            if not isinstance(ip_adapter_image_embeds, list):\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\n                )\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\n                )\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image\n    def prepare_control_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n        if hasattr(self.scheduler, \"set_begin_index\"):\n            self.scheduler.set_begin_index(t_start * self.scheduler.order)\n\n        return timesteps, num_inference_steps - t_start\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents\n    def prepare_latents(\n        self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True\n    ):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        # Offload text encoder if `enable_model_cpu_offload` was enabled\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            torch.cuda.empty_cache()\n            torch.cuda.ipc_collect()\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            init_latents = image\n\n        else:\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            if self.vae.config.force_upcast:\n                image = image.float()\n                self.vae.to(dtype=torch.float32)\n\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            elif isinstance(generator, list):\n                init_latents = [\n                    retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                    for i in range(batch_size)\n                ]\n                init_latents = torch.cat(init_latents, dim=0)\n            else:\n                init_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n            if self.vae.config.force_upcast:\n                self.vae.to(dtype)\n\n            init_latents = init_latents.to(dtype)\n\n            init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        if add_noise:\n            shape = init_latents.shape\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            # get latents\n            init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n\n        latents = init_latents\n\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents_t2i(\n        self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None\n    ):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids\n    def _get_add_time_ids(\n        self,\n        original_size,\n        crops_coords_top_left,\n        target_size,\n        aesthetic_score,\n        negative_aesthetic_score,\n        negative_original_size,\n        negative_crops_coords_top_left,\n        negative_target_size,\n        dtype,\n        text_encoder_projection_dim=None,\n    ):\n        if self.config.requires_aesthetics_score:\n            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))\n            add_neg_time_ids = list(\n                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)\n            )\n        else:\n            add_time_ids = list(original_size + crops_coords_top_left + target_size)\n            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)\n\n        passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + 4096\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if (\n            expected_add_embed_dim > passed_add_embed_dim\n            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model.\"\n            )\n        elif (\n            expected_add_embed_dim < passed_add_embed_dim\n            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model.\"\n            )\n        elif expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)\n\n        return add_time_ids, add_neg_time_ids\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        control_image: PipelineImageInput = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        strength: float = 0.8,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        guess_mode: bool = False,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 0.8,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        original_size: Tuple[int, int] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Tuple[int, int] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        aesthetic_score: float = 6.0,\n        negative_aesthetic_score: float = 2.5,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The initial image will be used as the starting point for the image generation process. Can also accept\n                image latents as `image`, if passing latents directly, it will not be encoded again.\n            control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If\n                the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also\n                be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height\n                and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in\n                init, images must be passed as a list such that each element of the list can be correctly batched for\n                input to a single controlnet.\n            height (`int`, *optional*, defaults to the size of control_image):\n                The height in pixels of the generated image. Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to the size of control_image):\n                The width in pixels of the generated image. Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            strength (`float`, *optional*, defaults to 0.8):\n                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1\n                essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original unet. If multiple ControlNets are specified in init, you can set the\n                corresponding scale as a list.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the controlnet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the controlnet stops applying.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            aesthetic_score (`float`, *optional*, defaults to 6.0):\n                Used to simulate an aesthetic score of the generated image by influencing the positive text condition.\n                Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_aesthetic_score (`float`, *optional*, defaults to 2.5):\n                Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to\n                simulate an aesthetic score of the generated image by influencing the negative text condition.\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`\n            containing the output images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # from IPython import embed; embed()\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            control_image,\n            strength,\n            num_inference_steps,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        # 3.1. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 3.2 Encode ip_adapter_image\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 4. Prepare image and controlnet_conditioning_image\n        image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n\n        if isinstance(controlnet, ControlNetModel):\n            control_image = self.prepare_control_image(\n                image=control_image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n            height, width = control_image.shape[-2:]\n        elif isinstance(controlnet, MultiControlNetModel):\n            control_images = []\n\n            for control_image_ in control_image:\n                control_image_ = self.prepare_control_image(\n                    image=control_image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                control_images.append(control_image_)\n\n            control_image = control_images\n            height, width = control_image[0].shape[-2:]\n        else:\n            assert False\n\n        # 5. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n        self._num_timesteps = len(timesteps)\n\n        # 6. Prepare latent variables\n\n        num_channels_latents = self.unet.config.in_channels\n        if latents is None:\n            if strength >= 1.0:\n                latents = self.prepare_latents_t2i(\n                    batch_size * num_images_per_prompt,\n                    num_channels_latents,\n                    height,\n                    width,\n                    prompt_embeds.dtype,\n                    device,\n                    generator,\n                    latents,\n                )\n            else:\n                latents = self.prepare_latents(\n                    image,\n                    latent_timestep,\n                    batch_size,\n                    num_images_per_prompt,\n                    prompt_embeds.dtype,\n                    device,\n                    generator,\n                    True,\n                )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7.1 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n        # 7.2 Prepare added time ids & embeddings\n        if isinstance(control_image, list):\n            original_size = original_size or control_image[0].shape[-2:]\n        else:\n            original_size = original_size or control_image.shape[-2:]\n        target_size = target_size or (height, width)\n\n        # 7. Prepare added time ids & embeddings\n        if negative_original_size is None:\n            negative_original_size = original_size\n        if negative_target_size is None:\n            negative_target_size = target_size\n\n        add_text_embeds = pooled_prompt_embeds\n        text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n\n        add_time_ids, add_neg_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            aesthetic_score,\n            negative_aesthetic_score,\n            negative_original_size,\n            negative_crops_coords_top_left,\n            negative_target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)\n            add_neg_time_ids = torch.cat([add_neg_time_ids, add_neg_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n        add_neg_time_ids = add_neg_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        # patch diffusers controlnet instance forward, undo\n        # after denoising loop\n\n        patched_cn_models = []\n        if isinstance(self.controlnet, MultiControlNetModel):\n            cn_models_to_patch = self.controlnet.nets\n        else:\n            cn_models_to_patch = [self.controlnet]\n\n        for cn_model in cn_models_to_patch:\n            cn_og_forward = cn_model.forward\n\n            def _cn_patch_forward(*args, **kwargs):\n                encoder_hidden_states = kwargs[\"encoder_hidden_states\"]\n                if cn_model.encoder_hid_proj is not None and cn_model.config.encoder_hid_dim_type == \"text_proj\":\n                    # Ensure encoder_hidden_states is on the same device as the projection layer\n                    encoder_hidden_states = encoder_hidden_states.to(cn_model.encoder_hid_proj.weight.device)\n                    encoder_hidden_states = cn_model.encoder_hid_proj(encoder_hidden_states)\n                kwargs.pop(\"encoder_hidden_states\")\n                return cn_og_forward(*args, encoder_hidden_states=encoder_hidden_states, **kwargs)\n\n            cn_model.forward = _cn_patch_forward\n            patched_cn_models.append((cn_model, cn_og_forward))\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n\n        try:\n            with self.progress_bar(total=num_inference_steps) as progress_bar:\n                for i, t in enumerate(timesteps):\n                    # expand the latents if we are doing classifier free guidance\n                    latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                    added_cond_kwargs = {\n                        \"text_embeds\": add_text_embeds,\n                        \"time_ids\": add_time_ids,\n                        \"neg_time_ids\": add_neg_time_ids,\n                    }\n\n                    # controlnet(s) inference\n                    if guess_mode and self.do_classifier_free_guidance:\n                        # Infer ControlNet only for the conditional batch.\n                        control_model_input = latents\n                        control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                        controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                        controlnet_added_cond_kwargs = {\n                            \"text_embeds\": add_text_embeds.chunk(2)[1],\n                            \"time_ids\": add_time_ids.chunk(2)[1],\n                            \"neg_time_ids\": add_neg_time_ids.chunk(2)[1],\n                        }\n                    else:\n                        control_model_input = latent_model_input\n                        controlnet_prompt_embeds = prompt_embeds\n                        controlnet_added_cond_kwargs = added_cond_kwargs\n\n                    if isinstance(controlnet_keep[i], list):\n                        cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                    else:\n                        controlnet_cond_scale = controlnet_conditioning_scale\n                        if isinstance(controlnet_cond_scale, list):\n                            controlnet_cond_scale = controlnet_cond_scale[0]\n                        cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                    down_block_res_samples, mid_block_res_sample = self.controlnet(\n                        control_model_input,\n                        t,\n                        encoder_hidden_states=controlnet_prompt_embeds,\n                        controlnet_cond=control_image,\n                        conditioning_scale=cond_scale,\n                        guess_mode=guess_mode,\n                        added_cond_kwargs=controlnet_added_cond_kwargs,\n                        return_dict=False,\n                    )\n\n                    if guess_mode and self.do_classifier_free_guidance:\n                        # Inferred ControlNet only for the conditional batch.\n                        # To apply the output of ControlNet to both the unconditional and conditional batches,\n                        # add 0 to the unconditional batch to keep it unchanged.\n                        down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                        mid_block_res_sample = torch.cat(\n                            [torch.zeros_like(mid_block_res_sample), mid_block_res_sample]\n                        )\n\n                    if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n                        added_cond_kwargs[\"image_embeds\"] = image_embeds\n\n                    # predict the noise residual\n                    noise_pred = self.unet(\n                        latent_model_input,\n                        t,\n                        encoder_hidden_states=prompt_embeds,\n                        cross_attention_kwargs=self.cross_attention_kwargs,\n                        down_block_additional_residuals=down_block_res_samples,\n                        mid_block_additional_residual=mid_block_res_sample,\n                        added_cond_kwargs=added_cond_kwargs,\n                        return_dict=False,\n                    )[0]\n\n                    # perform guidance\n                    if self.do_classifier_free_guidance:\n                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                    # compute the previous noisy sample x_t -> x_t-1\n                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                    if callback_on_step_end is not None:\n                        callback_kwargs = {}\n                        for k in callback_on_step_end_tensor_inputs:\n                            callback_kwargs[k] = locals()[k]\n                        callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                        latents = callback_outputs.pop(\"latents\", latents)\n                        prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                        negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                        add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                        negative_pooled_prompt_embeds = callback_outputs.pop(\n                            \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                        )\n                        add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n                        add_neg_time_ids = callback_outputs.pop(\"add_neg_time_ids\", add_neg_time_ids)\n                        control_image = callback_outputs.pop(\"control_image\", control_image)\n\n                    # call the callback, if provided\n                    if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                        progress_bar.update()\n                        if callback is not None and i % callback_steps == 0:\n                            step_idx = i // getattr(self.scheduler, \"order\", 1)\n                            callback(step_idx, t, latents)\n        finally:\n            for cn_and_og in patched_cn_models:\n                cn_and_og[0].forward = cn_and_og[1]\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n            torch.cuda.ipc_collect()\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            latents = latents / self.vae.config.scaling_factor\n            image = self.vae.decode(latents, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n            return StableDiffusionXLPipelineOutput(images=image)\n\n        image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_controlnet_xl_kolors_inpaint.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import (\n    AutoencoderKL,\n    ControlNetModel,\n    ImageProjection,\n    MultiControlNetModel,\n    UNet2DConditionModel,\n)\nfrom diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import deprecate, is_invisible_watermark_available, logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\n\n\nif is_invisible_watermark_available():\n    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> from diffusers import KolorsControlNetInpaintPipeline, ControlNetModel\n        >>> from diffusers.utils import load_image\n        >>> from PIL import Image\n        >>> import numpy as np\n        >>> import torch\n        >>> import cv2\n\n        >>> init_image = load_image(\n        ...     \"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png\"\n        ... )\n        >>> init_image = init_image.resize((1024, 1024))\n\n        >>> generator = torch.Generator(device=\"cpu\").manual_seed(1)\n\n        >>> mask_image = load_image(\n        ...     \"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png\"\n        ... )\n        >>> mask_image = mask_image.resize((1024, 1024))\n\n\n        >>> def make_canny_condition(image):\n        ...     image = np.array(image)\n        ...     image = cv2.Canny(image, 100, 200)\n        ...     image = image[:, :, None]\n        ...     image = np.concatenate([image, image, image], axis=2)\n        ...     image = Image.fromarray(image)\n        ...     return image\n\n\n        >>> control_image = make_canny_condition(init_image)\n\n        >>> controlnet = ControlNetModel.from_pretrained(\n        ...     \"Kwai-Kolors/Kolors-ControlNet-Canny\",\n        ...     use_safetensors=True,\n        ...     torch_dtype=torch.float16\n        ... )\n        >>> pipe = KolorsControlNetInpaintPipeline.from_pretrained(\n        ...     \"Kwai-Kolors/Kolors-diffusers\",\n        ...     controlnet=controlnet,\n        ...     variant=\"fp16\",\n        ...     use_safetensors=True,\n        ...     torch_dtype=torch.float16\n        ... )\n\n        >>> pipe.enable_model_cpu_offload()\n\n        # generate image\n        >>> image = pipe(\n        ...     \"a handsome man with ray-ban sunglasses\",\n        ...     num_inference_steps=20,\n        ...     generator=generator,\n        ...     eta=1.0,\n        ...     image=init_image,\n        ...     mask_image=mask_image,\n        ...     control_image=control_image,\n        ... ).images[0]\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass KolorsControlNetInpaintPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    FromSingleFileMixin,\n    IPAdapterMixin,\n):\n    r\"\"\"\n    Pipeline for inpainting using Kolors with ControlNet guidance.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.safetensors` files\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`ChatGLMModel`]):\n            Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b).\n        tokenizer (`ChatGLMTokenizer`):\n            Tokenizer of class\n            [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):\n            Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets\n            as a list, the outputs from each ControlNet are added together to create one combined additional\n            conditioning.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        requires_aesthetics_score (`bool`, *optional*, defaults to `\"False\"`):\n            Whether the `unet` requires an `aesthetic_score` condition to be passed during inference. Also see the\n            config of `stabilityai/stable-diffusion-xl-refiner-1-0`.\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\n            `Kwai-Kolors/Kolors-diffusers`.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n\n    _optional_components = [\n        \"tokenizer\",\n        \"text_encoder\",\n        \"feature_extractor\",\n        \"image_encoder\",\n    ]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"add_text_embeds\",\n        \"add_time_ids\",\n        \"negative_pooled_prompt_embeds\",\n        \"add_neg_time_ids\",\n        \"mask\",\n        \"masked_image_latents\",\n        \"control_image\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: ChatGLMModel,\n        tokenizer: ChatGLMTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        requires_aesthetics_score: bool = False,\n        force_zeros_for_empty_prompt: bool = True,\n        feature_extractor: CLIPImageProcessor = None,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        add_watermarker: Optional[bool] = None,\n    ):\n        super().__init__()\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n        self.mask_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True\n        )\n\n        if add_watermarker:\n            self.watermark = StableDiffusionXLWatermarker()\n        else:\n            self.watermark = None\n\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)\n\n    def encode_prompt(\n        self,\n        prompt,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer]\n        text_encoders = [self.text_encoder]\n\n        if prompt_embeds is None:\n            # textual inversion: procecss multi-vector tokens if necessary\n            prompt_embeds_list = []\n            for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=256,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                ).to(self._execution_device)\n                output = text_encoder(\n                    input_ids=text_inputs[\"input_ids\"],\n                    attention_mask=text_inputs[\"attention_mask\"],\n                    position_ids=text_inputs[\"position_ids\"],\n                    output_hidden_states=True,\n                )\n                prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()\n                pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()  # [batch_size, 4096]\n                bs_embed, seq_len, _ = prompt_embeds.shape\n                prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n                prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n                prompt_embeds_list.append(prompt_embeds)\n\n            # prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n            prompt_embeds = prompt_embeds_list[0]\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            # negative_prompt = negative_prompt or \"\"\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            negative_prompt_embeds_list = []\n            for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n                # textual inversion: procecss multi-vector tokens if necessary\n                if isinstance(self, TextualInversionLoaderMixin):\n                    uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    uncond_tokens,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                ).to(self._execution_device)\n                output = text_encoder(\n                    input_ids=uncond_input[\"input_ids\"],\n                    attention_mask=uncond_input[\"attention_mask\"],\n                    position_ids=uncond_input[\"position_ids\"],\n                    output_hidden_states=True,\n                )\n                negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()\n                negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()  # [batch_size, 4096]\n\n                if do_classifier_free_guidance:\n                    # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n                    seq_len = negative_prompt_embeds.shape[1]\n\n                    negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)\n\n                    negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n                    negative_prompt_embeds = negative_prompt_embeds.view(\n                        batch_size * num_images_per_prompt, seq_len, -1\n                    )\n\n                    # For classifier free guidance, we need to do two forward passes.\n                    # Here we concatenate the unconditional and text embeddings into a single batch\n                    # to avoid doing two forward passes\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            # negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n            negative_prompt_embeds = negative_prompt_embeds_list[0]\n\n        bs_embed = pooled_prompt_embeds.shape[0]\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        image_embeds = []\n        if do_classifier_free_guidance:\n            negative_image_embeds = []\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got \"\n                    f\"{len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n\n                image_embeds.append(single_image_embeds[None, :])\n                if do_classifier_free_guidance:\n                    negative_image_embeds.append(single_negative_image_embeds[None, :])\n        else:\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    negative_image_embeds.append(single_negative_image_embeds)\n                image_embeds.append(single_image_embeds)\n\n        ip_adapter_image_embeds = []\n        for i, single_image_embeds in enumerate(image_embeds):\n            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)\n            if do_classifier_free_guidance:\n                single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)\n                single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)\n\n            single_image_embeds = single_image_embeds.to(device=device)\n            ip_adapter_image_embeds.append(single_image_embeds)\n\n        return ip_adapter_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        image,\n        strength,\n        num_inference_steps,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        controlnet_conditioning_scale=1.0,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n        if num_inference_steps is None:\n            raise ValueError(\"`num_inference_steps` cannot be None.\")\n        elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:\n            raise ValueError(\n                f\"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type\"\n                f\" {type(num_inference_steps)}.\"\n            )\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n        # `prompt` needs more sophisticated handling when there are multiple\n        # conditionings.\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if isinstance(prompt, list):\n                logger.warning(\n                    f\"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}\"\n                    \" prompts. The conditionings will be fixed across the prompts.\"\n                )\n\n        # Check `image`\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            self.check_image(image, prompt, prompt_embeds)\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if not isinstance(image, list):\n                raise TypeError(\"For multiple controlnets: `image` must be type `list`\")\n\n            # When `image` is a nested list:\n            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])\n            elif any(isinstance(i, list) for i in image):\n                raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif len(image) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets.\"\n                )\n\n            for image_ in image:\n                self.check_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if not isinstance(control_guidance_start, (tuple, list)):\n            control_guidance_start = [control_guidance_start]\n\n        if not isinstance(control_guidance_end, (tuple, list)):\n            control_guidance_end = [control_guidance_end]\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if len(control_guidance_start) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n        if ip_adapter_image_embeds is not None:\n            if not isinstance(ip_adapter_image_embeds, list):\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\n                )\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\n                )\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image\n    def prepare_control_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):\n        # get the original timestep using init_timestep\n        if denoising_start is None:\n            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n            t_start = max(num_inference_steps - init_timestep, 0)\n        else:\n            t_start = 0\n\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        # Strength is irrelevant if we directly request a timestep to start at;\n        # that is, strength is determined by the denoising_start instead.\n        if denoising_start is not None:\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_start * self.scheduler.config.num_train_timesteps)\n                )\n            )\n\n            num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()\n            if self.scheduler.order == 2 and num_inference_steps % 2 == 0:\n                # if the scheduler is a 2nd order scheduler we might have to do +1\n                # because `num_inference_steps` might be even given that every timestep\n                # (except the highest one) is duplicated. If `num_inference_steps` is even it would\n                # mean that we cut the timesteps in the middle of the denoising step\n                # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1\n                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler\n                num_inference_steps = num_inference_steps + 1\n\n            # because t_n+1 >= t_n, we slice the timesteps starting from the end\n            timesteps = timesteps[-num_inference_steps:]\n            return timesteps, num_inference_steps\n\n        return timesteps, num_inference_steps - t_start\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents\n    def prepare_latents(\n        self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True\n    ):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        # Offload text encoder if `enable_model_cpu_offload` was enabled\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            torch.cuda.empty_cache()\n            torch.cuda.ipc_collect()\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            init_latents = image\n\n        else:\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            if self.vae.config.force_upcast:\n                image = image.float()\n                self.vae.to(dtype=torch.float32)\n\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            elif isinstance(generator, list):\n                init_latents = [\n                    retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                    for i in range(batch_size)\n                ]\n                init_latents = torch.cat(init_latents, dim=0)\n            else:\n                init_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n            if self.vae.config.force_upcast:\n                self.vae.to(dtype)\n\n            init_latents = init_latents.to(dtype)\n\n            init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        if add_noise:\n            shape = init_latents.shape\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            # get latents\n            init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n\n        latents = init_latents\n\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents_t2i(\n        self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None\n    ):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids\n    def _get_add_time_ids(\n        self,\n        original_size,\n        crops_coords_top_left,\n        target_size,\n        aesthetic_score,\n        negative_aesthetic_score,\n        negative_original_size,\n        negative_crops_coords_top_left,\n        negative_target_size,\n        dtype,\n        text_encoder_projection_dim=None,\n    ):\n        if self.config.requires_aesthetics_score:\n            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))\n            add_neg_time_ids = list(\n                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)\n            )\n        else:\n            add_time_ids = list(original_size + crops_coords_top_left + target_size)\n            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)\n\n        passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + 4096\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if (\n            expected_add_embed_dim > passed_add_embed_dim\n            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model.\"\n            )\n        elif (\n            expected_add_embed_dim < passed_add_embed_dim\n            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model.\"\n            )\n        elif expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)\n\n        return add_time_ids, add_neg_time_ids\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    @property\n    def denoising_end(self):\n        return self._denoising_end\n\n    @property\n    def denoising_start(self):\n        return self._denoising_start\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\n        dtype = image.dtype\n        if self.vae.config.force_upcast:\n            image = image.float()\n            self.vae.to(dtype=torch.float32)\n\n        if isinstance(generator, list):\n            image_latents = [\n                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                for i in range(image.shape[0])\n            ]\n            image_latents = torch.cat(image_latents, dim=0)\n        else:\n            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n        if self.vae.config.force_upcast:\n            self.vae.to(dtype)\n\n        image_latents = image_latents.to(dtype)\n        image_latents = self.vae.config.scaling_factor * image_latents\n\n        return image_latents\n\n    def prepare_mask_latents(\n        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance\n    ):\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\n        # and half precision\n        mask = torch.nn.functional.interpolate(\n            mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)\n        )\n        mask = mask.to(device=device, dtype=dtype)\n\n        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\n        if mask.shape[0] < batch_size:\n            if not batch_size % mask.shape[0] == 0:\n                raise ValueError(\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\n                    f\" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number\"\n                    \" of masks that you pass is divisible by the total requested batch size.\"\n                )\n            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)\n\n        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask\n\n        if masked_image is not None and masked_image.shape[1] == 4:\n            masked_image_latents = masked_image\n        else:\n            masked_image_latents = None\n\n        if masked_image is not None:\n            if masked_image_latents is None:\n                masked_image = masked_image.to(device=device, dtype=dtype)\n                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)\n\n            if masked_image_latents.shape[0] < batch_size:\n                if not batch_size % masked_image_latents.shape[0] == 0:\n                    raise ValueError(\n                        \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                        f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\n                        \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                    )\n                masked_image_latents = masked_image_latents.repeat(\n                    batch_size // masked_image_latents.shape[0], 1, 1, 1\n                )\n\n            masked_image_latents = (\n                torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents\n            )\n\n            # aligning device to prevent device errors when concating it with the latent model input\n            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\n\n        return mask, masked_image_latents\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        mask_image: PipelineImageInput = None,\n        control_image: PipelineImageInput = None,\n        masked_image_latents: torch.Tensor = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        padding_mask_crop: Optional[int] = None,\n        strength: float = 0.9999,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        sigmas: List[float] = None,\n        denoising_start: Optional[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        guess_mode: bool = False,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 0.8,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        guidance_rescale: float = 0.0,\n        original_size: Tuple[int, int] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Tuple[int, int] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        aesthetic_score: float = 6.0,\n        negative_aesthetic_score: float = 2.5,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will\n                be masked out with `mask_image` and repainted according to `prompt`.\n            mask_image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted\n                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)\n                instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If\n                the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also\n                be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height\n                and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in\n                init, images must be passed as a list such that each element of the list can be correctly batched for\n                input to a single controlnet.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            padding_mask_crop (`int`, *optional*, defaults to `None`):\n                The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to\n                image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region\n                with the same aspect ration of the image and contains all masked area, and then expand that area based\n                on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before\n                resizing to the original image size for inpainting. This is useful when the masked area is small while\n                the image is large and contain information irrelevant for inpainting, such as background.\n            strength (`float`, *optional*, defaults to 0.9999):\n                Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be\n                between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the\n                `strength`. The number of denoising steps depends on the amount of noise initially added. When\n                `strength` is 1, added noise will be maximum and the denoising process will run for the full number of\n                iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked\n                portion of the reference `image`. Note that in the case of `denoising_start` being declared as an\n                integer, the value of `strength` will be ignored.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            denoising_start (`float`, *optional*):\n                When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be\n                bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and\n                it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,\n                strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline\n                is integrated into a \"Mixture of Denoisers\" multi-pipeline setup, as detailed in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be\n                denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the\n                final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline\n                forms a part of a \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original unet. If multiple ControlNets are specified in init, you can set the\n                corresponding scale as a list.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the controlnet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the controlnet stops applying.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            aesthetic_score (`float`, *optional*, defaults to 6.0):\n                Used to simulate an aesthetic score of the generated image by influencing the positive text condition.\n                Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_aesthetic_score (`float`, *optional*, defaults to 2.5):\n                Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to\n                simulate an aesthetic score of the generated image by influencing the negative text condition.\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # from IPython import embed; embed()\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            control_image,\n            strength,\n            num_inference_steps,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n        self._denoising_start = denoising_start\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        # 3.1. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 3.2 Encode ip_adapter_image\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 4. Prepare image, mask, and controlnet_conditioning_image\n        if isinstance(controlnet, ControlNetModel):\n            control_image = self.prepare_control_image(\n                image=control_image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n            height, width = control_image.shape[-2:]\n        elif isinstance(controlnet, MultiControlNetModel):\n            control_images = []\n\n            for control_image_ in control_image:\n                control_image_ = self.prepare_control_image(\n                    image=control_image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                control_images.append(control_image_)\n\n            control_image = control_images\n            height, width = control_image[0].shape[-2:]\n        else:\n            assert False\n\n        # 5. set timesteps\n        def denoising_value_valid(dnv):\n            return isinstance(dnv, float) and 0 < dnv < 1\n\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler, num_inference_steps, device, timesteps, sigmas\n        )\n        timesteps, num_inference_steps = self.get_timesteps(\n            num_inference_steps,\n            strength,\n            device,\n            denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,\n        )\n        # check that number of inference steps is not < 1 - as this doesn't make sense\n        if num_inference_steps < 1:\n            raise ValueError(\n                f\"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline\"\n                f\"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline.\"\n            )\n        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise\n        is_strength_max = strength == 1.0\n\n        # 6. Preprocess mask and image\n        if padding_mask_crop is not None:\n            crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)\n            resize_mode = \"fill\"\n        else:\n            crops_coords = None\n            resize_mode = \"default\"\n\n        original_image = image\n        init_image = self.image_processor.preprocess(\n            image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode\n        )\n        init_image = init_image.to(dtype=torch.float32)\n\n        mask = self.mask_processor.preprocess(\n            mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords\n        )\n\n        if masked_image_latents is not None:\n            masked_image = masked_image_latents\n        elif init_image.shape[1] == 4:\n            # if images are in latent space, we can't mask it\n            masked_image = None\n        else:\n            masked_image = init_image * (mask < 0.5)\n\n        # 7. Prepare latent variables\n        num_channels_latents = self.vae.config.latent_channels\n        num_channels_unet = self.unet.config.in_channels\n        return_image_latents = num_channels_unet == 4\n\n        if latents is None:\n            if strength >= 1.0:\n                latents = self.prepare_latents_t2i(\n                    batch_size * num_images_per_prompt,\n                    num_channels_latents,\n                    height,\n                    width,\n                    prompt_embeds.dtype,\n                    device,\n                    generator,\n                    latents,\n                )\n            else:\n                latents = self.prepare_latents(\n                    init_image,\n                    latent_timestep,\n                    batch_size,\n                    num_images_per_prompt,\n                    prompt_embeds.dtype,\n                    device,\n                    generator,\n                    True,\n                )\n\n        # 8. Prepare mask latent variables\n        mask, masked_image_latents = self.prepare_mask_latents(\n            mask,\n            masked_image,\n            batch_size * num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            self.do_classifier_free_guidance,\n        )\n\n        # 9. Check that sizes of mask, masked image and latents match\n        if num_channels_unet == 9:\n            # default case for stable-diffusion-v1-5/stable-diffusion-inpainting\n            num_channels_mask = mask.shape[1]\n            num_channels_masked_image = masked_image_latents.shape[1]\n            if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:\n                raise ValueError(\n                    f\"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects\"\n                    f\" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +\"\n                    f\" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}\"\n                    f\" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of\"\n                    \" `pipeline.unet` or your `mask_image` or `image` input.\"\n                )\n        elif num_channels_unet != 4:\n            raise ValueError(\n                f\"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}.\"\n            )\n\n        # 8.1. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 8.2 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n        # 9 Prepare added time ids & embeddings\n        if isinstance(control_image, list):\n            original_size = original_size or control_image[0].shape[-2:]\n        else:\n            original_size = original_size or control_image.shape[-2:]\n        target_size = target_size or (height, width)\n\n        if negative_original_size is None:\n            negative_original_size = original_size\n        if negative_target_size is None:\n            negative_target_size = target_size\n\n        add_text_embeds = pooled_prompt_embeds\n        text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n\n        add_time_ids, add_neg_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            aesthetic_score,\n            negative_aesthetic_score,\n            negative_original_size,\n            negative_crops_coords_top_left,\n            negative_target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)\n            add_neg_time_ids = torch.cat([add_neg_time_ids, add_neg_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n        add_neg_time_ids = add_neg_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        # 10. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        if (\n            self.denoising_end is not None\n            and self.denoising_start is not None\n            and denoising_value_valid(self.denoising_end)\n            and denoising_value_valid(self.denoising_start)\n            and self.denoising_start >= self.denoising_end\n        ):\n            raise ValueError(\n                f\"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: \"\n                + f\" {self.denoising_end} when using type float.\"\n            )\n        elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        # 11.1 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # patch diffusers controlnet instance forward, undo\n        # after denoising loop\n\n        patched_cn_models = []\n        if isinstance(self.controlnet, MultiControlNetModel):\n            cn_models_to_patch = self.controlnet.nets\n        else:\n            cn_models_to_patch = [self.controlnet]\n\n        for cn_model in cn_models_to_patch:\n            cn_og_forward = cn_model.forward\n\n            def _cn_patch_forward(*args, **kwargs):\n                encoder_hidden_states = kwargs[\"encoder_hidden_states\"]\n                if cn_model.encoder_hid_proj is not None and cn_model.config.encoder_hid_dim_type == \"text_proj\":\n                    # Ensure encoder_hidden_states is on the same device as the projection layer\n                    encoder_hidden_states = encoder_hidden_states.to(cn_model.encoder_hid_proj.weight.device)\n                    encoder_hidden_states = cn_model.encoder_hid_proj(encoder_hidden_states)\n                kwargs.pop(\"encoder_hidden_states\")\n                return cn_og_forward(*args, encoder_hidden_states=encoder_hidden_states, **kwargs)\n\n            cn_model.forward = _cn_patch_forward\n            patched_cn_models.append((cn_model, cn_og_forward))\n\n        try:\n            with self.progress_bar(total=num_inference_steps) as progress_bar:\n                for i, t in enumerate(timesteps):\n                    # expand the latents if we are doing classifier free guidance\n                    latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                    if num_channels_unet == 9:\n                        latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)\n\n                    added_cond_kwargs = {\n                        \"text_embeds\": add_text_embeds,\n                        \"time_ids\": add_time_ids,\n                        \"neg_time_ids\": add_neg_time_ids,\n                    }\n\n                    # controlnet(s) inference\n                    if guess_mode and self.do_classifier_free_guidance:\n                        # Infer ControlNet only for the conditional batch.\n                        control_model_input = latents\n                        control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                        controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                        controlnet_added_cond_kwargs = {\n                            \"text_embeds\": add_text_embeds.chunk(2)[1],\n                            \"time_ids\": add_time_ids.chunk(2)[1],\n                            \"neg_time_ids\": add_neg_time_ids.chunk(2)[1],\n                        }\n                    else:\n                        control_model_input = latent_model_input\n                        controlnet_prompt_embeds = prompt_embeds\n                        controlnet_added_cond_kwargs = added_cond_kwargs\n\n                    if isinstance(controlnet_keep[i], list):\n                        cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                    else:\n                        controlnet_cond_scale = controlnet_conditioning_scale\n                        if isinstance(controlnet_cond_scale, list):\n                            controlnet_cond_scale = controlnet_cond_scale[0]\n                        cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                    down_block_res_samples, mid_block_res_sample = self.controlnet(\n                        control_model_input,\n                        t,\n                        encoder_hidden_states=controlnet_prompt_embeds,\n                        controlnet_cond=control_image,\n                        conditioning_scale=cond_scale,\n                        guess_mode=guess_mode,\n                        added_cond_kwargs=controlnet_added_cond_kwargs,\n                        return_dict=False,\n                    )\n\n                    if guess_mode and self.do_classifier_free_guidance:\n                        # Inferred ControlNet only for the conditional batch.\n                        # To apply the output of ControlNet to both the unconditional and conditional batches,\n                        # add 0 to the unconditional batch to keep it unchanged.\n                        down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                        mid_block_res_sample = torch.cat(\n                            [torch.zeros_like(mid_block_res_sample), mid_block_res_sample]\n                        )\n\n                    if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n                        added_cond_kwargs[\"image_embeds\"] = image_embeds\n\n                    # predict the noise residual\n                    noise_pred = self.unet(\n                        latent_model_input,\n                        t,\n                        encoder_hidden_states=prompt_embeds,\n                        cross_attention_kwargs=self.cross_attention_kwargs,\n                        down_block_additional_residuals=down_block_res_samples,\n                        mid_block_additional_residual=mid_block_res_sample,\n                        added_cond_kwargs=added_cond_kwargs,\n                        return_dict=False,\n                    )[0]\n\n                    # perform guidance\n                    if self.do_classifier_free_guidance:\n                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                    # compute the previous noisy sample x_t -> x_t-1\n                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                    if callback_on_step_end is not None:\n                        callback_kwargs = {}\n                        for k in callback_on_step_end_tensor_inputs:\n                            callback_kwargs[k] = locals()[k]\n                        callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                        latents = callback_outputs.pop(\"latents\", latents)\n                        prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                        negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                        control_image = callback_outputs.pop(\"control_image\", control_image)\n\n                    # call the callback, if provided\n                    if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                        progress_bar.update()\n                        if callback is not None and i % callback_steps == 0:\n                            step_idx = i // getattr(self.scheduler, \"order\", 1)\n                            callback(step_idx, t, latents)\n        finally:\n            for cn_and_og in patched_cn_models:\n                cn_and_og[0].forward = cn_and_og[1]\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n            torch.cuda.ipc_collect()\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            latents = latents / self.vae.config.scaling_factor\n            image = self.vae.decode(latents, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n            return StableDiffusionXLPipelineOutput(images=image)\n\n        image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_demofusion_sdxl.py",
    "content": "import inspect\nimport os\nimport random\nimport warnings\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport matplotlib.pyplot as plt\nimport torch\nimport torch.nn.functional as F\nfrom transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer\n\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    StableDiffusionLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    deprecate,\n    is_accelerate_available,\n    is_accelerate_version,\n    is_invisible_watermark_available,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_invisible_watermark_available():\n    from diffusers.pipelines.stable_diffusion_xl.watermark import (\n        StableDiffusionXLWatermarker,\n    )\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import StableDiffusionXLPipeline\n\n        >>> pipe = StableDiffusionXLPipeline.from_pretrained(\n        ...     \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16\n        ... )\n        >>> pipe = pipe.to(\"cuda\")\n\n        >>> prompt = \"a photo of an astronaut riding a horse on mars\"\n        >>> image = pipe(prompt).images[0]\n        ```\n\"\"\"\n\n\ndef gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):\n    x_coord = torch.arange(kernel_size)\n    gaussian_1d = torch.exp(-((x_coord - (kernel_size - 1) / 2) ** 2) / (2 * sigma**2))\n    gaussian_1d = gaussian_1d / gaussian_1d.sum()\n    gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]\n    kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)\n\n    return kernel\n\n\ndef gaussian_filter(latents, kernel_size=3, sigma=1.0):\n    channels = latents.shape[1]\n    kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)\n    blurred_latents = F.conv2d(latents, kernel, padding=kernel_size // 2, groups=channels)\n\n    return blurred_latents\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\nclass DemoFusionSDXLPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    FromSingleFileMixin,\n    StableDiffusionLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion XL.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    In addition the pipeline inherits the following loading methods:\n        - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]\n        - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]\n\n    as well as the following saving methods:\n        - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion XL uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([` CLIPTextModelWithProjection`]):\n            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\n            specifically the\n            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)\n            variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`CLIPTokenizer`):\n            Second Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\n            `stabilityai/stable-diffusion-xl-base-1-0`.\n        add_watermarker (`bool`, *optional*):\n            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to\n            watermark output images. If not defined, it will default to True if the package is installed, otherwise no\n            watermarker will be used.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->unet->vae\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        text_encoder_2: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        tokenizer_2: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        force_zeros_for_empty_prompt: bool = True,\n        add_watermarker: Optional[bool] = None,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            unet=unet,\n            scheduler=scheduler,\n        )\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.default_sample_size = (\n            self.unet.config.sample_size\n            if hasattr(self, \"unet\") and self.unet is not None and hasattr(self.unet.config, \"sample_size\")\n            else 128\n        )\n\n        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()\n\n        if add_watermarker:\n            self.watermark = StableDiffusionXLWatermarker()\n        else:\n            self.watermark = None\n\n    def encode_prompt(\n        self,\n        prompt: str,\n        prompt_2: str | None = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]\n        text_encoders = (\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\n        )\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            # textual inversion: process multi-vector tokens if necessary\n            prompt_embeds_list = []\n            prompts = [prompt, prompt_2]\n            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                text_input_ids = text_inputs.input_ids\n                untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                    text_input_ids, untruncated_ids\n                ):\n                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                    logger.warning(\n                        \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                        f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                    )\n\n                prompt_embeds = text_encoder(\n                    text_input_ids.to(device),\n                    output_hidden_states=True,\n                )\n\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:\n                    pooled_prompt_embeds = prompt_embeds[0]\n\n                prompt_embeds = prompt_embeds.hidden_states[-2]\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\n\n            uncond_tokens: List[str]\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt, negative_prompt_2]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = [negative_prompt, negative_prompt_2]\n\n            negative_prompt_embeds_list = []\n            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    negative_prompt,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                negative_prompt_embeds = text_encoder(\n                    uncond_input.input_ids.to(device),\n                    output_hidden_states=True,\n                )\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:\n                    negative_pooled_prompt_embeds = negative_prompt_embeds[0]\n                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        num_images_per_prompt=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n        # DemoFusion specific checks\n        if max(height, width) % 1024 != 0:\n            raise ValueError(\n                f\"the larger one of `height` and `width` has to be divisible by 1024 but are {height} and {width}.\"\n            )\n\n        if num_images_per_prompt != 1:\n            warnings.warn(\"num_images_per_prompt != 1 is not supported by DemoFusion and will be ignored.\")\n            num_images_per_prompt = 1\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n\n        passed_add_embed_dim = (\n            self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim\n        )\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        return add_time_ids\n\n    def get_views(self, height, width, window_size=128, stride=64, random_jitter=False):\n        height //= self.vae_scale_factor\n        width //= self.vae_scale_factor\n        num_blocks_height = int((height - window_size) / stride - 1e-6) + 2 if height > window_size else 1\n        num_blocks_width = int((width - window_size) / stride - 1e-6) + 2 if width > window_size else 1\n        total_num_blocks = int(num_blocks_height * num_blocks_width)\n        views = []\n        for i in range(total_num_blocks):\n            h_start = int((i // num_blocks_width) * stride)\n            h_end = h_start + window_size\n            w_start = int((i % num_blocks_width) * stride)\n            w_end = w_start + window_size\n\n            if h_end > height:\n                h_start = int(h_start + height - h_end)\n                h_end = int(height)\n            if w_end > width:\n                w_start = int(w_start + width - w_end)\n                w_end = int(width)\n            if h_start < 0:\n                h_end = int(h_end - h_start)\n                h_start = 0\n            if w_start < 0:\n                w_end = int(w_end - w_start)\n                w_start = 0\n\n            if random_jitter:\n                jitter_range = (window_size - stride) // 4\n                w_jitter = 0\n                h_jitter = 0\n                if (w_start != 0) and (w_end != width):\n                    w_jitter = random.randint(-jitter_range, jitter_range)\n                elif (w_start == 0) and (w_end != width):\n                    w_jitter = random.randint(-jitter_range, 0)\n                elif (w_start != 0) and (w_end == width):\n                    w_jitter = random.randint(0, jitter_range)\n                if (h_start != 0) and (h_end != height):\n                    h_jitter = random.randint(-jitter_range, jitter_range)\n                elif (h_start == 0) and (h_end != height):\n                    h_jitter = random.randint(-jitter_range, 0)\n                elif (h_start != 0) and (h_end == height):\n                    h_jitter = random.randint(0, jitter_range)\n                h_start += h_jitter + jitter_range\n                h_end += h_jitter + jitter_range\n                w_start += w_jitter + jitter_range\n                w_end += w_jitter + jitter_range\n\n            views.append((h_start, h_end, w_start, w_end))\n        return views\n\n    def tiled_decode(self, latents, current_height, current_width):\n        core_size = self.unet.config.sample_size // 4\n        core_stride = core_size\n        pad_size = self.unet.config.sample_size // 4 * 3\n        decoder_view_batch_size = 1\n\n        views = self.get_views(current_height, current_width, stride=core_stride, window_size=core_size)\n        views_batch = [views[i : i + decoder_view_batch_size] for i in range(0, len(views), decoder_view_batch_size)]\n        latents_ = F.pad(latents, (pad_size, pad_size, pad_size, pad_size), \"constant\", 0)\n        image = torch.zeros(latents.size(0), 3, current_height, current_width).to(latents.device)\n        count = torch.zeros_like(image).to(latents.device)\n        # get the latents corresponding to the current view coordinates\n        with self.progress_bar(total=len(views_batch)) as progress_bar:\n            for j, batch_view in enumerate(views_batch):\n                len(batch_view)\n                latents_for_view = torch.cat(\n                    [\n                        latents_[:, :, h_start : h_end + pad_size * 2, w_start : w_end + pad_size * 2]\n                        for h_start, h_end, w_start, w_end in batch_view\n                    ]\n                )\n                image_patch = self.vae.decode(latents_for_view / self.vae.config.scaling_factor, return_dict=False)[0]\n                h_start, h_end, w_start, w_end = views[j]\n                h_start, h_end, w_start, w_end = (\n                    h_start * self.vae_scale_factor,\n                    h_end * self.vae_scale_factor,\n                    w_start * self.vae_scale_factor,\n                    w_end * self.vae_scale_factor,\n                )\n                p_h_start, p_h_end, p_w_start, p_w_end = (\n                    pad_size * self.vae_scale_factor,\n                    image_patch.size(2) - pad_size * self.vae_scale_factor,\n                    pad_size * self.vae_scale_factor,\n                    image_patch.size(3) - pad_size * self.vae_scale_factor,\n                )\n                image[:, :, h_start:h_end, w_start:w_end] += image_patch[:, :, p_h_start:p_h_end, p_w_start:p_w_end]\n                count[:, :, h_start:h_end, w_start:w_end] += 1\n                progress_bar.update()\n        image = image / count\n\n        return image\n\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = False,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        ################### DemoFusion specific parameters ####################\n        view_batch_size: int = 16,\n        multi_decoder: bool = True,\n        stride: Optional[int] = 64,\n        cosine_scale_1: Optional[float] = 3.0,\n        cosine_scale_2: Optional[float] = 1.0,\n        cosine_scale_3: Optional[float] = 1.0,\n        sigma: Optional[float] = 0.8,\n        show_image: bool = False,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise as determined by the discrete timesteps selected by the\n                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a\n                \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead\n                of a plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.7):\n                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of\n                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).\n                Guidance rescale factor should fix overexposure when using zero terminal SNR.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            ################### DemoFusion specific parameters ####################\n            view_batch_size (`int`, defaults to 16):\n                The batch size for multiple denoising paths. Typically, a larger batch size can result in higher\n                efficiency but comes with increased GPU memory requirements.\n            multi_decoder (`bool`, defaults to True):\n                Determine whether to use a tiled decoder. Generally, when the resolution exceeds 3072x3072,\n                a tiled decoder becomes necessary.\n            stride (`int`, defaults to 64):\n                The stride of moving local patches. A smaller stride is better for alleviating seam issues,\n                but it also introduces additional computational overhead and inference time.\n            cosine_scale_1 (`float`, defaults to 3):\n                Control the strength of skip-residual. For specific impacts, please refer to Appendix C\n                in the DemoFusion paper.\n            cosine_scale_2 (`float`, defaults to 1):\n                Control the strength of dilated sampling. For specific impacts, please refer to Appendix C\n                in the DemoFusion paper.\n            cosine_scale_3 (`float`, defaults to 1):\n                Control the strength of the gaussian filter. For specific impacts, please refer to Appendix C\n                in the DemoFusion paper.\n            sigma (`float`, defaults to 1):\n                The standard value of the gaussian filter.\n            show_image (`bool`, defaults to False):\n                Determine whether to show intermediate results during generation.\n\n        Examples:\n\n        Returns:\n            a `list` with the generated images at each phase.\n        \"\"\"\n\n        # 0. Default height and width to unet\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        x1_size = self.default_sample_size * self.vae_scale_factor\n\n        height_scale = height / x1_size\n        width_scale = width / x1_size\n        scale_num = int(max(height_scale, width_scale))\n        aspect_ratio = min(height_scale, width_scale) / max(height_scale, width_scale)\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n            num_images_per_prompt,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height // scale_num,\n            width // scale_num,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Prepare added time ids & embeddings\n        add_text_embeds = pooled_prompt_embeds\n        add_time_ids = self._get_add_time_ids(\n            original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype\n        )\n        if negative_original_size is not None and negative_target_size is not None:\n            negative_add_time_ids = self._get_add_time_ids(\n                negative_original_size,\n                negative_crops_coords_top_left,\n                negative_target_size,\n                dtype=prompt_embeds.dtype,\n            )\n        else:\n            negative_add_time_ids = add_time_ids\n\n        if do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        # 8. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # 7.1 Apply denoising_end\n        if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        output_images = []\n\n        ############################################################### Phase 1 #################################################################\n\n        print(\"### Phase 1 Denoising ###\")\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                latents_for_view = latents\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = latents.repeat_interleave(2, dim=0) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if do_classifier_free_guidance and guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n            anchor_mean = latents.mean()\n            anchor_std = latents.std()\n            if not output_type == \"latent\":\n                # make sure the VAE is in float32 mode, as it overflows in float16\n                needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n                if needs_upcasting:\n                    self.upcast_vae()\n                    latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n                print(\"### Phase 1 Decoding ###\")\n                image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n                # cast back to fp16 if needed\n                if needs_upcasting:\n                    self.vae.to(dtype=torch.float16)\n\n            image = self.image_processor.postprocess(image, output_type=output_type)\n            if show_image:\n                plt.figure(figsize=(10, 10))\n                plt.imshow(image[0])\n                plt.axis(\"off\")  # Turn off axis numbers and ticks\n                plt.show()\n            output_images.append(image[0])\n\n        ####################################################### Phase 2+ #####################################################\n\n        for current_scale_num in range(2, scale_num + 1):\n            print(\"### Phase {} Denoising ###\".format(current_scale_num))\n            current_height = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num\n            current_width = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num\n            if height > width:\n                current_width = int(current_width * aspect_ratio)\n            else:\n                current_height = int(current_height * aspect_ratio)\n\n            latents = F.interpolate(\n                latents,\n                size=(int(current_height / self.vae_scale_factor), int(current_width / self.vae_scale_factor)),\n                mode=\"bicubic\",\n            )\n\n            noise_latents = []\n            noise = torch.randn_like(latents)\n            for timestep in timesteps:\n                noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0))\n                noise_latents.append(noise_latent)\n            latents = noise_latents[0]\n\n            with self.progress_bar(total=num_inference_steps) as progress_bar:\n                for i, t in enumerate(timesteps):\n                    count = torch.zeros_like(latents)\n                    value = torch.zeros_like(latents)\n                    cosine_factor = (\n                        0.5\n                        * (\n                            1\n                            + torch.cos(\n                                torch.pi\n                                * (self.scheduler.config.num_train_timesteps - t)\n                                / self.scheduler.config.num_train_timesteps\n                            )\n                        ).cpu()\n                    )\n\n                    c1 = cosine_factor**cosine_scale_1\n                    latents = latents * (1 - c1) + noise_latents[i] * c1\n\n                    ############################################# MultiDiffusion #############################################\n\n                    views = self.get_views(\n                        current_height,\n                        current_width,\n                        stride=stride,\n                        window_size=self.unet.config.sample_size,\n                        random_jitter=True,\n                    )\n                    views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]\n\n                    jitter_range = (self.unet.config.sample_size - stride) // 4\n                    latents_ = F.pad(latents, (jitter_range, jitter_range, jitter_range, jitter_range), \"constant\", 0)\n\n                    count_local = torch.zeros_like(latents_)\n                    value_local = torch.zeros_like(latents_)\n\n                    for j, batch_view in enumerate(views_batch):\n                        vb_size = len(batch_view)\n\n                        # get the latents corresponding to the current view coordinates\n                        latents_for_view = torch.cat(\n                            [\n                                latents_[:, :, h_start:h_end, w_start:w_end]\n                                for h_start, h_end, w_start, w_end in batch_view\n                            ]\n                        )\n\n                        # expand the latents if we are doing classifier free guidance\n                        latent_model_input = latents_for_view\n                        latent_model_input = (\n                            latent_model_input.repeat_interleave(2, dim=0)\n                            if do_classifier_free_guidance\n                            else latent_model_input\n                        )\n                        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                        prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)\n                        add_text_embeds_input = torch.cat([add_text_embeds] * vb_size)\n                        add_time_ids_input = []\n                        for h_start, h_end, w_start, w_end in batch_view:\n                            add_time_ids_ = add_time_ids.clone()\n                            add_time_ids_[:, 2] = h_start * self.vae_scale_factor\n                            add_time_ids_[:, 3] = w_start * self.vae_scale_factor\n                            add_time_ids_input.append(add_time_ids_)\n                        add_time_ids_input = torch.cat(add_time_ids_input)\n\n                        # predict the noise residual\n                        added_cond_kwargs = {\"text_embeds\": add_text_embeds_input, \"time_ids\": add_time_ids_input}\n                        noise_pred = self.unet(\n                            latent_model_input,\n                            t,\n                            encoder_hidden_states=prompt_embeds_input,\n                            cross_attention_kwargs=cross_attention_kwargs,\n                            added_cond_kwargs=added_cond_kwargs,\n                            return_dict=False,\n                        )[0]\n\n                        if do_classifier_free_guidance:\n                            noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]\n                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                        if do_classifier_free_guidance and guidance_rescale > 0.0:\n                            # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                            noise_pred = rescale_noise_cfg(\n                                noise_pred, noise_pred_text, guidance_rescale=guidance_rescale\n                            )\n\n                        # compute the previous noisy sample x_t -> x_t-1\n                        self.scheduler._init_step_index(t)\n                        latents_denoised_batch = self.scheduler.step(\n                            noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False\n                        )[0]\n\n                        # extract value from batch\n                        for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(\n                            latents_denoised_batch.chunk(vb_size), batch_view\n                        ):\n                            value_local[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised\n                            count_local[:, :, h_start:h_end, w_start:w_end] += 1\n\n                    value_local = value_local[\n                        :,\n                        :,\n                        jitter_range : jitter_range + current_height // self.vae_scale_factor,\n                        jitter_range : jitter_range + current_width // self.vae_scale_factor,\n                    ]\n                    count_local = count_local[\n                        :,\n                        :,\n                        jitter_range : jitter_range + current_height // self.vae_scale_factor,\n                        jitter_range : jitter_range + current_width // self.vae_scale_factor,\n                    ]\n\n                    c2 = cosine_factor**cosine_scale_2\n\n                    value += value_local / count_local * (1 - c2)\n                    count += torch.ones_like(value_local) * (1 - c2)\n\n                    ############################################# Dilated Sampling #############################################\n\n                    views = [[h, w] for h in range(current_scale_num) for w in range(current_scale_num)]\n                    views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]\n\n                    h_pad = (current_scale_num - (latents.size(2) % current_scale_num)) % current_scale_num\n                    w_pad = (current_scale_num - (latents.size(3) % current_scale_num)) % current_scale_num\n                    latents_ = F.pad(latents, (w_pad, 0, h_pad, 0), \"constant\", 0)\n\n                    count_global = torch.zeros_like(latents_)\n                    value_global = torch.zeros_like(latents_)\n\n                    c3 = 0.99 * cosine_factor**cosine_scale_3 + 1e-2\n                    std_, mean_ = latents_.std(), latents_.mean()\n                    latents_gaussian = gaussian_filter(\n                        latents_, kernel_size=(2 * current_scale_num - 1), sigma=sigma * c3\n                    )\n                    latents_gaussian = (\n                        latents_gaussian - latents_gaussian.mean()\n                    ) / latents_gaussian.std() * std_ + mean_\n\n                    for j, batch_view in enumerate(views_batch):\n                        latents_for_view = torch.cat(\n                            [latents_[:, :, h::current_scale_num, w::current_scale_num] for h, w in batch_view]\n                        )\n                        latents_for_view_gaussian = torch.cat(\n                            [latents_gaussian[:, :, h::current_scale_num, w::current_scale_num] for h, w in batch_view]\n                        )\n\n                        vb_size = latents_for_view.size(0)\n\n                        # expand the latents if we are doing classifier free guidance\n                        latent_model_input = latents_for_view_gaussian\n                        latent_model_input = (\n                            latent_model_input.repeat_interleave(2, dim=0)\n                            if do_classifier_free_guidance\n                            else latent_model_input\n                        )\n                        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                        prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)\n                        add_text_embeds_input = torch.cat([add_text_embeds] * vb_size)\n                        add_time_ids_input = torch.cat([add_time_ids] * vb_size)\n\n                        # predict the noise residual\n                        added_cond_kwargs = {\"text_embeds\": add_text_embeds_input, \"time_ids\": add_time_ids_input}\n                        noise_pred = self.unet(\n                            latent_model_input,\n                            t,\n                            encoder_hidden_states=prompt_embeds_input,\n                            cross_attention_kwargs=cross_attention_kwargs,\n                            added_cond_kwargs=added_cond_kwargs,\n                            return_dict=False,\n                        )[0]\n\n                        if do_classifier_free_guidance:\n                            noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]\n                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                        if do_classifier_free_guidance and guidance_rescale > 0.0:\n                            # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                            noise_pred = rescale_noise_cfg(\n                                noise_pred, noise_pred_text, guidance_rescale=guidance_rescale\n                            )\n\n                        # compute the previous noisy sample x_t -> x_t-1\n                        self.scheduler._init_step_index(t)\n                        latents_denoised_batch = self.scheduler.step(\n                            noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False\n                        )[0]\n\n                        # extract value from batch\n                        for latents_view_denoised, (h, w) in zip(latents_denoised_batch.chunk(vb_size), batch_view):\n                            value_global[:, :, h::current_scale_num, w::current_scale_num] += latents_view_denoised\n                            count_global[:, :, h::current_scale_num, w::current_scale_num] += 1\n\n                    c2 = cosine_factor**cosine_scale_2\n\n                    value_global = value_global[:, :, h_pad:, w_pad:]\n\n                    value += value_global * c2\n                    count += torch.ones_like(value_global) * c2\n\n                    ###########################################################\n\n                    latents = torch.where(count > 0, value / count, value)\n\n                    # call the callback, if provided\n                    if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                        progress_bar.update()\n                        if callback is not None and i % callback_steps == 0:\n                            step_idx = i // getattr(self.scheduler, \"order\", 1)\n                            callback(step_idx, t, latents)\n\n                #########################################################################################################################################\n\n                latents = (latents - latents.mean()) / latents.std() * anchor_std + anchor_mean\n                if not output_type == \"latent\":\n                    # make sure the VAE is in float32 mode, as it overflows in float16\n                    needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n                    if needs_upcasting:\n                        self.upcast_vae()\n                        latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n                    print(\"### Phase {} Decoding ###\".format(current_scale_num))\n                    if multi_decoder:\n                        image = self.tiled_decode(latents, current_height, current_width)\n                    else:\n                        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n\n                    # cast back to fp16 if needed\n                    if needs_upcasting:\n                        self.vae.to(dtype=torch.float16)\n                else:\n                    image = latents\n\n                if not output_type == \"latent\":\n                    image = self.image_processor.postprocess(image, output_type=output_type)\n                    if show_image:\n                        plt.figure(figsize=(10, 10))\n                        plt.imshow(image[0])\n                        plt.axis(\"off\")  # Turn off axis numbers and ticks\n                        plt.show()\n                    output_images.append(image[0])\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        return output_images\n\n    # Override to properly handle the loading and unloading of the additional text encoder.\n    def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):\n        # We could have accessed the unet config from `lora_state_dict()` too. We pass\n        # it here explicitly to be able to tell that it's coming from an SDXL\n        # pipeline.\n\n        # Remove any existing hooks.\n        if is_accelerate_available() and is_accelerate_version(\">=\", \"0.17.0.dev0\"):\n            from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module\n        else:\n            raise ImportError(\"Offloading requires `accelerate v0.17.0` or higher.\")\n\n        is_model_cpu_offload = False\n        is_sequential_cpu_offload = False\n        recursive = False\n        for _, component in self.components.items():\n            if isinstance(component, torch.nn.Module):\n                if hasattr(component, \"_hf_hook\"):\n                    is_model_cpu_offload = isinstance(getattr(component, \"_hf_hook\"), CpuOffload)\n                    is_sequential_cpu_offload = (\n                        isinstance(getattr(component, \"_hf_hook\"), AlignDevicesHook)\n                        or hasattr(component._hf_hook, \"hooks\")\n                        and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)\n                    )\n                    logger.info(\n                        \"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again.\"\n                    )\n                    recursive = is_sequential_cpu_offload\n                    remove_hook_from_module(component, recurse=recursive)\n        state_dict, network_alphas = self.lora_state_dict(\n            pretrained_model_name_or_path_or_dict,\n            unet_config=self.unet.config,\n            **kwargs,\n        )\n        self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)\n\n        text_encoder_state_dict = {k: v for k, v in state_dict.items() if \"text_encoder.\" in k}\n        if len(text_encoder_state_dict) > 0:\n            self.load_lora_into_text_encoder(\n                text_encoder_state_dict,\n                network_alphas=network_alphas,\n                text_encoder=self.text_encoder,\n                prefix=\"text_encoder\",\n                lora_scale=self.lora_scale,\n            )\n\n        text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if \"text_encoder_2.\" in k}\n        if len(text_encoder_2_state_dict) > 0:\n            self.load_lora_into_text_encoder(\n                text_encoder_2_state_dict,\n                network_alphas=network_alphas,\n                text_encoder=self.text_encoder_2,\n                prefix=\"text_encoder_2\",\n                lora_scale=self.lora_scale,\n            )\n\n        # Offload back.\n        if is_model_cpu_offload:\n            self.enable_model_cpu_offload()\n        elif is_sequential_cpu_offload:\n            self.enable_sequential_cpu_offload()\n\n    @classmethod\n    def save_lora_weights(\n        cls,\n        save_directory: Union[str, os.PathLike],\n        unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,\n        text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,\n        text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,\n        is_main_process: bool = True,\n        weight_name: str = None,\n        save_function: Callable = None,\n        safe_serialization: bool = True,\n    ):\n        state_dict = {}\n\n        def pack_weights(layers, prefix):\n            layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers\n            layers_state_dict = {f\"{prefix}.{module_name}\": param for module_name, param in layers_weights.items()}\n            return layers_state_dict\n\n        if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):\n            raise ValueError(\n                \"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`.\"\n            )\n\n        if unet_lora_layers:\n            state_dict.update(pack_weights(unet_lora_layers, \"unet\"))\n\n        if text_encoder_lora_layers and text_encoder_2_lora_layers:\n            state_dict.update(pack_weights(text_encoder_lora_layers, \"text_encoder\"))\n            state_dict.update(pack_weights(text_encoder_2_lora_layers, \"text_encoder_2\"))\n\n        cls.write_lora_layers(\n            state_dict=state_dict,\n            save_directory=save_directory,\n            is_main_process=is_main_process,\n            weight_name=weight_name,\n            save_function=save_function,\n            safe_serialization=safe_serialization,\n        )\n\n    def _remove_text_encoder_monkey_patch(self):\n        self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)\n        self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)\n"
  },
  {
    "path": "examples/community/pipeline_fabric.py",
    "content": "# Copyright 2025 FABRIC authors and the HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import List, Optional, Union\n\nimport torch\nfrom packaging import version\nfrom PIL import Image\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models.attention import BasicTransformerBlock\nfrom diffusers.models.attention_processor import LoRAAttnProcessor\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.schedulers import EulerAncestralDiscreteScheduler, KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    deprecate,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> from diffusers import DiffusionPipeline\n        >>> import torch\n\n        >>> model_id = \"dreamlike-art/dreamlike-photoreal-2.0\"\n        >>> pipe = DiffusionPipeline(model_id, torch_dtype=torch.float16, custom_pipeline=\"pipeline_fabric\")\n        >>> pipe = pipe.to(\"cuda\")\n        >>> prompt = \"a giant standing in a fantasy landscape best quality\"\n        >>> liked = []  # list of images for positive feedback\n        >>> disliked = []  # list of images for negative feedback\n        >>> image = pipe(prompt, num_images=4, liked=liked, disliked=disliked).images[0]\n        ```\n\"\"\"\n\n\nclass FabricCrossAttnProcessor:\n    def __init__(self):\n        self.attntion_probs = None\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        weights=None,\n        lora_scale=1.0,\n    ):\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if isinstance(attn.processor, LoRAAttnProcessor):\n            query = attn.to_q(hidden_states) + lora_scale * attn.processor.to_q_lora(hidden_states)\n        else:\n            query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        if isinstance(attn.processor, LoRAAttnProcessor):\n            key = attn.to_k(encoder_hidden_states) + lora_scale * attn.processor.to_k_lora(encoder_hidden_states)\n            value = attn.to_v(encoder_hidden_states) + lora_scale * attn.processor.to_v_lora(encoder_hidden_states)\n        else:\n            key = attn.to_k(encoder_hidden_states)\n            value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n\n        if weights is not None:\n            if weights.shape[0] != 1:\n                weights = weights.repeat_interleave(attn.heads, dim=0)\n            attention_probs = attention_probs * weights[:, None]\n            attention_probs = attention_probs / attention_probs.sum(dim=-1, keepdim=True)\n\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        if isinstance(attn.processor, LoRAAttnProcessor):\n            hidden_states = attn.to_out[0](hidden_states) + lora_scale * attn.processor.to_out_lora(hidden_states)\n        else:\n            hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        return hidden_states\n\n\nclass FabricPipeline(DiffusionPipeline):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion and conditioning the results using feedback images.\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        scheduler ([`EulerAncestralDiscreteScheduler`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            unet=unet,\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            scheduler=scheduler,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            prompt_embeds = self.text_encoder(\n                text_input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    def get_unet_hidden_states(self, z_all, t, prompt_embd):\n        cached_hidden_states = []\n        for module in self.unet.modules():\n            if isinstance(module, BasicTransformerBlock):\n\n                def new_forward(self, hidden_states, *args, **kwargs):\n                    cached_hidden_states.append(hidden_states.clone().detach().cpu())\n                    return self.old_forward(hidden_states, *args, **kwargs)\n\n                module.attn1.old_forward = module.attn1.forward\n                module.attn1.forward = new_forward.__get__(module.attn1)\n\n        # run forward pass to cache hidden states, output can be discarded\n        _ = self.unet(z_all, t, encoder_hidden_states=prompt_embd)\n\n        # restore original forward pass\n        for module in self.unet.modules():\n            if isinstance(module, BasicTransformerBlock):\n                module.attn1.forward = module.attn1.old_forward\n                del module.attn1.old_forward\n\n        return cached_hidden_states\n\n    def unet_forward_with_cached_hidden_states(\n        self,\n        z_all,\n        t,\n        prompt_embd,\n        cached_pos_hiddens: Optional[List[torch.Tensor]] = None,\n        cached_neg_hiddens: Optional[List[torch.Tensor]] = None,\n        pos_weights=(0.8, 0.8),\n        neg_weights=(0.5, 0.5),\n    ):\n        if cached_pos_hiddens is None and cached_neg_hiddens is None:\n            return self.unet(z_all, t, encoder_hidden_states=prompt_embd)\n\n        local_pos_weights = torch.linspace(*pos_weights, steps=len(self.unet.down_blocks) + 1)[:-1].tolist()\n        local_neg_weights = torch.linspace(*neg_weights, steps=len(self.unet.down_blocks) + 1)[:-1].tolist()\n        for block, pos_weight, neg_weight in zip(\n            self.unet.down_blocks + [self.unet.mid_block] + self.unet.up_blocks,\n            local_pos_weights + [pos_weights[1]] + local_pos_weights[::-1],\n            local_neg_weights + [neg_weights[1]] + local_neg_weights[::-1],\n        ):\n            for module in block.modules():\n                if isinstance(module, BasicTransformerBlock):\n\n                    def new_forward(\n                        self,\n                        hidden_states,\n                        pos_weight=pos_weight,\n                        neg_weight=neg_weight,\n                        **kwargs,\n                    ):\n                        cond_hiddens, uncond_hiddens = hidden_states.chunk(2, dim=0)\n                        batch_size, d_model = cond_hiddens.shape[:2]\n                        device, dtype = hidden_states.device, hidden_states.dtype\n\n                        weights = torch.ones(batch_size, d_model, device=device, dtype=dtype)\n                        out_pos = self.old_forward(hidden_states)\n                        out_neg = self.old_forward(hidden_states)\n\n                        if cached_pos_hiddens is not None:\n                            cached_pos_hs = cached_pos_hiddens.pop(0).to(hidden_states.device)\n                            cond_pos_hs = torch.cat([cond_hiddens, cached_pos_hs], dim=1)\n                            pos_weights = weights.clone().repeat(1, 1 + cached_pos_hs.shape[1] // d_model)\n                            pos_weights[:, d_model:] = pos_weight\n                            attn_with_weights = FabricCrossAttnProcessor()\n                            out_pos = attn_with_weights(\n                                self,\n                                cond_hiddens,\n                                encoder_hidden_states=cond_pos_hs,\n                                weights=pos_weights,\n                            )\n                        else:\n                            out_pos = self.old_forward(cond_hiddens)\n\n                        if cached_neg_hiddens is not None:\n                            cached_neg_hs = cached_neg_hiddens.pop(0).to(hidden_states.device)\n                            uncond_neg_hs = torch.cat([uncond_hiddens, cached_neg_hs], dim=1)\n                            neg_weights = weights.clone().repeat(1, 1 + cached_neg_hs.shape[1] // d_model)\n                            neg_weights[:, d_model:] = neg_weight\n                            attn_with_weights = FabricCrossAttnProcessor()\n                            out_neg = attn_with_weights(\n                                self,\n                                uncond_hiddens,\n                                encoder_hidden_states=uncond_neg_hs,\n                                weights=neg_weights,\n                            )\n                        else:\n                            out_neg = self.old_forward(uncond_hiddens)\n\n                        out = torch.cat([out_pos, out_neg], dim=0)\n                        return out\n\n                    module.attn1.old_forward = module.attn1.forward\n                    module.attn1.forward = new_forward.__get__(module.attn1)\n\n        out = self.unet(z_all, t, encoder_hidden_states=prompt_embd)\n\n        # restore original forward pass\n        for module in self.unet.modules():\n            if isinstance(module, BasicTransformerBlock):\n                module.attn1.forward = module.attn1.old_forward\n                del module.attn1.old_forward\n\n        return out\n\n    def preprocess_feedback_images(self, images, vae, dim, device, dtype, generator) -> torch.tensor:\n        images_t = [self.image_to_tensor(img, dim, dtype) for img in images]\n        images_t = torch.stack(images_t).to(device)\n        latents = vae.config.scaling_factor * vae.encode(images_t).latent_dist.sample(generator)\n\n        return torch.cat([latents], dim=0)\n\n    def check_inputs(\n        self,\n        prompt,\n        negative_prompt=None,\n        liked=None,\n        disliked=None,\n        height=None,\n        width=None,\n    ):\n        if prompt is None:\n            raise ValueError(\"Provide `prompt`. Cannot leave both `prompt` undefined.\")\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and (\n            not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)\n        ):\n            raise ValueError(f\"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}\")\n\n        if liked is not None and not isinstance(liked, list):\n            raise ValueError(f\"`liked` has to be of type `list` but is {type(liked)}\")\n\n        if disliked is not None and not isinstance(disliked, list):\n            raise ValueError(f\"`disliked` has to be of type `list` but is {type(disliked)}\")\n\n        if height is not None and not isinstance(height, int):\n            raise ValueError(f\"`height` has to be of type `int` but is {type(height)}\")\n\n        if width is not None and not isinstance(width, int):\n            raise ValueError(f\"`width` has to be of type `int` but is {type(width)}\")\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Optional[Union[str, List[str]]] = \"\",\n        negative_prompt: Optional[Union[str, List[str]]] = \"lowres, bad anatomy, bad hands, cropped, worst quality\",\n        liked: Optional[Union[List[str], List[Image.Image]]] = [],\n        disliked: Optional[Union[List[str], List[Image.Image]]] = [],\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        height: int = 512,\n        width: int = 512,\n        return_dict: bool = True,\n        num_images: int = 4,\n        guidance_scale: float = 7.0,\n        num_inference_steps: int = 20,\n        output_type: str | None = \"pil\",\n        feedback_start_ratio: float = 0.33,\n        feedback_end_ratio: float = 0.66,\n        min_weight: float = 0.05,\n        max_weight: float = 0.8,\n        neg_scale: float = 0.5,\n        pos_bottleneck_scale: float = 1.0,\n        neg_bottleneck_scale: float = 1.0,\n        latents: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation. Generate a trajectory of images with binary feedback. The\n        feedback can be given as a list of liked and disliked images.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`\n                instead.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            liked (`List[Image.Image]` or `List[str]`, *optional*):\n                Encourages images with liked features.\n            disliked (`List[Image.Image]` or `List[str]`, *optional*):\n                Discourages images with disliked features.\n            generator (`torch.Generator` or `List[torch.Generator]` or `int`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) or an `int` to\n                make generation deterministic.\n            height (`int`, *optional*, defaults to 512):\n                Height of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                Width of the generated image.\n            num_images (`int`, *optional*, defaults to 4):\n                The number of images to generate per prompt.\n            guidance_scale (`float`, *optional*, defaults to 7.0):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            num_inference_steps (`int`, *optional*, defaults to 20):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            feedback_start_ratio (`float`, *optional*, defaults to `.33`):\n                Start point for providing feedback (between 0 and 1).\n            feedback_end_ratio (`float`, *optional*, defaults to `.66`):\n                End point for providing feedback (between 0 and 1).\n            min_weight (`float`, *optional*, defaults to `.05`):\n                Minimum weight for feedback.\n            max_weight (`float`, *optional*, defaults tp `1.0`):\n                Maximum weight for feedback.\n            neg_scale (`float`, *optional*, defaults to `.5`):\n                Scale factor for negative feedback.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.fabric.FabricPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n\n        \"\"\"\n\n        self.check_inputs(prompt, negative_prompt, liked, disliked)\n\n        device = self._execution_device\n        dtype = self.unet.dtype\n\n        if isinstance(prompt, str) and prompt is not None:\n            batch_size = 1\n        elif isinstance(prompt, list) and prompt is not None:\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if isinstance(negative_prompt, str):\n            negative_prompt = negative_prompt\n        elif isinstance(negative_prompt, list):\n            negative_prompt = negative_prompt\n        else:\n            assert len(negative_prompt) == batch_size\n\n        shape = (\n            batch_size * num_images,\n            self.unet.config.in_channels,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        latent_noise = randn_tensor(\n            shape,\n            device=device,\n            dtype=dtype,\n            generator=generator,\n        )\n\n        positive_latents = (\n            self.preprocess_feedback_images(liked, self.vae, (height, width), device, dtype, generator)\n            if liked and len(liked) > 0\n            else torch.tensor(\n                [],\n                device=device,\n                dtype=dtype,\n            )\n        )\n        negative_latents = (\n            self.preprocess_feedback_images(disliked, self.vae, (height, width), device, dtype, generator)\n            if disliked and len(disliked) > 0\n            else torch.tensor(\n                [],\n                device=device,\n                dtype=dtype,\n            )\n        )\n\n        do_classifier_free_guidance = guidance_scale > 0.1\n\n        (prompt_neg_embs, prompt_pos_embs) = self._encode_prompt(\n            prompt,\n            device,\n            num_images,\n            do_classifier_free_guidance,\n            negative_prompt,\n        ).split([num_images * batch_size, num_images * batch_size])\n\n        batched_prompt_embd = torch.cat([prompt_pos_embs, prompt_neg_embs], dim=0)\n\n        null_tokens = self.tokenizer(\n            [\"\"],\n            return_tensors=\"pt\",\n            max_length=self.tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n        )\n\n        if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n            attention_mask = null_tokens.attention_mask.to(device)\n        else:\n            attention_mask = None\n\n        null_prompt_emb = self.text_encoder(\n            input_ids=null_tokens.input_ids.to(device),\n            attention_mask=attention_mask,\n        ).last_hidden_state\n\n        null_prompt_emb = null_prompt_emb.to(device=device, dtype=dtype)\n\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n        latent_noise = latent_noise * self.scheduler.init_noise_sigma\n\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n\n        ref_start_idx = round(len(timesteps) * feedback_start_ratio)\n        ref_end_idx = round(len(timesteps) * feedback_end_ratio)\n\n        with self.progress_bar(total=num_inference_steps) as pbar:\n            for i, t in enumerate(timesteps):\n                sigma = self.scheduler.sigma_t[t] if hasattr(self.scheduler, \"sigma_t\") else 0\n                if hasattr(self.scheduler, \"sigmas\"):\n                    sigma = self.scheduler.sigmas[i]\n\n                alpha_hat = 1 / (sigma**2 + 1)\n\n                z_single = self.scheduler.scale_model_input(latent_noise, t)\n                z_all = torch.cat([z_single] * 2, dim=0)\n                z_ref = torch.cat([positive_latents, negative_latents], dim=0)\n\n                if i >= ref_start_idx and i <= ref_end_idx:\n                    weight_factor = max_weight\n                else:\n                    weight_factor = min_weight\n\n                pos_ws = (weight_factor, weight_factor * pos_bottleneck_scale)\n                neg_ws = (weight_factor * neg_scale, weight_factor * neg_scale * neg_bottleneck_scale)\n\n                if z_ref.size(0) > 0 and weight_factor > 0:\n                    noise = torch.randn_like(z_ref)\n                    if isinstance(self.scheduler, EulerAncestralDiscreteScheduler):\n                        z_ref_noised = (alpha_hat**0.5 * z_ref + (1 - alpha_hat) ** 0.5 * noise).type(dtype)\n                    else:\n                        z_ref_noised = self.scheduler.add_noise(z_ref, noise, t)\n\n                    ref_prompt_embd = torch.cat(\n                        [null_prompt_emb] * (len(positive_latents) + len(negative_latents)), dim=0\n                    )\n                    cached_hidden_states = self.get_unet_hidden_states(z_ref_noised, t, ref_prompt_embd)\n\n                    n_pos, n_neg = positive_latents.shape[0], negative_latents.shape[0]\n                    cached_pos_hs, cached_neg_hs = [], []\n                    for hs in cached_hidden_states:\n                        cached_pos, cached_neg = hs.split([n_pos, n_neg], dim=0)\n                        cached_pos = cached_pos.view(1, -1, *cached_pos.shape[2:]).expand(num_images, -1, -1)\n                        cached_neg = cached_neg.view(1, -1, *cached_neg.shape[2:]).expand(num_images, -1, -1)\n                        cached_pos_hs.append(cached_pos)\n                        cached_neg_hs.append(cached_neg)\n\n                    if n_pos == 0:\n                        cached_pos_hs = None\n                    if n_neg == 0:\n                        cached_neg_hs = None\n                else:\n                    cached_pos_hs, cached_neg_hs = None, None\n                unet_out = self.unet_forward_with_cached_hidden_states(\n                    z_all,\n                    t,\n                    prompt_embd=batched_prompt_embd,\n                    cached_pos_hiddens=cached_pos_hs,\n                    cached_neg_hiddens=cached_neg_hs,\n                    pos_weights=pos_ws,\n                    neg_weights=neg_ws,\n                )[0]\n\n                noise_cond, noise_uncond = unet_out.chunk(2)\n                guidance = noise_cond - noise_uncond\n                noise_pred = noise_uncond + guidance_scale * guidance\n                latent_noise = self.scheduler.step(noise_pred, t, latent_noise)[0]\n\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    pbar.update()\n\n        y = self.vae.decode(latent_noise / self.vae.config.scaling_factor, return_dict=False)[0]\n        imgs = self.image_processor.postprocess(\n            y,\n            output_type=output_type,\n        )\n\n        if not return_dict:\n            return imgs\n\n        return StableDiffusionPipelineOutput(imgs, False)\n\n    def image_to_tensor(self, image: Union[str, Image.Image], dim: tuple, dtype):\n        \"\"\"\n        Convert latent PIL image to a torch tensor for further processing.\n        \"\"\"\n        if isinstance(image, str):\n            image = Image.open(image)\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n        image = self.image_processor.preprocess(image, height=dim[0], width=dim[1])[0]\n        return image.type(dtype)\n"
  },
  {
    "path": "examples/community/pipeline_faithdiff_stable_diffusion_xl.py",
    "content": "# Copyright 2025 Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab Team\n# and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport copy\nimport inspect\nfrom collections import OrderedDict\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport cv2\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer\n\nfrom diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    PeftAdapterMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n    UNet2DConditionLoadersMixin,\n)\nfrom diffusers.models import AutoencoderKL\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import DDPMScheduler, KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    is_invisible_watermark_available,\n    is_torch_version,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.outputs import BaseOutput\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_invisible_watermark_available():\n    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import random\n        >>> import numpy as np\n        >>> import torch\n        >>> from diffusers import DiffusionPipeline, AutoencoderKL, UniPCMultistepScheduler\n        >>> from huggingface_hub import hf_hub_download\n        >>> from diffusers.utils import load_image\n        >>> from PIL import Image\n        >>>\n        >>> device = \"cuda\"\n        >>> dtype = torch.float16\n        >>> MAX_SEED = np.iinfo(np.int32).max\n        >>>\n        >>> # Download weights for additional unet layers\n        >>> model_file = hf_hub_download(\n        ...     \"jychen9811/FaithDiff\",\n        ...     filename=\"FaithDiff.bin\", local_dir=\"./proc_data/faithdiff\", local_dir_use_symlinks=False\n        ... )\n        >>>\n        >>> # Initialize the models and pipeline\n        >>> vae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=dtype)\n        >>>\n        >>> model_id = \"SG161222/RealVisXL_V4.0\"\n        >>> pipe = DiffusionPipeline.from_pretrained(\n        ...     model_id,\n        ...     torch_dtype=dtype,\n        ...     vae=vae,\n        ...     unet=None, #<- Do not load with original model.\n        ...     custom_pipeline=\"mixture_tiling_sdxl\",\n        ...     use_safetensors=True,\n        ...     variant=\"fp16\",\n        ... ).to(device)\n        >>>\n        >>> # Here we need use pipeline internal unet model\n        >>> pipe.unet = pipe.unet_model.from_pretrained(model_id, subfolder=\"unet\", variant=\"fp16\", use_safetensors=True)\n        >>>\n        >>> # Load additional layers to the model\n        >>> pipe.unet.load_additional_layers(weight_path=\"proc_data/faithdiff/FaithDiff.bin\", dtype=dtype)\n        >>>\n        >>> # Enable vae tiling\n        >>> pipe.set_encoder_tile_settings()\n        >>> pipe.enable_vae_tiling()\n        >>>\n        >>> # Optimization\n        >>> pipe.enable_model_cpu_offload()\n        >>>\n        >>> # Set selected scheduler\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n        >>>\n        >>> #input params\n        >>> prompt = \"The image features a woman in her 55s with blonde hair and a white shirt, smiling at the camera. She appears to be in a good mood and is wearing a white scarf around her neck. \"\n        >>> upscale = 2 # scale here\n        >>> start_point = \"lr\" # or \"noise\"\n        >>> latent_tiled_overlap = 0.5\n        >>> latent_tiled_size = 1024\n        >>>\n        >>> # Load image\n        >>> lq_image = load_image(\"https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/woman.png\")\n        >>> original_height = lq_image.height\n        >>> original_width = lq_image.width\n        >>> print(f\"Current resolution: H:{original_height} x W:{original_width}\")\n        >>>\n        >>> width = original_width * int(upscale)\n        >>> height = original_height * int(upscale)\n        >>> print(f\"Final resolution: H:{height} x W:{width}\")\n        >>>\n        >>> # Restoration\n        >>> image = lq_image.resize((width, height), Image.LANCZOS)\n        >>> input_image, width_init, height_init, width_now, height_now = pipe.check_image_size(image)\n        >>>\n        >>> generator = torch.Generator(device=device).manual_seed(random.randint(0, MAX_SEED))\n        >>> gen_image = pipe(lr_img=input_image,\n        ...                 prompt = prompt,\n        ...                 num_inference_steps=20,\n        ...                 guidance_scale=5,\n        ...                 generator=generator,\n        ...                 start_point=start_point,\n        ...                 height = height_now,\n        ...                 width=width_now,\n        ...                 overlap=latent_tiled_overlap,\n        ...                 target_size=(latent_tiled_size, latent_tiled_size)\n        ...                 ).images[0]\n        >>>\n        >>> cropped_image = gen_image.crop((0, 0, width_init, height_init))\n        >>> cropped_image.save(\"data/result.png\")\n        ```\n\"\"\"\n\n\ndef zero_module(module):\n    \"\"\"Zero out the parameters of a module and return it.\"\"\"\n    for p in module.parameters():\n        nn.init.zeros_(p)\n    return module\n\n\nclass Encoder(nn.Module):\n    \"\"\"Encoder layer of a variational autoencoder that encodes input into a latent representation.\"\"\"\n\n    def __init__(\n        self,\n        in_channels: int = 3,\n        out_channels: int = 4,\n        down_block_types: Tuple[str, ...] = (\n            \"DownEncoderBlock2D\",\n            \"DownEncoderBlock2D\",\n            \"DownEncoderBlock2D\",\n            \"DownEncoderBlock2D\",\n        ),\n        block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),\n        layers_per_block: int = 2,\n        norm_num_groups: int = 32,\n        act_fn: str = \"silu\",\n        double_z: bool = True,\n        mid_block_add_attention: bool = True,\n    ):\n        super().__init__()\n        self.layers_per_block = layers_per_block\n\n        self.conv_in = nn.Conv2d(\n            in_channels,\n            block_out_channels[0],\n            kernel_size=3,\n            stride=1,\n            padding=1,\n        )\n\n        self.mid_block = None\n        self.down_blocks = nn.ModuleList([])\n        self.use_rgb = False\n        self.down_block_type = down_block_types\n        self.block_out_channels = block_out_channels\n\n        self.tile_sample_min_size = 1024\n        self.tile_latent_min_size = int(self.tile_sample_min_size / 8)\n        self.tile_overlap_factor = 0.25\n        self.use_tiling = False\n\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=self.layers_per_block,\n                in_channels=input_channel,\n                out_channels=output_channel,\n                add_downsample=not is_final_block,\n                resnet_eps=1e-6,\n                downsample_padding=0,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                attention_head_dim=output_channel,\n                temb_channels=None,\n            )\n            self.down_blocks.append(down_block)\n\n        self.mid_block = UNetMidBlock2D(\n            in_channels=block_out_channels[-1],\n            resnet_eps=1e-6,\n            resnet_act_fn=act_fn,\n            output_scale_factor=1,\n            resnet_time_scale_shift=\"default\",\n            attention_head_dim=block_out_channels[-1],\n            resnet_groups=norm_num_groups,\n            temb_channels=None,\n            add_attention=mid_block_add_attention,\n        )\n\n        self.gradient_checkpointing = False\n\n    def to_rgb_init(self):\n        \"\"\"Initialize layers to convert features to RGB.\"\"\"\n        self.to_rgbs = nn.ModuleList([])\n        self.use_rgb = True\n        for i, down_block_type in enumerate(self.down_block_type):\n            output_channel = self.block_out_channels[i]\n            self.to_rgbs.append(nn.Conv2d(output_channel, 3, kernel_size=3, padding=1))\n\n    def enable_tiling(self):\n        \"\"\"Enable tiling for large inputs.\"\"\"\n        self.use_tiling = True\n\n    def encode(self, sample: torch.FloatTensor) -> torch.FloatTensor:\n        \"\"\"Encode the input tensor into a latent representation.\"\"\"\n        sample = self.conv_in(sample)\n        if self.training and self.gradient_checkpointing:\n\n            def create_custom_forward(module):\n                def custom_forward(*inputs):\n                    return module(*inputs)\n\n                return custom_forward\n\n            if is_torch_version(\">=\", \"1.11.0\"):\n                for down_block in self.down_blocks:\n                    sample = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(down_block), sample, use_reentrant=False\n                    )\n                sample = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(self.mid_block), sample, use_reentrant=False\n                )\n            else:\n                for down_block in self.down_blocks:\n                    sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)\n                sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)\n            return sample\n        else:\n            for down_block in self.down_blocks:\n                sample = down_block(sample)\n            sample = self.mid_block(sample)\n            return sample\n\n    def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:\n        \"\"\"Blend two tensors vertically with a smooth transition.\"\"\"\n        blend_extent = min(a.shape[2], b.shape[2], blend_extent)\n        for y in range(blend_extent):\n            b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)\n        return b\n\n    def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:\n        \"\"\"Blend two tensors horizontally with a smooth transition.\"\"\"\n        blend_extent = min(a.shape[3], b.shape[3], blend_extent)\n        for x in range(blend_extent):\n            b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)\n        return b\n\n    def tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:\n        \"\"\"Encode the input tensor using tiling for large inputs.\"\"\"\n        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_latent_min_size - blend_extent\n\n        rows = []\n        for i in range(0, x.shape[2], overlap_size):\n            row = []\n            for j in range(0, x.shape[3], overlap_size):\n                tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]\n                tile = self.encode(tile)\n                row.append(tile)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=3))\n\n        moments = torch.cat(result_rows, dim=2)\n        return moments\n\n    def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:\n        \"\"\"Forward pass of the encoder, using tiling if enabled for large inputs.\"\"\"\n        if self.use_tiling and (\n            sample.shape[-1] > self.tile_latent_min_size or sample.shape[-2] > self.tile_latent_min_size\n        ):\n            return self.tiled_encode(sample)\n        return self.encode(sample)\n\n\nclass ControlNetConditioningEmbedding(nn.Module):\n    \"\"\"A small network to preprocess conditioning inputs, inspired by ControlNet.\"\"\"\n\n    def __init__(self, conditioning_embedding_channels: int, conditioning_channels: int = 4):\n        super().__init__()\n        self.conv_in = nn.Conv2d(conditioning_channels, conditioning_channels, kernel_size=3, padding=1)\n        self.norm_in = nn.GroupNorm(num_channels=conditioning_channels, num_groups=32, eps=1e-6)\n        self.conv_out = zero_module(\n            nn.Conv2d(conditioning_channels, conditioning_embedding_channels, kernel_size=3, padding=1)\n        )\n\n    def forward(self, conditioning):\n        \"\"\"Process the conditioning input through the network.\"\"\"\n        conditioning = self.norm_in(conditioning)\n        embedding = self.conv_in(conditioning)\n        embedding = F.silu(embedding)\n        embedding = self.conv_out(embedding)\n        return embedding\n\n\nclass QuickGELU(nn.Module):\n    \"\"\"A fast approximation of the GELU activation function.\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"Apply the QuickGELU activation to the input tensor.\"\"\"\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass LayerNorm(nn.LayerNorm):\n    \"\"\"Subclass torch's LayerNorm to handle fp16.\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"Apply LayerNorm and preserve the input dtype.\"\"\"\n        orig_type = x.dtype\n        ret = super().forward(x)\n        return ret.type(orig_type)\n\n\nclass ResidualAttentionBlock(nn.Module):\n    \"\"\"A transformer-style block with self-attention and an MLP.\"\"\"\n\n    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):\n        super().__init__()\n        self.attn = nn.MultiheadAttention(d_model, n_head)\n        self.ln_1 = LayerNorm(d_model)\n        self.mlp = nn.Sequential(\n            OrderedDict(\n                [\n                    (\"c_fc\", nn.Linear(d_model, d_model * 2)),\n                    (\"gelu\", QuickGELU()),\n                    (\"c_proj\", nn.Linear(d_model * 2, d_model)),\n                ]\n            )\n        )\n        self.ln_2 = LayerNorm(d_model)\n        self.attn_mask = attn_mask\n\n    def attention(self, x: torch.Tensor):\n        \"\"\"Apply self-attention to the input tensor.\"\"\"\n        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None\n        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"Forward pass through the residual attention block.\"\"\"\n        x = x + self.attention(self.ln_1(x))\n        x = x + self.mlp(self.ln_2(x))\n        return x\n\n\n@dataclass\nclass UNet2DConditionOutput(BaseOutput):\n    \"\"\"The output of UnifiedUNet2DConditionModel.\"\"\"\n\n    sample: torch.FloatTensor = None\n\n\nclass UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):\n    \"\"\"A unified 2D UNet model extending OriginalUNet2DConditionModel with custom functionality.\"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 4,\n        out_channels: int = 4,\n        center_input_sample: bool = False,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str, ...] = (\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"DownBlock2D\",\n        ),\n        mid_block_type: str | None = \"UNetMidBlock2DCrossAttn\",\n        up_block_types: Tuple[str, ...] = (\n            \"UpBlock2D\",\n            \"CrossAttnUpBlock2D\",\n            \"CrossAttnUpBlock2D\",\n            \"CrossAttnUpBlock2D\",\n        ),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),\n        layers_per_block: Union[int, Tuple[int]] = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        dropout: float = 0.0,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: Union[int, Tuple[int]] = 1280,\n        transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,\n        reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,\n        encoder_hid_dim: Optional[int] = None,\n        encoder_hid_dim_type: str | None = None,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        class_embed_type: str | None = None,\n        addition_embed_type: str | None = None,\n        addition_time_embed_dim: Optional[int] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_skip_time_act: bool = False,\n        resnet_out_scale_factor: float = 1.0,\n        time_embedding_type: str = \"positional\",\n        time_embedding_dim: Optional[int] = None,\n        time_embedding_act_fn: str | None = None,\n        timestep_post_act: str | None = None,\n        time_cond_proj_dim: Optional[int] = None,\n        conv_in_kernel: int = 3,\n        conv_out_kernel: int = 3,\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        attention_type: str = \"default\",\n        class_embeddings_concat: bool = False,\n        mid_block_only_cross_attention: Optional[bool] = None,\n        cross_attention_norm: str | None = None,\n        addition_embed_type_num_heads: int = 64,\n    ):\n        \"\"\"Initialize the UnifiedUNet2DConditionModel.\"\"\"\n        super().__init__(\n            sample_size=sample_size,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            center_input_sample=center_input_sample,\n            flip_sin_to_cos=flip_sin_to_cos,\n            freq_shift=freq_shift,\n            down_block_types=down_block_types,\n            mid_block_type=mid_block_type,\n            up_block_types=up_block_types,\n            only_cross_attention=only_cross_attention,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            downsample_padding=downsample_padding,\n            mid_block_scale_factor=mid_block_scale_factor,\n            dropout=dropout,\n            act_fn=act_fn,\n            norm_num_groups=norm_num_groups,\n            norm_eps=norm_eps,\n            cross_attention_dim=cross_attention_dim,\n            transformer_layers_per_block=transformer_layers_per_block,\n            reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,\n            encoder_hid_dim=encoder_hid_dim,\n            encoder_hid_dim_type=encoder_hid_dim_type,\n            attention_head_dim=attention_head_dim,\n            num_attention_heads=num_attention_heads,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            class_embed_type=class_embed_type,\n            addition_embed_type=addition_embed_type,\n            addition_time_embed_dim=addition_time_embed_dim,\n            num_class_embeds=num_class_embeds,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            resnet_skip_time_act=resnet_skip_time_act,\n            resnet_out_scale_factor=resnet_out_scale_factor,\n            time_embedding_type=time_embedding_type,\n            time_embedding_dim=time_embedding_dim,\n            time_embedding_act_fn=time_embedding_act_fn,\n            timestep_post_act=timestep_post_act,\n            time_cond_proj_dim=time_cond_proj_dim,\n            conv_in_kernel=conv_in_kernel,\n            conv_out_kernel=conv_out_kernel,\n            projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,\n            attention_type=attention_type,\n            class_embeddings_concat=class_embeddings_concat,\n            mid_block_only_cross_attention=mid_block_only_cross_attention,\n            cross_attention_norm=cross_attention_norm,\n            addition_embed_type_num_heads=addition_embed_type_num_heads,\n        )\n\n        # Additional attributes\n        self.denoise_encoder = None\n        self.information_transformer_layes = None\n        self.condition_embedding = None\n        self.agg_net = None\n        self.spatial_ch_projs = None\n\n    def init_vae_encoder(self, dtype):\n        self.denoise_encoder = Encoder()\n        if dtype is not None:\n            self.denoise_encoder.dtype = dtype\n\n    def init_information_transformer_layes(self):\n        num_trans_channel = 640\n        num_trans_head = 8\n        num_trans_layer = 2\n        num_proj_channel = 320\n        self.information_transformer_layes = nn.Sequential(\n            *[ResidualAttentionBlock(num_trans_channel, num_trans_head) for _ in range(num_trans_layer)]\n        )\n        self.spatial_ch_projs = zero_module(nn.Linear(num_trans_channel, num_proj_channel))\n\n    def init_ControlNetConditioningEmbedding(self, channel=512):\n        self.condition_embedding = ControlNetConditioningEmbedding(320, channel)\n\n    def init_extra_weights(self):\n        self.agg_net = nn.ModuleList()\n\n    def load_additional_layers(\n        self, dtype: Optional[torch.dtype] = torch.float16, channel: int = 512, weight_path: str | None = None\n    ):\n        \"\"\"Load additional layers and weights from a file.\n\n        Args:\n            weight_path (str): Path to the weight file.\n            dtype (torch.dtype, optional): Data type for the loaded weights. Defaults to torch.float16.\n            channel (int): Conditioning embedding channel out size. Defaults 512.\n        \"\"\"\n        if self.denoise_encoder is None:\n            self.init_vae_encoder(dtype)\n\n        if self.information_transformer_layes is None:\n            self.init_information_transformer_layes()\n\n        if self.condition_embedding is None:\n            self.init_ControlNetConditioningEmbedding(channel)\n\n        if self.agg_net is None:\n            self.init_extra_weights()\n\n        # Load weights if provided\n        if weight_path is not None:\n            state_dict = torch.load(weight_path, weights_only=False)\n            self.load_state_dict(state_dict, strict=True)\n\n        # Move all modules to the same device and dtype as the model\n        device = next(self.parameters()).device\n        if dtype is not None or device is not None:\n            self.to(device=device, dtype=dtype or next(self.parameters()).dtype)\n\n    def to(self, *args, **kwargs):\n        \"\"\"Override to() to move all additional modules to the same device and dtype.\"\"\"\n        super().to(*args, **kwargs)\n        for module in [\n            self.denoise_encoder,\n            self.information_transformer_layes,\n            self.condition_embedding,\n            self.agg_net,\n            self.spatial_ch_projs,\n        ]:\n            if module is not None:\n                module.to(*args, **kwargs)\n        return self\n\n    def load_state_dict(self, state_dict, strict=True):\n        \"\"\"Load state dictionary into the model.\n\n        Args:\n            state_dict (dict): State dictionary to load.\n            strict (bool, optional): Whether to strictly enforce that all keys match. Defaults to True.\n        \"\"\"\n        core_dict = {}\n        additional_dicts = {\n            \"denoise_encoder\": {},\n            \"information_transformer_layes\": {},\n            \"condition_embedding\": {},\n            \"agg_net\": {},\n            \"spatial_ch_projs\": {},\n        }\n\n        for key, value in state_dict.items():\n            if key.startswith(\"denoise_encoder.\"):\n                additional_dicts[\"denoise_encoder\"][key[len(\"denoise_encoder.\") :]] = value\n            elif key.startswith(\"information_transformer_layes.\"):\n                additional_dicts[\"information_transformer_layes\"][key[len(\"information_transformer_layes.\") :]] = value\n            elif key.startswith(\"condition_embedding.\"):\n                additional_dicts[\"condition_embedding\"][key[len(\"condition_embedding.\") :]] = value\n            elif key.startswith(\"agg_net.\"):\n                additional_dicts[\"agg_net\"][key[len(\"agg_net.\") :]] = value\n            elif key.startswith(\"spatial_ch_projs.\"):\n                additional_dicts[\"spatial_ch_projs\"][key[len(\"spatial_ch_projs.\") :]] = value\n            else:\n                core_dict[key] = value\n\n        super().load_state_dict(core_dict, strict=False)\n        for module_name, module_dict in additional_dicts.items():\n            module = getattr(self, module_name, None)\n            if module is not None and module_dict:\n                module.load_state_dict(module_dict, strict=strict)\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        input_embedding: Optional[torch.Tensor] = None,\n        add_sample: bool = True,\n        return_dict: bool = True,\n        use_condition_embedding: bool = True,\n    ) -> Union[UNet2DConditionOutput, Tuple]:\n        \"\"\"Forward pass prioritizing the original modified implementation.\n\n        Args:\n            sample (torch.FloatTensor): The noisy input tensor with shape `(batch, channel, height, width)`.\n            timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.\n            encoder_hidden_states (torch.Tensor): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.\n            class_labels (torch.Tensor, optional): Optional class labels for conditioning.\n            timestep_cond (torch.Tensor, optional): Conditional embeddings for timestep.\n            attention_mask (torch.Tensor, optional): An attention mask of shape `(batch, key_tokens)`.\n            cross_attention_kwargs (Dict[str, Any], optional): A kwargs dictionary for the AttentionProcessor.\n            added_cond_kwargs (Dict[str, torch.Tensor], optional): Additional embeddings to add to the UNet blocks.\n            down_block_additional_residuals (Tuple[torch.Tensor], optional): Residuals for down UNet blocks.\n            mid_block_additional_residual (torch.Tensor, optional): Residual for the middle UNet block.\n            down_intrablock_additional_residuals (Tuple[torch.Tensor], optional): Additional residuals within down blocks.\n            encoder_attention_mask (torch.Tensor, optional): A cross-attention mask of shape `(batch, sequence_length)`.\n            input_embedding (torch.Tensor, optional): Additional input embedding for preprocessing.\n            add_sample (bool): Whether to add the sample to the processed embedding. Defaults to True.\n            return_dict (bool): Whether to return a UNet2DConditionOutput. Defaults to True.\n            use_condition_embedding (bool): Whether to use the condition embedding. Defaults to True.\n\n        Returns:\n            Union[UNet2DConditionOutput, Tuple]: The processed sample tensor, either as a UNet2DConditionOutput or tuple.\n        \"\"\"\n        default_overall_up_factor = 2**self.num_upsamplers\n        forward_upsample_size = False\n        upsample_size = None\n\n        for dim in sample.shape[-2:]:\n            if dim % default_overall_up_factor != 0:\n                forward_upsample_size = True\n                break\n\n        if attention_mask is not None:\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        t_emb = self.get_time_embed(sample=sample, timestep=timestep)\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)\n        if class_emb is not None:\n            if self.config.class_embeddings_concat:\n                emb = torch.cat([emb, class_emb], dim=-1)\n            else:\n                emb = emb + class_emb\n\n        aug_emb = self.get_aug_embed(\n            emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n        )\n        if self.config.addition_embed_type == \"image_hint\":\n            aug_emb, hint = aug_emb\n            sample = torch.cat([sample, hint], dim=1)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        if self.time_embed_act is not None:\n            emb = self.time_embed_act(emb)\n\n        encoder_hidden_states = self.process_encoder_hidden_states(\n            encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n        )\n\n        # 2. pre-process (following the original modified logic)\n        sample = self.conv_in(sample)  # [B, 4, H, W] -> [B, 320, H, W]\n        if (\n            input_embedding is not None\n            and self.condition_embedding is not None\n            and self.information_transformer_layes is not None\n        ):\n            if use_condition_embedding:\n                input_embedding = self.condition_embedding(input_embedding)  # [B, 320, H, W]\n            batch_size, channel, height, width = input_embedding.shape\n            concat_feat = (\n                torch.cat([sample, input_embedding], dim=1)\n                .view(batch_size, 2 * channel, height * width)\n                .transpose(1, 2)\n            )\n            concat_feat = self.information_transformer_layes(concat_feat)\n            feat_alpha = self.spatial_ch_projs(concat_feat).transpose(1, 2).view(batch_size, channel, height, width)\n            sample = sample + feat_alpha if add_sample else feat_alpha  # Update sample as in the original version\n\n        # 2.5 GLIGEN position net (kept from the original version)\n        if cross_attention_kwargs is not None and cross_attention_kwargs.get(\"gligen\", None) is not None:\n            cross_attention_kwargs = cross_attention_kwargs.copy()\n            gligen_args = cross_attention_kwargs.pop(\"gligen\")\n            cross_attention_kwargs[\"gligen\"] = {\"objs\": self.position_net(**gligen_args)}\n\n        # 3. down (continues the standard flow)\n        if cross_attention_kwargs is not None:\n            cross_attention_kwargs = cross_attention_kwargs.copy()\n            lora_scale = cross_attention_kwargs.pop(\"scale\", 1.0)\n        else:\n            lora_scale = 1.0\n\n        if USE_PEFT_BACKEND:\n            scale_lora_layers(self, lora_scale)\n\n        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None\n        is_adapter = down_intrablock_additional_residuals is not None\n        if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:\n            deprecate(\n                \"T2I should not use down_block_additional_residuals\",\n                \"1.3.0\",\n                \"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \\\n                       and will be removed in diffusers 1.3.0.  `down_block_additional_residuals` should only be used \\\n                       for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. \",\n                standard_warn=False,\n            )\n            down_intrablock_additional_residuals = down_block_additional_residuals\n            is_adapter = True\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                additional_residuals = {}\n                if is_adapter and len(down_intrablock_additional_residuals) > 0:\n                    additional_residuals[\"additional_residuals\"] = down_intrablock_additional_residuals.pop(0)\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                    **additional_residuals,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n                if is_adapter and len(down_intrablock_additional_residuals) > 0:\n                    sample += down_intrablock_additional_residuals.pop(0)\n            down_block_res_samples += res_samples\n\n        if is_controlnet:\n            new_down_block_res_samples = ()\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n            down_block_res_samples = new_down_block_res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            if hasattr(self.mid_block, \"has_cross_attention\") and self.mid_block.has_cross_attention:\n                sample = self.mid_block(\n                    sample,\n                    emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                sample = self.mid_block(sample, emb)\n            if (\n                is_adapter\n                and len(down_intrablock_additional_residuals) > 0\n                and sample.shape == down_intrablock_additional_residuals[0].shape\n            ):\n                sample += down_intrablock_additional_residuals.pop(0)\n\n        if is_controlnet:\n            sample = sample + mid_block_additional_residual\n\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    upsample_size=upsample_size,\n                )\n\n        # 6. post-process\n        if self.conv_norm_out:\n            sample = self.conv_norm_out(sample)\n            sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if USE_PEFT_BACKEND:\n            unscale_lora_layers(self, lora_scale)\n\n        if not return_dict:\n            return (sample,)\n        return UNet2DConditionOutput(sample=sample)\n\n\nclass LocalAttention:\n    \"\"\"A class to handle local attention by splitting tensors into overlapping grids for processing.\"\"\"\n\n    def __init__(self, kernel_size=None, overlap=0.5):\n        \"\"\"Initialize the LocalAttention module.\n\n        Args:\n            kernel_size (tuple[int, int], optional): Size of the grid (height, width). Defaults to None.\n            overlap (float): Overlap factor between adjacent grids (0.0 to 1.0). Defaults to 0.5.\n        \"\"\"\n        super().__init__()\n        self.kernel_size = kernel_size\n        self.overlap = overlap\n\n    def grids_list(self, x):\n        \"\"\"Split the input tensor into a list of non-overlapping grid patches.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch, channels, height, width).\n\n        Returns:\n            list[torch.Tensor]: List of tensor patches.\n        \"\"\"\n        b, c, h, w = x.shape\n        self.original_size = (b, c, h, w)\n        assert b == 1\n        k1, k2 = self.kernel_size\n        if h < k1:\n            k1 = h\n        if w < k2:\n            k2 = w\n        num_row = (h - 1) // k1 + 1\n        num_col = (w - 1) // k2 + 1\n        self.nr = num_row\n        self.nc = num_col\n\n        import math\n\n        step_j = k2 if num_col == 1 else math.ceil(k2 * self.overlap)\n        step_i = k1 if num_row == 1 else math.ceil(k1 * self.overlap)\n        parts = []\n        idxes = []\n        i = 0\n        last_i = False\n        while i < h and not last_i:\n            j = 0\n            if i + k1 >= h:\n                i = h - k1\n                last_i = True\n            last_j = False\n            while j < w and not last_j:\n                if j + k2 >= w:\n                    j = w - k2\n                    last_j = True\n                parts.append(x[:, :, i : i + k1, j : j + k2])\n                idxes.append({\"i\": i, \"j\": j})\n                j = j + step_j\n            i = i + step_i\n        return parts\n\n    def grids(self, x):\n        \"\"\"Split the input tensor into overlapping grid patches and concatenate them.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch, channels, height, width).\n\n        Returns:\n            torch.Tensor: Concatenated tensor of all grid patches.\n        \"\"\"\n        b, c, h, w = x.shape\n        self.original_size = (b, c, h, w)\n        assert b == 1\n        k1, k2 = self.kernel_size\n        if h < k1:\n            k1 = h\n        if w < k2:\n            k2 = w\n        self.tile_weights = self._gaussian_weights(k2, k1)\n        num_row = (h - 1) // k1 + 1\n        num_col = (w - 1) // k2 + 1\n        self.nr = num_row\n        self.nc = num_col\n\n        import math\n\n        step_j = k2 if num_col == 1 else math.ceil(k2 * self.overlap)\n        step_i = k1 if num_row == 1 else math.ceil(k1 * self.overlap)\n        parts = []\n        idxes = []\n        i = 0\n        last_i = False\n        while i < h and not last_i:\n            j = 0\n            if i + k1 >= h:\n                i = h - k1\n                last_i = True\n            last_j = False\n            while j < w and not last_j:\n                if j + k2 >= w:\n                    j = w - k2\n                    last_j = True\n                parts.append(x[:, :, i : i + k1, j : j + k2])\n                idxes.append({\"i\": i, \"j\": j})\n                j = j + step_j\n            i = i + step_i\n        self.idxes = idxes\n        return torch.cat(parts, dim=0)\n\n    def _gaussian_weights(self, tile_width, tile_height):\n        \"\"\"Generate a Gaussian weight mask for tile contributions.\n\n        Args:\n            tile_width (int): Width of the tile.\n            tile_height (int): Height of the tile.\n\n        Returns:\n            torch.Tensor: Gaussian weight tensor of shape (channels, height, width).\n        \"\"\"\n        import numpy as np\n        from numpy import exp, pi, sqrt\n\n        latent_width = tile_width\n        latent_height = tile_height\n        var = 0.01\n        midpoint = (latent_width - 1) / 2\n        x_probs = [\n            exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)\n            for x in range(latent_width)\n        ]\n        midpoint = latent_height / 2\n        y_probs = [\n            exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)\n            for y in range(latent_height)\n        ]\n        weights = np.outer(y_probs, x_probs)\n        return torch.tile(torch.tensor(weights, device=torch.device(\"cuda\")), (4, 1, 1))\n\n    def grids_inverse(self, outs):\n        \"\"\"Reconstruct the original tensor from processed grid patches with overlap blending.\n\n        Args:\n            outs (torch.Tensor): Processed grid patches.\n\n        Returns:\n            torch.Tensor: Reconstructed tensor of original size.\n        \"\"\"\n        preds = torch.zeros(self.original_size).to(outs.device)\n        b, c, h, w = self.original_size\n        count_mt = torch.zeros((b, 4, h, w)).to(outs.device)\n        k1, k2 = self.kernel_size\n\n        for cnt, each_idx in enumerate(self.idxes):\n            i = each_idx[\"i\"]\n            j = each_idx[\"j\"]\n            preds[0, :, i : i + k1, j : j + k2] += outs[cnt, :, :, :] * self.tile_weights\n            count_mt[0, :, i : i + k1, j : j + k2] += self.tile_weights\n\n        del outs\n        torch.cuda.empty_cache()\n        return preds / count_mt\n\n    def _pad(self, x):\n        \"\"\"Pad the input tensor to align with kernel size.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch, channels, height, width).\n\n        Returns:\n            tuple: Padded tensor and padding values.\n        \"\"\"\n        b, c, h, w = x.shape\n        k1, k2 = self.kernel_size\n        mod_pad_h = (k1 - h % k1) % k1\n        mod_pad_w = (k2 - w % k2) % k2\n        pad = (mod_pad_w // 2, mod_pad_w - mod_pad_w // 2, mod_pad_h // 2, mod_pad_h - mod_pad_h // 2)\n        x = F.pad(x, pad, \"reflect\")\n        return x, pad\n\n    def forward(self, x):\n        \"\"\"Apply local attention by splitting into grids and reconstructing.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch, channels, height, width).\n\n        Returns:\n            torch.Tensor: Processed tensor of original size.\n        \"\"\"\n        b, c, h, w = x.shape\n        qkv = self.grids(x)\n        out = self.grids_inverse(qkv)\n        return out\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n\n    Args:\n        noise_cfg (torch.Tensor): Noise configuration tensor.\n        noise_pred_text (torch.Tensor): Predicted noise from text-conditioned model.\n        guidance_rescale (float): Rescaling factor for guidance. Defaults to 0.0.\n\n    Returns:\n        torch.Tensor: Rescaled noise configuration.\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    \"\"\"Retrieve latents from an encoder output.\n\n    Args:\n        encoder_output (torch.Tensor): Output from an encoder (e.g., VAE).\n        generator (torch.Generator, optional): Random generator for sampling. Defaults to None.\n        sample_mode (str): Sampling mode (\"sample\" or \"argmax\"). Defaults to \"sample\".\n\n    Returns:\n        torch.Tensor: Retrieved latent tensor.\n    \"\"\"\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass FaithDiffStableDiffusionXLPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    FromSingleFileMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n    IPAdapterMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion XL.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion XL uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([` CLIPTextModelWithProjection`]):\n            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\n            specifically the\n            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)\n            variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`CLIPTokenizer`):\n            Second Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\n            `stabilityai/stable-diffusion-xl-base-1-0`.\n        add_watermarker (`bool`, *optional*):\n            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to\n            watermark output images. If not defined, it will default to True if the package is installed, otherwise no\n            watermarker will be used.\n    \"\"\"\n\n    unet_model = UNet2DConditionModel\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->unet->vae\"\n    _optional_components = [\"tokenizer\", \"tokenizer_2\", \"text_encoder\", \"text_encoder_2\", \"feature_extractor\", \"unet\"]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"add_text_embeds\",\n        \"add_time_ids\",\n        \"negative_pooled_prompt_embeds\",\n        \"negative_add_time_ids\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        text_encoder_2: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        tokenizer_2: CLIPTokenizer,\n        unet: OriginalUNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        force_zeros_for_empty_prompt: bool = True,\n        add_watermarker: Optional[bool] = None,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            unet=unet,\n            scheduler=scheduler,\n        )\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.DDPMScheduler = DDPMScheduler.from_config(self.scheduler.config, subfolder=\"scheduler\")\n        self.default_sample_size = self.unet.config.sample_size if unet is not None else 128\n        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()\n\n        if add_watermarker:\n            self.watermark = StableDiffusionXLWatermarker()\n        else:\n            self.watermark = None\n\n    def encode_prompt(\n        self,\n        prompt: str,\n        prompt_2: str | None = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        device = \"cuda\"  # device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder, lora_scale)\n\n            if self.text_encoder_2 is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]\n        text_encoders = (\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\n        )\n        dtype = text_encoders[0].dtype\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # textual inversion: process multi-vector tokens if necessary\n            prompt_embeds_list = []\n            prompts = [prompt, prompt_2]\n            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                text_input_ids = text_inputs.input_ids\n                untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                    text_input_ids, untruncated_ids\n                ):\n                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                    logger.warning(\n                        \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                        f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                    )\n                text_encoder = text_encoder.to(dtype)\n                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                pooled_prompt_embeds = prompt_embeds[0]\n                if clip_skip is None:\n                    prompt_embeds = prompt_embeds.hidden_states[-2]\n                else:\n                    # \"2\" because SDXL always indexes from the penultimate layer.\n                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\n\n            # normalize str to list\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n            negative_prompt_2 = (\n                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2\n            )\n\n            uncond_tokens: List[str]\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = [negative_prompt, negative_prompt_2]\n\n            negative_prompt_embeds_list = []\n            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    negative_prompt,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                negative_prompt_embeds = text_encoder(\n                    uncond_input.input_ids.to(device),\n                    output_hidden_states=True,\n                )\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                negative_pooled_prompt_embeds = negative_prompt_embeds[0]\n                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n\n        if self.text_encoder_2 is not None:\n            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n        else:\n            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            if self.text_encoder_2 is not None:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n            else:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_image_size(self, x, padder_size=8):\n        # 获取图像的宽高\n        width, height = x.size\n        padder_size = padder_size\n        # 计算需要填充的高度和宽度\n        mod_pad_h = (padder_size - height % padder_size) % padder_size\n        mod_pad_w = (padder_size - width % padder_size) % padder_size\n        x_np = np.array(x)\n        # 使用 ImageOps.expand 进行填充\n        x_padded = cv2.copyMakeBorder(\n            x_np, top=0, bottom=mod_pad_h, left=0, right=mod_pad_w, borderType=cv2.BORDER_REPLICATE\n        )\n\n        x = PIL.Image.fromarray(x_padded)\n        # x = x.resize((width + mod_pad_w, height + mod_pad_h))\n\n        return x, width, height, width + mod_pad_w, height + mod_pad_h\n\n    def check_inputs(\n        self,\n        lr_img,\n        prompt,\n        prompt_2,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if lr_img is None:\n            raise ValueError(\"`lr_image` must be provided!\")\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(\n        self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32\n    ) -> torch.FloatTensor:\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            w (`torch.Tensor`):\n                Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.\n            embedding_dim (`int`, *optional*, defaults to 512):\n                Dimension of the embeddings to generate.\n            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):\n                Data type of the generated embeddings.\n\n        Returns:\n            `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    def set_encoder_tile_settings(\n        self,\n        denoise_encoder_tile_sample_min_size=1024,\n        denoise_encoder_sample_overlap_factor=0.25,\n        vae_sample_size=1024,\n        vae_tile_overlap_factor=0.25,\n    ):\n        self.unet.denoise_encoder.tile_sample_min_size = denoise_encoder_tile_sample_min_size\n        self.unet.denoise_encoder.tile_overlap_factor = denoise_encoder_sample_overlap_factor\n        self.vae.config.sample_size = vae_sample_size\n        self.vae.tile_overlap_factor = vae_tile_overlap_factor\n\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`.\"\n        deprecate(\n            \"enable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_tiling()\n        self.unet.denoise_encoder.enable_tiling()\n\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`.\"\n        deprecate(\n            \"disable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_tiling()\n        self.unet.denoise_encoder.disable_tiling()\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def denoising_end(self):\n        return self._denoising_end\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    def prepare_image_latents(\n        self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None\n    ):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            image_latents = image\n        else:\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            # needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n            # if needs_upcasting:\n            #     image = image.float()\n            #     self.upcast_vae()\n            self.unet.denoise_encoder.to(device=image.device, dtype=image.dtype)\n            image_latents = self.unet.denoise_encoder(image)\n            self.unet.denoise_encoder.to(\"cpu\")\n            # cast back to fp16 if needed\n            # if needs_upcasting:\n            #     self.vae.to(dtype=torch.float16)\n\n        if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:\n            # expand image_latents for batch_size\n            deprecation_message = (\n                f\"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial\"\n                \" images (`image`). Initial images are now duplicating to match the number of text prompts. Note\"\n                \" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update\"\n                \" your script to pass as many initial images as text prompts to suppress this warning.\"\n            )\n            deprecate(\"len(prompt) != len(image)\", \"1.0.0\", deprecation_message, standard_warn=False)\n            additional_image_per_prompt = batch_size // image_latents.shape[0]\n            image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            image_latents = torch.cat([image_latents], dim=0)\n\n        if do_classifier_free_guidance:\n            image_latents = image_latents\n\n        if image_latents.dtype != self.vae.dtype:\n            image_latents = image_latents.to(dtype=self.vae.dtype)\n\n        return image_latents\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        lr_img: PipelineImageInput = None,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        start_point: str | None = \"noise\",\n        timesteps: List[int] = None,\n        denoising_end: Optional[float] = None,\n        overlap: float = 0.5,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        target_size: Optional[Tuple[int, int]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        add_sample: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            lr_img (PipelineImageInput, optional): Low-resolution input image for conditioning the generation process.\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            start_point (str, *optional*):\n                The starting point for the generation process. Can be \"noise\" (random noise) or \"lr\" (low-resolution image).\n                Defaults to \"noise\".\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise as determined by the discrete timesteps selected by the\n                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a\n                \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)\n            overlap (float):\n                Overlap factor for local attention tiling (between 0.0 and 1.0). Controls the overlap between adjacent\n                grid patches during processing. Defaults to 0.5.\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead\n                of a plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of\n                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).\n                Guidance rescale factor should fix overexposure when using zero terminal SNR.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            add_sample (bool):\n                Whether to include sample conditioning (e.g., low-resolution image) in the UNet during denoising.\n                Defaults to True.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple`. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n\n        # 0. Default height and width to unet\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            lr_img,\n            prompt,\n            prompt_2,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n        self._interrupt = False\n        self.tlc_vae_latents = LocalAttention((target_size[0] // 8, target_size[1] // 8), overlap)\n        self.tlc_vae_img = LocalAttention((target_size[0] // 8, target_size[1] // 8), overlap)\n\n        # 2. Define call parameters\n        batch_size = 1\n        num_images_per_prompt = 1\n\n        device = torch.device(\"cuda\")  # self._execution_device\n\n        # 3. Encode input prompt\n        lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        num_samples = num_images_per_prompt\n        with torch.inference_mode():\n            (\n                prompt_embeds,\n                negative_prompt_embeds,\n                pooled_prompt_embeds,\n                negative_pooled_prompt_embeds,\n            ) = self.encode_prompt(\n                prompt,\n                num_images_per_prompt=num_samples,\n                do_classifier_free_guidance=True,\n                negative_prompt=negative_prompt,\n                lora_scale=lora_scale,\n            )\n\n        lr_img_list = [lr_img]\n        lr_img = self.image_processor.preprocess(lr_img_list, height=height, width=width).to(\n            device, dtype=prompt_embeds.dtype\n        )\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n        image_latents = self.prepare_image_latents(\n            lr_img, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, self.do_classifier_free_guidance\n        )\n\n        image_latents = self.tlc_vae_img.grids(image_latents)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.vae.config.latent_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n        if start_point == \"lr\":\n            latents_condition_image = self.vae.encode(lr_img * 2 - 1).latent_dist.sample()\n            latents_condition_image = latents_condition_image * self.vae.config.scaling_factor\n            start_steps_tensor = torch.randint(999, 999 + 1, (latents.shape[0],), device=latents.device)\n            start_steps_tensor = start_steps_tensor.long()\n            latents = self.DDPMScheduler.add_noise(latents_condition_image[0:1, ...], latents, start_steps_tensor)\n\n        latents = self.tlc_vae_latents.grids(latents)\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n        views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * image_latents.shape[0]\n\n        # 7. Prepare added time ids & embeddings\n        add_text_embeds = pooled_prompt_embeds\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n\n        # 8. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # 8.1 Apply denoising_end\n        if (\n            self.denoising_end is not None\n            and isinstance(self.denoising_end, float)\n            and self.denoising_end > 0\n            and self.denoising_end < 1\n        ):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        # 9. Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        self._num_timesteps = len(timesteps)\n        sub_latents_num = latents.shape[0]\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if i >= 1:\n                    latents = self.tlc_vae_latents.grids(latents).to(dtype=latents.dtype)\n                if self.interrupt:\n                    continue\n                concat_grid = []\n                for sub_num in range(sub_latents_num):\n                    self.scheduler.__dict__.update(views_scheduler_status[sub_num])\n                    sub_latents = latents[sub_num, :, :, :].unsqueeze(0)\n                    img_sub_latents = image_latents[sub_num, :, :, :].unsqueeze(0)\n                    latent_model_input = (\n                        torch.cat([sub_latents] * 2) if self.do_classifier_free_guidance else sub_latents\n                    )\n                    img_sub_latents = (\n                        torch.cat([img_sub_latents] * 2) if self.do_classifier_free_guidance else img_sub_latents\n                    )\n                    scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n                    pos_height = self.tlc_vae_latents.idxes[sub_num][\"i\"]\n                    pos_width = self.tlc_vae_latents.idxes[sub_num][\"j\"]\n                    add_time_ids = [\n                        torch.tensor([original_size]),\n                        torch.tensor([[pos_height, pos_width]]),\n                        torch.tensor([target_size]),\n                    ]\n                    add_time_ids = torch.cat(add_time_ids, dim=1).to(\n                        img_sub_latents.device, dtype=img_sub_latents.dtype\n                    )\n                    add_time_ids = add_time_ids.repeat(2, 1).to(dtype=img_sub_latents.dtype)\n\n                    # predict the noise residual\n                    added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n                    with torch.amp.autocast(\n                        device.type, dtype=latents.dtype, enabled=latents.dtype != self.unet.dtype\n                    ):\n                        noise_pred = self.unet(\n                            scaled_latent_model_input,\n                            t,\n                            encoder_hidden_states=prompt_embeds,\n                            timestep_cond=timestep_cond,\n                            cross_attention_kwargs=self.cross_attention_kwargs,\n                            input_embedding=img_sub_latents,\n                            add_sample=add_sample,\n                            added_cond_kwargs=added_cond_kwargs,\n                            return_dict=False,\n                        )[0]\n\n                    # perform guidance\n                    if self.do_classifier_free_guidance:\n                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                        noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                    if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                        # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                        noise_pred = rescale_noise_cfg(\n                            noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale\n                        )\n\n                    # compute the previous noisy sample x_t -> x_t-1\n                    latents_dtype = sub_latents.dtype\n                    sub_latents = self.scheduler.step(\n                        noise_pred, t, sub_latents, **extra_step_kwargs, return_dict=False\n                    )[0]\n\n                    views_scheduler_status[sub_num] = copy.deepcopy(self.scheduler.__dict__)\n                    concat_grid.append(sub_latents)\n                    if latents.dtype != sub_latents:\n                        if torch.backends.mps.is_available():\n                            # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                            sub_latents = sub_latents.to(latents_dtype)\n\n                    if callback_on_step_end is not None:\n                        callback_kwargs = {}\n                        for k in callback_on_step_end_tensor_inputs:\n                            callback_kwargs[k] = locals()[k]\n                        callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                        latents = callback_outputs.pop(\"latents\", latents)\n                        prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                        negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                        add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                        negative_pooled_prompt_embeds = callback_outputs.pop(\n                            \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                        )\n                        add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n                latents = self.tlc_vae_latents.grids_inverse(torch.cat(concat_grid, dim=0)).to(sub_latents.dtype)\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n            elif latents.dtype != self.vae.dtype:\n                if torch.backends.mps.is_available():\n                    # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                    self.vae = self.vae.to(latents.dtype)\n\n            # unscale/denormalize the latents\n            # denormalize with the mean and std if available and not None\n            has_latents_mean = hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None\n            has_latents_std = hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None\n            if has_latents_mean and has_latents_std:\n                latents_mean = (\n                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents_std = (\n                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean\n            else:\n                latents = latents / self.vae.config.scaling_factor\n\n            image = self.vae.decode(latents, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n\n        if not output_type == \"latent\":\n            # apply watermark if available\n            if self.watermark is not None:\n                image = self.watermark.apply_watermark(image)\n\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_flux_differential_img2img.py",
    "content": "# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import FluxLoraLoaderMixin\nfrom diffusers.models.autoencoders import AutoencoderKL\nfrom diffusers.models.transformers import FluxTransformer2DModel\nfrom diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers.utils import load_image\n        >>> from pipeline import FluxDifferentialImg2ImgPipeline\n\n        >>> image = load_image(\n        >>>     \"https://github.com/exx8/differential-diffusion/blob/main/assets/input.jpg?raw=true\",\n        >>> )\n\n        >>> mask = load_image(\n        >>>     \"https://github.com/exx8/differential-diffusion/blob/main/assets/map.jpg?raw=true\",\n        >>> )\n\n        >>> pipe = FluxDifferentialImg2ImgPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16)\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> prompt = \"painting of a mountain landscape with a meadow and a forest, meadow background, anime countryside landscape, anime nature wallpap, anime landscape wallpaper, studio ghibli landscape, anime landscape, mountain behind meadow, anime background art, studio ghibli environment, background of flowery hill, anime beautiful peace scene, forrest background, anime scenery, landscape background, background art, anime scenery concept art\"\n        >>> out = pipe(\n        >>>     prompt=prompt,\n        >>>     num_inference_steps=20,\n        >>>     guidance_scale=7.5,\n        >>>     image=image,\n        >>>     mask_image=mask,\n        >>>     strength=1.0,\n        >>> ).images[0]\n\n        >>> out.save(\"image.png\")\n        ```\n        \"\"\"\n\n\n# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.15,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):\n    r\"\"\"\n    Differential Image to Image pipeline for the Flux family of models.\n\n    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/\n\n    Args:\n        transformer ([`FluxTransformer2DModel`]):\n            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([`T5EncoderModel`]):\n            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically\n            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`T5TokenizerFast`):\n            Second Tokenizer of class\n            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->transformer->vae\"\n    _optional_components = []\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\"]\n\n    def __init__(\n        self,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        text_encoder_2: T5EncoderModel,\n        tokenizer_2: T5TokenizerFast,\n        transformer: FluxTransformer2DModel,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            transformer=transformer,\n            scheduler=scheduler,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, \"vae\", None) else 16\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        latent_channels = self.vae.config.latent_channels if getattr(self, \"vae\", None) else 16\n        self.mask_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor,\n            vae_latent_channels=latent_channels,\n            do_normalize=False,\n            do_binarize=False,\n            do_convert_grayscale=True,\n        )\n        self.tokenizer_max_length = (\n            self.tokenizer.model_max_length if hasattr(self, \"tokenizer\") and self.tokenizer is not None else 77\n        )\n        self.default_sample_size = 64\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_images_per_prompt: int = 1,\n        max_sequence_length: int = 512,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        text_inputs = self.tokenizer_2(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer_2(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]\n\n        dtype = self.text_encoder_2.dtype\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        _, seq_len, _ = prompt_embeds.shape\n\n        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds\n    def _get_clip_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]],\n        num_images_per_prompt: int = 1,\n        device: Optional[torch.device] = None,\n    ):\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer_max_length,\n            truncation=True,\n            return_overflowing_tokens=False,\n            return_length=False,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer_max_length} tokens: {removed_text}\"\n            )\n        prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n        # Use pooled output of CLIPTextModel\n        prompt_embeds = prompt_embeds.pooler_output\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        prompt_2: Union[str, List[str]],\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        max_sequence_length: int = 512,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in all text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder, lora_scale)\n            if self.text_encoder_2 is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # We only use the pooled prompt output from the CLIPTextModel\n            pooled_prompt_embeds = self._get_clip_prompt_embeds(\n                prompt=prompt,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n            )\n            prompt_embeds = self._get_t5_prompt_embeds(\n                prompt=prompt_2,\n                num_images_per_prompt=num_images_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype\n        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\n\n        return prompt_embeds, pooled_prompt_embeds, text_ids\n\n    # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\n        if isinstance(generator, list):\n            image_latents = [\n                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                for i in range(image.shape[0])\n            ]\n            image_latents = torch.cat(image_latents, dim=0)\n        else:\n            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n        image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor\n\n        return image_latents\n\n    # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(num_inference_steps * strength, num_inference_steps)\n\n        t_start = int(max(num_inference_steps - init_timestep, 0))\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n        if hasattr(self.scheduler, \"set_begin_index\"):\n            self.scheduler.set_begin_index(t_start * self.scheduler.order)\n\n        return timesteps, num_inference_steps - t_start\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        image,\n        mask_image,\n        strength,\n        height,\n        width,\n        output_type,\n        prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        padding_mask_crop=None,\n        max_sequence_length=None,\n    ):\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if padding_mask_crop is not None:\n            if not isinstance(image, PIL.Image.Image):\n                raise ValueError(\n                    f\"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}.\"\n                )\n            if not isinstance(mask_image, PIL.Image.Image):\n                raise ValueError(\n                    f\"The mask image should be a PIL image when inpainting mask crop, but is of type\"\n                    f\" {type(mask_image)}.\"\n                )\n            if output_type != \"pil\":\n                raise ValueError(f\"The output type should be PIL when inpainting mask crop, but is {output_type}.\")\n\n        if max_sequence_length is not None and max_sequence_length > 512:\n            raise ValueError(f\"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}\")\n\n    @staticmethod\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids\n    def _prepare_latent_image_ids(batch_size, height, width, device, dtype):\n        latent_image_ids = torch.zeros(height // 2, width // 2, 3)\n        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]\n        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]\n\n        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape\n\n        latent_image_ids = latent_image_ids.reshape(\n            latent_image_id_height * latent_image_id_width, latent_image_id_channels\n        )\n\n        return latent_image_ids.to(device=device, dtype=dtype)\n\n    @staticmethod\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents\n    def _pack_latents(latents, batch_size, num_channels_latents, height, width):\n        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)\n        latents = latents.permute(0, 2, 4, 1, 3, 5)\n        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)\n\n        return latents\n\n    @staticmethod\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents\n    def _unpack_latents(latents, height, width, vae_scale_factor):\n        batch_size, num_patches, channels = latents.shape\n\n        height = height // vae_scale_factor\n        width = width // vae_scale_factor\n\n        latents = latents.view(batch_size, height, width, channels // 4, 2, 2)\n        latents = latents.permute(0, 3, 1, 4, 2, 5)\n\n        latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)\n\n        return latents\n\n    def prepare_latents(\n        self,\n        image,\n        timestep,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        height = 2 * (int(height) // self.vae_scale_factor)\n        width = 2 * (int(width) // self.vae_scale_factor)\n\n        shape = (batch_size, num_channels_latents, height, width)\n        latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)\n\n        image = image.to(device=device, dtype=dtype)\n        image_latents = self._encode_vae_image(image=image, generator=generator)\n\n        if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            additional_image_per_prompt = batch_size // image_latents.shape[0]\n            image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            image_latents = torch.cat([image_latents], dim=0)\n\n        if latents is None:\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            latents = self.scheduler.scale_noise(image_latents, timestep, noise)\n        else:\n            noise = latents.to(device)\n            latents = noise\n\n        noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)\n        image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)\n        latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)\n        return latents, noise, image_latents, latent_image_ids\n\n    def prepare_mask_latents(\n        self,\n        mask,\n        masked_image,\n        batch_size,\n        num_channels_latents,\n        num_images_per_prompt,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n    ):\n        height = 2 * (int(height) // self.vae_scale_factor)\n        width = 2 * (int(width) // self.vae_scale_factor)\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\n        # and half precision\n        mask = torch.nn.functional.interpolate(mask, size=(height, width))\n        mask = mask.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        masked_image = masked_image.to(device=device, dtype=dtype)\n\n        if masked_image.shape[1] == 16:\n            masked_image_latents = masked_image\n        else:\n            masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)\n\n        masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor\n\n        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\n        if mask.shape[0] < batch_size:\n            if not batch_size % mask.shape[0] == 0:\n                raise ValueError(\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\n                    f\" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number\"\n                    \" of masks that you pass is divisible by the total requested batch size.\"\n                )\n            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)\n        if masked_image_latents.shape[0] < batch_size:\n            if not batch_size % masked_image_latents.shape[0] == 0:\n                raise ValueError(\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                    f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                )\n            masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)\n\n        # aligning device to prevent device errors when concating it with the latent model input\n        masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\n\n        masked_image_latents = self._pack_latents(\n            masked_image_latents,\n            batch_size,\n            num_channels_latents,\n            height,\n            width,\n        )\n        mask = self._pack_latents(\n            mask.repeat(1, num_channels_latents, 1, 1),\n            batch_size,\n            num_channels_latents,\n            height,\n            width,\n        )\n\n        return mask, masked_image_latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def joint_attention_kwargs(self):\n        return self._joint_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        image: PipelineImageInput = None,\n        mask_image: PipelineImageInput = None,\n        masked_image_latents: PipelineImageInput = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        padding_mask_crop: Optional[int] = None,\n        strength: float = 0.6,\n        num_inference_steps: int = 28,\n        timesteps: List[int] = None,\n        guidance_scale: float = 7.0,\n        num_images_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 512,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                will be used instead\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):\n                `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both\n                numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list\n                or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a\n                list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image\n                latents as `image`, but if passing latents directly it is not encoded again.\n            mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):\n                `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask\n                are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a\n                single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one\n                color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,\n                H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,\n                1)`, or `(H, W)`.\n            mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):\n                `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask\n                latents tensor will be generated by `mask_image`.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n            padding_mask_crop (`int`, *optional*, defaults to `None`):\n                The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to\n                image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region\n                with the same aspect ration of the image and contains all masked area, and then expand that area based\n                on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before\n                resizing to the original image size for inpainting. This is useful when the masked area is small while\n                the image is large and contain information irrelevant for inpainting, such as background.\n            strength (`float`, *optional*, defaults to 1.0):\n                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1\n                essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.\n            joint_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`\n            is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated\n            images.\n        \"\"\"\n\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            image,\n            mask_image,\n            strength,\n            height,\n            width,\n            output_type=output_type,\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            padding_mask_crop=padding_mask_crop,\n            max_sequence_length=max_sequence_length,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._joint_attention_kwargs = joint_attention_kwargs\n        self._interrupt = False\n\n        # 2. Preprocess mask and image\n        if padding_mask_crop is not None:\n            crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)\n            resize_mode = \"fill\"\n        else:\n            crops_coords = None\n            resize_mode = \"default\"\n\n        original_image = image\n        init_image = self.image_processor.preprocess(\n            image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode\n        )\n        init_image = init_image.to(dtype=torch.float32)\n\n        # 3. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        lora_scale = (\n            self.joint_attention_kwargs.get(\"scale\", None) if self.joint_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            pooled_prompt_embeds,\n            text_ids,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            max_sequence_length=max_sequence_length,\n            lora_scale=lora_scale,\n        )\n\n        # 4.Prepare timesteps\n        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)\n        image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)\n        mu = calculate_shift(\n            image_seq_len,\n            self.scheduler.config.get(\"base_image_seq_len\", 256),\n            self.scheduler.config.get(\"max_image_seq_len\", 4096),\n            self.scheduler.config.get(\"base_shift\", 0.5),\n            self.scheduler.config.get(\"max_shift\", 1.15),\n        )\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            timesteps,\n            sigmas,\n            mu=mu,\n        )\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n\n        if num_inference_steps < 1:\n            raise ValueError(\n                f\"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline\"\n                f\"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline.\"\n            )\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.transformer.config.in_channels // 4\n\n        latents, noise, original_image_latents, latent_image_ids = self.prepare_latents(\n            init_image,\n            latent_timestep,\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # start diff diff preparation\n        original_mask = self.mask_processor.preprocess(\n            mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords\n        )\n\n        masked_image = init_image * original_mask\n        original_mask, _ = self.prepare_mask_latents(\n            original_mask,\n            masked_image,\n            batch_size,\n            num_channels_latents,\n            num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n        )\n\n        mask_thresholds = torch.arange(num_inference_steps, dtype=original_mask.dtype) / num_inference_steps\n        mask_thresholds = mask_thresholds.reshape(-1, 1, 1, 1).to(device)\n        masks = original_mask > mask_thresholds\n        # end diff diff preparation\n\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # handle guidance\n        if self.transformer.config.guidance_embeds:\n            guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)\n            guidance = guidance.expand(latents.shape[0])\n        else:\n            guidance = None\n\n        # 6. Denoising loop\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latents.shape[0]).to(latents.dtype)\n                noise_pred = self.transformer(\n                    hidden_states=latents,\n                    timestep=timestep / 1000,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    joint_attention_kwargs=self.joint_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n\n                # for 64 channel transformer only.\n                image_latent = original_image_latents\n\n                if i < len(timesteps) - 1:\n                    noise_timestep = timesteps[i + 1]\n                    image_latent = self.scheduler.scale_noise(\n                        original_image_latents, torch.tensor([noise_timestep]), noise\n                    )\n\n                    # start diff diff\n                    mask = masks[i].to(latents_dtype)\n                    latents = image_latent * mask + latents * (1 - mask)\n                    # end diff diff\n\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if output_type == \"latent\":\n            image = latents\n\n        else:\n            latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)\n            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor\n            image = self.vae.decode(latents, return_dict=False)[0]\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return FluxPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_flux_kontext_multiple_images.py",
    "content": "# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTextModel,\n    CLIPTokenizer,\n    CLIPVisionModelWithProjection,\n    T5EncoderModel,\n    T5TokenizerFast,\n)\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, FluxTransformer2DModel\nfrom diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nPipelineSeveralImagesInput = Union[\n    Tuple[PIL.Image.Image, ...],\n    Tuple[np.ndarray, ...],\n    Tuple[torch.Tensor, ...],\n    List[Tuple[PIL.Image.Image, ...]],\n    List[Tuple[np.ndarray, ...]],\n    List[Tuple[torch.Tensor, ...]],\n]\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import FluxKontextPipeline\n        >>> from diffusers.utils import load_image\n\n        >>> pipe = FluxKontextPipeline.from_pretrained(\n        ...     \"black-forest-labs/FLUX.1-Kontext-dev\", torch_dtype=torch.bfloat16\n        ... )\n        >>> pipe.to(\"cuda\")\n\n        >>> image = load_image(\n        ...     \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png\"\n        ... ).convert(\"RGB\")\n        >>> prompt = \"Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors\"\n        >>> image = pipe(\n        ...     image=image,\n        ...     prompt=prompt,\n        ...     guidance_scale=2.5,\n        ...     generator=torch.Generator().manual_seed(42),\n        ... ).images[0]\n        >>> image.save(\"output.png\")\n        ```\n\"\"\"\n\nPREFERRED_KONTEXT_RESOLUTIONS = [\n    (672, 1568),\n    (688, 1504),\n    (720, 1456),\n    (752, 1392),\n    (800, 1328),\n    (832, 1248),\n    (880, 1184),\n    (944, 1104),\n    (1024, 1024),\n    (1104, 944),\n    (1184, 880),\n    (1248, 832),\n    (1328, 800),\n    (1392, 752),\n    (1456, 720),\n    (1504, 688),\n    (1568, 672),\n]\n\n\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.15,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\nclass FluxKontextPipeline(\n    DiffusionPipeline,\n    FluxLoraLoaderMixin,\n    FromSingleFileMixin,\n    TextualInversionLoaderMixin,\n    FluxIPAdapterMixin,\n):\n    r\"\"\"\n    The Flux Kontext pipeline for text-to-image generation.\n\n    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/\n\n    Args:\n        transformer ([`FluxTransformer2DModel`]):\n            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([`T5EncoderModel`]):\n            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically\n            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`T5TokenizerFast`):\n            Second Tokenizer of class\n            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->image_encoder->transformer->vae\"\n    _optional_components = [\"image_encoder\", \"feature_extractor\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\"]\n\n    def __init__(\n        self,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        text_encoder_2: T5EncoderModel,\n        tokenizer_2: T5TokenizerFast,\n        transformer: FluxTransformer2DModel,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        feature_extractor: CLIPImageProcessor = None,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            transformer=transformer,\n            scheduler=scheduler,\n            image_encoder=image_encoder,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible\n        # by the patch size. So the vae scale factor is multiplied by the patch size to account for this\n        self.latent_channels = self.vae.config.latent_channels if getattr(self, \"vae\", None) else 16\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)\n        self.tokenizer_max_length = (\n            self.tokenizer.model_max_length if hasattr(self, \"tokenizer\") and self.tokenizer is not None else 77\n        )\n        self.default_sample_size = 128\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_images_per_prompt: int = 1,\n        max_sequence_length: int = 512,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        if isinstance(self, TextualInversionLoaderMixin):\n            prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)\n\n        text_inputs = self.tokenizer_2(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer_2(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]\n\n        dtype = self.text_encoder_2.dtype\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        _, seq_len, _ = prompt_embeds.shape\n\n        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds\n    def _get_clip_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]],\n        num_images_per_prompt: int = 1,\n        device: Optional[torch.device] = None,\n    ):\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        if isinstance(self, TextualInversionLoaderMixin):\n            prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer_max_length,\n            truncation=True,\n            return_overflowing_tokens=False,\n            return_length=False,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer_max_length} tokens: {removed_text}\"\n            )\n        prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n        # Use pooled output of CLIPTextModel\n        prompt_embeds = prompt_embeds.pooler_output\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        prompt_2: Union[str, List[str]],\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        max_sequence_length: int = 512,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in all text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder, lora_scale)\n            if self.text_encoder_2 is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # We only use the pooled prompt output from the CLIPTextModel\n            pooled_prompt_embeds = self._get_clip_prompt_embeds(\n                prompt=prompt,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n            )\n            prompt_embeds = self._get_t5_prompt_embeds(\n                prompt=prompt_2,\n                num_images_per_prompt=num_images_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype\n        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\n\n        return prompt_embeds, pooled_prompt_embeds, text_ids\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        image_embeds = self.image_encoder(image).image_embeds\n        image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n        return image_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt\n    ):\n        image_embeds = []\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters.\"\n                )\n\n            for single_ip_adapter_image in ip_adapter_image:\n                single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)\n                image_embeds.append(single_image_embeds[None, :])\n        else:\n            if not isinstance(ip_adapter_image_embeds, list):\n                ip_adapter_image_embeds = [ip_adapter_image_embeds]\n\n            if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters.\"\n                )\n\n            for single_image_embeds in ip_adapter_image_embeds:\n                image_embeds.append(single_image_embeds)\n\n        ip_adapter_image_embeds = []\n        for single_image_embeds in image_embeds:\n            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)\n            single_image_embeds = single_image_embeds.to(device=device)\n            ip_adapter_image_embeds.append(single_image_embeds)\n\n        return ip_adapter_image_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        height,\n        width,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        max_sequence_length=None,\n    ):\n        if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:\n            logger.warning(\n                f\"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n        if max_sequence_length is not None and max_sequence_length > 512:\n            raise ValueError(f\"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}\")\n\n    @staticmethod\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids\n    def _prepare_latent_image_ids(batch_size, height, width, device, dtype):\n        latent_image_ids = torch.zeros(height, width, 3)\n        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]\n        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]\n\n        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape\n\n        latent_image_ids = latent_image_ids.reshape(\n            latent_image_id_height * latent_image_id_width, latent_image_id_channels\n        )\n\n        return latent_image_ids.to(device=device, dtype=dtype)\n\n    @staticmethod\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents\n    def _pack_latents(latents, batch_size, num_channels_latents, height, width):\n        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)\n        latents = latents.permute(0, 2, 4, 1, 3, 5)\n        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)\n\n        return latents\n\n    @staticmethod\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents\n    def _unpack_latents(latents, height, width, vae_scale_factor):\n        batch_size, num_patches, channels = latents.shape\n\n        # VAE applies 8x compression on images but we must also account for packing which requires\n        # latent height and width to be divisible by 2.\n        height = 2 * (int(height) // (vae_scale_factor * 2))\n        width = 2 * (int(width) // (vae_scale_factor * 2))\n\n        latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)\n        latents = latents.permute(0, 3, 1, 4, 2, 5)\n\n        latents = latents.reshape(batch_size, channels // (2 * 2), height, width)\n\n        return latents\n\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\n        if isinstance(generator, list):\n            image_latents = [\n                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=\"argmax\")\n                for i in range(image.shape[0])\n            ]\n            image_latents = torch.cat(image_latents, dim=0)\n        else:\n            image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode=\"argmax\")\n\n        image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor\n\n        return image_latents\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        self.vae.enable_slicing()\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_slicing()\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`.\"\n        deprecate(\n            \"enable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_tiling()\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`.\"\n        deprecate(\n            \"disable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_tiling()\n\n    def preprocess_image(self, image: PipelineImageInput, _auto_resize: bool, multiple_of: int) -> torch.Tensor:\n        img = image[0] if isinstance(image, list) else image\n        image_height, image_width = self.image_processor.get_default_height_width(img)\n        aspect_ratio = image_width / image_height\n        if _auto_resize:\n            # Kontext is trained on specific resolutions, using one of them is recommended\n            _, image_width, image_height = min(\n                (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS\n            )\n        image_width = image_width // multiple_of * multiple_of\n        image_height = image_height // multiple_of * multiple_of\n        image = self.image_processor.resize(image, image_height, image_width)\n        image = self.image_processor.preprocess(image, image_height, image_width)\n        return image\n\n    def preprocess_images(\n        self,\n        images: PipelineSeveralImagesInput,\n        _auto_resize: bool,\n        multiple_of: int,\n    ) -> torch.Tensor:\n        # TODO for reviewer: I'm not sure what's the best way to implement this part given the philosophy of the repo.\n        # The solutions I thought about are:\n        # - Make the `resize` and `preprocess` methods of `VaeImageProcessor` more generic (using TypeVar for instance)\n        # - Start by converting the image to a List[Tuple[ {image_format} ]], to unify the processing logic\n        # - Or duplicate the code, as done here.\n        # What do you think ?\n\n        # convert multiple_images to a list of tuple, to simplify following logic\n        if not isinstance(images, list):\n            images = [images]\n        # now multiple_images is a list of tuples.\n\n        img = images[0][0]\n        image_height, image_width = self.image_processor.get_default_height_width(img)\n        aspect_ratio = image_width / image_height\n        if _auto_resize:\n            # Kontext is trained on specific resolutions, using one of them is recommended\n            _, image_width, image_height = min(\n                (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS\n            )\n        image_width = image_width // multiple_of * multiple_of\n        image_height = image_height // multiple_of * multiple_of\n        n_image_per_batch = len(images[0])\n        output_images = []\n        for i in range(n_image_per_batch):\n            image = [batch_images[i] for batch_images in images]\n            image = self.image_processor.resize(image, image_height, image_width)\n            image = self.image_processor.preprocess(image, image_height, image_width)\n            output_images.append(image)\n        return output_images\n\n    def prepare_latents(\n        self,\n        images: Optional[list[torch.Tensor]],\n        batch_size: int,\n        num_channels_latents: int,\n        height: int,\n        width: int,\n        dtype: torch.dtype,\n        device: torch.device,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n    ):\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        # VAE applies 8x compression on images but we must also account for packing which requires\n        # latent height and width to be divisible by 2.\n        height = 2 * (int(height) // (self.vae_scale_factor * 2))\n        width = 2 * (int(width) // (self.vae_scale_factor * 2))\n        shape = (batch_size, num_channels_latents, height, width)\n\n        all_image_latents = []\n        all_image_ids = []\n        image_latents = images_ids = None\n        if images is not None:\n            for i, image in enumerate(images):\n                image = image.to(device=device, dtype=dtype)\n                if image.shape[1] != self.latent_channels:\n                    image_latents = self._encode_vae_image(image=image, generator=generator)\n                else:\n                    image_latents = image\n                if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:\n                    # expand init_latents for batch_size\n                    additional_image_per_prompt = batch_size // image_latents.shape[0]\n                    image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)\n                elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:\n                    raise ValueError(\n                        f\"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.\"\n                    )\n                else:\n                    image_latents = torch.cat([image_latents], dim=0)\n\n                image_latent_height, image_latent_width = image_latents.shape[2:]\n                image_latents = self._pack_latents(\n                    image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width\n                )\n                image_ids = self._prepare_latent_image_ids(\n                    batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype\n                )\n                # image ids are the same as latent ids with the first dimension set to 1 instead of 0\n                image_ids[..., 0] = 1\n\n                # set the image ids to the correct position in the latent grid\n                image_ids[..., 2] += i * (image_latent_height // 2)\n\n                all_image_ids.append(image_ids)\n                all_image_latents.append(image_latents)\n\n            image_latents = torch.cat(all_image_latents, dim=1)\n            images_ids = torch.cat(all_image_ids, dim=0)\n\n        latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)\n        else:\n            latents = latents.to(device=device, dtype=dtype)\n\n        return latents, image_latents, latent_ids, images_ids\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def joint_attention_kwargs(self):\n        return self._joint_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def current_timestep(self):\n        return self._current_timestep\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        image: Optional[PipelineImageInput] = None,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        negative_prompt: Union[str, List[str]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        true_cfg_scale: float = 1.0,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 28,\n        sigmas: Optional[List[float]] = None,\n        guidance_scale: float = 3.5,\n        num_images_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        negative_ip_adapter_image: Optional[PipelineImageInput] = None,\n        negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 512,\n        max_area: int = 1024**2,\n        _auto_resize: bool = True,\n        multiple_images: Optional[PipelineSeveralImagesInput] = None,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):\n                `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both\n                numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list\n                or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a\n                list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image\n                latents as `image`, but if passing latents directly it is not encoded again.\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                will be used instead.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is\n                not greater than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.\n            true_cfg_scale (`float`, *optional*, defaults to 1.0):\n                When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            guidance_scale (`float`, *optional*, defaults to 3.5):\n                Guidance scale as defined in [Classifier-Free Diffusion\n                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.\n                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting\n                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to\n                the text `prompt`, usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*):\n                Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            negative_ip_adapter_image:\n                (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.\n            joint_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int` defaults to 512):\n                Maximum sequence length to use with the `prompt`.\n            max_area (`int`, defaults to `1024 ** 2`):\n                The maximum area of the generated image in pixels. The height and width will be adjusted to fit this\n                area while maintaining the aspect ratio.\n            multiple_images (`PipelineSeveralImagesInput`, *optional*):\n                A list of images to be used as reference images for the generation. If provided, the pipeline will\n                merge the reference images in the latent space.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`\n            is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated\n            images.\n        \"\"\"\n\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        original_height, original_width = height, width\n        aspect_ratio = width / height\n        width = round((max_area * aspect_ratio) ** 0.5)\n        height = round((max_area / aspect_ratio) ** 0.5)\n\n        multiple_of = self.vae_scale_factor * 2\n        width = width // multiple_of * multiple_of\n        height = height // multiple_of * multiple_of\n\n        if height != original_height or width != original_width:\n            logger.warning(\n                f\"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements.\"\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            height,\n            width,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            max_sequence_length=max_sequence_length,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._joint_attention_kwargs = joint_attention_kwargs\n        self._current_timestep = None\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        lora_scale = (\n            self.joint_attention_kwargs.get(\"scale\", None) if self.joint_attention_kwargs is not None else None\n        )\n        has_neg_prompt = negative_prompt is not None or (\n            negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None\n        )\n        do_true_cfg = true_cfg_scale > 1 and has_neg_prompt\n        (\n            prompt_embeds,\n            pooled_prompt_embeds,\n            text_ids,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            max_sequence_length=max_sequence_length,\n            lora_scale=lora_scale,\n        )\n        if do_true_cfg:\n            (\n                negative_prompt_embeds,\n                negative_pooled_prompt_embeds,\n                negative_text_ids,\n            ) = self.encode_prompt(\n                prompt=negative_prompt,\n                prompt_2=negative_prompt_2,\n                prompt_embeds=negative_prompt_embeds,\n                pooled_prompt_embeds=negative_pooled_prompt_embeds,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n                max_sequence_length=max_sequence_length,\n                lora_scale=lora_scale,\n            )\n\n        # 3. Preprocess image\n        if image is not None and multiple_images is not None:\n            raise ValueError(\"Cannot pass both `image` and `multiple_images`. Please use only one of them.\")\n        if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):\n            image = [self.preprocess_image(image, _auto_resize=True, multiple_of=multiple_of)]\n        if multiple_images is not None:\n            image = self.preprocess_images(\n                multiple_images,\n                _auto_resize=_auto_resize,\n                multiple_of=multiple_of,\n            )\n\n        # 4. Prepare latent variables\n        num_channels_latents = self.transformer.config.in_channels // 4\n        latents, image_latents, latent_ids, image_ids = self.prepare_latents(\n            image,\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n        if image_ids is not None:\n            latent_ids = torch.cat([latent_ids, image_ids], dim=0)  # dim 0 is sequence dimension\n\n        # 5. Prepare timesteps\n        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas\n        image_seq_len = latents.shape[1]\n        mu = calculate_shift(\n            image_seq_len,\n            self.scheduler.config.get(\"base_image_seq_len\", 256),\n            self.scheduler.config.get(\"max_image_seq_len\", 4096),\n            self.scheduler.config.get(\"base_shift\", 0.5),\n            self.scheduler.config.get(\"max_shift\", 1.15),\n        )\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            sigmas=sigmas,\n            mu=mu,\n        )\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n        self._num_timesteps = len(timesteps)\n\n        # handle guidance\n        if self.transformer.config.guidance_embeds:\n            guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)\n            guidance = guidance.expand(latents.shape[0])\n        else:\n            guidance = None\n\n        if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (\n            negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None\n        ):\n            negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)\n            negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters\n\n        elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (\n            negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None\n        ):\n            ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)\n            ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters\n\n        if self.joint_attention_kwargs is None:\n            self._joint_attention_kwargs = {}\n\n        image_embeds = None\n        negative_image_embeds = None\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n            )\n        if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:\n            negative_image_embeds = self.prepare_ip_adapter_image_embeds(\n                negative_ip_adapter_image,\n                negative_ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n            )\n\n        # 6. Denoising loop\n        # We set the index here to remove DtoH sync, helpful especially during compilation.\n        # Check out more details here: https://github.com/huggingface/diffusers/pull/11696\n        self.scheduler.set_begin_index(0)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                self._current_timestep = t\n                if image_embeds is not None:\n                    self._joint_attention_kwargs[\"ip_adapter_image_embeds\"] = image_embeds\n\n                latent_model_input = latents\n                if image_latents is not None:\n                    latent_model_input = torch.cat([latents, image_latents], dim=1)\n                timestep = t.expand(latents.shape[0]).to(latents.dtype)\n\n                noise_pred = self.transformer(\n                    hidden_states=latent_model_input,\n                    timestep=timestep / 1000,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_ids,\n                    joint_attention_kwargs=self.joint_attention_kwargs,\n                    return_dict=False,\n                )[0]\n                noise_pred = noise_pred[:, : latents.size(1)]\n\n                if do_true_cfg:\n                    if negative_image_embeds is not None:\n                        self._joint_attention_kwargs[\"ip_adapter_image_embeds\"] = negative_image_embeds\n                    neg_noise_pred = self.transformer(\n                        hidden_states=latent_model_input,\n                        timestep=timestep / 1000,\n                        guidance=guidance,\n                        pooled_projections=negative_pooled_prompt_embeds,\n                        encoder_hidden_states=negative_prompt_embeds,\n                        txt_ids=negative_text_ids,\n                        img_ids=latent_ids,\n                        joint_attention_kwargs=self.joint_attention_kwargs,\n                        return_dict=False,\n                    )[0]\n                    neg_noise_pred = neg_noise_pred[:, : latents.size(1)]\n                    noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        self._current_timestep = None\n\n        if output_type == \"latent\":\n            image = latents\n        else:\n            latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)\n            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor\n            image = self.vae.decode(latents, return_dict=False)[0]\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return FluxPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_flux_rf_inversion.py",
    "content": "# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.\n# modeled after RF Inversion: https://rf-inversion.github.io/, authored by Litu Rout, Yujia Chen, Nataniel Ruiz,\n# Constantine Caramanis, Sanjay Shakkottai and Wen-Sheng Chu.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin\nfrom diffusers.models.autoencoders import AutoencoderKL\nfrom diffusers.models.transformers import FluxTransformer2DModel\nfrom diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> import requests\n        >>> import PIL\n        >>> from io import BytesIO\n        >>> from diffusers import DiffusionPipeline\n\n        >>> pipe = DiffusionPipeline.from_pretrained(\n        ...    \"black-forest-labs/FLUX.1-dev\",\n        ...    torch_dtype=torch.bfloat16,\n        ...    custom_pipeline=\"pipeline_flux_rf_inversion\")\n        >>> pipe.to(\"cuda\")\n\n         >>> def download_image(url):\n        ...     response = requests.get(url)\n        ...     return PIL.Image.open(BytesIO(response.content)).convert(\"RGB\")\n\n\n        >>> img_url = \"https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg\"\n        >>> image = download_image(img_url)\n\n        >>> inverted_latents, image_latents, latent_image_ids = pipe.invert(image=image, num_inversion_steps=28, gamma=0.5)\n\n        >>> edited_image = pipe(\n        ...     prompt=\"a tomato\",\n        ...     inverted_latents=inverted_latents,\n        ...     image_latents=image_latents,\n        ...     latent_image_ids=latent_image_ids,\n        ...     start_timestep=0,\n        ...     stop_timestep=.25,\n        ...     num_inference_steps=28,\n        ...     eta=0.9,\n        ... ).images[0]\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.15,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass RFInversionFluxPipeline(\n    DiffusionPipeline,\n    FluxLoraLoaderMixin,\n    FromSingleFileMixin,\n    TextualInversionLoaderMixin,\n):\n    r\"\"\"\n    The Flux pipeline for text-to-image generation.\n\n    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/\n\n    Args:\n        transformer ([`FluxTransformer2DModel`]):\n            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([`T5EncoderModel`]):\n            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically\n            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`T5TokenizerFast`):\n            Second Tokenizer of class\n            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->transformer->vae\"\n    _optional_components = []\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\"]\n\n    def __init__(\n        self,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        text_encoder_2: T5EncoderModel,\n        tokenizer_2: T5TokenizerFast,\n        transformer: FluxTransformer2DModel,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            transformer=transformer,\n            scheduler=scheduler,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.tokenizer_max_length = (\n            self.tokenizer.model_max_length if hasattr(self, \"tokenizer\") and self.tokenizer is not None else 77\n        )\n        self.default_sample_size = 128\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_images_per_prompt: int = 1,\n        max_sequence_length: int = 512,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        if isinstance(self, TextualInversionLoaderMixin):\n            prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)\n\n        text_inputs = self.tokenizer_2(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer_2(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]\n\n        dtype = self.text_encoder_2.dtype\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        _, seq_len, _ = prompt_embeds.shape\n\n        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds\n    def _get_clip_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]],\n        num_images_per_prompt: int = 1,\n        device: Optional[torch.device] = None,\n    ):\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        if isinstance(self, TextualInversionLoaderMixin):\n            prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer_max_length,\n            truncation=True,\n            return_overflowing_tokens=False,\n            return_length=False,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer_max_length} tokens: {removed_text}\"\n            )\n        prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n        # Use pooled output of CLIPTextModel\n        prompt_embeds = prompt_embeds.pooler_output\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        prompt_2: Union[str, List[str]],\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        max_sequence_length: int = 512,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in all text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder, lora_scale)\n            if self.text_encoder_2 is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # We only use the pooled prompt output from the CLIPTextModel\n            pooled_prompt_embeds = self._get_clip_prompt_embeds(\n                prompt=prompt,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n            )\n            prompt_embeds = self._get_t5_prompt_embeds(\n                prompt=prompt_2,\n                num_images_per_prompt=num_images_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype\n        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\n\n        return prompt_embeds, pooled_prompt_embeds, text_ids\n\n    @torch.no_grad()\n    # Modified from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.encode_image\n    def encode_image(self, image, dtype=None, height=None, width=None, resize_mode=\"default\", crops_coords=None):\n        image = self.image_processor.preprocess(\n            image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords\n        )\n        resized = self.image_processor.postprocess(image=image, output_type=\"pil\")\n\n        if max(image.shape[-2:]) > self.vae.config[\"sample_size\"] * 1.5:\n            logger.warning(\n                \"Your input images far exceed the default resolution of the underlying diffusion model. \"\n                \"The output images may contain severe artifacts! \"\n                \"Consider down-sampling the input using the `height` and `width` parameters\"\n            )\n        image = image.to(dtype)\n\n        x0 = self.vae.encode(image.to(self._execution_device)).latent_dist.sample()\n        x0 = (x0 - self.vae.config.shift_factor) * self.vae.config.scaling_factor\n        x0 = x0.to(dtype)\n        return x0, resized\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        inverted_latents,\n        image_latents,\n        latent_image_ids,\n        height,\n        width,\n        start_timestep,\n        stop_timestep,\n        prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        max_sequence_length=None,\n    ):\n        if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:\n            raise ValueError(\n                f\"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if max_sequence_length is not None and max_sequence_length > 512:\n            raise ValueError(f\"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}\")\n\n        if inverted_latents is not None and (image_latents is None or latent_image_ids is None):\n            raise ValueError(\n                \"If `inverted_latents` are provided, `image_latents` and `latent_image_ids` also have to be passed. \"\n            )\n        # check start_timestep and stop_timestep\n        if start_timestep < 0 or start_timestep > stop_timestep:\n            raise ValueError(f\"`start_timestep` should be in [0, stop_timestep] but is {start_timestep}\")\n\n    @staticmethod\n    def _prepare_latent_image_ids(batch_size, height, width, device, dtype):\n        latent_image_ids = torch.zeros(height, width, 3)\n        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]\n        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]\n\n        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape\n\n        latent_image_ids = latent_image_ids.reshape(\n            latent_image_id_height * latent_image_id_width, latent_image_id_channels\n        )\n\n        return latent_image_ids.to(device=device, dtype=dtype)\n\n    @staticmethod\n    def _pack_latents(latents, batch_size, num_channels_latents, height, width):\n        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)\n        latents = latents.permute(0, 2, 4, 1, 3, 5)\n        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)\n\n        return latents\n\n    @staticmethod\n    def _unpack_latents(latents, height, width, vae_scale_factor):\n        batch_size, num_patches, channels = latents.shape\n\n        height = height // vae_scale_factor\n        width = width // vae_scale_factor\n\n        latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)\n        latents = latents.permute(0, 3, 1, 4, 2, 5)\n\n        latents = latents.reshape(batch_size, channels // (2 * 2), height, width)\n\n        return latents\n\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`.\"\n        deprecate(\n            \"enable_vae_slicing\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_slicing()\n\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`.\"\n        deprecate(\n            \"disable_vae_slicing\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_slicing()\n\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`.\"\n        deprecate(\n            \"enable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_tiling()\n\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`.\"\n        deprecate(\n            \"disable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_tiling()\n\n    def prepare_latents_inversion(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        image_latents,\n    ):\n        height = int(height) // self.vae_scale_factor\n        width = int(width) // self.vae_scale_factor\n\n        latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)\n\n        latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)\n\n        return latents, latent_image_ids\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        # VAE applies 8x compression on images but we must also account for packing which requires\n        # latent height and width to be divisible by 2.\n        height = 2 * (int(height) // (self.vae_scale_factor * 2))\n        width = 2 * (int(width) // (self.vae_scale_factor * 2))\n\n        shape = (batch_size, num_channels_latents, height, width)\n\n        if latents is not None:\n            latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)\n            return latents.to(device=device, dtype=dtype), latent_image_ids\n\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)\n\n        latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)\n\n        return latents, latent_image_ids\n\n    # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength=1.0):\n        # get the original timestep using init_timestep\n        init_timestep = min(num_inference_steps * strength, num_inference_steps)\n\n        t_start = int(max(num_inference_steps - init_timestep, 0))\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n        sigmas = self.scheduler.sigmas[t_start * self.scheduler.order :]\n        if hasattr(self.scheduler, \"set_begin_index\"):\n            self.scheduler.set_begin_index(t_start * self.scheduler.order)\n\n        return timesteps, sigmas, num_inference_steps - t_start\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def joint_attention_kwargs(self):\n        return self._joint_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        inverted_latents: Optional[torch.FloatTensor] = None,\n        image_latents: Optional[torch.FloatTensor] = None,\n        latent_image_ids: Optional[torch.FloatTensor] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        eta: float = 1.0,\n        decay_eta: Optional[bool] = False,\n        eta_decay_power: Optional[float] = 1.0,\n        strength: float = 1.0,\n        start_timestep: float = 0,\n        stop_timestep: float = 0.25,\n        num_inference_steps: int = 28,\n        sigmas: Optional[List[float]] = None,\n        timesteps: List[int] = None,\n        guidance_scale: float = 3.5,\n        num_images_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 512,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                will be used instead\n            inverted_latents (`torch.Tensor`, *optional*):\n                The inverted latents from `pipe.invert`.\n            image_latents (`torch.Tensor`, *optional*):\n                The image latents from `pipe.invert`.\n            latent_image_ids (`torch.Tensor`, *optional*):\n                The latent image ids from `pipe.invert`.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n            eta (`float`, *optional*, defaults to 1.0):\n                The controller guidance, balancing faithfulness & editability:\n                higher eta - better faithfullness, less editability. For more significant edits, lower the value of eta.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.\n            joint_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`\n            is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated\n            images.\n        \"\"\"\n\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            inverted_latents,\n            image_latents,\n            latent_image_ids,\n            height,\n            width,\n            start_timestep,\n            stop_timestep,\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            max_sequence_length=max_sequence_length,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._joint_attention_kwargs = joint_attention_kwargs\n        self._interrupt = False\n        do_rf_inversion = inverted_latents is not None\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        lora_scale = (\n            self.joint_attention_kwargs.get(\"scale\", None) if self.joint_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            pooled_prompt_embeds,\n            text_ids,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            max_sequence_length=max_sequence_length,\n            lora_scale=lora_scale,\n        )\n\n        # 4. Prepare latent variables\n        num_channels_latents = self.transformer.config.in_channels // 4\n        if do_rf_inversion:\n            latents = inverted_latents\n        else:\n            latents, latent_image_ids = self.prepare_latents(\n                batch_size * num_images_per_prompt,\n                num_channels_latents,\n                height,\n                width,\n                prompt_embeds.dtype,\n                device,\n                generator,\n                latents,\n            )\n\n        # 5. Prepare timesteps\n        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas\n        image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)\n        mu = calculate_shift(\n            image_seq_len,\n            self.scheduler.config.get(\"base_image_seq_len\", 256),\n            self.scheduler.config.get(\"max_image_seq_len\", 4096),\n            self.scheduler.config.get(\"base_shift\", 0.5),\n            self.scheduler.config.get(\"max_shift\", 1.15),\n        )\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            timesteps,\n            sigmas,\n            mu=mu,\n        )\n        if do_rf_inversion:\n            start_timestep = int(start_timestep * num_inference_steps)\n            stop_timestep = min(int(stop_timestep * num_inference_steps), num_inference_steps)\n            timesteps, sigmas, num_inference_steps = self.get_timesteps(num_inference_steps, strength)\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n        self._num_timesteps = len(timesteps)\n\n        # handle guidance\n        if self.transformer.config.guidance_embeds:\n            guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)\n            guidance = guidance.expand(latents.shape[0])\n        else:\n            guidance = None\n\n        if do_rf_inversion:\n            y_0 = image_latents.clone()\n        # 6. Denoising loop / Controlled Reverse ODE, Algorithm 2 from: https://huggingface.co/papers/2410.10792\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if do_rf_inversion:\n                    # ti (current timestep) as annotated in algorithm 2 - i/num_inference_steps.\n                    t_i = 1 - t / 1000\n                    dt = torch.tensor(1 / (len(timesteps) - 1), device=device)\n\n                if self.interrupt:\n                    continue\n\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latents.shape[0]).to(latents.dtype)\n\n                noise_pred = self.transformer(\n                    hidden_states=latents,\n                    timestep=timestep / 1000,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    joint_attention_kwargs=self.joint_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                latents_dtype = latents.dtype\n                if do_rf_inversion:\n                    v_t = -noise_pred\n                    v_t_cond = (y_0 - latents) / (1 - t_i)\n                    eta_t = eta if start_timestep <= i < stop_timestep else 0.0\n                    if decay_eta:\n                        eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power  # Decay eta over the loop\n                    v_hat_t = v_t + eta_t * (v_t_cond - v_t)\n\n                    # SDE Eq: 17 from https://huggingface.co/papers/2410.10792\n                    latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])\n                else:\n                    # compute the previous noisy sample x_t -> x_t-1\n                    latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if output_type == \"latent\":\n            image = latents\n\n        else:\n            latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)\n            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor\n            image = self.vae.decode(latents, return_dict=False)[0]\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return FluxPipelineOutput(images=image)\n\n    @torch.no_grad()\n    def invert(\n        self,\n        image: PipelineImageInput,\n        source_prompt: str = \"\",\n        source_guidance_scale=0.0,\n        num_inversion_steps: int = 28,\n        strength: float = 1.0,\n        gamma: float = 0.5,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        timesteps: List[int] = None,\n        dtype: Optional[torch.dtype] = None,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        r\"\"\"\n        Performs Algorithm 1: Controlled Forward ODE from https://huggingface.co/papers/2410.10792\n        Args:\n            image (`PipelineImageInput`):\n                Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect\n                ratio.\n            source_prompt (`str` or `List[str]`, *optional* defaults to an empty prompt as done in the original paper):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            source_guidance_scale (`float`, *optional*, defaults to 0.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). For this algorithm, it's better to keep it 0.\n            num_inversion_steps (`int`, *optional*, defaults to 28):\n                The number of discretization steps.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n            gamma (`float`, *optional*, defaults to 0.5):\n                The controller guidance for the forward ODE, balancing faithfulness & editability:\n                higher eta - better faithfullness, less editability. For more significant edits, lower the value of eta.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n        \"\"\"\n        dtype = dtype or self.text_encoder.dtype\n        batch_size = 1\n        self._joint_attention_kwargs = joint_attention_kwargs\n        num_channels_latents = self.transformer.config.in_channels // 4\n\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n        device = self._execution_device\n\n        # 1. prepare image\n        image_latents, _ = self.encode_image(image, height=height, width=width, dtype=dtype)\n        image_latents, latent_image_ids = self.prepare_latents_inversion(\n            batch_size, num_channels_latents, height, width, dtype, device, image_latents\n        )\n\n        # 2. prepare timesteps\n        sigmas = np.linspace(1.0, 1 / num_inversion_steps, num_inversion_steps)\n        image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)\n        mu = calculate_shift(\n            image_seq_len,\n            self.scheduler.config.get(\"base_image_seq_len\", 256),\n            self.scheduler.config.get(\"max_image_seq_len\", 4096),\n            self.scheduler.config.get(\"base_shift\", 0.5),\n            self.scheduler.config.get(\"max_shift\", 1.15),\n        )\n        timesteps, num_inversion_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inversion_steps,\n            device,\n            timesteps,\n            sigmas,\n            mu=mu,\n        )\n        timesteps, sigmas, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength)\n\n        # 3. prepare text embeddings\n        (\n            prompt_embeds,\n            pooled_prompt_embeds,\n            text_ids,\n        ) = self.encode_prompt(\n            prompt=source_prompt,\n            prompt_2=source_prompt,\n            device=device,\n        )\n        # 4. handle guidance\n        if self.transformer.config.guidance_embeds:\n            guidance = torch.full([1], source_guidance_scale, device=device, dtype=torch.float32)\n        else:\n            guidance = None\n\n        # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt\n        Y_t = image_latents\n        y_1 = torch.randn_like(Y_t)\n        N = len(sigmas)\n\n        # forward ODE loop\n        with self.progress_bar(total=N - 1) as progress_bar:\n            for i in range(N - 1):\n                t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device)\n                timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size)\n\n                # get the unconditional vector field\n                u_t_i = self.transformer(\n                    hidden_states=Y_t,\n                    timestep=timestep,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    joint_attention_kwargs=self.joint_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # get the conditional vector field\n                u_t_i_cond = (y_1 - Y_t) / (1 - t_i)\n\n                # controlled vector field\n                # Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dt\n                u_hat_t_i = u_t_i + gamma * (u_t_i_cond - u_t_i)\n                Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i + 1])\n                progress_bar.update()\n\n        # return the inverted latents (start point for the denoising loop), encoded image & latent image ids\n        return Y_t, image_latents, latent_image_ids\n"
  },
  {
    "path": "examples/community/pipeline_flux_semantic_guidance.py",
    "content": "# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTextModel,\n    CLIPTokenizer,\n    CLIPVisionModelWithProjection,\n    T5EncoderModel,\n    T5TokenizerFast,\n)\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin\nfrom diffusers.models.autoencoders import AutoencoderKL\nfrom diffusers.models.transformers import FluxTransformer2DModel\nfrom diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import DiffusionPipeline\n\n        >>> pipe = DiffusionPipeline.from_pretrained(\n        >>>     \"black-forest-labs/FLUX.1-dev\",\n        >>>     custom_pipeline=\"pipeline_flux_semantic_guidance\",\n        >>>     torch_dtype=torch.bfloat16\n        >>> )\n        >>> pipe.to(\"cuda\")\n        >>> prompt = \"A cat holding a sign that says hello world\"\n        >>> image = pipe(\n        >>>     prompt=prompt,\n        >>>     num_inference_steps=28,\n        >>>     guidance_scale=3.5,\n        >>>     editing_prompt=[\"cat\", \"dog\"],  # changes from cat to dog.\n        >>>     reverse_editing_direction=[True, False],\n        >>>     edit_warmup_steps=[6, 8],\n        >>>     edit_guidance_scale=[6, 6.5],\n        >>>     edit_threshold=[0.89, 0.89],\n        >>>     edit_cooldown_steps = [25, 27],\n        >>>     edit_momentum_scale=0.3,\n        >>>     edit_mom_beta=0.6,\n        >>>     generator=torch.Generator(device=\"cuda\").manual_seed(6543),\n        >>> ).images[0]\n        >>> image.save(\"semantic_flux.png\")\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.15,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass FluxSemanticGuidancePipeline(\n    DiffusionPipeline,\n    FluxLoraLoaderMixin,\n    FromSingleFileMixin,\n    TextualInversionLoaderMixin,\n    FluxIPAdapterMixin,\n):\n    r\"\"\"\n    The Flux pipeline for text-to-image generation with semantic guidance.\n\n    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/\n\n    Args:\n        transformer ([`FluxTransformer2DModel`]):\n            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([`T5EncoderModel`]):\n            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically\n            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`T5TokenizerFast`):\n            Second Tokenizer of class\n            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->image_encoder->transformer->vae\"\n    _optional_components = [\"image_encoder\", \"feature_extractor\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\"]\n\n    def __init__(\n        self,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        text_encoder_2: T5EncoderModel,\n        tokenizer_2: T5TokenizerFast,\n        transformer: FluxTransformer2DModel,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        feature_extractor: CLIPImageProcessor = None,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            transformer=transformer,\n            scheduler=scheduler,\n            image_encoder=image_encoder,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible\n        # by the patch size. So the vae scale factor is multiplied by the patch size to account for this\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)\n        self.tokenizer_max_length = (\n            self.tokenizer.model_max_length if hasattr(self, \"tokenizer\") and self.tokenizer is not None else 77\n        )\n        self.default_sample_size = 128\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_images_per_prompt: int = 1,\n        max_sequence_length: int = 512,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        if isinstance(self, TextualInversionLoaderMixin):\n            prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)\n\n        text_inputs = self.tokenizer_2(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer_2(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]\n\n        dtype = self.text_encoder_2.dtype\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        _, seq_len, _ = prompt_embeds.shape\n\n        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds\n    def _get_clip_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]],\n        num_images_per_prompt: int = 1,\n        device: Optional[torch.device] = None,\n    ):\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        if isinstance(self, TextualInversionLoaderMixin):\n            prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer_max_length,\n            truncation=True,\n            return_overflowing_tokens=False,\n            return_length=False,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer_max_length} tokens: {removed_text}\"\n            )\n        prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n        # Use pooled output of CLIPTextModel\n        prompt_embeds = prompt_embeds.pooler_output\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        prompt_2: Union[str, List[str]],\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        max_sequence_length: int = 512,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in all text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder, lora_scale)\n            if self.text_encoder_2 is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # We only use the pooled prompt output from the CLIPTextModel\n            pooled_prompt_embeds = self._get_clip_prompt_embeds(\n                prompt=prompt,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n            )\n            prompt_embeds = self._get_t5_prompt_embeds(\n                prompt=prompt_2,\n                num_images_per_prompt=num_images_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype\n        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\n\n        return prompt_embeds, pooled_prompt_embeds, text_ids\n\n    def encode_text_with_editing(\n        self,\n        prompt: Union[str, List[str]],\n        prompt_2: Union[str, List[str]],\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        editing_prompt: Optional[List[str]] = None,\n        editing_prompt_2: Optional[List[str]] = None,\n        editing_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        max_sequence_length: int = 512,\n        lora_scale: Optional[float] = None,\n    ):\n        \"\"\"\n        Encode text prompts with editing prompts and negative prompts for semantic guidance.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide image generation.\n            prompt_2 (`str` or `List[str]`):\n                The prompt or prompts to guide image generation for second tokenizer.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            editing_prompt (`str` or `List[str]`, *optional*):\n                The editing prompts for semantic guidance.\n            editing_prompt_2 (`str` or `List[str]`, *optional*):\n                The editing prompts for semantic guidance for second tokenizer.\n            editing_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-computed embeddings for editing prompts.\n            pooled_editing_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-computed pooled embeddings for editing prompts.\n            device (`torch.device`, *optional*):\n                The device to use for computation.\n            num_images_per_prompt (`int`, defaults to 1):\n                Number of images to generate per prompt.\n            max_sequence_length (`int`, defaults to 512):\n                Maximum sequence length for text encoding.\n            lora_scale (`float`, *optional*):\n                Scale factor for LoRA layers if used.\n\n        Returns:\n            tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, int]:\n                A tuple containing the prompt embeddings, pooled prompt embeddings,\n                text IDs, and number of enabled editing prompts.\n        \"\"\"\n        device = device or self._execution_device\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(\"Prompt must be provided as string or list of strings\")\n\n        # Get base prompt embeddings\n        prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            max_sequence_length=max_sequence_length,\n            lora_scale=lora_scale,\n        )\n\n        # Handle editing prompts\n        if editing_prompt_embeds is not None:\n            enabled_editing_prompts = int(editing_prompt_embeds.shape[0])\n            edit_text_ids = []\n        elif editing_prompt is not None:\n            editing_prompt_embeds = []\n            pooled_editing_prompt_embeds = []\n            edit_text_ids = []\n\n            editing_prompt_2 = editing_prompt if editing_prompt_2 is None else editing_prompt_2\n            for edit_1, edit_2 in zip(editing_prompt, editing_prompt_2):\n                e_prompt_embeds, pooled_embeds, e_ids = self.encode_prompt(\n                    prompt=edit_1,\n                    prompt_2=edit_2,\n                    device=device,\n                    num_images_per_prompt=num_images_per_prompt,\n                    max_sequence_length=max_sequence_length,\n                    lora_scale=lora_scale,\n                )\n                editing_prompt_embeds.append(e_prompt_embeds)\n                pooled_editing_prompt_embeds.append(pooled_embeds)\n                edit_text_ids.append(e_ids)\n\n            enabled_editing_prompts = len(editing_prompt)\n\n        else:\n            edit_text_ids = []\n            enabled_editing_prompts = 0\n\n        if enabled_editing_prompts:\n            for idx in range(enabled_editing_prompts):\n                editing_prompt_embeds[idx] = torch.cat([editing_prompt_embeds[idx]] * batch_size, dim=0)\n                pooled_editing_prompt_embeds[idx] = torch.cat([pooled_editing_prompt_embeds[idx]] * batch_size, dim=0)\n\n        return (\n            prompt_embeds,\n            pooled_prompt_embeds,\n            editing_prompt_embeds,\n            pooled_editing_prompt_embeds,\n            text_ids,\n            edit_text_ids,\n            enabled_editing_prompts,\n        )\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        image_embeds = self.image_encoder(image).image_embeds\n        image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n        return image_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt\n    ):\n        image_embeds = []\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers\n            ):\n                single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)\n\n                image_embeds.append(single_image_embeds[None, :])\n        else:\n            for single_image_embeds in ip_adapter_image_embeds:\n                image_embeds.append(single_image_embeds)\n\n        ip_adapter_image_embeds = []\n        for i, single_image_embeds in enumerate(image_embeds):\n            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)\n            single_image_embeds = single_image_embeds.to(device=device)\n            ip_adapter_image_embeds.append(single_image_embeds)\n\n        return ip_adapter_image_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        height,\n        width,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        max_sequence_length=None,\n    ):\n        if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:\n            logger.warning(\n                f\"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n        if max_sequence_length is not None and max_sequence_length > 512:\n            raise ValueError(f\"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}\")\n\n    @staticmethod\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids\n    def _prepare_latent_image_ids(batch_size, height, width, device, dtype):\n        latent_image_ids = torch.zeros(height, width, 3)\n        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]\n        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]\n\n        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape\n\n        latent_image_ids = latent_image_ids.reshape(\n            latent_image_id_height * latent_image_id_width, latent_image_id_channels\n        )\n\n        return latent_image_ids.to(device=device, dtype=dtype)\n\n    @staticmethod\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents\n    def _pack_latents(latents, batch_size, num_channels_latents, height, width):\n        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)\n        latents = latents.permute(0, 2, 4, 1, 3, 5)\n        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)\n\n        return latents\n\n    @staticmethod\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents\n    def _unpack_latents(latents, height, width, vae_scale_factor):\n        batch_size, num_patches, channels = latents.shape\n\n        # VAE applies 8x compression on images but we must also account for packing which requires\n        # latent height and width to be divisible by 2.\n        height = 2 * (int(height) // (vae_scale_factor * 2))\n        width = 2 * (int(width) // (vae_scale_factor * 2))\n\n        latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)\n        latents = latents.permute(0, 3, 1, 4, 2, 5)\n\n        latents = latents.reshape(batch_size, channels // (2 * 2), height, width)\n\n        return latents\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        self.vae.enable_slicing()\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_slicing()\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`.\"\n        deprecate(\n            \"enable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_tiling()\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`.\"\n        deprecate(\n            \"disable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_tiling()\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        # VAE applies 8x compression on images but we must also account for packing which requires\n        # latent height and width to be divisible by 2.\n        height = 2 * (int(height) // (self.vae_scale_factor * 2))\n        width = 2 * (int(width) // (self.vae_scale_factor * 2))\n\n        shape = (batch_size, num_channels_latents, height, width)\n\n        if latents is not None:\n            latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)\n            return latents.to(device=device, dtype=dtype), latent_image_ids\n\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)\n\n        latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)\n\n        return latents, latent_image_ids\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def joint_attention_kwargs(self):\n        return self._joint_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        negative_prompt: Union[str, List[str]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        true_cfg_scale: float = 1.0,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 28,\n        sigmas: Optional[List[float]] = None,\n        guidance_scale: float = 3.5,\n        num_images_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        negative_ip_adapter_image: Optional[PipelineImageInput] = None,\n        negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 512,\n        editing_prompt: Optional[Union[str, List[str]]] = None,\n        editing_prompt_2: Optional[Union[str, List[str]]] = None,\n        editing_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_editing_prompt_embeds: Optional[torch.FloatTensor] = None,\n        reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,\n        edit_guidance_scale: Optional[Union[float, List[float]]] = 5,\n        edit_warmup_steps: Optional[Union[int, List[int]]] = 8,\n        edit_cooldown_steps: Optional[Union[int, List[int]]] = None,\n        edit_threshold: Optional[Union[float, List[float]]] = 0.9,\n        edit_momentum_scale: Optional[float] = 0.1,\n        edit_mom_beta: Optional[float] = 0.4,\n        edit_weights: Optional[List[float]] = None,\n        sem_guidance: Optional[List[torch.Tensor]] = None,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                will be used instead.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is\n                not greater than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.\n            true_cfg_scale (`float`, *optional*, defaults to 1.0):\n                When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            guidance_scale (`float`, *optional*, defaults to 7.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            negative_ip_adapter_image:\n                (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.\n            joint_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.\n            editing_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image editing. If not defined, no editing will be performed.\n            editing_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image editing. If not defined, will use editing_prompt instead.\n            editing_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings for editing. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, text embeddings will be generated from `editing_prompt` input argument.\n            reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`):\n                Whether to reverse the editing direction for each editing prompt.\n            edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):\n                Guidance scale for the editing process. If provided as a list, each value corresponds to an editing prompt.\n            edit_warmup_steps (`int` or `List[int]`, *optional*, defaults to 10):\n                Number of warmup steps for editing guidance. If provided as a list, each value corresponds to an editing prompt.\n            edit_cooldown_steps (`int` or `List[int]`, *optional*, defaults to None):\n                Number of cooldown steps for editing guidance. If provided as a list, each value corresponds to an editing prompt.\n            edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):\n                Threshold for editing guidance. If provided as a list, each value corresponds to an editing prompt.\n            edit_momentum_scale (`float`, *optional*, defaults to 0.1):\n                Scale of momentum to be added to the editing guidance at each diffusion step.\n            edit_mom_beta (`float`, *optional*, defaults to 0.4):\n                Beta value for momentum calculation in editing guidance.\n            edit_weights (`List[float]`, *optional*):\n                Weights for each editing prompt.\n            sem_guidance (`List[torch.Tensor]`, *optional*):\n                Pre-generated semantic guidance. If provided, it will be used instead of calculating guidance from editing prompts.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`\n            is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated\n            images.\n        \"\"\"\n\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            height,\n            width,\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            max_sequence_length=max_sequence_length,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._joint_attention_kwargs = joint_attention_kwargs\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if editing_prompt:\n            enable_edit_guidance = True\n            if isinstance(editing_prompt, str):\n                editing_prompt = [editing_prompt]\n            enabled_editing_prompts = len(editing_prompt)\n        elif editing_prompt_embeds is not None:\n            enable_edit_guidance = True\n            enabled_editing_prompts = editing_prompt_embeds.shape[0]\n        else:\n            enabled_editing_prompts = 0\n            enable_edit_guidance = False\n\n        has_neg_prompt = negative_prompt is not None or (\n            negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None\n        )\n        do_true_cfg = true_cfg_scale > 1 and has_neg_prompt\n\n        device = self._execution_device\n\n        lora_scale = (\n            self.joint_attention_kwargs.get(\"scale\", None) if self.joint_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            pooled_prompt_embeds,\n            editing_prompts_embeds,\n            pooled_editing_prompt_embeds,\n            text_ids,\n            edit_text_ids,\n            enabled_editing_prompts,\n        ) = self.encode_text_with_editing(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            editing_prompt=editing_prompt,\n            editing_prompt_2=editing_prompt_2,\n            pooled_editing_prompt_embeds=pooled_editing_prompt_embeds,\n            lora_scale=lora_scale,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            max_sequence_length=max_sequence_length,\n        )\n\n        if do_true_cfg:\n            (\n                negative_prompt_embeds,\n                negative_pooled_prompt_embeds,\n                _,\n            ) = self.encode_prompt(\n                prompt=negative_prompt,\n                prompt_2=negative_prompt_2,\n                prompt_embeds=negative_prompt_embeds,\n                pooled_prompt_embeds=negative_pooled_prompt_embeds,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n                max_sequence_length=max_sequence_length,\n                lora_scale=lora_scale,\n            )\n            negative_prompt_embeds = torch.cat([negative_prompt_embeds] * batch_size, dim=0)\n            negative_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds] * batch_size, dim=0)\n\n        # 4. Prepare latent variables\n        num_channels_latents = self.transformer.config.in_channels // 4\n        latents, latent_image_ids = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 5. Prepare timesteps\n        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas\n        image_seq_len = latents.shape[1]\n        mu = calculate_shift(\n            image_seq_len,\n            self.scheduler.config.get(\"base_image_seq_len\", 256),\n            self.scheduler.config.get(\"max_image_seq_len\", 4096),\n            self.scheduler.config.get(\"base_shift\", 0.5),\n            self.scheduler.config.get(\"max_shift\", 1.15),\n        )\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            sigmas=sigmas,\n            mu=mu,\n        )\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n        self._num_timesteps = len(timesteps)\n\n        edit_momentum = None\n        if edit_warmup_steps:\n            tmp_e_warmup_steps = edit_warmup_steps if isinstance(edit_warmup_steps, list) else [edit_warmup_steps]\n            min_edit_warmup_steps = min(tmp_e_warmup_steps)\n        else:\n            min_edit_warmup_steps = 0\n\n        if edit_cooldown_steps:\n            tmp_e_cooldown_steps = (\n                edit_cooldown_steps if isinstance(edit_cooldown_steps, list) else [edit_cooldown_steps]\n            )\n            max_edit_cooldown_steps = min(max(tmp_e_cooldown_steps), num_inference_steps)\n        else:\n            max_edit_cooldown_steps = num_inference_steps\n\n        # handle guidance\n        if self.transformer.config.guidance_embeds:\n            guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)\n            guidance = guidance.expand(latents.shape[0])\n        else:\n            guidance = None\n\n        if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (\n            negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None\n        ):\n            negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)\n        elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (\n            negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None\n        ):\n            ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)\n\n        if self.joint_attention_kwargs is None:\n            self._joint_attention_kwargs = {}\n\n        image_embeds = None\n        negative_image_embeds = None\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n            )\n        if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:\n            negative_image_embeds = self.prepare_ip_adapter_image_embeds(\n                negative_ip_adapter_image,\n                negative_ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n            )\n\n        # 6. Denoising loop\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                if image_embeds is not None:\n                    self._joint_attention_kwargs[\"ip_adapter_image_embeds\"] = image_embeds\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latents.shape[0]).to(latents.dtype)\n\n                # handle guidance\n                if self.transformer.config.guidance_embeds:\n                    guidance = torch.tensor([guidance_scale], device=device)\n                    guidance = guidance.expand(latents.shape[0])\n                else:\n                    guidance = None\n\n                noise_pred = self.transformer(\n                    hidden_states=latents,\n                    timestep=timestep / 1000,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    joint_attention_kwargs=self.joint_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps:\n                    noise_pred_edit_concepts = []\n                    for e_embed, pooled_e_embed, e_text_id in zip(\n                        editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids\n                    ):\n                        noise_pred_edit = self.transformer(\n                            hidden_states=latents,\n                            timestep=timestep / 1000,\n                            guidance=guidance,\n                            pooled_projections=pooled_e_embed,\n                            encoder_hidden_states=e_embed,\n                            txt_ids=e_text_id,\n                            img_ids=latent_image_ids,\n                            joint_attention_kwargs=self.joint_attention_kwargs,\n                            return_dict=False,\n                        )[0]\n                        noise_pred_edit_concepts.append(noise_pred_edit)\n\n                if do_true_cfg:\n                    if negative_image_embeds is not None:\n                        self._joint_attention_kwargs[\"ip_adapter_image_embeds\"] = negative_image_embeds\n                    noise_pred_uncond = self.transformer(\n                        hidden_states=latents,\n                        timestep=timestep / 1000,\n                        guidance=guidance,\n                        pooled_projections=negative_pooled_prompt_embeds,\n                        encoder_hidden_states=negative_prompt_embeds,\n                        txt_ids=text_ids,\n                        img_ids=latent_image_ids,\n                        joint_attention_kwargs=self.joint_attention_kwargs,\n                        return_dict=False,\n                    )[0]\n                    noise_guidance = true_cfg_scale * (noise_pred - noise_pred_uncond)\n                else:\n                    noise_pred_uncond = noise_pred\n                    noise_guidance = noise_pred\n\n                if edit_momentum is None:\n                    edit_momentum = torch.zeros_like(noise_guidance)\n\n                if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps:\n                    concept_weights = torch.zeros(\n                        (enabled_editing_prompts, noise_guidance.shape[0]),\n                        device=device,\n                        dtype=noise_guidance.dtype,\n                    )\n                    noise_guidance_edit = torch.zeros(\n                        (enabled_editing_prompts, *noise_guidance.shape),\n                        device=device,\n                        dtype=noise_guidance.dtype,\n                    )\n\n                    warmup_inds = []\n                    for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):\n                        if isinstance(edit_guidance_scale, list):\n                            edit_guidance_scale_c = edit_guidance_scale[c]\n                        else:\n                            edit_guidance_scale_c = edit_guidance_scale\n\n                        if isinstance(edit_threshold, list):\n                            edit_threshold_c = edit_threshold[c]\n                        else:\n                            edit_threshold_c = edit_threshold\n                        if isinstance(reverse_editing_direction, list):\n                            reverse_editing_direction_c = reverse_editing_direction[c]\n                        else:\n                            reverse_editing_direction_c = reverse_editing_direction\n                        if edit_weights:\n                            edit_weight_c = edit_weights[c]\n                        else:\n                            edit_weight_c = 1.0\n                        if isinstance(edit_warmup_steps, list):\n                            edit_warmup_steps_c = edit_warmup_steps[c]\n                        else:\n                            edit_warmup_steps_c = edit_warmup_steps\n\n                        if isinstance(edit_cooldown_steps, list):\n                            edit_cooldown_steps_c = edit_cooldown_steps[c]\n                        elif edit_cooldown_steps is None:\n                            edit_cooldown_steps_c = i + 1\n                        else:\n                            edit_cooldown_steps_c = edit_cooldown_steps\n                        if i >= edit_warmup_steps_c:\n                            warmup_inds.append(c)\n                        if i >= edit_cooldown_steps_c:\n                            noise_guidance_edit[c, :, :, :] = torch.zeros_like(noise_pred_edit_concept)\n                            continue\n\n                        if do_true_cfg:\n                            noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond\n                        else:  # simple sega\n                            noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred\n                        tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2))\n\n                        tmp_weights = torch.full_like(tmp_weights, edit_weight_c)  # * (1 / enabled_editing_prompts)\n                        if reverse_editing_direction_c:\n                            noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1\n                        concept_weights[c, :] = tmp_weights\n\n                        noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c\n\n                        # torch.quantile function expects float32\n                        if noise_guidance_edit_tmp.dtype == torch.float32:\n                            tmp = torch.quantile(\n                                torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2),\n                                edit_threshold_c,\n                                dim=2,\n                                keepdim=False,\n                            )\n                        else:\n                            tmp = torch.quantile(\n                                torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32),\n                                edit_threshold_c,\n                                dim=2,\n                                keepdim=False,\n                            ).to(noise_guidance_edit_tmp.dtype)\n\n                        noise_guidance_edit_tmp = torch.where(\n                            torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None],\n                            noise_guidance_edit_tmp,\n                            torch.zeros_like(noise_guidance_edit_tmp),\n                        )\n\n                        noise_guidance_edit[c, :, :, :] = noise_guidance_edit_tmp\n\n                    warmup_inds = torch.tensor(warmup_inds).to(device)\n                    if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:\n                        concept_weights = concept_weights.to(\"cpu\")  # Offload to cpu\n                        noise_guidance_edit = noise_guidance_edit.to(\"cpu\")\n\n                        concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds)\n                        concept_weights_tmp = torch.where(\n                            concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp\n                        )\n                        concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)\n\n                        noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds)\n                        noise_guidance_edit_tmp = torch.einsum(\n                            \"cb,cbij->bij\", concept_weights_tmp, noise_guidance_edit_tmp\n                        )\n                        noise_guidance_edit_tmp = noise_guidance_edit_tmp\n                        noise_guidance = noise_guidance + noise_guidance_edit_tmp\n\n                        del noise_guidance_edit_tmp\n                        del concept_weights_tmp\n                        concept_weights = concept_weights.to(device)\n                        noise_guidance_edit = noise_guidance_edit.to(device)\n\n                    concept_weights = torch.where(\n                        concept_weights < 0, torch.zeros_like(concept_weights), concept_weights\n                    )\n\n                    concept_weights = torch.nan_to_num(concept_weights)\n\n                    noise_guidance_edit = torch.einsum(\"cb,cbij->bij\", concept_weights, noise_guidance_edit)\n\n                    noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum\n\n                    edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit\n\n                    if warmup_inds.shape[0] == len(noise_pred_edit_concepts):\n                        noise_guidance = noise_guidance + noise_guidance_edit\n\n                if sem_guidance is not None:\n                    edit_guidance = sem_guidance[i].to(device)\n                    noise_guidance = noise_guidance + edit_guidance\n\n                if do_true_cfg:\n                    noise_pred = noise_guidance + noise_pred_uncond\n                else:\n                    noise_pred = noise_guidance\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if output_type == \"latent\":\n            image = latents\n\n        else:\n            latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)\n            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor\n            image = self.vae.decode(latents, return_dict=False)[0]\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return FluxPipelineOutput(\n            image,\n        )\n"
  },
  {
    "path": "examples/community/pipeline_flux_with_cfg.py",
    "content": "# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast\n\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin\nfrom diffusers.models.autoencoders import AutoencoderKL\nfrom diffusers.models.transformers import FluxTransformer2DModel\nfrom diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import FluxPipeline\n\n        >>> pipe = FluxPipeline.from_pretrained(\"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16)\n        >>> pipe.to(\"cuda\")\n        >>> prompt = \"A cat holding a sign that says hello world\"\n        >>> # Depending on the variant being used, the pipeline call will slightly vary.\n        >>> # Refer to the pipeline documentation for more details.\n        >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]\n        >>> image.save(\"flux.png\")\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.15,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):\n    r\"\"\"\n    The Flux pipeline for text-to-image generation.\n\n    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/\n\n    Args:\n        transformer ([`FluxTransformer2DModel`]):\n            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([`T5EncoderModel`]):\n            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically\n            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`T5TokenizerFast`):\n            Second Tokenizer of class\n            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->transformer->vae\"\n    _optional_components = []\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\"]\n\n    def __init__(\n        self,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        text_encoder_2: T5EncoderModel,\n        tokenizer_2: T5TokenizerFast,\n        transformer: FluxTransformer2DModel,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            transformer=transformer,\n            scheduler=scheduler,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, \"vae\", None) else 16\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.tokenizer_max_length = (\n            self.tokenizer.model_max_length if hasattr(self, \"tokenizer\") and self.tokenizer is not None else 77\n        )\n        self.default_sample_size = 64\n\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_images_per_prompt: int = 1,\n        max_sequence_length: int = 512,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        text_inputs = self.tokenizer_2(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer_2(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]\n\n        dtype = self.text_encoder_2.dtype\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        _, seq_len, _ = prompt_embeds.shape\n\n        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        return prompt_embeds\n\n    def _get_clip_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]],\n        num_images_per_prompt: int = 1,\n        device: Optional[torch.device] = None,\n    ):\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer_max_length,\n            truncation=True,\n            return_overflowing_tokens=False,\n            return_length=False,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer_max_length} tokens: {removed_text}\"\n            )\n        prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n        # Use pooled output of CLIPTextModel\n        prompt_embeds = prompt_embeds.pooler_output\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n        return prompt_embeds\n\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        prompt_2: Union[str, List[str]],\n        negative_prompt: Union[str, List[str]] = None,\n        negative_prompt_2: Union[str, List[str]] = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        max_sequence_length: int = 512,\n        lora_scale: Optional[float] = None,\n        do_true_cfg: bool = False,\n    ):\n        device = device or self._execution_device\n\n        # Set LoRA scale if applicable\n        if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            if self.text_encoder is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder, lora_scale)\n            if self.text_encoder_2 is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        if do_true_cfg and negative_prompt is not None:\n            negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n            negative_batch_size = len(negative_prompt)\n\n            if negative_batch_size != batch_size:\n                raise ValueError(\n                    f\"Negative prompt batch size ({negative_batch_size}) does not match prompt batch size ({batch_size})\"\n                )\n\n            # Concatenate prompts\n            prompts = prompt + negative_prompt\n            prompts_2 = (\n                prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None\n            )\n        else:\n            prompts = prompt\n            prompts_2 = prompt_2\n\n        if prompt_embeds is None:\n            if prompts_2 is None:\n                prompts_2 = prompts\n\n            # Get pooled prompt embeddings from CLIPTextModel\n            pooled_prompt_embeds = self._get_clip_prompt_embeds(\n                prompt=prompts,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n            )\n            prompt_embeds = self._get_t5_prompt_embeds(\n                prompt=prompts_2,\n                num_images_per_prompt=num_images_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n            )\n\n            if do_true_cfg and negative_prompt is not None:\n                # Split embeddings back into positive and negative parts\n                total_batch_size = batch_size * num_images_per_prompt\n                positive_indices = slice(0, total_batch_size)\n                negative_indices = slice(total_batch_size, 2 * total_batch_size)\n\n                positive_pooled_prompt_embeds = pooled_prompt_embeds[positive_indices]\n                negative_pooled_prompt_embeds = pooled_prompt_embeds[negative_indices]\n\n                positive_prompt_embeds = prompt_embeds[positive_indices]\n                negative_prompt_embeds = prompt_embeds[negative_indices]\n\n                pooled_prompt_embeds = positive_pooled_prompt_embeds\n                prompt_embeds = positive_prompt_embeds\n\n        # Unscale LoRA layers\n        if self.text_encoder is not None:\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype\n        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\n\n        if do_true_cfg and negative_prompt is not None:\n            return (\n                prompt_embeds,\n                pooled_prompt_embeds,\n                text_ids,\n                negative_prompt_embeds,\n                negative_pooled_prompt_embeds,\n            )\n        else:\n            return prompt_embeds, pooled_prompt_embeds, text_ids, None, None\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        height,\n        width,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        max_sequence_length=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n        if max_sequence_length is not None and max_sequence_length > 512:\n            raise ValueError(f\"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}\")\n\n    @staticmethod\n    def _prepare_latent_image_ids(batch_size, height, width, device, dtype):\n        latent_image_ids = torch.zeros(height // 2, width // 2, 3)\n        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]\n        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]\n\n        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape\n\n        latent_image_ids = latent_image_ids.reshape(\n            latent_image_id_height * latent_image_id_width, latent_image_id_channels\n        )\n\n        return latent_image_ids.to(device=device, dtype=dtype)\n\n    @staticmethod\n    def _pack_latents(latents, batch_size, num_channels_latents, height, width):\n        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)\n        latents = latents.permute(0, 2, 4, 1, 3, 5)\n        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)\n\n        return latents\n\n    @staticmethod\n    def _unpack_latents(latents, height, width, vae_scale_factor):\n        batch_size, num_patches, channels = latents.shape\n\n        height = height // vae_scale_factor\n        width = width // vae_scale_factor\n\n        latents = latents.view(batch_size, height, width, channels // 4, 2, 2)\n        latents = latents.permute(0, 3, 1, 4, 2, 5)\n\n        latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)\n\n        return latents\n\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`.\"\n        deprecate(\n            \"enable_vae_slicing\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_slicing()\n\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`.\"\n        deprecate(\n            \"disable_vae_slicing\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_slicing()\n\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`.\"\n        deprecate(\n            \"enable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_tiling()\n\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`.\"\n        deprecate(\n            \"disable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_tiling()\n\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        height = 2 * (int(height) // self.vae_scale_factor)\n        width = 2 * (int(width) // self.vae_scale_factor)\n\n        shape = (batch_size, num_channels_latents, height, width)\n\n        if latents is not None:\n            latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)\n            return latents.to(device=device, dtype=dtype), latent_image_ids\n\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)\n\n        latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)\n\n        return latents, latent_image_ids\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def joint_attention_kwargs(self):\n        return self._joint_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        negative_prompt: Union[str, List[str]] = None,  #\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        true_cfg: float = 1.0,  #\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 28,\n        timesteps: List[int] = None,\n        guidance_scale: float = 3.5,\n        num_images_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 512,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                will be used instead\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.\n            joint_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`\n            is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated\n            images.\n        \"\"\"\n\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            height,\n            width,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            max_sequence_length=max_sequence_length,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._joint_attention_kwargs = joint_attention_kwargs\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        lora_scale = (\n            self.joint_attention_kwargs.get(\"scale\", None) if self.joint_attention_kwargs is not None else None\n        )\n        do_true_cfg = true_cfg > 1 and negative_prompt is not None\n        (\n            prompt_embeds,\n            pooled_prompt_embeds,\n            text_ids,\n            negative_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            max_sequence_length=max_sequence_length,\n            lora_scale=lora_scale,\n            do_true_cfg=do_true_cfg,\n        )\n\n        if do_true_cfg:\n            # Concatenate embeddings\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)\n\n        # 4. Prepare latent variables\n        num_channels_latents = self.transformer.config.in_channels // 4\n        latents, latent_image_ids = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 5. Prepare timesteps\n        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)\n        image_seq_len = latents.shape[1]\n        mu = calculate_shift(\n            image_seq_len,\n            self.scheduler.config.get(\"base_image_seq_len\", 256),\n            self.scheduler.config.get(\"max_image_seq_len\", 4096),\n            self.scheduler.config.get(\"base_shift\", 0.5),\n            self.scheduler.config.get(\"max_shift\", 1.15),\n        )\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            timesteps,\n            sigmas,\n            mu=mu,\n        )\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n        self._num_timesteps = len(timesteps)\n\n        # 6. Denoising loop\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents\n\n                # handle guidance\n                if self.transformer.config.guidance_embeds:\n                    guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)\n                    guidance = guidance.expand(latent_model_input.shape[0])\n                else:\n                    guidance = None\n\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)\n\n                noise_pred = self.transformer(\n                    hidden_states=latent_model_input,\n                    timestep=timestep / 1000,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    joint_attention_kwargs=self.joint_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                if do_true_cfg:\n                    neg_noise_pred, noise_pred = noise_pred.chunk(2)\n                    noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if output_type == \"latent\":\n            image = latents\n\n        else:\n            latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)\n            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor\n            image = self.vae.decode(latents, return_dict=False)[0]\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return FluxPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_hunyuandit_differential_img2img.py",
    "content": "# Copyright 2025 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom transformers import (\n    BertModel,\n    BertTokenizer,\n    CLIPImageProcessor,\n    T5EncoderModel,\n    T5Tokenizer,\n)\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.models import AutoencoderKL, HunyuanDiT2DModel\nfrom diffusers.models.embeddings import get_2d_rotary_pos_embed\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import (\n    StableDiffusionSafetyChecker,\n)\nfrom diffusers.schedulers import DDPMScheduler\nfrom diffusers.utils import (\n    deprecate,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import FlowMatchEulerDiscreteScheduler\n        >>> from diffusers.utils import load_image\n        >>> from PIL import Image\n        >>> from torchvision import transforms\n        >>> from pipeline_hunyuandit_differential_img2img import HunyuanDiTDifferentialImg2ImgPipeline\n        >>> pipe = HunyuanDiTDifferentialImg2ImgPipeline.from_pretrained(\n        >>>     \"Tencent-Hunyuan/HunyuanDiT-Diffusers\", torch_dtype=torch.float16\n        >>> ).to(\"cuda\")\n        >>> source_image = load_image(\n        >>>     \"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png\"\n        >>> )\n        >>> map = load_image(\n        >>>     \"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask_2.png\"\n        >>> )\n        >>> prompt = \"a green pear\"\n        >>> negative_prompt = \"blurry\"\n        >>> image = pipe(\n        >>>     prompt=prompt,\n        >>>     negative_prompt=negative_prompt,\n        >>>     image=source_image,\n        >>>     num_inference_steps=28,\n        >>>     guidance_scale=4.5,\n        >>>     strength=1.0,\n        >>>     map=map,\n        >>> ).images[0]\n\n        ```\n\"\"\"\n\nSTANDARD_RATIO = np.array(\n    [\n        1.0,  # 1:1\n        4.0 / 3.0,  # 4:3\n        3.0 / 4.0,  # 3:4\n        16.0 / 9.0,  # 16:9\n        9.0 / 16.0,  # 9:16\n    ]\n)\nSTANDARD_SHAPE = [\n    [(1024, 1024), (1280, 1280)],  # 1:1\n    [(1024, 768), (1152, 864), (1280, 960)],  # 4:3\n    [(768, 1024), (864, 1152), (960, 1280)],  # 3:4\n    [(1280, 768)],  # 16:9\n    [(768, 1280)],  # 9:16\n]\nSTANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]\nSUPPORTED_SHAPE = [\n    (1024, 1024),\n    (1280, 1280),  # 1:1\n    (1024, 768),\n    (1152, 864),\n    (1280, 960),  # 4:3\n    (768, 1024),\n    (864, 1152),\n    (960, 1280),  # 3:4\n    (1280, 768),  # 16:9\n    (768, 1280),  # 9:16\n]\n\n\ndef map_to_standard_shapes(target_width, target_height):\n    target_ratio = target_width / target_height\n    closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))\n    closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))\n    width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]\n    return width, height\n\n\ndef get_resize_crop_region_for_grid(src, tgt_size):\n    th = tw = tgt_size\n    h, w = src\n\n    r = h / w\n\n    # resize\n    if r > 1:\n        resize_height = th\n        resize_width = int(round(th / h * w))\n    else:\n        resize_width = tw\n        resize_height = int(round(tw / w * h))\n\n    crop_top = int(round((th - resize_height) / 2.0))\n    crop_left = int(round((tw - resize_width) / 2.0))\n\n    return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor,\n    generator: torch.Generator | None = None,\n    sample_mode: str = \"sample\",\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):\n    r\"\"\"\n    Differential Pipeline for English/Chinese-to-image generation using HunyuanDiT.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    HunyuanDiT uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by\n    ourselves)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use\n            `sdxl-vae-fp16-fix`.\n        text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n            HunyuanDiT uses a fine-tuned [bilingual CLIP].\n        tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):\n            A `BertTokenizer` or `CLIPTokenizer` to tokenize text.\n        transformer ([`HunyuanDiT2DModel`]):\n            The HunyuanDiT model designed by Tencent Hunyuan.\n        text_encoder_2 (`T5EncoderModel`):\n            The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.\n        tokenizer_2 (`T5Tokenizer`):\n            The tokenizer for the mT5 embedder.\n        scheduler ([`DDPMScheduler`]):\n            A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->transformer->vae\"\n    _optional_components = [\n        \"safety_checker\",\n        \"feature_extractor\",\n        \"text_encoder_2\",\n        \"tokenizer_2\",\n        \"text_encoder\",\n        \"tokenizer\",\n    ]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"prompt_embeds_2\",\n        \"negative_prompt_embeds_2\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: BertModel,\n        tokenizer: BertTokenizer,\n        transformer: HunyuanDiT2DModel,\n        scheduler: DDPMScheduler,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n        text_encoder_2=T5EncoderModel,\n        tokenizer_2=T5Tokenizer,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            transformer=transformer,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            text_encoder_2=text_encoder_2,\n        )\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.mask_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor,\n            do_normalize=False,\n            do_convert_grayscale=True,\n        )\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n        self.default_sample_size = (\n            self.transformer.config.sample_size\n            if hasattr(self, \"transformer\") and self.transformer is not None\n            else 128\n        )\n\n    # copied from diffusers.pipelines.huanyuandit.pipeline_huanyuandit.HunyuanDiTPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: str,\n        device: torch.device = None,\n        dtype: torch.dtype = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str | None = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        max_sequence_length: Optional[int] = None,\n        text_encoder_index: int = 0,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            dtype (`torch.dtype`):\n                torch dtype\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            prompt_attention_mask (`torch.Tensor`, *optional*):\n                Attention mask for the prompt. Required when `prompt_embeds` is passed directly.\n            negative_prompt_attention_mask (`torch.Tensor`, *optional*):\n                Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.\n            max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.\n            text_encoder_index (`int`, *optional*):\n                Index of the text encoder to use. `0` for clip and `1` for T5.\n        \"\"\"\n        if dtype is None:\n            if self.text_encoder_2 is not None:\n                dtype = self.text_encoder_2.dtype\n            elif self.transformer is not None:\n                dtype = self.transformer.dtype\n            else:\n                dtype = None\n\n        if device is None:\n            device = self._execution_device\n\n        tokenizers = [self.tokenizer, self.tokenizer_2]\n        text_encoders = [self.text_encoder, self.text_encoder_2]\n\n        tokenizer = tokenizers[text_encoder_index]\n        text_encoder = text_encoders[text_encoder_index]\n\n        if max_sequence_length is None:\n            if text_encoder_index == 0:\n                max_length = 77\n            if text_encoder_index == 1:\n                max_length = 256\n        else:\n            max_length = max_sequence_length\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            text_inputs = tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_attention_mask=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            prompt_attention_mask = text_inputs.attention_mask.to(device)\n            prompt_embeds = text_encoder(\n                text_input_ids.to(device),\n                attention_mask=prompt_attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n            prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)\n\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            negative_prompt_attention_mask = uncond_input.attention_mask.to(device)\n            negative_prompt_embeds = text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=negative_prompt_attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n            negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        return (\n            prompt_embeds,\n            negative_prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_attention_mask,\n        )\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        prompt_attention_mask=None,\n        negative_prompt_attention_mask=None,\n        prompt_embeds_2=None,\n        negative_prompt_embeds_2=None,\n        prompt_attention_mask_2=None,\n        negative_prompt_attention_mask_2=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is None and prompt_embeds_2 is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if prompt_embeds is not None and prompt_attention_mask is None:\n            raise ValueError(\"Must provide `prompt_attention_mask` when specifying `prompt_embeds`.\")\n\n        if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:\n            raise ValueError(\"Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:\n            raise ValueError(\"Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.\")\n\n        if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:\n            raise ValueError(\n                \"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`.\"\n            )\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n        if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:\n            if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:\n                raise ValueError(\n                    \"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`\"\n                    f\" {negative_prompt_embeds_2.shape}.\"\n                )\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n        if hasattr(self.scheduler, \"set_begin_index\"):\n            self.scheduler.set_begin_index(t_start * self.scheduler.order)\n\n        return timesteps, num_inference_steps - t_start\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        image,\n        timestep,\n        dtype,\n        device,\n        generator=None,\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n\n        image = image.to(device=device, dtype=dtype)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n        elif isinstance(generator, list):\n            init_latents = [\n                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)\n            ]\n            init_latents = torch.cat(init_latents, dim=0)\n\n        else:\n            init_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n        init_latents = init_latents * self.vae.config.scaling_factor\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            deprecation_message = (\n                f\"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial\"\n                \" images (`image`). Initial images are now duplicating to match the number of text prompts. Note\"\n                \" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update\"\n                \" your script to pass as many initial images as text prompts to suppress this warning.\"\n            )\n            deprecate(\n                \"len(prompt) != len(image)\",\n                \"1.0.0\",\n                deprecation_message,\n                standard_warn=False,\n            )\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        shape = init_latents.shape\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        # get latents\n        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n        latents = init_latents\n\n        return latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        strength: float = 0.8,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: Optional[int] = 50,\n        timesteps: List[int] = None,\n        sigmas: List[float] = None,\n        guidance_scale: Optional[float] = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: Optional[float] = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_embeds_2: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds_2: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        prompt_attention_mask_2: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback_on_step_end: Optional[\n            Union[\n                Callable[[int, int, Dict], None],\n                PipelineCallback,\n                MultiPipelineCallbacks,\n            ]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = (1024, 1024),\n        target_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        use_resolution_binning: bool = True,\n        map: PipelineImageInput = None,\n        denoising_start: Optional[float] = None,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation with HunyuanDiT.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):\n                `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both\n                numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list\n                or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a\n                list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image\n                latents as `image`, but if passing latents directly it is not encoded again.\n            strength (`float`, *optional*, defaults to 0.8):\n                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1\n                essentially ignores `image`.\n            height (`int`):\n                The height in pixels of the generated image.\n            width (`int`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference. This parameter is modulated by `strength`.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            prompt_embeds_2 (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            negative_prompt_embeds_2 (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            prompt_attention_mask (`torch.Tensor`, *optional*):\n                Attention mask for the prompt. Required when `prompt_embeds` is passed directly.\n            prompt_attention_mask_2 (`torch.Tensor`, *optional*):\n                Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly.\n            negative_prompt_attention_mask (`torch.Tensor`, *optional*):\n                Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.\n            negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):\n                Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A callback function or a list of callback functions to be called at the end of each denoising step.\n            callback_on_step_end_tensor_inputs (`List[str]`, *optional*):\n                A list of tensor inputs that should be passed to the callback function. If not defined, all tensor\n                inputs will be passed.\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise\n                Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n            original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):\n                The original size of the image. Used to calculate the time ids.\n            target_size (`Tuple[int, int]`, *optional*):\n                The target size of the image. Used to calculate the time ids.\n            crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):\n                The top left coordinates of the crop. Used to calculate the time ids.\n            use_resolution_binning (`bool`, *optional*, defaults to `True`):\n                Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest\n                standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960,\n                768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`.\n            denoising_start (`float`, *optional*):\n                When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be\n                bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and\n                it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,\n                strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline\n                is integrated into a \"Mixture of Denoisers\" multi-pipeline setup, as detailed in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        # 0. default height and width\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n        height = int((height // 16) * 16)\n        width = int((width // 16) * 16)\n\n        if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE:\n            width, height = map_to_standard_shapes(width, height)\n            height = int(height)\n            width = int(width)\n            logger.warning(f\"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}\")\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_attention_mask,\n            prompt_embeds_2,\n            negative_prompt_embeds_2,\n            prompt_attention_mask_2,\n            negative_prompt_attention_mask_2,\n            callback_on_step_end_tensor_inputs,\n        )\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_attention_mask,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            dtype=self.transformer.dtype,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            prompt_attention_mask=prompt_attention_mask,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n            max_sequence_length=77,\n            text_encoder_index=0,\n        )\n        (\n            prompt_embeds_2,\n            negative_prompt_embeds_2,\n            prompt_attention_mask_2,\n            negative_prompt_attention_mask_2,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            dtype=self.transformer.dtype,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds_2,\n            negative_prompt_embeds=negative_prompt_embeds_2,\n            prompt_attention_mask=prompt_attention_mask_2,\n            negative_prompt_attention_mask=negative_prompt_attention_mask_2,\n            max_sequence_length=256,\n            text_encoder_index=1,\n        )\n\n        # 4. Preprocess image\n        init_image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        map = self.mask_processor.preprocess(\n            map,\n            height=height // self.vae_scale_factor,\n            width=width // self.vae_scale_factor,\n        ).to(device)\n\n        # 5. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler, num_inference_steps, device, timesteps, sigmas\n        )\n\n        # begin diff diff change\n        total_time_steps = num_inference_steps\n        # end diff diff change\n\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.transformer.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            init_image,\n            latent_timestep,\n            prompt_embeds.dtype,\n            device,\n            generator,\n        )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 8. create image_rotary_emb, style embedding & time ids\n        grid_height = height // 8 // self.transformer.config.patch_size\n        grid_width = width // 8 // self.transformer.config.patch_size\n        base_size = 512 // 8 // self.transformer.config.patch_size\n        grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)\n        image_rotary_emb = get_2d_rotary_pos_embed(\n            self.transformer.inner_dim // self.transformer.num_heads,\n            grid_crops_coords,\n            (grid_height, grid_width),\n            device=device,\n            output_type=\"pt\",\n        )\n\n        style = torch.tensor([0], device=device)\n\n        target_size = target_size or (height, width)\n        add_time_ids = list(original_size + target_size + crops_coords_top_left)\n        add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n            prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])\n            prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])\n            prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])\n            add_time_ids = torch.cat([add_time_ids] * 2, dim=0)\n            style = torch.cat([style] * 2, dim=0)\n\n        prompt_embeds = prompt_embeds.to(device=device)\n        prompt_attention_mask = prompt_attention_mask.to(device=device)\n        prompt_embeds_2 = prompt_embeds_2.to(device=device)\n        prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)\n        add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(\n            batch_size * num_images_per_prompt, 1\n        )\n        style = style.to(device=device).repeat(batch_size * num_images_per_prompt)\n        # 9. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        # preparations for diff diff\n        original_with_noise = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            init_image,\n            timesteps,\n            prompt_embeds.dtype,\n            device,\n            generator,\n        )\n        thresholds = torch.arange(total_time_steps, dtype=map.dtype) / total_time_steps\n        thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device)\n        masks = map.squeeze() > (thresholds + (denoising_start or 0))\n        # end diff diff preparations\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n                # diff diff\n                if i == 0 and denoising_start is None:\n                    latents = original_with_noise[:1]\n                else:\n                    mask = masks[i].unsqueeze(0).to(latents.dtype)\n                    mask = mask.unsqueeze(1)  # fit shape\n                    latents = original_with_noise[i] * mask + latents * (1 - mask)\n                # end diff diff\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input\n                t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(\n                    dtype=latent_model_input.dtype\n                )\n\n                # predict the noise residual\n                noise_pred = self.transformer(\n                    latent_model_input,\n                    t_expand,\n                    encoder_hidden_states=prompt_embeds,\n                    text_embedding_mask=prompt_attention_mask,\n                    encoder_hidden_states_t5=prompt_embeds_2,\n                    text_embedding_mask_t5=prompt_attention_mask_2,\n                    image_meta_size=add_time_ids,\n                    style=style,\n                    image_rotary_emb=image_rotary_emb,\n                    return_dict=False,\n                )[0]\n\n                noise_pred, _ = noise_pred.chunk(2, dim=1)\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    prompt_embeds_2 = callback_outputs.pop(\"prompt_embeds_2\", prompt_embeds_2)\n                    negative_prompt_embeds_2 = callback_outputs.pop(\n                        \"negative_prompt_embeds_2\", negative_prompt_embeds_2\n                    )\n\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/pipeline_kolors_differential_img2img.py",
    "content": "# Copyright 2025 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport PIL.Image\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin\nfrom diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel\nfrom diffusers.pipelines.kolors.pipeline_output import KolorsPipelineOutput\nfrom diffusers.pipelines.kolors.text_encoder import ChatGLMModel\nfrom diffusers.pipelines.kolors.tokenizer import ChatGLMTokenizer\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import deprecate, is_torch_xla_available, logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import KolorsDifferentialImg2ImgPipeline\n        >>> from diffusers.utils import load_image\n\n        >>> pipe = KolorsDifferentialImg2ImgPipeline.from_pretrained(\n        ...     \"Kwai-Kolors/Kolors-diffusers\", variant=\"fp16\", torch_dtype=torch.float16\n        ... )\n        >>> pipe = pipe.to(\"cuda\")\n        >>> url = (\n        ...     \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/bunny_source.png\"\n        ... )\n\n\n        >>> init_image = load_image(url)\n        >>> prompt = \"high quality image of a capybara wearing sunglasses. In the background of the image there are trees, poles, grass and other objects. At the bottom of the object there is the road., 8k, highly detailed.\"\n        >>> image = pipe(prompt, image=init_image).images[0]\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass KolorsDifferentialImg2ImgPipeline(\n    DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Kolors.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`ChatGLMModel`]):\n            Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b).\n        tokenizer (`ChatGLMTokenizer`):\n            Tokenizer of class\n            [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"False\"`):\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\n            `Kwai-Kolors/Kolors-diffusers`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder-unet->vae\"\n    _optional_components = [\n        \"image_encoder\",\n        \"feature_extractor\",\n    ]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"add_text_embeds\",\n        \"add_time_ids\",\n        \"negative_pooled_prompt_embeds\",\n        \"negative_add_time_ids\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: ChatGLMModel,\n        tokenizer: ChatGLMTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        feature_extractor: CLIPImageProcessor = None,\n        force_zeros_for_empty_prompt: bool = False,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            image_encoder=image_encoder,\n            feature_extractor=feature_extractor,\n        )\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n        self.mask_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True\n        )\n\n        self.default_sample_size = (\n            self.unet.config.sample_size\n            if hasattr(self, \"unet\") and self.unet is not None and hasattr(self.unet.config, \"sample_size\")\n            else 128\n        )\n\n    # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        max_sequence_length: int = 256,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.\n        \"\"\"\n        # from IPython import embed; embed(); exit()\n        device = device or self._execution_device\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer]\n        text_encoders = [self.text_encoder]\n\n        if prompt_embeds is None:\n            prompt_embeds_list = []\n            for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=max_sequence_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                ).to(device)\n                output = text_encoder(\n                    input_ids=text_inputs[\"input_ids\"],\n                    attention_mask=text_inputs[\"attention_mask\"],\n                    position_ids=text_inputs[\"position_ids\"],\n                    output_hidden_states=True,\n                )\n\n                # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size]\n                # clone to have a contiguous tensor\n                prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()\n                # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size]\n                pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()\n                bs_embed, seq_len, _ = prompt_embeds.shape\n                prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n                prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = prompt_embeds_list[0]\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            negative_prompt_embeds_list = []\n\n            for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n                uncond_input = tokenizer(\n                    uncond_tokens,\n                    padding=\"max_length\",\n                    max_length=max_sequence_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                ).to(device)\n                output = text_encoder(\n                    input_ids=uncond_input[\"input_ids\"],\n                    attention_mask=uncond_input[\"attention_mask\"],\n                    position_ids=uncond_input[\"position_ids\"],\n                    output_hidden_states=True,\n                )\n\n                # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size]\n                # clone to have a contiguous tensor\n                negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()\n                # [max_sequence_length, batch, hidden_size] -> [batch, hidden_size]\n                negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()\n\n                if do_classifier_free_guidance:\n                    # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n                    seq_len = negative_prompt_embeds.shape[1]\n\n                    negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)\n\n                    negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n                    negative_prompt_embeds = negative_prompt_embeds.view(\n                        batch_size * num_images_per_prompt, seq_len, -1\n                    )\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            negative_prompt_embeds = negative_prompt_embeds_list[0]\n\n        bs_embed = pooled_prompt_embeds.shape[0]\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        image_embeds = []\n        if do_classifier_free_guidance:\n            negative_image_embeds = []\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n\n                image_embeds.append(single_image_embeds[None, :])\n                if do_classifier_free_guidance:\n                    negative_image_embeds.append(single_negative_image_embeds[None, :])\n        else:\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    negative_image_embeds.append(single_negative_image_embeds)\n                image_embeds.append(single_image_embeds)\n\n        ip_adapter_image_embeds = []\n        for i, single_image_embeds in enumerate(image_embeds):\n            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)\n            if do_classifier_free_guidance:\n                single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)\n                single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)\n\n            single_image_embeds = single_image_embeds.to(device=device)\n            ip_adapter_image_embeds.append(single_image_embeds)\n\n        return ip_adapter_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        strength,\n        num_inference_steps,\n        height,\n        width,\n        negative_prompt=None,\n        prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        max_sequence_length=None,\n    ):\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        if not isinstance(num_inference_steps, int) or num_inference_steps <= 0:\n            raise ValueError(\n                f\"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type\"\n                f\" {type(num_inference_steps)}.\"\n            )\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n        if ip_adapter_image_embeds is not None:\n            if not isinstance(ip_adapter_image_embeds, list):\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\n                )\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\n                )\n\n        if max_sequence_length is not None and max_sequence_length > 256:\n            raise ValueError(f\"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}\")\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):\n        # get the original timestep using init_timestep\n        if denoising_start is None:\n            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n            t_start = max(num_inference_steps - init_timestep, 0)\n        else:\n            t_start = 0\n\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        # Strength is irrelevant if we directly request a timestep to start at;\n        # that is, strength is determined by the denoising_start instead.\n        if denoising_start is not None:\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_start * self.scheduler.config.num_train_timesteps)\n                )\n            )\n\n            num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()\n            if self.scheduler.order == 2 and num_inference_steps % 2 == 0:\n                # if the scheduler is a 2nd order scheduler we might have to do +1\n                # because `num_inference_steps` might be even given that every timestep\n                # (except the highest one) is duplicated. If `num_inference_steps` is even it would\n                # mean that we cut the timesteps in the middle of the denoising step\n                # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1\n                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler\n                num_inference_steps = num_inference_steps + 1\n\n            # because t_n+1 >= t_n, we slice the timesteps starting from the end\n            timesteps = timesteps[-num_inference_steps:]\n            return timesteps, num_inference_steps\n\n        return timesteps, num_inference_steps - t_start\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents\n    def prepare_latents(\n        self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True\n    ):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        latents_mean = latents_std = None\n        if hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None:\n            latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)\n        if hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None:\n            latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)\n\n        # Offload text encoder if `enable_model_cpu_offload` was enabled\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.text_encoder_2.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            init_latents = image\n\n        else:\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            if self.vae.config.force_upcast:\n                image = image.float()\n                self.vae.to(dtype=torch.float32)\n\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            elif isinstance(generator, list):\n                if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:\n                    image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)\n                elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:\n                    raise ValueError(\n                        f\"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} \"\n                    )\n\n                init_latents = [\n                    retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                    for i in range(batch_size)\n                ]\n                init_latents = torch.cat(init_latents, dim=0)\n            else:\n                init_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n            if self.vae.config.force_upcast:\n                self.vae.to(dtype)\n\n            init_latents = init_latents.to(dtype)\n            if latents_mean is not None and latents_std is not None:\n                latents_mean = latents_mean.to(device=device, dtype=dtype)\n                latents_std = latents_std.to(device=device, dtype=dtype)\n                init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std\n            else:\n                init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        if add_noise:\n            shape = init_latents.shape\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            # get latents\n            init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n\n        latents = init_latents\n\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids\n    def _get_add_time_ids(\n        self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None\n    ):\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n\n        passed_add_embed_dim = (\n            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim\n        )\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        return add_time_ids\n\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(\n        self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32\n    ) -> torch.Tensor:\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            w (`torch.Tensor`):\n                Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.\n            embedding_dim (`int`, *optional*, defaults to 512):\n                Dimension of the embeddings to generate.\n            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):\n                Data type of the generated embeddings.\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def denoising_start(self):\n        return self._denoising_start\n\n    @property\n    def denoising_end(self):\n        return self._denoising_end\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        strength: float = 0.3,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        sigmas: List[float] = None,\n        denoising_start: Optional[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 256,\n        map: PipelineImageInput = None,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):\n                The image(s) to modify with the pipeline.\n            strength (`float`, *optional*, defaults to 0.3):\n                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`\n                will be used as a starting point, adding more noise to it the larger the `strength`. The number of\n                denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will\n                be maximum and the denoising process will run for the full number of iterations specified in\n                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of\n                `denoising_start` being declared as an integer, the value of `strength` will be ignored.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints\n                that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [Kwai-Kolors/Kolors-diffusers](https://huggingface.co/Kwai-Kolors/Kolors-diffusers) and checkpoints\n                that are not specifically fine-tuned on low resolutions.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            denoising_start (`float`, *optional*):\n                When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be\n                bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and\n                it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,\n                strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline\n                is integrated into a \"Mixture of Denoisers\" multi-pipeline setup, as detailed in [**Refine Image\n                Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise as determined by the discrete timesteps selected by the\n                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a\n                \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.kolors.KolorsPipelineOutput`] instead of a plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.kolors.KolorsPipelineOutput`] or `tuple`: [`~pipelines.kolors.KolorsPipelineOutput`] if\n            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the\n            generated images.\n        \"\"\"\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        # 0. Default height and width to unet\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            strength,\n            num_inference_steps,\n            height,\n            width,\n            negative_prompt,\n            prompt_embeds,\n            pooled_prompt_embeds,\n            negative_prompt_embeds,\n            negative_pooled_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            callback_on_step_end_tensor_inputs,\n            max_sequence_length=max_sequence_length,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n        self._denoising_start = denoising_start\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n\n        # 4. Preprocess image\n        init_image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n\n        map = self.mask_processor.preprocess(\n            map, height=height // self.vae_scale_factor, width=width // self.vae_scale_factor\n        ).to(device)\n\n        # 5. Prepare timesteps\n        def denoising_value_valid(dnv):\n            return isinstance(dnv, float) and 0 < dnv < 1\n\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler, num_inference_steps, device, timesteps, sigmas\n        )\n\n        # begin diff diff change\n        total_time_steps = num_inference_steps\n        # end diff diff change\n\n        timesteps, num_inference_steps = self.get_timesteps(\n            num_inference_steps,\n            strength,\n            device,\n            denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,\n        )\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        add_noise = True if self.denoising_start is None else False\n\n        # 6. Prepare latent variables\n        if latents is None:\n            latents = self.prepare_latents(\n                init_image,\n                latent_timestep,\n                batch_size,\n                num_images_per_prompt,\n                prompt_embeds.dtype,\n                device,\n                generator,\n                add_noise,\n            )\n\n        # 7. Prepare extra step kwargs.\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        height, width = latents.shape[-2:]\n        height = height * self.vae_scale_factor\n        width = width * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 8. Prepare added time ids & embeddings\n        add_text_embeds = pooled_prompt_embeds\n        text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n\n        add_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n        if negative_original_size is not None and negative_target_size is not None:\n            negative_add_time_ids = self._get_add_time_ids(\n                negative_original_size,\n                negative_crops_coords_top_left,\n                negative_target_size,\n                dtype=prompt_embeds.dtype,\n                text_encoder_projection_dim=text_encoder_projection_dim,\n            )\n        else:\n            negative_add_time_ids = add_time_ids\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 9. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # preparations for diff diff\n        original_with_noise = self.prepare_latents(\n            init_image, timesteps, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator\n        )\n        thresholds = torch.arange(total_time_steps, dtype=map.dtype) / total_time_steps\n        thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device)\n        masks = map.squeeze() > thresholds\n        # end diff diff preparations\n\n        # 9.1 Apply denoising_end\n        if (\n            self.denoising_end is not None\n            and self.denoising_start is not None\n            and denoising_value_valid(self.denoising_end)\n            and denoising_value_valid(self.denoising_start)\n            and self.denoising_start >= self.denoising_end\n        ):\n            raise ValueError(\n                f\"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: \"\n                + f\" {self.denoising_end} when using type float.\"\n            )\n        elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        # 9.2 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # diff diff\n                if i == 0:\n                    latents = original_with_noise[:1]\n                else:\n                    mask = masks[i].unsqueeze(0).to(latents.dtype)\n                    mask = mask.unsqueeze(1)  # fit shape\n                    latents = original_with_noise[i] * mask + latents * (1 - mask)\n                # end diff diff\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n                    added_cond_kwargs[\"image_embeds\"] = image_embeds\n\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                    negative_pooled_prompt_embeds = callback_outputs.pop(\n                        \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                    )\n                    add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n                    negative_add_time_ids = callback_outputs.pop(\"negative_add_time_ids\", negative_add_time_ids)\n\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n            elif latents.dtype != self.vae.dtype:\n                if torch.backends.mps.is_available():\n                    # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                    self.vae = self.vae.to(latents.dtype)\n\n            # unscale/denormalize the latents\n            latents = latents / self.vae.config.scaling_factor\n\n            image = self.vae.decode(latents, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n\n        if not output_type == \"latent\":\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return KolorsPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_kolors_inpainting.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel\nfrom diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    deprecate,\n    is_invisible_watermark_available,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_invisible_watermark_available():\n    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import KolorsInpaintPipeline\n        >>> from diffusers.utils import load_image\n\n        >>> pipe = KolorsInpaintPipeline.from_pretrained(\n        ...     \"Kwai-Kolors/Kolors-diffusers\",\n        ...     torch_dtype=torch.float16,\n        ...     variant=\"fp16\"\n        ...     use_safetensors=True\n        ... )\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> img_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\n        >>> mask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\n        >>> init_image = load_image(img_url).convert(\"RGB\")\n        >>> mask_image = load_image(mask_url).convert(\"RGB\")\n\n        >>> prompt = \"A majestic tiger sitting on a bench\"\n        >>> image = pipe(\n        ...     prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80\n        ... ).images[0]\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\ndef mask_pil_to_torch(mask, height, width):\n    # preprocess mask\n    if isinstance(mask, (PIL.Image.Image, np.ndarray)):\n        mask = [mask]\n\n    if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):\n        mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]\n        mask = np.concatenate([np.array(m.convert(\"L\"))[None, None, :] for m in mask], axis=0)\n        mask = mask.astype(np.float32) / 255.0\n    elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):\n        mask = np.concatenate([m[None, None, :] for m in mask], axis=0)\n\n    mask = torch.from_numpy(mask)\n    return mask\n\n\ndef prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):\n    \"\"\"\n    Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be\n    converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the\n    ``image`` and ``1`` for the ``mask``.\n\n    The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be\n    binarized (``mask > 0.5``) and cast to ``torch.float32`` too.\n\n    Args:\n        image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.\n            It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``\n            ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.\n        mask (_type_): The mask to apply to the image, i.e. regions to inpaint.\n            It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``\n            ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.\n\n\n    Raises:\n        ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask\n        should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.\n        TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not\n            (ot the other way around).\n\n    Returns:\n        tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4\n            dimensions: ``batch x channels x height x width``.\n    \"\"\"\n\n    # checkpoint. TOD(Yiyi) - need to clean this up later\n    deprecation_message = \"The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead\"\n    deprecate(\n        \"prepare_mask_and_masked_image\",\n        \"0.30.0\",\n        deprecation_message,\n    )\n    if image is None:\n        raise ValueError(\"`image` input cannot be undefined.\")\n\n    if mask is None:\n        raise ValueError(\"`mask_image` input cannot be undefined.\")\n\n    if isinstance(image, torch.Tensor):\n        if not isinstance(mask, torch.Tensor):\n            mask = mask_pil_to_torch(mask, height, width)\n\n        if image.ndim == 3:\n            image = image.unsqueeze(0)\n\n        # Batch and add channel dim for single mask\n        if mask.ndim == 2:\n            mask = mask.unsqueeze(0).unsqueeze(0)\n\n        # Batch single mask or add channel dim\n        if mask.ndim == 3:\n            # Single batched mask, no channel dim or single mask not batched but channel dim\n            if mask.shape[0] == 1:\n                mask = mask.unsqueeze(0)\n\n            # Batched masks no channel dim\n            else:\n                mask = mask.unsqueeze(1)\n\n        assert image.ndim == 4 and mask.ndim == 4, \"Image and Mask must have 4 dimensions\"\n        # assert image.shape[-2:] == mask.shape[-2:], \"Image and Mask must have the same spatial dimensions\"\n        assert image.shape[0] == mask.shape[0], \"Image and Mask must have the same batch size\"\n\n        # Check image is in [-1, 1]\n        # if image.min() < -1 or image.max() > 1:\n        #    raise ValueError(\"Image should be in [-1, 1] range\")\n\n        # Check mask is in [0, 1]\n        if mask.min() < 0 or mask.max() > 1:\n            raise ValueError(\"Mask should be in [0, 1] range\")\n\n        # Binarize mask\n        mask[mask < 0.5] = 0\n        mask[mask >= 0.5] = 1\n\n        # Image as float32\n        image = image.to(dtype=torch.float32)\n    elif isinstance(mask, torch.Tensor):\n        raise TypeError(f\"`mask` is a torch.Tensor but `image` (type: {type(image)} is not\")\n    else:\n        # preprocess image\n        if isinstance(image, (PIL.Image.Image, np.ndarray)):\n            image = [image]\n        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n            # resize all images w.r.t passed height an width\n            image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]\n            image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n            image = np.concatenate(image, axis=0)\n        elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n            image = np.concatenate([i[None, :] for i in image], axis=0)\n\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n        mask = mask_pil_to_torch(mask, height, width)\n        mask[mask < 0.5] = 0\n        mask[mask >= 0.5] = 1\n\n    if image.shape[1] == 4:\n        # images are in latent space and thus can't\n        # be masked set masked_image to None\n        # we assume that the checkpoint is not an inpainting\n        # checkpoint. TOD(Yiyi) - need to clean this up later\n        masked_image = None\n    else:\n        masked_image = image * (mask < 0.5)\n\n    # n.b. ensure backwards compatibility as old function does not return image\n    if return_image:\n        return mask, masked_image, image\n\n    return mask, masked_image\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass KolorsInpaintPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    FromSingleFileMixin,\n    IPAdapterMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Kolors.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.safetensors` files\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`ChatGLMModel`]):\n            Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b).\n        tokenizer (`ChatGLMTokenizer`):\n            Tokenizer of class\n            [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        requires_aesthetics_score (`bool`, *optional*, defaults to `\"False\"`):\n            Whether the `unet` requires a aesthetic_score condition to be passed during inference.\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\n            `Kwai-Kolors/Kolors-diffusers`.\n        add_watermarker (`bool`, *optional*):\n            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to\n            watermark output images. If not defined, it will default to True if the package is installed, otherwise no\n            watermarker will be used.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n\n    _optional_components = [\n        \"tokenizer\",\n        \"text_encoder\",\n        \"image_encoder\",\n        \"feature_extractor\",\n    ]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"add_text_embeds\",\n        \"add_time_ids\",\n        \"negative_pooled_prompt_embeds\",\n        \"add_neg_time_ids\",\n        \"mask\",\n        \"masked_image_latents\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: ChatGLMModel,\n        tokenizer: ChatGLMTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        feature_extractor: CLIPImageProcessor = None,\n        requires_aesthetics_score: bool = False,\n        force_zeros_for_empty_prompt: bool = True,\n        add_watermarker: Optional[bool] = None,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            image_encoder=image_encoder,\n            feature_extractor=feature_extractor,\n            scheduler=scheduler,\n        )\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.mask_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True\n        )\n\n        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()\n\n        if add_watermarker:\n            self.watermark = StableDiffusionXLWatermarker()\n        else:\n            self.watermark = None\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            image_embeds = []\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)\n                single_negative_image_embeds = torch.stack(\n                    [single_negative_image_embeds] * num_images_per_prompt, dim=0\n                )\n\n                if do_classifier_free_guidance:\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                    single_image_embeds = single_image_embeds.to(device)\n\n                image_embeds.append(single_image_embeds)\n        else:\n            repeat_dims = [1]\n            image_embeds = []\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                    single_negative_image_embeds = single_negative_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))\n                    )\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                else:\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                image_embeds.append(single_image_embeds)\n\n        return image_embeds\n\n    def encode_prompt(\n        self,\n        prompt,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer]\n        text_encoders = [self.text_encoder]\n\n        if prompt_embeds is None:\n            # textual inversion: procecss multi-vector tokens if necessary\n            prompt_embeds_list = []\n            for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=256,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                ).to(self._execution_device)\n                output = text_encoder(\n                    input_ids=text_inputs[\"input_ids\"],\n                    attention_mask=text_inputs[\"attention_mask\"],\n                    position_ids=text_inputs[\"position_ids\"],\n                    output_hidden_states=True,\n                )\n                prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()\n                pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()  # [batch_size, 4096]\n                bs_embed, seq_len, _ = prompt_embeds.shape\n                prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n                prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n                prompt_embeds_list.append(prompt_embeds)\n\n            # prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n            prompt_embeds = prompt_embeds_list[0]\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            # negative_prompt = negative_prompt or \"\"\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            negative_prompt_embeds_list = []\n            for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n                # textual inversion: procecss multi-vector tokens if necessary\n                if isinstance(self, TextualInversionLoaderMixin):\n                    uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    uncond_tokens,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                ).to(self._execution_device)\n                output = text_encoder(\n                    input_ids=uncond_input[\"input_ids\"],\n                    attention_mask=uncond_input[\"attention_mask\"],\n                    position_ids=uncond_input[\"position_ids\"],\n                    output_hidden_states=True,\n                )\n                negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()\n                negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone()  # [batch_size, 4096]\n\n                if do_classifier_free_guidance:\n                    # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n                    seq_len = negative_prompt_embeds.shape[1]\n\n                    negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)\n\n                    negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n                    negative_prompt_embeds = negative_prompt_embeds.view(\n                        batch_size * num_images_per_prompt, seq_len, -1\n                    )\n\n                    # For classifier free guidance, we need to do two forward passes.\n                    # Here we concatenate the unconditional and text embeddings into a single batch\n                    # to avoid doing two forward passes\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            # negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n            negative_prompt_embeds = negative_prompt_embeds_list[0]\n\n        bs_embed = pooled_prompt_embeds.shape[0]\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        image,\n        mask_image,\n        height,\n        width,\n        strength,\n        callback_steps,\n        output_type,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        padding_mask_crop=None,\n    ):\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n        if padding_mask_crop is not None:\n            if not isinstance(image, PIL.Image.Image):\n                raise ValueError(\n                    f\"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}.\"\n                )\n            if not isinstance(mask_image, PIL.Image.Image):\n                raise ValueError(\n                    f\"The mask image should be a PIL image when inpainting mask crop, but is of type\"\n                    f\" {type(mask_image)}.\"\n                )\n            if output_type != \"pil\":\n                raise ValueError(f\"The output type should be PIL when inpainting mask crop, but is {output_type}.\")\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n        if ip_adapter_image_embeds is not None:\n            if not isinstance(ip_adapter_image_embeds, list):\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\n                )\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\n                )\n\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n        image=None,\n        timestep=None,\n        is_strength_max=True,\n        add_noise=True,\n        return_noise=False,\n        return_image_latents=False,\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if (image is None or timestep is None) and not is_strength_max:\n            raise ValueError(\n                \"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise.\"\n                \"However, either the image or the noise timestep has not been provided.\"\n            )\n\n        if image.shape[1] == 4:\n            image_latents = image.to(device=device, dtype=dtype)\n            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)\n        elif return_image_latents or (latents is None and not is_strength_max):\n            image = image.to(device=device, dtype=dtype)\n            image_latents = self._encode_vae_image(image=image, generator=generator)\n            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)\n\n        if latents is None and add_noise:\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            # if strength is 1. then initialise the latents to noise, else initial to image + noise\n            latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)\n            # if pure noise then scale the initial latents by the  Scheduler's init sigma\n            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents\n        elif add_noise:\n            noise = latents.to(device)\n            latents = noise * self.scheduler.init_noise_sigma\n        else:\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            latents = image_latents.to(device)\n\n        outputs = (latents,)\n\n        if return_noise:\n            outputs += (noise,)\n\n        if return_image_latents:\n            outputs += (image_latents,)\n\n        return outputs\n\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\n        dtype = image.dtype\n        if self.vae.config.force_upcast:\n            image = image.float()\n            self.vae.to(dtype=torch.float32)\n\n        if isinstance(generator, list):\n            image_latents = [\n                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                for i in range(image.shape[0])\n            ]\n            image_latents = torch.cat(image_latents, dim=0)\n        else:\n            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n        if self.vae.config.force_upcast:\n            self.vae.to(dtype)\n\n        image_latents = image_latents.to(dtype)\n        image_latents = self.vae.config.scaling_factor * image_latents\n\n        return image_latents\n\n    def prepare_mask_latents(\n        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance\n    ):\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\n        # and half precision\n        mask = torch.nn.functional.interpolate(\n            mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)\n        )\n        mask = mask.to(device=device, dtype=dtype)\n\n        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\n        if mask.shape[0] < batch_size:\n            if not batch_size % mask.shape[0] == 0:\n                raise ValueError(\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\n                    f\" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number\"\n                    \" of masks that you pass is divisible by the total requested batch size.\"\n                )\n            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)\n\n        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask\n\n        if masked_image is not None and masked_image.shape[1] == 4:\n            masked_image_latents = masked_image\n        else:\n            masked_image_latents = None\n\n        if masked_image is not None:\n            if masked_image_latents is None:\n                masked_image = masked_image.to(device=device, dtype=dtype)\n                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)\n\n            if masked_image_latents.shape[0] < batch_size:\n                if not batch_size % masked_image_latents.shape[0] == 0:\n                    raise ValueError(\n                        \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                        f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\n                        \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                    )\n                masked_image_latents = masked_image_latents.repeat(\n                    batch_size // masked_image_latents.shape[0], 1, 1, 1\n                )\n\n            masked_image_latents = (\n                torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents\n            )\n\n            # aligning device to prevent device errors when concating it with the latent model input\n            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\n\n        return mask, masked_image_latents\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):\n        # get the original timestep using init_timestep\n        if denoising_start is None:\n            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n            t_start = max(num_inference_steps - init_timestep, 0)\n        else:\n            t_start = 0\n\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        # Strength is irrelevant if we directly request a timestep to start at;\n        # that is, strength is determined by the denoising_start instead.\n        if denoising_start is not None:\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_start * self.scheduler.config.num_train_timesteps)\n                )\n            )\n\n            num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()\n            if self.scheduler.order == 2 and num_inference_steps % 2 == 0:\n                # if the scheduler is a 2nd order scheduler we might have to do +1\n                # because `num_inference_steps` might be even given that every timestep\n                # (except the highest one) is duplicated. If `num_inference_steps` is even it would\n                # mean that we cut the timesteps in the middle of the denoising step\n                # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1\n                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler\n                num_inference_steps = num_inference_steps + 1\n\n            # because t_n+1 >= t_n, we slice the timesteps starting from the end\n            timesteps = timesteps[-num_inference_steps:]\n            return timesteps, num_inference_steps\n\n        return timesteps, num_inference_steps - t_start\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids\n    def _get_add_time_ids(\n        self,\n        original_size,\n        crops_coords_top_left,\n        target_size,\n        aesthetic_score,\n        negative_aesthetic_score,\n        negative_original_size,\n        negative_crops_coords_top_left,\n        negative_target_size,\n        dtype,\n        text_encoder_projection_dim=None,\n    ):\n        if self.config.requires_aesthetics_score:\n            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))\n            add_neg_time_ids = list(\n                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)\n            )\n        else:\n            add_time_ids = list(original_size + crops_coords_top_left + target_size)\n            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)\n\n        passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + 4096\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if (\n            expected_add_embed_dim > passed_add_embed_dim\n            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model.\"\n            )\n        elif (\n            expected_add_embed_dim < passed_add_embed_dim\n            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model.\"\n            )\n        elif expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)\n\n        return add_time_ids, add_neg_time_ids\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(\n        self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32\n    ) -> torch.Tensor:\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            w (`torch.Tensor`):\n                Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.\n            embedding_dim (`int`, *optional*, defaults to 512):\n                Dimension of the embeddings to generate.\n            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):\n                Data type of the generated embeddings.\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def denoising_end(self):\n        return self._denoising_end\n\n    @property\n    def denoising_start(self):\n        return self._denoising_start\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        mask_image: PipelineImageInput = None,\n        masked_image_latents: torch.Tensor = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        padding_mask_crop: Optional[int] = None,\n        strength: float = 0.9999,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        sigmas: List[float] = None,\n        denoising_start: Optional[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Tuple[int, int] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Tuple[int, int] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        aesthetic_score: float = 6.0,\n        negative_aesthetic_score: float = 2.5,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will\n                be masked out with `mask_image` and repainted according to `prompt`.\n            mask_image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted\n                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)\n                instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            padding_mask_crop (`int`, *optional*, defaults to `None`):\n                The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to\n                image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region\n                with the same aspect ration of the image and contains all masked area, and then expand that area based\n                on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before\n                resizing to the original image size for inpainting. This is useful when the masked area is small while\n                the image is large and contain information irrelevant for inpainting, such as background.\n            strength (`float`, *optional*, defaults to 0.9999):\n                Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be\n                between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the\n                `strength`. The number of denoising steps depends on the amount of noise initially added. When\n                `strength` is 1, added noise will be maximum and the denoising process will run for the full number of\n                iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked\n                portion of the reference `image`. Note that in the case of `denoising_start` being declared as an\n                integer, the value of `strength` will be ignored.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            denoising_start (`float`, *optional*):\n                When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be\n                bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and\n                it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,\n                strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline\n                is integrated into a \"Mixture of Denoisers\" multi-pipeline setup, as detailed in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be\n                denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the\n                final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline\n                forms a part of a \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            aesthetic_score (`float`, *optional*, defaults to 6.0):\n                Used to simulate an aesthetic score of the generated image by influencing the positive text condition.\n                Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_aesthetic_score (`float`, *optional*, defaults to 2.5):\n                Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to\n                simulate an aesthetic score of the generated image by influencing the negative text condition.\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs\n        self.check_inputs(\n            prompt,\n            image,\n            mask_image,\n            height,\n            width,\n            strength,\n            callback_steps,\n            output_type,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            callback_on_step_end_tensor_inputs,\n            padding_mask_crop,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n        self._denoising_start = denoising_start\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 4. set timesteps\n        def denoising_value_valid(dnv):\n            return isinstance(dnv, float) and 0 < dnv < 1\n\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler, num_inference_steps, device, timesteps, sigmas\n        )\n        timesteps, num_inference_steps = self.get_timesteps(\n            num_inference_steps,\n            strength,\n            device,\n            denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,\n        )\n        # check that number of inference steps is not < 1 - as this doesn't make sense\n        if num_inference_steps < 1:\n            raise ValueError(\n                f\"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline\"\n                f\"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline.\"\n            )\n        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise\n        is_strength_max = strength == 1.0\n\n        # 5. Preprocess mask and image\n        if padding_mask_crop is not None:\n            crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)\n            resize_mode = \"fill\"\n        else:\n            crops_coords = None\n            resize_mode = \"default\"\n\n        original_image = image\n        init_image = self.image_processor.preprocess(\n            image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode\n        )\n        init_image = init_image.to(dtype=torch.float32)\n\n        mask = self.mask_processor.preprocess(\n            mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords\n        )\n\n        if masked_image_latents is not None:\n            masked_image = masked_image_latents\n        elif init_image.shape[1] == 4:\n            # if images are in latent space, we can't mask it\n            masked_image = None\n        else:\n            masked_image = init_image * (mask < 0.5)\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.vae.config.latent_channels\n        num_channels_unet = self.unet.config.in_channels\n        return_image_latents = num_channels_unet == 4\n\n        add_noise = True if self.denoising_start is None else False\n        latents_outputs = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n            image=init_image,\n            timestep=latent_timestep,\n            is_strength_max=is_strength_max,\n            add_noise=add_noise,\n            return_noise=True,\n            return_image_latents=return_image_latents,\n        )\n\n        if return_image_latents:\n            latents, noise, image_latents = latents_outputs\n        else:\n            latents, noise = latents_outputs\n\n        # 7. Prepare mask latent variables\n        mask, masked_image_latents = self.prepare_mask_latents(\n            mask,\n            masked_image,\n            batch_size * num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            self.do_classifier_free_guidance,\n        )\n\n        # 8. Check that sizes of mask, masked image and latents match\n        if num_channels_unet == 9:\n            # default case for stable-diffusion-v1-5/stable-diffusion-inpainting\n            num_channels_mask = mask.shape[1]\n            num_channels_masked_image = masked_image_latents.shape[1]\n            if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:\n                raise ValueError(\n                    f\"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects\"\n                    f\" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +\"\n                    f\" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}\"\n                    f\" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of\"\n                    \" `pipeline.unet` or your `mask_image` or `image` input.\"\n                )\n        elif num_channels_unet != 4:\n            raise ValueError(\n                f\"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}.\"\n            )\n        # 8.1 Prepare extra step kwargs.\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        height, width = latents.shape[-2:]\n        height = height * self.vae_scale_factor\n        width = width * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 10. Prepare added time ids & embeddings\n        if negative_original_size is None:\n            negative_original_size = original_size\n        if negative_target_size is None:\n            negative_target_size = target_size\n\n        add_text_embeds = pooled_prompt_embeds\n        text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n\n        add_time_ids, add_neg_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            aesthetic_score,\n            negative_aesthetic_score,\n            negative_original_size,\n            negative_crops_coords_top_left,\n            negative_target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n            add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device)\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 11. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        if (\n            self.denoising_end is not None\n            and self.denoising_start is not None\n            and denoising_value_valid(self.denoising_end)\n            and denoising_value_valid(self.denoising_start)\n            and self.denoising_start >= self.denoising_end\n        ):\n            raise ValueError(\n                f\"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: \"\n                + f\" {self.denoising_end} when using type float.\"\n            )\n        elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        # 11.1 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n\n                # concat latents, mask, masked_image_latents in the channel dimension\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                if num_channels_unet == 9:\n                    latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)\n\n                # predict the noise residual\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n                    added_cond_kwargs[\"image_embeds\"] = image_embeds\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if num_channels_unet == 4:\n                    init_latents_proper = image_latents\n                    if self.do_classifier_free_guidance:\n                        init_mask, _ = mask.chunk(2)\n                    else:\n                        init_mask = mask\n\n                    if i < len(timesteps) - 1:\n                        noise_timestep = timesteps[i + 1]\n                        init_latents_proper = self.scheduler.add_noise(\n                            init_latents_proper, noise, torch.tensor([noise_timestep])\n                        )\n\n                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                    negative_pooled_prompt_embeds = callback_outputs.pop(\n                        \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                    )\n                    add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n                    add_neg_time_ids = callback_outputs.pop(\"add_neg_time_ids\", add_neg_time_ids)\n                    mask = callback_outputs.pop(\"mask\", mask)\n                    masked_image_latents = callback_outputs.pop(\"masked_image_latents\", masked_image_latents)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n            elif latents.dtype != self.vae.dtype:\n                if torch.backends.mps.is_available():\n                    # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                    self.vae = self.vae.to(latents.dtype)\n\n            # unscale/denormalize the latents\n            # denormalize with the mean and std if available and not None\n            has_latents_mean = hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None\n            has_latents_std = hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None\n            if has_latents_mean and has_latents_std:\n                latents_mean = (\n                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents_std = (\n                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean\n            else:\n                latents = latents / self.vae.config.scaling_factor\n\n            image = self.vae.decode(latents, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            return StableDiffusionXLPipelineOutput(images=latents)\n\n        # apply watermark if available\n        if self.watermark is not None:\n            image = self.watermark.apply_watermark(image)\n\n        image = self.image_processor.postprocess(image, output_type=output_type)\n\n        if padding_mask_crop is not None:\n            image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_null_text_inversion.py",
    "content": "import inspect\nimport os\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as nnf\nfrom PIL import Image\nfrom torch.optim.adam import Adam\nfrom tqdm import tqdm\n\nfrom diffusers import StableDiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\n\n\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps=None,\n    device=None,\n    timesteps=None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass NullTextPipeline(StableDiffusionPipeline):\n    def get_noise_pred(self, latents, t, context):\n        latents_input = torch.cat([latents] * 2)\n        guidance_scale = 7.5\n        noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)[\"sample\"]\n        noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)\n        noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)\n        latents = self.prev_step(noise_pred, t, latents)\n        return latents\n\n    def get_noise_pred_single(self, latents, t, context):\n        noise_pred = self.unet(latents, t, encoder_hidden_states=context)[\"sample\"]\n        return noise_pred\n\n    @torch.no_grad()\n    def image2latent(self, image_path):\n        image = Image.open(image_path).convert(\"RGB\")\n        image = np.array(image)\n        image = torch.from_numpy(image).float() / 127.5 - 1\n        image = image.permute(2, 0, 1).unsqueeze(0).to(self.device)\n        latents = self.vae.encode(image)[\"latent_dist\"].mean\n        latents = latents * 0.18215\n        return latents\n\n    @torch.no_grad()\n    def latent2image(self, latents):\n        latents = 1 / 0.18215 * latents.detach()\n        image = self.vae.decode(latents)[\"sample\"].detach()\n        image = self.processor.postprocess(image, output_type=\"pil\")[0]\n        return image\n\n    def prev_step(self, model_output, timestep, sample):\n        prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps\n        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]\n        alpha_prod_t_prev = (\n            self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod\n        )\n        beta_prod_t = 1 - alpha_prod_t\n        pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output\n        prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction\n        return prev_sample\n\n    def next_step(self, model_output, timestep, sample):\n        timestep, next_timestep = (\n            min(timestep - self.scheduler.config.num_train_timesteps // self.num_inference_steps, 999),\n            timestep,\n        )\n        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod\n        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]\n        beta_prod_t = 1 - alpha_prod_t\n        next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output\n        next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction\n        return next_sample\n\n    def null_optimization(self, latents, context, num_inner_steps, epsilon):\n        uncond_embeddings, cond_embeddings = context.chunk(2)\n        uncond_embeddings_list = []\n        latent_cur = latents[-1]\n        bar = tqdm(total=num_inner_steps * self.num_inference_steps)\n        for i in range(self.num_inference_steps):\n            uncond_embeddings = uncond_embeddings.clone().detach()\n            uncond_embeddings.requires_grad = True\n            optimizer = Adam([uncond_embeddings], lr=1e-2 * (1.0 - i / 100.0))\n            latent_prev = latents[len(latents) - i - 2]\n            t = self.scheduler.timesteps[i]\n            with torch.no_grad():\n                noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)\n            for j in range(num_inner_steps):\n                noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)\n                noise_pred = noise_pred_uncond + 7.5 * (noise_pred_cond - noise_pred_uncond)\n                latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)\n                loss = nnf.mse_loss(latents_prev_rec, latent_prev)\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n                loss_item = loss.item()\n                bar.update()\n                if loss_item < epsilon + i * 2e-5:\n                    break\n            for j in range(j + 1, num_inner_steps):\n                bar.update()\n            uncond_embeddings_list.append(uncond_embeddings[:1].detach())\n            with torch.no_grad():\n                context = torch.cat([uncond_embeddings, cond_embeddings])\n                latent_cur = self.get_noise_pred(latent_cur, t, context)\n        bar.close()\n        return uncond_embeddings_list\n\n    @torch.no_grad()\n    def ddim_inversion_loop(self, latent, context):\n        self.scheduler.set_timesteps(self.num_inference_steps)\n        _, cond_embeddings = context.chunk(2)\n        all_latent = [latent]\n        latent = latent.clone().detach()\n        with torch.no_grad():\n            for i in range(0, self.num_inference_steps):\n                t = self.scheduler.timesteps[len(self.scheduler.timesteps) - i - 1]\n                noise_pred = self.unet(latent, t, encoder_hidden_states=cond_embeddings)[\"sample\"]\n                latent = self.next_step(noise_pred, t, latent)\n                all_latent.append(latent)\n        return all_latent\n\n    def get_context(self, prompt):\n        uncond_input = self.tokenizer(\n            [\"\"], padding=\"max_length\", max_length=self.tokenizer.model_max_length, return_tensors=\"pt\"\n        )\n        uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n        text_input = self.tokenizer(\n            [prompt],\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]\n        context = torch.cat([uncond_embeddings, text_embeddings])\n        return context\n\n    def invert(\n        self, image_path: str, prompt: str, num_inner_steps=10, early_stop_epsilon=1e-6, num_inference_steps=50\n    ):\n        self.num_inference_steps = num_inference_steps\n        context = self.get_context(prompt)\n        latent = self.image2latent(image_path)\n        ddim_latents = self.ddim_inversion_loop(latent, context)\n        if os.path.exists(image_path + \".pt\"):\n            uncond_embeddings = torch.load(image_path + \".pt\")\n        else:\n            uncond_embeddings = self.null_optimization(ddim_latents, context, num_inner_steps, early_stop_epsilon)\n            uncond_embeddings = torch.stack(uncond_embeddings, 0)\n            torch.save(uncond_embeddings, image_path + \".pt\")\n        return ddim_latents[-1], uncond_embeddings\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt,\n        uncond_embeddings,\n        inverted_latent,\n        num_inference_steps: int = 50,\n        timesteps=None,\n        guidance_scale=7.5,\n        negative_prompt=None,\n        num_images_per_prompt=1,\n        generator=None,\n        latents=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        output_type=\"pil\",\n    ):\n        self._guidance_scale = guidance_scale\n        # 0. Default height and width to unet\n        height = self.unet.config.sample_size * self.vae_scale_factor\n        width = self.unet.config.sample_size * self.vae_scale_factor\n        # to deal with lora scaling and other possible forward hook\n        callback_steps = None\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n        )\n        # 2. Define call parameter\n        device = self._execution_device\n        # 3. Encode input prompt\n        prompt_embeds, _ = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n        latents = inverted_latent\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=uncond_embeddings[i])[\"sample\"]\n                noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)[\"sample\"]\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n                progress_bar.update()\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\n                0\n            ]\n        else:\n            image = latents\n        image = self.image_processor.postprocess(\n            image, output_type=output_type, do_denormalize=[True] * image.shape[0]\n        )\n        # Offload all models\n        self.maybe_free_model_hooks()\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=False)\n"
  },
  {
    "path": "examples/community/pipeline_prompt2prompt.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport abc\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom packaging import version\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTextModel,\n    CLIPTokenizer,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel\nfrom diffusers.configuration_utils import FrozenDict, deprecate\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models.attention import Attention\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import (\n    StableDiffusionSafetyChecker,\n)\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    logging,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\nclass Prompt2PromptPipeline(\n    DiffusionPipeline,\n    TextualInversionLoaderMixin,\n    StableDiffusionLoraLoaderMixin,\n    IPAdapterMixin,\n    FromSingleFileMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ):\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device),\n                    attention_mask=attention_mask,\n                    output_hidden_states=True,\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: Optional[int] = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n\n                The keyword arguments to configure the edit are:\n                - edit_type (`str`). The edit type to apply. Can be either of `replace`, `refine`, `reweight`.\n                - n_cross_replace (`int`): Number of diffusion steps in which cross attention should be replaced\n                - n_self_replace (`int`): Number of diffusion steps in which self attention should be replaced\n                - local_blend_words(`List[str]`, *optional*, default to `None`): Determines which area should be\n                  changed. If None, then the whole image can be changed.\n                - equalizer_words(`List[str]`, *optional*, default to `None`): Required for edit type `reweight`.\n                  Determines which words should be enhanced.\n                - equalizer_strengths (`List[float]`, *optional*, default to `None`) Required for edit type `reweight`.\n                  Determines which how much the words in `equalizer_words` should be enhanced.\n\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when\n                using zero terminal SNR.\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n\n        self.controller = create_controller(\n            prompt,\n            cross_attention_kwargs,\n            num_inference_steps,\n            tokenizer=self.tokenizer,\n            device=self.device,\n        )\n        self.register_attention_control(self.controller)  # add attention controller\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(prompt, height, width, callback_steps)\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if do_classifier_free_guidance and guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                # step callback\n                latents = self.controller.step_callback(latents)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        # 8. Post-processing\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        # 9. Run safety checker\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n    def register_attention_control(self, controller):\n        attn_procs = {}\n        cross_att_count = 0\n        for name in self.unet.attn_processors.keys():\n            (None if name.endswith(\"attn1.processor\") else self.unet.config.cross_attention_dim)\n            if name.startswith(\"mid_block\"):\n                self.unet.config.block_out_channels[-1]\n                place_in_unet = \"mid\"\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                list(reversed(self.unet.config.block_out_channels))[block_id]\n                place_in_unet = \"up\"\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                self.unet.config.block_out_channels[block_id]\n                place_in_unet = \"down\"\n            else:\n                continue\n            cross_att_count += 1\n            attn_procs[name] = P2PCrossAttnProcessor(controller=controller, place_in_unet=place_in_unet)\n\n        self.unet.set_attn_processor(attn_procs)\n        controller.num_att_layers = cross_att_count\n\n\nclass P2PCrossAttnProcessor:\n    def __init__(self, controller, place_in_unet):\n        super().__init__()\n        self.controller = controller\n        self.place_in_unet = place_in_unet\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n    ):\n        batch_size, sequence_length, _ = hidden_states.shape\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        query = attn.to_q(hidden_states)\n\n        is_cross = encoder_hidden_states is not None\n        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n\n        # one line change\n        self.controller(attention_probs, is_cross, self.place_in_unet)\n\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        return hidden_states\n\n\ndef create_controller(\n    prompts: List[str],\n    cross_attention_kwargs: Dict,\n    num_inference_steps: int,\n    tokenizer,\n    device,\n) -> AttentionControl:\n    edit_type = cross_attention_kwargs.get(\"edit_type\", None)\n    local_blend_words = cross_attention_kwargs.get(\"local_blend_words\", None)\n    equalizer_words = cross_attention_kwargs.get(\"equalizer_words\", None)\n    equalizer_strengths = cross_attention_kwargs.get(\"equalizer_strengths\", None)\n    n_cross_replace = cross_attention_kwargs.get(\"n_cross_replace\", 0.4)\n    n_self_replace = cross_attention_kwargs.get(\"n_self_replace\", 0.4)\n\n    # only replace\n    if edit_type == \"replace\" and local_blend_words is None:\n        return AttentionReplace(\n            prompts,\n            num_inference_steps,\n            n_cross_replace,\n            n_self_replace,\n            tokenizer=tokenizer,\n            device=device,\n        )\n\n    # replace + localblend\n    if edit_type == \"replace\" and local_blend_words is not None:\n        lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device)\n        return AttentionReplace(\n            prompts,\n            num_inference_steps,\n            n_cross_replace,\n            n_self_replace,\n            lb,\n            tokenizer=tokenizer,\n            device=device,\n        )\n\n    # only refine\n    if edit_type == \"refine\" and local_blend_words is None:\n        return AttentionRefine(\n            prompts,\n            num_inference_steps,\n            n_cross_replace,\n            n_self_replace,\n            tokenizer=tokenizer,\n            device=device,\n        )\n\n    # refine + localblend\n    if edit_type == \"refine\" and local_blend_words is not None:\n        lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device)\n        return AttentionRefine(\n            prompts,\n            num_inference_steps,\n            n_cross_replace,\n            n_self_replace,\n            lb,\n            tokenizer=tokenizer,\n            device=device,\n        )\n\n    # reweight\n    if edit_type == \"reweight\":\n        assert equalizer_words is not None and equalizer_strengths is not None, (\n            \"To use reweight edit, please specify equalizer_words and equalizer_strengths.\"\n        )\n        assert len(equalizer_words) == len(equalizer_strengths), (\n            \"equalizer_words and equalizer_strengths must be of same length.\"\n        )\n        equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)\n        return AttentionReweight(\n            prompts,\n            num_inference_steps,\n            n_cross_replace,\n            n_self_replace,\n            tokenizer=tokenizer,\n            device=device,\n            equalizer=equalizer,\n        )\n\n    raise ValueError(f\"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight.\")\n\n\nclass AttentionControl(abc.ABC):\n    def step_callback(self, x_t):\n        return x_t\n\n    def between_steps(self):\n        return\n\n    @property\n    def num_uncond_att_layers(self):\n        return 0\n\n    @abc.abstractmethod\n    def forward(self, attn, is_cross: bool, place_in_unet: str):\n        raise NotImplementedError\n\n    def __call__(self, attn, is_cross: bool, place_in_unet: str):\n        if self.cur_att_layer >= self.num_uncond_att_layers:\n            h = attn.shape[0]\n            attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet)\n        self.cur_att_layer += 1\n        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:\n            self.cur_att_layer = 0\n            self.cur_step += 1\n            self.between_steps()\n        return attn\n\n    def reset(self):\n        self.cur_step = 0\n        self.cur_att_layer = 0\n\n    def __init__(self):\n        self.cur_step = 0\n        self.num_att_layers = -1\n        self.cur_att_layer = 0\n\n\nclass EmptyControl(AttentionControl):\n    def forward(self, attn, is_cross: bool, place_in_unet: str):\n        return attn\n\n\nclass AttentionStore(AttentionControl):\n    @staticmethod\n    def get_empty_store():\n        return {\n            \"down_cross\": [],\n            \"mid_cross\": [],\n            \"up_cross\": [],\n            \"down_self\": [],\n            \"mid_self\": [],\n            \"up_self\": [],\n        }\n\n    def forward(self, attn, is_cross: bool, place_in_unet: str):\n        key = f\"{place_in_unet}_{'cross' if is_cross else 'self'}\"\n        if attn.shape[1] <= 32**2:  # avoid memory overhead\n            self.step_store[key].append(attn)\n        return attn\n\n    def between_steps(self):\n        if len(self.attention_store) == 0:\n            self.attention_store = self.step_store\n        else:\n            for key in self.attention_store:\n                for i in range(len(self.attention_store[key])):\n                    self.attention_store[key][i] += self.step_store[key][i]\n        self.step_store = self.get_empty_store()\n\n    def get_average_attention(self):\n        average_attention = {\n            key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store\n        }\n        return average_attention\n\n    def reset(self):\n        super(AttentionStore, self).reset()\n        self.step_store = self.get_empty_store()\n        self.attention_store = {}\n\n    def __init__(self):\n        super(AttentionStore, self).__init__()\n        self.step_store = self.get_empty_store()\n        self.attention_store = {}\n\n\nclass LocalBlend:\n    def __call__(self, x_t, attention_store):\n        k = 1\n        maps = attention_store[\"down_cross\"][2:4] + attention_store[\"up_cross\"][:3]\n        maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.max_num_words) for item in maps]\n        maps = torch.cat(maps, dim=1)\n        maps = (maps * self.alpha_layers).sum(-1).mean(1)\n        mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))\n        mask = F.interpolate(mask, size=(x_t.shape[2:]))\n        mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]\n        mask = mask.gt(self.threshold)\n        mask = (mask[:1] + mask[1:]).float()\n        x_t = x_t[:1] + mask * (x_t - x_t[:1])\n        return x_t\n\n    def __init__(\n        self,\n        prompts: List[str],\n        words: [List[List[str]]],\n        tokenizer,\n        device,\n        threshold=0.3,\n        max_num_words=77,\n    ):\n        self.max_num_words = 77\n\n        alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words)\n        for i, (prompt, words_) in enumerate(zip(prompts, words)):\n            if isinstance(words_, str):\n                words_ = [words_]\n            for word in words_:\n                ind = get_word_inds(prompt, word, tokenizer)\n                alpha_layers[i, :, :, :, :, ind] = 1\n        self.alpha_layers = alpha_layers.to(device)\n        self.threshold = threshold\n\n\nclass AttentionControlEdit(AttentionStore, abc.ABC):\n    def step_callback(self, x_t):\n        if self.local_blend is not None:\n            x_t = self.local_blend(x_t, self.attention_store)\n        return x_t\n\n    def replace_self_attention(self, attn_base, att_replace):\n        if att_replace.shape[2] <= 16**2:\n            return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)\n        else:\n            return att_replace\n\n    @abc.abstractmethod\n    def replace_cross_attention(self, attn_base, att_replace):\n        raise NotImplementedError\n\n    def forward(self, attn, is_cross: bool, place_in_unet: str):\n        super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)\n        # FIXME not replace correctly\n        if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):\n            h = attn.shape[0] // (self.batch_size)\n            attn = attn.reshape(self.batch_size, h, *attn.shape[1:])\n            attn_base, attn_repalce = attn[0], attn[1:]\n            if is_cross:\n                alpha_words = self.cross_replace_alpha[self.cur_step]\n                attn_repalce_new = (\n                    self.replace_cross_attention(attn_base, attn_repalce) * alpha_words\n                    + (1 - alpha_words) * attn_repalce\n                )\n                attn[1:] = attn_repalce_new\n            else:\n                attn[1:] = self.replace_self_attention(attn_base, attn_repalce)\n            attn = attn.reshape(self.batch_size * h, *attn.shape[2:])\n        return attn\n\n    def __init__(\n        self,\n        prompts,\n        num_steps: int,\n        cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],\n        self_replace_steps: Union[float, Tuple[float, float]],\n        local_blend: Optional[LocalBlend],\n        tokenizer,\n        device,\n    ):\n        super(AttentionControlEdit, self).__init__()\n        # add tokenizer and device here\n\n        self.tokenizer = tokenizer\n        self.device = device\n\n        self.batch_size = len(prompts)\n        self.cross_replace_alpha = get_time_words_attention_alpha(\n            prompts, num_steps, cross_replace_steps, self.tokenizer\n        ).to(self.device)\n        if isinstance(self_replace_steps, float):\n            self_replace_steps = 0, self_replace_steps\n        self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])\n        self.local_blend = local_blend  # 在外面定义后传进来\n\n\nclass AttentionReplace(AttentionControlEdit):\n    def replace_cross_attention(self, attn_base, att_replace):\n        return torch.einsum(\"hpw,bwn->bhpn\", attn_base, self.mapper)\n\n    def __init__(\n        self,\n        prompts,\n        num_steps: int,\n        cross_replace_steps: float,\n        self_replace_steps: float,\n        local_blend: Optional[LocalBlend] = None,\n        tokenizer=None,\n        device=None,\n    ):\n        super(AttentionReplace, self).__init__(\n            prompts,\n            num_steps,\n            cross_replace_steps,\n            self_replace_steps,\n            local_blend,\n            tokenizer,\n            device,\n        )\n        self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device)\n\n\nclass AttentionRefine(AttentionControlEdit):\n    def replace_cross_attention(self, attn_base, att_replace):\n        attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)\n        attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)\n        return attn_replace\n\n    def __init__(\n        self,\n        prompts,\n        num_steps: int,\n        cross_replace_steps: float,\n        self_replace_steps: float,\n        local_blend: Optional[LocalBlend] = None,\n        tokenizer=None,\n        device=None,\n    ):\n        super(AttentionRefine, self).__init__(\n            prompts,\n            num_steps,\n            cross_replace_steps,\n            self_replace_steps,\n            local_blend,\n            tokenizer,\n            device,\n        )\n        self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer)\n        self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device)\n        self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])\n\n\nclass AttentionReweight(AttentionControlEdit):\n    def replace_cross_attention(self, attn_base, att_replace):\n        if self.prev_controller is not None:\n            attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)\n        attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]\n        return attn_replace\n\n    def __init__(\n        self,\n        prompts,\n        num_steps: int,\n        cross_replace_steps: float,\n        self_replace_steps: float,\n        equalizer,\n        local_blend: Optional[LocalBlend] = None,\n        controller: Optional[AttentionControlEdit] = None,\n        tokenizer=None,\n        device=None,\n    ):\n        super(AttentionReweight, self).__init__(\n            prompts,\n            num_steps,\n            cross_replace_steps,\n            self_replace_steps,\n            local_blend,\n            tokenizer,\n            device,\n        )\n        self.equalizer = equalizer.to(self.device)\n        self.prev_controller = controller\n\n\n### util functions for all Edits\ndef update_alpha_time_word(\n    alpha,\n    bounds: Union[float, Tuple[float, float]],\n    prompt_ind: int,\n    word_inds: Optional[torch.Tensor] = None,\n):\n    if isinstance(bounds, float):\n        bounds = 0, bounds\n    start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])\n    if word_inds is None:\n        word_inds = torch.arange(alpha.shape[2])\n    alpha[:start, prompt_ind, word_inds] = 0\n    alpha[start:end, prompt_ind, word_inds] = 1\n    alpha[end:, prompt_ind, word_inds] = 0\n    return alpha\n\n\ndef get_time_words_attention_alpha(\n    prompts,\n    num_steps,\n    cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],\n    tokenizer,\n    max_num_words=77,\n):\n    if not isinstance(cross_replace_steps, dict):\n        cross_replace_steps = {\"default_\": cross_replace_steps}\n    if \"default_\" not in cross_replace_steps:\n        cross_replace_steps[\"default_\"] = (0.0, 1.0)\n    alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)\n    for i in range(len(prompts) - 1):\n        alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps[\"default_\"], i)\n    for key, item in cross_replace_steps.items():\n        if key != \"default_\":\n            inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]\n            for i, ind in enumerate(inds):\n                if len(ind) > 0:\n                    alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)\n    alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)\n    return alpha_time_words\n\n\n### util functions for LocalBlend and ReplacementEdit\ndef get_word_inds(text: str, word_place: int, tokenizer):\n    split_text = text.split(\" \")\n    if isinstance(word_place, str):\n        word_place = [i for i, word in enumerate(split_text) if word_place == word]\n    elif isinstance(word_place, int):\n        word_place = [word_place]\n    out = []\n    if len(word_place) > 0:\n        words_encode = [tokenizer.decode([item]).strip(\"#\") for item in tokenizer.encode(text)][1:-1]\n        cur_len, ptr = 0, 0\n\n        for i in range(len(words_encode)):\n            cur_len += len(words_encode[i])\n            if ptr in word_place:\n                out.append(i + 1)\n            if cur_len >= len(split_text[ptr]):\n                ptr += 1\n                cur_len = 0\n    return np.array(out)\n\n\n### util functions for ReplacementEdit\ndef get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):\n    words_x = x.split(\" \")\n    words_y = y.split(\" \")\n    if len(words_x) != len(words_y):\n        raise ValueError(\n            f\"attention replacement edit can only be applied on prompts with the same length\"\n            f\" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.\"\n        )\n    inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]\n    inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]\n    inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]\n    mapper = np.zeros((max_len, max_len))\n    i = j = 0\n    cur_inds = 0\n    while i < max_len and j < max_len:\n        if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:\n            inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]\n            if len(inds_source_) == len(inds_target_):\n                mapper[inds_source_, inds_target_] = 1\n            else:\n                ratio = 1 / len(inds_target_)\n                for i_t in inds_target_:\n                    mapper[inds_source_, i_t] = ratio\n            cur_inds += 1\n            i += len(inds_source_)\n            j += len(inds_target_)\n        elif cur_inds < len(inds_source):\n            mapper[i, j] = 1\n            i += 1\n            j += 1\n        else:\n            mapper[j, j] = 1\n            i += 1\n            j += 1\n\n    return torch.from_numpy(mapper).float()\n\n\ndef get_replacement_mapper(prompts, tokenizer, max_len=77):\n    x_seq = prompts[0]\n    mappers = []\n    for i in range(1, len(prompts)):\n        mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)\n        mappers.append(mapper)\n    return torch.stack(mappers)\n\n\n### util functions for ReweightEdit\ndef get_equalizer(\n    text: str,\n    word_select: Union[int, Tuple[int, ...]],\n    values: Union[List[float], Tuple[float, ...]],\n    tokenizer,\n):\n    if isinstance(word_select, (int, str)):\n        word_select = (word_select,)\n    equalizer = torch.ones(len(values), 77)\n    values = torch.tensor(values, dtype=torch.float32)\n    for word in word_select:\n        inds = get_word_inds(text, word, tokenizer)\n        equalizer[:, inds] = values\n    return equalizer\n\n\n### util functions for RefinementEdit\nclass ScoreParams:\n    def __init__(self, gap, match, mismatch):\n        self.gap = gap\n        self.match = match\n        self.mismatch = mismatch\n\n    def mis_match_char(self, x, y):\n        if x != y:\n            return self.mismatch\n        else:\n            return self.match\n\n\ndef get_matrix(size_x, size_y, gap):\n    matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)\n    matrix[0, 1:] = (np.arange(size_y) + 1) * gap\n    matrix[1:, 0] = (np.arange(size_x) + 1) * gap\n    return matrix\n\n\ndef get_traceback_matrix(size_x, size_y):\n    matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)\n    matrix[0, 1:] = 1\n    matrix[1:, 0] = 2\n    matrix[0, 0] = 4\n    return matrix\n\n\ndef global_align(x, y, score):\n    matrix = get_matrix(len(x), len(y), score.gap)\n    trace_back = get_traceback_matrix(len(x), len(y))\n    for i in range(1, len(x) + 1):\n        for j in range(1, len(y) + 1):\n            left = matrix[i, j - 1] + score.gap\n            up = matrix[i - 1, j] + score.gap\n            diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])\n            matrix[i, j] = max(left, up, diag)\n            if matrix[i, j] == left:\n                trace_back[i, j] = 1\n            elif matrix[i, j] == up:\n                trace_back[i, j] = 2\n            else:\n                trace_back[i, j] = 3\n    return matrix, trace_back\n\n\ndef get_aligned_sequences(x, y, trace_back):\n    x_seq = []\n    y_seq = []\n    i = len(x)\n    j = len(y)\n    mapper_y_to_x = []\n    while i > 0 or j > 0:\n        if trace_back[i, j] == 3:\n            x_seq.append(x[i - 1])\n            y_seq.append(y[j - 1])\n            i = i - 1\n            j = j - 1\n            mapper_y_to_x.append((j, i))\n        elif trace_back[i][j] == 1:\n            x_seq.append(\"-\")\n            y_seq.append(y[j - 1])\n            j = j - 1\n            mapper_y_to_x.append((j, -1))\n        elif trace_back[i][j] == 2:\n            x_seq.append(x[i - 1])\n            y_seq.append(\"-\")\n            i = i - 1\n        elif trace_back[i][j] == 4:\n            break\n    mapper_y_to_x.reverse()\n    return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)\n\n\ndef get_mapper(x: str, y: str, tokenizer, max_len=77):\n    x_seq = tokenizer.encode(x)\n    y_seq = tokenizer.encode(y)\n    score = ScoreParams(0, 1, -1)\n    matrix, trace_back = global_align(x_seq, y_seq, score)\n    mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]\n    alphas = torch.ones(max_len)\n    alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()\n    mapper = torch.zeros(max_len, dtype=torch.int64)\n    mapper[: mapper_base.shape[0]] = mapper_base[:, 1]\n    mapper[mapper_base.shape[0] :] = len(y_seq) + torch.arange(max_len - len(y_seq))\n    return mapper, alphas\n\n\ndef get_refinement_mapper(prompts, tokenizer, max_len=77):\n    x_seq = prompts[0]\n    mappers, alphas = [], []\n    for i in range(1, len(prompts)):\n        mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)\n        mappers.append(mapper)\n        alphas.append(alpha)\n    return torch.stack(mappers), torch.stack(alphas)\n"
  },
  {
    "path": "examples/community/pipeline_sdxl_style_aligned.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# Based on [Style Aligned Image Generation via Shared Attention](https://huggingface.co/papers/2312.02133).\n# Authors: Amir Hertz, Andrey Voynov, Shlomi Fruchter, Daniel Cohen-Or\n# Project Page: https://style-aligned-gen.github.io/\n# Code: https://github.com/google/style-aligned\n#\n# Adapted to Diffusers by [Aryan V S](https://github.com/a-r-r-o-w/).\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom PIL import Image\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTextModel,\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel\nfrom diffusers.models.attention_processor import (\n    Attention,\n    AttnProcessor2_0,\n)\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    is_invisible_watermark_available,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_invisible_watermark_available():\n    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> from typing import List\n\n        >>> import torch\n        >>> from diffusers.pipelines.pipeline_utils import DiffusionPipeline\n        >>> from PIL import Image\n\n        >>> model_id = \"a-r-r-o-w/dreamshaper-xl-turbo\"\n        >>> pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant=\"fp16\", custom_pipeline=\"pipeline_sdxl_style_aligned\")\n        >>> pipe = pipe.to(\"cuda\")\n\n        # Enable memory saving techniques\n        >>> pipe.enable_vae_slicing()\n        >>> pipe.enable_vae_tiling()\n\n        >>> prompt = [\n        ...     \"a toy train. macro photo. 3d game asset\",\n        ...     \"a toy airplane. macro photo. 3d game asset\",\n        ...     \"a toy bicycle. macro photo. 3d game asset\",\n        ...     \"a toy car. macro photo. 3d game asset\",\n        ... ]\n        >>> negative_prompt = \"low quality, worst quality, \"\n\n        >>> # Enable StyleAligned\n        >>> pipe.enable_style_aligned(\n        ...     share_group_norm=False,\n        ...     share_layer_norm=False,\n        ...     share_attention=True,\n        ...     adain_queries=True,\n        ...     adain_keys=True,\n        ...     adain_values=False,\n        ...     full_attention_share=False,\n        ...     shared_score_scale=1.0,\n        ...     shared_score_shift=0.0,\n        ...     only_self_level=0.0,\n        >>> )\n\n        >>> # Run inference\n        >>> images = pipe(\n        ...     prompt=prompt,\n        ...     negative_prompt=negative_prompt,\n        ...     guidance_scale=2,\n        ...     height=1024,\n        ...     width=1024,\n        ...     num_inference_steps=10,\n        ...     generator=torch.Generator().manual_seed(42),\n        >>> ).images\n\n        >>> # Disable StyleAligned if you do not wish to use it anymore\n        >>> pipe.disable_style_aligned()\n        ```\n\"\"\"\n\n\ndef expand_first(feat: torch.Tensor, scale: float = 1.0) -> torch.Tensor:\n    b = feat.shape[0]\n    feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)\n    if scale == 1:\n        feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])\n    else:\n        feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)\n        feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)\n    return feat_style.reshape(*feat.shape)\n\n\ndef concat_first(feat: torch.Tensor, dim: int = 2, scale: float = 1.0) -> torch.Tensor:\n    feat_style = expand_first(feat, scale=scale)\n    return torch.cat((feat, feat_style), dim=dim)\n\n\ndef calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:\n    feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()\n    feat_mean = feat.mean(dim=-2, keepdims=True)\n    return feat_mean, feat_std\n\n\ndef adain(feat: torch.Tensor) -> torch.Tensor:\n    feat_mean, feat_std = calc_mean_std(feat)\n    feat_style_mean = expand_first(feat_mean)\n    feat_style_std = expand_first(feat_std)\n    feat = (feat - feat_mean) / feat_std\n    feat = feat * feat_style_std + feat_style_mean\n    return feat\n\n\ndef get_switch_vec(total_num_layers, level):\n    if level == 0:\n        return torch.zeros(total_num_layers, dtype=torch.bool)\n    if level == 1:\n        return torch.ones(total_num_layers, dtype=torch.bool)\n    to_flip = level > 0.5\n    if to_flip:\n        level = 1 - level\n    num_switch = int(level * total_num_layers)\n    vec = torch.arange(total_num_layers)\n    vec = vec % (total_num_layers // num_switch)\n    vec = vec == 0\n    if to_flip:\n        vec = ~vec\n    return vec\n\n\nclass SharedAttentionProcessor(AttnProcessor2_0):\n    def __init__(\n        self,\n        share_attention: bool = True,\n        adain_queries: bool = True,\n        adain_keys: bool = True,\n        adain_values: bool = False,\n        full_attention_share: bool = False,\n        shared_score_scale: float = 1.0,\n        shared_score_shift: float = 0.0,\n    ):\n        r\"\"\"Shared Attention Processor as proposed in the StyleAligned paper.\"\"\"\n        super().__init__()\n        self.share_attention = share_attention\n        self.adain_queries = adain_queries\n        self.adain_keys = adain_keys\n        self.adain_values = adain_values\n        self.full_attention_share = full_attention_share\n        self.shared_score_scale = shared_score_scale\n        self.shared_score_shift = shared_score_shift\n\n    def shifted_scaled_dot_product_attention(\n        self, attn: Attention, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor\n    ) -> torch.Tensor:\n        logits = torch.einsum(\"bhqd,bhkd->bhqk\", query, key) * attn.scale\n        logits[:, :, :, query.shape[2] :] += self.shared_score_shift\n        probs = logits.softmax(-1)\n        return torch.einsum(\"bhqk,bhkd->bhqd\", probs, value)\n\n    def shared_call(\n        self,\n        attn: Attention,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        **kwargs,\n    ):\n        residual = hidden_states\n        input_ndim = hidden_states.ndim\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n        key = attn.to_k(hidden_states)\n        value = attn.to_v(hidden_states)\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        if self.adain_queries:\n            query = adain(query)\n        if self.adain_keys:\n            key = adain(key)\n        if self.adain_values:\n            value = adain(value)\n        if self.share_attention:\n            key = concat_first(key, -2, scale=self.shared_score_scale)\n            value = concat_first(value, -2)\n            if self.shared_score_shift != 0:\n                hidden_states = self.shifted_scaled_dot_product_attention(attn, query, key, value)\n            else:\n                hidden_states = F.scaled_dot_product_attention(\n                    query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n                )\n        else:\n            hidden_states = F.scaled_dot_product_attention(\n                query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n            )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n        return hidden_states\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        **kwargs,\n    ):\n        if self.full_attention_share:\n            b, n, d = hidden_states.shape\n            k = 2\n            hidden_states = hidden_states.view(k, b, n, d).permute(0, 1, 3, 2).contiguous().view(-1, n, d)\n            # hidden_states = einops.rearrange(hidden_states, \"(k b) n d -> k (b n) d\", k=2)\n            hidden_states = super().__call__(\n                attn,\n                hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=attention_mask,\n                **kwargs,\n            )\n            hidden_states = hidden_states.view(k, b, n, d).permute(0, 1, 3, 2).contiguous().view(-1, n, d)\n            # hidden_states = einops.rearrange(hidden_states, \"k (b n) d -> (k b) n d\", n=n)\n        else:\n            hidden_states = self.shared_call(attn, hidden_states, hidden_states, attention_mask, **kwargs)\n\n        return hidden_states\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\nclass StyleAlignedSDXLPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    FromSingleFileMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n    IPAdapterMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion XL.\n\n    This pipeline also adds experimental support for [StyleAligned](https://huggingface.co/papers/2312.02133). It can\n    be enabled/disabled using `.enable_style_aligned()` or `.disable_style_aligned()` respectively.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion XL uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([` CLIPTextModelWithProjection`]):\n            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\n            specifically the\n            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)\n            variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`CLIPTokenizer`):\n            Second Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\n            `stabilityai/stable-diffusion-xl-base-1-0`.\n        add_watermarker (`bool`, *optional*):\n            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to\n            watermark output images. If not defined, it will default to True if the package is installed, otherwise no\n            watermarker will be used.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->image_encoder->unet->vae\"\n    _optional_components = [\n        \"tokenizer\",\n        \"tokenizer_2\",\n        \"text_encoder\",\n        \"text_encoder_2\",\n        \"image_encoder\",\n        \"feature_extractor\",\n    ]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"add_text_embeds\",\n        \"add_time_ids\",\n        \"negative_pooled_prompt_embeds\",\n        \"negative_add_time_ids\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        text_encoder_2: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        tokenizer_2: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        feature_extractor: CLIPImageProcessor = None,\n        force_zeros_for_empty_prompt: bool = True,\n        add_watermarker: Optional[bool] = None,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            unet=unet,\n            scheduler=scheduler,\n            image_encoder=image_encoder,\n            feature_extractor=feature_extractor,\n        )\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.mask_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True\n        )\n\n        self.default_sample_size = (\n            self.unet.config.sample_size\n            if hasattr(self, \"unet\") and self.unet is not None and hasattr(self.unet.config, \"sample_size\")\n            else 128\n        )\n\n        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()\n\n        if add_watermarker:\n            self.watermark = StableDiffusionXLWatermarker()\n        else:\n            self.watermark = None\n\n    def encode_prompt(\n        self,\n        prompt: str,\n        prompt_2: str | None = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder, lora_scale)\n\n            if self.text_encoder_2 is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]\n        text_encoders = (\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\n        )\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # textual inversion: process multi-vector tokens if necessary\n            prompt_embeds_list = []\n            prompts = [prompt, prompt_2]\n            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                text_input_ids = text_inputs.input_ids\n                untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                    text_input_ids, untruncated_ids\n                ):\n                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                    logger.warning(\n                        \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                        f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                    )\n\n                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:\n                    pooled_prompt_embeds = prompt_embeds[0]\n\n                if clip_skip is None:\n                    prompt_embeds = prompt_embeds.hidden_states[-2]\n                else:\n                    # \"2\" because SDXL always indexes from the penultimate layer.\n                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\n\n            # normalize str to list\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n            negative_prompt_2 = (\n                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2\n            )\n\n            uncond_tokens: List[str]\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = [negative_prompt, negative_prompt_2]\n\n            negative_prompt_embeds_list = []\n            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    negative_prompt,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                negative_prompt_embeds = text_encoder(\n                    uncond_input.input_ids.to(device),\n                    output_hidden_states=True,\n                )\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:\n                    negative_pooled_prompt_embeds = negative_prompt_embeds[0]\n                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n\n        if self.text_encoder_2 is not None:\n            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n        else:\n            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            if self.text_encoder_2 is not None:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n            else:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):\n        # get the original timestep using init_timestep\n        if denoising_start is None:\n            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n            t_start = max(num_inference_steps - init_timestep, 0)\n        else:\n            t_start = 0\n\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        # Strength is irrelevant if we directly request a timestep to start at;\n        # that is, strength is determined by the denoising_start instead.\n        if denoising_start is not None:\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_start * self.scheduler.config.num_train_timesteps)\n                )\n            )\n\n            num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()\n            if self.scheduler.order == 2 and num_inference_steps % 2 == 0:\n                # if the scheduler is a 2nd order scheduler we might have to do +1\n                # because `num_inference_steps` might be even given that every timestep\n                # (except the highest one) is duplicated. If `num_inference_steps` is even it would\n                # mean that we cut the timesteps in the middle of the denoising step\n                # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1\n                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler\n                num_inference_steps = num_inference_steps + 1\n\n            # because t_n+1 >= t_n, we slice the timesteps starting from the end\n            timesteps = timesteps[-num_inference_steps:]\n            return timesteps, num_inference_steps\n\n        return timesteps, num_inference_steps - t_start\n\n    def prepare_latents(\n        self,\n        image,\n        mask,\n        width,\n        height,\n        num_channels_latents,\n        timestep,\n        batch_size,\n        num_images_per_prompt,\n        dtype,\n        device,\n        generator=None,\n        add_noise=True,\n        latents=None,\n        is_strength_max=True,\n        return_noise=False,\n        return_image_latents=False,\n    ):\n        batch_size *= num_images_per_prompt\n\n        if image is None:\n            shape = (\n                batch_size,\n                num_channels_latents,\n                int(height) // self.vae_scale_factor,\n                int(width) // self.vae_scale_factor,\n            )\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            if latents is None:\n                latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            else:\n                latents = latents.to(device)\n\n            # scale the initial noise by the standard deviation required by the scheduler\n            latents = latents * self.scheduler.init_noise_sigma\n            return latents\n\n        elif mask is None:\n            if not isinstance(image, (torch.Tensor, Image.Image, list)):\n                raise ValueError(\n                    f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n                )\n\n            # Offload text encoder if `enable_model_cpu_offload` was enabled\n            if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n                self.text_encoder_2.to(\"cpu\")\n                torch.cuda.empty_cache()\n\n            image = image.to(device=device, dtype=dtype)\n\n            if image.shape[1] == 4:\n                init_latents = image\n\n            else:\n                # make sure the VAE is in float32 mode, as it overflows in float16\n                if self.vae.config.force_upcast:\n                    image = image.float()\n                    self.vae.to(dtype=torch.float32)\n\n                if isinstance(generator, list) and len(generator) != batch_size:\n                    raise ValueError(\n                        f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                        f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                    )\n\n                elif isinstance(generator, list):\n                    init_latents = [\n                        retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                        for i in range(batch_size)\n                    ]\n                    init_latents = torch.cat(init_latents, dim=0)\n                else:\n                    init_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n                if self.vae.config.force_upcast:\n                    self.vae.to(dtype)\n\n                init_latents = init_latents.to(dtype)\n                init_latents = self.vae.config.scaling_factor * init_latents\n\n            if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n                # expand init_latents for batch_size\n                additional_image_per_prompt = batch_size // init_latents.shape[0]\n                init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n            elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n                raise ValueError(\n                    f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n                )\n            else:\n                init_latents = torch.cat([init_latents], dim=0)\n\n            if add_noise:\n                shape = init_latents.shape\n                noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n                # get latents\n                init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n\n            latents = init_latents\n            return latents\n\n        else:\n            shape = (\n                batch_size,\n                num_channels_latents,\n                int(height) // self.vae_scale_factor,\n                int(width) // self.vae_scale_factor,\n            )\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            if (image is None or timestep is None) and not is_strength_max:\n                raise ValueError(\n                    \"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise.\"\n                    \"However, either the image or the noise timestep has not been provided.\"\n                )\n\n            if image.shape[1] == 4:\n                image_latents = image.to(device=device, dtype=dtype)\n                image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)\n            elif return_image_latents or (latents is None and not is_strength_max):\n                image = image.to(device=device, dtype=dtype)\n                image_latents = self._encode_vae_image(image=image, generator=generator)\n                image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)\n\n            if latents is None and add_noise:\n                noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n                # if strength is 1. then initialise the latents to noise, else initial to image + noise\n                latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)\n                # if pure noise then scale the initial latents by the  Scheduler's init sigma\n                latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents\n            elif add_noise:\n                noise = latents.to(device)\n                latents = noise * self.scheduler.init_noise_sigma\n            else:\n                noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n                latents = image_latents.to(device)\n\n            outputs = (latents,)\n\n            if return_noise:\n                outputs += (noise,)\n\n            if return_image_latents:\n                outputs += (image_latents,)\n\n            return outputs\n\n    def prepare_mask_latents(\n        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance\n    ):\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\n        # and half precision\n        mask = torch.nn.functional.interpolate(\n            mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)\n        )\n        mask = mask.to(device=device, dtype=dtype)\n\n        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\n        if mask.shape[0] < batch_size:\n            if not batch_size % mask.shape[0] == 0:\n                raise ValueError(\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\n                    f\" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number\"\n                    \" of masks that you pass is divisible by the total requested batch size.\"\n                )\n            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)\n\n        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask\n\n        if masked_image is not None and masked_image.shape[1] == 4:\n            masked_image_latents = masked_image\n        else:\n            masked_image_latents = None\n\n        if masked_image is not None:\n            if masked_image_latents is None:\n                masked_image = masked_image.to(device=device, dtype=dtype)\n                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)\n\n            if masked_image_latents.shape[0] < batch_size:\n                if not batch_size % masked_image_latents.shape[0] == 0:\n                    raise ValueError(\n                        \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                        f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\n                        \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                    )\n                masked_image_latents = masked_image_latents.repeat(\n                    batch_size // masked_image_latents.shape[0], 1, 1, 1\n                )\n\n            masked_image_latents = (\n                torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents\n            )\n\n            # aligning device to prevent device errors when concating it with the latent model input\n            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\n\n        return mask, masked_image_latents\n\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\n        dtype = image.dtype\n        if self.vae.config.force_upcast:\n            image = image.float()\n            self.vae.to(dtype=torch.float32)\n\n        if isinstance(generator, list):\n            image_latents = [\n                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                for i in range(image.shape[0])\n            ]\n            image_latents = torch.cat(image_latents, dim=0)\n        else:\n            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n        if self.vae.config.force_upcast:\n            self.vae.to(dtype)\n\n        image_latents = image_latents.to(dtype)\n        image_latents = self.vae.config.scaling_factor * image_latents\n\n        return image_latents\n\n    def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n\n        passed_add_embed_dim = (\n            self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim\n        )\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        return add_time_ids\n\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    def _enable_shared_attention_processors(\n        self,\n        share_attention: bool,\n        adain_queries: bool,\n        adain_keys: bool,\n        adain_values: bool,\n        full_attention_share: bool,\n        shared_score_scale: float,\n        shared_score_shift: float,\n        only_self_level: float,\n    ):\n        r\"\"\"Helper method to enable usage of Shared Attention Processor.\"\"\"\n        attn_procs = {}\n        num_self_layers = len([name for name in self.unet.attn_processors.keys() if \"attn1\" in name])\n\n        only_self_vec = get_switch_vec(num_self_layers, only_self_level)\n\n        for i, name in enumerate(self.unet.attn_processors.keys()):\n            is_self_attention = \"attn1\" in name\n            if is_self_attention:\n                if only_self_vec[i // 2]:\n                    attn_procs[name] = AttnProcessor2_0()\n                else:\n                    attn_procs[name] = SharedAttentionProcessor(\n                        share_attention=share_attention,\n                        adain_queries=adain_queries,\n                        adain_keys=adain_keys,\n                        adain_values=adain_values,\n                        full_attention_share=full_attention_share,\n                        shared_score_scale=shared_score_scale,\n                        shared_score_shift=shared_score_shift,\n                    )\n            else:\n                attn_procs[name] = AttnProcessor2_0()\n\n        self.unet.set_attn_processor(attn_procs)\n\n    def _disable_shared_attention_processors(self):\n        r\"\"\"\n        Helper method to disable usage of the Shared Attention Processor. All processors\n        are reset to the default Attention Processor for pytorch versions above 2.0.\n        \"\"\"\n        attn_procs = {}\n\n        for i, name in enumerate(self.unet.attn_processors.keys()):\n            attn_procs[name] = AttnProcessor2_0()\n\n        self.unet.set_attn_processor(attn_procs)\n\n    def _register_shared_norm(self, share_group_norm: bool = True, share_layer_norm: bool = True):\n        r\"\"\"Helper method to register shared group/layer normalization layers.\"\"\"\n\n        def register_norm_forward(norm_layer: Union[nn.GroupNorm, nn.LayerNorm]) -> Union[nn.GroupNorm, nn.LayerNorm]:\n            if not hasattr(norm_layer, \"orig_forward\"):\n                setattr(norm_layer, \"orig_forward\", norm_layer.forward)\n            orig_forward = norm_layer.orig_forward\n\n            def forward_(hidden_states: torch.Tensor) -> torch.Tensor:\n                n = hidden_states.shape[-2]\n                hidden_states = concat_first(hidden_states, dim=-2)\n                hidden_states = orig_forward(hidden_states)\n                return hidden_states[..., :n, :]\n\n            norm_layer.forward = forward_\n            return norm_layer\n\n        def get_norm_layers(pipeline_, norm_layers_: Dict[str, List[Union[nn.GroupNorm, nn.LayerNorm]]]):\n            if isinstance(pipeline_, nn.LayerNorm) and share_layer_norm:\n                norm_layers_[\"layer\"].append(pipeline_)\n            if isinstance(pipeline_, nn.GroupNorm) and share_group_norm:\n                norm_layers_[\"group\"].append(pipeline_)\n            else:\n                for layer in pipeline_.children():\n                    get_norm_layers(layer, norm_layers_)\n\n        norm_layers = {\"group\": [], \"layer\": []}\n        get_norm_layers(self.unet, norm_layers)\n\n        norm_layers_list = []\n        for key in [\"group\", \"layer\"]:\n            for layer in norm_layers[key]:\n                norm_layers_list.append(register_norm_forward(layer))\n\n        return norm_layers_list\n\n    @property\n    def style_aligned_enabled(self):\n        r\"\"\"Returns whether StyleAligned has been enabled in the pipeline or not.\"\"\"\n        return hasattr(self, \"_style_aligned_norm_layers\") and self._style_aligned_norm_layers is not None\n\n    def enable_style_aligned(\n        self,\n        share_group_norm: bool = True,\n        share_layer_norm: bool = True,\n        share_attention: bool = True,\n        adain_queries: bool = True,\n        adain_keys: bool = True,\n        adain_values: bool = False,\n        full_attention_share: bool = False,\n        shared_score_scale: float = 1.0,\n        shared_score_shift: float = 0.0,\n        only_self_level: float = 0.0,\n    ):\n        r\"\"\"\n        Enables the StyleAligned mechanism as in https://huggingface.co/papers/2312.02133.\n\n        Args:\n            share_group_norm (`bool`, defaults to `True`):\n                Whether or not to use shared group normalization layers.\n            share_layer_norm (`bool`, defaults to `True`):\n                Whether or not to use shared layer normalization layers.\n            share_attention (`bool`, defaults to `True`):\n                Whether or not to use attention sharing between batch images.\n            adain_queries (`bool`, defaults to `True`):\n                Whether or not to apply the AdaIn operation on attention queries.\n            adain_keys (`bool`, defaults to `True`):\n                Whether or not to apply the AdaIn operation on attention keys.\n            adain_values (`bool`, defaults to `False`):\n                Whether or not to apply the AdaIn operation on attention values.\n            full_attention_share (`bool`, defaults to `False`):\n                Whether or not to use full attention sharing between all images in a batch. Can\n                lead to content leakage within each batch and some loss in diversity.\n            shared_score_scale (`float`, defaults to `1.0`):\n                Scale for shared attention.\n        \"\"\"\n        self._style_aligned_norm_layers = self._register_shared_norm(share_group_norm, share_layer_norm)\n        self._enable_shared_attention_processors(\n            share_attention=share_attention,\n            adain_queries=adain_queries,\n            adain_keys=adain_keys,\n            adain_values=adain_values,\n            full_attention_share=full_attention_share,\n            shared_score_scale=shared_score_scale,\n            shared_score_shift=shared_score_shift,\n            only_self_level=only_self_level,\n        )\n\n    def disable_style_aligned(self):\n        r\"\"\"Disables the StyleAligned mechanism if it had been previously enabled.\"\"\"\n        if self.style_aligned_enabled:\n            for layer in self._style_aligned_norm_layers:\n                layer.forward = layer.orig_forward\n\n            self._style_aligned_norm_layers = None\n            self._disable_shared_attention_processors()\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def denoising_end(self):\n        return self._denoising_end\n\n    @property\n    def denoising_start(self):\n        return self._denoising_start\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        image: Optional[PipelineImageInput] = None,\n        mask_image: Optional[PipelineImageInput] = None,\n        masked_image_latents: Optional[torch.Tensor] = None,\n        strength: float = 0.3,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        denoising_start: Optional[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise as determined by the discrete timesteps selected by the\n                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a\n                \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*):\n                Optional image input to work with IP Adapters.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead\n                of a plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of\n                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).\n                Guidance rescale factor should fix overexposure when using zero terminal SNR.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple`. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n\n        # 0. Default height and width to unet\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            height=height,\n            width=width,\n            callback_steps=callback_steps,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n        self._denoising_start = denoising_start\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # 4. Preprocess image and mask_image\n        if image is not None:\n            image = self.image_processor.preprocess(image, height=height, width=width)\n            image = image.to(device=self.device, dtype=prompt_embeds.dtype)\n\n        if mask_image is not None:\n            mask = self.mask_processor.preprocess(mask_image, height=height, width=width)\n            mask = mask.to(device=self.device, dtype=prompt_embeds.dtype)\n\n            if masked_image_latents is not None:\n                masked_image = masked_image_latents\n            elif image.shape[1] == 4:\n                # if image is in latent space, we can't mask it\n                masked_image = None\n            else:\n                masked_image = image * (mask < 0.5)\n        else:\n            mask = None\n\n        # 4. Prepare timesteps\n        def denoising_value_valid(dnv):\n            return isinstance(dnv, float) and 0 < dnv < 1\n\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n\n        if image is not None:\n            timesteps, num_inference_steps = self.get_timesteps(\n                num_inference_steps,\n                strength,\n                device,\n                denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,\n            )\n\n            # check that number of inference steps is not < 1 - as this doesn't make sense\n            if num_inference_steps < 1:\n                raise ValueError(\n                    f\"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline\"\n                    f\"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline.\"\n                )\n\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n        is_strength_max = strength == 1.0\n        add_noise = True if self.denoising_start is None else False\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        num_channels_unet = self.unet.config.in_channels\n        return_image_latents = num_channels_unet == 4\n\n        latents = self.prepare_latents(\n            image=image,\n            mask=mask,\n            width=width,\n            height=height,\n            num_channels_latents=num_channels_latents,\n            timestep=latent_timestep,\n            batch_size=batch_size * num_images_per_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            dtype=prompt_embeds.dtype,\n            device=device,\n            generator=generator,\n            add_noise=add_noise,\n            latents=latents,\n            is_strength_max=is_strength_max,\n            return_noise=True,\n            return_image_latents=return_image_latents,\n        )\n\n        if mask is not None:\n            if return_image_latents:\n                latents, noise, image_latents = latents\n            else:\n                latents, noise = latents\n\n            mask, masked_image_latents = self.prepare_mask_latents(\n                mask=mask,\n                masked_image=masked_image,\n                batch_size=batch_size * num_images_per_prompt,\n                height=height,\n                width=width,\n                dtype=prompt_embeds.dtype,\n                device=device,\n                generator=generator,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n            )\n\n            # Check that sizes of mask, masked image and latents match\n            if num_channels_unet == 9:\n                # default case for stable-diffusion-v1-5/stable-diffusion-inpainting\n                num_channels_mask = mask.shape[1]\n                num_channels_masked_image = masked_image_latents.shape[1]\n                if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:\n                    raise ValueError(\n                        f\"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects\"\n                        f\" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +\"\n                        f\" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}\"\n                        f\" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of\"\n                        \" `pipeline.unet` or your `mask_image` or `image` input.\"\n                    )\n            elif num_channels_unet != 4:\n                raise ValueError(\n                    f\"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}.\"\n                )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        height, width = latents.shape[-2:]\n        height = height * self.vae_scale_factor\n        width = width * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 7. Prepare added time ids & embeddings\n        add_text_embeds = pooled_prompt_embeds\n        add_time_ids = self._get_add_time_ids(\n            original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype\n        )\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        if ip_adapter_image is not None:\n            output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True\n            image_embeds, negative_image_embeds = self.encode_image(\n                ip_adapter_image, device, num_images_per_prompt, output_hidden_state\n            )\n            if self.do_classifier_free_guidance:\n                image_embeds = torch.cat([negative_image_embeds, image_embeds])\n                image_embeds = image_embeds.to(device)\n\n        # 8. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # 8.1 Apply denoising_end\n        if (\n            self.denoising_end is not None\n            and isinstance(self.denoising_end, float)\n            and self.denoising_end > 0\n            and self.denoising_end < 1\n        ):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        # 9. Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        self._num_timesteps = len(timesteps)\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n                if ip_adapter_image is not None:\n                    added_cond_kwargs[\"image_embeds\"] = image_embeds\n\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if mask is not None and num_channels_unet == 4:\n                    init_latents_proper = image_latents\n\n                    if self.do_classifier_free_guidance:\n                        init_mask, _ = mask.chunk(2)\n                    else:\n                        init_mask = mask\n\n                    if i < len(timesteps) - 1:\n                        noise_timestep = timesteps[i + 1]\n                        init_latents_proper = self.scheduler.add_noise(\n                            init_latents_proper, noise, torch.tensor([noise_timestep])\n                        )\n\n                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                    negative_pooled_prompt_embeds = callback_outputs.pop(\n                        \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                    )\n                    add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n\n        if not output_type == \"latent\":\n            # apply watermark if available\n            if self.watermark is not None:\n                image = self.watermark.apply_watermark(image)\n\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_3_differential_img2img.py",
    "content": "# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Callable, Dict, List, Optional, Union\n\nimport torch\nfrom transformers import (\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n    T5EncoderModel,\n    T5TokenizerFast,\n)\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.models.autoencoders import AutoencoderKL\nfrom diffusers.models.transformers import SD3Transformer2DModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import is_torch_xla_available, logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n\n        >>> from diffusers import AutoPipelineForImage2Image\n        >>> from diffusers.utils import load_image\n\n        >>> device = \"cuda\"\n        >>> model_id_or_path = \"stabilityai/stable-diffusion-3-medium-diffusers\"\n        >>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16)\n        >>> pipe = pipe.to(device)\n\n        >>> url = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\n        >>> init_image = load_image(url).resize((512, 512))\n\n        >>> prompt = \"cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k\"\n\n        >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0]\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass StableDiffusion3DifferentialImg2ImgPipeline(DiffusionPipeline):\n    r\"\"\"\n    Args:\n        transformer ([`SD3Transformer2DModel`]):\n            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModelWithProjection`]):\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\n            specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,\n            with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`\n            as its dimension.\n        text_encoder_2 ([`CLIPTextModelWithProjection`]):\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\n            specifically the\n            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)\n            variant.\n        text_encoder_3 ([`T5EncoderModel`]):\n            Frozen text-encoder. Stable Diffusion 3 uses\n            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the\n            [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`CLIPTokenizer`):\n            Second Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_3 (`T5TokenizerFast`):\n            Tokenizer of class\n            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->text_encoder_3->transformer->vae\"\n    _optional_components = []\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\", \"negative_pooled_prompt_embeds\"]\n\n    def __init__(\n        self,\n        transformer: SD3Transformer2DModel,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        text_encoder_2: CLIPTextModelWithProjection,\n        tokenizer_2: CLIPTokenizer,\n        text_encoder_3: T5EncoderModel,\n        tokenizer_3: T5TokenizerFast,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            text_encoder_3=text_encoder_3,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            tokenizer_3=tokenizer_3,\n            transformer=transformer,\n            scheduler=scheduler,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels\n        )\n        self.mask_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True\n        )\n\n        self.tokenizer_max_length = self.tokenizer.model_max_length\n        self.default_sample_size = self.transformer.config.sample_size\n\n    # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_images_per_prompt: int = 1,\n        max_sequence_length: int = 256,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        if self.text_encoder_3 is None:\n            return torch.zeros(\n                (\n                    batch_size * num_images_per_prompt,\n                    self.tokenizer_max_length,\n                    self.transformer.config.joint_attention_dim,\n                ),\n                device=device,\n                dtype=dtype,\n            )\n\n        text_inputs = self.tokenizer_3(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            add_special_tokens=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer_3(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]\n\n        dtype = self.text_encoder_3.dtype\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        _, seq_len, _ = prompt_embeds.shape\n\n        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds\n    def _get_clip_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]],\n        num_images_per_prompt: int = 1,\n        device: Optional[torch.device] = None,\n        clip_skip: Optional[int] = None,\n        clip_model_index: int = 0,\n    ):\n        device = device or self._execution_device\n\n        clip_tokenizers = [self.tokenizer, self.tokenizer_2]\n        clip_text_encoders = [self.text_encoder, self.text_encoder_2]\n\n        tokenizer = clip_tokenizers[clip_model_index]\n        text_encoder = clip_text_encoders[clip_model_index]\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer_max_length} tokens: {removed_text}\"\n            )\n        prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n        pooled_prompt_embeds = prompt_embeds[0]\n\n        if clip_skip is None:\n            prompt_embeds = prompt_embeds.hidden_states[-2]\n        else:\n            prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        _, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n        return prompt_embeds, pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        prompt_2: Union[str, List[str]],\n        prompt_3: Union[str, List[str]],\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        negative_prompt_3: Optional[Union[str, List[str]]] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        clip_skip: Optional[int] = None,\n        max_sequence_length: int = 256,\n    ):\n        r\"\"\"\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in all text-encoders\n            prompt_3 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is\n                used in all text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and\n                `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            prompt_3 = prompt_3 or prompt\n            prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3\n\n            prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(\n                prompt=prompt,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n                clip_skip=clip_skip,\n                clip_model_index=0,\n            )\n            prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(\n                prompt=prompt_2,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n                clip_skip=clip_skip,\n                clip_model_index=1,\n            )\n            clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)\n\n            t5_prompt_embed = self._get_t5_prompt_embeds(\n                prompt=prompt_3,\n                num_images_per_prompt=num_images_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n            )\n\n            clip_prompt_embeds = torch.nn.functional.pad(\n                clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])\n            )\n\n            prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)\n            pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)\n\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\n            negative_prompt_3 = negative_prompt_3 or negative_prompt\n\n            # normalize str to list\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n            negative_prompt_2 = (\n                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2\n            )\n            negative_prompt_3 = (\n                batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3\n            )\n\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n\n            negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(\n                negative_prompt,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n                clip_skip=None,\n                clip_model_index=0,\n            )\n            negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(\n                negative_prompt_2,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n                clip_skip=None,\n                clip_model_index=1,\n            )\n            negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)\n\n            t5_negative_prompt_embed = self._get_t5_prompt_embeds(\n                prompt=negative_prompt_3,\n                num_images_per_prompt=num_images_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n            )\n\n            negative_clip_prompt_embeds = torch.nn.functional.pad(\n                negative_clip_prompt_embeds,\n                (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),\n            )\n\n            negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)\n            negative_pooled_prompt_embeds = torch.cat(\n                [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1\n            )\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        prompt_3,\n        strength,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        negative_prompt_3=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        max_sequence_length=None,\n    ):\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_3 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n        elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):\n            raise ValueError(f\"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_3 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n        if max_sequence_length is not None and max_sequence_length > 512:\n            raise ValueError(f\"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}\")\n\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(num_inference_steps * strength, num_inference_steps)\n\n        t_start = int(max(num_inference_steps - init_timestep, 0))\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        return timesteps, num_inference_steps - t_start\n\n    def prepare_latents(\n        self, batch_size, num_channels_latents, height, width, image, timestep, dtype, device, generator=None\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n\n        image = image.to(device=device, dtype=dtype)\n\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n        elif isinstance(generator, list):\n            init_latents = [\n                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)\n            ]\n            init_latents = torch.cat(init_latents, dim=0)\n        else:\n            init_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n        init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        shape = init_latents.shape\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)\n        latents = init_latents.to(device=device, dtype=dtype)\n\n        return latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        prompt_3: Optional[Union[str, List[str]]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        image: PipelineImageInput = None,\n        strength: float = 0.6,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        guidance_scale: float = 7.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        negative_prompt_3: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 256,\n        map: PipelineImageInput = None,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                will be used instead\n            prompt_3 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is\n                will be used instead\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used instead\n            negative_prompt_3 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and\n                `text_encoder_3`. If not defined, `negative_prompt` is used instead\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead\n                of a plain tuple.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple`. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        # 0. Default height and width\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            prompt_3,\n            strength,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            negative_prompt_3=negative_prompt_3,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            max_sequence_length=max_sequence_length,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            prompt_3=prompt_3,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            negative_prompt_3=negative_prompt_3,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            device=device,\n            clip_skip=self.clip_skip,\n            num_images_per_prompt=num_images_per_prompt,\n            max_sequence_length=max_sequence_length,\n        )\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)\n\n        # 3. Preprocess image\n        init_image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n\n        map = self.mask_processor.preprocess(\n            map, height=height // self.vae_scale_factor, width=width // self.vae_scale_factor\n        ).to(device)\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n\n        # begin diff diff change\n        total_time_steps = num_inference_steps\n        # end diff diff change\n\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.transformer.config.in_channels\n        if latents is None:\n            latents = self.prepare_latents(\n                batch_size * num_images_per_prompt,\n                num_channels_latents,\n                height,\n                width,\n                init_image,\n                latent_timestep,\n                prompt_embeds.dtype,\n                device,\n                generator,\n            )\n\n        # 6. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n        self._num_timesteps = len(timesteps)\n\n        # preparations for diff diff\n        original_with_noise = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            init_image,\n            timesteps,\n            prompt_embeds.dtype,\n            device,\n            generator,\n        )\n        thresholds = torch.arange(total_time_steps, dtype=map.dtype) / total_time_steps\n        thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device)\n        masks = map.squeeze() > thresholds\n        # end diff diff preparations\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # diff diff\n                if i == 0:\n                    latents = original_with_noise[:1]\n                else:\n                    mask = masks[i].unsqueeze(0).to(latents.dtype)\n                    mask = mask.unsqueeze(1)  # fit shape\n                    latents = original_with_noise[i] * mask + latents * (1 - mask)\n                # end diff diff\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latent_model_input.shape[0])\n\n                noise_pred = self.transformer(\n                    hidden_states=latent_model_input,\n                    timestep=timestep,\n                    encoder_hidden_states=prompt_embeds,\n                    pooled_projections=pooled_prompt_embeds,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    negative_pooled_prompt_embeds = callback_outputs.pop(\n                        \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                    )\n\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if output_type == \"latent\":\n            image = latents\n\n        else:\n            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor\n\n            image = self.vae.decode(latents, return_dict=False)[0]\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusion3PipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_3_instruct_pix2pix.py",
    "content": "# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport PIL.Image\nimport torch\nfrom transformers import (\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n    SiglipImageProcessor,\n    SiglipVisionModel,\n    T5EncoderModel,\n    T5TokenizerFast,\n)\n\nfrom ...image_processor import PipelineImageInput, VaeImageProcessor\nfrom ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin\nfrom ...models.autoencoders import AutoencoderKL\nfrom ...models.transformers import SD3Transformer2DModel\nfrom ...schedulers import FlowMatchEulerDiscreteScheduler\nfrom ...utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom ...utils.torch_utils import randn_tensor\nfrom ..pipeline_utils import DiffusionPipeline\nfrom .pipeline_output import StableDiffusion3PipelineOutput\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import StableDiffusion3InstructPix2PixPipeline\n        >>> from diffusers.utils import load_image\n\n        >>> resolution = 1024\n        >>> image = load_image(\n        ...     \"https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png\"\n        ... ).resize((resolution, resolution))\n        >>> edit_instruction = \"Turn sky into a cloudy one\"\n\n        >>> pipe = StableDiffusion3InstructPix2PixPipeline.from_pretrained(\n        ...     \"your_own_model_path\", torch_dtype=torch.float16\n        ... ).to(\"cuda\")\n\n        >>> edited_image = pipe(\n        ...     prompt=edit_instruction,\n        ...     image=image,\n        ...     height=resolution,\n        ...     width=resolution,\n        ...     guidance_scale=7.5,\n        ...     image_guidance_scale=1.5,\n        ...     num_inference_steps=30,\n        ... ).images[0]\n        >>> edited_image\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.15,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass StableDiffusion3InstructPix2PixPipeline(\n    DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin\n):\n    r\"\"\"\n    Args:\n        transformer ([`SD3Transformer2DModel`]):\n            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModelWithProjection`]):\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\n            specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,\n            with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`\n            as its dimension.\n        text_encoder_2 ([`CLIPTextModelWithProjection`]):\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\n            specifically the\n            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)\n            variant.\n        text_encoder_3 ([`T5EncoderModel`]):\n            Frozen text-encoder. Stable Diffusion 3 uses\n            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the\n            [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`CLIPTokenizer`):\n            Second Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_3 (`T5TokenizerFast`):\n            Tokenizer of class\n            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).\n        image_encoder (`SiglipVisionModel`, *optional*):\n            Pre-trained Vision Model for IP Adapter.\n        feature_extractor (`SiglipImageProcessor`, *optional*):\n            Image processor for IP Adapter.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae\"\n    _optional_components = [\"image_encoder\", \"feature_extractor\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\", \"negative_pooled_prompt_embeds\"]\n\n    def __init__(\n        self,\n        transformer: SD3Transformer2DModel,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        text_encoder_2: CLIPTextModelWithProjection,\n        tokenizer_2: CLIPTokenizer,\n        text_encoder_3: T5EncoderModel,\n        tokenizer_3: T5TokenizerFast,\n        image_encoder: SiglipVisionModel = None,\n        feature_extractor: SiglipImageProcessor = None,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            text_encoder_3=text_encoder_3,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            tokenizer_3=tokenizer_3,\n            transformer=transformer,\n            scheduler=scheduler,\n            image_encoder=image_encoder,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.tokenizer_max_length = (\n            self.tokenizer.model_max_length if hasattr(self, \"tokenizer\") and self.tokenizer is not None else 77\n        )\n        self.default_sample_size = (\n            self.transformer.config.sample_size\n            if hasattr(self, \"transformer\") and self.transformer is not None\n            else 128\n        )\n        self.patch_size = (\n            self.transformer.config.patch_size if hasattr(self, \"transformer\") and self.transformer is not None else 2\n        )\n\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_images_per_prompt: int = 1,\n        max_sequence_length: int = 256,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        if self.text_encoder_3 is None:\n            return torch.zeros(\n                (\n                    batch_size * num_images_per_prompt,\n                    self.tokenizer_max_length,\n                    self.transformer.config.joint_attention_dim,\n                ),\n                device=device,\n                dtype=dtype,\n            )\n\n        text_inputs = self.tokenizer_3(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            add_special_tokens=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer_3(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]\n\n        dtype = self.text_encoder_3.dtype\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        _, seq_len, _ = prompt_embeds.shape\n\n        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        return prompt_embeds\n\n    def _get_clip_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]],\n        num_images_per_prompt: int = 1,\n        device: Optional[torch.device] = None,\n        clip_skip: Optional[int] = None,\n        clip_model_index: int = 0,\n    ):\n        device = device or self._execution_device\n\n        clip_tokenizers = [self.tokenizer, self.tokenizer_2]\n        clip_text_encoders = [self.text_encoder, self.text_encoder_2]\n\n        tokenizer = clip_tokenizers[clip_model_index]\n        text_encoder = clip_text_encoders[clip_model_index]\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer_max_length} tokens: {removed_text}\"\n            )\n        prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n        pooled_prompt_embeds = prompt_embeds[0]\n\n        if clip_skip is None:\n            prompt_embeds = prompt_embeds.hidden_states[-2]\n        else:\n            prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        _, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n        return prompt_embeds, pooled_prompt_embeds\n\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        prompt_2: Union[str, List[str]],\n        prompt_3: Union[str, List[str]],\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        negative_prompt_3: Optional[Union[str, List[str]]] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        clip_skip: Optional[int] = None,\n        max_sequence_length: int = 256,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in all text-encoders\n            prompt_3 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is\n                used in all text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.\n            negative_prompt_3 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and\n                `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder, lora_scale)\n            if self.text_encoder_2 is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            prompt_3 = prompt_3 or prompt\n            prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3\n\n            prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(\n                prompt=prompt,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n                clip_skip=clip_skip,\n                clip_model_index=0,\n            )\n            prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(\n                prompt=prompt_2,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n                clip_skip=clip_skip,\n                clip_model_index=1,\n            )\n            clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)\n\n            t5_prompt_embed = self._get_t5_prompt_embeds(\n                prompt=prompt_3,\n                num_images_per_prompt=num_images_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n            )\n\n            clip_prompt_embeds = torch.nn.functional.pad(\n                clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])\n            )\n\n            prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)\n            pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)\n\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\n            negative_prompt_3 = negative_prompt_3 or negative_prompt\n\n            # normalize str to list\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n            negative_prompt_2 = (\n                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2\n            )\n            negative_prompt_3 = (\n                batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3\n            )\n\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n\n            negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(\n                negative_prompt,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n                clip_skip=None,\n                clip_model_index=0,\n            )\n            negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(\n                negative_prompt_2,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n                clip_skip=None,\n                clip_model_index=1,\n            )\n            negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)\n\n            t5_negative_prompt_embed = self._get_t5_prompt_embeds(\n                prompt=negative_prompt_3,\n                num_images_per_prompt=num_images_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n            )\n\n            negative_clip_prompt_embeds = torch.nn.functional.pad(\n                negative_clip_prompt_embeds,\n                (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),\n            )\n\n            negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)\n            negative_pooled_prompt_embeds = torch.cat(\n                [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        prompt_3,\n        height,\n        width,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        negative_prompt_3=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        max_sequence_length=None,\n    ):\n        if (\n            height % (self.vae_scale_factor * self.patch_size) != 0\n            or width % (self.vae_scale_factor * self.patch_size) != 0\n        ):\n            raise ValueError(\n                f\"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}.\"\n                f\"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_3 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n        elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):\n            raise ValueError(f\"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_3 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n        if max_sequence_length is not None and max_sequence_length > 512:\n            raise ValueError(f\"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}\")\n\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        if latents is not None:\n            return latents.to(device=device, dtype=dtype)\n\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        return latents\n\n    def prepare_image_latents(\n        self,\n        image,\n        batch_size,\n        num_images_per_prompt,\n        dtype,\n        device,\n        generator,\n        do_classifier_free_guidance,\n    ):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == self.vae.config.latent_channels:\n            image_latents = image\n        else:\n            image_latents = retrieve_latents(self.vae.encode(image), sample_mode=\"argmax\", generator=generator)\n\n        image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor\n\n        if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:\n            # expand image_latents for batch_size\n            deprecation_message = (\n                f\"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial\"\n                \" images (`image`). Initial images are now duplicating to match the number of text prompts. Note\"\n                \" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update\"\n                \" your script to pass as many initial images as text prompts to suppress this warning.\"\n            )\n            deprecate(\"len(prompt) != len(image)\", \"1.0.0\", deprecation_message, standard_warn=False)\n            additional_image_per_prompt = batch_size // image_latents.shape[0]\n            image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            image_latents = torch.cat([image_latents], dim=0)\n\n        if do_classifier_free_guidance:\n            uncond_image_latents = torch.zeros_like(image_latents)\n            image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0)\n\n        return image_latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def image_guidance_scale(self):\n        return self._image_guidance_scale\n\n    @property\n    def skip_guidance_layers(self):\n        return self._skip_guidance_layers\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1.0 and self.image_guidance_scale >= 1.0\n\n    @property\n    def joint_attention_kwargs(self):\n        return self._joint_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image\n    def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:\n        \"\"\"Encodes the given image into a feature representation using a pre-trained image encoder.\n\n        Args:\n            image (`PipelineImageInput`):\n                Input image to be encoded.\n            device: (`torch.device`):\n                Torch device.\n\n        Returns:\n            `torch.Tensor`: The encoded image feature representation.\n        \"\"\"\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=self.dtype)\n\n        return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n\n    # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[torch.Tensor] = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n    ) -> torch.Tensor:\n        \"\"\"Prepares image embeddings for use in the IP-Adapter.\n\n        Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.\n\n        Args:\n            ip_adapter_image (`PipelineImageInput`, *optional*):\n                The input image to extract features from for IP-Adapter.\n            ip_adapter_image_embeds (`torch.Tensor`, *optional*):\n                Precomputed image embeddings.\n            device: (`torch.device`, *optional*):\n                Torch device.\n            num_images_per_prompt (`int`, defaults to 1):\n                Number of images that should be generated per prompt.\n            do_classifier_free_guidance (`bool`, defaults to True):\n                Whether to use classifier free guidance or not.\n        \"\"\"\n        device = device or self._execution_device\n\n        if ip_adapter_image_embeds is not None:\n            if do_classifier_free_guidance:\n                single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)\n            else:\n                single_image_embeds = ip_adapter_image_embeds\n        elif ip_adapter_image is not None:\n            single_image_embeds = self.encode_image(ip_adapter_image, device)\n            if do_classifier_free_guidance:\n                single_negative_image_embeds = torch.zeros_like(single_image_embeds)\n        else:\n            raise ValueError(\"Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.\")\n\n        image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)\n\n        if do_classifier_free_guidance:\n            negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)\n            image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)\n\n        return image_embeds.to(device=device)\n\n    def enable_sequential_cpu_offload(self, *args, **kwargs):\n        if self.image_encoder is not None and \"image_encoder\" not in self._exclude_from_cpu_offload:\n            logger.warning(\n                \"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses \"\n                \"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling \"\n                \"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`.\"\n            )\n\n        super().enable_sequential_cpu_offload(*args, **kwargs)\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        prompt_3: Optional[Union[str, List[str]]] = None,\n        image: PipelineImageInput = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 28,\n        sigmas: Optional[List[float]] = None,\n        guidance_scale: float = 7.0,\n        image_guidance_scale: float = 1.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        negative_prompt_3: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 256,\n        skip_guidance_layers: List[int] = None,\n        skip_layer_guidance_scale: float = 2.8,\n        skip_layer_guidance_stop: float = 0.2,\n        skip_layer_guidance_start: float = 0.01,\n        mu: Optional[float] = None,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                will be used instead\n            prompt_3 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is\n                will be used instead\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            guidance_scale (`float`, *optional*, defaults to 7.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            image_guidance_scale (`float`, *optional*, defaults to 1.5):\n                Image guidance scale is to push the generated image towards the initial image `image`. Image guidance\n                scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to\n                generate images that are closely linked to the source image `image`, usually at the expense of lower\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used instead\n            negative_prompt_3 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and\n                `text_encoder_3`. If not defined, `negative_prompt` is used instead\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            ip_adapter_image (`PipelineImageInput`, *optional*):\n                Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`torch.Tensor`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,\n                emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to\n                `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of\n                a plain tuple.\n            joint_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.\n            skip_guidance_layers (`List[int]`, *optional*):\n                A list of integers that specify layers to skip during guidance. If not provided, all layers will be\n                used for guidance. If provided, the guidance will only be applied to the layers specified in the list.\n                Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].\n            skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in\n                `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`\n                with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers\n                with a scale of `1`.\n            skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in\n                `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in\n                `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by\n                StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.\n            skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in\n                `skip_guidance_layers` will start. The guidance will be applied to the layers specified in\n                `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by\n                StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.\n            mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple`. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            prompt_3,\n            height,\n            width,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            negative_prompt_3=negative_prompt_3,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            max_sequence_length=max_sequence_length,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._image_guidance_scale = image_guidance_scale\n        self._skip_layer_guidance_scale = skip_layer_guidance_scale\n        self._clip_skip = clip_skip\n        self._joint_attention_kwargs = joint_attention_kwargs\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        lora_scale = (\n            self.joint_attention_kwargs.get(\"scale\", None) if self.joint_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            prompt_3=prompt_3,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            negative_prompt_3=negative_prompt_3,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            device=device,\n            clip_skip=self.clip_skip,\n            num_images_per_prompt=num_images_per_prompt,\n            max_sequence_length=max_sequence_length,\n            lora_scale=lora_scale,\n        )\n\n        if self.do_classifier_free_guidance:\n            if skip_guidance_layers is not None:\n                original_prompt_embeds = prompt_embeds\n                original_pooled_prompt_embeds = pooled_prompt_embeds\n            # The extra concat similar to how it's done in SD InstructPix2Pix.\n            prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0)\n            pooled_prompt_embeds = torch.cat(\n                [pooled_prompt_embeds, negative_pooled_prompt_embeds, negative_pooled_prompt_embeds], dim=0\n            )\n\n        # 4. Prepare latent variables\n        num_channels_latents = self.vae.config.latent_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n        # 5. Prepare image latents\n        image = self.image_processor.preprocess(image)\n        image_latents = self.prepare_image_latents(\n            image,\n            batch_size,\n            num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            self.do_classifier_free_guidance,\n        )\n\n        # 6. Check that shapes of latents and image match the DiT (SD3) in_channels\n        num_channels_image = image_latents.shape[1]\n        if num_channels_latents + num_channels_image != self.transformer.config.in_channels:\n            raise ValueError(\n                f\"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects\"\n                f\" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +\"\n                f\" `num_channels_image`: {num_channels_image} \"\n                f\" = {num_channels_latents + num_channels_image}. Please verify the config of\"\n                \" `pipeline.transformer` or your `image` input.\"\n            )\n\n        # 7. Prepare timesteps\n        scheduler_kwargs = {}\n        if self.scheduler.config.get(\"use_dynamic_shifting\", None) and mu is None:\n            _, _, height, width = latents.shape\n            image_seq_len = (height // self.transformer.config.patch_size) * (\n                width // self.transformer.config.patch_size\n            )\n            mu = calculate_shift(\n                image_seq_len,\n                self.scheduler.config.get(\"base_image_seq_len\", 256),\n                self.scheduler.config.get(\"max_image_seq_len\", 4096),\n                self.scheduler.config.get(\"base_shift\", 0.5),\n                self.scheduler.config.get(\"max_shift\", 1.16),\n            )\n            scheduler_kwargs[\"mu\"] = mu\n        elif mu is not None:\n            scheduler_kwargs[\"mu\"] = mu\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            sigmas=sigmas,\n            **scheduler_kwargs,\n        )\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n        self._num_timesteps = len(timesteps)\n\n        # 8. Prepare image embeddings\n        if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:\n            ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n            if self.joint_attention_kwargs is None:\n                self._joint_attention_kwargs = {\"ip_adapter_image_embeds\": ip_adapter_image_embeds}\n            else:\n                self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)\n\n        # 9. Denoising loop\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # expand the latents if we are doing classifier free guidance\n                # The latents are expanded 3 times because for pix2pix the guidance\n                # is applied for both the text and the input image.\n                latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latent_model_input.shape[0])\n                scaled_latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)\n\n                noise_pred = self.transformer(\n                    hidden_states=scaled_latent_model_input,\n                    timestep=timestep,\n                    encoder_hidden_states=prompt_embeds,\n                    pooled_projections=pooled_prompt_embeds,\n                    joint_attention_kwargs=self.joint_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)\n                    noise_pred = (\n                        noise_pred_uncond\n                        + self.guidance_scale * (noise_pred_text - noise_pred_image)\n                        + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond)\n                    )\n                    should_skip_layers = (\n                        True\n                        if i > num_inference_steps * skip_layer_guidance_start\n                        and i < num_inference_steps * skip_layer_guidance_stop\n                        else False\n                    )\n                    if skip_guidance_layers is not None and should_skip_layers:\n                        timestep = t.expand(latents.shape[0])\n                        latent_model_input = latents\n                        noise_pred_skip_layers = self.transformer(\n                            hidden_states=latent_model_input,\n                            timestep=timestep,\n                            encoder_hidden_states=original_prompt_embeds,\n                            pooled_projections=original_pooled_prompt_embeds,\n                            joint_attention_kwargs=self.joint_attention_kwargs,\n                            return_dict=False,\n                            skip_layers=skip_guidance_layers,\n                        )[0]\n                        noise_pred = (\n                            noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale\n                        )\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    negative_pooled_prompt_embeds = callback_outputs.pop(\n                        \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                    )\n                    image_latents = callback_outputs.pop(\"image_latents\", image_latents)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if output_type == \"latent\":\n            image = latents\n\n        else:\n            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor\n            latents = latents.to(dtype=self.vae.dtype)\n\n            image = self.vae.decode(latents, return_dict=False)[0]\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusion3PipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_boxdiff.py",
    "content": "# Copyright 2025 Jingyang Zhang and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport abc\nimport inspect\nimport math\nimport numbers\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom packaging import version\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel\nfrom diffusers.models.attention_processor import Attention, FusedAttnProcessor2_0\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import StableDiffusionPipeline\n\n        >>> pipe = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16)\n        >>> pipe = pipe.to(\"cuda\")\n\n        >>> prompt = \"a photo of an astronaut riding a horse on mars\"\n        >>> image = pipe(prompt).images[0]\n        ```\n\"\"\"\n\n\nclass GaussianSmoothing(nn.Module):\n    \"\"\"\n    Copied from official repo: https://github.com/showlab/BoxDiff/blob/master/utils/gaussian_smoothing.py\n    Apply gaussian smoothing on a\n    1d, 2d or 3d tensor. Filtering is performed separately for each channel\n    in the input using a depthwise convolution.\n    Arguments:\n        channels (int, sequence): Number of channels of the input tensors. Output will\n            have this number of channels as well.\n        kernel_size (int, sequence): Size of the gaussian kernel.\n        sigma (float, sequence): Standard deviation of the gaussian kernel.\n        dim (int, optional): The number of dimensions of the data.\n            Default value is 2 (spatial).\n    \"\"\"\n\n    def __init__(self, channels, kernel_size, sigma, dim=2):\n        super(GaussianSmoothing, self).__init__()\n        if isinstance(kernel_size, numbers.Number):\n            kernel_size = [kernel_size] * dim\n        if isinstance(sigma, numbers.Number):\n            sigma = [sigma] * dim\n\n        # The gaussian kernel is the product of the\n        # gaussian function of each dimension.\n        kernel = 1\n        meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])\n        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):\n            mean = (size - 1) / 2\n            kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))\n\n        # Make sure sum of values in gaussian kernel equals 1.\n        kernel = kernel / torch.sum(kernel)\n\n        # Reshape to depthwise convolutional weight\n        kernel = kernel.view(1, 1, *kernel.size())\n        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))\n\n        self.register_buffer(\"weight\", kernel)\n        self.groups = channels\n\n        if dim == 1:\n            self.conv = F.conv1d\n        elif dim == 2:\n            self.conv = F.conv2d\n        elif dim == 3:\n            self.conv = F.conv3d\n        else:\n            raise RuntimeError(\"Only 1, 2 and 3 dimensions are supported. Received {}.\".format(dim))\n\n    def forward(self, input):\n        \"\"\"\n        Apply gaussian filter to input.\n        Arguments:\n            input (torch.Tensor): Input to apply gaussian filter on.\n        Returns:\n            filtered (torch.Tensor): Filtered output.\n        \"\"\"\n        return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)\n\n\nclass AttendExciteCrossAttnProcessor:\n    def __init__(self, attnstore, place_in_unet):\n        super().__init__()\n        self.attnstore = attnstore\n        self.place_in_unet = place_in_unet\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n    ) -> torch.Tensor:\n        batch_size, sequence_length, _ = hidden_states.shape\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=1)\n        query = attn.to_q(hidden_states)\n\n        is_cross = encoder_hidden_states is not None\n        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        self.attnstore(attention_probs, is_cross, self.place_in_unet)\n\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        return hidden_states\n\n\nclass AttentionControl(abc.ABC):\n    def step_callback(self, x_t):\n        return x_t\n\n    def between_steps(self):\n        return\n\n    # @property\n    # def num_uncond_att_layers(self):\n    #     return 0\n\n    @abc.abstractmethod\n    def forward(self, attn, is_cross: bool, place_in_unet: str):\n        raise NotImplementedError\n\n    def __call__(self, attn, is_cross: bool, place_in_unet: str):\n        if self.cur_att_layer >= self.num_uncond_att_layers:\n            self.forward(attn, is_cross, place_in_unet)\n        self.cur_att_layer += 1\n        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:\n            self.cur_att_layer = 0\n            self.cur_step += 1\n            self.between_steps()\n\n    def reset(self):\n        self.cur_step = 0\n        self.cur_att_layer = 0\n\n    def __init__(self):\n        self.cur_step = 0\n        self.num_att_layers = -1\n        self.cur_att_layer = 0\n\n\nclass AttentionStore(AttentionControl):\n    @staticmethod\n    def get_empty_store():\n        return {\"down_cross\": [], \"mid_cross\": [], \"up_cross\": [], \"down_self\": [], \"mid_self\": [], \"up_self\": []}\n\n    def forward(self, attn, is_cross: bool, place_in_unet: str):\n        key = f\"{place_in_unet}_{'cross' if is_cross else 'self'}\"\n        if attn.shape[1] <= 32**2:  # avoid memory overhead\n            self.step_store[key].append(attn)\n        return attn\n\n    def between_steps(self):\n        self.attention_store = self.step_store\n        if self.save_global_store:\n            with torch.no_grad():\n                if len(self.global_store) == 0:\n                    self.global_store = self.step_store\n                else:\n                    for key in self.global_store:\n                        for i in range(len(self.global_store[key])):\n                            self.global_store[key][i] += self.step_store[key][i].detach()\n        self.step_store = self.get_empty_store()\n        self.step_store = self.get_empty_store()\n\n    def get_average_attention(self):\n        average_attention = self.attention_store\n        return average_attention\n\n    def get_average_global_attention(self):\n        average_attention = {\n            key: [item / self.cur_step for item in self.global_store[key]] for key in self.attention_store\n        }\n        return average_attention\n\n    def reset(self):\n        super(AttentionStore, self).reset()\n        self.step_store = self.get_empty_store()\n        self.attention_store = {}\n        self.global_store = {}\n\n    def __init__(self, save_global_store=False):\n        \"\"\"\n        Initialize an empty AttentionStore\n        :param step_index: used to visualize only a specific step in the diffusion process\n        \"\"\"\n        super(AttentionStore, self).__init__()\n        self.save_global_store = save_global_store\n        self.step_store = self.get_empty_store()\n        self.attention_store = {}\n        self.global_store = {}\n        self.curr_step_index = 0\n        self.num_uncond_att_layers = 0\n\n\ndef aggregate_attention(\n    attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int\n) -> torch.Tensor:\n    \"\"\"Aggregates the attention across the different layers and heads at the specified resolution.\"\"\"\n    out = []\n    attention_maps = attention_store.get_average_attention()\n\n    # for k, v in attention_maps.items():\n    #     for vv in v:\n    #         print(vv.shape)\n    # exit()\n\n    num_pixels = res**2\n    for location in from_where:\n        for item in attention_maps[f\"{location}_{'cross' if is_cross else 'self'}\"]:\n            if item.shape[1] == num_pixels:\n                cross_maps = item.reshape(1, -1, res, res, item.shape[-1])[select]\n                out.append(cross_maps)\n    out = torch.cat(out, dim=0)\n    out = out.sum(0) / out.shape[0]\n    return out\n\n\ndef register_attention_control(model, controller):\n    attn_procs = {}\n    cross_att_count = 0\n    for name in model.unet.attn_processors.keys():\n        # cross_attention_dim = None if name.endswith(\"attn1.processor\") else model.unet.config.cross_attention_dim\n        if name.startswith(\"mid_block\"):\n            # hidden_size = model.unet.config.block_out_channels[-1]\n            place_in_unet = \"mid\"\n        elif name.startswith(\"up_blocks\"):\n            # block_id = int(name[len(\"up_blocks.\")])\n            # hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]\n            place_in_unet = \"up\"\n        elif name.startswith(\"down_blocks\"):\n            # block_id = int(name[len(\"down_blocks.\")])\n            # hidden_size = model.unet.config.block_out_channels[block_id]\n            place_in_unet = \"down\"\n        else:\n            continue\n\n        cross_att_count += 1\n        attn_procs[name] = AttendExciteCrossAttnProcessor(attnstore=controller, place_in_unet=place_in_unet)\n    model.unet.set_attn_processor(attn_procs)\n    controller.num_att_layers = cross_att_count\n\n\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass StableDiffusionBoxDiffPipeline(\n    DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion with BoxDiff.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`.\"\n        deprecate(\n            \"enable_vae_slicing\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_slicing()\n\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`.\"\n        deprecate(\n            \"disable_vae_slicing\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_slicing()\n\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`.\"\n        deprecate(\n            \"enable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_tiling()\n\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`.\"\n        deprecate(\n            \"disable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_tiling()\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ):\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return text_inputs, prompt_embeds, negative_prompt_embeds\n\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        deprecation_message = \"The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead\"\n        deprecate(\"decode_latents\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        boxdiff_phrases,\n        boxdiff_boxes,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if boxdiff_phrases is not None or boxdiff_boxes is not None:\n            if not (boxdiff_phrases is not None and boxdiff_boxes is not None):\n                raise ValueError(\"Either both `boxdiff_phrases` and `boxdiff_boxes` must be passed or none of them.\")\n\n            if not isinstance(boxdiff_phrases, list) or not isinstance(boxdiff_boxes, list):\n                raise ValueError(\"`boxdiff_phrases` and `boxdiff_boxes` must be lists.\")\n\n            if len(boxdiff_phrases) != len(boxdiff_boxes):\n                raise ValueError(\n                    \"`boxdiff_phrases` and `boxdiff_boxes` must have the same length,\"\n                    f\" got: `boxdiff_phrases` {len(boxdiff_phrases)} != `boxdiff_boxes`\"\n                    f\" {len(boxdiff_boxes)}.\"\n                )\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):\n        r\"\"\"Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497.\n\n        The suffixes after the scaling factors represent the stages where they are being applied.\n\n        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values\n        that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.\n\n        Args:\n            s1 (`float`):\n                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to\n                mitigate \"oversmoothing effect\" in the enhanced denoising process.\n            s2 (`float`):\n                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to\n                mitigate \"oversmoothing effect\" in the enhanced denoising process.\n            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.\n            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.\n        \"\"\"\n        if not hasattr(self, \"unet\"):\n            raise ValueError(\"The pipeline must have `unet` for using FreeU.\")\n        self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)\n\n    def disable_freeu(self):\n        \"\"\"Disables the FreeU mechanism if enabled.\"\"\"\n        self.unet.disable_freeu()\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections\n    def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):\n        \"\"\"\n        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,\n        key, value) are fused. For cross-attention modules, key and value projection matrices are fused.\n\n        > [!WARNING]\n        > This API is 🧪 experimental.\n\n        Args:\n            unet (`bool`, defaults to `True`): To apply fusion on the UNet.\n            vae (`bool`, defaults to `True`): To apply fusion on the VAE.\n        \"\"\"\n        self.fusing_unet = False\n        self.fusing_vae = False\n\n        if unet:\n            self.fusing_unet = True\n            self.unet.fuse_qkv_projections()\n            self.unet.set_attn_processor(FusedAttnProcessor2_0())\n\n        if vae:\n            if not isinstance(self.vae, AutoencoderKL):\n                raise ValueError(\"`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.\")\n\n            self.fusing_vae = True\n            self.vae.fuse_qkv_projections()\n            self.vae.set_attn_processor(FusedAttnProcessor2_0())\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections\n    def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):\n        \"\"\"Disable QKV projection fusion if enabled.\n\n        > [!WARNING]\n        > This API is 🧪 experimental.\n\n        Args:\n            unet (`bool`, defaults to `True`): To apply fusion on the UNet.\n            vae (`bool`, defaults to `True`): To apply fusion on the VAE.\n\n        \"\"\"\n        if unet:\n            if not self.fusing_unet:\n                logger.warning(\"The UNet was not initially fused for QKV projections. Doing nothing.\")\n            else:\n                self.unet.unfuse_qkv_projections()\n                self.fusing_unet = False\n\n        if vae:\n            if not self.fusing_vae:\n                logger.warning(\"The VAE was not initially fused for QKV projections. Doing nothing.\")\n            else:\n                self.vae.unfuse_qkv_projections()\n                self.fusing_vae = False\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    def _compute_max_attention_per_index(\n        self,\n        attention_maps: torch.Tensor,\n        indices_to_alter: List[int],\n        smooth_attentions: bool = False,\n        sigma: float = 0.5,\n        kernel_size: int = 3,\n        normalize_eot: bool = False,\n        bboxes: List[int] = None,\n        L: int = 1,\n        P: float = 0.2,\n    ) -> List[torch.Tensor]:\n        \"\"\"Computes the maximum attention value for each of the tokens we wish to alter.\"\"\"\n        last_idx = -1\n        if normalize_eot:\n            prompt = self.prompt\n            if isinstance(self.prompt, list):\n                prompt = self.prompt[0]\n            last_idx = len(self.tokenizer(prompt)[\"input_ids\"]) - 1\n        attention_for_text = attention_maps[:, :, 1:last_idx]\n        attention_for_text *= 100\n        attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1)\n\n        # Shift indices since we removed the first token \"1:last_idx\"\n        indices_to_alter = [index - 1 for index in indices_to_alter]\n\n        # Extract the maximum values\n        max_indices_list_fg = []\n        max_indices_list_bg = []\n        dist_x = []\n        dist_y = []\n\n        cnt = 0\n        for i in indices_to_alter:\n            image = attention_for_text[:, :, i]\n\n            # TODO\n            # box = [max(round(b / (512 / image.shape[0])), 0) for b in bboxes[cnt]]\n            # x1, y1, x2, y2 = box\n            H, W = image.shape\n            x1 = min(max(round(bboxes[cnt][0] * W), 0), W)\n            y1 = min(max(round(bboxes[cnt][1] * H), 0), H)\n            x2 = min(max(round(bboxes[cnt][2] * W), 0), W)\n            y2 = min(max(round(bboxes[cnt][3] * H), 0), H)\n            box = [x1, y1, x2, y2]\n            cnt += 1\n\n            # coordinates to masks\n            obj_mask = torch.zeros_like(image)\n            ones_mask = torch.ones([y2 - y1, x2 - x1], dtype=obj_mask.dtype).to(obj_mask.device)\n            obj_mask[y1:y2, x1:x2] = ones_mask\n            bg_mask = 1 - obj_mask\n\n            if smooth_attentions:\n                smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(image.device)\n                input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode=\"reflect\")\n                image = smoothing(input).squeeze(0).squeeze(0)\n\n            # Inner-Box constraint\n            k = (obj_mask.sum() * P).long()\n            max_indices_list_fg.append((image * obj_mask).reshape(-1).topk(k)[0].mean())\n\n            # Outer-Box constraint\n            k = (bg_mask.sum() * P).long()\n            max_indices_list_bg.append((image * bg_mask).reshape(-1).topk(k)[0].mean())\n\n            # Corner Constraint\n            gt_proj_x = torch.max(obj_mask, dim=0)[0]\n            gt_proj_y = torch.max(obj_mask, dim=1)[0]\n            corner_mask_x = torch.zeros_like(gt_proj_x)\n            corner_mask_y = torch.zeros_like(gt_proj_y)\n\n            # create gt according to the number config.L\n            N = gt_proj_x.shape[0]\n            corner_mask_x[max(box[0] - L, 0) : min(box[0] + L + 1, N)] = 1.0\n            corner_mask_x[max(box[2] - L, 0) : min(box[2] + L + 1, N)] = 1.0\n            corner_mask_y[max(box[1] - L, 0) : min(box[1] + L + 1, N)] = 1.0\n            corner_mask_y[max(box[3] - L, 0) : min(box[3] + L + 1, N)] = 1.0\n            dist_x.append((F.l1_loss(image.max(dim=0)[0], gt_proj_x, reduction=\"none\") * corner_mask_x).mean())\n            dist_y.append((F.l1_loss(image.max(dim=1)[0], gt_proj_y, reduction=\"none\") * corner_mask_y).mean())\n\n        return max_indices_list_fg, max_indices_list_bg, dist_x, dist_y\n\n    def _aggregate_and_get_max_attention_per_token(\n        self,\n        attention_store: AttentionStore,\n        indices_to_alter: List[int],\n        attention_res: int = 16,\n        smooth_attentions: bool = False,\n        sigma: float = 0.5,\n        kernel_size: int = 3,\n        normalize_eot: bool = False,\n        bboxes: List[int] = None,\n        L: int = 1,\n        P: float = 0.2,\n    ):\n        \"\"\"Aggregates the attention for each token and computes the max activation value for each token to alter.\"\"\"\n        attention_maps = aggregate_attention(\n            attention_store=attention_store,\n            res=attention_res,\n            from_where=(\"up\", \"down\", \"mid\"),\n            is_cross=True,\n            select=0,\n        )\n        max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y = self._compute_max_attention_per_index(\n            attention_maps=attention_maps,\n            indices_to_alter=indices_to_alter,\n            smooth_attentions=smooth_attentions,\n            sigma=sigma,\n            kernel_size=kernel_size,\n            normalize_eot=normalize_eot,\n            bboxes=bboxes,\n            L=L,\n            P=P,\n        )\n        return max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y\n\n    @staticmethod\n    def _compute_loss(\n        max_attention_per_index_fg: List[torch.Tensor],\n        max_attention_per_index_bg: List[torch.Tensor],\n        dist_x: List[torch.Tensor],\n        dist_y: List[torch.Tensor],\n        return_losses: bool = False,\n    ) -> torch.Tensor:\n        \"\"\"Computes the attend-and-excite loss using the maximum attention value for each token.\"\"\"\n        losses_fg = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index_fg]\n        losses_bg = [max(0, curr_max) for curr_max in max_attention_per_index_bg]\n        loss = sum(losses_fg) + sum(losses_bg) + sum(dist_x) + sum(dist_y)\n        if return_losses:\n            return max(losses_fg), losses_fg\n        else:\n            return max(losses_fg), loss\n\n    @staticmethod\n    def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor:\n        \"\"\"Update the latent according to the computed loss.\"\"\"\n        grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0]\n        latents = latents - step_size * grad_cond\n        return latents\n\n    def _perform_iterative_refinement_step(\n        self,\n        latents: torch.Tensor,\n        indices_to_alter: List[int],\n        loss_fg: torch.Tensor,\n        threshold: float,\n        text_embeddings: torch.Tensor,\n        text_input,\n        attention_store: AttentionStore,\n        step_size: float,\n        t: int,\n        attention_res: int = 16,\n        smooth_attentions: bool = True,\n        sigma: float = 0.5,\n        kernel_size: int = 3,\n        max_refinement_steps: int = 20,\n        normalize_eot: bool = False,\n        bboxes: List[int] = None,\n        L: int = 1,\n        P: float = 0.2,\n    ):\n        \"\"\"\n        Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent\n        code according to our loss objective until the given threshold is reached for all tokens.\n        \"\"\"\n        iteration = 0\n        target_loss = max(0, 1.0 - threshold)\n\n        while loss_fg > target_loss:\n            iteration += 1\n\n            latents = latents.clone().detach().requires_grad_(True)\n            # noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample\n            self.unet.zero_grad()\n\n            # Get max activation value for each subject token\n            (\n                max_attention_per_index_fg,\n                max_attention_per_index_bg,\n                dist_x,\n                dist_y,\n            ) = self._aggregate_and_get_max_attention_per_token(\n                attention_store=attention_store,\n                indices_to_alter=indices_to_alter,\n                attention_res=attention_res,\n                smooth_attentions=smooth_attentions,\n                sigma=sigma,\n                kernel_size=kernel_size,\n                normalize_eot=normalize_eot,\n                bboxes=bboxes,\n                L=L,\n                P=P,\n            )\n\n            loss_fg, losses_fg = self._compute_loss(\n                max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True\n            )\n\n            if loss_fg != 0:\n                latents = self._update_latent(latents, loss_fg, step_size)\n\n            # with torch.no_grad():\n            #     noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[0].unsqueeze(0)).sample\n            #     noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample\n\n            # try:\n            #     low_token = np.argmax([l.item() if not isinstance(l, int) else l for l in losses_fg])\n            # except Exception as e:\n            #     print(e)  # catch edge case :)\n            #     low_token = np.argmax(losses_fg)\n\n            # low_word = self.tokenizer.decode(text_input.input_ids[0][indices_to_alter[low_token]])\n            # print(f'\\t Try {iteration}. {low_word} has a max attention of {max_attention_per_index_fg[low_token]}')\n\n            if iteration >= max_refinement_steps:\n                # print(f'\\t Exceeded max number of iterations ({max_refinement_steps})! '\n                #       f'Finished with a max attention of {max_attention_per_index_fg[low_token]}')\n                break\n\n        # Run one more time but don't compute gradients and update the latents.\n        # We just need to compute the new loss - the grad update will occur below\n        latents = latents.clone().detach().requires_grad_(True)\n        # noise_pred_text = self.unet(latents, t, encoder_hidden_states=text_embeddings[1].unsqueeze(0)).sample\n        self.unet.zero_grad()\n\n        # Get max activation value for each subject token\n        (\n            max_attention_per_index_fg,\n            max_attention_per_index_bg,\n            dist_x,\n            dist_y,\n        ) = self._aggregate_and_get_max_attention_per_token(\n            attention_store=attention_store,\n            indices_to_alter=indices_to_alter,\n            attention_res=attention_res,\n            smooth_attentions=smooth_attentions,\n            sigma=sigma,\n            kernel_size=kernel_size,\n            normalize_eot=normalize_eot,\n            bboxes=bboxes,\n            L=L,\n            P=P,\n        )\n        loss_fg, losses_fg = self._compute_loss(\n            max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y, return_losses=True\n        )\n        # print(f\"\\t Finished with loss of: {loss_fg}\")\n        return loss_fg, latents, max_attention_per_index_fg\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        boxdiff_phrases: List[str] = None,\n        boxdiff_boxes: List[List[float]] = None,  # TODO\n        boxdiff_kwargs: Optional[Dict[str, Any]] = {\n            \"attention_res\": 16,\n            \"P\": 0.2,\n            \"L\": 1,\n            \"max_iter_to_alter\": 25,\n            \"loss_thresholds\": {0: 0.05, 10: 0.5, 20: 0.8},\n            \"scale_factor\": 20,\n            \"scale_range\": (1.0, 0.5),\n            \"smooth_attentions\": True,\n            \"sigma\": 0.5,\n            \"kernel_size\": 3,\n            \"refine\": False,\n            \"normalize_eot\": True,\n        },\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n\n            boxdiff_attention_res (`int`, *optional*, defaults to 16):\n                The resolution of the attention maps used for computing the BoxDiff loss.\n            boxdiff_P (`float`, *optional*, defaults to 0.2):\n\n            boxdiff_L (`int`, *optional*, defaults to 1):\n                The number of pixels around the corner to be selected in BoxDiff loss.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when\n                using zero terminal SNR.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        # -1. Register attention control (for BoxDiff)\n        attention_store = AttentionStore()\n        register_attention_control(self, attention_store)\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n        # to deal with lora scaling and other possible forward hooks\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            boxdiff_phrases,\n            boxdiff_boxes,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n        self.prompt = prompt\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        text_inputs, prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        if ip_adapter_image is not None:\n            output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True\n            image_embeds, negative_image_embeds = self.encode_image(\n                ip_adapter_image, device, num_images_per_prompt, output_hidden_state\n            )\n            if self.do_classifier_free_guidance:\n                image_embeds = torch.cat([negative_image_embeds, image_embeds])\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 6.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = {\"image_embeds\": image_embeds} if ip_adapter_image is not None else None\n\n        # 6.2 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 6.3 Prepare BoxDiff inputs\n        # a) Indices to alter\n        input_ids = self.tokenizer(prompt)[\"input_ids\"]\n        decoded = [self.tokenizer.decode([t]) for t in input_ids]\n        indices_to_alter = []\n        bboxes = []\n        for phrase, box in zip(boxdiff_phrases, boxdiff_boxes):\n            # it could happen that phrase does not correspond a single token?\n            if phrase not in decoded:\n                continue\n            indices_to_alter.append(decoded.index(phrase))\n            bboxes.append(box)\n\n        # b) A bunch of hyperparameters\n        attention_res = boxdiff_kwargs.get(\"attention_res\", 16)\n        smooth_attentions = boxdiff_kwargs.get(\"smooth_attentions\", True)\n        sigma = boxdiff_kwargs.get(\"sigma\", 0.5)\n        kernel_size = boxdiff_kwargs.get(\"kernel_size\", 3)\n        L = boxdiff_kwargs.get(\"L\", 1)\n        P = boxdiff_kwargs.get(\"P\", 0.2)\n        thresholds = boxdiff_kwargs.get(\"loss_thresholds\", {0: 0.05, 10: 0.5, 20: 0.8})\n        max_iter_to_alter = boxdiff_kwargs.get(\"max_iter_to_alter\", len(self.scheduler.timesteps) + 1)\n        scale_factor = boxdiff_kwargs.get(\"scale_factor\", 20)\n        refine = boxdiff_kwargs.get(\"refine\", False)\n        normalize_eot = boxdiff_kwargs.get(\"normalize_eot\", True)\n\n        scale_range = boxdiff_kwargs.get(\"scale_range\", (1.0, 0.5))\n        scale_range = np.linspace(scale_range[0], scale_range[1], len(self.scheduler.timesteps))\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # BoxDiff optimization\n                with torch.enable_grad():\n                    latents = latents.clone().detach().requires_grad_(True)\n\n                    # Forward pass of denoising with text conditioning\n                    noise_pred_text = self.unet(\n                        latents,\n                        t,\n                        encoder_hidden_states=prompt_embeds[1].unsqueeze(0),\n                        cross_attention_kwargs=cross_attention_kwargs,\n                    ).sample\n                    self.unet.zero_grad()\n\n                    # Get max activation value for each subject token\n                    (\n                        max_attention_per_index_fg,\n                        max_attention_per_index_bg,\n                        dist_x,\n                        dist_y,\n                    ) = self._aggregate_and_get_max_attention_per_token(\n                        attention_store=attention_store,\n                        indices_to_alter=indices_to_alter,\n                        attention_res=attention_res,\n                        smooth_attentions=smooth_attentions,\n                        sigma=sigma,\n                        kernel_size=kernel_size,\n                        normalize_eot=normalize_eot,\n                        bboxes=bboxes,\n                        L=L,\n                        P=P,\n                    )\n\n                    loss_fg, loss = self._compute_loss(\n                        max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y\n                    )\n\n                    # Refinement from attend-and-excite (not necessary)\n                    if refine and i in thresholds.keys() and loss_fg > 1.0 - thresholds[i]:\n                        del noise_pred_text\n                        torch.cuda.empty_cache()\n                        loss_fg, latents, max_attention_per_index_fg = self._perform_iterative_refinement_step(\n                            latents=latents,\n                            indices_to_alter=indices_to_alter,\n                            loss_fg=loss_fg,\n                            threshold=thresholds[i],\n                            text_embeddings=prompt_embeds,\n                            text_input=text_inputs,\n                            attention_store=attention_store,\n                            step_size=scale_factor * np.sqrt(scale_range[i]),\n                            t=t,\n                            attention_res=attention_res,\n                            smooth_attentions=smooth_attentions,\n                            sigma=sigma,\n                            kernel_size=kernel_size,\n                            normalize_eot=normalize_eot,\n                            bboxes=bboxes,\n                            L=L,\n                            P=P,\n                        )\n\n                    # Perform gradient update\n                    if i < max_iter_to_alter:\n                        _, loss = self._compute_loss(\n                            max_attention_per_index_fg, max_attention_per_index_bg, dist_x, dist_y\n                        )\n                        if loss != 0:\n                            latents = self._update_latent(\n                                latents=latents, loss=loss, step_size=scale_factor * np.sqrt(scale_range[i])\n                            )\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\n                0\n            ]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_pag.py",
    "content": "# Implementation of StableDiffusionPipeline with PAG\n# https://ku-cvlab.github.io/Perturbed-Attention-Guidance\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom packaging import version\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel\nfrom diffusers.models.attention_processor import Attention, AttnProcessor2_0, FusedAttnProcessor2_0\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import StableDiffusionPipeline\n        >>> pipe = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16)\n        >>> pipe = pipe.to(\"cuda\")\n        >>> prompt = \"a photo of an astronaut riding a horse on mars\"\n        >>> image = pipe(prompt).images[0]\n        ```\n\"\"\"\n\n\nclass PAGIdentitySelfAttnProcessor:\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(self):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        temb: Optional[torch.Tensor] = None,\n        *args,\n        **kwargs,\n    ) -> torch.Tensor:\n        if len(args) > 0 or kwargs.get(\"scale\", None) is not None:\n            deprecation_message = \"The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.\"\n            deprecate(\"scale\", \"1.0.0\", deprecation_message)\n\n        residual = hidden_states\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        # chunk\n        hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)\n\n        # original path\n        batch_size, sequence_length, _ = hidden_states_org.shape\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states_org)\n        key = attn.to_k(hidden_states_org)\n        value = attn.to_v(hidden_states_org)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states_org = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states_org = hidden_states_org.to(query.dtype)\n\n        # linear proj\n        hidden_states_org = attn.to_out[0](hidden_states_org)\n        # dropout\n        hidden_states_org = attn.to_out[1](hidden_states_org)\n\n        if input_ndim == 4:\n            hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        # perturbed path (identity attention)\n        batch_size, sequence_length, _ = hidden_states_ptb.shape\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)\n\n        value = attn.to_v(hidden_states_ptb)\n\n        # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())\n        hidden_states_ptb = value\n\n        hidden_states_ptb = hidden_states_ptb.to(query.dtype)\n\n        # linear proj\n        hidden_states_ptb = attn.to_out[0](hidden_states_ptb)\n        # dropout\n        hidden_states_ptb = attn.to_out[1](hidden_states_ptb)\n\n        if input_ndim == 4:\n            hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        # cat\n        hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass PAGCFGIdentitySelfAttnProcessor:\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(self):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        temb: Optional[torch.Tensor] = None,\n        *args,\n        **kwargs,\n    ) -> torch.Tensor:\n        if len(args) > 0 or kwargs.get(\"scale\", None) is not None:\n            deprecation_message = \"The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.\"\n            deprecate(\"scale\", \"1.0.0\", deprecation_message)\n\n        residual = hidden_states\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        # chunk\n        hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)\n        hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])\n\n        # original path\n        batch_size, sequence_length, _ = hidden_states_org.shape\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states_org)\n        key = attn.to_k(hidden_states_org)\n        value = attn.to_v(hidden_states_org)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states_org = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states_org = hidden_states_org.to(query.dtype)\n\n        # linear proj\n        hidden_states_org = attn.to_out[0](hidden_states_org)\n        # dropout\n        hidden_states_org = attn.to_out[1](hidden_states_org)\n\n        if input_ndim == 4:\n            hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        # perturbed path (identity attention)\n        batch_size, sequence_length, _ = hidden_states_ptb.shape\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)\n\n        value = attn.to_v(hidden_states_ptb)\n        hidden_states_ptb = value\n        hidden_states_ptb = hidden_states_ptb.to(query.dtype)\n\n        # linear proj\n        hidden_states_ptb = attn.to_out[0](hidden_states_ptb)\n        # dropout\n        hidden_states_ptb = attn.to_out[1](hidden_states_ptb)\n\n        if input_ndim == 4:\n            hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        # cat\n        hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass StableDiffusionPAGPipeline(\n    DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion.\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`.\"\n        deprecate(\n            \"enable_vae_slicing\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_slicing()\n\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`.\"\n        deprecate(\n            \"disable_vae_slicing\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_slicing()\n\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`.\"\n        deprecate(\n            \"enable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_tiling()\n\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`.\"\n        deprecate(\n            \"disable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_tiling()\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ):\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt\n    ):\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            image_embeds = []\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)\n                single_negative_image_embeds = torch.stack(\n                    [single_negative_image_embeds] * num_images_per_prompt, dim=0\n                )\n\n                if self.do_classifier_free_guidance:\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                    single_image_embeds = single_image_embeds.to(device)\n\n                image_embeds.append(single_image_embeds)\n        else:\n            image_embeds = ip_adapter_image_embeds\n        return image_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        deprecation_message = \"The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead\"\n        deprecate(\"decode_latents\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):\n        r\"\"\"Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497.\n        The suffixes after the scaling factors represent the stages where they are being applied.\n        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values\n        that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.\n        Args:\n            s1 (`float`):\n                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to\n                mitigate \"oversmoothing effect\" in the enhanced denoising process.\n            s2 (`float`):\n                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to\n                mitigate \"oversmoothing effect\" in the enhanced denoising process.\n            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.\n            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.\n        \"\"\"\n        if not hasattr(self, \"unet\"):\n            raise ValueError(\"The pipeline must have `unet` for using FreeU.\")\n        self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)\n\n    def disable_freeu(self):\n        \"\"\"Disables the FreeU mechanism if enabled.\"\"\"\n        self.unet.disable_freeu()\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections\n    def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):\n        \"\"\"\n        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,\n        key, value) are fused. For cross-attention modules, key and value projection matrices are fused.\n        > [!WARNING]\n        > This API is 🧪 experimental.\n        Args:\n            unet (`bool`, defaults to `True`): To apply fusion on the UNet.\n            vae (`bool`, defaults to `True`): To apply fusion on the VAE.\n        \"\"\"\n        self.fusing_unet = False\n        self.fusing_vae = False\n\n        if unet:\n            self.fusing_unet = True\n            self.unet.fuse_qkv_projections()\n            self.unet.set_attn_processor(FusedAttnProcessor2_0())\n\n        if vae:\n            if not isinstance(self.vae, AutoencoderKL):\n                raise ValueError(\"`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.\")\n\n            self.fusing_vae = True\n            self.vae.fuse_qkv_projections()\n            self.vae.set_attn_processor(FusedAttnProcessor2_0())\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections\n    def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):\n        \"\"\"Disable QKV projection fusion if enabled.\n        > [!WARNING]\n        > This API is 🧪 experimental.\n        Args:\n            unet (`bool`, defaults to `True`): To apply fusion on the UNet.\n            vae (`bool`, defaults to `True`): To apply fusion on the VAE.\n        \"\"\"\n        if unet:\n            if not self.fusing_unet:\n                logger.warning(\"The UNet was not initially fused for QKV projections. Doing nothing.\")\n            else:\n                self.unet.unfuse_qkv_projections()\n                self.fusing_unet = False\n\n        if vae:\n            if not self.fusing_vae:\n                logger.warning(\"The VAE was not initially fused for QKV projections. Doing nothing.\")\n            else:\n                self.vae.unfuse_qkv_projections()\n                self.fusing_vae = False\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    def pred_z0(self, sample, model_output, timestep):\n        alpha_prod_t = self.scheduler.alphas_cumprod[timestep].to(sample.device)\n\n        beta_prod_t = 1 - alpha_prod_t\n        if self.scheduler.config.prediction_type == \"epsilon\":\n            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n        elif self.scheduler.config.prediction_type == \"sample\":\n            pred_original_sample = model_output\n        elif self.scheduler.config.prediction_type == \"v_prediction\":\n            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output\n            # predict V\n            model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample\n        else:\n            raise ValueError(\n                f\"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,\"\n                \" or `v_prediction`\"\n            )\n\n        return pred_original_sample\n\n    def pred_x0(self, latents, noise_pred, t, generator, device, prompt_embeds, output_type):\n        pred_z0 = self.pred_z0(latents, noise_pred, t)\n        pred_x0 = self.vae.decode(pred_z0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]\n        pred_x0, ____ = self.run_safety_checker(pred_x0, device, prompt_embeds.dtype)\n        do_denormalize = [True] * pred_x0.shape[0]\n        pred_x0 = self.image_processor.postprocess(pred_x0, output_type=output_type, do_denormalize=do_denormalize)\n\n        return pred_x0\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @property\n    def pag_scale(self):\n        return self._pag_scale\n\n    @property\n    def do_perturbed_attention_guidance(self):\n        return self._pag_scale > 0\n\n    @property\n    def pag_adaptive_scaling(self):\n        return self._pag_adaptive_scaling\n\n    @property\n    def do_pag_adaptive_scaling(self):\n        return self._pag_adaptive_scaling > 0\n\n    @property\n    def pag_applied_layers_index(self):\n        return self._pag_applied_layers_index\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        guidance_scale: float = 7.5,\n        pag_scale: float = 0.0,\n        pag_adaptive_scaling: float = 0.0,\n        pag_applied_layers_index: List[str] = [\"d4\"],  # ['d4', 'd5', 'm0']\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when\n                using zero terminal SNR.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n        Examples:\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n        # to deal with lora scaling and other possible forward hooks\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._interrupt = False\n\n        self._pag_scale = pag_scale\n        self._pag_adaptive_scaling = pag_adaptive_scaling\n        self._pag_applied_layers_index = pag_applied_layers_index\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n\n        # cfg\n        if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n        # pag\n        elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:\n            prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])\n        # both\n        elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt\n            )\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 6.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = (\n            {\"image_embeds\": image_embeds}\n            if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)\n            else None\n        )\n\n        # 6.2 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 7. Denoising loop\n        if self.do_perturbed_attention_guidance:\n            down_layers = []\n            mid_layers = []\n            up_layers = []\n            for name, module in self.unet.named_modules():\n                if \"attn1\" in name and \"to\" not in name:\n                    layer_type = name.split(\".\")[0].split(\"_\")[0]\n                    if layer_type == \"down\":\n                        down_layers.append(module)\n                    elif layer_type == \"mid\":\n                        mid_layers.append(module)\n                    elif layer_type == \"up\":\n                        up_layers.append(module)\n                    else:\n                        raise ValueError(f\"Invalid layer type: {layer_type}\")\n\n        # change attention layer in UNet if use PAG\n        if self.do_perturbed_attention_guidance:\n            if self.do_classifier_free_guidance:\n                replace_processor = PAGCFGIdentitySelfAttnProcessor()\n            else:\n                replace_processor = PAGIdentitySelfAttnProcessor()\n\n            drop_layers = self.pag_applied_layers_index\n            for drop_layer in drop_layers:\n                try:\n                    if drop_layer[0] == \"d\":\n                        down_layers[int(drop_layer[1])].processor = replace_processor\n                    elif drop_layer[0] == \"m\":\n                        mid_layers[int(drop_layer[1])].processor = replace_processor\n                    elif drop_layer[0] == \"u\":\n                        up_layers[int(drop_layer[1])].processor = replace_processor\n                    else:\n                        raise ValueError(f\"Invalid layer type: {drop_layer[0]}\")\n                except IndexError:\n                    raise ValueError(\n                        f\"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers.\"\n                    )\n\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # cfg\n                if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:\n                    latent_model_input = torch.cat([latents] * 2)\n                # pag\n                elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:\n                    latent_model_input = torch.cat([latents] * 2)\n                # both\n                elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:\n                    latent_model_input = torch.cat([latents] * 3)\n                # no\n                else:\n                    latent_model_input = latents\n\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n\n                # cfg\n                if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n\n                    delta = noise_pred_text - noise_pred_uncond\n                    noise_pred = noise_pred_uncond + self.guidance_scale * delta\n\n                # pag\n                elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:\n                    noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)\n\n                    signal_scale = self.pag_scale\n                    if self.do_pag_adaptive_scaling:\n                        signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t)\n                        if signal_scale < 0:\n                            signal_scale = 0\n\n                    noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)\n\n                # both\n                elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:\n                    noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)\n\n                    signal_scale = self.pag_scale\n                    if self.do_pag_adaptive_scaling:\n                        signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000 - t)\n                        if signal_scale < 0:\n                            signal_scale = 0\n\n                    noise_pred = (\n                        noise_pred_text\n                        + (self.guidance_scale - 1.0) * (noise_pred_text - noise_pred_uncond)\n                        + signal_scale * (noise_pred_text - noise_pred_text_perturb)\n                    )\n\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\n                0\n            ]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        # change attention layer in UNet if use PAG\n        if self.do_perturbed_attention_guidance:\n            drop_layers = self.pag_applied_layers_index\n            for drop_layer in drop_layers:\n                try:\n                    if drop_layer[0] == \"d\":\n                        down_layers[int(drop_layer[1])].processor = AttnProcessor2_0()\n                    elif drop_layer[0] == \"m\":\n                        mid_layers[int(drop_layer[1])].processor = AttnProcessor2_0()\n                    elif drop_layer[0] == \"u\":\n                        up_layers[int(drop_layer[1])].processor = AttnProcessor2_0()\n                    else:\n                        raise ValueError(f\"Invalid layer type: {drop_layer[0]}\")\n                except IndexError:\n                    raise ValueError(\n                        f\"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers.\"\n                    )\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_upscale_ldm3d.py",
    "content": "# Copyright 2025 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport PIL\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.image_processor import PipelineDepthInput, PipelineImageInput, VaeImageProcessorLDM3D\nfrom diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\nfrom diffusers.pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d import LDM3DPipelineOutput\nfrom diffusers.schedulers import DDPMScheduler, KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```python\n        >>> from diffusers import StableDiffusionUpscaleLDM3DPipeline\n        >>> from PIL import Image\n        >>> from io import BytesIO\n        >>> import requests\n\n        >>> pipe = StableDiffusionUpscaleLDM3DPipeline.from_pretrained(\"Intel/ldm3d-sr\")\n        >>> pipe = pipe.to(\"cuda\")\n        >>> rgb_path = \"https://huggingface.co/Intel/ldm3d-sr/resolve/main/lemons_ldm3d_rgb.jpg\"\n        >>> depth_path = \"https://huggingface.co/Intel/ldm3d-sr/resolve/main/lemons_ldm3d_depth.png\"\n        >>> low_res_rgb = Image.open(BytesIO(requests.get(rgb_path).content)).convert(\"RGB\")\n        >>> low_res_depth = Image.open(BytesIO(requests.get(depth_path).content)).convert(\"L\")\n        >>> output = pipe(\n        ...     prompt=\"high quality high resolution uhd 4k image\",\n        ...     rgb=low_res_rgb,\n        ...     depth=low_res_depth,\n        ...     num_inference_steps=50,\n        ...     target_res=[1024, 1024],\n        ... )\n        >>> rgb_image, depth_image = output.rgb, output.depth\n        >>> rgb_image[0].save(\"hr_ldm3d_rgb.jpg\")\n        >>> depth_image[0].save(\"hr_ldm3d_depth.png\")\n        ```\n\"\"\"\n\n\nclass StableDiffusionUpscaleLDM3DPipeline(\n    DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin\n):\n    r\"\"\"\n    Pipeline for text-to-image and 3D generation using LDM3D.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        low_res_scheduler ([`SchedulerMixin`]):\n            A scheduler used to add initial noise to the low resolution conditioning image. It must be an instance of\n            [`DDPMScheduler`].\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        low_res_scheduler: DDPMScheduler,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n        watermarker: Optional[Any] = None,\n        max_noise_level: int = 350,\n    ):\n        super().__init__()\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            low_res_scheduler=low_res_scheduler,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            watermarker=watermarker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor, resample=\"bilinear\")\n        # self.register_to_config(requires_safety_checker=requires_safety_checker)\n        self.register_to_config(max_noise_level=max_noise_level)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ):\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            rgb_feature_extractor_input = feature_extractor_input[0]\n            safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        image,\n        noise_level,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        target_res=None,\n    ):\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if (\n            not isinstance(image, torch.Tensor)\n            and not isinstance(image, PIL.Image.Image)\n            and not isinstance(image, np.ndarray)\n            and not isinstance(image, list)\n        ):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}\"\n            )\n\n        # verify batch size of prompt and image are same if image is a list or tensor or numpy array\n        if isinstance(image, (list, np.ndarray, torch.Tensor)):\n            if prompt is not None and isinstance(prompt, str):\n                batch_size = 1\n            elif prompt is not None and isinstance(prompt, list):\n                batch_size = len(prompt)\n            else:\n                batch_size = prompt_embeds.shape[0]\n\n            if isinstance(image, list):\n                image_batch_size = len(image)\n            else:\n                image_batch_size = image.shape[0]\n            if batch_size != image_batch_size:\n                raise ValueError(\n                    f\"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}.\"\n                    \" Please make sure that passed `prompt` matches the batch size of `image`.\"\n                )\n\n        # check noise level\n        if noise_level > self.config.max_noise_level:\n            raise ValueError(f\"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (batch_size, num_channels_latents, height, width)\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            if latents.shape != shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {shape}\")\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        rgb: PipelineImageInput = None,\n        depth: PipelineDepthInput = None,\n        num_inference_steps: int = 75,\n        guidance_scale: float = 9.0,\n        noise_level: int = 20,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        target_res: Optional[List[int]] = [1024, 1024],\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):\n                `Image` or tensor representing an image batch to be upscaled.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that calls every `callback_steps` steps during inference. The function is called with the\n                following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function is called. If not specified, the callback is called at\n                every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            rgb,\n            noise_level,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n        )\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        # 4. Preprocess image\n        rgb, depth = self.image_processor.preprocess(rgb, depth, target_res=target_res)\n        rgb = rgb.to(dtype=prompt_embeds.dtype, device=device)\n        depth = depth.to(dtype=prompt_embeds.dtype, device=device)\n\n        # 5. set timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 6. Encode low resolutiom image to latent space\n        image = torch.cat([rgb, depth], axis=1)\n        latent_space_image = self.vae.encode(image).latent_dist.sample(generator)\n        latent_space_image *= self.vae.scaling_factor\n        noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)\n        # noise_rgb = randn_tensor(rgb.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)\n        # rgb = self.low_res_scheduler.add_noise(rgb, noise_rgb, noise_level)\n        # noise_depth = randn_tensor(depth.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)\n        # depth = self.low_res_scheduler.add_noise(depth, noise_depth, noise_level)\n\n        batch_multiplier = 2 if do_classifier_free_guidance else 1\n        latent_space_image = torch.cat([latent_space_image] * batch_multiplier * num_images_per_prompt)\n        noise_level = torch.cat([noise_level] * latent_space_image.shape[0])\n\n        # 7. Prepare latent variables\n        height, width = latent_space_image.shape[2:]\n        num_channels_latents = self.vae.config.latent_channels\n\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 8. Check that sizes of image and latents match\n        num_channels_image = latent_space_image.shape[1]\n        if num_channels_latents + num_channels_image != self.unet.config.in_channels:\n            raise ValueError(\n                f\"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects\"\n                f\" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +\"\n                f\" `num_channels_image`: {num_channels_image} \"\n                f\" = {num_channels_latents + num_channels_image}. Please verify the config of\"\n                \" `pipeline.unet` or your `image` input.\"\n            )\n\n        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 10. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n\n                # concat latents, mask, masked_image_latents in the channel dimension\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n                latent_model_input = torch.cat([latent_model_input, latent_space_image], dim=1)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    class_labels=noise_level,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        callback(i, t, latents)\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            image = self.vae.decode(latents / self.vae.scaling_factor, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        rgb, depth = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # 11. Apply watermark\n        if output_type == \"pil\" and self.watermarker is not None:\n            rgb = self.watermarker.apply_watermark(rgb)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return ((rgb, depth), has_nsfw_concept)\n\n        return LDM3DPipelineOutput(rgb=rgb, depth=depth, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom PIL import Image\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTextModel,\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    is_invisible_watermark_available,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_invisible_watermark_available():\n    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom einops import rearrange, repeat\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import DDIMScheduler, DiffusionPipeline\n        >>> from diffusers.utils import load_image\n        >>> import torch.nn.functional as F\n        >>> from torchvision.transforms.functional import to_tensor, gaussian_blur\n\n        >>> dtype = torch.float16\n        >>> device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n        >>> scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n\n        >>> pipeline = DiffusionPipeline.from_pretrained(\n        ...    \"stabilityai/stable-diffusion-xl-base-1.0\",\n        ...    custom_pipeline=\"pipeline_stable_diffusion_xl_attentive_eraser\",\n        ...    scheduler=scheduler,\n        ...    variant=\"fp16\",\n        ...    use_safetensors=True,\n        ...    torch_dtype=dtype,\n        ... ).to(device)\n\n\n        >>> def preprocess_image(image_path, device):\n        ...     image = to_tensor((load_image(image_path)))\n        ...     image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]\n        ...     if image.shape[1] != 3:\n        ...         image = image.expand(-1, 3, -1, -1)\n        ...         image = F.interpolate(image, (1024, 1024))\n        ...         image = image.to(dtype).to(device)\n        ...         return image\n\n        >>> def preprocess_mask(mask_path, device):\n        ...     mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))\n        ...     mask = mask.unsqueeze_(0).float()  # 0 or 1\n        ...     mask = F.interpolate(mask, (1024, 1024))\n        ...     mask = gaussian_blur(mask, kernel_size=(77, 77))\n        ...     mask[mask < 0.1] = 0\n        ...     mask[mask >= 0.1] = 1\n        ...     mask = mask.to(dtype).to(device)\n        ...     return mask\n\n        >>> prompt = \"\" # Set prompt to null\n        >>> seed=123\n        >>> generator = torch.Generator(device=device).manual_seed(seed)\n        >>> source_image_path = \"https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png\"\n        >>> mask_path = \"https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png\"\n        >>> source_image = preprocess_image(source_image_path, device)\n        >>> mask = preprocess_mask(mask_path, device)\n\n        >>> image = pipeline(\n        ...     prompt=prompt,\n        ...     image=source_image,\n        ...     mask_image=mask,\n        ...     height=1024,\n        ...     width=1024,\n        ...     AAS=True, # enable AAS\n        ...     strength=0.8, # inpainting strength\n        ...     rm_guidance_scale=9, # removal guidance scale\n        ...     ss_steps = 9, # similarity suppression steps\n        ...     ss_scale = 0.3, # similarity suppression scale\n        ...     AAS_start_step=0, # AAS start step\n        ...     AAS_start_layer=34, # AAS start layer\n        ...     AAS_end_layer=70, # AAS end layer\n        ...     num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)\n        ...     generator=generator,\n        ...     guidance_scale=1,\n        ... ).images[0]\n        >>> image.save('./removed_img.png')\n        >>> print(\"Object removal completed\")\n        ```\n\"\"\"\n\n\nclass AttentionBase:\n    def __init__(self):\n        self.cur_step = 0\n        self.num_att_layers = -1\n        self.cur_att_layer = 0\n\n    def after_step(self):\n        pass\n\n    def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)\n        self.cur_att_layer += 1\n        if self.cur_att_layer == self.num_att_layers:\n            self.cur_att_layer = 0\n            self.cur_step += 1\n            # after step\n            self.after_step()\n        return out\n\n    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        out = torch.einsum(\"b i j, b j d -> b i d\", attn, v)\n        out = rearrange(out, \"(b h) n d -> b n (h d)\", h=num_heads)\n        return out\n\n    def reset(self):\n        self.cur_step = 0\n        self.cur_att_layer = 0\n\n\nclass AAS_XL(AttentionBase):\n    MODEL_TYPE = {\"SD\": 16, \"SDXL\": 70}\n\n    def __init__(\n        self,\n        start_step=4,\n        end_step=50,\n        start_layer=10,\n        end_layer=16,\n        layer_idx=None,\n        step_idx=None,\n        total_steps=50,\n        mask=None,\n        model_type=\"SD\",\n        ss_steps=9,\n        ss_scale=1.0,\n    ):\n        \"\"\"\n        Args:\n            start_step: the step to start AAS\n            start_layer: the layer to start AAS\n            layer_idx: list of the layers to apply AAS\n            step_idx: list the steps to apply AAS\n            total_steps: the total number of steps\n            mask: source mask with shape (h, w)\n            model_type: the model type, SD or SDXL\n        \"\"\"\n        super().__init__()\n        self.total_steps = total_steps\n        self.total_layers = self.MODEL_TYPE.get(model_type, 16)\n        self.start_step = start_step\n        self.end_step = end_step\n        self.start_layer = start_layer\n        self.end_layer = end_layer\n        self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, end_layer))\n        self.step_idx = step_idx if step_idx is not None else list(range(start_step, end_step))\n        self.mask = mask  # mask with shape (1, 1 ,h, w)\n        self.ss_steps = ss_steps\n        self.ss_scale = ss_scale\n        self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()\n        self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()\n        self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()\n        self.mask_128 = F.max_pool2d(mask, (1024 // 128, 1024 // 128)).round().squeeze().squeeze()\n\n    def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, is_mask_attn, mask, **kwargs):\n        B = q.shape[0] // num_heads\n        if is_mask_attn:\n            mask_flatten = mask.flatten(0)\n            if self.cur_step <= self.ss_steps:\n                # background\n                sim_bg = sim + mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)\n\n                # object\n                sim_fg = self.ss_scale * sim\n                sim_fg += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)\n                sim = torch.cat([sim_fg, sim_bg], dim=0)\n            else:\n                sim += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)\n\n        attn = sim.softmax(-1)\n        if len(attn) == 2 * len(v):\n            v = torch.cat([v] * 2)\n        out = torch.einsum(\"h i j, h j d -> h i d\", attn, v)\n        out = rearrange(out, \"(h1 h) (b n) d -> (h1 b) n (h d)\", b=B, h=num_heads)\n        return out\n\n    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):\n        \"\"\"\n        Attention forward function\n        \"\"\"\n        if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:\n            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)\n        H = int(np.sqrt(q.shape[1]))\n        if H == 16:\n            mask = self.mask_16.to(sim.device)\n        elif H == 32:\n            mask = self.mask_32.to(sim.device)\n        elif H == 64:\n            mask = self.mask_64.to(sim.device)\n        else:\n            mask = self.mask_128.to(sim.device)\n\n        q_wo, q_w = q.chunk(2)\n        k_wo, k_w = k.chunk(2)\n        v_wo, v_w = v.chunk(2)\n        sim_wo, sim_w = sim.chunk(2)\n        attn_wo, attn_w = attn.chunk(2)\n\n        out_source = self.attn_batch(\n            q_wo,\n            k_wo,\n            v_wo,\n            sim_wo,\n            attn_wo,\n            is_cross,\n            place_in_unet,\n            num_heads,\n            is_mask_attn=False,\n            mask=None,\n            **kwargs,\n        )\n        out_target = self.attn_batch(\n            q_w, k_w, v_w, sim_w, attn_w, is_cross, place_in_unet, num_heads, is_mask_attn=True, mask=mask, **kwargs\n        )\n\n        if self.mask is not None:\n            if out_target.shape[0] == 2:\n                out_target_fg, out_target_bg = out_target.chunk(2, 0)\n                mask = mask.reshape(-1, 1)  # (hw, 1)\n                out_target = out_target_fg * mask + out_target_bg * (1 - mask)\n            else:\n                out_target = out_target\n\n        out = torch.cat([out_source, out_target], dim=0)\n        return out\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\ndef mask_pil_to_torch(mask, height, width):\n    # preprocess mask\n    if isinstance(mask, (PIL.Image.Image, np.ndarray)):\n        mask = [mask]\n\n    if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):\n        mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]\n        mask = np.concatenate([np.array(m.convert(\"L\"))[None, None, :] for m in mask], axis=0)\n        mask = mask.astype(np.float32) / 255.0\n    elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):\n        mask = np.concatenate([m[None, None, :] for m in mask], axis=0)\n\n    mask = torch.from_numpy(mask)\n    return mask\n\n\ndef prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):\n    \"\"\"\n    Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be\n    converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the\n    ``image`` and ``1`` for the ``mask``.\n\n    The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be\n    binarized (``mask > 0.5``) and cast to ``torch.float32`` too.\n\n    Args:\n        image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.\n            It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``\n            ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.\n        mask (_type_): The mask to apply to the image, i.e. regions to inpaint.\n            It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``\n            ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.\n\n\n    Raises:\n        ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask\n        should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.\n        TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not\n            (ot the other way around).\n\n    Returns:\n        tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4\n            dimensions: ``batch x channels x height x width``.\n    \"\"\"\n\n    if image is None:\n        raise ValueError(\"`image` input cannot be undefined.\")\n\n    if mask is None:\n        raise ValueError(\"`mask_image` input cannot be undefined.\")\n\n    if isinstance(image, torch.Tensor):\n        if not isinstance(mask, torch.Tensor):\n            mask = mask_pil_to_torch(mask, height, width)\n\n        if image.ndim == 3:\n            image = image.unsqueeze(0)\n\n        # Batch and add channel dim for single mask\n        if mask.ndim == 2:\n            mask = mask.unsqueeze(0).unsqueeze(0)\n\n        # Batch single mask or add channel dim\n        if mask.ndim == 3:\n            # Single batched mask, no channel dim or single mask not batched but channel dim\n            if mask.shape[0] == 1:\n                mask = mask.unsqueeze(0)\n\n            # Batched masks no channel dim\n            else:\n                mask = mask.unsqueeze(1)\n\n        assert image.ndim == 4 and mask.ndim == 4, \"Image and Mask must have 4 dimensions\"\n        # assert image.shape[-2:] == mask.shape[-2:], \"Image and Mask must have the same spatial dimensions\"\n        assert image.shape[0] == mask.shape[0], \"Image and Mask must have the same batch size\"\n\n        # Check image is in [-1, 1]\n        # if image.min() < -1 or image.max() > 1:\n        #    raise ValueError(\"Image should be in [-1, 1] range\")\n\n        # Check mask is in [0, 1]\n        if mask.min() < 0 or mask.max() > 1:\n            raise ValueError(\"Mask should be in [0, 1] range\")\n\n        # Binarize mask\n        mask[mask < 0.5] = 0\n        mask[mask >= 0.5] = 1\n\n        # Image as float32\n        image = image.to(dtype=torch.float32)\n    elif isinstance(mask, torch.Tensor):\n        raise TypeError(f\"`mask` is a torch.Tensor but `image` (type: {type(image)} is not\")\n    else:\n        # preprocess image\n        if isinstance(image, (PIL.Image.Image, np.ndarray)):\n            image = [image]\n        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n            # resize all images w.r.t passed height an width\n            image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]\n            image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n            image = np.concatenate(image, axis=0)\n        elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n            image = np.concatenate([i[None, :] for i in image], axis=0)\n\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n        mask = mask_pil_to_torch(mask, height, width)\n        mask[mask < 0.5] = 0\n        mask[mask >= 0.5] = 1\n\n    if image.shape[1] == 4:\n        # images are in latent space and thus can't\n        # be masked set masked_image to None\n        # we assume that the checkpoint is not an inpainting\n        # checkpoint. TOD(Yiyi) - need to clean this up later\n        masked_image = None\n    else:\n        masked_image = image * (mask < 0.5)\n\n    # n.b. ensure backwards compatibility as old function does not return image\n    if return_image:\n        return mask, masked_image, image\n\n    return mask, masked_image\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass StableDiffusionXL_AE_Pipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    FromSingleFileMixin,\n    IPAdapterMixin,\n):\n    r\"\"\"\n    Pipeline for object removal using Stable Diffusion XL.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion XL uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([` CLIPTextModelWithProjection`]):\n            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\n            specifically the\n            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)\n            variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`CLIPTokenizer`):\n            Second Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        requires_aesthetics_score (`bool`, *optional*, defaults to `\"False\"`):\n            Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config\n            of `stabilityai/stable-diffusion-xl-refiner-1-0`.\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\n            `stabilityai/stable-diffusion-xl-base-1-0`.\n        add_watermarker (`bool`, *optional*):\n            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to\n            watermark output images. If not defined, it will default to True if the package is installed, otherwise no\n            watermarker will be used.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->image_encoder->unet->vae\"\n\n    _optional_components = [\n        \"tokenizer\",\n        \"tokenizer_2\",\n        \"text_encoder\",\n        \"text_encoder_2\",\n        \"image_encoder\",\n        \"feature_extractor\",\n    ]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"add_text_embeds\",\n        \"add_time_ids\",\n        \"negative_pooled_prompt_embeds\",\n        \"add_neg_time_ids\",\n        \"mask\",\n        \"masked_image_latents\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        text_encoder_2: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        tokenizer_2: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        feature_extractor: CLIPImageProcessor = None,\n        requires_aesthetics_score: bool = False,\n        force_zeros_for_empty_prompt: bool = True,\n        add_watermarker: Optional[bool] = None,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            unet=unet,\n            image_encoder=image_encoder,\n            feature_extractor=feature_extractor,\n            scheduler=scheduler,\n        )\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.mask_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True\n        )\n\n        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()\n\n        if add_watermarker:\n            self.watermark = StableDiffusionXLWatermarker()\n        else:\n            self.watermark = None\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            image_embeds = []\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)\n                single_negative_image_embeds = torch.stack(\n                    [single_negative_image_embeds] * num_images_per_prompt, dim=0\n                )\n\n                if do_classifier_free_guidance:\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                    single_image_embeds = single_image_embeds.to(device)\n\n                image_embeds.append(single_image_embeds)\n        else:\n            repeat_dims = [1]\n            image_embeds = []\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                    single_negative_image_embeds = single_negative_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))\n                    )\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                else:\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                image_embeds.append(single_image_embeds)\n\n        return image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: str,\n        prompt_2: str | None = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder, lora_scale)\n\n            if self.text_encoder_2 is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]\n        text_encoders = (\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\n        )\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # textual inversion: process multi-vector tokens if necessary\n            prompt_embeds_list = []\n            prompts = [prompt, prompt_2]\n            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                text_input_ids = text_inputs.input_ids\n                untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                    text_input_ids, untruncated_ids\n                ):\n                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                    logger.warning(\n                        \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                        f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                    )\n\n                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                pooled_prompt_embeds = prompt_embeds[0]\n                if clip_skip is None:\n                    prompt_embeds = prompt_embeds.hidden_states[-2]\n                else:\n                    # \"2\" because SDXL always indexes from the penultimate layer.\n                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\n\n            # normalize str to list\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n            negative_prompt_2 = (\n                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2\n            )\n\n            uncond_tokens: List[str]\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = [negative_prompt, negative_prompt_2]\n\n            negative_prompt_embeds_list = []\n            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    negative_prompt,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                negative_prompt_embeds = text_encoder(\n                    uncond_input.input_ids.to(device),\n                    output_hidden_states=True,\n                )\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                negative_pooled_prompt_embeds = negative_prompt_embeds[0]\n                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n\n        if self.text_encoder_2 is not None:\n            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n        else:\n            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            if self.text_encoder_2 is not None:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n            else:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        image,\n        mask_image,\n        height,\n        width,\n        strength,\n        callback_steps,\n        output_type,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        padding_mask_crop=None,\n    ):\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n        if padding_mask_crop is not None:\n            if not isinstance(image, PIL.Image.Image):\n                raise ValueError(\n                    f\"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}.\"\n                )\n            if not isinstance(mask_image, PIL.Image.Image):\n                raise ValueError(\n                    f\"The mask image should be a PIL image when inpainting mask crop, but is of type\"\n                    f\" {type(mask_image)}.\"\n                )\n            if output_type != \"pil\":\n                raise ValueError(f\"The output type should be PIL when inpainting mask crop, but is {output_type}.\")\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n        if ip_adapter_image_embeds is not None:\n            if not isinstance(ip_adapter_image_embeds, list):\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\n                )\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\n                )\n\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n        image=None,\n        timestep=None,\n        is_strength_max=True,\n        add_noise=True,\n        return_noise=False,\n        return_image_latents=False,\n    ):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if (image is None or timestep is None) and not is_strength_max:\n            raise ValueError(\n                \"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise.\"\n                \"However, either the image or the noise timestep has not been provided.\"\n            )\n\n        if image.shape[1] == 4:\n            image_latents = image.to(device=device, dtype=dtype)\n            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)\n        elif return_image_latents or (latents is None and not is_strength_max):\n            image = image.to(device=device, dtype=dtype)\n            image_latents = self._encode_vae_image(image=image, generator=generator)\n            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)\n\n        if latents is None and add_noise:\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            # if strength is 1. then initialise the latents to noise, else initial to image + noise\n            latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)\n            # if pure noise then scale the initial latents by the  Scheduler's init sigma\n            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents\n        elif add_noise:\n            noise = latents.to(device)\n            latents = noise * self.scheduler.init_noise_sigma\n        else:\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            latents = image_latents.to(device)\n\n        outputs = (latents,)\n\n        if return_noise:\n            outputs += (noise,)\n\n        if return_image_latents:\n            outputs += (image_latents,)\n\n        return outputs\n\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\n        dtype = image.dtype\n        if self.vae.config.force_upcast:\n            image = image.float()\n            self.vae.to(dtype=torch.float32)\n\n        if isinstance(generator, list):\n            image_latents = [\n                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                for i in range(image.shape[0])\n            ]\n            image_latents = torch.cat(image_latents, dim=0)\n        else:\n            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n        if self.vae.config.force_upcast:\n            self.vae.to(dtype)\n\n        image_latents = image_latents.to(dtype)\n        image_latents = self.vae.config.scaling_factor * image_latents\n\n        return image_latents\n\n    def prepare_mask_latents(\n        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance\n    ):\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\n        # and half precision\n        # mask = torch.nn.functional.interpolate(\n        #    mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)\n        # )\n        mask = torch.nn.functional.max_pool2d(mask, (8, 8)).round()\n        mask = mask.to(device=device, dtype=dtype)\n\n        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\n        if mask.shape[0] < batch_size:\n            if not batch_size % mask.shape[0] == 0:\n                raise ValueError(\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\n                    f\" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number\"\n                    \" of masks that you pass is divisible by the total requested batch size.\"\n                )\n            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)\n\n        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask\n\n        if masked_image is not None and masked_image.shape[1] == 4:\n            masked_image_latents = masked_image\n        else:\n            masked_image_latents = None\n\n        if masked_image is not None:\n            if masked_image_latents is None:\n                masked_image = masked_image.to(device=device, dtype=dtype)\n                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)\n\n            if masked_image_latents.shape[0] < batch_size:\n                if not batch_size % masked_image_latents.shape[0] == 0:\n                    raise ValueError(\n                        \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                        f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\n                        \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                    )\n                masked_image_latents = masked_image_latents.repeat(\n                    batch_size // masked_image_latents.shape[0], 1, 1, 1\n                )\n\n            masked_image_latents = (\n                torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents\n            )\n\n            # aligning device to prevent device errors when concating it with the latent model input\n            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\n\n        return mask, masked_image_latents\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):\n        # get the original timestep using init_timestep\n        if denoising_start is None:\n            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n            t_start = max(num_inference_steps - init_timestep, 0)\n        else:\n            t_start = 0\n\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        # Strength is irrelevant if we directly request a timestep to start at;\n        # that is, strength is determined by the denoising_start instead.\n        if denoising_start is not None:\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_start * self.scheduler.config.num_train_timesteps)\n                )\n            )\n\n            num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()\n            if self.scheduler.order == 2 and num_inference_steps % 2 == 0:\n                # if the scheduler is a 2nd order scheduler we might have to do +1\n                # because `num_inference_steps` might be even given that every timestep\n                # (except the highest one) is duplicated. If `num_inference_steps` is even it would\n                # mean that we cut the timesteps in the middle of the denoising step\n                # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1\n                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler\n                num_inference_steps = num_inference_steps + 1\n\n            # because t_n+1 >= t_n, we slice the timesteps starting from the end\n            timesteps = timesteps[-num_inference_steps:]\n            return timesteps, num_inference_steps\n\n        return timesteps, num_inference_steps - t_start\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids\n    def _get_add_time_ids(\n        self,\n        original_size,\n        crops_coords_top_left,\n        target_size,\n        aesthetic_score,\n        negative_aesthetic_score,\n        negative_original_size,\n        negative_crops_coords_top_left,\n        negative_target_size,\n        dtype,\n        text_encoder_projection_dim=None,\n    ):\n        if self.config.requires_aesthetics_score:\n            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))\n            add_neg_time_ids = list(\n                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)\n            )\n        else:\n            add_time_ids = list(original_size + crops_coords_top_left + target_size)\n            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)\n\n        passed_add_embed_dim = (\n            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim\n        )\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if (\n            expected_add_embed_dim > passed_add_embed_dim\n            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model.\"\n            )\n        elif (\n            expected_add_embed_dim < passed_add_embed_dim\n            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model.\"\n            )\n        elif expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)\n\n        return add_time_ids, add_neg_time_ids\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    @property\n    def do_self_attention_redirection_guidance(self):  # SARG\n        return self._rm_guidance_scale > 1 and self._AAS\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return (\n            self._guidance_scale > 1\n            and self.unet.config.time_cond_proj_dim is None\n            and not self.do_self_attention_redirection_guidance\n        )  # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def denoising_end(self):\n        return self._denoising_end\n\n    @property\n    def denoising_start(self):\n        return self._denoising_start\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    def image2latent(self, image: torch.Tensor, generator: torch.Generator):\n        DEVICE = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n        if type(image) is Image:\n            image = np.array(image)\n            image = torch.from_numpy(image).float() / 127.5 - 1\n            image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)\n        # input image density range [-1, 1]\n        # latents = self.vae.encode(image)['latent_dist'].mean\n        latents = self._encode_vae_image(image, generator)\n        # latents = retrieve_latents(self.vae.encode(image))\n        # latents = latents * self.vae.config.scaling_factor\n        return latents\n\n    def next_step(self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta=0.0, verbose=False):\n        \"\"\"\n        Inverse sampling for DDIM Inversion\n        \"\"\"\n        if verbose:\n            print(\"timestep: \", timestep)\n        next_step = timestep\n        timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)\n        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod\n        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]\n        beta_prod_t = 1 - alpha_prod_t\n        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        pred_dir = (1 - alpha_prod_t_next) ** 0.5 * model_output\n        x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir\n        return x_next, pred_x0\n\n    @torch.no_grad()\n    def invert(\n        self,\n        image: torch.Tensor,\n        prompt,\n        num_inference_steps=50,\n        eta=0.0,\n        original_size: Tuple[int, int] = None,\n        target_size: Tuple[int, int] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        aesthetic_score: float = 6.0,\n        negative_aesthetic_score: float = 2.5,\n        return_intermediates=False,\n        **kwds,\n    ):\n        \"\"\"\n        invert a real image into noise map with determinisc DDIM inversion\n        \"\"\"\n        DEVICE = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n        batch_size = image.shape[0]\n        if isinstance(prompt, list):\n            if batch_size == 1:\n                image = image.expand(len(prompt), -1, -1, -1)\n        elif isinstance(prompt, str):\n            if batch_size > 1:\n                prompt = [prompt] * batch_size\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]\n        text_encoders = (\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\n        )\n\n        prompt_2 = prompt\n        prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n        # textual inversion: process multi-vector tokens if necessary\n        prompt_embeds_list = []\n        prompts = [prompt, prompt_2]\n        for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n            text_inputs = tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            prompt_embeds = text_encoder(text_input_ids.to(DEVICE), output_hidden_states=True)\n\n            # We are only ALWAYS interested in the pooled output of the final text encoder\n            pooled_prompt_embeds = prompt_embeds[0]\n            prompt_embeds = prompt_embeds.hidden_states[-2]\n            prompt_embeds_list.append(prompt_embeds)\n        prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n        prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=DEVICE)\n\n        # define initial latents\n        latents = self.image2latent(image, generator=None)\n\n        start_latents = latents\n        height, width = latents.shape[-2:]\n        height = height * self.vae_scale_factor\n        width = width * self.vae_scale_factor\n\n        original_size = (height, width)\n        target_size = (height, width)\n        negative_original_size = original_size\n        negative_target_size = target_size\n\n        add_text_embeds = pooled_prompt_embeds\n        text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        add_time_ids, add_neg_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            aesthetic_score,\n            negative_aesthetic_score,\n            negative_original_size,\n            negative_crops_coords_top_left,\n            negative_target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n\n        add_time_ids = add_time_ids.repeat(batch_size, 1).to(DEVICE)\n\n        # interactive sampling\n        self.scheduler.set_timesteps(num_inference_steps)\n        latents_list = [latents]\n        pred_x0_list = []\n        # for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc=\"DDIM Inversion\")):\n        for i, t in enumerate(reversed(self.scheduler.timesteps)):\n            model_inputs = latents\n\n            # predict the noise\n            added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n            noise_pred = self.unet(\n                model_inputs, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs\n            ).sample\n\n            # compute the previous noise sample x_t-1 -> x_t\n            latents, pred_x0 = self.next_step(noise_pred, t, latents)\n            \"\"\"\n            if t >= 1 and t < 41:\n                latents, pred_x0 = self.next_step_degrade(noise_pred, t, latents, mask)\n            else:\n                latents, pred_x0 = self.next_step(noise_pred, t, latents) \"\"\"\n\n            latents_list.append(latents)\n            pred_x0_list.append(pred_x0)\n\n        if return_intermediates:\n            # return the intermediate laters during inversion\n            # pred_x0_list = [self.latent2image(img, return_type=\"np\") for img in pred_x0_list]\n            # latents_list = [self.latent2image(img, return_type=\"np\") for img in latents_list]\n            return latents, latents_list, pred_x0_list\n        return latents, start_latents\n\n    def opt(\n        self,\n        model_output: torch.FloatTensor,\n        timestep: int,\n        x: torch.FloatTensor,\n    ):\n        \"\"\"\n        predict the sample the next step in the denoise process.\n        \"\"\"\n        ref_noise = model_output[:1, :, :, :].expand(model_output.shape)\n        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]\n        beta_prod_t = 1 - alpha_prod_t\n        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5\n        x_opt = alpha_prod_t**0.5 * pred_x0 + (1 - alpha_prod_t) ** 0.5 * ref_noise\n        return x_opt, pred_x0\n\n    def regiter_attention_editor_diffusers(self, unet, editor: AttentionBase):\n        \"\"\"\n        Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]\n        \"\"\"\n\n        def ca_forward(self, place_in_unet):\n            def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):\n                \"\"\"\n                The attention is similar to the original implementation of LDM CrossAttention class\n                except adding some modifications on the attention\n                \"\"\"\n                if encoder_hidden_states is not None:\n                    context = encoder_hidden_states\n                if attention_mask is not None:\n                    mask = attention_mask\n\n                to_out = self.to_out\n                if isinstance(to_out, nn.modules.container.ModuleList):\n                    to_out = self.to_out[0]\n                else:\n                    to_out = self.to_out\n\n                h = self.heads\n                q = self.to_q(x)\n                is_cross = context is not None\n                context = context if is_cross else x\n                k = self.to_k(context)\n                v = self.to_v(context)\n                # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))\n                q, k, v = (rearrange(t, \"b n (h d) -> (b h) n d\", h=h) for t in (q, k, v))\n\n                sim = torch.einsum(\"b i d, b j d -> b i j\", q, k) * self.scale\n\n                if mask is not None:\n                    mask = rearrange(mask, \"b ... -> b (...)\")\n                    max_neg_value = -torch.finfo(sim.dtype).max\n                    mask = repeat(mask, \"b j -> (b h) () j\", h=h)\n                    mask = mask[:, None, :].repeat(h, 1, 1)\n                    sim.masked_fill_(~mask, max_neg_value)\n\n                attn = sim.softmax(dim=-1)\n                # the only difference\n                out = editor(q, k, v, sim, attn, is_cross, place_in_unet, self.heads, scale=self.scale)\n\n                return to_out(out)\n\n            return forward\n\n        def register_editor(net, count, place_in_unet):\n            for name, subnet in net.named_children():\n                if net.__class__.__name__ == \"Attention\":  # spatial Transformer layer\n                    net.forward = ca_forward(net, place_in_unet)\n                    return count + 1\n                elif hasattr(net, \"children\"):\n                    count = register_editor(subnet, count, place_in_unet)\n            return count\n\n        cross_att_count = 0\n        for net_name, net in unet.named_children():\n            if \"down\" in net_name:\n                cross_att_count += register_editor(net, 0, \"down\")\n            elif \"mid\" in net_name:\n                cross_att_count += register_editor(net, 0, \"mid\")\n            elif \"up\" in net_name:\n                cross_att_count += register_editor(net, 0, \"up\")\n        editor.num_att_layers = cross_att_count\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        image: PipelineImageInput = None,\n        mask_image: PipelineImageInput = None,\n        masked_image_latents: torch.FloatTensor = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        padding_mask_crop: Optional[int] = None,\n        strength: float = 0.9999,\n        AAS: bool = True,  # AE parameter\n        rm_guidance_scale: float = 7.0,  # AE parameter\n        ss_steps: int = 9,  # AE parameter\n        ss_scale: float = 0.3,  # AE parameter\n        AAS_start_step: int = 0,  # AE parameter\n        AAS_start_layer: int = 34,  # AE parameter\n        AAS_end_layer: int = 70,  # AE parameter\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        denoising_start: Optional[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Tuple[int, int] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Tuple[int, int] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        aesthetic_score: float = 6.0,\n        negative_aesthetic_score: float = 2.5,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will\n                be masked out with `mask_image` and repainted according to `prompt`.\n            mask_image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted\n                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)\n                instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            padding_mask_crop (`int`, *optional*, defaults to `None`):\n                The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If\n                `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and\n                contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on\n                the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large\n                and contain information inreleant for inpainging, such as background.\n            strength (`float`, *optional*, defaults to 0.9999):\n                Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be\n                between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the\n                `strength`. The number of denoising steps depends on the amount of noise initially added. When\n                `strength` is 1, added noise will be maximum and the denoising process will run for the full number of\n                iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked\n                portion of the reference `image`. Note that in the case of `denoising_start` being declared as an\n                integer, the value of `strength` will be ignored.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            denoising_start (`float`, *optional*):\n                When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be\n                bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and\n                it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,\n                strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline\n                is integrated into a \"Mixture of Denoisers\" multi-pipeline setup, as detailed in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be\n                denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the\n                final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline\n                forms a part of a \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.\n                Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding\n                if `do_classifier_free_guidance` is set to `True`.\n                If not provided, embeddings are computed from the `ip_adapter_image` input argument.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            aesthetic_score (`float`, *optional*, defaults to 6.0):\n                Used to simulate an aesthetic score of the generated image by influencing the positive text condition.\n                Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_aesthetic_score (`float`, *optional*, defaults to 2.5):\n                Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to\n                simulate an aesthetic score of the generated image by influencing the negative text condition.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            image,\n            mask_image,\n            height,\n            width,\n            strength,\n            callback_steps,\n            output_type,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            callback_on_step_end_tensor_inputs,\n            padding_mask_crop,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n        self._denoising_start = denoising_start\n        self._interrupt = False\n\n        ########### AE parameters\n        self._num_timesteps = num_inference_steps\n        self._rm_guidance_scale = rm_guidance_scale\n        self._AAS = AAS\n        self._ss_steps = ss_steps\n        self._ss_scale = ss_scale\n        self._AAS_start_step = AAS_start_step\n        self._AAS_start_layer = AAS_start_layer\n        self._AAS_end_layer = AAS_end_layer\n        ###########\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # 4. set timesteps\n        def denoising_value_valid(dnv):\n            return isinstance(dnv, float) and 0 < dnv < 1\n\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n        timesteps, num_inference_steps = self.get_timesteps(\n            num_inference_steps,\n            strength,\n            device,\n            denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,\n        )\n        # check that number of inference steps is not < 1 - as this doesn't make sense\n        if num_inference_steps < 1:\n            raise ValueError(\n                f\"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline\"\n                f\"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline.\"\n            )\n        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise\n        is_strength_max = strength == 1.0\n\n        # 5. Preprocess mask and image\n        if padding_mask_crop is not None:\n            crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)\n            resize_mode = \"fill\"\n        else:\n            crops_coords = None\n            resize_mode = \"default\"\n\n        original_image = image\n        init_image = self.image_processor.preprocess(\n            image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode\n        )\n        init_image = init_image.to(dtype=torch.float32)\n\n        mask = self.mask_processor.preprocess(\n            mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords\n        )\n\n        if masked_image_latents is not None:\n            masked_image = masked_image_latents\n        elif init_image.shape[1] == 4:\n            # if images are in latent space, we can't mask it\n            masked_image = None\n        else:\n            masked_image = init_image * (mask < 0.5)\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.vae.config.latent_channels\n        num_channels_unet = self.unet.config.in_channels\n        return_image_latents = num_channels_unet == 4\n\n        add_noise = True if self.denoising_start is None else False\n        latents_outputs = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n            image=init_image,\n            timestep=latent_timestep,\n            is_strength_max=is_strength_max,\n            add_noise=add_noise,\n            return_noise=True,\n            return_image_latents=return_image_latents,\n        )\n\n        if return_image_latents:\n            latents, noise, image_latents = latents_outputs\n        else:\n            latents, noise = latents_outputs\n\n        # 7. Prepare mask latent variables\n        mask, masked_image_latents = self.prepare_mask_latents(\n            mask,\n            masked_image,\n            batch_size * num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            self.do_classifier_free_guidance,\n        )\n\n        # 8. Check that sizes of mask, masked image and latents match\n        if num_channels_unet == 9:\n            # default case for stable-diffusion-v1-5/stable-diffusion-inpainting\n            num_channels_mask = mask.shape[1]\n            num_channels_masked_image = masked_image_latents.shape[1]\n            if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:\n                raise ValueError(\n                    f\"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects\"\n                    f\" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +\"\n                    f\" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}\"\n                    f\" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of\"\n                    \" `pipeline.unet` or your `mask_image` or `image` input.\"\n                )\n        elif num_channels_unet != 4:\n            raise ValueError(\n                f\"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}.\"\n            )\n        # 8.1 Prepare extra step kwargs.\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        height, width = latents.shape[-2:]\n        height = height * self.vae_scale_factor\n        width = width * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 10. Prepare added time ids & embeddings\n        if negative_original_size is None:\n            negative_original_size = original_size\n        if negative_target_size is None:\n            negative_target_size = target_size\n\n        add_text_embeds = pooled_prompt_embeds\n        if self.text_encoder_2 is None:\n            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        else:\n            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n        add_time_ids, add_neg_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            aesthetic_score,\n            negative_aesthetic_score,\n            negative_original_size,\n            negative_crops_coords_top_left,\n            negative_target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n            add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)\n\n        ###########\n        if self.do_self_attention_redirection_guidance:\n            prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)\n            add_neg_time_ids = add_neg_time_ids.repeat(2, 1)\n            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)\n        ############\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device)\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # apply AAS to modify the attention module\n        if self.do_self_attention_redirection_guidance:\n            self._AAS_end_step = int(strength * self._num_timesteps)\n            layer_idx = list(range(self._AAS_start_layer, self._AAS_end_layer))\n            editor = AAS_XL(\n                self._AAS_start_step,\n                self._AAS_end_step,\n                self._AAS_start_layer,\n                self._AAS_end_layer,\n                layer_idx=layer_idx,\n                mask=mask_image,\n                model_type=\"SDXL\",\n                ss_steps=self._ss_steps,\n                ss_scale=self._ss_scale,\n            )\n            self.regiter_attention_editor_diffusers(self.unet, editor)\n\n        # 11. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        if (\n            self.denoising_end is not None\n            and self.denoising_start is not None\n            and denoising_value_valid(self.denoising_end)\n            and denoising_value_valid(self.denoising_start)\n            and self.denoising_start >= self.denoising_end\n        ):\n            raise ValueError(\n                f\"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: \"\n                + f\" {self.denoising_end} when using type float.\"\n            )\n        elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        # 11.1 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        self._num_timesteps = len(timesteps)\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n\n                # removal guidance\n                latent_model_input = (\n                    torch.cat([latents] * 2) if self.do_self_attention_redirection_guidance else latents\n                )  # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not\n                # latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents\n\n                # concat latents, mask, masked_image_latents in the channel dimension\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n                # latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)\n\n                if num_channels_unet == 9:\n                    latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)\n\n                # predict the noise residual\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n                    added_cond_kwargs[\"image_embeds\"] = image_embeds\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform SARG\n                if self.do_self_attention_redirection_guidance:\n                    noise_pred_wo, noise_pred_w = noise_pred.chunk(2)\n                    delta = noise_pred_w - noise_pred_wo\n                    noise_pred = noise_pred_wo + self._rm_guidance_scale * delta\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if num_channels_unet == 4:\n                    init_latents_proper = image_latents\n                    if self.do_classifier_free_guidance:\n                        init_mask, _ = mask.chunk(2)\n                    else:\n                        init_mask = mask\n\n                    if i < len(timesteps) - 1:\n                        noise_timestep = timesteps[i + 1]\n                        init_latents_proper = self.scheduler.add_noise(\n                            init_latents_proper, noise, torch.tensor([noise_timestep])\n                        )\n\n                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                    negative_pooled_prompt_embeds = callback_outputs.pop(\n                        \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                    )\n                    add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n                    add_neg_time_ids = callback_outputs.pop(\"add_neg_time_ids\", add_neg_time_ids)\n                    mask = callback_outputs.pop(\"mask\", mask)\n                    masked_image_latents = callback_outputs.pop(\"masked_image_latents\", masked_image_latents)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n            latents = latents[-1:]\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            # unscale/denormalize the latents\n            # denormalize with the mean and std if available and not None\n            has_latents_mean = hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None\n            has_latents_std = hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None\n            if has_latents_mean and has_latents_std:\n                latents_mean = (\n                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents_std = (\n                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean\n            else:\n                latents = latents / self.vae.config.scaling_factor\n\n            image = self.vae.decode(latents, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            return StableDiffusionXLPipelineOutput(images=latents)\n\n        # apply watermark if available\n        if self.watermark is not None:\n            image = self.watermark.apply_watermark(image)\n\n        image = self.image_processor.postprocess(image, output_type=output_type)\n\n        if padding_mask_crop is not None:\n            image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py",
    "content": "# Copyright 2025 TencentARC and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nfrom transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    PIL_INTERPOLATION,\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, DDPMScheduler\n        >>> from diffusers.utils import load_image\n        >>> from controlnet_aux.midas import MidasDetector\n\n        >>> img_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\n        >>> mask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\n        >>> image = load_image(img_url).resize((1024, 1024))\n        >>> mask_image = load_image(mask_url).resize((1024, 1024))\n\n        >>> midas_depth = MidasDetector.from_pretrained(\n        ...    \"valhalla/t2iadapter-aux-models\", filename=\"dpt_large_384.pt\", model_type=\"dpt_large\"\n        ... ).to(\"cuda\")\n\n        >>> depth_image = midas_depth(\n        ...    image, detect_resolution=512, image_resolution=1024\n        ... )\n\n        >>> model_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\n\n        >>> adapter = T2IAdapter.from_pretrained(\n        ...     \"Adapter/t2iadapter\",\n        ...     subfolder=\"sketch_sdxl_1.0\",\n        ...     torch_dtype=torch.float16,\n        ...     adapter_type=\"full_adapter_xl\",\n        ... )\n\n        >>> controlnet = ControlNetModel.from_pretrained(\n        ...    \"diffusers/controlnet-depth-sdxl-1.0\",\n        ...    torch_dtype=torch.float16,\n        ...    variant=\"fp16\",\n        ...    use_safetensors=True\n        ... ).to(\"cuda\")\n\n        >>> scheduler = DDPMScheduler.from_pretrained(model_id, subfolder=\"scheduler\")\n\n        >>> pipe = StableDiffusionXLAdapterPipeline.from_pretrained(\n        ...     model_id,\n        ...     adapter=adapter,\n        ...     controlnet=controlnet,\n        ...     torch_dtype=torch.float16,\n        ...     variant=\"fp16\",\n        ...     scheduler=scheduler\n        ... ).to(\"cuda\")\n\n        >>> strength = 0.5\n\n        >>> generator = torch.manual_seed(42)\n        >>> sketch_image_out = pipe(\n        ...     prompt=\"a photo of a tiger sitting on a park bench\",\n        ...     negative_prompt=\"extra digit, fewer digits, cropped, worst quality, low quality\",\n        ...     adapter_image=depth_image,\n        ...     control_image=mask_image,\n        ...     adapter_conditioning_scale=strength,\n        ...     controlnet_conditioning_scale=strength,\n        ...     generator=generator,\n        ...     guidance_scale=7.5,\n        ... ).images[0]\n        ```\n\"\"\"\n\n\ndef _preprocess_adapter_image(image, height, width):\n    if isinstance(image, torch.Tensor):\n        return image\n    elif isinstance(image, PIL.Image.Image):\n        image = [image]\n\n    if isinstance(image[0], PIL.Image.Image):\n        image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION[\"lanczos\"])) for i in image]\n        image = [\n            i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image\n        ]  # expand [h, w] or [h, w, c] to [b, h, w, c]\n        image = np.concatenate(image, axis=0)\n        image = np.array(image).astype(np.float32) / 255.0\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image)\n    elif isinstance(image[0], torch.Tensor):\n        if image[0].ndim == 3:\n            image = torch.stack(image, dim=0)\n        elif image[0].ndim == 4:\n            image = torch.cat(image, dim=0)\n        else:\n            raise ValueError(\n                f\"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but receive: {image[0].ndim}\"\n            )\n    return image\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\nclass StableDiffusionXLControlNetAdapterPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    FromSingleFileMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter\n    https://huggingface.co/papers/2302.08453\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):\n            Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a\n            list, the outputs from each Adapter are added together to create one combined additional conditioning.\n        adapter_weights (`List[float]`, *optional*, defaults to None):\n            List of floats representing the weight which will be multiply to each adapter's output before adding them\n            together.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->unet->vae\"\n    _optional_components = [\"tokenizer\", \"tokenizer_2\", \"text_encoder\", \"text_encoder_2\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        text_encoder_2: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        tokenizer_2: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],\n        controlnet: Union[ControlNetModel, MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        force_zeros_for_empty_prompt: bool = True,\n    ):\n        super().__init__()\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            unet=unet,\n            adapter=adapter,\n            controlnet=controlnet,\n            scheduler=scheduler,\n        )\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n        self.default_sample_size = (\n            self.unet.config.sample_size\n            if hasattr(self, \"unet\") and self.unet is not None and hasattr(self.unet.config, \"sample_size\")\n            else 128\n        )\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: str,\n        prompt_2: str | None = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder, lora_scale)\n\n            if self.text_encoder_2 is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]\n        text_encoders = (\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\n        )\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # textual inversion: process multi-vector tokens if necessary\n            prompt_embeds_list = []\n            prompts = [prompt, prompt_2]\n            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                text_input_ids = text_inputs.input_ids\n                untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                    text_input_ids, untruncated_ids\n                ):\n                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                    logger.warning(\n                        \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                        f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                    )\n\n                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:\n                    pooled_prompt_embeds = prompt_embeds[0]\n\n                if clip_skip is None:\n                    prompt_embeds = prompt_embeds.hidden_states[-2]\n                else:\n                    # \"2\" because SDXL always indexes from the penultimate layer.\n                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\n\n            # normalize str to list\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n            negative_prompt_2 = (\n                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2\n            )\n\n            uncond_tokens: List[str]\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = [negative_prompt, negative_prompt_2]\n\n            negative_prompt_embeds_list = []\n            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    negative_prompt,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                negative_prompt_embeds = text_encoder(\n                    uncond_input.input_ids.to(device),\n                    output_hidden_states=True,\n                )\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:\n                    negative_pooled_prompt_embeds = negative_prompt_embeds[0]\n                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n\n        if self.text_encoder_2 is not None:\n            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n        else:\n            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            if self.text_encoder_2 is not None:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n            else:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n    def check_conditions(\n        self,\n        prompt,\n        prompt_embeds,\n        adapter_image,\n        control_image,\n        adapter_conditioning_scale,\n        controlnet_conditioning_scale,\n        control_guidance_start,\n        control_guidance_end,\n    ):\n        # controlnet checks\n        if not isinstance(control_guidance_start, (tuple, list)):\n            control_guidance_start = [control_guidance_start]\n\n        if not isinstance(control_guidance_end, (tuple, list)):\n            control_guidance_end = [control_guidance_end]\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if len(control_guidance_start) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n        # Check controlnet `image`\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            self.check_image(control_image, prompt, prompt_embeds)\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if not isinstance(control_image, list):\n                raise TypeError(\"For multiple controlnets: `control_image` must be type `list`\")\n\n            # When `image` is a nested list:\n            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])\n            elif any(isinstance(i, list) for i in control_image):\n                raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif len(control_image) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(control_image)} images and {len(self.controlnet.nets)} ControlNets.\"\n                )\n\n            for image_ in control_image:\n                self.check_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        # adapter checks\n        if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter):\n            self.check_image(adapter_image, prompt, prompt_embeds)\n        elif (\n            isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter)\n        ):\n            if not isinstance(adapter_image, list):\n                raise TypeError(\"For multiple adapters: `adapter_image` must be type `list`\")\n\n            # When `image` is a nested list:\n            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])\n            elif any(isinstance(i, list) for i in adapter_image):\n                raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif len(adapter_image) != len(self.adapter.adapters):\n                raise ValueError(\n                    f\"For multiple adapters: `image` must have the same length as the number of adapters, but got {len(adapter_image)} images and {len(self.adapters.nets)} Adapters.\"\n                )\n\n            for image_ in adapter_image:\n                self.check_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `adapter_conditioning_scale`\n        if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter):\n            if not isinstance(adapter_conditioning_scale, float):\n                raise TypeError(\"For single adapter: `adapter_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter)\n        ):\n            if isinstance(adapter_conditioning_scale, list):\n                if any(isinstance(i, list) for i in adapter_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif isinstance(adapter_conditioning_scale, list) and len(adapter_conditioning_scale) != len(\n                self.adapter.adapters\n            ):\n                raise ValueError(\n                    \"For multiple adapters: When `adapter_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of adapters\"\n                )\n        else:\n            assert False\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids\n    def _get_add_time_ids(\n        self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None\n    ):\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n\n        passed_add_embed_dim = (\n            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim\n        )\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        return add_time_ids\n\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width\n    def _default_height_width(self, height, width, image):\n        # NOTE: It is possible that a list of images have different\n        # dimensions for each image, so just checking the first image\n        # is not _exactly_ correct, but it is simple.\n        while isinstance(image, list):\n            image = image[0]\n\n        if height is None:\n            if isinstance(image, PIL.Image.Image):\n                height = image.height\n            elif isinstance(image, torch.Tensor):\n                height = image.shape[-2]\n\n            # round down to nearest multiple of `self.adapter.downscale_factor`\n            height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor\n\n        if width is None:\n            if isinstance(image, PIL.Image.Image):\n                width = image.width\n            elif isinstance(image, torch.Tensor):\n                width = image.shape[-1]\n\n            # round down to nearest multiple of `self.adapter.downscale_factor`\n            width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor\n\n        return height, width\n\n    def prepare_control_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        adapter_image: PipelineImageInput = None,\n        control_image: PipelineImageInput = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        adapter_conditioning_scale: Union[float, List[float]] = 1.0,\n        adapter_conditioning_factor: float = 1.0,\n        clip_skip: Optional[int] = None,\n        controlnet_conditioning_scale=1.0,\n        guess_mode: bool = False,\n        control_guidance_start: float = 0.0,\n        control_guidance_end: float = 1.0,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            adapter_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`):\n                The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the\n                type is specified as `torch.Tensor`, it is passed to Adapter as is. PIL.Image.Image` can also be\n                accepted as an image. The control image is automatically resized to fit the output image.\n            control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is\n                specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be\n                accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height\n                and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in\n                `init`, images must be passed as a list such that each element of the list can be correctly batched for\n                input to a single ControlNet.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise as determined by the discrete timesteps selected by the\n                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a\n                \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`]\n                instead of a plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of\n                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).\n                Guidance rescale factor should fix overexposure when using zero terminal SNR.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the\n                residual in the original unet. If multiple adapters are specified in init, you can set the\n                corresponding scale as a list.\n            adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the\n                residual in the original unet. If multiple adapters are specified in init, you can set the\n                corresponding scale as a list.\n            adapter_conditioning_factor (`float`, *optional*, defaults to 1.0):\n                The fraction of timesteps for which adapter should be applied. If `adapter_conditioning_factor` is\n                `0.0`, adapter is not applied at all. If `adapter_conditioning_factor` is `1.0`, adapter is applied for\n                all timesteps. If `adapter_conditioning_factor` is `0.5`, adapter is applied for half of the timesteps.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple`. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n        adapter = self.adapter._orig_mod if is_compiled_module(self.adapter) else self.adapter\n\n        # 0. Default height and width to unet\n\n        height, width = self._default_height_width(height, width, adapter_image)\n        device = self._execution_device\n\n        if isinstance(adapter, MultiAdapter):\n            adapter_input = []\n\n            for one_image in adapter_image:\n                one_image = _preprocess_adapter_image(one_image, height, width)\n                one_image = one_image.to(device=device, dtype=adapter.dtype)\n                adapter_input.append(one_image)\n        else:\n            adapter_input = _preprocess_adapter_image(adapter_image, height, width)\n            adapter_input = adapter_input.to(device=device, dtype=adapter.dtype)\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 0.1 align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n        if isinstance(adapter, MultiAdapter) and isinstance(adapter_conditioning_scale, float):\n            adapter_conditioning_scale = [adapter_conditioning_scale] * len(adapter.adapters)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            height,\n            width,\n            callback_steps,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n        )\n\n        self.check_conditions(\n            prompt,\n            prompt_embeds,\n            adapter_image,\n            control_image,\n            adapter_conditioning_scale,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            clip_skip=clip_skip,\n        )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Prepare added time ids & embeddings & adapter features\n        if isinstance(adapter, MultiAdapter):\n            adapter_state = adapter(adapter_input, adapter_conditioning_scale)\n            for k, v in enumerate(adapter_state):\n                adapter_state[k] = v\n        else:\n            adapter_state = adapter(adapter_input)\n            for k, v in enumerate(adapter_state):\n                adapter_state[k] = v * adapter_conditioning_scale\n        if num_images_per_prompt > 1:\n            for k, v in enumerate(adapter_state):\n                adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)\n        if do_classifier_free_guidance:\n            for k, v in enumerate(adapter_state):\n                adapter_state[k] = torch.cat([v] * 2, dim=0)\n\n        # 7.2 Prepare control images\n        if isinstance(controlnet, ControlNetModel):\n            control_image = self.prepare_control_image(\n                image=control_image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n        elif isinstance(controlnet, MultiControlNetModel):\n            control_images = []\n\n            for control_image_ in control_image:\n                control_image_ = self.prepare_control_image(\n                    image=control_image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                control_images.append(control_image_)\n\n            control_image = control_images\n        else:\n            raise ValueError(f\"{controlnet.__class__} is not supported.\")\n\n        # 8.2 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            if isinstance(self.controlnet, MultiControlNetModel):\n                controlnet_keep.append(keeps)\n            else:\n                controlnet_keep.append(keeps[0])\n\n        add_text_embeds = pooled_prompt_embeds\n        if self.text_encoder_2 is None:\n            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        else:\n            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n        add_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n        if negative_original_size is not None and negative_target_size is not None:\n            negative_add_time_ids = self._get_add_time_ids(\n                negative_original_size,\n                negative_crops_coords_top_left,\n                negative_target_size,\n                dtype=prompt_embeds.dtype,\n                text_encoder_projection_dim=text_encoder_projection_dim,\n            )\n        else:\n            negative_add_time_ids = add_time_ids\n\n        if do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        # 8. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # 7.1 Apply denoising_end\n        if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n                if i < int(num_inference_steps * adapter_conditioning_factor):\n                    down_intrablock_additional_residuals = [state.clone() for state in adapter_state]\n                else:\n                    down_intrablock_additional_residuals = None\n\n                # ----------- ControlNet\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input_controlnet = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n\n                # concat latents, mask, masked_image_latents in the channel dimension\n                latent_model_input_controlnet = self.scheduler.scale_model_input(latent_model_input_controlnet, t)\n\n                # controlnet(s) inference\n                if guess_mode and do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                    controlnet_added_cond_kwargs = {\n                        \"text_embeds\": add_text_embeds.chunk(2)[1],\n                        \"time_ids\": add_time_ids.chunk(2)[1],\n                    }\n                else:\n                    control_model_input = latent_model_input_controlnet\n                    controlnet_prompt_embeds = prompt_embeds\n                    controlnet_added_cond_kwargs = added_cond_kwargs\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=controlnet_prompt_embeds,\n                    controlnet_cond=control_image,\n                    conditioning_scale=cond_scale,\n                    guess_mode=guess_mode,\n                    added_cond_kwargs=controlnet_added_cond_kwargs,\n                    return_dict=False,\n                )\n\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                    down_intrablock_additional_residuals=down_intrablock_additional_residuals,  # t2iadapter\n                    down_block_additional_residuals=down_block_res_samples,  # controlnet\n                    mid_block_additional_residual=mid_block_res_sample,  # controlnet\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if do_classifier_free_guidance and guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n            return StableDiffusionXLPipelineOutput(images=image)\n\n        image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py",
    "content": "# Copyright 2025 Jake Babbidge, TencentARC and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# ignore the entire file for precommit\n# type: ignore\n\nimport inspect\nfrom collections.abc import Callable\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL\nimport torch\nimport torch.nn.functional as F\nfrom transformers import (\n    CLIPTextModel,\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n)\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import (\n    AutoencoderKL,\n    ControlNetModel,\n    MultiAdapter,\n    T2IAdapter,\n    UNet2DConditionModel,\n)\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    PIL_INTERPOLATION,\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import DiffusionPipeline, T2IAdapter\n        >>> from diffusers.utils import load_image\n        >>> from PIL import Image\n        >>> from controlnet_aux.midas import MidasDetector\n\n        >>> adapter = T2IAdapter.from_pretrained(\n        ...     \"TencentARC/t2i-adapter-sketch-sdxl-1.0\", torch_dtype=torch.float16, variant=\"fp16\"\n        ... ).to(\"cuda\")\n\n        >>> controlnet = ControlNetModel.from_pretrained(\n        ...    \"diffusers/controlnet-depth-sdxl-1.0\",\n        ...    torch_dtype=torch.float16,\n        ...    variant=\"fp16\",\n        ...    use_safetensors=True\n        ... ).to(\"cuda\")\n\n        >>> pipe = DiffusionPipeline.from_pretrained(\n        ...     \"diffusers/stable-diffusion-xl-1.0-inpainting-0.1\",\n        ...     torch_dtype=torch.float16,\n        ...     variant=\"fp16\",\n        ...     use_safetensors=True,\n        ...     custom_pipeline=\"stable_diffusion_xl_adapter_controlnet_inpaint\",\n        ...     adapter=adapter,\n        ...     controlnet=controlnet,\n        ... ).to(\"cuda\")\n\n        >>> prompt = \"a tiger sitting on a park bench\"\n        >>> img_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\n        >>> mask_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n\n        >>> image = load_image(img_url).resize((1024, 1024))\n        >>> mask_image = load_image(mask_url).resize((1024, 1024))\n\n        >>> midas_depth = MidasDetector.from_pretrained(\n        ...    \"valhalla/t2iadapter-aux-models\", filename=\"dpt_large_384.pt\", model_type=\"dpt_large\"\n        ... ).to(\"cuda\")\n\n        >>> depth_image = midas_depth(\n        ...    image, detect_resolution=512, image_resolution=1024\n        ... )\n\n        >>> strength = 0.4\n\n        >>> generator = torch.manual_seed(42)\n\n        >>> result_image = pipe(\n        ...     image=image,\n        ...     mask_image=mask,\n        ...     adapter_image=depth_image,\n        ...     control_image=depth_image,\n        ...     controlnet_conditioning_scale=strength,\n        ...     adapter_conditioning_scale=strength,\n        ...     strength=0.7,\n        ...     generator=generator,\n        ...     prompt=prompt,\n        ...     negative_prompt=\"extra digit, fewer digits, cropped, worst quality, low quality\",\n        ...        num_inference_steps=50\n        ... ).images[0]\n        ```\n\"\"\"\n\n\ndef _preprocess_adapter_image(image, height, width):\n    if isinstance(image, torch.Tensor):\n        return image\n    elif isinstance(image, PIL.Image.Image):\n        image = [image]\n\n    if isinstance(image[0], PIL.Image.Image):\n        image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION[\"lanczos\"])) for i in image]\n        image = [\n            i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image\n        ]  # expand [h, w] or [h, w, c] to [b, h, w, c]\n        image = np.concatenate(image, axis=0)\n        image = np.array(image).astype(np.float32) / 255.0\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image)\n    elif isinstance(image[0], torch.Tensor):\n        if image[0].ndim == 3:\n            image = torch.stack(image, dim=0)\n        elif image[0].ndim == 4:\n            image = torch.cat(image, dim=0)\n        else:\n            raise ValueError(\n                f\"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but receive: {image[0].ndim}\"\n            )\n    return image\n\n\ndef mask_pil_to_torch(mask, height, width):\n    # preprocess mask\n    if isinstance(mask, Union[PIL.Image.Image, np.ndarray]):\n        mask = [mask]\n\n    if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):\n        mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]\n        mask = np.concatenate([np.array(m.convert(\"L\"))[None, None, :] for m in mask], axis=0)\n        mask = mask.astype(np.float32) / 255.0\n    elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):\n        mask = np.concatenate([m[None, None, :] for m in mask], axis=0)\n\n    mask = torch.from_numpy(mask)\n    return mask\n\n\ndef prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):\n    \"\"\"\n    Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be\n    converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the\n    ``image`` and ``1`` for the ``mask``.\n\n    The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be\n    binarized (``mask > 0.5``) and cast to ``torch.float32`` too.\n\n    Args:\n        image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.\n            It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``\n            ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.\n        mask (_type_): The mask to apply to the image, i.e. regions to inpaint.\n            It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``\n            ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.\n\n\n    Raises:\n        ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask\n        should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.\n        TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not\n            (ot the other way around).\n\n    Returns:\n        tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4\n            dimensions: ``batch x channels x height x width``.\n    \"\"\"\n\n    # checkpoint. #TODO(Yiyi) - need to clean this up later\n    if image is None:\n        raise ValueError(\"`image` input cannot be undefined.\")\n\n    if mask is None:\n        raise ValueError(\"`mask_image` input cannot be undefined.\")\n\n    if isinstance(image, torch.Tensor):\n        if not isinstance(mask, torch.Tensor):\n            mask = mask_pil_to_torch(mask, height, width)\n\n        if image.ndim == 3:\n            image = image.unsqueeze(0)\n\n        # Batch and add channel dim for single mask\n        if mask.ndim == 2:\n            mask = mask.unsqueeze(0).unsqueeze(0)\n\n        # Batch single mask or add channel dim\n        if mask.ndim == 3:\n            # Single batched mask, no channel dim or single mask not batched but channel dim\n            if mask.shape[0] == 1:\n                mask = mask.unsqueeze(0)\n\n            # Batched masks no channel dim\n            else:\n                mask = mask.unsqueeze(1)\n\n        assert image.ndim == 4 and mask.ndim == 4, \"Image and Mask must have 4 dimensions\"\n        # assert image.shape[-2:] == mask.shape[-2:], \"Image and Mask must have the same spatial dimensions\"\n        assert image.shape[0] == mask.shape[0], \"Image and Mask must have the same batch size\"\n\n        # Check image is in [-1, 1]\n        # if image.min() < -1 or image.max() > 1:\n        #    raise ValueError(\"Image should be in [-1, 1] range\")\n\n        # Check mask is in [0, 1]\n        if mask.min() < 0 or mask.max() > 1:\n            raise ValueError(\"Mask should be in [0, 1] range\")\n\n        # Binarize mask\n        mask[mask < 0.5] = 0\n        mask[mask >= 0.5] = 1\n\n        # Image as float32\n        image = image.to(dtype=torch.float32)\n    elif isinstance(mask, torch.Tensor):\n        raise TypeError(f\"`mask` is a torch.Tensor but `image` (type: {type(image)} is not\")\n    else:\n        # preprocess image\n        if isinstance(image, Union[PIL.Image.Image, np.ndarray]):\n            image = [image]\n        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n            # resize all images w.r.t passed height an width\n            image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]\n            image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n            image = np.concatenate(image, axis=0)\n        elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n            image = np.concatenate([i[None, :] for i in image], axis=0)\n\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n        mask = mask_pil_to_torch(mask, height, width)\n        mask[mask < 0.5] = 0\n        mask[mask >= 0.5] = 1\n\n    if image.shape[1] == 4:\n        # images are in latent space and thus can't\n        # be masked set masked_image to None\n        # we assume that the checkpoint is not an inpainting\n        # checkpoint. #TODO(Yiyi) - need to clean this up later\n        masked_image = None\n    else:\n        masked_image = image * (mask < 0.5)\n\n    # n.b. ensure backwards compatibility as old function does not return image\n    if return_image:\n        return mask, masked_image, image\n\n    return mask, masked_image\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\nclass StableDiffusionXLControlNetAdapterInpaintPipeline(\n    DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter\n    https://huggingface.co/papers/2302.08453\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):\n            Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a\n            list, the outputs from each Adapter are added together to create one combined additional conditioning.\n        adapter_weights (`List[float]`, *optional*, defaults to None):\n            List of floats representing the weight which will be multiply to each adapter's output before adding them\n            together.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n        requires_aesthetics_score (`bool`, *optional*, defaults to `\"False\"`):\n            Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config\n            of `stabilityai/stable-diffusion-xl-refiner-1-0`.\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\n            `stabilityai/stable-diffusion-xl-base-1-0`.\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        text_encoder_2: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        tokenizer_2: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        adapter: Union[T2IAdapter, MultiAdapter],\n        controlnet: Union[ControlNetModel, MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        requires_aesthetics_score: bool = False,\n        force_zeros_for_empty_prompt: bool = True,\n    ):\n        super().__init__()\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            unet=unet,\n            adapter=adapter,\n            controlnet=controlnet,\n            scheduler=scheduler,\n        )\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n        self.default_sample_size = (\n            self.unet.config.sample_size\n            if hasattr(self, \"unet\") and self.unet is not None and hasattr(self.unet.config, \"sample_size\")\n            else 128\n        )\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: str,\n        prompt_2: str | None = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder, lora_scale)\n\n            if self.text_encoder_2 is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]\n        text_encoders = (\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\n        )\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # textual inversion: process multi-vector tokens if necessary\n            prompt_embeds_list = []\n            prompts = [prompt, prompt_2]\n            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                text_input_ids = text_inputs.input_ids\n                untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                    text_input_ids, untruncated_ids\n                ):\n                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                    logger.warning(\n                        \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                        f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                    )\n\n                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:\n                    pooled_prompt_embeds = prompt_embeds[0]\n\n                if clip_skip is None:\n                    prompt_embeds = prompt_embeds.hidden_states[-2]\n                else:\n                    # \"2\" because SDXL always indexes from the penultimate layer.\n                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\n\n            # normalize str to list\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n            negative_prompt_2 = (\n                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2\n            )\n\n            uncond_tokens: List[str]\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = [negative_prompt, negative_prompt_2]\n\n            negative_prompt_embeds_list = []\n            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    negative_prompt,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                negative_prompt_embeds = text_encoder(\n                    uncond_input.input_ids.to(device),\n                    output_hidden_states=True,\n                )\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:\n                    negative_pooled_prompt_embeds = negative_prompt_embeds[0]\n                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n\n        if self.text_encoder_2 is not None:\n            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n        else:\n            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            if self.text_encoder_2 is not None:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n            else:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n    def check_conditions(\n        self,\n        prompt,\n        prompt_embeds,\n        adapter_image,\n        control_image,\n        adapter_conditioning_scale,\n        controlnet_conditioning_scale,\n        control_guidance_start,\n        control_guidance_end,\n    ):\n        # controlnet checks\n        if not isinstance(control_guidance_start, (tuple, list)):\n            control_guidance_start = [control_guidance_start]\n\n        if not isinstance(control_guidance_end, (tuple, list)):\n            control_guidance_end = [control_guidance_end]\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if len(control_guidance_start) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n        # Check controlnet `image`\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            self.check_image(control_image, prompt, prompt_embeds)\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if not isinstance(control_image, list):\n                raise TypeError(\"For multiple controlnets: `control_image` must be type `list`\")\n\n            # When `image` is a nested list:\n            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])\n            elif any(isinstance(i, list) for i in control_image):\n                raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif len(control_image) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(control_image)} images and {len(self.controlnet.nets)} ControlNets.\"\n                )\n\n            for image_ in control_image:\n                self.check_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        # adapter checks\n        if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter):\n            self.check_image(adapter_image, prompt, prompt_embeds)\n        elif (\n            isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter)\n        ):\n            if not isinstance(adapter_image, list):\n                raise TypeError(\"For multiple adapters: `adapter_image` must be type `list`\")\n\n            # When `image` is a nested list:\n            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])\n            elif any(isinstance(i, list) for i in adapter_image):\n                raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif len(adapter_image) != len(self.adapter.adapters):\n                raise ValueError(\n                    f\"For multiple adapters: `image` must have the same length as the number of adapters, but got {len(adapter_image)} images and {len(self.adapters.nets)} Adapters.\"\n                )\n\n            for image_ in adapter_image:\n                self.check_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `adapter_conditioning_scale`\n        if isinstance(self.adapter, T2IAdapter) or is_compiled and isinstance(self.adapter._orig_mod, T2IAdapter):\n            if not isinstance(adapter_conditioning_scale, float):\n                raise TypeError(\"For single adapter: `adapter_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.adapter, MultiAdapter) or is_compiled and isinstance(self.adapter._orig_mod, MultiAdapter)\n        ):\n            if isinstance(adapter_conditioning_scale, list):\n                if any(isinstance(i, list) for i in adapter_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif isinstance(adapter_conditioning_scale, list) and len(adapter_conditioning_scale) != len(\n                self.adapter.adapters\n            ):\n                raise ValueError(\n                    \"For multiple adapters: When `adapter_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of adapters\"\n                )\n        else:\n            assert False\n\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n        image=None,\n        timestep=None,\n        is_strength_max=True,\n        add_noise=True,\n        return_noise=False,\n        return_image_latents=False,\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if (image is None or timestep is None) and not is_strength_max:\n            raise ValueError(\n                \"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise.\"\n                \"However, either the image or the noise timestep has not been provided.\"\n            )\n\n        if image.shape[1] == 4:\n            image_latents = image.to(device=device, dtype=dtype)\n        elif return_image_latents or (latents is None and not is_strength_max):\n            image = image.to(device=device, dtype=dtype)\n            image_latents = self._encode_vae_image(image=image, generator=generator)\n\n        image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)\n\n        if latents is None and add_noise:\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            # if strength is 1. then initialise the latents to noise, else initial to image + noise\n            latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)\n            # if pure noise then scale the initial latents by the  Scheduler's init sigma\n            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents\n        elif add_noise:\n            noise = latents.to(device)\n            latents = noise * self.scheduler.init_noise_sigma\n        else:\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            latents = image_latents.to(device)\n\n        outputs = (latents,)\n\n        if return_noise:\n            outputs += (noise,)\n\n        if return_image_latents:\n            outputs += (image_latents,)\n\n        return outputs\n\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\n        dtype = image.dtype\n        if self.vae.config.force_upcast:\n            image = image.float()\n            self.vae.to(dtype=torch.float32)\n\n        if isinstance(generator, list):\n            image_latents = [\n                self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])\n                for i in range(image.shape[0])\n            ]\n            image_latents = torch.cat(image_latents, dim=0)\n        else:\n            image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)\n\n        if self.vae.config.force_upcast:\n            self.vae.to(dtype)\n\n        image_latents = image_latents.to(dtype)\n        image_latents = self.vae.config.scaling_factor * image_latents\n\n        return image_latents\n\n    def prepare_mask_latents(\n        self,\n        mask,\n        masked_image,\n        batch_size,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        do_classifier_free_guidance,\n    ):\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\n        # and half precision\n        mask = torch.nn.functional.interpolate(\n            mask,\n            size=(\n                height // self.vae_scale_factor,\n                width // self.vae_scale_factor,\n            ),\n        )\n        mask = mask.to(device=device, dtype=dtype)\n\n        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\n        if mask.shape[0] < batch_size:\n            if not batch_size % mask.shape[0] == 0:\n                raise ValueError(\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\n                    f\" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number\"\n                    \" of masks that you pass is divisible by the total requested batch size.\"\n                )\n            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)\n\n        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask\n\n        masked_image_latents = None\n        if masked_image is not None:\n            masked_image = masked_image.to(device=device, dtype=dtype)\n            masked_image_latents = self._encode_vae_image(masked_image, generator=generator)\n            if masked_image_latents.shape[0] < batch_size:\n                if not batch_size % masked_image_latents.shape[0] == 0:\n                    raise ValueError(\n                        \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                        f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\n                        \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                    )\n                masked_image_latents = masked_image_latents.repeat(\n                    batch_size // masked_image_latents.shape[0], 1, 1, 1\n                )\n\n            masked_image_latents = (\n                torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents\n            )\n\n            # aligning device to prevent device errors when concating it with the latent model input\n            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\n\n        return mask, masked_image_latents\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):\n        # get the original timestep using init_timestep\n        if denoising_start is None:\n            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n            t_start = max(num_inference_steps - init_timestep, 0)\n        else:\n            t_start = 0\n\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        # Strength is irrelevant if we directly request a timestep to start at;\n        # that is, strength is determined by the denoising_start instead.\n        if denoising_start is not None:\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_start * self.scheduler.config.num_train_timesteps)\n                )\n            )\n\n            num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()\n            if self.scheduler.order == 2 and num_inference_steps % 2 == 0:\n                # if the scheduler is a 2nd order scheduler we might have to do +1\n                # because `num_inference_steps` might be even given that every timestep\n                # (except the highest one) is duplicated. If `num_inference_steps` is even it would\n                # mean that we cut the timesteps in the middle of the denoising step\n                # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1\n                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler\n                num_inference_steps = num_inference_steps + 1\n\n            # because t_n+1 >= t_n, we slice the timesteps starting from the end\n            timesteps = timesteps[-num_inference_steps:]\n            return timesteps, num_inference_steps\n\n        return timesteps, num_inference_steps - t_start\n\n    def _get_add_time_ids(\n        self,\n        original_size,\n        crops_coords_top_left,\n        target_size,\n        aesthetic_score,\n        negative_aesthetic_score,\n        dtype,\n        text_encoder_projection_dim=None,\n    ):\n        if self.config.requires_aesthetics_score:\n            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))\n            add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))\n        else:\n            add_time_ids = list(original_size + crops_coords_top_left + target_size)\n            add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)\n\n        passed_add_embed_dim = (\n            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim\n        )\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if (\n            expected_add_embed_dim > passed_add_embed_dim\n            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model.\"\n            )\n        elif (\n            expected_add_embed_dim < passed_add_embed_dim\n            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model.\"\n            )\n        elif expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)\n\n        return add_time_ids, add_neg_time_ids\n\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width\n    def _default_height_width(self, height, width, image):\n        # NOTE: It is possible that a list of images have different\n        # dimensions for each image, so just checking the first image\n        # is not _exactly_ correct, but it is simple.\n        while isinstance(image, list):\n            image = image[0]\n\n        if height is None:\n            if isinstance(image, PIL.Image.Image):\n                height = image.height\n            elif isinstance(image, torch.Tensor):\n                height = image.shape[-2]\n\n            # round down to nearest multiple of `self.adapter.downscale_factor`\n            height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor\n\n        if width is None:\n            if isinstance(image, PIL.Image.Image):\n                width = image.width\n            elif isinstance(image, torch.Tensor):\n                width = image.shape[-1]\n\n            # round down to nearest multiple of `self.adapter.downscale_factor`\n            width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor\n\n        return height, width\n\n    def prepare_control_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Optional[Union[str, List[str]]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None,\n        mask_image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None,\n        adapter_image: PipelineImageInput = None,\n        control_image: PipelineImageInput = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        strength: float = 0.9999,\n        num_inference_steps: int = 50,\n        denoising_start: Optional[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[Union[torch.Tensor]] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Optional[Tuple[int, int]] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        adapter_conditioning_scale: Optional[Union[float, List[float]]] = 1.0,\n        cond_tau: float = 1.0,\n        aesthetic_score: float = 6.0,\n        negative_aesthetic_score: float = 2.5,\n        controlnet_conditioning_scale=1.0,\n        guess_mode: bool = False,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will\n                be masked out with `mask_image` and repainted according to `prompt`.\n            mask_image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted\n                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)\n                instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            adapter_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`):\n                The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the\n                type is specified as `torch.Tensor`, it is passed to Adapter as is. PIL.Image.Image` can also be\n                accepted as an image. The control image is automatically resized to fit the output image.\n            control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is\n                specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be\n                accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height\n                and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in\n                `init`, images must be passed as a list such that each element of the list can be correctly batched for\n                input to a single ControlNet.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            strength (`float`, *optional*, defaults to 1.0):\n                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1\n                essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            denoising_start (`float`, *optional*):\n                When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be\n                bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and\n                it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,\n                strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline\n                is integrated into a \"Mixture of Denoisers\" multi-pipeline setup, as detailed in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise as determined by the discrete timesteps selected by the\n                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a\n                \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`]\n                instead of a plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.7):\n                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of\n                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).\n                Guidance rescale factor should fix overexposure when using zero terminal SNR.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the\n                residual in the original unet. If multiple adapters are specified in init, you can set the\n                corresponding scale as a list.\n            adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the\n                residual in the original unet. If multiple adapters are specified in init, you can set the\n                corresponding scale as a list.\n            aesthetic_score (`float`, *optional*, defaults to 6.0):\n                Used to simulate an aesthetic score of the generated image by influencing the positive text condition.\n                Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_aesthetic_score (`float`, *optional*, defaults to 2.5):\n                Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to\n                simulate an aesthetic score of the generated image by influencing the negative text condition.\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple`. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n        # 0. Default height and width to unet\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n        adapter = self.adapter._orig_mod if is_compiled_module(self.adapter) else self.adapter\n        height, width = self._default_height_width(height, width, adapter_image)\n        device = self._execution_device\n\n        if isinstance(adapter, MultiAdapter):\n            adapter_input = []\n            for one_image in adapter_image:\n                one_image = _preprocess_adapter_image(one_image, height, width)\n                one_image = one_image.to(device=device, dtype=adapter.dtype)\n                adapter_input.append(one_image)\n        else:\n            adapter_input = _preprocess_adapter_image(adapter_image, height, width)\n            adapter_input = adapter_input.to(device=device, dtype=adapter.dtype)\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 0.1 align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n        if isinstance(adapter, MultiAdapter) and isinstance(adapter_conditioning_scale, float):\n            adapter_conditioning_scale = [adapter_conditioning_scale] * len(adapter.nets)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            height,\n            width,\n            callback_steps,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n        )\n\n        self.check_conditions(\n            prompt,\n            prompt_embeds,\n            adapter_image,\n            control_image,\n            adapter_conditioning_scale,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n        )\n\n        # 4. set timesteps\n        def denoising_value_valid(dnv):\n            return isinstance(dnv, float) and 0 < dnv < 1\n\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps, num_inference_steps = self.get_timesteps(\n            num_inference_steps,\n            strength,\n            device,\n            denoising_start=denoising_start if denoising_value_valid(denoising_start) else None,\n        )\n        # check that number of inference steps is not < 1 - as this doesn't make sense\n        if num_inference_steps < 1:\n            raise ValueError(\n                f\"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline\"\n                f\"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline.\"\n            )\n        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise\n        is_strength_max = strength == 1.0\n\n        # 5. Preprocess mask and image - resizes image and mask w.r.t height and width\n        mask, masked_image, init_image = prepare_mask_and_masked_image(\n            image, mask_image, height, width, return_image=True\n        )\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.vae.config.latent_channels\n        num_channels_unet = self.unet.config.in_channels\n        return_image_latents = num_channels_unet == 4\n\n        add_noise = denoising_start is None\n        latents_outputs = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n            image=init_image,\n            timestep=latent_timestep,\n            is_strength_max=is_strength_max,\n            add_noise=add_noise,\n            return_noise=True,\n            return_image_latents=return_image_latents,\n        )\n\n        if return_image_latents:\n            latents, noise, image_latents = latents_outputs\n        else:\n            latents, noise = latents_outputs\n\n        # 7. Prepare mask latent variables\n        mask, masked_image_latents = self.prepare_mask_latents(\n            mask,\n            masked_image,\n            batch_size * num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            do_classifier_free_guidance,\n        )\n\n        # 8. Check that sizes of mask, masked image and latents match\n        if num_channels_unet == 9:\n            # default case for stable-diffusion-v1-5/stable-diffusion-inpainting\n            num_channels_mask = mask.shape[1]\n            num_channels_masked_image = masked_image_latents.shape[1]\n            if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:\n                raise ValueError(\n                    f\"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects\"\n                    f\" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +\"\n                    f\" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}\"\n                    f\" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of\"\n                    \" `pipeline.unet` or your `mask_image` or `image` input.\"\n                )\n        elif num_channels_unet != 4:\n            raise ValueError(\n                f\"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}.\"\n            )\n\n        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 10. Prepare added time ids & embeddings & adapter features\n        if isinstance(adapter, MultiAdapter):\n            adapter_state = adapter(adapter_input, adapter_conditioning_scale)\n            for k, v in enumerate(adapter_state):\n                adapter_state[k] = v\n        else:\n            adapter_state = adapter(adapter_input)\n            for k, v in enumerate(adapter_state):\n                adapter_state[k] = v * adapter_conditioning_scale\n        if num_images_per_prompt > 1:\n            for k, v in enumerate(adapter_state):\n                adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)\n        if do_classifier_free_guidance:\n            for k, v in enumerate(adapter_state):\n                adapter_state[k] = torch.cat([v] * 2, dim=0)\n\n        # 10.2 Prepare control images\n        if isinstance(controlnet, ControlNetModel):\n            control_image = self.prepare_control_image(\n                image=control_image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n        elif isinstance(controlnet, MultiControlNetModel):\n            control_images = []\n\n            for control_image_ in control_image:\n                control_image_ = self.prepare_control_image(\n                    image=control_image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                control_images.append(control_image_)\n\n            control_image = control_images\n        else:\n            raise ValueError(f\"{controlnet.__class__} is not supported.\")\n\n        # 8.2 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            if isinstance(self.controlnet, MultiControlNetModel):\n                controlnet_keep.append(keeps)\n            else:\n                controlnet_keep.append(keeps[0])\n        # ----------------------------------------------------------------\n\n        add_text_embeds = pooled_prompt_embeds\n        if self.text_encoder_2 is None:\n            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        else:\n            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n        add_time_ids, add_neg_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            aesthetic_score,\n            negative_aesthetic_score,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n\n        if do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n            add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device)\n\n        # 11. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # 11.1 Apply denoising_end\n        if (\n            denoising_end is not None\n            and denoising_start is not None\n            and denoising_value_valid(denoising_end)\n            and denoising_value_valid(denoising_start)\n            and denoising_start >= denoising_end\n        ):\n            raise ValueError(\n                f\"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: \"\n                + f\" {denoising_end} when using type float.\"\n            )\n        elif denoising_end is not None and denoising_value_valid(denoising_end):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                if num_channels_unet == 9:\n                    latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)\n\n                # predict the noise residual\n                added_cond_kwargs = {\n                    \"text_embeds\": add_text_embeds,\n                    \"time_ids\": add_time_ids,\n                }\n\n                if i < int(num_inference_steps * cond_tau):\n                    down_block_additional_residuals = [state.clone() for state in adapter_state]\n                else:\n                    down_block_additional_residuals = None\n\n                # ----------- ControlNet\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input_controlnet = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n\n                # concat latents, mask, masked_image_latents in the channel dimension\n                latent_model_input_controlnet = self.scheduler.scale_model_input(latent_model_input_controlnet, t)\n\n                # controlnet(s) inference\n                if guess_mode and do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                    controlnet_added_cond_kwargs = {\n                        \"text_embeds\": add_text_embeds.chunk(2)[1],\n                        \"time_ids\": add_time_ids.chunk(2)[1],\n                    }\n                else:\n                    control_model_input = latent_model_input_controlnet\n                    controlnet_prompt_embeds = prompt_embeds\n                    controlnet_added_cond_kwargs = added_cond_kwargs\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=controlnet_prompt_embeds,\n                    controlnet_cond=control_image,\n                    conditioning_scale=cond_scale,\n                    guess_mode=guess_mode,\n                    added_cond_kwargs=controlnet_added_cond_kwargs,\n                    return_dict=False,\n                )\n\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                    down_intrablock_additional_residuals=down_block_additional_residuals,  # t2iadapter\n                    down_block_additional_residuals=down_block_res_samples,  # controlnet\n                    mid_block_additional_residual=mid_block_res_sample,  # controlnet\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if do_classifier_free_guidance and guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(\n                        noise_pred,\n                        noise_pred_text,\n                        guidance_rescale=guidance_rescale,\n                    )\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(\n                    noise_pred,\n                    t,\n                    latents,\n                    **extra_step_kwargs,\n                    return_dict=False,\n                )[0]\n\n                if num_channels_unet == 4:\n                    init_latents_proper = image_latents\n                    if do_classifier_free_guidance:\n                        init_mask, _ = mask.chunk(2)\n                    else:\n                        init_mask = mask\n\n                    if i < len(timesteps) - 1:\n                        noise_timestep = timesteps[i + 1]\n                        init_latents_proper = self.scheduler.add_noise(\n                            init_latents_proper,\n                            noise,\n                            torch.tensor([noise_timestep]),\n                        )\n\n                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        callback(i, t, latents)\n\n        # make sure the VAE is in float32 mode, as it overflows in float16\n        if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:\n            self.upcast_vae()\n            latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n        if output_type != \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n        else:\n            image = latents\n            return StableDiffusionXLPipelineOutput(images=image)\n\n        image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_xl_differential_img2img.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torchvision\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTextModel,\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    is_invisible_watermark_available,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_invisible_watermark_available():\n    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import StableDiffusionXLImg2ImgPipeline\n        >>> from diffusers.utils import load_image\n\n        >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(\n        ...     \"stabilityai/stable-diffusion-xl-refiner-1.0\", torch_dtype=torch.float16\n        ... )\n        >>> pipe = pipe.to(\"cuda\")\n        >>> url = \"https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png\"\n\n        >>> init_image = load_image(url).convert(\"RGB\")\n        >>> prompt = \"a photo of an astronaut riding a horse on mars\"\n        >>> image = pipe(prompt, image=init_image).images[0]\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass StableDiffusionXLDifferentialImg2ImgPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    FromSingleFileMixin,\n    StableDiffusionXLLoraLoaderMixin,\n    IPAdapterMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion XL.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    In addition the pipeline inherits the following loading methods:\n        - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]\n        - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]\n        - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]\n\n    as well as the following saving methods:\n        - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`]\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion XL uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([` CLIPTextModelWithProjection`]):\n            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\n            specifically the\n            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)\n            variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`CLIPTokenizer`):\n            Second Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->image_encoder->unet->vae\"\n    _optional_components = [\n        \"tokenizer\",\n        \"tokenizer_2\",\n        \"text_encoder\",\n        \"text_encoder_2\",\n        \"image_encoder\",\n        \"feature_extractor\",\n    ]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"add_text_embeds\",\n        \"add_time_ids\",\n        \"negative_pooled_prompt_embeds\",\n        \"add_neg_time_ids\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        text_encoder_2: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        tokenizer_2: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        feature_extractor: CLIPImageProcessor = None,\n        requires_aesthetics_score: bool = False,\n        force_zeros_for_empty_prompt: bool = True,\n        add_watermarker: Optional[bool] = None,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            unet=unet,\n            image_encoder=image_encoder,\n            feature_extractor=feature_extractor,\n            scheduler=scheduler,\n        )\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()\n\n        if add_watermarker:\n            self.watermark = StableDiffusionXLWatermarker()\n        else:\n            self.watermark = None\n\n    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: str,\n        prompt_2: str | None = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder, lora_scale)\n\n            if self.text_encoder_2 is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]\n        text_encoders = (\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\n        )\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # textual inversion: process multi-vector tokens if necessary\n            prompt_embeds_list = []\n            prompts = [prompt, prompt_2]\n            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                text_input_ids = text_inputs.input_ids\n                untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                    text_input_ids, untruncated_ids\n                ):\n                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                    logger.warning(\n                        \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                        f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                    )\n\n                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:\n                    pooled_prompt_embeds = prompt_embeds[0]\n\n                if clip_skip is None:\n                    prompt_embeds = prompt_embeds.hidden_states[-2]\n                else:\n                    # \"2\" because SDXL always indexes from the penultimate layer.\n                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\n\n            # normalize str to list\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n            negative_prompt_2 = (\n                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2\n            )\n\n            uncond_tokens: List[str]\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = [negative_prompt, negative_prompt_2]\n\n            negative_prompt_embeds_list = []\n            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    negative_prompt,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                negative_prompt_embeds = text_encoder(\n                    uncond_input.input_ids.to(device),\n                    output_hidden_states=True,\n                )\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:\n                    negative_pooled_prompt_embeds = negative_prompt_embeds[0]\n                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n\n        if self.text_encoder_2 is not None:\n            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n        else:\n            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            if self.text_encoder_2 is not None:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n            else:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        strength,\n        num_inference_steps,\n        callback_steps,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n        if num_inference_steps is None:\n            raise ValueError(\"`num_inference_steps` cannot be None.\")\n        elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:\n            raise ValueError(\n                f\"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type\"\n                f\" {type(num_inference_steps)}.\"\n            )\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n        if ip_adapter_image_embeds is not None:\n            if not isinstance(ip_adapter_image_embeds, list):\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\n                )\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\n                )\n\n    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):\n        # get the original timestep using init_timestep\n        if denoising_start is None:\n            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n            t_start = max(num_inference_steps - init_timestep, 0)\n        else:\n            t_start = 0\n\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        # Strength is irrelevant if we directly request a timestep to start at;\n        # that is, strength is determined by the denoising_start instead.\n        if denoising_start is not None:\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_start * self.scheduler.config.num_train_timesteps)\n                )\n            )\n\n            num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()\n            if self.scheduler.order == 2 and num_inference_steps % 2 == 0:\n                # if the scheduler is a 2nd order scheduler we might have to do +1\n                # because `num_inference_steps` might be even given that every timestep\n                # (except the highest one) is duplicated. If `num_inference_steps` is even it would\n                # mean that we cut the timesteps in the middle of the denoising step\n                # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1\n                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler\n                num_inference_steps = num_inference_steps + 1\n\n            # because t_n+1 >= t_n, we slice the timesteps starting from the end\n            timesteps = timesteps[-num_inference_steps:]\n            return timesteps, num_inference_steps\n\n        return timesteps, num_inference_steps - t_start\n\n    def prepare_latents(\n        self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True\n    ):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        # Offload text encoder if `enable_model_cpu_offload` was enabled\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.text_encoder_2.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            init_latents = image\n\n        else:\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            if self.vae.config.force_upcast:\n                image = image.float()\n                self.vae.to(dtype=torch.float32)\n\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            elif isinstance(generator, list):\n                init_latents = [\n                    retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                    for i in range(batch_size)\n                ]\n                init_latents = torch.cat(init_latents, dim=0)\n            else:\n                init_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n            if self.vae.config.force_upcast:\n                self.vae.to(dtype)\n\n            init_latents = init_latents.to(dtype)\n            init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        if add_noise:\n            shape = init_latents.shape\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            # get latents\n            init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n\n        latents = init_latents\n\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            image_embeds = []\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)\n                single_negative_image_embeds = torch.stack(\n                    [single_negative_image_embeds] * num_images_per_prompt, dim=0\n                )\n\n                if do_classifier_free_guidance:\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                    single_image_embeds = single_image_embeds.to(device)\n\n                image_embeds.append(single_image_embeds)\n        else:\n            repeat_dims = [1]\n            image_embeds = []\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                    single_negative_image_embeds = single_negative_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))\n                    )\n                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])\n                else:\n                    single_image_embeds = single_image_embeds.repeat(\n                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))\n                    )\n                image_embeds.append(single_image_embeds)\n\n        return image_embeds\n\n    def _get_add_time_ids(\n        self,\n        original_size,\n        crops_coords_top_left,\n        target_size,\n        aesthetic_score,\n        negative_aesthetic_score,\n        negative_original_size,\n        negative_crops_coords_top_left,\n        negative_target_size,\n        dtype,\n        text_encoder_projection_dim=None,\n    ):\n        if self.config.requires_aesthetics_score:\n            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))\n            add_neg_time_ids = list(\n                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)\n            )\n        else:\n            add_time_ids = list(original_size + crops_coords_top_left + target_size)\n            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)\n\n        passed_add_embed_dim = (\n            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim\n        )\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if (\n            expected_add_embed_dim > passed_add_embed_dim\n            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model.\"\n            )\n        elif (\n            expected_add_embed_dim < passed_add_embed_dim\n            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim\n        ):\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model.\"\n            )\n        elif expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)\n\n        return add_time_ids, add_neg_time_ids\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(\n        self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32\n    ) -> torch.Tensor:\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            w (`torch.Tensor`):\n                Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.\n            embedding_dim (`int`, *optional*, defaults to 512):\n                Dimension of the embeddings to generate.\n            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):\n                Data type of the generated embeddings.\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def denoising_end(self):\n        return self._denoising_end\n\n    @property\n    def denoising_start(self):\n        return self._denoising_start\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        image: Union[\n            torch.Tensor,\n            PIL.Image.Image,\n            np.ndarray,\n            List[torch.Tensor],\n            List[PIL.Image.Image],\n            List[np.ndarray],\n        ] = None,\n        strength: float = 0.3,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        denoising_start: Optional[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Tuple[int, int] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Tuple[int, int] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        aesthetic_score: float = 6.0,\n        negative_aesthetic_score: float = 2.5,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        map: torch.Tensor = None,\n        original_image: Union[\n            torch.Tensor,\n            PIL.Image.Image,\n            np.ndarray,\n            List[torch.Tensor],\n            List[PIL.Image.Image],\n            List[np.ndarray],\n        ] = None,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            image (`torch.Tensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.Tensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):\n                The image(s) to modify with the pipeline.\n            strength (`float`, *optional*, defaults to 0.3):\n                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`\n                will be used as a starting point, adding more noise to it the larger the `strength`. The number of\n                denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will\n                be maximum and the denoising process will run for the full number of iterations specified in\n                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of\n                `denoising_start` being declared as an integer, the value of `strength` will be ignored.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            denoising_start (`float`, *optional*):\n                When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be\n                bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and\n                it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,\n                strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline\n                is integrated into a \"Mixture of Denoisers\" multi-pipeline setup, as detailed in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be\n                denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the\n                final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline\n                forms a part of a \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.\n                Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding\n                if `do_classifier_free_guidance` is set to `True`.\n                If not provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.7):\n                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of\n                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).\n                Guidance rescale factor should fix overexposure when using zero terminal SNR.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            aesthetic_score (`float`, *optional*, defaults to 6.0):\n                Used to simulate an aesthetic score of the generated image by influencing the positive text condition.\n                Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_aesthetic_score (`float`, *optional*, defaults to 2.5):\n                Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to\n                simulate an aesthetic score of the generated image by influencing the negative text condition.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            strength,\n            num_inference_steps,\n            callback_steps,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n        self._denoising_start = denoising_start\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 4. Preprocess image\n        # image = self.image_processor.preprocess(image) #ideally we would have preprocess the image with diffusers, but for this POC we won't --- it throws a deprecated warning\n        map = torchvision.transforms.Resize(\n            tuple(s // self.vae_scale_factor for s in original_image.shape[2:]), antialias=None\n        )(map)\n\n        # 5. Prepare timesteps\n        def denoising_value_valid(dnv):\n            return isinstance(dnv, float) and 0 < dnv < 1\n\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n\n        # begin diff diff change\n        total_time_steps = num_inference_steps\n        # end diff diff change\n\n        timesteps, num_inference_steps = self.get_timesteps(\n            num_inference_steps,\n            strength,\n            device,\n            denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,\n        )\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        add_noise = True if denoising_start is None else False\n        # 6. Prepare latent variables\n        latents = self.prepare_latents(\n            image,\n            latent_timestep,\n            batch_size,\n            num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            add_noise,\n        )\n        # 7. Prepare extra step kwargs.\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        height, width = latents.shape[-2:]\n        height = height * self.vae_scale_factor\n        width = width * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 8. Prepare added time ids & embeddings\n        if negative_original_size is None:\n            negative_original_size = original_size\n        if negative_target_size is None:\n            negative_target_size = target_size\n\n        add_text_embeds = pooled_prompt_embeds\n        if self.text_encoder_2 is None:\n            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        else:\n            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n        add_time_ids, add_neg_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            aesthetic_score,\n            negative_aesthetic_score,\n            negative_original_size,\n            negative_crops_coords_top_left,\n            negative_target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n            add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device)\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 9. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # 9.1 Apply denoising_end\n        if (\n            denoising_end is not None\n            and denoising_start is not None\n            and denoising_value_valid(denoising_end)\n            and denoising_value_valid(denoising_start)\n            and denoising_start >= denoising_end\n        ):\n            raise ValueError(\n                f\"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: \"\n                + f\" {denoising_end} when using type float.\"\n            )\n        elif denoising_end is not None and denoising_value_valid(denoising_end):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        # preparations for diff diff\n        original_with_noise = self.prepare_latents(\n            original_image, timesteps, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator\n        )\n        thresholds = torch.arange(total_time_steps, dtype=map.dtype) / total_time_steps\n        thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device)\n        masks = map > (thresholds + (denoising_start or 0))\n        # end diff diff preparations\n\n        # 9.2 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # diff diff\n                if i == 0 and denoising_start is None:\n                    latents = original_with_noise[:1]\n                else:\n                    mask = masks[i].unsqueeze(0)\n                    # cast mask to the same type as latents etc\n                    mask = mask.to(latents.dtype)\n                    mask = mask.unsqueeze(1)  # fit shape\n                    latents = original_with_noise[i] * mask + latents * (1 - mask)\n                # end diff diff\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n                    added_cond_kwargs[\"image_embeds\"] = image_embeds\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n                    else:\n                        raise ValueError(\n                            \"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/.\"\n                        )\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                    negative_pooled_prompt_embeds = callback_outputs.pop(\n                        \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                    )\n                    add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n                    add_neg_time_ids = callback_outputs.pop(\"add_neg_time_ids\", add_neg_time_ids)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n            elif latents.dtype != self.vae.dtype:\n                if torch.backends.mps.is_available():\n                    # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                    self.vae = self.vae.to(latents.dtype)\n                else:\n                    raise ValueError(\n                        \"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/.\"\n                    )\n            # unscale/denormalize the latents\n            # denormalize with the mean and std if available and not None\n            has_latents_mean = hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None\n            has_latents_std = hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None\n            if has_latents_mean and has_latents_std:\n                latents_mean = (\n                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents_std = (\n                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean\n            else:\n                latents = latents / self.vae.config.scaling_factor\n\n            image = self.vae.decode(latents, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n\n        # apply watermark if available\n        if self.watermark is not None:\n            image = self.watermark.apply_watermark(image)\n\n        image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_xl_instandid_img2img.py",
    "content": "# Copyright 2025 The InstantX Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport math\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport cv2\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn as nn\n\nfrom diffusers import StableDiffusionXLControlNetImg2ImgPipeline\nfrom diffusers.image_processor import PipelineImageInput\nfrom diffusers.models import ControlNetModel\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput\nfrom diffusers.utils import (\n    deprecate,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module, is_torch_version\n\n\ntry:\n    import xformers\n    import xformers.ops\n\n    xformers_available = True\nexcept Exception:\n    xformers_available = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nlogger.warning(\n    \"To use instant id pipelines, please make sure you have the `insightface` library installed: `pip install insightface`.\"\n    \"Please refer to: https://huggingface.co/InstantX/InstantID for further instructions regarding inference\"\n)\n\n\ndef FeedForward(dim, mult=4):\n    inner_dim = int(dim * mult)\n    return nn.Sequential(\n        nn.LayerNorm(dim),\n        nn.Linear(dim, inner_dim, bias=False),\n        nn.GELU(),\n        nn.Linear(inner_dim, dim, bias=False),\n    )\n\n\ndef reshape_tensor(x, heads):\n    bs, length, width = x.shape\n    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)\n    x = x.view(bs, length, heads, -1)\n    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)\n    x = x.transpose(1, 2)\n    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)\n    x = x.reshape(bs, heads, length, -1)\n    return x\n\n\nclass PerceiverAttention(nn.Module):\n    def __init__(self, *, dim, dim_head=64, heads=8):\n        super().__init__()\n        self.scale = dim_head**-0.5\n        self.dim_head = dim_head\n        self.heads = heads\n        inner_dim = dim_head * heads\n\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)\n        self.to_out = nn.Linear(inner_dim, dim, bias=False)\n\n    def forward(self, x, latents):\n        \"\"\"\n        Args:\n            x (torch.Tensor): image features\n                shape (b, n1, D)\n            latent (torch.Tensor): latent features\n                shape (b, n2, D)\n        \"\"\"\n        x = self.norm1(x)\n        latents = self.norm2(latents)\n\n        b, l, _ = latents.shape\n\n        q = self.to_q(latents)\n        kv_input = torch.cat((x, latents), dim=-2)\n        k, v = self.to_kv(kv_input).chunk(2, dim=-1)\n\n        q = reshape_tensor(q, self.heads)\n        k = reshape_tensor(k, self.heads)\n        v = reshape_tensor(v, self.heads)\n\n        # attention\n        scale = 1 / math.sqrt(math.sqrt(self.dim_head))\n        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards\n        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n        out = weight @ v\n\n        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)\n\n        return self.to_out(out)\n\n\nclass Resampler(nn.Module):\n    def __init__(\n        self,\n        dim=1024,\n        depth=8,\n        dim_head=64,\n        heads=16,\n        num_queries=8,\n        embedding_dim=768,\n        output_dim=1024,\n        ff_mult=4,\n    ):\n        super().__init__()\n\n        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)\n\n        self.proj_in = nn.Linear(embedding_dim, dim)\n\n        self.proj_out = nn.Linear(dim, output_dim)\n        self.norm_out = nn.LayerNorm(output_dim)\n\n        self.layers = nn.ModuleList([])\n        for _ in range(depth):\n            self.layers.append(\n                nn.ModuleList(\n                    [\n                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),\n                        FeedForward(dim=dim, mult=ff_mult),\n                    ]\n                )\n            )\n\n    def forward(self, x):\n        latents = self.latents.repeat(x.size(0), 1, 1)\n        x = self.proj_in(x)\n\n        for attn, ff in self.layers:\n            latents = attn(x, latents) + latents\n            latents = ff(latents) + latents\n\n        latents = self.proj_out(latents)\n        return self.norm_out(latents)\n\n\nclass AttnProcessor(nn.Module):\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size=None,\n        cross_attention_dim=None,\n    ):\n        super().__init__()\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass IPAttnProcessor(nn.Module):\n    r\"\"\"\n    Attention processor for IP-Adapater.\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`):\n            The number of channels in the `encoder_hidden_states`.\n        scale (`float`, defaults to 1.0):\n            the weight scale of image prompt.\n        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):\n            The context length of the image features.\n    \"\"\"\n\n    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.scale = scale\n        self.num_tokens = num_tokens\n\n        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            # get encoder_hidden_states, ip_hidden_states\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states, ip_hidden_states = (\n                encoder_hidden_states[:, :end_pos, :],\n                encoder_hidden_states[:, end_pos:, :],\n            )\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        if xformers_available:\n            hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)\n        else:\n            attention_probs = attn.get_attention_scores(query, key, attention_mask)\n            hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # for ip-adapter\n        ip_key = self.to_k_ip(ip_hidden_states)\n        ip_value = self.to_v_ip(ip_hidden_states)\n\n        ip_key = attn.head_to_batch_dim(ip_key)\n        ip_value = attn.head_to_batch_dim(ip_value)\n\n        if xformers_available:\n            ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)\n        else:\n            ip_attention_probs = attn.get_attention_scores(query, ip_key, None)\n            ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)\n        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)\n\n        hidden_states = hidden_states + self.scale * ip_hidden_states\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n    def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):\n        # TODO attention_mask\n        query = query.contiguous()\n        key = key.contiguous()\n        value = value.contiguous()\n        hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)\n        return hidden_states\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> # !pip install opencv-python transformers accelerate insightface\n        >>> import diffusers\n        >>> from diffusers.utils import load_image\n        >>> from diffusers.models import ControlNetModel\n\n        >>> import cv2\n        >>> import torch\n        >>> import numpy as np\n        >>> from PIL import Image\n\n        >>> from insightface.app import FaceAnalysis\n        >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps\n\n        >>> # download 'antelopev2' under ./models\n        >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\n        >>> app.prepare(ctx_id=0, det_size=(640, 640))\n\n        >>> # download models under ./checkpoints\n        >>> face_adapter = f'./checkpoints/ip-adapter.bin'\n        >>> controlnet_path = f'./checkpoints/ControlNetModel'\n\n        >>> # load IdentityNet\n        >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)\n\n        >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(\n        ...     \"stabilityai/stable-diffusion-xl-base-1.0\", controlnet=controlnet, torch_dtype=torch.float16\n        ... )\n        >>> pipe.cuda()\n\n        >>> # load adapter\n        >>> pipe.load_ip_adapter_instantid(face_adapter)\n\n        >>> prompt = \"analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality\"\n        >>> negative_prompt = \"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured\"\n\n        >>> # load an image\n        >>> image = load_image(\"your-example.jpg\")\n\n        >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]\n        >>> face_emb = face_info['embedding']\n        >>> face_kps = draw_kps(face_image, face_info['kps'])\n\n        >>> pipe.set_ip_adapter_scale(0.8)\n\n        >>> # generate image\n        >>> image = pipe(\n        ...     prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8\n        ... ).images[0]\n        ```\n\"\"\"\n\n\ndef draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):\n    stickwidth = 4\n    limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])\n    kps = np.array(kps)\n\n    w, h = image_pil.size\n    out_img = np.zeros([h, w, 3])\n\n    for i in range(len(limbSeq)):\n        index = limbSeq[i]\n        color = color_list[index[0]]\n\n        x = kps[index][:, 0]\n        y = kps[index][:, 1]\n        length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5\n        angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))\n        polygon = cv2.ellipse2Poly(\n            (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1\n        )\n        out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)\n    out_img = (out_img * 0.6).astype(np.uint8)\n\n    for idx_kp, kp in enumerate(kps):\n        color = color_list[idx_kp]\n        x, y = kp\n        out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)\n\n    out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))\n    return out_img_pil\n\n\nclass StableDiffusionXLInstantIDImg2ImgPipeline(StableDiffusionXLControlNetImg2ImgPipeline):\n    def cuda(self, dtype=torch.float16, use_xformers=False):\n        self.to(\"cuda\", dtype)\n\n        if hasattr(self, \"image_proj_model\"):\n            self.image_proj_model.to(self.unet.device).to(self.unet.dtype)\n\n        if use_xformers:\n            if is_xformers_available():\n                import xformers\n                from packaging import version\n\n                xformers_version = version.parse(xformers.__version__)\n                if xformers_version == version.parse(\"0.0.16\"):\n                    logger.warning(\n                        \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                    )\n                self.enable_xformers_memory_efficient_attention()\n            else:\n                raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):\n        self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)\n        self.set_ip_adapter(model_ckpt, num_tokens, scale)\n\n    def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):\n        image_proj_model = Resampler(\n            dim=1280,\n            depth=4,\n            dim_head=64,\n            heads=20,\n            num_queries=num_tokens,\n            embedding_dim=image_emb_dim,\n            output_dim=self.unet.config.cross_attention_dim,\n            ff_mult=4,\n        )\n\n        image_proj_model.eval()\n\n        self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)\n        state_dict = torch.load(model_ckpt, map_location=\"cpu\")\n        if \"image_proj\" in state_dict:\n            state_dict = state_dict[\"image_proj\"]\n        self.image_proj_model.load_state_dict(state_dict)\n\n        self.image_proj_model_in_features = image_emb_dim\n\n    def set_ip_adapter(self, model_ckpt, num_tokens, scale):\n        unet = self.unet\n        attn_procs = {}\n        for name in unet.attn_processors.keys():\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n            if name.startswith(\"mid_block\"):\n                hidden_size = unet.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = unet.config.block_out_channels[block_id]\n            if cross_attention_dim is None:\n                attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)\n            else:\n                attn_procs[name] = IPAttnProcessor(\n                    hidden_size=hidden_size,\n                    cross_attention_dim=cross_attention_dim,\n                    scale=scale,\n                    num_tokens=num_tokens,\n                ).to(unet.device, dtype=unet.dtype)\n        unet.set_attn_processor(attn_procs)\n\n        state_dict = torch.load(model_ckpt, map_location=\"cpu\")\n        ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())\n        if \"ip_adapter\" in state_dict:\n            state_dict = state_dict[\"ip_adapter\"]\n        ip_layers.load_state_dict(state_dict)\n\n    def set_ip_adapter_scale(self, scale):\n        unet = getattr(self, self.unet_name) if not hasattr(self, \"unet\") else self.unet\n        for attn_processor in unet.attn_processors.values():\n            if isinstance(attn_processor, IPAttnProcessor):\n                attn_processor.scale = scale\n\n    def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):\n        if isinstance(prompt_image_emb, torch.Tensor):\n            prompt_image_emb = prompt_image_emb.clone().detach()\n        else:\n            prompt_image_emb = torch.tensor(prompt_image_emb)\n\n        prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)\n        prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])\n\n        if do_classifier_free_guidance:\n            prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)\n        else:\n            prompt_image_emb = torch.cat([prompt_image_emb], dim=0)\n\n        prompt_image_emb = self.image_proj_model(prompt_image_emb)\n        return prompt_image_emb\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        image: PipelineImageInput = None,\n        control_image: PipelineImageInput = None,\n        strength: float = 0.8,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        image_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        original_size: Tuple[int, int] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Tuple[int, int] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        aesthetic_score: float = 6.0,\n        negative_aesthetic_score: float = 2.5,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is\n                specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be\n                accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height\n                and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in\n                `init`, images must be passed as a list such that each element of the list can be correctly batched for\n                input to a single ControlNet.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image. Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image. Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`\n                and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, pooled text embeddings are generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt\n                weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input\n                argument.\n            image_embeds (`torch.Tensor`, *optional*):\n                Pre-generated image embeddings.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set\n                the corresponding scale as a list.\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                The ControlNet encoder tries to recognize the content of the input image even if you remove all\n                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the ControlNet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the ControlNet stops applying.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned containing the output images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            control_image,\n            strength,\n            num_inference_steps,\n            callback_steps,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n            None,\n            None,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        global_pool_conditions = (\n            controlnet.config.global_pool_conditions\n            if isinstance(controlnet, ControlNetModel)\n            else controlnet.nets[0].config.global_pool_conditions\n        )\n        guess_mode = guess_mode or global_pool_conditions\n\n        # 3.1 Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt,\n            prompt_2,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # 3.2 Encode image prompt\n        prompt_image_emb = self._encode_prompt_image_emb(\n            image_embeds, device, self.unet.dtype, self.do_classifier_free_guidance\n        )\n        bs_embed, seq_len, _ = prompt_image_emb.shape\n        prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)\n        prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # 4. Prepare image and controlnet_conditioning_image\n        image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n\n        if isinstance(controlnet, ControlNetModel):\n            control_image = self.prepare_control_image(\n                image=control_image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n            height, width = control_image.shape[-2:]\n        elif isinstance(controlnet, MultiControlNetModel):\n            control_images = []\n\n            for control_image_ in control_image:\n                control_image_ = self.prepare_control_image(\n                    image=control_image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                control_images.append(control_image_)\n\n            control_image = control_images\n            height, width = control_image[0].shape[-2:]\n        else:\n            assert False\n\n        # 5. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n        self._num_timesteps = len(timesteps)\n\n        # 6. Prepare latent variables\n        latents = self.prepare_latents(\n            image,\n            latent_timestep,\n            batch_size,\n            num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            True,\n        )\n\n        # # 6.5 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7.1 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n        # 7.2 Prepare added time ids & embeddings\n        if isinstance(control_image, list):\n            original_size = original_size or control_image[0].shape[-2:]\n        else:\n            original_size = original_size or control_image.shape[-2:]\n        target_size = target_size or (height, width)\n\n        if negative_original_size is None:\n            negative_original_size = original_size\n        if negative_target_size is None:\n            negative_target_size = target_size\n        add_text_embeds = pooled_prompt_embeds\n\n        if self.text_encoder_2 is None:\n            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        else:\n            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n        add_time_ids, add_neg_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            aesthetic_score,\n            negative_aesthetic_score,\n            negative_original_size,\n            negative_crops_coords_top_left,\n            negative_target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)\n            add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n        encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        is_unet_compiled = is_compiled_module(self.unet)\n        is_controlnet_compiled = is_compiled_module(self.controlnet)\n        is_torch_higher_equal_2_1 = is_torch_version(\">=\", \"2.1\")\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # Relevant thread:\n                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428\n                if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:\n                    torch._inductor.cudagraph_mark_step_begin()\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n                # controlnet(s) inference\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                    controlnet_added_cond_kwargs = {\n                        \"text_embeds\": add_text_embeds.chunk(2)[1],\n                        \"time_ids\": add_time_ids.chunk(2)[1],\n                    }\n                else:\n                    control_model_input = latent_model_input\n                    controlnet_prompt_embeds = prompt_embeds\n                    controlnet_added_cond_kwargs = added_cond_kwargs\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=prompt_image_emb,\n                    controlnet_cond=control_image,\n                    conditioning_scale=cond_scale,\n                    guess_mode=guess_mode,\n                    added_cond_kwargs=controlnet_added_cond_kwargs,\n                    return_dict=False,\n                )\n\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Inferred ControlNet only for the conditional batch.\n                    # To apply the output of ControlNet to both the unconditional and conditional batches,\n                    # add 0 to the unconditional batch to keep it unchanged.\n                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=encoder_hidden_states,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n\n        if not output_type == \"latent\":\n            # apply watermark if available\n            if self.watermark is not None:\n                image = self.watermark.apply_watermark(image)\n\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_xl_instantid.py",
    "content": "# Copyright 2025 The InstantX Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport math\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport cv2\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn as nn\n\nfrom diffusers import StableDiffusionXLControlNetPipeline\nfrom diffusers.image_processor import PipelineImageInput\nfrom diffusers.models import ControlNetModel\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput\nfrom diffusers.utils import (\n    deprecate,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module, is_torch_version\n\n\ntry:\n    import xformers\n    import xformers.ops\n\n    xformers_available = True\nexcept Exception:\n    xformers_available = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nlogger.warning(\n    \"To use instant id pipelines, please make sure you have the `insightface` library installed: `pip install insightface`.\"\n    \"Please refer to: https://huggingface.co/InstantX/InstantID for further instructions regarding inference\"\n)\n\n\ndef FeedForward(dim, mult=4):\n    inner_dim = int(dim * mult)\n    return nn.Sequential(\n        nn.LayerNorm(dim),\n        nn.Linear(dim, inner_dim, bias=False),\n        nn.GELU(),\n        nn.Linear(inner_dim, dim, bias=False),\n    )\n\n\ndef reshape_tensor(x, heads):\n    bs, length, width = x.shape\n    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)\n    x = x.view(bs, length, heads, -1)\n    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)\n    x = x.transpose(1, 2)\n    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)\n    x = x.reshape(bs, heads, length, -1)\n    return x\n\n\nclass PerceiverAttention(nn.Module):\n    def __init__(self, *, dim, dim_head=64, heads=8):\n        super().__init__()\n        self.scale = dim_head**-0.5\n        self.dim_head = dim_head\n        self.heads = heads\n        inner_dim = dim_head * heads\n\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)\n        self.to_out = nn.Linear(inner_dim, dim, bias=False)\n\n    def forward(self, x, latents):\n        \"\"\"\n        Args:\n            x (torch.Tensor): image features\n                shape (b, n1, D)\n            latent (torch.Tensor): latent features\n                shape (b, n2, D)\n        \"\"\"\n        x = self.norm1(x)\n        latents = self.norm2(latents)\n\n        b, l, _ = latents.shape\n\n        q = self.to_q(latents)\n        kv_input = torch.cat((x, latents), dim=-2)\n        k, v = self.to_kv(kv_input).chunk(2, dim=-1)\n\n        q = reshape_tensor(q, self.heads)\n        k = reshape_tensor(k, self.heads)\n        v = reshape_tensor(v, self.heads)\n\n        # attention\n        scale = 1 / math.sqrt(math.sqrt(self.dim_head))\n        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards\n        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n        out = weight @ v\n\n        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)\n\n        return self.to_out(out)\n\n\nclass Resampler(nn.Module):\n    def __init__(\n        self,\n        dim=1024,\n        depth=8,\n        dim_head=64,\n        heads=16,\n        num_queries=8,\n        embedding_dim=768,\n        output_dim=1024,\n        ff_mult=4,\n    ):\n        super().__init__()\n\n        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)\n\n        self.proj_in = nn.Linear(embedding_dim, dim)\n\n        self.proj_out = nn.Linear(dim, output_dim)\n        self.norm_out = nn.LayerNorm(output_dim)\n\n        self.layers = nn.ModuleList([])\n        for _ in range(depth):\n            self.layers.append(\n                nn.ModuleList(\n                    [\n                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),\n                        FeedForward(dim=dim, mult=ff_mult),\n                    ]\n                )\n            )\n\n    def forward(self, x):\n        latents = self.latents.repeat(x.size(0), 1, 1)\n        x = self.proj_in(x)\n\n        for attn, ff in self.layers:\n            latents = attn(x, latents) + latents\n            latents = ff(latents) + latents\n\n        latents = self.proj_out(latents)\n        return self.norm_out(latents)\n\n\nclass AttnProcessor(nn.Module):\n    r\"\"\"\n    Default processor for performing attention-related computations.\n    \"\"\"\n\n    def __init__(\n        self,\n        hidden_size=None,\n        cross_attention_dim=None,\n    ):\n        super().__init__()\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        attention_probs = attn.get_attention_scores(query, key, attention_mask)\n        hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass IPAttnProcessor(nn.Module):\n    r\"\"\"\n    Attention processor for IP-Adapater.\n    Args:\n        hidden_size (`int`):\n            The hidden size of the attention layer.\n        cross_attention_dim (`int`):\n            The number of channels in the `encoder_hidden_states`.\n        scale (`float`, defaults to 1.0):\n            the weight scale of image prompt.\n        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):\n            The context length of the image features.\n    \"\"\"\n\n    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        self.cross_attention_dim = cross_attention_dim\n        self.scale = scale\n        self.num_tokens = num_tokens\n\n        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)\n\n    def __call__(\n        self,\n        attn,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        else:\n            # get encoder_hidden_states, ip_hidden_states\n            end_pos = encoder_hidden_states.shape[1] - self.num_tokens\n            encoder_hidden_states, ip_hidden_states = (\n                encoder_hidden_states[:, :end_pos, :],\n                encoder_hidden_states[:, end_pos:, :],\n            )\n            if attn.norm_cross:\n                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        query = attn.head_to_batch_dim(query)\n        key = attn.head_to_batch_dim(key)\n        value = attn.head_to_batch_dim(value)\n\n        if xformers_available:\n            hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)\n        else:\n            attention_probs = attn.get_attention_scores(query, key, attention_mask)\n            hidden_states = torch.bmm(attention_probs, value)\n        hidden_states = attn.batch_to_head_dim(hidden_states)\n\n        # for ip-adapter\n        ip_key = self.to_k_ip(ip_hidden_states)\n        ip_value = self.to_v_ip(ip_hidden_states)\n\n        ip_key = attn.head_to_batch_dim(ip_key)\n        ip_value = attn.head_to_batch_dim(ip_value)\n\n        if xformers_available:\n            ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)\n        else:\n            ip_attention_probs = attn.get_attention_scores(query, ip_key, None)\n            ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)\n        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)\n\n        hidden_states = hidden_states + self.scale * ip_hidden_states\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n    def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):\n        # TODO attention_mask\n        query = query.contiguous()\n        key = key.contiguous()\n        value = value.contiguous()\n        hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)\n        return hidden_states\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> # !pip install opencv-python transformers accelerate insightface\n        >>> import diffusers\n        >>> from diffusers.utils import load_image\n        >>> from diffusers.models import ControlNetModel\n\n        >>> import cv2\n        >>> import torch\n        >>> import numpy as np\n        >>> from PIL import Image\n\n        >>> from insightface.app import FaceAnalysis\n        >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps\n\n        >>> # download 'antelopev2' under ./models\n        >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])\n        >>> app.prepare(ctx_id=0, det_size=(640, 640))\n\n        >>> # download models under ./checkpoints\n        >>> face_adapter = f'./checkpoints/ip-adapter.bin'\n        >>> controlnet_path = f'./checkpoints/ControlNetModel'\n\n        >>> # load IdentityNet\n        >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)\n\n        >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(\n        ...     \"stabilityai/stable-diffusion-xl-base-1.0\", controlnet=controlnet, torch_dtype=torch.float16\n        ... )\n        >>> pipe.cuda()\n\n        >>> # load adapter\n        >>> pipe.load_ip_adapter_instantid(face_adapter)\n\n        >>> prompt = \"analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality\"\n        >>> negative_prompt = \"(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured\"\n\n        >>> # load an image\n        >>> image = load_image(\"your-example.jpg\")\n\n        >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]\n        >>> face_emb = face_info['embedding']\n        >>> face_kps = draw_kps(face_image, face_info['kps'])\n\n        >>> pipe.set_ip_adapter_scale(0.8)\n\n        >>> # generate image\n        >>> image = pipe(\n        ...     prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8\n        ... ).images[0]\n        ```\n\"\"\"\n\n\ndef draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):\n    stickwidth = 4\n    limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])\n    kps = np.array(kps)\n\n    w, h = image_pil.size\n    out_img = np.zeros([h, w, 3])\n\n    for i in range(len(limbSeq)):\n        index = limbSeq[i]\n        color = color_list[index[0]]\n\n        x = kps[index][:, 0]\n        y = kps[index][:, 1]\n        length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5\n        angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))\n        polygon = cv2.ellipse2Poly(\n            (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1\n        )\n        out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)\n    out_img = (out_img * 0.6).astype(np.uint8)\n\n    for idx_kp, kp in enumerate(kps):\n        color = color_list[idx_kp]\n        x, y = kp\n        out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)\n\n    out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))\n    return out_img_pil\n\n\nclass StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):\n    def cuda(self, dtype=torch.float16, use_xformers=False):\n        self.to(\"cuda\", dtype)\n\n        if hasattr(self, \"image_proj_model\"):\n            self.image_proj_model.to(self.unet.device).to(self.unet.dtype)\n\n        if use_xformers:\n            if is_xformers_available():\n                import xformers\n                from packaging import version\n\n                xformers_version = version.parse(xformers.__version__)\n                if xformers_version == version.parse(\"0.0.16\"):\n                    logger.warning(\n                        \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                    )\n                self.enable_xformers_memory_efficient_attention()\n            else:\n                raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):\n        self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)\n        self.set_ip_adapter(model_ckpt, num_tokens, scale)\n\n    def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):\n        image_proj_model = Resampler(\n            dim=1280,\n            depth=4,\n            dim_head=64,\n            heads=20,\n            num_queries=num_tokens,\n            embedding_dim=image_emb_dim,\n            output_dim=self.unet.config.cross_attention_dim,\n            ff_mult=4,\n        )\n\n        image_proj_model.eval()\n\n        self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)\n        state_dict = torch.load(model_ckpt, map_location=\"cpu\")\n        if \"image_proj\" in state_dict:\n            state_dict = state_dict[\"image_proj\"]\n        self.image_proj_model.load_state_dict(state_dict)\n\n        self.image_proj_model_in_features = image_emb_dim\n\n    def set_ip_adapter(self, model_ckpt, num_tokens, scale):\n        unet = self.unet\n        attn_procs = {}\n        for name in unet.attn_processors.keys():\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n            if name.startswith(\"mid_block\"):\n                hidden_size = unet.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = unet.config.block_out_channels[block_id]\n            if cross_attention_dim is None:\n                attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)\n            else:\n                attn_procs[name] = IPAttnProcessor(\n                    hidden_size=hidden_size,\n                    cross_attention_dim=cross_attention_dim,\n                    scale=scale,\n                    num_tokens=num_tokens,\n                ).to(unet.device, dtype=unet.dtype)\n        unet.set_attn_processor(attn_procs)\n\n        state_dict = torch.load(model_ckpt, map_location=\"cpu\")\n        ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())\n        if \"ip_adapter\" in state_dict:\n            state_dict = state_dict[\"ip_adapter\"]\n        ip_layers.load_state_dict(state_dict)\n\n    def set_ip_adapter_scale(self, scale):\n        unet = getattr(self, self.unet_name) if not hasattr(self, \"unet\") else self.unet\n        for attn_processor in unet.attn_processors.values():\n            if isinstance(attn_processor, IPAttnProcessor):\n                attn_processor.scale = scale\n\n    def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):\n        if isinstance(prompt_image_emb, torch.Tensor):\n            prompt_image_emb = prompt_image_emb.clone().detach()\n        else:\n            prompt_image_emb = torch.tensor(prompt_image_emb)\n\n        prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)\n        prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])\n\n        if do_classifier_free_guidance:\n            prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)\n        else:\n            prompt_image_emb = torch.cat([prompt_image_emb], dim=0)\n\n        prompt_image_emb = self.image_proj_model(prompt_image_emb)\n        return prompt_image_emb\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        image: PipelineImageInput = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        image_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        original_size: Tuple[int, int] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Tuple[int, int] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is\n                specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be\n                accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height\n                and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in\n                `init`, images must be passed as a list such that each element of the list can be correctly batched for\n                input to a single ControlNet.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image. Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image. Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`\n                and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, pooled text embeddings are generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt\n                weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input\n                argument.\n            image_embeds (`torch.Tensor`, *optional*):\n                Pre-generated image embeddings.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set\n                the corresponding scale as a list.\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                The ControlNet encoder tries to recognize the content of the input image even if you remove all\n                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the ControlNet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the ControlNet stops applying.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned containing the output images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            image,\n            callback_steps,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        global_pool_conditions = (\n            controlnet.config.global_pool_conditions\n            if isinstance(controlnet, ControlNetModel)\n            else controlnet.nets[0].config.global_pool_conditions\n        )\n        guess_mode = guess_mode or global_pool_conditions\n\n        # 3.1 Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt,\n            prompt_2,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # 3.2 Encode image prompt\n        prompt_image_emb = self._encode_prompt_image_emb(\n            image_embeds, device, self.unet.dtype, self.do_classifier_free_guidance\n        )\n        bs_embed, seq_len, _ = prompt_image_emb.shape\n        prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)\n        prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # 4. Prepare image\n        if isinstance(controlnet, ControlNetModel):\n            image = self.prepare_image(\n                image=image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n            height, width = image.shape[-2:]\n        elif isinstance(controlnet, MultiControlNetModel):\n            images = []\n\n            for image_ in image:\n                image_ = self.prepare_image(\n                    image=image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                images.append(image_)\n\n            image = images\n            height, width = image[0].shape[-2:]\n        else:\n            assert False\n\n        # 5. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n        self._num_timesteps = len(timesteps)\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6.5 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7.1 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n        # 7.2 Prepare added time ids & embeddings\n        if isinstance(image, list):\n            original_size = original_size or image[0].shape[-2:]\n        else:\n            original_size = original_size or image.shape[-2:]\n        target_size = target_size or (height, width)\n\n        add_text_embeds = pooled_prompt_embeds\n        if self.text_encoder_2 is None:\n            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        else:\n            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n        add_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n\n        if negative_original_size is not None and negative_target_size is not None:\n            negative_add_time_ids = self._get_add_time_ids(\n                negative_original_size,\n                negative_crops_coords_top_left,\n                negative_target_size,\n                dtype=prompt_embeds.dtype,\n                text_encoder_projection_dim=text_encoder_projection_dim,\n            )\n        else:\n            negative_add_time_ids = add_time_ids\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n        encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        is_unet_compiled = is_compiled_module(self.unet)\n        is_controlnet_compiled = is_compiled_module(self.controlnet)\n        is_torch_higher_equal_2_1 = is_torch_version(\">=\", \"2.1\")\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # Relevant thread:\n                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428\n                if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:\n                    torch._inductor.cudagraph_mark_step_begin()\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n                # controlnet(s) inference\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                    controlnet_added_cond_kwargs = {\n                        \"text_embeds\": add_text_embeds.chunk(2)[1],\n                        \"time_ids\": add_time_ids.chunk(2)[1],\n                    }\n                else:\n                    control_model_input = latent_model_input\n                    controlnet_prompt_embeds = prompt_embeds\n                    controlnet_added_cond_kwargs = added_cond_kwargs\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=prompt_image_emb,\n                    controlnet_cond=image,\n                    conditioning_scale=cond_scale,\n                    guess_mode=guess_mode,\n                    added_cond_kwargs=controlnet_added_cond_kwargs,\n                    return_dict=False,\n                )\n\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Inferred ControlNet only for the conditional batch.\n                    # To apply the output of ControlNet to both the unconditional and conditional batches,\n                    # add 0 to the unconditional batch to keep it unchanged.\n                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=encoder_hidden_states,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n\n        if not output_type == \"latent\":\n            # apply watermark if available\n            if self.watermark is not None:\n                image = self.watermark.apply_watermark(image)\n\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_xl_ipex.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport intel_extension_for_pytorch as ipex\nimport torch\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTextModel,\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers import StableDiffusionXLPipeline\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    StableDiffusionXLLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    is_invisible_watermark_available,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_invisible_watermark_available():\n    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import StableDiffusionXLPipelineIpex\n\n        >>> # SDXL-Turbo, a distilled version of SDXL 1.0, trained for real-time synthesis\n        >>> pipe = StableDiffusionXLPipelineIpex.from_pretrained(\n        ...     \"stabilityai/sdxl-turbo\", low_cpu_mem_usage=True, use_safetensors=True\n        ... )\n\n        >>> num_inference_steps = 1\n        >>> guidance_scale = 0.0\n        >>> use_bf16 = True\n        >>> data_type = torch.bfloat16 if use_bf16 else torch.float32\n        >>> prompt = \"a photo of an astronaut riding a horse on mars\"\n\n        >>> # value of image height/width should be consistent with the pipeline inference\n        >>> # For Float32\n        >>> pipe.prepare_for_ipex(torch.float32, prompt, height=512, width=512)\n        >>> # For BFloat16\n        >>> pipe.prepare_for_ipex(torch.bfloat16, prompt, height=512, width=512)\n\n        >>> # value of image height/width should be consistent with 'prepare_for_ipex()'\n        >>> # For Float32\n        >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=guidance_scale).images[0]\n        >>> # For BFloat16\n        >>> with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):\n        >>>     image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512, guidance_scale=guidance_scale).images[0]\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass StableDiffusionXLPipelineIpex(\n    StableDiffusionXLPipeline,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion XL on IPEX.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    In addition the pipeline inherits the following loading methods:\n        - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]\n        - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]\n\n    as well as the following saving methods:\n        - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion XL uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([` CLIPTextModelWithProjection`]):\n            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),\n            specifically the\n            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)\n            variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`CLIPTokenizer`):\n            Second Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\n            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of\n            `stabilityai/stable-diffusion-xl-base-1-0`.\n        add_watermarker (`bool`, *optional*):\n            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to\n            watermark output images. If not defined, it will default to True if the package is installed, otherwise no\n            watermarker will be used.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->unet->vae\"\n    _optional_components = [\n        \"tokenizer\",\n        \"tokenizer_2\",\n        \"text_encoder\",\n        \"text_encoder_2\",\n        \"image_encoder\",\n        \"feature_extractor\",\n    ]\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"add_text_embeds\",\n        \"add_time_ids\",\n        \"negative_pooled_prompt_embeds\",\n        \"negative_add_time_ids\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        text_encoder_2: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        tokenizer_2: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        feature_extractor: CLIPImageProcessor = None,\n        force_zeros_for_empty_prompt: bool = True,\n        add_watermarker: Optional[bool] = None,\n    ):\n        # super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            unet=unet,\n            scheduler=scheduler,\n            image_encoder=image_encoder,\n            feature_extractor=feature_extractor,\n        )\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n        self.default_sample_size = (\n            self.unet.config.sample_size\n            if hasattr(self, \"unet\") and self.unet is not None and hasattr(self.unet.config, \"sample_size\")\n            else 128\n        )\n\n        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()\n\n        if add_watermarker:\n            self.watermark = StableDiffusionXLWatermarker()\n        else:\n            self.watermark = None\n\n    def encode_prompt(\n        self,\n        prompt: str,\n        prompt_2: str | None = None,\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str | None = None,\n        negative_prompt_2: str | None = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder, lora_scale)\n\n            if self.text_encoder_2 is not None:\n                if not USE_PEFT_BACKEND:\n                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)\n                else:\n                    scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # Define tokenizers and text encoders\n        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]\n        text_encoders = (\n            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]\n        )\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # textual inversion: procecss multi-vector tokens if necessary\n            prompt_embeds_list = []\n            prompts = [prompt, prompt_2]\n            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    prompt = self.maybe_convert_prompt(prompt, tokenizer)\n\n                text_inputs = tokenizer(\n                    prompt,\n                    padding=\"max_length\",\n                    max_length=tokenizer.model_max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                text_input_ids = text_inputs.input_ids\n                untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                    text_input_ids, untruncated_ids\n                ):\n                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                    logger.warning(\n                        \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                        f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                    )\n\n                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:\n                    pooled_prompt_embeds = prompt_embeds[0]\n\n                if clip_skip is None:\n                    prompt_embeds = prompt_embeds.hidden_states[-2]\n                else:\n                    # \"2\" because SDXL always indexes from the penultimate layer.\n                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]\n\n                prompt_embeds_list.append(prompt_embeds)\n\n            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n\n        # get unconditional embeddings for classifier free guidance\n        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt\n        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:\n            negative_prompt_embeds = torch.zeros_like(prompt_embeds)\n            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)\n        elif do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt_2 = negative_prompt_2 or negative_prompt\n\n            # normalize str to list\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n            negative_prompt_2 = (\n                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2\n            )\n\n            uncond_tokens: List[str]\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = [negative_prompt, negative_prompt_2]\n\n            negative_prompt_embeds_list = []\n            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):\n                if isinstance(self, TextualInversionLoaderMixin):\n                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)\n\n                max_length = prompt_embeds.shape[1]\n                uncond_input = tokenizer(\n                    negative_prompt,\n                    padding=\"max_length\",\n                    max_length=max_length,\n                    truncation=True,\n                    return_tensors=\"pt\",\n                )\n\n                negative_prompt_embeds = text_encoder(\n                    uncond_input.input_ids.to(device),\n                    output_hidden_states=True,\n                )\n                # We are only ALWAYS interested in the pooled output of the final text encoder\n                if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:\n                    negative_pooled_prompt_embeds = negative_prompt_embeds[0]\n                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]\n\n                negative_prompt_embeds_list.append(negative_prompt_embeds)\n\n            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)\n\n        if self.text_encoder_2 is not None:\n            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n        else:\n            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            if self.text_encoder_2 is not None:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)\n            else:\n                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n            bs_embed * num_images_per_prompt, -1\n        )\n        if do_classifier_free_guidance:\n            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(\n                bs_embed * num_images_per_prompt, -1\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        image_embeds = self.image_encoder(image).image_embeds\n        image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n\n        uncond_image_embeds = torch.zeros_like(image_embeds)\n        return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        negative_prompt_2=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        negative_pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.\"\n            )\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def _get_add_time_ids(\n        self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None\n    ):\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n\n        passed_add_embed_dim = (\n            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim\n        )\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.\"\n            )\n\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        return add_time_ids\n\n    def upcast_vae(self):\n        deprecate(\"upcast_vae\", \"1.0.0\", \"`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`\")\n        self.vae.to(dtype=torch.float32)\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def guidance_rescale(self):\n        return self._guidance_rescale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def denoising_end(self):\n        return self._denoising_end\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise as determined by the discrete timesteps selected by the\n                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a\n                \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead\n                of a plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of\n                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).\n                Guidance rescale factor should fix overexposure when using zero terminal SNR.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple`. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n\n        # 0. Default height and width to unet\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Prepare added time ids & embeddings\n        add_text_embeds = pooled_prompt_embeds\n        if self.text_encoder_2 is None:\n            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        else:\n            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n        add_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n        if negative_original_size is not None and negative_target_size is not None:\n            negative_add_time_ids = self._get_add_time_ids(\n                negative_original_size,\n                negative_crops_coords_top_left,\n                negative_target_size,\n                dtype=prompt_embeds.dtype,\n                text_encoder_projection_dim=text_encoder_projection_dim,\n            )\n        else:\n            negative_add_time_ids = add_time_ids\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        if ip_adapter_image is not None:\n            image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)\n            if self.do_classifier_free_guidance:\n                image_embeds = torch.cat([negative_image_embeds, image_embeds])\n                image_embeds = image_embeds.to(device)\n\n        # 8. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # 8.1 Apply denoising_end\n        if (\n            self.denoising_end is not None\n            and isinstance(self.denoising_end, float)\n            and self.denoising_end > 0\n            and self.denoising_end < 1\n        ):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        # 9. Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n                if ip_adapter_image is not None:\n                    added_cond_kwargs[\"image_embeds\"] = image_embeds\n\n                # noise_pred = self.unet(\n                #     latent_model_input,\n                #     t,\n                #     encoder_hidden_states=prompt_embeds,\n                #     timestep_cond=timestep_cond,\n                #     cross_attention_kwargs=self.cross_attention_kwargs,\n                #     added_cond_kwargs=added_cond_kwargs,\n                #     return_dict=False,\n                # )[0]\n\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    added_cond_kwargs=added_cond_kwargs,\n                )[\"sample\"]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                    negative_pooled_prompt_embeds = callback_outputs.pop(\n                        \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                    )\n                    add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n                    negative_add_time_ids = callback_outputs.pop(\"negative_add_time_ids\", negative_add_time_ids)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n\n        if not output_type == \"latent\":\n            # apply watermark if available\n            if self.watermark is not None:\n                image = self.watermark.apply_watermark(image)\n\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n\n    @torch.no_grad()\n    def prepare_for_ipex(\n        self,\n        dtype=torch.float32,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n\n        # 0. Default height and width to unet\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = \"cpu\"\n        do_classifier_free_guidance = self.do_classifier_free_guidance\n\n        # 3. Encode input prompt\n        lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 7. Prepare added time ids & embeddings\n        add_text_embeds = pooled_prompt_embeds\n        if self.text_encoder_2 is None:\n            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        else:\n            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n        add_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n        if negative_original_size is not None and negative_target_size is not None:\n            negative_add_time_ids = self._get_add_time_ids(\n                negative_original_size,\n                negative_crops_coords_top_left,\n                negative_target_size,\n                dtype=prompt_embeds.dtype,\n                text_encoder_projection_dim=text_encoder_projection_dim,\n            )\n        else:\n            negative_add_time_ids = add_time_ids\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        if ip_adapter_image is not None:\n            image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)\n            if self.do_classifier_free_guidance:\n                image_embeds = torch.cat([negative_image_embeds, image_embeds])\n                image_embeds = image_embeds.to(device)\n\n        dummy = torch.ones(1, dtype=torch.int32)\n        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n        latent_model_input = self.scheduler.scale_model_input(latent_model_input, dummy)\n\n        # predict the noise residual\n        added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n        if ip_adapter_image is not None:\n            added_cond_kwargs[\"image_embeds\"] = image_embeds\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n\n        self.unet = self.unet.to(memory_format=torch.channels_last)\n        self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last)\n        self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last)\n\n        unet_input_example = {\n            \"sample\": latent_model_input,\n            \"timestep\": dummy,\n            \"encoder_hidden_states\": prompt_embeds,\n            \"added_cond_kwargs\": added_cond_kwargs,\n        }\n\n        vae_decoder_input_example = latents\n\n        # optimize with ipex\n        if dtype == torch.bfloat16:\n            self.unet = ipex.optimize(\n                self.unet.eval(),\n                dtype=torch.bfloat16,\n                inplace=True,\n            )\n            self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True)\n            self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)\n        elif dtype == torch.float32:\n            self.unet = ipex.optimize(\n                self.unet.eval(),\n                dtype=torch.float32,\n                inplace=True,\n                level=\"O1\",\n                weights_prepack=True,\n                auto_kernel_selection=False,\n            )\n            self.vae.decoder = ipex.optimize(\n                self.vae.decoder.eval(),\n                dtype=torch.float32,\n                inplace=True,\n                level=\"O1\",\n                weights_prepack=True,\n                auto_kernel_selection=False,\n            )\n            self.text_encoder = ipex.optimize(\n                self.text_encoder.eval(),\n                dtype=torch.float32,\n                inplace=True,\n                level=\"O1\",\n                weights_prepack=True,\n                auto_kernel_selection=False,\n            )\n        else:\n            raise ValueError(\" The value of 'dtype' should be 'torch.bfloat16' or 'torch.float32' !\")\n\n        # trace unet model to get better performance on IPEX\n        with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():\n            unet_trace_model = torch.jit.trace(\n                self.unet, example_kwarg_inputs=unet_input_example, check_trace=False, strict=False\n            )\n            unet_trace_model = torch.jit.freeze(unet_trace_model)\n            self.unet.forward = unet_trace_model.forward\n\n        # trace vae.decoder model to get better performance on IPEX\n        with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():\n            vae_decoder_trace_model = torch.jit.trace(\n                self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False\n            )\n            vae_decoder_trace_model = torch.jit.freeze(vae_decoder_trace_model)\n            self.vae.decoder.forward = vae_decoder_trace_model.forward\n"
  },
  {
    "path": "examples/community/pipeline_stable_diffusion_xl_t5.py",
    "content": "# Copyright Philip Brown, ppbrown@github\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n###########################################################################\n# This pipeline attempts to use a model that has SDXL vae, T5 text encoder,\n# and SDXL unet.\n# At the present time, there are no pretrained models that give pleasing\n# output. So as yet, (2025/06/10) this pipeline is somewhat of a tech\n# demo proving that the pieces can at least be put together.\n# Hopefully, it will encourage someone with the hardware available to\n# throw enough resources into training one up.\n\n\nfrom typing import Optional\n\nimport torch.nn as nn\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTokenizer,\n    CLIPVisionModelWithProjection,\n    T5EncoderModel,\n)\n\nfrom diffusers import DiffusionPipeline, StableDiffusionXLPipeline\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\n\n\n# Note: At this time, the intent is to use the T5 encoder mentioned\n# below, with zero changes.\n# Therefore, the model deliberately does not store the T5 encoder model bytes,\n# (Since they are not unique!)\n# but instead takes advantage of huggingface hub cache loading\n\nT5_NAME = \"mcmonkey/google_t5-v1_1-xxl_encoderonly\"\n\n# Caller is expected to load this, or equivalent, as model name for now\n#   eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME)\nSDXL_NAME = \"stabilityai/stable-diffusion-xl-base-1.0\"\n\n\nclass LinearWithDtype(nn.Linear):\n    @property\n    def dtype(self):\n        return self.weight.dtype\n\n\nclass StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):\n    _expected_modules = [\n        \"vae\",\n        \"unet\",\n        \"scheduler\",\n        \"tokenizer\",\n        \"image_encoder\",\n        \"feature_extractor\",\n        \"t5_encoder\",\n        \"t5_projection\",\n        \"t5_pooled_projection\",\n    ]\n\n    _optional_components = [\n        \"image_encoder\",\n        \"feature_extractor\",\n        \"t5_encoder\",\n        \"t5_projection\",\n        \"t5_pooled_projection\",\n    ]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        tokenizer: CLIPTokenizer,\n        t5_encoder=None,\n        t5_projection=None,\n        t5_pooled_projection=None,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        feature_extractor: CLIPImageProcessor = None,\n        force_zeros_for_empty_prompt: bool = True,\n        add_watermarker: Optional[bool] = None,\n    ):\n        DiffusionPipeline.__init__(self)\n\n        if t5_encoder is None:\n            self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME, torch_dtype=unet.dtype)\n        else:\n            self.t5_encoder = t5_encoder\n\n        # ----- build T5 4096 => 2048 dim projection -----\n        if t5_projection is None:\n            self.t5_projection = LinearWithDtype(4096, 2048)  # trainable\n        else:\n            self.t5_projection = t5_projection\n        self.t5_projection.to(dtype=unet.dtype)\n        # ----- build T5 4096 => 1280 dim projection -----\n        if t5_pooled_projection is None:\n            self.t5_pooled_projection = LinearWithDtype(4096, 1280)  # trainable\n        else:\n            self.t5_pooled_projection = t5_pooled_projection\n        self.t5_pooled_projection.to(dtype=unet.dtype)\n\n        print(\"dtype of Linear is \", self.t5_projection.dtype)\n\n        self.register_modules(\n            vae=vae,\n            unet=unet,\n            scheduler=scheduler,\n            tokenizer=tokenizer,\n            t5_encoder=self.t5_encoder,\n            t5_projection=self.t5_projection,\n            t5_pooled_projection=self.t5_pooled_projection,\n            image_encoder=image_encoder,\n            feature_extractor=feature_extractor,\n        )\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n        self.default_sample_size = (\n            self.unet.config.sample_size\n            if hasattr(self, \"unet\") and self.unet is not None and hasattr(self.unet.config, \"sample_size\")\n            else 128\n        )\n\n        self.watermark = None\n\n        # Parts of original SDXL class complain if these attributes are not\n        # at least PRESENT\n        self.text_encoder = self.text_encoder_2 = None\n\n    # ------------------------------------------------------------------\n    #  Encode a text prompt (T5-XXL + 4096→2048 projection)\n    #  Returns exactly four tensors in the order SDXL’s __call__ expects.\n    # ------------------------------------------------------------------\n    def encode_prompt(\n        self,\n        prompt,\n        num_images_per_prompt: int = 1,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str | None = None,\n        **_,\n    ):\n        \"\"\"\n        Returns\n        -------\n        prompt_embeds                : Tensor [B, T, 2048]\n        negative_prompt_embeds       : Tensor [B, T, 2048] | None\n        pooled_prompt_embeds         : Tensor [B, 1280]\n        negative_pooled_prompt_embeds: Tensor [B, 1280]    | None\n        where B = batch * num_images_per_prompt\n        \"\"\"\n\n        # --- helper to tokenize on the pipeline’s device ----------------\n        def _tok(text: str):\n            tok_out = self.tokenizer(\n                text,\n                return_tensors=\"pt\",\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n            ).to(self.device)\n            return tok_out.input_ids, tok_out.attention_mask\n\n        # ---------- positive stream -------------------------------------\n        ids, mask = _tok(prompt)\n        h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state  # [b, T, 4096]\n        tok_pos = self.t5_projection(h_pos)  # [b, T, 2048]\n        pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1))  # [b, 1280]\n\n        # expand for multiple images per prompt\n        tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0)\n        pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0)\n\n        # ---------- negative / CFG stream --------------------------------\n        if do_classifier_free_guidance:\n            neg_text = \"\" if negative_prompt is None else negative_prompt\n            ids_n, mask_n = _tok(neg_text)\n            h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state\n            tok_neg = self.t5_projection(h_neg)\n            pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1))\n\n            tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0)\n            pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0)\n        else:\n            tok_neg = pool_neg = None\n\n        # ----------------- final ordered return --------------------------\n        # 1) positive token embeddings\n        # 2) negative token embeddings (or None)\n        # 3) positive pooled embeddings\n        # 4) negative pooled embeddings (or None)\n        return tok_pos, tok_neg, pool_pos, pool_neg\n"
  },
  {
    "path": "examples/community/pipeline_stg_cogvideox.py",
    "content": "# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.\n# All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport math\nimport types\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom transformers import T5EncoderModel, T5Tokenizer\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.loaders import CogVideoXLoraLoaderMixin\nfrom diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel\nfrom diffusers.models.embeddings import get_3d_rotary_pos_embed\nfrom diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler\nfrom diffusers.utils import is_torch_xla_available, logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.video_processor import VideoProcessor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```python\n        >>> import torch\n        >>> from diffusers.utils import export_to_video\n        >>> from examples.community.pipeline_stg_cogvideox import CogVideoXSTGPipeline\n\n        >>> # Models: \"THUDM/CogVideoX-2b\" or \"THUDM/CogVideoX-5b\"\n        >>> pipe = CogVideoXSTGPipeline.from_pretrained(\"THUDM/CogVideoX-5b\", torch_dtype=torch.float16).to(\"cuda\")\n        >>> prompt = (\n        ...     \"A father and son building a treehouse together, their hands covered in sawdust and smiles on their faces, realistic style.\"\n        ... )\n        >>> pipe.transformer.to(memory_format=torch.channels_last)\n\n        >>> # Configure STG mode options\n        >>> stg_applied_layers_idx = [11]  # Layer indices from 0 to 41\n        >>> stg_scale = 1.0  # Set to 0.0 for CFG\n        >>> do_rescaling = False\n\n        >>> # Generate video frames with STG parameters\n        >>> frames = pipe(\n        ...     prompt=prompt,\n        ...     stg_applied_layers_idx=stg_applied_layers_idx,\n        ...     stg_scale=stg_scale,\n        ...     do_rescaling=do_rescaling,\n        >>> ).frames[0]\n        >>> export_to_video(frames, \"output.mp4\", fps=8)\n        ```\n\"\"\"\n\n\ndef forward_with_stg(\n    self,\n    hidden_states: torch.Tensor,\n    encoder_hidden_states: torch.Tensor,\n    temb: torch.Tensor,\n    image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n) -> torch.Tensor:\n    hidden_states_ptb = hidden_states[2:]\n    encoder_hidden_states_ptb = encoder_hidden_states[2:]\n\n    text_seq_length = encoder_hidden_states.size(1)\n\n    # norm & modulate\n    norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(\n        hidden_states, encoder_hidden_states, temb\n    )\n\n    # attention\n    attn_hidden_states, attn_encoder_hidden_states = self.attn1(\n        hidden_states=norm_hidden_states,\n        encoder_hidden_states=norm_encoder_hidden_states,\n        image_rotary_emb=image_rotary_emb,\n    )\n\n    hidden_states = hidden_states + gate_msa * attn_hidden_states\n    encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states\n\n    # norm & modulate\n    norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(\n        hidden_states, encoder_hidden_states, temb\n    )\n\n    # feed-forward\n    norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)\n    ff_output = self.ff(norm_hidden_states)\n\n    hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]\n    encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]\n\n    hidden_states[2:] = hidden_states_ptb\n    encoder_hidden_states[2:] = encoder_hidden_states_ptb\n\n    return hidden_states, encoder_hidden_states\n\n\n# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid\ndef get_resize_crop_region_for_grid(src, tgt_width, tgt_height):\n    tw = tgt_width\n    th = tgt_height\n    h, w = src\n    r = h / w\n    if r > (th / tw):\n        resize_height = th\n        resize_width = int(round(th / h * w))\n    else:\n        resize_width = tw\n        resize_height = int(round(tw / w * h))\n\n    crop_top = int(round((th - resize_height) / 2.0))\n    crop_left = int(round((tw - resize_width) / 2.0))\n\n    return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass CogVideoXSTGPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):\n    r\"\"\"\n    Pipeline for text-to-video generation using CogVideoX.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.\n        text_encoder ([`T5EncoderModel`]):\n            Frozen text-encoder. CogVideoX uses\n            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the\n            [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.\n        tokenizer (`T5Tokenizer`):\n            Tokenizer of class\n            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).\n        transformer ([`CogVideoXTransformer3DModel`]):\n            A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded video latents.\n    \"\"\"\n\n    _optional_components = []\n    model_cpu_offload_seq = \"text_encoder->transformer->vae\"\n\n    _callback_tensor_inputs = [\n        \"latents\",\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n    ]\n\n    def __init__(\n        self,\n        tokenizer: T5Tokenizer,\n        text_encoder: T5EncoderModel,\n        vae: AutoencoderKLCogVideoX,\n        transformer: CogVideoXTransformer3DModel,\n        scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],\n    ):\n        super().__init__()\n\n        self.register_modules(\n            tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler\n        )\n        self.vae_scale_factor_spatial = (\n            2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        )\n        self.vae_scale_factor_temporal = (\n            self.vae.config.temporal_compression_ratio if getattr(self, \"vae\", None) else 4\n        )\n        self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, \"vae\", None) else 0.7\n\n        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)\n\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_videos_per_prompt: int = 1,\n        max_sequence_length: int = 226,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            add_special_tokens=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        _, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)\n\n        return prompt_embeds\n\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        do_classifier_free_guidance: bool = True,\n        num_videos_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        max_sequence_length: int = 226,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):\n                Whether to use classifier free guidance or not.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            device: (`torch.device`, *optional*):\n                torch device\n            dtype: (`torch.dtype`, *optional*):\n                torch dtype\n        \"\"\"\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            prompt_embeds = self._get_t5_prompt_embeds(\n                prompt=prompt,\n                num_videos_per_prompt=num_videos_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n                dtype=dtype,\n            )\n\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n\n            negative_prompt_embeds = self._get_t5_prompt_embeds(\n                prompt=negative_prompt,\n                num_videos_per_prompt=num_videos_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n                dtype=dtype,\n            )\n\n        return prompt_embeds, negative_prompt_embeds\n\n    def prepare_latents(\n        self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None\n    ):\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        shape = (\n            batch_size,\n            (num_frames - 1) // self.vae_scale_factor_temporal + 1,\n            num_channels_latents,\n            height // self.vae_scale_factor_spatial,\n            width // self.vae_scale_factor_spatial,\n        )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:\n        latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_channels, num_frames, height, width]\n        latents = 1 / self.vae_scaling_factor_image * latents\n\n        frames = self.vae.decode(latents).sample\n        return frames\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        negative_prompt,\n        callback_on_step_end_tensor_inputs,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n    def fuse_qkv_projections(self) -> None:\n        r\"\"\"Enables fused QKV projections.\"\"\"\n        self.fusing_transformer = True\n        self.transformer.fuse_qkv_projections()\n\n    def unfuse_qkv_projections(self) -> None:\n        r\"\"\"Disable QKV projection fusion if enabled.\"\"\"\n        if not self.fusing_transformer:\n            logger.warning(\"The Transformer was not initially fused for QKV projections. Doing nothing.\")\n        else:\n            self.transformer.unfuse_qkv_projections()\n            self.fusing_transformer = False\n\n    def _prepare_rotary_positional_embeddings(\n        self,\n        height: int,\n        width: int,\n        num_frames: int,\n        device: torch.device,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)\n        grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)\n\n        p = self.transformer.config.patch_size\n        p_t = self.transformer.config.patch_size_t\n\n        base_size_width = self.transformer.config.sample_width // p\n        base_size_height = self.transformer.config.sample_height // p\n\n        if p_t is None:\n            # CogVideoX 1.0\n            grid_crops_coords = get_resize_crop_region_for_grid(\n                (grid_height, grid_width), base_size_width, base_size_height\n            )\n            freqs_cos, freqs_sin = get_3d_rotary_pos_embed(\n                embed_dim=self.transformer.config.attention_head_dim,\n                crops_coords=grid_crops_coords,\n                grid_size=(grid_height, grid_width),\n                temporal_size=num_frames,\n                device=device,\n            )\n        else:\n            # CogVideoX 1.5\n            base_num_frames = (num_frames + p_t - 1) // p_t\n\n            freqs_cos, freqs_sin = get_3d_rotary_pos_embed(\n                embed_dim=self.transformer.config.attention_head_dim,\n                crops_coords=None,\n                grid_size=(grid_height, grid_width),\n                temporal_size=base_num_frames,\n                grid_type=\"slice\",\n                max_size=(base_size_height, base_size_width),\n                device=device,\n            )\n\n        return freqs_cos, freqs_sin\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def do_spatio_temporal_guidance(self):\n        return self._stg_scale > 0.0\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def attention_kwargs(self):\n        return self._attention_kwargs\n\n    @property\n    def current_timestep(self):\n        return self._current_timestep\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_frames: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: Optional[List[int]] = None,\n        guidance_scale: float = 6,\n        use_dynamic_cfg: bool = False,\n        num_videos_per_prompt: int = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: str = \"pil\",\n        return_dict: bool = True,\n        attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 226,\n        stg_applied_layers_idx: Optional[List[int]] = [11],\n        stg_scale: Optional[float] = 0.0,\n        do_rescaling: Optional[bool] = False,\n    ) -> Union[CogVideoXPipelineOutput, Tuple]:\n        \"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):\n                The height in pixels of the generated image. This is set to 480 by default for the best results.\n            width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):\n                The width in pixels of the generated image. This is set to 720 by default for the best results.\n            num_frames (`int`, defaults to `48`):\n                Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will\n                contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where\n                num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that\n                needs to be satisfied is that of divisibility mentioned above.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                The number of videos to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead\n                of a plain tuple.\n            attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int`, defaults to `226`):\n                Maximum sequence length in encoded prompt. Must be consistent with\n                `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:\n            [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple`. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial\n        width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial\n        num_frames = num_frames or self.transformer.config.sample_frames\n\n        num_videos_per_prompt = 1\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            negative_prompt,\n            callback_on_step_end_tensor_inputs,\n            prompt_embeds,\n            negative_prompt_embeds,\n        )\n        self._stg_scale = stg_scale\n        self._guidance_scale = guidance_scale\n        self._attention_kwargs = attention_kwargs\n        self._current_timestep = None\n        self._interrupt = False\n\n        if self.do_spatio_temporal_guidance:\n            for i in stg_applied_layers_idx:\n                self.transformer.transformer_blocks[i].forward = types.MethodType(\n                    forward_with_stg, self.transformer.transformer_blocks[i]\n                )\n\n        # 2. Default call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            negative_prompt,\n            do_classifier_free_guidance,\n            num_videos_per_prompt=num_videos_per_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            max_sequence_length=max_sequence_length,\n            device=device,\n        )\n        if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n        elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n        self._num_timesteps = len(timesteps)\n\n        # 5. Prepare latents\n        latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1\n\n        # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t\n        patch_size_t = self.transformer.config.patch_size_t\n        additional_frames = 0\n        if patch_size_t is not None and latent_frames % patch_size_t != 0:\n            additional_frames = patch_size_t - latent_frames % patch_size_t\n            num_frames += additional_frames * self.vae_scale_factor_temporal\n\n        latent_channels = self.transformer.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_videos_per_prompt,\n            latent_channels,\n            num_frames,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Create rotary embeds if required\n        image_rotary_emb = (\n            self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)\n            if self.transformer.config.use_rotary_positional_embeddings\n            else None\n        )\n\n        # 8. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            # for DPM-solver++\n            old_pred_original_sample = None\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                self._current_timestep = t\n                if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n                    latent_model_input = torch.cat([latents] * 2)\n                elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n                    latent_model_input = torch.cat([latents] * 3)\n                else:\n                    latent_model_input = latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latent_model_input.shape[0])\n\n                # predict noise model_output\n                noise_pred = self.transformer(\n                    hidden_states=latent_model_input,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep=timestep,\n                    image_rotary_emb=image_rotary_emb,\n                    attention_kwargs=attention_kwargs,\n                    return_dict=False,\n                )[0]\n                noise_pred = noise_pred.float()\n\n                # perform guidance\n                if use_dynamic_cfg:\n                    self._guidance_scale = 1 + guidance_scale * (\n                        (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2\n                    )\n                if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n                elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n                    noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)\n                    noise_pred = (\n                        noise_pred_uncond\n                        + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n                        + self._stg_scale * (noise_pred_text - noise_pred_perturb)\n                    )\n\n                if do_rescaling:\n                    rescaling_scale = 0.7\n                    factor = noise_pred_text.std() / noise_pred.std()\n                    factor = rescaling_scale * factor + (1 - rescaling_scale)\n                    noise_pred = noise_pred * factor\n\n                # compute the previous noisy sample x_t -> x_t-1\n                if not isinstance(self.scheduler, CogVideoXDPMScheduler):\n                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n                else:\n                    latents, old_pred_original_sample = self.scheduler.step(\n                        noise_pred,\n                        old_pred_original_sample,\n                        t,\n                        timesteps[i - 1] if i > 0 else None,\n                        latents,\n                        **extra_step_kwargs,\n                        return_dict=False,\n                    )\n                latents = latents.to(prompt_embeds.dtype)\n\n                # call the callback, if provided\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        self._current_timestep = None\n\n        if not output_type == \"latent\":\n            # Discard any padding frames that were added for CogVideoX 1.5\n            latents = latents[:, additional_frames:]\n            video = self.decode_latents(latents)\n            video = self.video_processor.postprocess_video(video=video, output_type=output_type)\n        else:\n            video = latents\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (video,)\n\n        return CogVideoXPipelineOutput(frames=video)\n"
  },
  {
    "path": "examples/community/pipeline_stg_hunyuan_video.py",
    "content": "# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport types\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.loaders import HunyuanVideoLoraLoaderMixin\nfrom diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel\nfrom diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import deprecate, is_torch_xla_available, logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.video_processor import VideoProcessor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```python\n        >>> import torch\n        >>> from diffusers.utils import export_to_video\n        >>> from diffusers import HunyuanVideoTransformer3DModel\n        >>> from examples.community.pipeline_stg_hunyuan_video import HunyuanVideoSTGPipeline\n\n        >>> model_id = \"hunyuanvideo-community/HunyuanVideo\"\n        >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(\n        ...     model_id, subfolder=\"transformer\", torch_dtype=torch.bfloat16\n        ... )\n        >>> pipe = HunyuanVideoSTGPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)\n        >>> pipe.vae.enable_tiling()\n        >>> pipe.to(\"cuda\")\n\n        >>> # Configure STG mode options\n        >>> stg_applied_layers_idx = [2]  # Layer indices from 0 to 41\n        >>> stg_scale = 1.0 # Set 0.0 for CFG\n\n        >>> output = pipe(\n        ...     prompt=\"A wolf howling at the moon, with the moon subtly resembling a giant clock face, realistic style.\",\n        ...     height=320,\n        ...     width=512,\n        ...     num_frames=61,\n        ...     num_inference_steps=30,\n        ...     stg_applied_layers_idx=stg_applied_layers_idx,\n        ...     stg_scale=stg_scale,\n        >>> ).frames[0]\n        >>> export_to_video(output, \"output.mp4\", fps=15)\n        ```\n\"\"\"\n\n\nDEFAULT_PROMPT_TEMPLATE = {\n    \"template\": (\n        \"<|start_header_id|>system<|end_header_id|>\\n\\nDescribe the video by detailing the following aspects: \"\n        \"1. The main content and theme of the video.\"\n        \"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\"\n        \"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\"\n        \"4. background environment, light, style and atmosphere.\"\n        \"5. camera angles, movements, and transitions used in the video:<|eot_id|>\"\n        \"<|start_header_id|>user<|end_header_id|>\\n\\n{}<|eot_id|>\"\n    ),\n    \"crop_start\": 95,\n}\n\n\ndef forward_with_stg(\n    self,\n    hidden_states: torch.Tensor,\n    encoder_hidden_states: torch.Tensor,\n    temb: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    return hidden_states, encoder_hidden_states\n\n\ndef forward_without_stg(\n    self,\n    hidden_states: torch.Tensor,\n    encoder_hidden_states: torch.Tensor,\n    temb: torch.Tensor,\n    attention_mask: Optional[torch.Tensor] = None,\n    freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    # 1. Input normalization\n    norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)\n    norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(\n        encoder_hidden_states, emb=temb\n    )\n\n    # 2. Joint attention\n    attn_output, context_attn_output = self.attn(\n        hidden_states=norm_hidden_states,\n        encoder_hidden_states=norm_encoder_hidden_states,\n        attention_mask=attention_mask,\n        image_rotary_emb=freqs_cis,\n    )\n\n    # 3. Modulation and residual connection\n    hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)\n    encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)\n\n    norm_hidden_states = self.norm2(hidden_states)\n    norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)\n\n    norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]\n    norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]\n\n    # 4. Feed-forward\n    ff_output = self.ff(norm_hidden_states)\n    context_ff_output = self.ff_context(norm_encoder_hidden_states)\n\n    hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output\n    encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output\n\n    return hidden_states, encoder_hidden_states\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):\n    r\"\"\"\n    Pipeline for text-to-video generation using HunyuanVideo.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    Args:\n        text_encoder ([`LlamaModel`]):\n            [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).\n        tokenizer (`LlamaTokenizer`):\n            Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).\n        transformer ([`HunyuanVideoTransformer3DModel`]):\n            Conditional Transformer to denoise the encoded image latents.\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKLHunyuanVideo`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.\n        text_encoder_2 ([`CLIPTextModel`]):\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer_2 (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->transformer->vae\"\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\"]\n\n    def __init__(\n        self,\n        text_encoder: LlamaModel,\n        tokenizer: LlamaTokenizerFast,\n        transformer: HunyuanVideoTransformer3DModel,\n        vae: AutoencoderKLHunyuanVideo,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        text_encoder_2: CLIPTextModel,\n        tokenizer_2: CLIPTokenizer,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            transformer=transformer,\n            scheduler=scheduler,\n            text_encoder_2=text_encoder_2,\n            tokenizer_2=tokenizer_2,\n        )\n\n        self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, \"vae\", None) else 4\n        self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, \"vae\", None) else 8\n        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)\n\n    def _get_llama_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]],\n        prompt_template: Dict[str, Any],\n        num_videos_per_prompt: int = 1,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n        max_sequence_length: int = 256,\n        num_hidden_layers_to_skip: int = 2,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        prompt = [prompt_template[\"template\"].format(p) for p in prompt]\n\n        crop_start = prompt_template.get(\"crop_start\", None)\n        if crop_start is None:\n            prompt_template_input = self.tokenizer(\n                prompt_template[\"template\"],\n                padding=\"max_length\",\n                return_tensors=\"pt\",\n                return_length=False,\n                return_overflowing_tokens=False,\n                return_attention_mask=False,\n            )\n            crop_start = prompt_template_input[\"input_ids\"].shape[-1]\n            # Remove <|eot_id|> token and placeholder {}\n            crop_start -= 2\n\n        max_sequence_length += crop_start\n        text_inputs = self.tokenizer(\n            prompt,\n            max_length=max_sequence_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_attention_mask=True,\n        )\n        text_input_ids = text_inputs.input_ids.to(device=device)\n        prompt_attention_mask = text_inputs.attention_mask.to(device=device)\n\n        prompt_embeds = self.text_encoder(\n            input_ids=text_input_ids,\n            attention_mask=prompt_attention_mask,\n            output_hidden_states=True,\n        ).hidden_states[-(num_hidden_layers_to_skip + 1)]\n        prompt_embeds = prompt_embeds.to(dtype=dtype)\n\n        if crop_start is not None and crop_start > 0:\n            prompt_embeds = prompt_embeds[:, crop_start:]\n            prompt_attention_mask = prompt_attention_mask[:, crop_start:]\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        _, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)\n        prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)\n        prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)\n\n        return prompt_embeds, prompt_attention_mask\n\n    def _get_clip_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]],\n        num_videos_per_prompt: int = 1,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n        max_sequence_length: int = 77,\n    ) -> torch.Tensor:\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder_2.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        text_inputs = self.tokenizer_2(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer_2(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)\n        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)\n\n        return prompt_embeds\n\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        prompt_2: Union[str, List[str]] = None,\n        prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,\n        num_videos_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n        max_sequence_length: int = 256,\n    ):\n        if prompt_embeds is None:\n            prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(\n                prompt,\n                prompt_template,\n                num_videos_per_prompt,\n                device=device,\n                dtype=dtype,\n                max_sequence_length=max_sequence_length,\n            )\n\n        if pooled_prompt_embeds is None:\n            if prompt_2 is None and pooled_prompt_embeds is None:\n                prompt_2 = prompt\n            pooled_prompt_embeds = self._get_clip_prompt_embeds(\n                prompt,\n                num_videos_per_prompt,\n                device=device,\n                dtype=dtype,\n                max_sequence_length=77,\n            )\n\n        return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        height,\n        width,\n        prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        prompt_template=None,\n    ):\n        if height % 16 != 0 or width % 16 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 16 but are {height} and {width}.\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if prompt_template is not None:\n            if not isinstance(prompt_template, dict):\n                raise ValueError(f\"`prompt_template` has to be of type `dict` but is {type(prompt_template)}\")\n            if \"template\" not in prompt_template:\n                raise ValueError(\n                    f\"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}\"\n                )\n\n    def prepare_latents(\n        self,\n        batch_size: int,\n        num_channels_latents: 32,\n        height: int = 720,\n        width: int = 1280,\n        num_frames: int = 129,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if latents is not None:\n            return latents.to(device=device, dtype=dtype)\n\n        shape = (\n            batch_size,\n            num_channels_latents,\n            num_frames,\n            int(height) // self.vae_scale_factor_spatial,\n            int(width) // self.vae_scale_factor_spatial,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        return latents\n\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`.\"\n        deprecate(\n            \"enable_vae_slicing\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_slicing()\n\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`.\"\n        deprecate(\n            \"disable_vae_slicing\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_slicing()\n\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`.\"\n        deprecate(\n            \"enable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_tiling()\n\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`.\"\n        deprecate(\n            \"disable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_tiling()\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def do_spatio_temporal_guidance(self):\n        return self._stg_scale > 0.0\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def attention_kwargs(self):\n        return self._attention_kwargs\n\n    @property\n    def current_timestep(self):\n        return self._current_timestep\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Union[str, List[str]] = None,\n        height: int = 720,\n        width: int = 1280,\n        num_frames: int = 129,\n        num_inference_steps: int = 50,\n        sigmas: List[float] = None,\n        guidance_scale: float = 6.0,\n        num_videos_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,\n        max_sequence_length: int = 256,\n        stg_applied_layers_idx: Optional[List[int]] = [2],\n        stg_scale: Optional[float] = 0.0,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                will be used instead.\n            height (`int`, defaults to `720`):\n                The height in pixels of the generated image.\n            width (`int`, defaults to `1280`):\n                The width in pixels of the generated image.\n            num_frames (`int`, defaults to `129`):\n                The number of frames in the generated video.\n            num_inference_steps (`int`, defaults to `50`):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            guidance_scale (`float`, defaults to `6.0`):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality. Note that the only available HunyuanVideo model is\n                CFG-distilled, which means that traditional guidance between unconditional and conditional latent is\n                not applied.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.\n            attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~HunyuanVideoPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned\n                where the first element is a list with the generated images and the second element is a list of `bool`s\n                indicating whether the corresponding generated image contains \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            height,\n            width,\n            prompt_embeds,\n            callback_on_step_end_tensor_inputs,\n            prompt_template,\n        )\n\n        self._stg_scale = stg_scale\n        self._guidance_scale = guidance_scale\n        self._attention_kwargs = attention_kwargs\n        self._current_timestep = None\n        self._interrupt = False\n\n        device = self._execution_device\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # 3. Encode input prompt\n        prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            prompt_template=prompt_template,\n            num_videos_per_prompt=num_videos_per_prompt,\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            prompt_attention_mask=prompt_attention_mask,\n            device=device,\n            max_sequence_length=max_sequence_length,\n        )\n\n        transformer_dtype = self.transformer.dtype\n        prompt_embeds = prompt_embeds.to(transformer_dtype)\n        prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)\n        if pooled_prompt_embeds is not None:\n            pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)\n\n        # 4. Prepare timesteps\n        sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            sigmas=sigmas,\n        )\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.transformer.config.in_channels\n        num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1\n        latents = self.prepare_latents(\n            batch_size * num_videos_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            num_latent_frames,\n            torch.float32,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare guidance condition\n        guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                self._current_timestep = t\n                latent_model_input = latents.to(transformer_dtype)\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latents.shape[0]).to(latents.dtype)\n\n                if self.do_spatio_temporal_guidance:\n                    for i in stg_applied_layers_idx:\n                        self.transformer.transformer_blocks[i].forward = types.MethodType(\n                            forward_without_stg, self.transformer.transformer_blocks[i]\n                        )\n\n                noise_pred = self.transformer(\n                    hidden_states=latent_model_input,\n                    timestep=timestep,\n                    encoder_hidden_states=prompt_embeds,\n                    encoder_attention_mask=prompt_attention_mask,\n                    pooled_projections=pooled_prompt_embeds,\n                    guidance=guidance,\n                    attention_kwargs=attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                if self.do_spatio_temporal_guidance:\n                    for i in stg_applied_layers_idx:\n                        self.transformer.transformer_blocks[i].forward = types.MethodType(\n                            forward_with_stg, self.transformer.transformer_blocks[i]\n                        )\n\n                    noise_pred_perturb = self.transformer(\n                        hidden_states=latent_model_input,\n                        timestep=timestep,\n                        encoder_hidden_states=prompt_embeds,\n                        encoder_attention_mask=prompt_attention_mask,\n                        pooled_projections=pooled_prompt_embeds,\n                        guidance=guidance,\n                        attention_kwargs=attention_kwargs,\n                        return_dict=False,\n                    )[0]\n                    noise_pred = noise_pred + self._stg_scale * (noise_pred - noise_pred_perturb)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        self._current_timestep = None\n\n        if not output_type == \"latent\":\n            latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor\n            video = self.vae.decode(latents, return_dict=False)[0]\n            video = self.video_processor.postprocess_video(video, output_type=output_type)\n        else:\n            video = latents\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (video,)\n\n        return HunyuanVideoPipelineOutput(frames=video)\n"
  },
  {
    "path": "examples/community/pipeline_stg_ltx.py",
    "content": "# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport types\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom transformers import T5EncoderModel, T5TokenizerFast\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin\nfrom diffusers.models.autoencoders import AutoencoderKLLTXVideo\nfrom diffusers.models.transformers import LTXVideoTransformer3DModel\nfrom diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import is_torch_xla_available, logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.video_processor import VideoProcessor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers.utils import export_to_video\n        >>> from examples.community.pipeline_stg_ltx import LTXSTGPipeline\n\n        >>> pipe = LTXSTGPipeline.from_pretrained(\"Lightricks/LTX-Video\", torch_dtype=torch.bfloat16)\n        >>> pipe.to(\"cuda\")\n\n        >>> prompt = \"A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage.\"\n        >>> negative_prompt = \"worst quality, inconsistent motion, blurry, jittery, distorted\"\n\n        >>> # Configure STG mode options\n        >>> stg_applied_layers_idx = [19]  # Layer indices from 0 to 41\n        >>> stg_scale = 1.0 # Set 0.0 for CFG\n        >>> do_rescaling = False\n\n        >>> video = pipe(\n        ...     prompt=prompt,\n        ...     negative_prompt=negative_prompt,\n        ...     width=704,\n        ...     height=480,\n        ...     num_frames=161,\n        ...     num_inference_steps=50,\n        ...     stg_applied_layers_idx=stg_applied_layers_idx,\n        ...     stg_scale=stg_scale,\n        ...     do_rescaling=do_rescaling,\n        >>> ).frames[0]\n        >>> export_to_video(video, \"output.mp4\", fps=24)\n        ```\n\"\"\"\n\n\ndef forward_with_stg(\n    self,\n    hidden_states: torch.Tensor,\n    encoder_hidden_states: torch.Tensor,\n    temb: torch.Tensor,\n    image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    encoder_attention_mask: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    hidden_states_ptb = hidden_states[2:]\n    encoder_hidden_states_ptb = encoder_hidden_states[2:]\n\n    batch_size = hidden_states.size(0)\n    norm_hidden_states = self.norm1(hidden_states)\n\n    num_ada_params = self.scale_shift_table.shape[0]\n    ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)\n    shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)\n    norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa\n\n    attn_hidden_states = self.attn1(\n        hidden_states=norm_hidden_states,\n        encoder_hidden_states=None,\n        image_rotary_emb=image_rotary_emb,\n    )\n    hidden_states = hidden_states + attn_hidden_states * gate_msa\n\n    attn_hidden_states = self.attn2(\n        hidden_states,\n        encoder_hidden_states=encoder_hidden_states,\n        image_rotary_emb=None,\n        attention_mask=encoder_attention_mask,\n    )\n    hidden_states = hidden_states + attn_hidden_states\n    norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp\n\n    ff_output = self.ff(norm_hidden_states)\n    hidden_states = hidden_states + ff_output * gate_mlp\n\n    hidden_states[2:] = hidden_states_ptb\n    encoder_hidden_states[2:] = encoder_hidden_states_ptb\n\n    return hidden_states\n\n\n# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.16,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass LTXSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):\n    r\"\"\"\n    Pipeline for text-to-video generation.\n\n    Reference: https://github.com/Lightricks/LTX-Video\n\n    Args:\n        transformer ([`LTXVideoTransformer3DModel`]):\n            Conditional Transformer architecture to denoise the encoded video latents.\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKLLTXVideo`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`T5EncoderModel`]):\n            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically\n            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer (`T5TokenizerFast`):\n            Second Tokenizer of class\n            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->transformer->vae\"\n    _optional_components = []\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        vae: AutoencoderKLLTXVideo,\n        text_encoder: T5EncoderModel,\n        tokenizer: T5TokenizerFast,\n        transformer: LTXVideoTransformer3DModel,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            transformer=transformer,\n            scheduler=scheduler,\n        )\n\n        self.vae_spatial_compression_ratio = (\n            self.vae.spatial_compression_ratio if getattr(self, \"vae\", None) is not None else 32\n        )\n        self.vae_temporal_compression_ratio = (\n            self.vae.temporal_compression_ratio if getattr(self, \"vae\", None) is not None else 8\n        )\n        self.transformer_spatial_patch_size = (\n            self.transformer.config.patch_size if getattr(self, \"transformer\", None) is not None else 1\n        )\n        self.transformer_temporal_patch_size = (\n            self.transformer.config.patch_size_t if getattr(self, \"transformer\") is not None else 1\n        )\n\n        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)\n        self.tokenizer_max_length = (\n            self.tokenizer.model_max_length if getattr(self, \"tokenizer\", None) is not None else 128\n        )\n\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_videos_per_prompt: int = 1,\n        max_sequence_length: int = 128,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            add_special_tokens=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        prompt_attention_mask = text_inputs.attention_mask\n        prompt_attention_mask = prompt_attention_mask.bool().to(device)\n\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        _, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)\n\n        prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)\n        prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)\n\n        return prompt_embeds, prompt_attention_mask\n\n    # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        do_classifier_free_guidance: bool = True,\n        num_videos_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        max_sequence_length: int = 128,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):\n                Whether to use classifier free guidance or not.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            device: (`torch.device`, *optional*):\n                torch device\n            dtype: (`torch.dtype`, *optional*):\n                torch dtype\n        \"\"\"\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(\n                prompt=prompt,\n                num_videos_per_prompt=num_videos_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n                dtype=dtype,\n            )\n\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n\n            negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(\n                prompt=negative_prompt,\n                num_videos_per_prompt=num_videos_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n                dtype=dtype,\n            )\n\n        return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_on_step_end_tensor_inputs=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        prompt_attention_mask=None,\n        negative_prompt_attention_mask=None,\n    ):\n        if height % 32 != 0 or width % 32 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 32 but are {height} and {width}.\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if prompt_embeds is not None and prompt_attention_mask is None:\n            raise ValueError(\"Must provide `prompt_attention_mask` when specifying `prompt_embeds`.\")\n\n        if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:\n            raise ValueError(\"Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.\")\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n            if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:\n                raise ValueError(\n                    \"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`\"\n                    f\" {negative_prompt_attention_mask.shape}.\"\n                )\n\n    @staticmethod\n    def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:\n        # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].\n        # The patch dimensions are then permuted and collapsed into the channel dimension of shape:\n        # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).\n        # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features\n        batch_size, num_channels, num_frames, height, width = latents.shape\n        post_patch_num_frames = num_frames // patch_size_t\n        post_patch_height = height // patch_size\n        post_patch_width = width // patch_size\n        latents = latents.reshape(\n            batch_size,\n            -1,\n            post_patch_num_frames,\n            patch_size_t,\n            post_patch_height,\n            patch_size,\n            post_patch_width,\n            patch_size,\n        )\n        latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)\n        return latents\n\n    @staticmethod\n    def _unpack_latents(\n        latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1\n    ) -> torch.Tensor:\n        # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)\n        # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of\n        # what happens in the `_pack_latents` method.\n        batch_size = latents.size(0)\n        latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)\n        latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)\n        return latents\n\n    @staticmethod\n    def _normalize_latents(\n        latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0\n    ) -> torch.Tensor:\n        # Normalize latents across the channel dimension [B, C, F, H, W]\n        latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)\n        latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)\n        latents = (latents - latents_mean) * scaling_factor / latents_std\n        return latents\n\n    @staticmethod\n    def _denormalize_latents(\n        latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0\n    ) -> torch.Tensor:\n        # Denormalize latents across the channel dimension [B, C, F, H, W]\n        latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)\n        latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)\n        latents = latents * latents_std / scaling_factor + latents_mean\n        return latents\n\n    def prepare_latents(\n        self,\n        batch_size: int = 1,\n        num_channels_latents: int = 128,\n        height: int = 512,\n        width: int = 704,\n        num_frames: int = 161,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if latents is not None:\n            return latents.to(device=device, dtype=dtype)\n\n        height = height // self.vae_spatial_compression_ratio\n        width = width // self.vae_spatial_compression_ratio\n        num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1\n\n        shape = (batch_size, num_channels_latents, num_frames, height, width)\n\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        latents = self._pack_latents(\n            latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size\n        )\n        return latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1.0\n\n    @property\n    def do_spatio_temporal_guidance(self):\n        return self._stg_scale > 0.0\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def attention_kwargs(self):\n        return self._attention_kwargs\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        height: int = 512,\n        width: int = 704,\n        num_frames: int = 161,\n        frame_rate: int = 25,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        guidance_scale: float = 3,\n        num_videos_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        decode_timestep: Union[float, List[float]] = 0.0,\n        decode_noise_scale: Optional[Union[float, List[float]]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 128,\n        stg_applied_layers_idx: Optional[List[int]] = [19],\n        stg_scale: Optional[float] = 1.0,\n        do_rescaling: Optional[bool] = False,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            height (`int`, defaults to `512`):\n                The height in pixels of the generated image. This is set to 480 by default for the best results.\n            width (`int`, defaults to `704`):\n                The width in pixels of the generated image. This is set to 848 by default for the best results.\n            num_frames (`int`, defaults to `161`):\n                The number of video frames to generate\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, defaults to `3 `):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                The number of videos to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            prompt_attention_mask (`torch.Tensor`, *optional*):\n                Pre-generated attention mask for text embeddings.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be \"\". If not\n                provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.\n            negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):\n                Pre-generated attention mask for negative text embeddings.\n            decode_timestep (`float`, defaults to `0.0`):\n                The timestep at which generated video is decoded.\n            decode_noise_scale (`float`, defaults to `None`):\n                The interpolation factor between random noise and denoised latents at the decode timestep.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.\n            attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int` defaults to `128 `):\n                Maximum sequence length to use with the `prompt`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is\n                returned where the first element is a list with the generated images.\n        \"\"\"\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt=prompt,\n            height=height,\n            width=width,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            prompt_attention_mask=prompt_attention_mask,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n        )\n\n        self._stg_scale = stg_scale\n        self._guidance_scale = guidance_scale\n        self._attention_kwargs = attention_kwargs\n        self._interrupt = False\n\n        if self.do_spatio_temporal_guidance:\n            for i in stg_applied_layers_idx:\n                self.transformer.transformer_blocks[i].forward = types.MethodType(\n                    forward_with_stg, self.transformer.transformer_blocks[i]\n                )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Prepare text embeddings\n        (\n            prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_embeds,\n            negative_prompt_attention_mask,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            num_videos_per_prompt=num_videos_per_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            prompt_attention_mask=prompt_attention_mask,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n            max_sequence_length=max_sequence_length,\n            device=device,\n        )\n        if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)\n        elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)\n            prompt_attention_mask = torch.cat(\n                [negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0\n            )\n\n        # 4. Prepare latent variables\n        num_channels_latents = self.transformer.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_videos_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            num_frames,\n            torch.float32,\n            device,\n            generator,\n            latents,\n        )\n\n        # 5. Prepare timesteps\n        latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1\n        latent_height = height // self.vae_spatial_compression_ratio\n        latent_width = width // self.vae_spatial_compression_ratio\n        video_sequence_length = latent_num_frames * latent_height * latent_width\n        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)\n        mu = calculate_shift(\n            video_sequence_length,\n            self.scheduler.config.get(\"base_image_seq_len\", 256),\n            self.scheduler.config.get(\"max_image_seq_len\", 4096),\n            self.scheduler.config.get(\"base_shift\", 0.5),\n            self.scheduler.config.get(\"max_shift\", 1.16),\n        )\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            timesteps,\n            sigmas=sigmas,\n            mu=mu,\n        )\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n        self._num_timesteps = len(timesteps)\n\n        # 6. Prepare micro-conditions\n        latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio\n        rope_interpolation_scale = (\n            1 / latent_frame_rate,\n            self.vae_spatial_compression_ratio,\n            self.vae_spatial_compression_ratio,\n        )\n\n        # 7. Denoising loop\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n                    latent_model_input = torch.cat([latents] * 2)\n                elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n                    latent_model_input = torch.cat([latents] * 3)\n                else:\n                    latent_model_input = latents\n\n                latent_model_input = latent_model_input.to(prompt_embeds.dtype)\n\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latent_model_input.shape[0])\n\n                noise_pred = self.transformer(\n                    hidden_states=latent_model_input,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep=timestep,\n                    encoder_attention_mask=prompt_attention_mask,\n                    num_frames=latent_num_frames,\n                    height=latent_height,\n                    width=latent_width,\n                    rope_interpolation_scale=rope_interpolation_scale,\n                    attention_kwargs=attention_kwargs,\n                    return_dict=False,\n                )[0]\n                noise_pred = noise_pred.float()\n\n                if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n                elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n                    noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)\n                    noise_pred = (\n                        noise_pred_uncond\n                        + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n                        + self._stg_scale * (noise_pred_text - noise_pred_perturb)\n                    )\n\n                if do_rescaling:\n                    rescaling_scale = 0.7\n                    factor = noise_pred_text.std() / noise_pred.std()\n                    factor = rescaling_scale * factor + (1 - rescaling_scale)\n                    noise_pred = noise_pred * factor\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if output_type == \"latent\":\n            video = latents\n        else:\n            latents = self._unpack_latents(\n                latents,\n                latent_num_frames,\n                latent_height,\n                latent_width,\n                self.transformer_spatial_patch_size,\n                self.transformer_temporal_patch_size,\n            )\n            latents = self._denormalize_latents(\n                latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor\n            )\n            latents = latents.to(prompt_embeds.dtype)\n\n            if not self.vae.config.timestep_conditioning:\n                timestep = None\n            else:\n                noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)\n                if not isinstance(decode_timestep, list):\n                    decode_timestep = [decode_timestep] * batch_size\n                if decode_noise_scale is None:\n                    decode_noise_scale = decode_timestep\n                elif not isinstance(decode_noise_scale, list):\n                    decode_noise_scale = [decode_noise_scale] * batch_size\n\n                timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)\n                decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[\n                    :, None, None, None, None\n                ]\n                latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise\n\n            video = self.vae.decode(latents, timestep, return_dict=False)[0]\n            video = self.video_processor.postprocess_video(video, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (video,)\n\n        return LTXPipelineOutput(frames=video)\n"
  },
  {
    "path": "examples/community/pipeline_stg_ltx_image2video.py",
    "content": "# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport types\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom transformers import T5EncoderModel, T5TokenizerFast\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.image_processor import PipelineImageInput\nfrom diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin\nfrom diffusers.models.autoencoders import AutoencoderKLLTXVideo\nfrom diffusers.models.transformers import LTXVideoTransformer3DModel\nfrom diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import is_torch_xla_available, logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.video_processor import VideoProcessor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers.utils import export_to_video, load_image\n        >>> from examples.community.pipeline_stg_ltx_image2video import LTXImageToVideoSTGPipeline\n\n        >>> pipe = LTXImageToVideoSTGPipeline.from_pretrained(\"Lightricks/LTX-Video\", torch_dtype=torch.bfloat16)\n        >>> pipe.to(\"cuda\")\n\n        >>> image = load_image(\n        ...     \"https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/11.png\"\n        >>> )\n        >>> prompt = \"A medieval fantasy scene featuring a rugged man with shoulder-length brown hair and a beard. He wears a dark leather tunic over a maroon shirt with intricate metal details. His facial expression is serious and intense, and he is making a gesture with his right hand, forming a small circle with his thumb and index finger. The warm golden lighting casts dramatic shadows on his face. The background includes an ornate stone arch and blurred medieval-style decor, creating an epic atmosphere.\"\n        >>> negative_prompt = \"worst quality, inconsistent motion, blurry, jittery, distorted\"\n\n        >>> # Configure STG mode options\n        >>> stg_applied_layers_idx = [19]  # Layer indices from 0 to 41\n        >>> stg_scale = 1.0 # Set 0.0 for CFG\n        >>> do_rescaling = False\n\n        >>> video = pipe(\n        ...     image=image,\n        ...     prompt=prompt,\n        ...     negative_prompt=negative_prompt,\n        ...     width=704,\n        ...     height=480,\n        ...     num_frames=161,\n        ...     num_inference_steps=50,\n        ...     stg_applied_layers_idx=stg_applied_layers_idx,\n        ...     stg_scale=stg_scale,\n        ...     do_rescaling=do_rescaling,\n        >>> ).frames[0]\n        >>> export_to_video(video, \"output.mp4\", fps=24)\n        ```\n\"\"\"\n\n\ndef forward_with_stg(\n    self,\n    hidden_states: torch.Tensor,\n    encoder_hidden_states: torch.Tensor,\n    temb: torch.Tensor,\n    image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n    encoder_attention_mask: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    hidden_states_ptb = hidden_states[2:]\n    encoder_hidden_states_ptb = encoder_hidden_states[2:]\n\n    batch_size = hidden_states.size(0)\n    norm_hidden_states = self.norm1(hidden_states)\n\n    num_ada_params = self.scale_shift_table.shape[0]\n    ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)\n    shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)\n    norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa\n\n    attn_hidden_states = self.attn1(\n        hidden_states=norm_hidden_states,\n        encoder_hidden_states=None,\n        image_rotary_emb=image_rotary_emb,\n    )\n    hidden_states = hidden_states + attn_hidden_states * gate_msa\n\n    attn_hidden_states = self.attn2(\n        hidden_states,\n        encoder_hidden_states=encoder_hidden_states,\n        image_rotary_emb=None,\n        attention_mask=encoder_attention_mask,\n    )\n    hidden_states = hidden_states + attn_hidden_states\n    norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp\n\n    ff_output = self.ff(norm_hidden_states)\n    hidden_states = hidden_states + ff_output * gate_mlp\n\n    hidden_states[2:] = hidden_states_ptb\n    encoder_hidden_states[2:] = encoder_hidden_states_ptb\n\n    return hidden_states\n\n\n# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.16,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\nclass LTXImageToVideoSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):\n    r\"\"\"\n    Pipeline for image-to-video generation.\n\n    Reference: https://github.com/Lightricks/LTX-Video\n\n    Args:\n        transformer ([`LTXVideoTransformer3DModel`]):\n            Conditional Transformer architecture to denoise the encoded video latents.\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKLLTXVideo`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`T5EncoderModel`]):\n            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically\n            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer (`T5TokenizerFast`):\n            Second Tokenizer of class\n            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->transformer->vae\"\n    _optional_components = []\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        vae: AutoencoderKLLTXVideo,\n        text_encoder: T5EncoderModel,\n        tokenizer: T5TokenizerFast,\n        transformer: LTXVideoTransformer3DModel,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            transformer=transformer,\n            scheduler=scheduler,\n        )\n\n        self.vae_spatial_compression_ratio = (\n            self.vae.spatial_compression_ratio if getattr(self, \"vae\", None) is not None else 32\n        )\n        self.vae_temporal_compression_ratio = (\n            self.vae.temporal_compression_ratio if getattr(self, \"vae\", None) is not None else 8\n        )\n        self.transformer_spatial_patch_size = (\n            self.transformer.config.patch_size if getattr(self, \"transformer\", None) is not None else 1\n        )\n        self.transformer_temporal_patch_size = (\n            self.transformer.config.patch_size_t if getattr(self, \"transformer\") is not None else 1\n        )\n\n        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)\n        self.tokenizer_max_length = (\n            self.tokenizer.model_max_length if getattr(self, \"tokenizer\", None) is not None else 128\n        )\n\n        self.default_height = 512\n        self.default_width = 704\n        self.default_frames = 121\n\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_videos_per_prompt: int = 1,\n        max_sequence_length: int = 128,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            add_special_tokens=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        prompt_attention_mask = text_inputs.attention_mask\n        prompt_attention_mask = prompt_attention_mask.bool().to(device)\n\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        _, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)\n\n        prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)\n        prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)\n\n        return prompt_embeds, prompt_attention_mask\n\n    # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        do_classifier_free_guidance: bool = True,\n        num_videos_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        max_sequence_length: int = 128,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):\n                Whether to use classifier free guidance or not.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            device: (`torch.device`, *optional*):\n                torch device\n            dtype: (`torch.dtype`, *optional*):\n                torch dtype\n        \"\"\"\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(\n                prompt=prompt,\n                num_videos_per_prompt=num_videos_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n                dtype=dtype,\n            )\n\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n\n            negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(\n                prompt=negative_prompt,\n                num_videos_per_prompt=num_videos_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n                dtype=dtype,\n            )\n\n        return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask\n\n    # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_on_step_end_tensor_inputs=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        prompt_attention_mask=None,\n        negative_prompt_attention_mask=None,\n    ):\n        if height % 32 != 0 or width % 32 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 32 but are {height} and {width}.\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if prompt_embeds is not None and prompt_attention_mask is None:\n            raise ValueError(\"Must provide `prompt_attention_mask` when specifying `prompt_embeds`.\")\n\n        if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:\n            raise ValueError(\"Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.\")\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n            if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:\n                raise ValueError(\n                    \"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`\"\n                    f\" {negative_prompt_attention_mask.shape}.\"\n                )\n\n    @staticmethod\n    # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents\n    def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:\n        # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].\n        # The patch dimensions are then permuted and collapsed into the channel dimension of shape:\n        # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).\n        # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features\n        batch_size, num_channels, num_frames, height, width = latents.shape\n        post_patch_num_frames = num_frames // patch_size_t\n        post_patch_height = height // patch_size\n        post_patch_width = width // patch_size\n        latents = latents.reshape(\n            batch_size,\n            -1,\n            post_patch_num_frames,\n            patch_size_t,\n            post_patch_height,\n            patch_size,\n            post_patch_width,\n            patch_size,\n        )\n        latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)\n        return latents\n\n    @staticmethod\n    # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents\n    def _unpack_latents(\n        latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1\n    ) -> torch.Tensor:\n        # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)\n        # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of\n        # what happens in the `_pack_latents` method.\n        batch_size = latents.size(0)\n        latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)\n        latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)\n        return latents\n\n    @staticmethod\n    # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents\n    def _normalize_latents(\n        latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0\n    ) -> torch.Tensor:\n        # Normalize latents across the channel dimension [B, C, F, H, W]\n        latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)\n        latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)\n        latents = (latents - latents_mean) * scaling_factor / latents_std\n        return latents\n\n    @staticmethod\n    # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents\n    def _denormalize_latents(\n        latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0\n    ) -> torch.Tensor:\n        # Denormalize latents across the channel dimension [B, C, F, H, W]\n        latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)\n        latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)\n        latents = latents * latents_std / scaling_factor + latents_mean\n        return latents\n\n    def prepare_latents(\n        self,\n        image: Optional[torch.Tensor] = None,\n        batch_size: int = 1,\n        num_channels_latents: int = 128,\n        height: int = 512,\n        width: int = 704,\n        num_frames: int = 161,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        height = height // self.vae_spatial_compression_ratio\n        width = width // self.vae_spatial_compression_ratio\n        num_frames = (\n            (num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)\n        )\n\n        shape = (batch_size, num_channels_latents, num_frames, height, width)\n        mask_shape = (batch_size, 1, num_frames, height, width)\n\n        if latents is not None:\n            conditioning_mask = latents.new_zeros(shape)\n            conditioning_mask[:, :, 0] = 1.0\n            conditioning_mask = self._pack_latents(\n                conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size\n            )\n            return latents.to(device=device, dtype=dtype), conditioning_mask\n\n        if isinstance(generator, list):\n            if len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            init_latents = [\n                retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i])\n                for i in range(batch_size)\n            ]\n        else:\n            init_latents = [\n                retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator) for img in image\n            ]\n\n        init_latents = torch.cat(init_latents, dim=0).to(dtype)\n        init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)\n        init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)\n        conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)\n        conditioning_mask[:, :, 0] = 1.0\n\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)\n\n        conditioning_mask = self._pack_latents(\n            conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size\n        ).squeeze(-1)\n        latents = self._pack_latents(\n            latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size\n        )\n\n        return latents, conditioning_mask\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1.0\n\n    @property\n    def do_spatio_temporal_guidance(self):\n        return self._stg_scale > 0.0\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def attention_kwargs(self):\n        return self._attention_kwargs\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        image: PipelineImageInput = None,\n        prompt: Union[str, List[str]] = None,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        height: int = 512,\n        width: int = 704,\n        num_frames: int = 161,\n        frame_rate: int = 25,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        guidance_scale: float = 3,\n        num_videos_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        decode_timestep: Union[float, List[float]] = 0.0,\n        decode_noise_scale: Optional[Union[float, List[float]]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 128,\n        stg_applied_layers_idx: Optional[List[int]] = [19],\n        stg_scale: Optional[float] = 1.0,\n        do_rescaling: Optional[bool] = False,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            image (`PipelineImageInput`):\n                The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            height (`int`, defaults to `512`):\n                The height in pixels of the generated image. This is set to 480 by default for the best results.\n            width (`int`, defaults to `704`):\n                The width in pixels of the generated image. This is set to 848 by default for the best results.\n            num_frames (`int`, defaults to `161`):\n                The number of video frames to generate\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, defaults to `3 `):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                The number of videos to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            prompt_attention_mask (`torch.Tensor`, *optional*):\n                Pre-generated attention mask for text embeddings.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be \"\". If not\n                provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.\n            negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):\n                Pre-generated attention mask for negative text embeddings.\n            decode_timestep (`float`, defaults to `0.0`):\n                The timestep at which generated video is decoded.\n            decode_noise_scale (`float`, defaults to `None`):\n                The interpolation factor between random noise and denoised latents at the decode timestep.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.\n            attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int` defaults to `128 `):\n                Maximum sequence length to use with the `prompt`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is\n                returned where the first element is a list with the generated images.\n        \"\"\"\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt=prompt,\n            height=height,\n            width=width,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            prompt_attention_mask=prompt_attention_mask,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n        )\n\n        self._stg_scale = stg_scale\n        self._guidance_scale = guidance_scale\n        self._attention_kwargs = attention_kwargs\n        self._interrupt = False\n\n        if self.do_spatio_temporal_guidance:\n            for i in stg_applied_layers_idx:\n                self.transformer.transformer_blocks[i].forward = types.MethodType(\n                    forward_with_stg, self.transformer.transformer_blocks[i]\n                )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Prepare text embeddings\n        (\n            prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_embeds,\n            negative_prompt_attention_mask,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            num_videos_per_prompt=num_videos_per_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            prompt_attention_mask=prompt_attention_mask,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n            max_sequence_length=max_sequence_length,\n            device=device,\n        )\n        if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)\n        elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)\n            prompt_attention_mask = torch.cat(\n                [negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0\n            )\n\n        # 4. Prepare latent variables\n        if latents is None:\n            image = self.video_processor.preprocess(image, height=height, width=width)\n            image = image.to(device=device, dtype=prompt_embeds.dtype)\n\n        num_channels_latents = self.transformer.config.in_channels\n        latents, conditioning_mask = self.prepare_latents(\n            image,\n            batch_size * num_videos_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            num_frames,\n            torch.float32,\n            device,\n            generator,\n            latents,\n        )\n\n        if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n            conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])\n        elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n            conditioning_mask = torch.cat([conditioning_mask, conditioning_mask, conditioning_mask])\n\n        # 5. Prepare timesteps\n        latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1\n        latent_height = height // self.vae_spatial_compression_ratio\n        latent_width = width // self.vae_spatial_compression_ratio\n        video_sequence_length = latent_num_frames * latent_height * latent_width\n        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)\n        mu = calculate_shift(\n            video_sequence_length,\n            self.scheduler.config.get(\"base_image_seq_len\", 256),\n            self.scheduler.config.get(\"max_image_seq_len\", 4096),\n            self.scheduler.config.get(\"base_shift\", 0.5),\n            self.scheduler.config.get(\"max_shift\", 1.16),\n        )\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            timesteps,\n            sigmas=sigmas,\n            mu=mu,\n        )\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n        self._num_timesteps = len(timesteps)\n\n        # 6. Prepare micro-conditions\n        latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio\n        rope_interpolation_scale = (\n            1 / latent_frame_rate,\n            self.vae_spatial_compression_ratio,\n            self.vae_spatial_compression_ratio,\n        )\n\n        # 7. Denoising loop\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n                    latent_model_input = torch.cat([latents] * 2)\n                elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n                    latent_model_input = torch.cat([latents] * 3)\n                else:\n                    latent_model_input = latents\n\n                latent_model_input = latent_model_input.to(prompt_embeds.dtype)\n\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latent_model_input.shape[0])\n                timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)\n\n                noise_pred = self.transformer(\n                    hidden_states=latent_model_input,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep=timestep,\n                    encoder_attention_mask=prompt_attention_mask,\n                    num_frames=latent_num_frames,\n                    height=latent_height,\n                    width=latent_width,\n                    rope_interpolation_scale=rope_interpolation_scale,\n                    attention_kwargs=attention_kwargs,\n                    return_dict=False,\n                )[0]\n                noise_pred = noise_pred.float()\n\n                if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n                    timestep, _ = timestep.chunk(2)\n                elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n                    noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)\n                    noise_pred = (\n                        noise_pred_uncond\n                        + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n                        + self._stg_scale * (noise_pred_text - noise_pred_perturb)\n                    )\n                    timestep, _, _ = timestep.chunk(3)\n\n                if do_rescaling:\n                    rescaling_scale = 0.7\n                    factor = noise_pred_text.std() / noise_pred.std()\n                    factor = rescaling_scale * factor + (1 - rescaling_scale)\n                    noise_pred = noise_pred * factor\n\n                # compute the previous noisy sample x_t -> x_t-1\n                noise_pred = self._unpack_latents(\n                    noise_pred,\n                    latent_num_frames,\n                    latent_height,\n                    latent_width,\n                    self.transformer_spatial_patch_size,\n                    self.transformer_temporal_patch_size,\n                )\n                latents = self._unpack_latents(\n                    latents,\n                    latent_num_frames,\n                    latent_height,\n                    latent_width,\n                    self.transformer_spatial_patch_size,\n                    self.transformer_temporal_patch_size,\n                )\n\n                noise_pred = noise_pred[:, :, 1:]\n                noise_latents = latents[:, :, 1:]\n                pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]\n\n                latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)\n                latents = self._pack_latents(\n                    latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size\n                )\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if output_type == \"latent\":\n            video = latents\n        else:\n            latents = self._unpack_latents(\n                latents,\n                latent_num_frames,\n                latent_height,\n                latent_width,\n                self.transformer_spatial_patch_size,\n                self.transformer_temporal_patch_size,\n            )\n            latents = self._denormalize_latents(\n                latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor\n            )\n            latents = latents.to(prompt_embeds.dtype)\n\n            if not self.vae.config.timestep_conditioning:\n                timestep = None\n            else:\n                noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)\n                if not isinstance(decode_timestep, list):\n                    decode_timestep = [decode_timestep] * batch_size\n                if decode_noise_scale is None:\n                    decode_noise_scale = decode_timestep\n                elif not isinstance(decode_noise_scale, list):\n                    decode_noise_scale = [decode_noise_scale] * batch_size\n\n                timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)\n                decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[\n                    :, None, None, None, None\n                ]\n                latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise\n\n            video = self.vae.decode(latents, timestep, return_dict=False)[0]\n            video = self.video_processor.postprocess_video(video, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (video,)\n\n        return LTXPipelineOutput(frames=video)\n"
  },
  {
    "path": "examples/community/pipeline_stg_mochi.py",
    "content": "# Copyright 2025 Genmo and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport types\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom transformers import T5EncoderModel, T5TokenizerFast\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.loaders import Mochi1LoraLoaderMixin\nfrom diffusers.models import AutoencoderKLMochi, MochiTransformer3DModel\nfrom diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import deprecate, is_torch_xla_available, logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.video_processor import VideoProcessor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers.utils import export_to_video\n        >>> from examples.community.pipeline_stg_mochi import MochiSTGPipeline\n\n        >>> pipe = MochiSTGPipeline.from_pretrained(\"genmo/mochi-1-preview\", torch_dtype=torch.bfloat16)\n        >>> pipe.enable_model_cpu_offload()\n        >>> pipe.enable_vae_tiling()\n        >>> prompt = \"A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style.\"\n\n        >>> # Configure STG mode options\n        >>> stg_applied_layers_idx = [34]  # Layer indices from 0 to 41\n        >>> stg_scale = 1.0 # Set 0.0 for CFG\n        >>> do_rescaling = False\n\n        >>> frames = pipe(\n        ...     prompt=prompt,\n        ...     num_inference_steps=28,\n        ...     guidance_scale=3.5,\n        ...     stg_applied_layers_idx=stg_applied_layers_idx,\n        ...     stg_scale=stg_scale,\n        ...     do_rescaling=do_rescaling).frames[0]\n        >>> export_to_video(frames, \"mochi.mp4\")\n        ```\n\"\"\"\n\n\ndef forward_with_stg(\n    self,\n    hidden_states: torch.Tensor,\n    encoder_hidden_states: torch.Tensor,\n    temb: torch.Tensor,\n    encoder_attention_mask: torch.Tensor,\n    image_rotary_emb: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    hidden_states_ptb = hidden_states[2:]\n    encoder_hidden_states_ptb = encoder_hidden_states[2:]\n    norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)\n\n    if not self.context_pre_only:\n        norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(\n            encoder_hidden_states, temb\n        )\n    else:\n        norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)\n\n    attn_hidden_states, context_attn_hidden_states = self.attn1(\n        hidden_states=norm_hidden_states,\n        encoder_hidden_states=norm_encoder_hidden_states,\n        image_rotary_emb=image_rotary_emb,\n        attention_mask=encoder_attention_mask,\n    )\n\n    hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))\n    norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))\n    ff_output = self.ff(norm_hidden_states)\n    hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))\n\n    if not self.context_pre_only:\n        encoder_hidden_states = encoder_hidden_states + self.norm2_context(\n            context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)\n        )\n        norm_encoder_hidden_states = self.norm3_context(\n            encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))\n        )\n        context_ff_output = self.ff_context(norm_encoder_hidden_states)\n        encoder_hidden_states = encoder_hidden_states + self.norm4_context(\n            context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)\n        )\n\n        hidden_states[2:] = hidden_states_ptb\n        encoder_hidden_states[2:] = encoder_hidden_states_ptb\n\n    return hidden_states, encoder_hidden_states\n\n\n# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77\ndef linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):\n    if linear_steps is None:\n        linear_steps = num_steps // 2\n    linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]\n    threshold_noise_step_diff = linear_steps - threshold_noise * num_steps\n    quadratic_steps = num_steps - linear_steps\n    quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)\n    linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)\n    const = quadratic_coef * (linear_steps**2)\n    quadratic_sigma_schedule = [\n        quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)\n    ]\n    sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule\n    sigma_schedule = [1.0 - x for x in sigma_schedule]\n    return sigma_schedule\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom value\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):\n    r\"\"\"\n    The mochi pipeline for text-to-video generation.\n\n    Reference: https://github.com/genmoai/models\n\n    Args:\n        transformer ([`MochiTransformer3DModel`]):\n            Conditional Transformer architecture to denoise the encoded video latents.\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKLMochi`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.\n        text_encoder ([`T5EncoderModel`]):\n            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically\n            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer (`T5TokenizerFast`):\n            Second Tokenizer of class\n            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->transformer->vae\"\n    _optional_components = []\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        vae: AutoencoderKLMochi,\n        text_encoder: T5EncoderModel,\n        tokenizer: T5TokenizerFast,\n        transformer: MochiTransformer3DModel,\n        force_zeros_for_empty_prompt: bool = False,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            transformer=transformer,\n            scheduler=scheduler,\n        )\n        # TODO: determine these scaling factors from model parameters\n        self.vae_spatial_scale_factor = 8\n        self.vae_temporal_scale_factor = 6\n        self.patch_size = 2\n\n        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)\n        self.tokenizer_max_length = (\n            self.tokenizer.model_max_length if hasattr(self, \"tokenizer\") and self.tokenizer is not None else 256\n        )\n        self.default_height = 480\n        self.default_width = 848\n        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)\n\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_videos_per_prompt: int = 1,\n        max_sequence_length: int = 256,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            add_special_tokens=True,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n        prompt_attention_mask = text_inputs.attention_mask\n        prompt_attention_mask = prompt_attention_mask.bool().to(device)\n\n        # The original Mochi implementation zeros out empty negative prompts\n        # but this can lead to overflow when placing the entire pipeline under the autocast context\n        # adding this here so that we can enable zeroing prompts if necessary\n        if self.config.force_zeros_for_empty_prompt and (prompt == \"\" or prompt[-1] == \"\"):\n            text_input_ids = torch.zeros_like(text_input_ids, device=device)\n            prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)\n\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        _, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)\n\n        prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)\n        prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)\n\n        return prompt_embeds, prompt_attention_mask\n\n    # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        do_classifier_free_guidance: bool = True,\n        num_videos_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        max_sequence_length: int = 256,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):\n                Whether to use classifier free guidance or not.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            device: (`torch.device`, *optional*):\n                torch device\n            dtype: (`torch.dtype`, *optional*):\n                torch dtype\n        \"\"\"\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(\n                prompt=prompt,\n                num_videos_per_prompt=num_videos_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n                dtype=dtype,\n            )\n\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n\n            negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(\n                prompt=negative_prompt,\n                num_videos_per_prompt=num_videos_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n                dtype=dtype,\n            )\n\n        return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_on_step_end_tensor_inputs=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        prompt_attention_mask=None,\n        negative_prompt_attention_mask=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if prompt_embeds is not None and prompt_attention_mask is None:\n            raise ValueError(\"Must provide `prompt_attention_mask` when specifying `prompt_embeds`.\")\n\n        if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:\n            raise ValueError(\"Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.\")\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n            if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:\n                raise ValueError(\n                    \"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`\"\n                    f\" {negative_prompt_attention_mask.shape}.\"\n                )\n\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`.\"\n        deprecate(\n            \"enable_vae_slicing\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_slicing()\n\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`.\"\n        deprecate(\n            \"disable_vae_slicing\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_slicing()\n\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`.\"\n        deprecate(\n            \"enable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_tiling()\n\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`.\"\n        deprecate(\n            \"disable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_tiling()\n\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        num_frames,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        height = height // self.vae_spatial_scale_factor\n        width = width // self.vae_spatial_scale_factor\n        num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1\n\n        shape = (batch_size, num_channels_latents, num_frames, height, width)\n\n        if latents is not None:\n            return latents.to(device=device, dtype=dtype)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)\n        latents = latents.to(dtype)\n        return latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1.0\n\n    @property\n    def do_spatio_temporal_guidance(self):\n        return self._stg_scale > 0.0\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def attention_kwargs(self):\n        return self._attention_kwargs\n\n    @property\n    def current_timestep(self):\n        return self._current_timestep\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_frames: int = 19,\n        num_inference_steps: int = 64,\n        timesteps: List[int] = None,\n        guidance_scale: float = 4.5,\n        num_videos_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 256,\n        stg_applied_layers_idx: Optional[List[int]] = [34],\n        stg_scale: Optional[float] = 0.0,\n        do_rescaling: Optional[bool] = False,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            height (`int`, *optional*, defaults to `self.default_height`):\n                The height in pixels of the generated image. This is set to 480 by default for the best results.\n            width (`int`, *optional*, defaults to `self.default_width`):\n                The width in pixels of the generated image. This is set to 848 by default for the best results.\n            num_frames (`int`, defaults to `19`):\n                The number of video frames to generate\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, defaults to `4.5`):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                The number of videos to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            prompt_attention_mask (`torch.Tensor`, *optional*):\n                Pre-generated attention mask for text embeddings.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be \"\". If not\n                provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.\n            negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):\n                Pre-generated attention mask for negative text embeddings.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.\n            attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int` defaults to `256`):\n                Maximum sequence length to use with the `prompt`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`\n                is returned where the first element is a list with the generated images.\n        \"\"\"\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        height = height or self.default_height\n        width = width or self.default_width\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt=prompt,\n            height=height,\n            width=width,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            prompt_attention_mask=prompt_attention_mask,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._stg_scale = stg_scale\n        self._attention_kwargs = attention_kwargs\n        self._current_timestep = None\n        self._interrupt = False\n\n        if self.do_spatio_temporal_guidance:\n            for i in stg_applied_layers_idx:\n                self.transformer.transformer_blocks[i].forward = types.MethodType(\n                    forward_with_stg, self.transformer.transformer_blocks[i]\n                )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # 3. Prepare text embeddings\n        (\n            prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_embeds,\n            negative_prompt_attention_mask,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            num_videos_per_prompt=num_videos_per_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            prompt_attention_mask=prompt_attention_mask,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n            max_sequence_length=max_sequence_length,\n            device=device,\n        )\n        # 4. Prepare latent variables\n        num_channels_latents = self.transformer.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_videos_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            num_frames,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)\n        elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)\n            prompt_attention_mask = torch.cat(\n                [negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0\n            )\n\n        # 5. Prepare timestep\n        # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77\n        threshold_noise = 0.025\n        sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)\n        sigmas = np.array(sigmas)\n\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            timesteps,\n            sigmas,\n        )\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n        self._num_timesteps = len(timesteps)\n\n        # 6. Denoising loop\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need\n                # to make sure we're using the correct non-reversed timestep value.\n                self._current_timestep = 1000 - t\n                if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n                    latent_model_input = torch.cat([latents] * 2)\n                elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n                    latent_model_input = torch.cat([latents] * 3)\n                else:\n                    latent_model_input = latents\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)\n\n                noise_pred = self.transformer(\n                    hidden_states=latent_model_input,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep=timestep,\n                    encoder_attention_mask=prompt_attention_mask,\n                    attention_kwargs=attention_kwargs,\n                    return_dict=False,\n                )[0]\n                # Mochi CFG + Sampling runs in FP32\n                noise_pred = noise_pred.to(torch.float32)\n\n                if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n                elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:\n                    noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)\n                    noise_pred = (\n                        noise_pred_uncond\n                        + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n                        + self._stg_scale * (noise_pred_text - noise_pred_perturb)\n                    )\n\n                if do_rescaling:\n                    rescaling_scale = 0.7\n                    factor = noise_pred_text.std() / noise_pred.std()\n                    factor = rescaling_scale * factor + (1 - rescaling_scale)\n                    noise_pred = noise_pred * factor\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]\n                latents = latents.to(latents_dtype)\n\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        self._current_timestep = None\n\n        if output_type == \"latent\":\n            video = latents\n        else:\n            # unscale/denormalize the latents\n            # denormalize with the mean and std if available and not None\n            has_latents_mean = hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None\n            has_latents_std = hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None\n            if has_latents_mean and has_latents_std:\n                latents_mean = (\n                    torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents_std = (\n                    torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean\n            else:\n                latents = latents / self.vae.config.scaling_factor\n\n            video = self.vae.decode(latents, return_dict=False)[0]\n            video = self.video_processor.postprocess_video(video, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (video,)\n\n        return MochiPipelineOutput(frames=video)\n"
  },
  {
    "path": "examples/community/pipeline_stg_wan.py",
    "content": "# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport html\nimport types\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport ftfy\nimport regex as re\nimport torch\nfrom transformers import AutoTokenizer, UMT5EncoderModel\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.loaders import WanLoraLoaderMixin\nfrom diffusers.models import AutoencoderKLWan, WanTransformer3DModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.wan.pipeline_output import WanPipelineOutput\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import is_torch_xla_available, logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.video_processor import VideoProcessor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```python\n        >>> import torch\n        >>> from diffusers.utils import export_to_video\n        >>> from diffusers import AutoencoderKLWan\n        >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler\n        >>> from examples.community.pipeline_stg_wan import WanSTGPipeline\n\n        >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers\n        >>> model_id = \"Wan-AI/Wan2.1-T2V-14B-Diffusers\"\n        >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder=\"vae\", torch_dtype=torch.float32)\n        >>> pipe = WanSTGPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)\n        >>> flow_shift = 5.0  # 5.0 for 720P, 3.0 for 480P\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)\n        >>> pipe.to(\"cuda\")\n\n        >>> prompt = \"A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window.\"\n        >>> negative_prompt = \"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards\"\n\n        >>> # Configure STG mode options\n        >>> stg_applied_layers_idx = [8] # Layer indices from 0 to 39 for 14b or 0 to 29 for 1.3b\n        >>> stg_scale = 1.0 # Set 0.0 for CFG\n\n        >>> output = pipe(\n        ...     prompt=prompt,\n        ...     negative_prompt=negative_prompt,\n        ...     height=720,\n        ...     width=1280,\n        ...     num_frames=81,\n        ...     guidance_scale=5.0,\n        ...     stg_applied_layers_idx=stg_applied_layers_idx,\n        ...     stg_scale=stg_scale,\n        ... ).frames[0]\n        >>> export_to_video(output, \"output.mp4\", fps=16)\n        ```\n\"\"\"\n\n\ndef basic_clean(text):\n    text = ftfy.fix_text(text)\n    text = html.unescape(html.unescape(text))\n    return text.strip()\n\n\ndef whitespace_clean(text):\n    text = re.sub(r\"\\s+\", \" \", text)\n    text = text.strip()\n    return text\n\n\ndef prompt_clean(text):\n    text = whitespace_clean(basic_clean(text))\n    return text\n\n\ndef forward_with_stg(\n    self,\n    hidden_states: torch.Tensor,\n    encoder_hidden_states: torch.Tensor,\n    temb: torch.Tensor,\n    rotary_emb: torch.Tensor,\n) -> torch.Tensor:\n    return hidden_states\n\n\ndef forward_without_stg(\n    self,\n    hidden_states: torch.Tensor,\n    encoder_hidden_states: torch.Tensor,\n    temb: torch.Tensor,\n    rotary_emb: torch.Tensor,\n) -> torch.Tensor:\n    shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (\n        self.scale_shift_table + temb.float()\n    ).chunk(6, dim=1)\n\n    # 1. Self-attention\n    norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)\n    attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)\n    hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)\n\n    # 2. Cross-attention\n    norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)\n    attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)\n    hidden_states = hidden_states + attn_output\n\n    # 3. Feed-forward\n    norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)\n    ff_output = self.ffn(norm_hidden_states)\n    hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)\n\n    return hidden_states\n\n\nclass WanSTGPipeline(DiffusionPipeline, WanLoraLoaderMixin):\n    r\"\"\"\n    Pipeline for text-to-video generation using Wan.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    Args:\n        tokenizer ([`T5Tokenizer`]):\n            Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),\n            specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.\n        text_encoder ([`T5EncoderModel`]):\n            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically\n            the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.\n        transformer ([`WanTransformer3DModel`]):\n            Conditional Transformer to denoise the input latents.\n        scheduler ([`UniPCMultistepScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKLWan`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->transformer->vae\"\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        tokenizer: AutoTokenizer,\n        text_encoder: UMT5EncoderModel,\n        transformer: WanTransformer3DModel,\n        vae: AutoencoderKLWan,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            transformer=transformer,\n            scheduler=scheduler,\n        )\n\n        self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, \"vae\", None) else 4\n        self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, \"vae\", None) else 8\n        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)\n\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_videos_per_prompt: int = 1,\n        max_sequence_length: int = 226,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        prompt = [prompt_clean(u) for u in prompt]\n        batch_size = len(prompt)\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            add_special_tokens=True,\n            return_attention_mask=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask\n        seq_lens = mask.gt(0).sum(dim=1).long()\n\n        prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n        prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]\n        prompt_embeds = torch.stack(\n            [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0\n        )\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        _, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)\n\n        return prompt_embeds\n\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        do_classifier_free_guidance: bool = True,\n        num_videos_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        max_sequence_length: int = 226,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):\n                Whether to use classifier free guidance or not.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            device: (`torch.device`, *optional*):\n                torch device\n            dtype: (`torch.dtype`, *optional*):\n                torch dtype\n        \"\"\"\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        if prompt is not None:\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            prompt_embeds = self._get_t5_prompt_embeds(\n                prompt=prompt,\n                num_videos_per_prompt=num_videos_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n                dtype=dtype,\n            )\n\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            negative_prompt = negative_prompt or \"\"\n            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n\n            if prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n\n            negative_prompt_embeds = self._get_t5_prompt_embeds(\n                prompt=negative_prompt,\n                num_videos_per_prompt=num_videos_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n                dtype=dtype,\n            )\n\n        return prompt_embeds, negative_prompt_embeds\n\n    def check_inputs(\n        self,\n        prompt,\n        negative_prompt,\n        height,\n        width,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if height % 16 != 0 or width % 16 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 16 but are {height} and {width}.\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif negative_prompt is not None and (\n            not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)\n        ):\n            raise ValueError(f\"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}\")\n\n    def prepare_latents(\n        self,\n        batch_size: int,\n        num_channels_latents: int = 16,\n        height: int = 480,\n        width: int = 832,\n        num_frames: int = 81,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if latents is not None:\n            return latents.to(device=device, dtype=dtype)\n\n        num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1\n        shape = (\n            batch_size,\n            num_channels_latents,\n            num_latent_frames,\n            int(height) // self.vae_scale_factor_spatial,\n            int(width) // self.vae_scale_factor_spatial,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        return latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1.0\n\n    @property\n    def do_spatio_temporal_guidance(self):\n        return self._stg_scale > 0.0\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def current_timestep(self):\n        return self._current_timestep\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @property\n    def attention_kwargs(self):\n        return self._attention_kwargs\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        negative_prompt: Union[str, List[str]] = None,\n        height: int = 480,\n        width: int = 832,\n        num_frames: int = 81,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 5.0,\n        num_videos_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"np\",\n        return_dict: bool = True,\n        attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 512,\n        stg_applied_layers_idx: Optional[List[int]] = [3, 8, 16],\n        stg_scale: Optional[float] = 0.0,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            height (`int`, defaults to `480`):\n                The height in pixels of the generated image.\n            width (`int`, defaults to `832`):\n                The width in pixels of the generated image.\n            num_frames (`int`, defaults to `81`):\n                The number of frames in the generated video.\n            num_inference_steps (`int`, defaults to `50`):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, defaults to `5.0`):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_videos_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.\n            attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):\n                The dtype to use for the torch.amp.autocast.\n\n        Examples:\n\n        Returns:\n            [`~WanPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where\n                the first element is a list with the generated images and the second element is a list of `bool`s\n                indicating whether the corresponding generated image contains \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            negative_prompt,\n            height,\n            width,\n            prompt_embeds,\n            negative_prompt_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._stg_scale = stg_scale\n        self._attention_kwargs = attention_kwargs\n        self._current_timestep = None\n        self._interrupt = False\n\n        device = self._execution_device\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # 3. Encode input prompt\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            num_videos_per_prompt=num_videos_per_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            max_sequence_length=max_sequence_length,\n            device=device,\n        )\n\n        transformer_dtype = self.transformer.dtype\n        prompt_embeds = prompt_embeds.to(transformer_dtype)\n        if negative_prompt_embeds is not None:\n            negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.transformer.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_videos_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            num_frames,\n            torch.float32,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                self._current_timestep = t\n                latent_model_input = latents.to(transformer_dtype)\n                timestep = t.expand(latents.shape[0])\n\n                if self.do_spatio_temporal_guidance:\n                    for idx, block in enumerate(self.transformer.blocks):\n                        block.forward = types.MethodType(forward_without_stg, block)\n\n                noise_pred = self.transformer(\n                    hidden_states=latent_model_input,\n                    timestep=timestep,\n                    encoder_hidden_states=prompt_embeds,\n                    attention_kwargs=attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                if self.do_classifier_free_guidance:\n                    noise_uncond = self.transformer(\n                        hidden_states=latent_model_input,\n                        timestep=timestep,\n                        encoder_hidden_states=negative_prompt_embeds,\n                        attention_kwargs=attention_kwargs,\n                        return_dict=False,\n                    )[0]\n                    if self.do_spatio_temporal_guidance:\n                        for idx, block in enumerate(self.transformer.blocks):\n                            if idx in stg_applied_layers_idx:\n                                block.forward = types.MethodType(forward_with_stg, block)\n                        noise_perturb = self.transformer(\n                            hidden_states=latent_model_input,\n                            timestep=timestep,\n                            encoder_hidden_states=prompt_embeds,\n                            attention_kwargs=attention_kwargs,\n                            return_dict=False,\n                        )[0]\n                        noise_pred = (\n                            noise_uncond\n                            + guidance_scale * (noise_pred - noise_uncond)\n                            + self._stg_scale * (noise_pred - noise_perturb)\n                        )\n                    else:\n                        noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        self._current_timestep = None\n\n        if not output_type == \"latent\":\n            latents = latents.to(self.vae.dtype)\n            latents_mean = (\n                torch.tensor(self.vae.config.latents_mean)\n                .view(1, self.vae.config.z_dim, 1, 1, 1)\n                .to(latents.device, latents.dtype)\n            )\n            latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(\n                latents.device, latents.dtype\n            )\n            latents = latents / latents_std + latents_mean\n            video = self.vae.decode(latents, return_dict=False)[0]\n            video = self.video_processor.postprocess_video(video, output_type=output_type)\n        else:\n            video = latents\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (video,)\n\n        return WanPipelineOutput(frames=video)\n"
  },
  {
    "path": "examples/community/pipeline_z_image_differential_img2img.py",
    "content": "# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport torch\nfrom transformers import AutoTokenizer, PreTrainedModel\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import FromSingleFileMixin, ZImageLoraLoaderMixin\nfrom diffusers.models.autoencoders import AutoencoderKL\nfrom diffusers.models.transformers import ZImageTransformer2DModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from pipeline_z_image_differential_img2img import ZImageDifferentialImg2ImgPipeline\n        >>> from diffusers.utils import load_image\n\n        >>> pipe = ZImageDifferentialImg2ImgPipeline.from_pretrained(\"Z-a-o/Z-Image-Turbo\", torch_dtype=torch.bfloat16)\n        >>> pipe.to(\"cuda\")\n\n        >>> init_image = load_image(\n        >>>     \"https://github.com/exx8/differential-diffusion/blob/main/assets/input.jpg?raw=true\",\n        >>> )\n\n        >>> mask = load_image(\n        >>>     \"https://github.com/exx8/differential-diffusion/blob/main/assets/map.jpg?raw=true\",\n        >>> )\n\n        >>> prompt = \"painting of a mountain landscape with a meadow and a forest, meadow background, anime countryside landscape, anime nature wallpap, anime landscape wallpaper, studio ghibli landscape, anime landscape, mountain behind meadow, anime background art, studio ghibli environment, background of flowery hill, anime beautiful peace scene, forrest background, anime scenery, landscape background, background art, anime scenery concept art\"\n\n        >>> image = pipe(\n        ...     prompt,\n        ...     image=init_image,\n        ...     mask_image=mask,\n        ...     strength=0.75,\n        ...     num_inference_steps=9,\n        ...     guidance_scale=0.0,\n        ...     generator=torch.Generator(\"cuda\").manual_seed(41),\n        ... ).images[0]\n        >>> image.save(\"image.png\")\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.15,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass ZImageDifferentialImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):\n    r\"\"\"\n    The ZImage pipeline for image-to-image generation.\n\n    Args:\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`PreTrainedModel`]):\n            A text encoder model to encode text prompts.\n        tokenizer ([`AutoTokenizer`]):\n            A tokenizer to tokenize text prompts.\n        transformer ([`ZImageTransformer2DModel`]):\n            A ZImage transformer model to denoise the encoded image latents.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->transformer->vae\"\n    _optional_components = []\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\"]\n\n    def __init__(\n        self,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        vae: AutoencoderKL,\n        text_encoder: PreTrainedModel,\n        tokenizer: AutoTokenizer,\n        transformer: ZImageTransformer2DModel,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            scheduler=scheduler,\n            transformer=transformer,\n        )\n        self.vae_scale_factor = (\n            2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, \"vae\") and self.vae is not None else 8\n        )\n        latent_channels = self.vae.config.latent_channels if getattr(self, \"vae\", None) else 16\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)\n\n        self.mask_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor,\n            vae_latent_channels=latent_channels,\n            do_normalize=False,\n            do_binarize=False,\n            do_convert_grayscale=True,\n        )\n\n    # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        device: Optional[torch.device] = None,\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        prompt_embeds: Optional[List[torch.FloatTensor]] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        max_sequence_length: int = 512,\n    ):\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        prompt_embeds = self._encode_prompt(\n            prompt=prompt,\n            device=device,\n            prompt_embeds=prompt_embeds,\n            max_sequence_length=max_sequence_length,\n        )\n\n        if do_classifier_free_guidance:\n            if negative_prompt is None:\n                negative_prompt = [\"\" for _ in prompt]\n            else:\n                negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt\n            assert len(prompt) == len(negative_prompt)\n            negative_prompt_embeds = self._encode_prompt(\n                prompt=negative_prompt,\n                device=device,\n                prompt_embeds=negative_prompt_embeds,\n                max_sequence_length=max_sequence_length,\n            )\n        else:\n            negative_prompt_embeds = []\n        return prompt_embeds, negative_prompt_embeds\n\n    # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        device: Optional[torch.device] = None,\n        prompt_embeds: Optional[List[torch.FloatTensor]] = None,\n        max_sequence_length: int = 512,\n    ) -> List[torch.FloatTensor]:\n        device = device or self._execution_device\n\n        if prompt_embeds is not None:\n            return prompt_embeds\n\n        if isinstance(prompt, str):\n            prompt = [prompt]\n\n        for i, prompt_item in enumerate(prompt):\n            messages = [\n                {\"role\": \"user\", \"content\": prompt_item},\n            ]\n            prompt_item = self.tokenizer.apply_chat_template(\n                messages,\n                tokenize=False,\n                add_generation_prompt=True,\n                enable_thinking=True,\n            )\n            prompt[i] = prompt_item\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids.to(device)\n        prompt_masks = text_inputs.attention_mask.to(device).bool()\n\n        prompt_embeds = self.text_encoder(\n            input_ids=text_input_ids,\n            attention_mask=prompt_masks,\n            output_hidden_states=True,\n        ).hidden_states[-2]\n\n        embeddings_list = []\n\n        for i in range(len(prompt_embeds)):\n            embeddings_list.append(prompt_embeds[i][prompt_masks[i]])\n\n        return embeddings_list\n\n    # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(num_inference_steps * strength, num_inference_steps)\n\n        t_start = int(max(num_inference_steps - init_timestep, 0))\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n        if hasattr(self.scheduler, \"set_begin_index\"):\n            self.scheduler.set_begin_index(t_start * self.scheduler.order)\n\n        return timesteps, num_inference_steps - t_start\n\n    @staticmethod\n    def _prepare_latent_image_ids(batch_size, height, width, device, dtype):\n        latent_image_ids = torch.zeros(height // 2, width // 2, 3)\n        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]\n        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]\n\n        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape\n\n        latent_image_ids = latent_image_ids.reshape(\n            latent_image_id_height * latent_image_id_width, latent_image_id_channels\n        )\n\n        return latent_image_ids.to(device=device, dtype=dtype)\n\n    def prepare_latents(\n        self,\n        image,\n        timestep,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        height = 2 * (int(height) // (self.vae_scale_factor * 2))\n        width = 2 * (int(width) // (self.vae_scale_factor * 2))\n\n        shape = (batch_size, num_channels_latents, height, width)\n        latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)\n\n        if latents is not None:\n            return latents.to(device=device, dtype=dtype)\n\n        # Encode the input image\n        image = image.to(device=device, dtype=dtype)\n        if image.shape[1] != num_channels_latents:\n            if isinstance(generator, list):\n                image_latents = [\n                    retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                    for i in range(image.shape[0])\n                ]\n                image_latents = torch.cat(image_latents, dim=0)\n            else:\n                image_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n            # Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor)\n            image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor\n        else:\n            image_latents = image\n\n        # Handle batch size expansion\n        if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:\n            additional_image_per_prompt = batch_size // image_latents.shape[0]\n            image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n\n        # Add noise using flow matching scale_noise\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        latents = self.scheduler.scale_noise(image_latents, timestep, noise)\n\n        return latents, noise, image_latents, latent_image_ids\n\n    def prepare_mask_latents(\n        self,\n        mask,\n        masked_image,\n        batch_size,\n        num_images_per_prompt,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n    ):\n        height = 2 * (int(height) // (self.vae_scale_factor * 2))\n        width = 2 * (int(width) // (self.vae_scale_factor * 2))\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\n        # and half precision\n        mask = torch.nn.functional.interpolate(mask, size=(height, width))\n        mask = mask.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        masked_image = masked_image.to(device=device, dtype=dtype)\n\n        if masked_image.shape[1] == 16:\n            masked_image_latents = masked_image\n        else:\n            masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)\n\n        masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor\n\n        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\n        if mask.shape[0] < batch_size:\n            if not batch_size % mask.shape[0] == 0:\n                raise ValueError(\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\n                    f\" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number\"\n                    \" of masks that you pass is divisible by the total requested batch size.\"\n                )\n            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)\n        if masked_image_latents.shape[0] < batch_size:\n            if not batch_size % masked_image_latents.shape[0] == 0:\n                raise ValueError(\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                    f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                )\n            masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)\n\n        # aligning device to prevent device errors when concating it with the latent model input\n        masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\n\n        return mask, masked_image_latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1\n\n    @property\n    def joint_attention_kwargs(self):\n        return self._joint_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        mask_image: PipelineImageInput = None,\n        strength: float = 0.6,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        sigmas: Optional[List[float]] = None,\n        guidance_scale: float = 5.0,\n        cfg_normalization: bool = False,\n        cfg_truncation: float = 1.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[List[torch.FloatTensor]] = None,\n        negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 512,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for image-to-image generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):\n                `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both\n                numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a\n                list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or\n                a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.\n            mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):\n                `Image`, numpy array or tensor representing an image batch to mask `image`. Black pixels in the mask\n                are repainted while white pixels are preserved. If `mask_image` is a PIL image, it is converted to a\n                single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one\n                color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,\n                H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,\n                1)`, or `(H, W)`.\n            strength (`float`, *optional*, defaults to 0.6):\n                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1\n                essentially ignores `image`.\n            height (`int`, *optional*, defaults to 1024):\n                The height in pixels of the generated image. If not provided, uses the input image height.\n            width (`int`, *optional*, defaults to 1024):\n                The width in pixels of the generated image. If not provided, uses the input image width.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            cfg_normalization (`bool`, *optional*, defaults to False):\n                Whether to apply configuration normalization.\n            cfg_truncation (`float`, *optional*, defaults to 1.0):\n                The truncation value for configuration.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`List[torch.FloatTensor]`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain\n                tuple.\n            joint_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int`, *optional*, defaults to 512):\n                Maximum sequence length to use with the `prompt`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if\n            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the\n            generated images.\n        \"\"\"\n        # 1. Check inputs and validate strength\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should be in [0.0, 1.0] but is {strength}\")\n\n        # 2. Preprocess image\n        init_image = self.image_processor.preprocess(image)\n        init_image = init_image.to(dtype=torch.float32)\n\n        # Get dimensions from the preprocessed image if not specified\n        if height is None:\n            height = init_image.shape[-2]\n        if width is None:\n            width = init_image.shape[-1]\n\n        vae_scale = self.vae_scale_factor * 2\n        if height % vae_scale != 0:\n            raise ValueError(\n                f\"Height must be divisible by {vae_scale} (got {height}). \"\n                f\"Please adjust the height to a multiple of {vae_scale}.\"\n            )\n        if width % vae_scale != 0:\n            raise ValueError(\n                f\"Width must be divisible by {vae_scale} (got {width}). \"\n                f\"Please adjust the width to a multiple of {vae_scale}.\"\n            )\n\n        device = self._execution_device\n\n        self._guidance_scale = guidance_scale\n        self._joint_attention_kwargs = joint_attention_kwargs\n        self._interrupt = False\n        self._cfg_normalization = cfg_normalization\n        self._cfg_truncation = cfg_truncation\n\n        # 3. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = len(prompt_embeds)\n\n        # If prompt_embeds is provided and prompt is None, skip encoding\n        if prompt_embeds is not None and prompt is None:\n            if self.do_classifier_free_guidance and negative_prompt_embeds is None:\n                raise ValueError(\n                    \"When `prompt_embeds` is provided without `prompt`, \"\n                    \"`negative_prompt_embeds` must also be provided for classifier-free guidance.\"\n                )\n        else:\n            (\n                prompt_embeds,\n                negative_prompt_embeds,\n            ) = self.encode_prompt(\n                prompt=prompt,\n                negative_prompt=negative_prompt,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                prompt_embeds=prompt_embeds,\n                negative_prompt_embeds=negative_prompt_embeds,\n                device=device,\n                max_sequence_length=max_sequence_length,\n            )\n\n        # 4. Prepare latent variables\n        num_channels_latents = self.transformer.in_channels\n\n        # Repeat prompt_embeds for num_images_per_prompt\n        if num_images_per_prompt > 1:\n            prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]\n            if self.do_classifier_free_guidance and negative_prompt_embeds:\n                negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]\n\n        actual_batch_size = batch_size * num_images_per_prompt\n\n        # Calculate latent dimensions for image_seq_len\n        latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))\n        latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))\n        image_seq_len = (latent_height // 2) * (latent_width // 2)\n\n        # 5. Prepare timesteps\n        mu = calculate_shift(\n            image_seq_len,\n            self.scheduler.config.get(\"base_image_seq_len\", 256),\n            self.scheduler.config.get(\"max_image_seq_len\", 4096),\n            self.scheduler.config.get(\"base_shift\", 0.5),\n            self.scheduler.config.get(\"max_shift\", 1.15),\n        )\n        self.scheduler.sigma_min = 0.0\n        scheduler_kwargs = {\"mu\": mu}\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            sigmas=sigmas,\n            **scheduler_kwargs,\n        )\n\n        # 6. Adjust timesteps based on strength\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n        if num_inference_steps < 1:\n            raise ValueError(\n                f\"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline \"\n                f\"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline.\"\n            )\n        latent_timestep = timesteps[:1].repeat(actual_batch_size)\n\n        # 7. Prepare latents from image\n        latents, noise, original_image_latents, latent_image_ids = self.prepare_latents(\n            init_image,\n            latent_timestep,\n            actual_batch_size,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds[0].dtype,\n            device,\n            generator,\n            latents,\n        )\n        resize_mode = \"default\"\n        crops_coords = None\n\n        # start diff diff preparation\n        original_mask = self.mask_processor.preprocess(\n            mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords\n        )\n\n        masked_image = init_image * original_mask\n        original_mask, _ = self.prepare_mask_latents(\n            original_mask,\n            masked_image,\n            batch_size,\n            num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds[0].dtype,\n            device,\n            generator,\n        )\n        mask_thresholds = torch.arange(num_inference_steps, dtype=original_mask.dtype) / num_inference_steps\n        mask_thresholds = mask_thresholds.reshape(-1, 1, 1, 1).to(device)\n        masks = original_mask > mask_thresholds\n        # end diff diff preparation\n\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n        self._num_timesteps = len(timesteps)\n\n        # 8. Denoising loop\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latents.shape[0])\n                timestep = (1000 - timestep) / 1000\n                # Normalized time for time-aware config (0 at start, 1 at end)\n                t_norm = timestep[0].item()\n\n                # Handle cfg truncation\n                current_guidance_scale = self.guidance_scale\n                if (\n                    self.do_classifier_free_guidance\n                    and self._cfg_truncation is not None\n                    and float(self._cfg_truncation) <= 1\n                ):\n                    if t_norm > self._cfg_truncation:\n                        current_guidance_scale = 0.0\n\n                # Run CFG only if configured AND scale is non-zero\n                apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0\n\n                if apply_cfg:\n                    latents_typed = latents.to(self.transformer.dtype)\n                    latent_model_input = latents_typed.repeat(2, 1, 1, 1)\n                    prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds\n                    timestep_model_input = timestep.repeat(2)\n                else:\n                    latent_model_input = latents.to(self.transformer.dtype)\n                    prompt_embeds_model_input = prompt_embeds\n                    timestep_model_input = timestep\n\n                latent_model_input = latent_model_input.unsqueeze(2)\n                latent_model_input_list = list(latent_model_input.unbind(dim=0))\n\n                model_out_list = self.transformer(\n                    latent_model_input_list,\n                    timestep_model_input,\n                    prompt_embeds_model_input,\n                )[0]\n\n                if apply_cfg:\n                    # Perform CFG\n                    pos_out = model_out_list[:actual_batch_size]\n                    neg_out = model_out_list[actual_batch_size:]\n\n                    noise_pred = []\n                    for j in range(actual_batch_size):\n                        pos = pos_out[j].float()\n                        neg = neg_out[j].float()\n\n                        pred = pos + current_guidance_scale * (pos - neg)\n\n                        # Renormalization\n                        if self._cfg_normalization and float(self._cfg_normalization) > 0.0:\n                            ori_pos_norm = torch.linalg.vector_norm(pos)\n                            new_pos_norm = torch.linalg.vector_norm(pred)\n                            max_new_norm = ori_pos_norm * float(self._cfg_normalization)\n                            if new_pos_norm > max_new_norm:\n                                pred = pred * (max_new_norm / new_pos_norm)\n\n                        noise_pred.append(pred)\n\n                    noise_pred = torch.stack(noise_pred, dim=0)\n                else:\n                    noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)\n\n                noise_pred = noise_pred.squeeze(2)\n                noise_pred = -noise_pred\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]\n                assert latents.dtype == torch.float32\n\n                # start diff diff\n                image_latent = original_image_latents\n                latents_dtype = latents.dtype\n                if i < len(timesteps) - 1:\n                    noise_timestep = timesteps[i + 1]\n                    image_latent = self.scheduler.scale_noise(\n                        original_image_latents, torch.tensor([noise_timestep]), noise\n                    )\n\n                    mask = masks[i].to(latents_dtype)\n                    latents = image_latent * mask + latents * (1 - mask)\n                # end diff diff\n\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n        if output_type == \"latent\":\n            image = latents\n\n        else:\n            latents = latents.to(self.vae.dtype)\n            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor\n\n            image = self.vae.decode(latents, return_dict=False)[0]\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return ZImagePipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/pipeline_zero1to3.py",
    "content": "# A diffuser version implementation of Zero1to3 (https://github.com/cvlab-columbia/zero123), ICCV 2023\n# by Xin Kong\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport kornia\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom packaging import version\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\n\n# from ...configuration_utils import FrozenDict\n# from ...models import AutoencoderKL, UNet2DConditionModel\n# from ...schedulers import KarrasDiffusionSchedulers\n# from ...utils import (\n#     deprecate,\n#     is_accelerate_available,\n#     is_accelerate_version,\n#     logging,\n#     randn_tensor,\n#     replace_example_docstring,\n# )\n# from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin\n# from . import StableDiffusionPipelineOutput\n# from .safety_checker import StableDiffusionSafetyChecker\nfrom diffusers import AutoencoderKL, DiffusionPipeline, StableDiffusionMixin, UNet2DConditionModel\nfrom diffusers.configuration_utils import ConfigMixin, FrozenDict\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    deprecate,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n# todo\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import StableDiffusionPipeline\n\n        >>> pipe = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", torch_dtype=torch.float16)\n        >>> pipe = pipe.to(\"cuda\")\n\n        >>> prompt = \"a photo of an astronaut riding a horse on mars\"\n        >>> image = pipe(prompt).images[0]\n        ```\n\"\"\"\n\n\nclass CCProjection(ModelMixin, ConfigMixin):\n    def __init__(self, in_channel=772, out_channel=768):\n        super().__init__()\n        self.in_channel = in_channel\n        self.out_channel = out_channel\n        self.projection = torch.nn.Linear(in_channel, out_channel)\n\n    def forward(self, x):\n        return self.projection(x)\n\n\nclass Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):\n    r\"\"\"\n    Pipeline for single view conditioned novel view generation using Zero1to3.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        image_encoder ([`CLIPVisionModelWithProjection`]):\n            Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),\n            specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n        cc_projection ([`CCProjection`]):\n            Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        image_encoder: CLIPVisionModelWithProjection,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        cc_projection: CCProjection,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            image_encoder=image_encoder,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            cc_projection=cc_projection,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n        # self.model_mode = None\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n        \"\"\"\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            prompt_embeds = self.text_encoder(\n                text_input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    def CLIP_preprocess(self, x):\n        dtype = x.dtype\n        # following openai's implementation\n        # TODO HF OpenAI CLIP preprocessing issue https://github.com/huggingface/transformers/issues/22505#issuecomment-1650170741\n        # follow openai preprocessing to keep exact same, input tensor [-1, 1], otherwise the preprocessing will be different, https://github.com/huggingface/transformers/pull/22608\n        if isinstance(x, torch.Tensor):\n            if x.min() < -1.0 or x.max() > 1.0:\n                raise ValueError(\"Expected input tensor to have values in the range [-1, 1]\")\n        x = kornia.geometry.resize(\n            x.to(torch.float32), (224, 224), interpolation=\"bicubic\", align_corners=True, antialias=False\n        ).to(dtype=dtype)\n        x = (x + 1.0) / 2.0\n        # renormalize according to clip\n        x = kornia.enhance.normalize(\n            x, torch.Tensor([0.48145466, 0.4578275, 0.40821073]), torch.Tensor([0.26862954, 0.26130258, 0.27577711])\n        )\n        return x\n\n    # from image_variation\n    def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):\n        dtype = next(self.image_encoder.parameters()).dtype\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        if isinstance(image, torch.Tensor):\n            # Batch single image\n            if image.ndim == 3:\n                assert image.shape[0] == 3, \"Image outside a batch should be of shape (3, H, W)\"\n                image = image.unsqueeze(0)\n\n            assert image.ndim == 4, \"Image must have 4 dimensions\"\n\n            # Check image is in [-1, 1]\n            if image.min() < -1 or image.max() > 1:\n                raise ValueError(\"Image should be in [-1, 1] range\")\n        else:\n            # preprocess image\n            if isinstance(image, (PIL.Image.Image, np.ndarray)):\n                image = [image]\n\n            if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n                image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n                image = np.concatenate(image, axis=0)\n            elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n                image = np.concatenate([i[None, :] for i in image], axis=0)\n\n            image = image.transpose(0, 3, 1, 2)\n            image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n        image = image.to(device=device, dtype=dtype)\n\n        image = self.CLIP_preprocess(image)\n        # if not isinstance(image, torch.Tensor):\n        #     # 0-255\n        #     print(\"Warning: image is processed by hf's preprocess, which is different from openai original's.\")\n        #     image = self.feature_extractor(images=image, return_tensors=\"pt\").pixel_values\n        image_embeddings = self.image_encoder(image).image_embeds.to(dtype=dtype)\n        image_embeddings = image_embeddings.unsqueeze(1)\n\n        # duplicate image embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = image_embeddings.shape\n        image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)\n        image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        if do_classifier_free_guidance:\n            negative_prompt_embeds = torch.zeros_like(image_embeddings)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])\n\n        return image_embeddings\n\n    def _encode_pose(self, pose, device, num_images_per_prompt, do_classifier_free_guidance):\n        dtype = next(self.cc_projection.parameters()).dtype\n        if isinstance(pose, torch.Tensor):\n            pose_embeddings = pose.unsqueeze(1).to(device=device, dtype=dtype)\n        else:\n            if isinstance(pose[0], list):\n                pose = torch.Tensor(pose)\n            else:\n                pose = torch.Tensor([pose])\n            x, y, z = pose[:, 0].unsqueeze(1), pose[:, 1].unsqueeze(1), pose[:, 2].unsqueeze(1)\n            pose_embeddings = (\n                torch.cat([torch.deg2rad(x), torch.sin(torch.deg2rad(y)), torch.cos(torch.deg2rad(y)), z], dim=-1)\n                .unsqueeze(1)\n                .to(device=device, dtype=dtype)\n            )  # B, 1, 4\n        # duplicate pose embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = pose_embeddings.shape\n        pose_embeddings = pose_embeddings.repeat(1, num_images_per_prompt, 1)\n        pose_embeddings = pose_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)\n        if do_classifier_free_guidance:\n            negative_prompt_embeds = torch.zeros_like(pose_embeddings)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            pose_embeddings = torch.cat([negative_prompt_embeds, pose_embeddings])\n        return pose_embeddings\n\n    def _encode_image_with_pose(self, image, pose, device, num_images_per_prompt, do_classifier_free_guidance):\n        img_prompt_embeds = self._encode_image(image, device, num_images_per_prompt, False)\n        pose_prompt_embeds = self._encode_pose(pose, device, num_images_per_prompt, False)\n        prompt_embeds = torch.cat([img_prompt_embeds, pose_prompt_embeds], dim=-1)\n        prompt_embeds = self.cc_projection(prompt_embeds)\n        # prompt_embeds = img_prompt_embeds\n        # follow 0123, add negative prompt, after projection\n        if do_classifier_free_guidance:\n            negative_prompt = torch.zeros_like(prompt_embeds)\n            prompt_embeds = torch.cat([negative_prompt, prompt_embeds])\n        return prompt_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        else:\n            has_nsfw_concept = None\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(self, image, height, width, callback_steps):\n        if (\n            not isinstance(image, torch.Tensor)\n            and not isinstance(image, PIL.Image.Image)\n            and not isinstance(image, list)\n        ):\n            raise ValueError(\n                \"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is\"\n                f\" {type(image)}\"\n            )\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def prepare_img_latents(self, image, batch_size, dtype, device, generator=None, do_classifier_free_guidance=False):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        if isinstance(image, torch.Tensor):\n            # Batch single image\n            if image.ndim == 3:\n                assert image.shape[0] == 3, \"Image outside a batch should be of shape (3, H, W)\"\n                image = image.unsqueeze(0)\n\n            assert image.ndim == 4, \"Image must have 4 dimensions\"\n\n            # Check image is in [-1, 1]\n            if image.min() < -1 or image.max() > 1:\n                raise ValueError(\"Image should be in [-1, 1] range\")\n        else:\n            # preprocess image\n            if isinstance(image, (PIL.Image.Image, np.ndarray)):\n                image = [image]\n\n            if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n                image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n                image = np.concatenate(image, axis=0)\n            elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n                image = np.concatenate([i[None, :] for i in image], axis=0)\n\n            image = image.transpose(0, 3, 1, 2)\n            image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n        image = image.to(device=device, dtype=dtype)\n\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if isinstance(generator, list):\n            init_latents = [\n                self.vae.encode(image[i : i + 1]).latent_dist.mode(generator[i])\n                for i in range(batch_size)  # sample\n            ]\n            init_latents = torch.cat(init_latents, dim=0)\n        else:\n            init_latents = self.vae.encode(image).latent_dist.mode()\n\n        # init_latents = self.vae.config.scaling_factor * init_latents  # todo in original zero123's inference gradio_new.py, model.encode_first_stage() is not scaled by scaling_factor\n        if batch_size > init_latents.shape[0]:\n            # init_latents = init_latents.repeat(batch_size // init_latents.shape[0], 1, 1, 1)\n            num_images_per_prompt = batch_size // init_latents.shape[0]\n            # duplicate image latents for each generation per prompt, using mps friendly method\n            bs_embed, emb_c, emb_h, emb_w = init_latents.shape\n            init_latents = init_latents.unsqueeze(1)\n            init_latents = init_latents.repeat(1, num_images_per_prompt, 1, 1, 1)\n            init_latents = init_latents.view(bs_embed * num_images_per_prompt, emb_c, emb_h, emb_w)\n\n        # init_latents = torch.cat([init_latents]*2) if do_classifier_free_guidance else init_latents   # follow zero123\n        init_latents = (\n            torch.cat([torch.zeros_like(init_latents), init_latents]) if do_classifier_free_guidance else init_latents\n        )\n\n        init_latents = init_latents.to(device=device, dtype=dtype)\n        return init_latents\n\n    # def load_cc_projection(self, pretrained_weights=None):\n    #     self.cc_projection = torch.nn.Linear(772, 768)\n    #     torch.nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768])\n    #     torch.nn.init.zeros_(list(self.cc_projection.parameters())[1])\n    #     if pretrained_weights is not None:\n    #         self.cc_projection.load_state_dict(pretrained_weights)\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        input_imgs: Union[torch.Tensor, PIL.Image.Image] = None,\n        prompt_imgs: Union[torch.Tensor, PIL.Image.Image] = None,\n        poses: Union[List[float], List[List[float]]] = None,\n        torch_dtype=torch.float32,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 3.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: float = 1.0,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            input_imgs (`PIL` or `List[PIL]`, *optional*):\n                The single input image for each 3D object\n            prompt_imgs (`PIL` or `List[PIL]`, *optional*):\n                Same as input_imgs, but will be used later as an image prompt condition, encoded by CLIP feature\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under\n                `self.processor` in\n                [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        # input_image = hint_imgs\n        self.check_inputs(input_imgs, height, width, callback_steps)\n\n        # 2. Define call parameters\n        if isinstance(input_imgs, PIL.Image.Image):\n            batch_size = 1\n        elif isinstance(input_imgs, list):\n            batch_size = len(input_imgs)\n        else:\n            batch_size = input_imgs.shape[0]\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input image with pose as prompt\n        prompt_embeds = self._encode_image_with_pose(\n            prompt_imgs, poses, device, num_images_per_prompt, do_classifier_free_guidance\n        )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            4,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare image latents\n        img_latents = self.prepare_img_latents(\n            input_imgs,\n            batch_size * num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            do_classifier_free_guidance,\n        )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n                latent_model_input = torch.cat([latent_model_input, img_latents], dim=1)\n\n                # predict the noise residual\n                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        # 8. Post-processing\n        has_nsfw_concept = None\n        if output_type == \"latent\":\n            image = latents\n        elif output_type == \"pil\":\n            # 8. Post-processing\n            image = self.decode_latents(latents)\n            # 10. Convert to PIL\n            image = self.numpy_to_pil(image)\n        else:\n            # 8. Post-processing\n            image = self.decode_latents(latents)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/pipline_flux_fill_controlnet_Inpaint.py",
    "content": "import inspect\r\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\r\n\r\nimport numpy as np\r\nimport PIL\r\nimport torch\r\nfrom transformers import (\r\n    CLIPTextModel,\r\n    CLIPTokenizer,\r\n    T5EncoderModel,\r\n    T5TokenizerFast,\r\n)\r\n\r\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\r\nfrom diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin\r\nfrom diffusers.models.autoencoders import AutoencoderKL\r\nfrom diffusers.models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel\r\nfrom diffusers.models.transformers import FluxTransformer2DModel\r\nfrom diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput\r\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\r\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\r\nfrom diffusers.utils import (\r\n    USE_PEFT_BACKEND,\r\n    is_torch_xla_available,\r\n    logging,\r\n    replace_example_docstring,\r\n    scale_lora_layers,\r\n    unscale_lora_layers,\r\n)\r\nfrom diffusers.utils.torch_utils import randn_tensor\r\n\r\n\r\nif is_torch_xla_available():\r\n    import torch_xla.core.xla_model as xm\r\n\r\n    XLA_AVAILABLE = True\r\nelse:\r\n    XLA_AVAILABLE = False\r\n\r\nlogger = logging.get_logger(__name__)\r\n\r\nEXAMPLE_DOC_STRING = \"\"\"\r\n    Examples:\r\n        ```py\r\n        >>> import torch\r\n        >>> from diffusers import FluxControlNetInpaintPipeline\r\n        >>> from diffusers.models import FluxControlNetModel\r\n        >>> from diffusers.utils import load_image\r\n\r\n        >>> controlnet = FluxControlNetModel.from_pretrained(\r\n        ...     \"InstantX/FLUX.1-dev-controlnet-canny\", torch_dtype=torch.float16\r\n        ... )\r\n        >>> pipe = FluxControlNetInpaintPipeline.from_pretrained(\r\n        ...     \"black-forest-labs/FLUX.1-schnell\", controlnet=controlnet, torch_dtype=torch.float16\r\n        ... )\r\n        >>> pipe.to(\"cuda\")\r\n\r\n        >>> control_image = load_image(\r\n        ...     \"https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg\"\r\n        ... )\r\n        >>> init_image = load_image(\r\n        ...     \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\r\n        ... )\r\n        >>> mask_image = load_image(\r\n        ...     \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\r\n        ... )\r\n\r\n        >>> prompt = \"A girl holding a sign that says InstantX\"\r\n        >>> image = pipe(\r\n        ...     prompt,\r\n        ...     image=init_image,\r\n        ...     mask_image=mask_image,\r\n        ...     control_image=control_image,\r\n        ...     control_guidance_start=0.2,\r\n        ...     control_guidance_end=0.8,\r\n        ...     controlnet_conditioning_scale=0.7,\r\n        ...     strength=0.7,\r\n        ...     num_inference_steps=28,\r\n        ...     guidance_scale=3.5,\r\n        ... ).images[0]\r\n        >>> image.save(\"flux_controlnet_inpaint.png\")\r\n        ```\r\n\"\"\"\r\n\r\n\r\n# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift\r\ndef calculate_shift(\r\n    image_seq_len,\r\n    base_seq_len: int = 256,\r\n    max_seq_len: int = 4096,\r\n    base_shift: float = 0.5,\r\n    max_shift: float = 1.15,\r\n):\r\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\r\n    b = base_shift - m * base_seq_len\r\n    mu = image_seq_len * m + b\r\n    return mu\r\n\r\n\r\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\r\ndef retrieve_latents(\r\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\r\n):\r\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\r\n        return encoder_output.latent_dist.sample(generator)\r\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\r\n        return encoder_output.latent_dist.mode()\r\n    elif hasattr(encoder_output, \"latents\"):\r\n        return encoder_output.latents\r\n    else:\r\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\r\n\r\n\r\ndef retrieve_latents_fill(\r\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\r\n):\r\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\r\n        return encoder_output.latent_dist.sample(generator)\r\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\r\n        return encoder_output.latent_dist.mode()\r\n    elif hasattr(encoder_output, \"latents\"):\r\n        return encoder_output.latents\r\n    else:\r\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\r\n\r\n\r\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\r\ndef retrieve_timesteps(\r\n    scheduler,\r\n    num_inference_steps: Optional[int] = None,\r\n    device: Optional[Union[str, torch.device]] = None,\r\n    timesteps: Optional[List[int]] = None,\r\n    sigmas: Optional[List[float]] = None,\r\n    **kwargs,\r\n):\r\n    r\"\"\"\r\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\r\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\r\n\r\n    Args:\r\n        scheduler (`SchedulerMixin`):\r\n            The scheduler to get timesteps from.\r\n        num_inference_steps (`int`):\r\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\r\n            must be `None`.\r\n        device (`str` or `torch.device`, *optional*):\r\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\r\n        timesteps (`List[int]`, *optional*):\r\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\r\n            `num_inference_steps` and `sigmas` must be `None`.\r\n        sigmas (`List[float]`, *optional*):\r\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\r\n            `num_inference_steps` and `timesteps` must be `None`.\r\n\r\n    Returns:\r\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\r\n        second element is the number of inference steps.\r\n    \"\"\"\r\n    if timesteps is not None and sigmas is not None:\r\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\r\n    if timesteps is not None:\r\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\r\n        if not accepts_timesteps:\r\n            raise ValueError(\r\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\r\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\r\n            )\r\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\r\n        timesteps = scheduler.timesteps\r\n        num_inference_steps = len(timesteps)\r\n    elif sigmas is not None:\r\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\r\n        if not accept_sigmas:\r\n            raise ValueError(\r\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\r\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\r\n            )\r\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\r\n        timesteps = scheduler.timesteps\r\n        num_inference_steps = len(timesteps)\r\n    else:\r\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\r\n        timesteps = scheduler.timesteps\r\n    return timesteps, num_inference_steps\r\n\r\n\r\nclass FluxControlNetFillInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):\r\n    r\"\"\"\r\n    The Flux controlnet pipeline for inpainting.\r\n\r\n    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/\r\n\r\n    Args:\r\n        transformer ([`FluxTransformer2DModel`]):\r\n            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.\r\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\r\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\r\n        vae ([`AutoencoderKL`]):\r\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\r\n        text_encoder ([`CLIPTextModel`]):\r\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\r\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\r\n        text_encoder_2 ([`T5EncoderModel`]):\r\n            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically\r\n            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.\r\n        tokenizer (`CLIPTokenizer`):\r\n            Tokenizer of class\r\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).\r\n        tokenizer_2 (`T5TokenizerFast`):\r\n            Second Tokenizer of class\r\n            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).\r\n    \"\"\"\r\n\r\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->transformer->vae\"\r\n    _optional_components = []\r\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"control_image\", \"mask\", \"masked_image_latents\"]\r\n\r\n    def __init__(\r\n        self,\r\n        scheduler: FlowMatchEulerDiscreteScheduler,\r\n        vae: AutoencoderKL,\r\n        text_encoder: CLIPTextModel,\r\n        tokenizer: CLIPTokenizer,\r\n        text_encoder_2: T5EncoderModel,\r\n        tokenizer_2: T5TokenizerFast,\r\n        transformer: FluxTransformer2DModel,\r\n        controlnet: Union[\r\n            FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel\r\n        ],\r\n    ):\r\n        super().__init__()\r\n        if isinstance(controlnet, (list, tuple)):\r\n            controlnet = FluxMultiControlNetModel(controlnet)\r\n\r\n        self.register_modules(\r\n            scheduler=scheduler,\r\n            vae=vae,\r\n            text_encoder=text_encoder,\r\n            tokenizer=tokenizer,\r\n            text_encoder_2=text_encoder_2,\r\n            tokenizer_2=tokenizer_2,\r\n            transformer=transformer,\r\n            controlnet=controlnet,\r\n        )\r\n\r\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\r\n        # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible\r\n        # by the patch size. So the vae scale factor is multiplied by the patch size to account for this\r\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)\r\n        latent_channels = self.vae.config.latent_channels if getattr(self, \"vae\", None) else 16\r\n        self.mask_processor = VaeImageProcessor(\r\n            vae_scale_factor=self.vae_scale_factor * 2,\r\n            vae_latent_channels=latent_channels,\r\n            do_normalize=False,\r\n            do_binarize=True,\r\n            do_convert_grayscale=True,\r\n        )\r\n        self.tokenizer_max_length = (\r\n            self.tokenizer.model_max_length if hasattr(self, \"tokenizer\") and self.tokenizer is not None else 77\r\n        )\r\n        self.default_sample_size = 128\r\n\r\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds\r\n    def _get_t5_prompt_embeds(\r\n        self,\r\n        prompt: Union[str, List[str]] = None,\r\n        num_images_per_prompt: int = 1,\r\n        max_sequence_length: int = 512,\r\n        device: Optional[torch.device] = None,\r\n        dtype: Optional[torch.dtype] = None,\r\n    ):\r\n        device = device or self._execution_device\r\n        dtype = dtype or self.text_encoder.dtype\r\n\r\n        prompt = [prompt] if isinstance(prompt, str) else prompt\r\n        batch_size = len(prompt)\r\n\r\n        if isinstance(self, TextualInversionLoaderMixin):\r\n            prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)\r\n\r\n        text_inputs = self.tokenizer_2(\r\n            prompt,\r\n            padding=\"max_length\",\r\n            max_length=max_sequence_length,\r\n            truncation=True,\r\n            return_length=False,\r\n            return_overflowing_tokens=False,\r\n            return_tensors=\"pt\",\r\n        )\r\n        text_input_ids = text_inputs.input_ids\r\n        untruncated_ids = self.tokenizer_2(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\r\n\r\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\r\n            removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\r\n            logger.warning(\r\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\r\n                f\" {max_sequence_length} tokens: {removed_text}\"\r\n            )\r\n\r\n        prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]\r\n\r\n        dtype = self.text_encoder_2.dtype\r\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\r\n\r\n        _, seq_len, _ = prompt_embeds.shape\r\n\r\n        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\r\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\r\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\r\n\r\n        return prompt_embeds\r\n\r\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds\r\n    def _get_clip_prompt_embeds(\r\n        self,\r\n        prompt: Union[str, List[str]],\r\n        num_images_per_prompt: int = 1,\r\n        device: Optional[torch.device] = None,\r\n    ):\r\n        device = device or self._execution_device\r\n\r\n        prompt = [prompt] if isinstance(prompt, str) else prompt\r\n        batch_size = len(prompt)\r\n\r\n        if isinstance(self, TextualInversionLoaderMixin):\r\n            prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\r\n\r\n        text_inputs = self.tokenizer(\r\n            prompt,\r\n            padding=\"max_length\",\r\n            max_length=self.tokenizer_max_length,\r\n            truncation=True,\r\n            return_overflowing_tokens=False,\r\n            return_length=False,\r\n            return_tensors=\"pt\",\r\n        )\r\n\r\n        text_input_ids = text_inputs.input_ids\r\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\r\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\r\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\r\n            logger.warning(\r\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\r\n                f\" {self.tokenizer_max_length} tokens: {removed_text}\"\r\n            )\r\n        prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)\r\n\r\n        # Use pooled output of CLIPTextModel\r\n        prompt_embeds = prompt_embeds.pooler_output\r\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\r\n\r\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\r\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)\r\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\r\n\r\n        return prompt_embeds\r\n\r\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt\r\n    def encode_prompt(\r\n        self,\r\n        prompt: Union[str, List[str]],\r\n        prompt_2: Union[str, List[str]],\r\n        device: Optional[torch.device] = None,\r\n        num_images_per_prompt: int = 1,\r\n        prompt_embeds: Optional[torch.FloatTensor] = None,\r\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\r\n        max_sequence_length: int = 512,\r\n        lora_scale: Optional[float] = None,\r\n    ):\r\n        r\"\"\"\r\n\r\n        Args:\r\n            prompt (`str` or `List[str]`, *optional*):\r\n                prompt to be encoded\r\n            prompt_2 (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\r\n                used in all text-encoders\r\n            device: (`torch.device`):\r\n                torch device\r\n            num_images_per_prompt (`int`):\r\n                number of images that should be generated per prompt\r\n            prompt_embeds (`torch.FloatTensor`, *optional*):\r\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\r\n                provided, text embeddings will be generated from `prompt` input argument.\r\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\r\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\r\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\r\n            lora_scale (`float`, *optional*):\r\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\r\n        \"\"\"\r\n        device = device or self._execution_device\r\n\r\n        # set lora scale so that monkey patched LoRA\r\n        # function of text encoder can correctly access it\r\n        if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):\r\n            self._lora_scale = lora_scale\r\n\r\n            # dynamically adjust the LoRA scale\r\n            if self.text_encoder is not None and USE_PEFT_BACKEND:\r\n                scale_lora_layers(self.text_encoder, lora_scale)\r\n            if self.text_encoder_2 is not None and USE_PEFT_BACKEND:\r\n                scale_lora_layers(self.text_encoder_2, lora_scale)\r\n\r\n        prompt = [prompt] if isinstance(prompt, str) else prompt\r\n\r\n        if prompt_embeds is None:\r\n            prompt_2 = prompt_2 or prompt\r\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\r\n\r\n            # We only use the pooled prompt output from the CLIPTextModel\r\n            pooled_prompt_embeds = self._get_clip_prompt_embeds(\r\n                prompt=prompt,\r\n                device=device,\r\n                num_images_per_prompt=num_images_per_prompt,\r\n            )\r\n            prompt_embeds = self._get_t5_prompt_embeds(\r\n                prompt=prompt_2,\r\n                num_images_per_prompt=num_images_per_prompt,\r\n                max_sequence_length=max_sequence_length,\r\n                device=device,\r\n            )\r\n\r\n        if self.text_encoder is not None:\r\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\r\n                # Retrieve the original scale by scaling back the LoRA layers\r\n                unscale_lora_layers(self.text_encoder, lora_scale)\r\n\r\n        if self.text_encoder_2 is not None:\r\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\r\n                # Retrieve the original scale by scaling back the LoRA layers\r\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\r\n\r\n        dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype\r\n        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\r\n\r\n        return prompt_embeds, pooled_prompt_embeds, text_ids\r\n\r\n    # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image\r\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\r\n        if isinstance(generator, list):\r\n            image_latents = [\r\n                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\r\n                for i in range(image.shape[0])\r\n            ]\r\n            image_latents = torch.cat(image_latents, dim=0)\r\n        else:\r\n            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)\r\n\r\n        image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor\r\n\r\n        return image_latents\r\n\r\n    # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps\r\n    def get_timesteps(self, num_inference_steps, strength, device):\r\n        # get the original timestep using init_timestep\r\n        init_timestep = min(num_inference_steps * strength, num_inference_steps)\r\n\r\n        t_start = int(max(num_inference_steps - init_timestep, 0))\r\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\r\n        if hasattr(self.scheduler, \"set_begin_index\"):\r\n            self.scheduler.set_begin_index(t_start * self.scheduler.order)\r\n\r\n        return timesteps, num_inference_steps - t_start\r\n\r\n    def check_inputs(\r\n        self,\r\n        prompt,\r\n        prompt_2,\r\n        image,\r\n        mask_image,\r\n        strength,\r\n        height,\r\n        width,\r\n        output_type,\r\n        prompt_embeds=None,\r\n        pooled_prompt_embeds=None,\r\n        callback_on_step_end_tensor_inputs=None,\r\n        padding_mask_crop=None,\r\n        max_sequence_length=None,\r\n    ):\r\n        if strength < 0 or strength > 1:\r\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\r\n\r\n        if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:\r\n            logger.warning(\r\n                f\"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly\"\r\n            )\r\n\r\n        if callback_on_step_end_tensor_inputs is not None and not all(\r\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\r\n        ):\r\n            raise ValueError(\r\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\r\n            )\r\n\r\n        if prompt is not None and prompt_embeds is not None:\r\n            raise ValueError(\r\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\r\n                \" only forward one of the two.\"\r\n            )\r\n        elif prompt_2 is not None and prompt_embeds is not None:\r\n            raise ValueError(\r\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\r\n                \" only forward one of the two.\"\r\n            )\r\n        elif prompt is None and prompt_embeds is None:\r\n            raise ValueError(\r\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\r\n            )\r\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\r\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\r\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\r\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\r\n\r\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\r\n            raise ValueError(\r\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\r\n            )\r\n\r\n        if padding_mask_crop is not None:\r\n            if not isinstance(image, PIL.Image.Image):\r\n                raise ValueError(\r\n                    f\"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}.\"\r\n                )\r\n            if not isinstance(mask_image, PIL.Image.Image):\r\n                raise ValueError(\r\n                    f\"The mask image should be a PIL image when inpainting mask crop, but is of type\"\r\n                    f\" {type(mask_image)}.\"\r\n                )\r\n            if output_type != \"pil\":\r\n                raise ValueError(f\"The output type should be PIL when inpainting mask crop, but is {output_type}.\")\r\n\r\n        if max_sequence_length is not None and max_sequence_length > 512:\r\n            raise ValueError(f\"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}\")\r\n\r\n    @staticmethod\r\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids\r\n    def _prepare_latent_image_ids(batch_size, height, width, device, dtype):\r\n        latent_image_ids = torch.zeros(height, width, 3)\r\n        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]\r\n        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]\r\n\r\n        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape\r\n\r\n        latent_image_ids = latent_image_ids.reshape(\r\n            latent_image_id_height * latent_image_id_width, latent_image_id_channels\r\n        )\r\n\r\n        return latent_image_ids.to(device=device, dtype=dtype)\r\n\r\n    @staticmethod\r\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents\r\n    def _pack_latents(latents, batch_size, num_channels_latents, height, width):\r\n        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)\r\n        latents = latents.permute(0, 2, 4, 1, 3, 5)\r\n        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)\r\n\r\n        return latents\r\n\r\n    @staticmethod\r\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents\r\n    def _unpack_latents(latents, height, width, vae_scale_factor):\r\n        batch_size, num_patches, channels = latents.shape\r\n\r\n        # VAE applies 8x compression on images but we must also account for packing which requires\r\n        # latent height and width to be divisible by 2.\r\n        height = 2 * (int(height) // (vae_scale_factor * 2))\r\n        width = 2 * (int(width) // (vae_scale_factor * 2))\r\n\r\n        latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)\r\n        latents = latents.permute(0, 3, 1, 4, 2, 5)\r\n\r\n        latents = latents.reshape(batch_size, channels // (2 * 2), height, width)\r\n\r\n        return latents\r\n\r\n    def prepare_latents(\r\n        self,\r\n        image,\r\n        timestep,\r\n        batch_size,\r\n        num_channels_latents,\r\n        height,\r\n        width,\r\n        dtype,\r\n        device,\r\n        generator,\r\n        latents=None,\r\n    ):\r\n        if isinstance(generator, list) and len(generator) != batch_size:\r\n            raise ValueError(\r\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\r\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\r\n            )\r\n\r\n        # VAE applies 8x compression on images but we must also account for packing which requires\r\n        # latent height and width to be divisible by 2.\r\n        height = 2 * (int(height) // (self.vae_scale_factor * 2))\r\n        width = 2 * (int(width) // (self.vae_scale_factor * 2))\r\n        shape = (batch_size, num_channels_latents, height, width)\r\n        latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)\r\n\r\n        image = image.to(device=device, dtype=dtype)\r\n        image_latents = self._encode_vae_image(image=image, generator=generator)\r\n\r\n        if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:\r\n            # expand init_latents for batch_size\r\n            additional_image_per_prompt = batch_size // image_latents.shape[0]\r\n            image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)\r\n        elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:\r\n            raise ValueError(\r\n                f\"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.\"\r\n            )\r\n        else:\r\n            image_latents = torch.cat([image_latents], dim=0)\r\n\r\n        if latents is None:\r\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\r\n            latents = self.scheduler.scale_noise(image_latents, timestep, noise)\r\n        else:\r\n            noise = latents.to(device)\r\n            latents = noise\r\n\r\n        noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)\r\n        image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)\r\n        latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)\r\n\r\n        return latents, noise, image_latents, latent_image_ids\r\n\r\n    def prepare_mask_latents(\r\n        self,\r\n        mask,\r\n        masked_image,\r\n        batch_size,\r\n        num_channels_latents,\r\n        num_images_per_prompt,\r\n        height,\r\n        width,\r\n        dtype,\r\n        device,\r\n        generator,\r\n    ):\r\n        # VAE applies 8x compression on images but we must also account for packing which requires\r\n        # latent height and width to be divisible by 2.\r\n        height = 2 * (int(height) // (self.vae_scale_factor * 2))\r\n        width = 2 * (int(width) // (self.vae_scale_factor * 2))\r\n        # resize the mask to latents shape as we concatenate the mask to the latents\r\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\r\n        # and half precision\r\n        mask = torch.nn.functional.interpolate(mask, size=(height, width))\r\n        mask = mask.to(device=device, dtype=dtype)\r\n\r\n        batch_size = batch_size * num_images_per_prompt\r\n\r\n        masked_image = masked_image.to(device=device, dtype=dtype)\r\n\r\n        if masked_image.shape[1] == 16:\r\n            masked_image_latents = masked_image\r\n        else:\r\n            masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)\r\n\r\n        masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor\r\n\r\n        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\r\n        if mask.shape[0] < batch_size:\r\n            if not batch_size % mask.shape[0] == 0:\r\n                raise ValueError(\r\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\r\n                    f\" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number\"\r\n                    \" of masks that you pass is divisible by the total requested batch size.\"\r\n                )\r\n            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)\r\n        if masked_image_latents.shape[0] < batch_size:\r\n            if not batch_size % masked_image_latents.shape[0] == 0:\r\n                raise ValueError(\r\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\r\n                    f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\r\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\r\n                )\r\n            masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)\r\n\r\n        # aligning device to prevent device errors when concating it with the latent model input\r\n        masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\r\n        masked_image_latents = self._pack_latents(\r\n            masked_image_latents,\r\n            batch_size,\r\n            num_channels_latents,\r\n            height,\r\n            width,\r\n        )\r\n\r\n        mask = self._pack_latents(\r\n            mask.repeat(1, num_channels_latents, 1, 1),\r\n            batch_size,\r\n            num_channels_latents,\r\n            height,\r\n            width,\r\n        )\r\n        return mask, masked_image_latents\r\n\r\n    # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image\r\n    def prepare_image(\r\n        self,\r\n        image,\r\n        width,\r\n        height,\r\n        batch_size,\r\n        num_images_per_prompt,\r\n        device,\r\n        dtype,\r\n        do_classifier_free_guidance=False,\r\n        guess_mode=False,\r\n    ):\r\n        if isinstance(image, torch.Tensor):\r\n            pass\r\n        else:\r\n            image = self.image_processor.preprocess(image, height=height, width=width)\r\n\r\n        image_batch_size = image.shape[0]\r\n\r\n        if image_batch_size == 1:\r\n            repeat_by = batch_size\r\n        else:\r\n            # image batch size is the same as prompt batch size\r\n            repeat_by = num_images_per_prompt\r\n\r\n        image = image.repeat_interleave(repeat_by, dim=0)\r\n\r\n        image = image.to(device=device, dtype=dtype)\r\n\r\n        if do_classifier_free_guidance and not guess_mode:\r\n            image = torch.cat([image] * 2)\r\n\r\n        return image\r\n\r\n    def prepare_mask_latents_fill(\r\n        self,\r\n        mask,\r\n        masked_image,\r\n        batch_size,\r\n        num_channels_latents,\r\n        num_images_per_prompt,\r\n        height,\r\n        width,\r\n        dtype,\r\n        device,\r\n        generator,\r\n    ):\r\n        # 1. calculate the height and width of the latents\r\n        # VAE applies 8x compression on images but we must also account for packing which requires\r\n        # latent height and width to be divisible by 2.\r\n        height = 2 * (int(height) // (self.vae_scale_factor * 2))\r\n        width = 2 * (int(width) // (self.vae_scale_factor * 2))\r\n\r\n        # 2. encode the masked image\r\n        if masked_image.shape[1] == num_channels_latents:\r\n            masked_image_latents = masked_image\r\n        else:\r\n            masked_image_latents = retrieve_latents_fill(self.vae.encode(masked_image), generator=generator)\r\n\r\n        masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor\r\n        masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\r\n\r\n        # 3. duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\r\n        batch_size = batch_size * num_images_per_prompt\r\n        if mask.shape[0] < batch_size:\r\n            if not batch_size % mask.shape[0] == 0:\r\n                raise ValueError(\r\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\r\n                    f\" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number\"\r\n                    \" of masks that you pass is divisible by the total requested batch size.\"\r\n                )\r\n            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)\r\n        if masked_image_latents.shape[0] < batch_size:\r\n            if not batch_size % masked_image_latents.shape[0] == 0:\r\n                raise ValueError(\r\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\r\n                    f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\r\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\r\n                )\r\n            masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)\r\n\r\n        # 4. pack the masked_image_latents\r\n        # batch_size, num_channels_latents, height, width -> batch_size, height//2 * width//2 , num_channels_latents*4\r\n        masked_image_latents = self._pack_latents(\r\n            masked_image_latents,\r\n            batch_size,\r\n            num_channels_latents,\r\n            height,\r\n            width,\r\n        )\r\n\r\n        # 5.resize mask to latents shape we we concatenate the mask to the latents\r\n        mask = mask[:, 0, :, :]  # batch_size, 8 * height, 8 * width (mask has not been 8x compressed)\r\n        mask = mask.view(\r\n            batch_size, height, self.vae_scale_factor, width, self.vae_scale_factor\r\n        )  # batch_size, height, 8, width, 8\r\n        mask = mask.permute(0, 2, 4, 1, 3)  # batch_size, 8, 8, height, width\r\n        mask = mask.reshape(\r\n            batch_size, self.vae_scale_factor * self.vae_scale_factor, height, width\r\n        )  # batch_size, 8*8, height, width\r\n\r\n        # 6. pack the mask:\r\n        # batch_size, 64, height, width -> batch_size, height//2 * width//2 , 64*2*2\r\n        mask = self._pack_latents(\r\n            mask,\r\n            batch_size,\r\n            self.vae_scale_factor * self.vae_scale_factor,\r\n            height,\r\n            width,\r\n        )\r\n        mask = mask.to(device=device, dtype=dtype)\r\n\r\n        return mask, masked_image_latents\r\n\r\n    @property\r\n    def guidance_scale(self):\r\n        return self._guidance_scale\r\n\r\n    @property\r\n    def joint_attention_kwargs(self):\r\n        return self._joint_attention_kwargs\r\n\r\n    @property\r\n    def num_timesteps(self):\r\n        return self._num_timesteps\r\n\r\n    @property\r\n    def interrupt(self):\r\n        return self._interrupt\r\n\r\n    @torch.no_grad()\r\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\r\n    def __call__(\r\n        self,\r\n        prompt: Union[str, List[str]] = None,\r\n        prompt_2: Optional[Union[str, List[str]]] = None,\r\n        image: PipelineImageInput = None,\r\n        mask_image: PipelineImageInput = None,\r\n        masked_image_latents: PipelineImageInput = None,\r\n        control_image: PipelineImageInput = None,\r\n        height: Optional[int] = None,\r\n        width: Optional[int] = None,\r\n        strength: float = 0.6,\r\n        padding_mask_crop: Optional[int] = None,\r\n        sigmas: Optional[List[float]] = None,\r\n        num_inference_steps: int = 28,\r\n        guidance_scale: float = 7.0,\r\n        control_guidance_start: Union[float, List[float]] = 0.0,\r\n        control_guidance_end: Union[float, List[float]] = 1.0,\r\n        control_mode: Optional[Union[int, List[int]]] = None,\r\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\r\n        num_images_per_prompt: Optional[int] = 1,\r\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\r\n        latents: Optional[torch.FloatTensor] = None,\r\n        prompt_embeds: Optional[torch.FloatTensor] = None,\r\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\r\n        output_type: str | None = \"pil\",\r\n        return_dict: bool = True,\r\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\r\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\r\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\r\n        max_sequence_length: int = 512,\r\n    ):\r\n        \"\"\"\r\n        Function invoked when calling the pipeline for generation.\r\n\r\n        Args:\r\n            prompt (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts to guide the image generation.\r\n            prompt_2 (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.\r\n            image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):\r\n                The image(s) to inpaint.\r\n            mask_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):\r\n                The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels\r\n                will be preserved.\r\n            masked_image_latents (`torch.FloatTensor`, *optional*):\r\n                Pre-generated masked image latents.\r\n            control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):\r\n                The ControlNet input condition. Image to control the generation.\r\n            height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):\r\n                The height in pixels of the generated image.\r\n            width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):\r\n                The width in pixels of the generated image.\r\n            strength (`float`, *optional*, defaults to 0.6):\r\n                Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1.\r\n            padding_mask_crop (`int`, *optional*):\r\n                The size of the padding to use when cropping the mask.\r\n            num_inference_steps (`int`, *optional*, defaults to 28):\r\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\r\n                expense of slower inference.\r\n            sigmas (`List[float]`, *optional*):\r\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\r\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\r\n                will be used.\r\n            guidance_scale (`float`, *optional*, defaults to 7.0):\r\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).\r\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\r\n                The percentage of total steps at which the ControlNet starts applying.\r\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\r\n                The percentage of total steps at which the ControlNet stops applying.\r\n            control_mode (`int` or `List[int]`, *optional*):\r\n                The mode for the ControlNet. If multiple ControlNets are used, this should be a list.\r\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\r\n                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added\r\n                to the residual in the original transformer.\r\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\r\n                The number of images to generate per prompt.\r\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\r\n                One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to\r\n                make generation deterministic.\r\n            latents (`torch.FloatTensor`, *optional*):\r\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\r\n                generation. Can be used to tweak the same generation with different prompts.\r\n            prompt_embeds (`torch.FloatTensor`, *optional*):\r\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\r\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\r\n                Pre-generated pooled text embeddings.\r\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\r\n                The output format of the generate image. Choose between `PIL.Image` or `np.array`.\r\n            return_dict (`bool`, *optional*, defaults to `True`):\r\n                Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.\r\n            joint_attention_kwargs (`dict`, *optional*):\r\n                Additional keyword arguments to be passed to the joint attention mechanism.\r\n            callback_on_step_end (`Callable`, *optional*):\r\n                A function that calls at the end of each denoising step during the inference.\r\n            callback_on_step_end_tensor_inputs (`List[str]`, *optional*):\r\n                The list of tensor inputs for the `callback_on_step_end` function.\r\n            max_sequence_length (`int`, *optional*, defaults to 512):\r\n                The maximum length of the sequence to be generated.\r\n\r\n        Examples:\r\n\r\n        Returns:\r\n            [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`\r\n            is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated\r\n            images.\r\n        \"\"\"\r\n        height = height or self.default_sample_size * self.vae_scale_factor\r\n        width = width or self.default_sample_size * self.vae_scale_factor\r\n\r\n        global_height = height\r\n        global_width = width\r\n\r\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\r\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\r\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\r\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\r\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\r\n            mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1\r\n            control_guidance_start, control_guidance_end = (\r\n                mult * [control_guidance_start],\r\n                mult * [control_guidance_end],\r\n            )\r\n\r\n        # 1. Check inputs\r\n        self.check_inputs(\r\n            prompt,\r\n            prompt_2,\r\n            image,\r\n            mask_image,\r\n            strength,\r\n            height,\r\n            width,\r\n            output_type=output_type,\r\n            prompt_embeds=prompt_embeds,\r\n            pooled_prompt_embeds=pooled_prompt_embeds,\r\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\r\n            padding_mask_crop=padding_mask_crop,\r\n            max_sequence_length=max_sequence_length,\r\n        )\r\n\r\n        self._guidance_scale = guidance_scale\r\n        self._joint_attention_kwargs = joint_attention_kwargs\r\n        self._interrupt = False\r\n\r\n        # 2. Define call parameters\r\n        if prompt is not None and isinstance(prompt, str):\r\n            batch_size = 1\r\n        elif prompt is not None and isinstance(prompt, list):\r\n            batch_size = len(prompt)\r\n        else:\r\n            batch_size = prompt_embeds.shape[0]\r\n\r\n        device = self._execution_device\r\n        dtype = self.transformer.dtype\r\n\r\n        # 3. Encode input prompt\r\n        lora_scale = (\r\n            self.joint_attention_kwargs.get(\"scale\", None) if self.joint_attention_kwargs is not None else None\r\n        )\r\n        prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(\r\n            prompt=prompt,\r\n            prompt_2=prompt_2,\r\n            prompt_embeds=prompt_embeds,\r\n            pooled_prompt_embeds=pooled_prompt_embeds,\r\n            device=device,\r\n            num_images_per_prompt=num_images_per_prompt,\r\n            max_sequence_length=max_sequence_length,\r\n            lora_scale=lora_scale,\r\n        )\r\n\r\n        # 4. Preprocess mask and image\r\n        if padding_mask_crop is not None:\r\n            crops_coords = self.mask_processor.get_crop_region(\r\n                mask_image, global_width, global_height, pad=padding_mask_crop\r\n            )\r\n            resize_mode = \"fill\"\r\n        else:\r\n            crops_coords = None\r\n            resize_mode = \"default\"\r\n\r\n        original_image = image\r\n        init_image = self.image_processor.preprocess(\r\n            image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode\r\n        )\r\n        init_image = init_image.to(dtype=torch.float32)\r\n\r\n        # 5. Prepare control image\r\n        # num_channels_latents = self.transformer.config.in_channels // 4\r\n        num_channels_latents = self.vae.config.latent_channels\r\n\r\n        if isinstance(self.controlnet, FluxControlNetModel):\r\n            control_image = self.prepare_image(\r\n                image=control_image,\r\n                width=width,\r\n                height=height,\r\n                batch_size=batch_size * num_images_per_prompt,\r\n                num_images_per_prompt=num_images_per_prompt,\r\n                device=device,\r\n                dtype=self.vae.dtype,\r\n            )\r\n            height, width = control_image.shape[-2:]\r\n\r\n            # xlab controlnet has a input_hint_block and instantx controlnet does not\r\n            controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True\r\n            if self.controlnet.input_hint_block is None:\r\n                # vae encode\r\n                control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)\r\n                control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor\r\n\r\n                # pack\r\n                height_control_image, width_control_image = control_image.shape[2:]\r\n                control_image = self._pack_latents(\r\n                    control_image,\r\n                    batch_size * num_images_per_prompt,\r\n                    num_channels_latents,\r\n                    height_control_image,\r\n                    width_control_image,\r\n                )\r\n\r\n            # set control mode\r\n            if control_mode is not None:\r\n                control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)\r\n                control_mode = control_mode.reshape([-1, 1])\r\n\r\n        elif isinstance(self.controlnet, FluxMultiControlNetModel):\r\n            control_images = []\r\n\r\n            # xlab controlnet has a input_hint_block and instantx controlnet does not\r\n            controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True\r\n            for i, control_image_ in enumerate(control_image):\r\n                control_image_ = self.prepare_image(\r\n                    image=control_image_,\r\n                    width=width,\r\n                    height=height,\r\n                    batch_size=batch_size * num_images_per_prompt,\r\n                    num_images_per_prompt=num_images_per_prompt,\r\n                    device=device,\r\n                    dtype=self.vae.dtype,\r\n                )\r\n                height, width = control_image_.shape[-2:]\r\n\r\n                if self.controlnet.nets[0].input_hint_block is None:\r\n                    # vae encode\r\n                    control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)\r\n                    control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor\r\n\r\n                    # pack\r\n                    height_control_image, width_control_image = control_image_.shape[2:]\r\n                    control_image_ = self._pack_latents(\r\n                        control_image_,\r\n                        batch_size * num_images_per_prompt,\r\n                        num_channels_latents,\r\n                        height_control_image,\r\n                        width_control_image,\r\n                    )\r\n\r\n                control_images.append(control_image_)\r\n\r\n            control_image = control_images\r\n\r\n            # set control mode\r\n            control_mode_ = []\r\n            if isinstance(control_mode, list):\r\n                for cmode in control_mode:\r\n                    if cmode is None:\r\n                        control_mode_.append(-1)\r\n                    else:\r\n                        control_mode_.append(cmode)\r\n            control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)\r\n            control_mode = control_mode.reshape([-1, 1])\r\n\r\n        # 6. Prepare timesteps\r\n\r\n        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas\r\n        image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * (\r\n            int(global_width) // self.vae_scale_factor // 2\r\n        )\r\n        mu = calculate_shift(\r\n            image_seq_len,\r\n            self.scheduler.config.get(\"base_image_seq_len\", 256),\r\n            self.scheduler.config.get(\"max_image_seq_len\", 4096),\r\n            self.scheduler.config.get(\"base_shift\", 0.5),\r\n            self.scheduler.config.get(\"max_shift\", 1.15),\r\n        )\r\n        timesteps, num_inference_steps = retrieve_timesteps(\r\n            self.scheduler,\r\n            num_inference_steps,\r\n            device,\r\n            sigmas=sigmas,\r\n            mu=mu,\r\n        )\r\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\r\n\r\n        if num_inference_steps < 1:\r\n            raise ValueError(\r\n                f\"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline\"\r\n                f\"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline.\"\r\n            )\r\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\r\n\r\n        # 7. Prepare latent variables\r\n\r\n        latents, noise, image_latents, latent_image_ids = self.prepare_latents(\r\n            init_image,\r\n            latent_timestep,\r\n            batch_size * num_images_per_prompt,\r\n            num_channels_latents,\r\n            global_height,\r\n            global_width,\r\n            prompt_embeds.dtype,\r\n            device,\r\n            generator,\r\n            latents,\r\n        )\r\n\r\n        # 8. Prepare mask latents\r\n        mask_condition = self.mask_processor.preprocess(\r\n            mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords\r\n        )\r\n        if masked_image_latents is None:\r\n            masked_image = init_image * (mask_condition < 0.5)\r\n        else:\r\n            masked_image = masked_image_latents\r\n\r\n        mask, masked_image_latents = self.prepare_mask_latents(\r\n            mask_condition,\r\n            masked_image,\r\n            batch_size,\r\n            num_channels_latents,\r\n            num_images_per_prompt,\r\n            global_height,\r\n            global_width,\r\n            prompt_embeds.dtype,\r\n            device,\r\n            generator,\r\n        )\r\n\r\n        mask_image_fill = self.mask_processor.preprocess(mask_image, height=height, width=width)\r\n        masked_image_fill = init_image * (1 - mask_image_fill)\r\n        masked_image_fill = masked_image_fill.to(dtype=self.vae.dtype, device=device)\r\n        mask_fill, masked_latents_fill = self.prepare_mask_latents_fill(\r\n            mask_image_fill,\r\n            masked_image_fill,\r\n            batch_size,\r\n            num_channels_latents,\r\n            num_images_per_prompt,\r\n            height,\r\n            width,\r\n            prompt_embeds.dtype,\r\n            device,\r\n            generator,\r\n        )\r\n\r\n        controlnet_keep = []\r\n        for i in range(len(timesteps)):\r\n            keeps = [\r\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\r\n                for s, e in zip(control_guidance_start, control_guidance_end)\r\n            ]\r\n            controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)\r\n\r\n        # 9. Denoising loop\r\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\r\n        self._num_timesteps = len(timesteps)\r\n\r\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\r\n            for i, t in enumerate(timesteps):\r\n                if self.interrupt:\r\n                    continue\r\n\r\n                timestep = t.expand(latents.shape[0]).to(latents.dtype)\r\n\r\n                # predict the noise residual\r\n                if isinstance(self.controlnet, FluxMultiControlNetModel):\r\n                    use_guidance = self.controlnet.nets[0].config.guidance_embeds\r\n                else:\r\n                    use_guidance = self.controlnet.config.guidance_embeds\r\n                if use_guidance:\r\n                    guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)\r\n                    guidance = guidance.expand(latents.shape[0])\r\n                else:\r\n                    guidance = None\r\n\r\n                if isinstance(controlnet_keep[i], list):\r\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\r\n                else:\r\n                    controlnet_cond_scale = controlnet_conditioning_scale\r\n                    if isinstance(controlnet_cond_scale, list):\r\n                        controlnet_cond_scale = controlnet_cond_scale[0]\r\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\r\n\r\n                controlnet_block_samples, controlnet_single_block_samples = self.controlnet(\r\n                    hidden_states=latents,\r\n                    controlnet_cond=control_image,\r\n                    controlnet_mode=control_mode,\r\n                    conditioning_scale=cond_scale,\r\n                    timestep=timestep / 1000,\r\n                    guidance=guidance,\r\n                    pooled_projections=pooled_prompt_embeds,\r\n                    encoder_hidden_states=prompt_embeds,\r\n                    txt_ids=text_ids,\r\n                    img_ids=latent_image_ids,\r\n                    joint_attention_kwargs=self.joint_attention_kwargs,\r\n                    return_dict=False,\r\n                )\r\n\r\n                if self.transformer.config.guidance_embeds:\r\n                    guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)\r\n                    guidance = guidance.expand(latents.shape[0])\r\n                else:\r\n                    guidance = None\r\n\r\n                masked_image_latents_fill = torch.cat((masked_latents_fill, mask_fill), dim=-1)\r\n                latent_model_input = torch.cat([latents, masked_image_latents_fill], dim=2)\r\n\r\n                noise_pred = self.transformer(\r\n                    hidden_states=latent_model_input,\r\n                    timestep=timestep / 1000,\r\n                    guidance=guidance,\r\n                    pooled_projections=pooled_prompt_embeds,\r\n                    encoder_hidden_states=prompt_embeds,\r\n                    controlnet_block_samples=controlnet_block_samples,\r\n                    controlnet_single_block_samples=controlnet_single_block_samples,\r\n                    txt_ids=text_ids,\r\n                    img_ids=latent_image_ids,\r\n                    joint_attention_kwargs=self.joint_attention_kwargs,\r\n                    return_dict=False,\r\n                    controlnet_blocks_repeat=controlnet_blocks_repeat,\r\n                )[0]\r\n\r\n                # compute the previous noisy sample x_t -> x_t-1\r\n                latents_dtype = latents.dtype\r\n                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\r\n\r\n                # For inpainting, we need to apply the mask and add the masked image latents\r\n                init_latents_proper = image_latents\r\n                init_mask = mask\r\n\r\n                if i < len(timesteps) - 1:\r\n                    noise_timestep = timesteps[i + 1]\r\n                    init_latents_proper = self.scheduler.scale_noise(\r\n                        init_latents_proper, torch.tensor([noise_timestep]), noise\r\n                    )\r\n\r\n                latents = (1 - init_mask) * init_latents_proper + init_mask * latents\r\n\r\n                if latents.dtype != latents_dtype:\r\n                    if torch.backends.mps.is_available():\r\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\r\n                        latents = latents.to(latents_dtype)\r\n\r\n                # call the callback, if provided\r\n                if callback_on_step_end is not None:\r\n                    callback_kwargs = {}\r\n                    for k in callback_on_step_end_tensor_inputs:\r\n                        callback_kwargs[k] = locals()[k]\r\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\r\n\r\n                    latents = callback_outputs.pop(\"latents\", latents)\r\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\r\n                    control_image = callback_outputs.pop(\"control_image\", control_image)\r\n                    mask = callback_outputs.pop(\"mask\", mask)\r\n                    masked_image_latents = callback_outputs.pop(\"masked_image_latents\", masked_image_latents)\r\n\r\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\r\n                    progress_bar.update()\r\n\r\n                if XLA_AVAILABLE:\r\n                    xm.mark_step()\r\n\r\n        # Post-processing\r\n        if output_type == \"latent\":\r\n            image = latents\r\n        else:\r\n            latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor)\r\n            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor\r\n            image = self.vae.decode(latents, return_dict=False)[0]\r\n            image = self.image_processor.postprocess(image, output_type=output_type)\r\n\r\n        # Offload all models\r\n        self.maybe_free_model_hooks()\r\n\r\n        if not return_dict:\r\n            return (image,)\r\n\r\n        return FluxPipelineOutput(images=image)\r\n"
  },
  {
    "path": "examples/community/regional_prompting_stable_diffusion.py",
    "content": "import inspect\r\nimport math\r\nfrom typing import Any, Callable, Dict, List, Optional, Union\r\n\r\nimport torch\r\nimport torchvision.transforms.functional as FF\r\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\r\n\r\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\r\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\r\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin\r\nfrom diffusers.loaders.ip_adapter import IPAdapterMixin\r\nfrom diffusers.loaders.lora_pipeline import LoraLoaderMixin\r\nfrom diffusers.loaders.single_file import FromSingleFileMixin\r\nfrom diffusers.loaders.textual_inversion import TextualInversionLoaderMixin\r\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\r\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\r\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\r\nfrom diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput\r\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\r\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\r\nfrom diffusers.utils import (\r\n    USE_PEFT_BACKEND,\r\n    deprecate,\r\n    is_torch_xla_available,\r\n    logging,\r\n    scale_lora_layers,\r\n    unscale_lora_layers,\r\n)\r\nfrom diffusers.utils.torch_utils import randn_tensor\r\n\r\n\r\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\r\n\r\nif is_torch_xla_available():\r\n    import torch_xla.core.xla_model as xm\r\n\r\n    XLA_AVAILABLE = True\r\nelse:\r\n    XLA_AVAILABLE = False\r\n\r\n\r\ntry:\r\n    from compel import Compel\r\nexcept ImportError:\r\n    Compel = None\r\n\r\nKBASE = \"ADDBASE\"\r\nKCOMM = \"ADDCOMM\"\r\nKBRK = \"BREAK\"\r\n\r\n\r\nclass RegionalPromptingStableDiffusionPipeline(\r\n    DiffusionPipeline,\r\n    TextualInversionLoaderMixin,\r\n    LoraLoaderMixin,\r\n    IPAdapterMixin,\r\n    FromSingleFileMixin,\r\n    StableDiffusionLoraLoaderMixin,\r\n):\r\n    r\"\"\"\r\n    Args for Regional Prompting Pipeline:\r\n        rp_args:dict\r\n        Required\r\n            rp_args[\"mode\"]: cols, rows, prompt, prompt-ex\r\n        for cols, rows mode\r\n            rp_args[\"div\"]: ex) 1;1;1(Divide into 3 regions)\r\n        for prompt, prompt-ex mode\r\n            rp_args[\"th\"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)\r\n\r\n        Optional\r\n            rp_args[\"save_mask\"]: True/False (save masks in prompt mode)\r\n            rp_args[\"power\"]: int (power for attention maps in prompt mode)\r\n            rp_args[\"base_ratio\"]:\r\n                float (Sets the ratio of the base prompt)\r\n                ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)\r\n                [Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)\r\n\r\n    Pipeline for text-to-image generation using Stable Diffusion.\r\n\r\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\r\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\r\n\r\n    Args:\r\n        vae ([`AutoencoderKL`]):\r\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\r\n        text_encoder ([`CLIPTextModel`]):\r\n            Frozen text-encoder. Stable Diffusion uses the text portion of\r\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\r\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\r\n        tokenizer (`CLIPTokenizer`):\r\n            Tokenizer of class\r\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\r\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\r\n        scheduler ([`SchedulerMixin`]):\r\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\r\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\r\n        safety_checker ([`StableDiffusionSafetyChecker`]):\r\n            Classification module that estimates whether generated images could be considered offensive or harmful.\r\n            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.\r\n        feature_extractor ([`CLIPImageProcessor`]):\r\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\r\n    \"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n        vae: AutoencoderKL,\r\n        text_encoder: CLIPTextModel,\r\n        tokenizer: CLIPTokenizer,\r\n        unet: UNet2DConditionModel,\r\n        scheduler: KarrasDiffusionSchedulers,\r\n        safety_checker: StableDiffusionSafetyChecker,\r\n        feature_extractor: CLIPImageProcessor,\r\n        image_encoder: CLIPVisionModelWithProjection = None,\r\n        requires_safety_checker: bool = True,\r\n    ):\r\n        super().__init__()\r\n        self.register_modules(\r\n            vae=vae,\r\n            text_encoder=text_encoder,\r\n            tokenizer=tokenizer,\r\n            unet=unet,\r\n            scheduler=scheduler,\r\n            safety_checker=safety_checker,\r\n            feature_extractor=feature_extractor,\r\n            image_encoder=image_encoder,\r\n        )\r\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\r\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\r\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\r\n\r\n        # Initialize additional properties needed for DiffusionPipeline\r\n        self._num_timesteps = None\r\n        self._interrupt = False\r\n        self._guidance_scale = 7.5\r\n        self._guidance_rescale = 0.0\r\n        self._clip_skip = None\r\n        self._cross_attention_kwargs = None\r\n\r\n    @torch.no_grad()\r\n    def __call__(\r\n        self,\r\n        prompt: str,\r\n        height: int = 512,\r\n        width: int = 512,\r\n        num_inference_steps: int = 50,\r\n        guidance_scale: float = 7.5,\r\n        negative_prompt: str = None,\r\n        num_images_per_prompt: Optional[int] = 1,\r\n        eta: float = 0.0,\r\n        generator: torch.Generator | None = None,\r\n        latents: Optional[torch.Tensor] = None,\r\n        output_type: str | None = \"pil\",\r\n        return_dict: bool = True,\r\n        rp_args: Dict[str, str] = None,\r\n    ):\r\n        active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt\r\n        use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt\r\n        if negative_prompt is None:\r\n            negative_prompt = \"\" if isinstance(prompt, str) else [\"\"] * len(prompt)\r\n\r\n        device = self._execution_device\r\n        regions = 0\r\n\r\n        self.base_ratio = float(rp_args[\"base_ratio\"]) if \"base_ratio\" in rp_args else 0.0\r\n        self.power = int(rp_args[\"power\"]) if \"power\" in rp_args else 1\r\n\r\n        prompts = prompt if isinstance(prompt, list) else [prompt]\r\n        n_prompts = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt]\r\n        self.batch = batch = num_images_per_prompt * len(prompts)\r\n\r\n        if use_base:\r\n            bases = prompts.copy()\r\n            n_bases = n_prompts.copy()\r\n\r\n            for i, prompt in enumerate(prompts):\r\n                parts = prompt.split(KBASE)\r\n                if len(parts) == 2:\r\n                    bases[i], prompts[i] = parts\r\n                elif len(parts) > 2:\r\n                    raise ValueError(f\"Multiple instances of {KBASE} found in prompt: {prompt}\")\r\n            for i, prompt in enumerate(n_prompts):\r\n                n_parts = prompt.split(KBASE)\r\n                if len(n_parts) == 2:\r\n                    n_bases[i], n_prompts[i] = n_parts\r\n                elif len(n_parts) > 2:\r\n                    raise ValueError(f\"Multiple instances of {KBASE} found in negative prompt: {prompt}\")\r\n\r\n            all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt)\r\n            all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt)\r\n\r\n        all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)\r\n        all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)\r\n\r\n        equal = len(all_prompts_cn) == len(all_n_prompts_cn)\r\n\r\n        if Compel:\r\n            compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)\r\n\r\n            def getcompelembs(prps):\r\n                embl = []\r\n                for prp in prps:\r\n                    embl.append(compel.build_conditioning_tensor(prp))\r\n                return torch.cat(embl)\r\n\r\n            conds = getcompelembs(all_prompts_cn)\r\n            unconds = getcompelembs(all_n_prompts_cn)\r\n            base_embs = getcompelembs(all_bases_cn) if use_base else None\r\n            base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None\r\n            # When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts\r\n            embs = getcompelembs(prompts) if not use_base else base_embs\r\n            n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs\r\n\r\n            if use_base and self.base_ratio > 0:\r\n                conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds\r\n                unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds\r\n\r\n            prompt = negative_prompt = None\r\n        else:\r\n            conds = self.encode_prompt(prompts, device, 1, True)[0]\r\n            unconds = (\r\n                self.encode_prompt(n_prompts, device, 1, True)[0]\r\n                if equal\r\n                else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]\r\n            )\r\n\r\n            if use_base and self.base_ratio > 0:\r\n                base_embs = self.encode_prompt(bases, device, 1, True)[0]\r\n                base_n_embs = (\r\n                    self.encode_prompt(n_bases, device, 1, True)[0]\r\n                    if equal\r\n                    else self.encode_prompt(all_n_bases_cn, device, 1, True)[0]\r\n                )\r\n\r\n                conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds\r\n                unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds\r\n\r\n            embs = n_embs = None\r\n\r\n        if not active:\r\n            pcallback = None\r\n            mode = None\r\n        else:\r\n            if any(x in rp_args[\"mode\"].upper() for x in [\"COL\", \"ROW\"]):\r\n                mode = \"COL\" if \"COL\" in rp_args[\"mode\"].upper() else \"ROW\"\r\n                ocells, icells, regions = make_cells(rp_args[\"div\"])\r\n\r\n            elif \"PRO\" in rp_args[\"mode\"].upper():\r\n                regions = len(all_prompts_p[0])\r\n                mode = \"PROMPT\"\r\n                reset_attnmaps(self)\r\n                self.ex = \"EX\" in rp_args[\"mode\"].upper()\r\n                self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)\r\n                thresholds = [float(x) for x in rp_args[\"th\"].split(\",\")]\r\n\r\n            orig_hw = (height, width)\r\n            revers = True\r\n\r\n            def pcallback(s_self, step: int, timestep: int, latents: torch.Tensor, selfs=None):\r\n                if \"PRO\" in mode:  # in Prompt mode, make masks from sum of attention maps\r\n                    self.step = step\r\n\r\n                    if len(self.attnmaps_sizes) > 3:\r\n                        self.history[step] = self.attnmaps.copy()\r\n                        for hw in self.attnmaps_sizes:\r\n                            allmasks = []\r\n                            basemasks = [None] * batch\r\n                            for tt, th in zip(target_tokens, thresholds):\r\n                                for b in range(batch):\r\n                                    key = f\"{tt}-{b}\"\r\n                                    _, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step)\r\n                                    mask = mask.unsqueeze(0).unsqueeze(-1)\r\n                                    if self.ex:\r\n                                        allmasks[b::batch] = [x - mask for x in allmasks[b::batch]]\r\n                                        allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]\r\n                                    allmasks.append(mask)\r\n                                    basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask\r\n                            basemasks = [1 - mask for mask in basemasks]\r\n                            basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]\r\n                            allmasks = basemasks + allmasks\r\n\r\n                            self.attnmasks[hw] = torch.cat(allmasks)\r\n                        self.maskready = True\r\n                return latents\r\n\r\n            def hook_forward(module):\r\n                # diffusers==0.23.2\r\n                def forward(\r\n                    hidden_states: torch.Tensor,\r\n                    encoder_hidden_states: Optional[torch.Tensor] = None,\r\n                    attention_mask: Optional[torch.Tensor] = None,\r\n                    temb: Optional[torch.Tensor] = None,\r\n                    scale: float = 1.0,\r\n                ) -> torch.Tensor:\r\n                    attn = module\r\n                    xshape = hidden_states.shape\r\n                    self.hw = (h, w) = split_dims(xshape[1], *orig_hw)\r\n\r\n                    if revers:\r\n                        nx, px = hidden_states.chunk(2)\r\n                    else:\r\n                        px, nx = hidden_states.chunk(2)\r\n\r\n                    if equal:\r\n                        hidden_states = torch.cat(\r\n                            [px for i in range(regions)] + [nx for i in range(regions)],\r\n                            0,\r\n                        )\r\n                        encoder_hidden_states = torch.cat([conds] + [unconds])\r\n                    else:\r\n                        hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)\r\n                        encoder_hidden_states = torch.cat([conds] + [unconds])\r\n\r\n                    residual = hidden_states\r\n\r\n                    if attn.spatial_norm is not None:\r\n                        hidden_states = attn.spatial_norm(hidden_states, temb)\r\n\r\n                    input_ndim = hidden_states.ndim\r\n\r\n                    if input_ndim == 4:\r\n                        batch_size, channel, height, width = hidden_states.shape\r\n                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\r\n\r\n                    batch_size, sequence_length, _ = (\r\n                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\r\n                    )\r\n\r\n                    if attention_mask is not None:\r\n                        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\r\n                        attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\r\n\r\n                    if attn.group_norm is not None:\r\n                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\r\n\r\n                    query = attn.to_q(hidden_states)\r\n\r\n                    if encoder_hidden_states is None:\r\n                        encoder_hidden_states = hidden_states\r\n                    elif attn.norm_cross:\r\n                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\r\n\r\n                    key = attn.to_k(encoder_hidden_states)\r\n                    value = attn.to_v(encoder_hidden_states)\r\n\r\n                    inner_dim = key.shape[-1]\r\n                    head_dim = inner_dim // attn.heads\r\n\r\n                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\r\n\r\n                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\r\n                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\r\n\r\n                    # the output of sdp = (batch, num_heads, seq_len, head_dim)\r\n                    # TODO: add support for attn.scale when we move to Torch 2.1\r\n                    hidden_states = scaled_dot_product_attention(\r\n                        self,\r\n                        query,\r\n                        key,\r\n                        value,\r\n                        attn_mask=attention_mask,\r\n                        dropout_p=0.0,\r\n                        is_causal=False,\r\n                        getattn=\"PRO\" in mode,\r\n                    )\r\n\r\n                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\r\n                    hidden_states = hidden_states.to(query.dtype)\r\n\r\n                    # linear proj\r\n                    hidden_states = attn.to_out[0](hidden_states)\r\n                    # dropout\r\n                    hidden_states = attn.to_out[1](hidden_states)\r\n\r\n                    if input_ndim == 4:\r\n                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\r\n\r\n                    if attn.residual_connection:\r\n                        hidden_states = hidden_states + residual\r\n\r\n                    hidden_states = hidden_states / attn.rescale_output_factor\r\n\r\n                    #### Regional Prompting Col/Row mode\r\n                    if any(x in mode for x in [\"COL\", \"ROW\"]):\r\n                        reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])\r\n                        center = reshaped.shape[0] // 2\r\n                        px = reshaped[0:center] if equal else reshaped[0:-batch]\r\n                        nx = reshaped[center:] if equal else reshaped[-batch:]\r\n                        outs = [px, nx] if equal else [px]\r\n                        for out in outs:\r\n                            c = 0\r\n                            for i, ocell in enumerate(ocells):\r\n                                for icell in icells[i]:\r\n                                    if \"ROW\" in mode:\r\n                                        out[\r\n                                            0:batch,\r\n                                            int(h * ocell[0]) : int(h * ocell[1]),\r\n                                            int(w * icell[0]) : int(w * icell[1]),\r\n                                            :,\r\n                                        ] = out[\r\n                                            c * batch : (c + 1) * batch,\r\n                                            int(h * ocell[0]) : int(h * ocell[1]),\r\n                                            int(w * icell[0]) : int(w * icell[1]),\r\n                                            :,\r\n                                        ]\r\n                                    else:\r\n                                        out[\r\n                                            0:batch,\r\n                                            int(h * icell[0]) : int(h * icell[1]),\r\n                                            int(w * ocell[0]) : int(w * ocell[1]),\r\n                                            :,\r\n                                        ] = out[\r\n                                            c * batch : (c + 1) * batch,\r\n                                            int(h * icell[0]) : int(h * icell[1]),\r\n                                            int(w * ocell[0]) : int(w * ocell[1]),\r\n                                            :,\r\n                                        ]\r\n                                    c += 1\r\n                        px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)\r\n                        hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)\r\n                        hidden_states = hidden_states.reshape(xshape)\r\n\r\n                    #### Regional Prompting Prompt mode\r\n                    elif \"PRO\" in mode:\r\n                        px, nx = (\r\n                            torch.chunk(hidden_states) if equal else hidden_states[0:-batch],\r\n                            hidden_states[-batch:],\r\n                        )\r\n\r\n                        if (h, w) in self.attnmasks and self.maskready:\r\n\r\n                            def mask(input):\r\n                                out = torch.multiply(input, self.attnmasks[(h, w)])\r\n                                for b in range(batch):\r\n                                    for r in range(1, regions):\r\n                                        out[b] = out[b] + out[r * batch + b]\r\n                                return out\r\n\r\n                            px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx)\r\n                        px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)\r\n                        hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)\r\n                    return hidden_states\r\n\r\n                return forward\r\n\r\n            def hook_forwards(root_module: torch.nn.Module):\r\n                for name, module in root_module.named_modules():\r\n                    if \"attn2\" in name and module.__class__.__name__ == \"Attention\":\r\n                        module.forward = hook_forward(module)\r\n\r\n            hook_forwards(self.unet)\r\n\r\n        output = self.stable_diffusion_call(\r\n            prompt=prompt,\r\n            prompt_embeds=embs,\r\n            negative_prompt=negative_prompt,\r\n            negative_prompt_embeds=n_embs,\r\n            height=height,\r\n            width=width,\r\n            num_inference_steps=num_inference_steps,\r\n            guidance_scale=guidance_scale,\r\n            num_images_per_prompt=num_images_per_prompt,\r\n            eta=eta,\r\n            generator=generator,\r\n            latents=latents,\r\n            output_type=output_type,\r\n            return_dict=return_dict,\r\n            callback_on_step_end=pcallback,\r\n        )\r\n\r\n        if \"save_mask\" in rp_args:\r\n            save_mask = rp_args[\"save_mask\"]\r\n        else:\r\n            save_mask = False\r\n\r\n        if mode == \"PROMPT\" and save_mask:\r\n            saveattnmaps(\r\n                self,\r\n                output,\r\n                height,\r\n                width,\r\n                thresholds,\r\n                num_inference_steps // 2,\r\n                regions,\r\n            )\r\n\r\n        return output\r\n\r\n    # copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion\r\n    def prepare_extra_step_kwargs(self, generator, eta):\r\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\r\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\r\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\r\n        # and should be between [0, 1]\r\n\r\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\r\n        extra_step_kwargs = {}\r\n        if accepts_eta:\r\n            extra_step_kwargs[\"eta\"] = eta\r\n\r\n        # check if the scheduler accepts generator\r\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\r\n        if accepts_generator:\r\n            extra_step_kwargs[\"generator\"] = generator\r\n        return extra_step_kwargs\r\n\r\n    # copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion\r\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\r\n        shape = (\r\n            batch_size,\r\n            num_channels_latents,\r\n            int(height) // self.vae_scale_factor,\r\n            int(width) // self.vae_scale_factor,\r\n        )\r\n        if isinstance(generator, list) and len(generator) != batch_size:\r\n            raise ValueError(\r\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\r\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\r\n            )\r\n\r\n        if latents is None:\r\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\r\n        else:\r\n            latents = latents.to(device)\r\n\r\n        # scale the initial noise by the standard deviation required by the scheduler\r\n        latents = latents * self.scheduler.init_noise_sigma\r\n        return latents\r\n\r\n    # copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion\r\n    def encode_prompt(\r\n        self,\r\n        prompt,\r\n        device,\r\n        num_images_per_prompt,\r\n        do_classifier_free_guidance,\r\n        negative_prompt=None,\r\n        prompt_embeds: Optional[torch.Tensor] = None,\r\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\r\n        lora_scale: Optional[float] = None,\r\n        clip_skip: Optional[int] = None,\r\n    ):\r\n        r\"\"\"\r\n        Encodes the prompt into text encoder hidden states.\r\n\r\n        Args:\r\n            prompt (`str` or `List[str]`, *optional*):\r\n                prompt to be encoded\r\n            device: (`torch.device`):\r\n                torch device\r\n            num_images_per_prompt (`int`):\r\n                number of images that should be generated per prompt\r\n            do_classifier_free_guidance (`bool`):\r\n                whether to use classifier free guidance or not\r\n            negative_prompt (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\r\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\r\n                less than `1`).\r\n            prompt_embeds (`torch.Tensor`, *optional*):\r\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\r\n                provided, text embeddings will be generated from `prompt` input argument.\r\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\r\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\r\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\r\n                argument.\r\n            lora_scale (`float`, *optional*):\r\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\r\n            clip_skip (`int`, *optional*):\r\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\r\n                the output of the pre-final layer will be used for computing the prompt embeddings.\r\n        \"\"\"\r\n        # set lora scale so that monkey patched LoRA\r\n        # function of text encoder can correctly access it\r\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\r\n            self._lora_scale = lora_scale\r\n\r\n            # dynamically adjust the LoRA scale\r\n            if not USE_PEFT_BACKEND:\r\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\r\n            else:\r\n                scale_lora_layers(self.text_encoder, lora_scale)\r\n\r\n        if prompt is not None and isinstance(prompt, str):\r\n            batch_size = 1\r\n        elif prompt is not None and isinstance(prompt, list):\r\n            batch_size = len(prompt)\r\n        else:\r\n            batch_size = prompt_embeds.shape[0]\r\n\r\n        if prompt_embeds is None:\r\n            # textual inversion: process multi-vector tokens if necessary\r\n            if isinstance(self, TextualInversionLoaderMixin):\r\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\r\n\r\n            text_inputs = self.tokenizer(\r\n                prompt,\r\n                padding=\"max_length\",\r\n                max_length=self.tokenizer.model_max_length,\r\n                truncation=True,\r\n                return_tensors=\"pt\",\r\n            )\r\n            text_input_ids = text_inputs.input_ids\r\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\r\n\r\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\r\n                text_input_ids, untruncated_ids\r\n            ):\r\n                removed_text = self.tokenizer.batch_decode(\r\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\r\n                )\r\n                logger.warning(\r\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\r\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\r\n                )\r\n\r\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\r\n                attention_mask = text_inputs.attention_mask.to(device)\r\n            else:\r\n                attention_mask = None\r\n\r\n            if clip_skip is None:\r\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\r\n                prompt_embeds = prompt_embeds[0]\r\n            else:\r\n                prompt_embeds = self.text_encoder(\r\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\r\n                )\r\n                # Access the `hidden_states` first, that contains a tuple of\r\n                # all the hidden states from the encoder layers. Then index into\r\n                # the tuple to access the hidden states from the desired layer.\r\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\r\n                # We also need to apply the final LayerNorm here to not mess with the\r\n                # representations. The `last_hidden_states` that we typically use for\r\n                # obtaining the final prompt representations passes through the LayerNorm\r\n                # layer.\r\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\r\n\r\n        if self.text_encoder is not None:\r\n            prompt_embeds_dtype = self.text_encoder.dtype\r\n        elif self.unet is not None:\r\n            prompt_embeds_dtype = self.unet.dtype\r\n        else:\r\n            prompt_embeds_dtype = prompt_embeds.dtype\r\n\r\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\r\n\r\n        bs_embed, seq_len, _ = prompt_embeds.shape\r\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\r\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\r\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\r\n\r\n        # get unconditional embeddings for classifier free guidance\r\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\r\n            uncond_tokens: List[str]\r\n            if negative_prompt is None:\r\n                uncond_tokens = [\"\"] * batch_size\r\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\r\n                raise TypeError(\r\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\r\n                    f\" {type(prompt)}.\"\r\n                )\r\n            elif isinstance(negative_prompt, str):\r\n                uncond_tokens = [negative_prompt]\r\n            elif batch_size != len(negative_prompt):\r\n                raise ValueError(\r\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\r\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\r\n                    \" the batch size of `prompt`.\"\r\n                )\r\n            else:\r\n                uncond_tokens = negative_prompt\r\n\r\n            # textual inversion: process multi-vector tokens if necessary\r\n            if isinstance(self, TextualInversionLoaderMixin):\r\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\r\n\r\n            max_length = prompt_embeds.shape[1]\r\n            uncond_input = self.tokenizer(\r\n                uncond_tokens,\r\n                padding=\"max_length\",\r\n                max_length=max_length,\r\n                truncation=True,\r\n                return_tensors=\"pt\",\r\n            )\r\n\r\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\r\n                attention_mask = uncond_input.attention_mask.to(device)\r\n            else:\r\n                attention_mask = None\r\n\r\n            negative_prompt_embeds = self.text_encoder(\r\n                uncond_input.input_ids.to(device),\r\n                attention_mask=attention_mask,\r\n            )\r\n            negative_prompt_embeds = negative_prompt_embeds[0]\r\n\r\n        if do_classifier_free_guidance:\r\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\r\n            seq_len = negative_prompt_embeds.shape[1]\r\n\r\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\r\n\r\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\r\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\r\n\r\n        if self.text_encoder is not None:\r\n            if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\r\n                # Retrieve the original scale by scaling back the LoRA layers\r\n                unscale_lora_layers(self.text_encoder, lora_scale)\r\n\r\n        return prompt_embeds, negative_prompt_embeds\r\n\r\n    # copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion\r\n    def check_inputs(\r\n        self,\r\n        prompt,\r\n        height,\r\n        width,\r\n        callback_steps,\r\n        negative_prompt=None,\r\n        prompt_embeds=None,\r\n        negative_prompt_embeds=None,\r\n        ip_adapter_image=None,\r\n        ip_adapter_image_embeds=None,\r\n        callback_on_step_end_tensor_inputs=None,\r\n    ):\r\n        if height % 8 != 0 or width % 8 != 0:\r\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\r\n\r\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\r\n            raise ValueError(\r\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\r\n                f\" {type(callback_steps)}.\"\r\n            )\r\n        if callback_on_step_end_tensor_inputs is not None and not all(\r\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\r\n        ):\r\n            raise ValueError(\r\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\r\n            )\r\n\r\n        if prompt is not None and prompt_embeds is not None:\r\n            raise ValueError(\r\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\r\n                \" only forward one of the two.\"\r\n            )\r\n        elif prompt is None and prompt_embeds is None:\r\n            raise ValueError(\r\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\r\n            )\r\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\r\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\r\n\r\n        if negative_prompt is not None and negative_prompt_embeds is not None:\r\n            raise ValueError(\r\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\r\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\r\n            )\r\n\r\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\r\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\r\n                raise ValueError(\r\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\r\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\r\n                    f\" {negative_prompt_embeds.shape}.\"\r\n                )\r\n\r\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\r\n            raise ValueError(\r\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\r\n            )\r\n\r\n        if ip_adapter_image_embeds is not None:\r\n            if not isinstance(ip_adapter_image_embeds, list):\r\n                raise ValueError(\r\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\r\n                )\r\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\r\n                raise ValueError(\r\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\r\n                )\r\n\r\n    # copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion\r\n    @torch.no_grad()\r\n    def stable_diffusion_call(\r\n        self,\r\n        prompt: Union[str, List[str]] = None,\r\n        height: Optional[int] = None,\r\n        width: Optional[int] = None,\r\n        num_inference_steps: int = 50,\r\n        timesteps: List[int] = None,\r\n        sigmas: List[float] = None,\r\n        guidance_scale: float = 7.5,\r\n        negative_prompt: Optional[Union[str, List[str]]] = None,\r\n        num_images_per_prompt: Optional[int] = 1,\r\n        eta: float = 0.0,\r\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\r\n        latents: Optional[torch.Tensor] = None,\r\n        prompt_embeds: Optional[torch.Tensor] = None,\r\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\r\n        ip_adapter_image: Optional[PipelineImageInput] = None,\r\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\r\n        output_type: str | None = \"pil\",\r\n        return_dict: bool = True,\r\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\r\n        guidance_rescale: float = 0.0,\r\n        clip_skip: Optional[int] = None,\r\n        callback_on_step_end: Optional[\r\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\r\n        ] = None,\r\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\r\n        **kwargs,\r\n    ):\r\n        r\"\"\"\r\n        The call function to the pipeline for generation.\r\n\r\n        Args:\r\n            prompt (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\r\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\r\n                The height in pixels of the generated image.\r\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\r\n                The width in pixels of the generated image.\r\n            num_inference_steps (`int`, *optional*, defaults to 50):\r\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\r\n                expense of slower inference.\r\n            timesteps (`List[int]`, *optional*):\r\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\r\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\r\n                passed will be used. Must be in descending order.\r\n            sigmas (`List[float]`, *optional*):\r\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\r\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\r\n                will be used.\r\n            guidance_scale (`float`, *optional*, defaults to 7.5):\r\n                A higher guidance scale value encourages the model to generate images closely linked to the text\r\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\r\n            negative_prompt (`str` or `List[str]`, *optional*):\r\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\r\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\r\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\r\n                The number of images to generate per prompt.\r\n            eta (`float`, *optional*, defaults to 0.0):\r\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\r\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\r\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\r\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\r\n                generation deterministic.\r\n            latents (`torch.Tensor`, *optional*):\r\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\r\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\r\n                tensor is generated by sampling using the supplied random `generator`.\r\n            prompt_embeds (`torch.Tensor`, *optional*):\r\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\r\n                provided, text embeddings are generated from the `prompt` input argument.\r\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\r\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\r\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\r\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\r\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\r\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\r\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\r\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\r\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\r\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\r\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\r\n            return_dict (`bool`, *optional*, defaults to `True`):\r\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\r\n                plain tuple.\r\n            cross_attention_kwargs (`dict`, *optional*):\r\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\r\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\r\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\r\n                Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are\r\n                Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when\r\n                using zero terminal SNR.\r\n            clip_skip (`int`, *optional*):\r\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\r\n                the output of the pre-final layer will be used for computing the prompt embeddings.\r\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\r\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\r\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\r\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\r\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\r\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\r\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\r\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\r\n                `._callback_tensor_inputs` attribute of your pipeline class.\r\n\r\n        Examples:\r\n\r\n        Returns:\r\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\r\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\r\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\r\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\r\n                \"not-safe-for-work\" (nsfw) content.\r\n        \"\"\"\r\n\r\n        callback = kwargs.pop(\"callback\", None)\r\n        callback_steps = kwargs.pop(\"callback_steps\", None)\r\n        self.model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\r\n        self._optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\r\n        self._exclude_from_cpu_offload = [\"safety_checker\"]\r\n        self._callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\r\n\r\n        if callback is not None:\r\n            deprecate(\r\n                \"callback\",\r\n                \"1.0.0\",\r\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\r\n            )\r\n        if callback_steps is not None:\r\n            deprecate(\r\n                \"callback_steps\",\r\n                \"1.0.0\",\r\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\r\n            )\r\n\r\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\r\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\r\n\r\n        # 0. Default height and width to unet\r\n        if not height or not width:\r\n            height = (\r\n                self.unet.config.sample_size\r\n                if self._is_unet_config_sample_size_int\r\n                else self.unet.config.sample_size[0]\r\n            )\r\n            width = (\r\n                self.unet.config.sample_size\r\n                if self._is_unet_config_sample_size_int\r\n                else self.unet.config.sample_size[1]\r\n            )\r\n            height, width = height * self.vae_scale_factor, width * self.vae_scale_factor\r\n        # to deal with lora scaling and other possible forward hooks\r\n\r\n        # 1. Check inputs. Raise error if not correct\r\n        self.check_inputs(\r\n            prompt,\r\n            height,\r\n            width,\r\n            callback_steps,\r\n            negative_prompt,\r\n            prompt_embeds,\r\n            negative_prompt_embeds,\r\n            ip_adapter_image,\r\n            ip_adapter_image_embeds,\r\n            callback_on_step_end_tensor_inputs,\r\n        )\r\n\r\n        self._guidance_scale = guidance_scale\r\n        self._guidance_rescale = guidance_rescale\r\n        self._clip_skip = clip_skip\r\n        self._cross_attention_kwargs = cross_attention_kwargs\r\n        self._interrupt = False\r\n\r\n        # 2. Define call parameters\r\n        if prompt is not None and isinstance(prompt, str):\r\n            batch_size = 1\r\n        elif prompt is not None and isinstance(prompt, list):\r\n            batch_size = len(prompt)\r\n        else:\r\n            batch_size = prompt_embeds.shape[0]\r\n\r\n        device = self._execution_device\r\n\r\n        # 3. Encode input prompt\r\n        lora_scale = (\r\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\r\n        )\r\n\r\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\r\n            prompt,\r\n            device,\r\n            num_images_per_prompt,\r\n            self.do_classifier_free_guidance,\r\n            negative_prompt,\r\n            prompt_embeds=prompt_embeds,\r\n            negative_prompt_embeds=negative_prompt_embeds,\r\n            lora_scale=lora_scale,\r\n            clip_skip=self.clip_skip,\r\n        )\r\n\r\n        # For classifier free guidance, we need to do two forward passes.\r\n        # Here we concatenate the unconditional and text embeddings into a single batch\r\n        # to avoid doing two forward passes\r\n        if self.do_classifier_free_guidance:\r\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\r\n\r\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\r\n            image_embeds = self.prepare_ip_adapter_image_embeds(\r\n                ip_adapter_image,\r\n                ip_adapter_image_embeds,\r\n                device,\r\n                batch_size * num_images_per_prompt,\r\n                self.do_classifier_free_guidance,\r\n            )\r\n\r\n        # 4. Prepare timesteps\r\n        timesteps, num_inference_steps = retrieve_timesteps(\r\n            self.scheduler, num_inference_steps, device, timesteps, sigmas\r\n        )\r\n\r\n        # 5. Prepare latent variables\r\n        num_channels_latents = self.unet.config.in_channels\r\n        latents = self.prepare_latents(\r\n            batch_size * num_images_per_prompt,\r\n            num_channels_latents,\r\n            height,\r\n            width,\r\n            prompt_embeds.dtype,\r\n            device,\r\n            generator,\r\n            latents,\r\n        )\r\n\r\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\r\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\r\n\r\n        # 6.1 Add image embeds for IP-Adapter\r\n        added_cond_kwargs = (\r\n            {\"image_embeds\": image_embeds}\r\n            if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)\r\n            else None\r\n        )\r\n\r\n        # 6.2 Optionally get Guidance Scale Embedding\r\n        timestep_cond = None\r\n        if self.unet.config.time_cond_proj_dim is not None:\r\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\r\n            timestep_cond = self.get_guidance_scale_embedding(\r\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\r\n            ).to(device=device, dtype=latents.dtype)\r\n\r\n        # 7. Denoising loop\r\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\r\n        self._num_timesteps = len(timesteps)\r\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\r\n            for i, t in enumerate(timesteps):\r\n                if self.interrupt:\r\n                    continue\r\n\r\n                # expand the latents if we are doing classifier free guidance\r\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\r\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\r\n\r\n                # predict the noise residual\r\n                noise_pred = self.unet(\r\n                    latent_model_input,\r\n                    t,\r\n                    encoder_hidden_states=prompt_embeds,\r\n                    timestep_cond=timestep_cond,\r\n                    cross_attention_kwargs=self.cross_attention_kwargs,\r\n                    added_cond_kwargs=added_cond_kwargs,\r\n                    return_dict=False,\r\n                )[0]\r\n\r\n                # perform guidance\r\n                if self.do_classifier_free_guidance:\r\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\r\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\r\n\r\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\r\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\r\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\r\n\r\n                # compute the previous noisy sample x_t -> x_t-1\r\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\r\n\r\n                if callback_on_step_end is not None:\r\n                    callback_kwargs = {}\r\n                    for k in callback_on_step_end_tensor_inputs:\r\n                        callback_kwargs[k] = locals()[k]\r\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\r\n\r\n                    latents = callback_outputs.pop(\"latents\", latents)\r\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\r\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\r\n\r\n                # call the callback, if provided\r\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\r\n                    progress_bar.update()\r\n                    if callback is not None and i % callback_steps == 0:\r\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\r\n                        callback(step_idx, t, latents)\r\n\r\n                if XLA_AVAILABLE:\r\n                    xm.mark_step()\r\n\r\n        if not output_type == \"latent\":\r\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\r\n                0\r\n            ]\r\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\r\n        else:\r\n            image = latents\r\n            has_nsfw_concept = None\r\n\r\n        if has_nsfw_concept is None:\r\n            do_denormalize = [True] * image.shape[0]\r\n        else:\r\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\r\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\r\n\r\n        # Offload all models\r\n        self.maybe_free_model_hooks()\r\n\r\n        if not return_dict:\r\n            return (image, has_nsfw_concept)\r\n\r\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\r\n\r\n    # copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion\r\n    def _encode_prompt(\r\n        self,\r\n        prompt,\r\n        device,\r\n        num_images_per_prompt,\r\n        do_classifier_free_guidance,\r\n        negative_prompt=None,\r\n        prompt_embeds: Optional[torch.Tensor] = None,\r\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\r\n        lora_scale: Optional[float] = None,\r\n        **kwargs,\r\n    ):\r\n        r\"\"\"Encodes the prompt into text encoder hidden states.\"\"\"\r\n        batch_size = len(prompt) if isinstance(prompt, list) else 1\r\n\r\n        # get prompt text embeddings\r\n        text_inputs = self.tokenizer(\r\n            prompt,\r\n            padding=\"max_length\",\r\n            max_length=self.tokenizer.model_max_length,\r\n            truncation=True,\r\n            return_tensors=\"pt\",\r\n        )\r\n        text_input_ids = text_inputs.input_ids\r\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\r\n\r\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\r\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])\r\n            logger.warning(\r\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\r\n                f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\r\n            )\r\n\r\n        if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\r\n            attention_mask = text_inputs.attention_mask.to(device)\r\n        else:\r\n            attention_mask = None\r\n\r\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\r\n            # cast text_encoder.dtype to prevent overflow when using bf16\r\n            text_input_ids = text_input_ids.to(device=device, dtype=self.text_encoder.dtype)\r\n            prompt_embeds = self.text_encoder(\r\n                text_input_ids,\r\n                attention_mask=attention_mask,\r\n            )\r\n            prompt_embeds = prompt_embeds[0]\r\n        else:\r\n            text_encoder_lora_scale = None\r\n            if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\r\n                text_encoder_lora_scale = lora_scale\r\n            if text_encoder_lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\r\n                # dynamically adjust the LoRA scale\r\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\r\n\r\n            prompt_embeds = self.text_encoder(\r\n                text_input_ids.to(device),\r\n                attention_mask=attention_mask,\r\n            )\r\n            prompt_embeds = prompt_embeds[0]\r\n\r\n        # duplicate text embeddings for each generation per prompt\r\n        bs_embed, seq_len, _ = prompt_embeds.shape\r\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\r\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\r\n\r\n        # get unconditional embeddings for classifier free guidance\r\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\r\n            uncond_tokens: List[str]\r\n            if negative_prompt is None:\r\n                uncond_tokens = [\"\"]\r\n            elif type(prompt) is not type(negative_prompt):\r\n                raise TypeError(\r\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\r\n                    f\" {type(prompt)}.\"\r\n                )\r\n            elif isinstance(negative_prompt, str):\r\n                uncond_tokens = [negative_prompt]\r\n            elif batch_size != len(negative_prompt):\r\n                raise ValueError(\r\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\r\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\r\n                    \" the batch size of `prompt`.\"\r\n                )\r\n            else:\r\n                uncond_tokens = negative_prompt\r\n\r\n            # textual inversion: process multi-vector tokens if necessary\r\n            if isinstance(self, TextualInversionLoaderMixin):\r\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\r\n\r\n            max_length = prompt_embeds.shape[1]\r\n            uncond_input = self.tokenizer(\r\n                uncond_tokens,\r\n                padding=\"max_length\",\r\n                max_length=max_length,\r\n                truncation=True,\r\n                return_tensors=\"pt\",\r\n            )\r\n\r\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\r\n                attention_mask = uncond_input.attention_mask.to(device)\r\n            else:\r\n                attention_mask = None\r\n\r\n            negative_prompt_embeds = self.text_encoder(\r\n                uncond_input.input_ids.to(device),\r\n                attention_mask=attention_mask,\r\n            )\r\n            negative_prompt_embeds = negative_prompt_embeds[0]\r\n\r\n        if do_classifier_free_guidance:\r\n            # duplicate unconditional embeddings for each generation per prompt\r\n            seq_len = negative_prompt_embeds.shape[1]\r\n\r\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\r\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\r\n\r\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\r\n            # Unscale LoRA weights to avoid overfitting. This is a hack\r\n            unscale_lora_layers(self.text_encoder, lora_scale)\r\n\r\n        return prompt_embeds, negative_prompt_embeds\r\n\r\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\r\n        \"\"\"Encodes the image into image encoder hidden states.\"\"\"\r\n        dtype = next(self.image_encoder.parameters()).dtype\r\n\r\n        if not isinstance(image, torch.Tensor):\r\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\r\n\r\n        image = image.to(device=device, dtype=dtype)\r\n        if output_hidden_states:\r\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\r\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\r\n            uncond_image_enc_hidden_states = self.image_encoder(\r\n                torch.zeros_like(image), output_hidden_states=True\r\n            ).hidden_states[-2]\r\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\r\n                num_images_per_prompt, dim=0\r\n            )\r\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\r\n        else:\r\n            image_embeds = self.image_encoder(image).image_embeds\r\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\r\n            uncond_image_embeds = torch.zeros_like(image_embeds)\r\n\r\n            return image_embeds, uncond_image_embeds\r\n\r\n    def prepare_ip_adapter_image_embeds(\r\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\r\n    ):\r\n        \"\"\"Prepares and processes IP-Adapter image embeddings.\"\"\"\r\n        image_embeds = []\r\n        if do_classifier_free_guidance:\r\n            negative_image_embeds = []\r\n        if ip_adapter_image_embeds is None:\r\n            for image in ip_adapter_image:\r\n                if not isinstance(image, torch.Tensor):\r\n                    image = self.image_processor.preprocess(image)\r\n                    image = image.to(device=device)\r\n                if len(image.shape) == 3:\r\n                    image = image.unsqueeze(0)\r\n                image_emb, neg_image_emb = self.encode_image(image, device, num_images_per_prompt, True)\r\n                image_embeds.append(image_emb)\r\n                if do_classifier_free_guidance:\r\n                    negative_image_embeds.append(neg_image_emb)\r\n\r\n            if len(image_embeds) == 1:\r\n                image_embeds = image_embeds[0]\r\n                if do_classifier_free_guidance:\r\n                    negative_image_embeds = negative_image_embeds[0]\r\n            else:\r\n                image_embeds = torch.cat(image_embeds, dim=0)\r\n                if do_classifier_free_guidance:\r\n                    negative_image_embeds = torch.cat(negative_image_embeds, dim=0)\r\n        else:\r\n            repeat_dim = 2 if do_classifier_free_guidance else 1\r\n            image_embeds = ip_adapter_image_embeds.repeat_interleave(repeat_dim, dim=0)\r\n            if do_classifier_free_guidance:\r\n                negative_image_embeds = torch.zeros_like(image_embeds)\r\n\r\n        if do_classifier_free_guidance:\r\n            image_embeds = torch.cat([negative_image_embeds, image_embeds])\r\n\r\n        return image_embeds\r\n\r\n    def run_safety_checker(self, image, device, dtype):\r\n        \"\"\"Runs the safety checker on the generated image.\"\"\"\r\n        if self.safety_checker is None:\r\n            has_nsfw_concept = None\r\n            return image, has_nsfw_concept\r\n\r\n        if isinstance(self.safety_checker, StableDiffusionSafetyChecker):\r\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(device)\r\n            image, has_nsfw_concept = self.safety_checker(\r\n                images=image,\r\n                clip_input=safety_checker_input.pixel_values.to(dtype),\r\n            )\r\n        else:\r\n            images_np = self.numpy_to_pil(image)\r\n            safety_checker_input = self.safety_checker.feature_extractor(images_np, return_tensors=\"pt\").to(device)\r\n            has_nsfw_concept = self.safety_checker(\r\n                images=image,\r\n                clip_input=safety_checker_input.pixel_values.to(dtype),\r\n            )[1]\r\n\r\n        return image, has_nsfw_concept\r\n\r\n    # copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion\r\n    def decode_latents(self, latents):\r\n        \"\"\"Decodes the latents to images.\"\"\"\r\n        latents = 1 / self.vae.config.scaling_factor * latents\r\n        image = self.vae.decode(latents, return_dict=False)[0]\r\n        image = (image / 2 + 0.5).clamp(0, 1)\r\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\r\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\r\n        return image\r\n\r\n    @property\r\n    def guidance_scale(self):\r\n        return self._guidance_scale\r\n\r\n    @property\r\n    def guidance_rescale(self):\r\n        return self._guidance_rescale\r\n\r\n    # copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion\r\n    def get_guidance_scale_embedding(\r\n        self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32\r\n    ):\r\n        \"\"\"Gets the guidance scale embedding for classifier free guidance conditioning.\r\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\r\n\r\n        Args:\r\n            w (`torch.Tensor`):\r\n                The guidance scale tensor used for classifier free guidance conditioning.\r\n            embedding_dim (`int`, defaults to 512):\r\n                The dimensionality of the guidance scale embedding.\r\n            dtype (`torch.dtype`, defaults to torch.float32):\r\n                The dtype of the embedding.\r\n\r\n        Returns:\r\n            `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.\r\n        \"\"\"\r\n        assert len(w.shape) == 1\r\n        w = w * 1000.0\r\n\r\n        half_dim = embedding_dim // 2\r\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\r\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\r\n        emb = w.to(dtype)[:, None] * emb[None, :]\r\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\r\n        if embedding_dim % 2 == 1:  # zero pad\r\n            emb = torch.nn.functional.pad(emb, (0, 1))\r\n        assert emb.shape == (w.shape[0], embedding_dim)\r\n        return emb\r\n\r\n    @property\r\n    def clip_skip(self):\r\n        return self._clip_skip\r\n\r\n    @property\r\n    def num_timesteps(self):\r\n        return self._num_timesteps\r\n\r\n    @property\r\n    def interrupt(self):\r\n        return self._interrupt\r\n\r\n    @property\r\n    def cross_attention_kwargs(self):\r\n        return self._cross_attention_kwargs\r\n\r\n    @property\r\n    def do_classifier_free_guidance(self):\r\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\r\n\r\n\r\n### Make prompt list for each regions\r\ndef promptsmaker(prompts, batch):\r\n    out_p = []\r\n    plen = len(prompts)\r\n    for prompt in prompts:\r\n        add = \"\"\r\n        if KCOMM in prompt:\r\n            add, prompt = prompt.split(KCOMM)\r\n            add = add.strip() + \" \"\r\n        prompts = [p.strip() for p in prompt.split(KBRK)]\r\n        out_p.append([add + p for i, p in enumerate(prompts)])\r\n    out = [None] * batch * len(out_p[0]) * len(out_p)\r\n    for p, prs in enumerate(out_p):  # inputs prompts\r\n        for r, pr in enumerate(prs):  # prompts for regions\r\n            start = (p + r * plen) * batch\r\n            out[start : start + batch] = [pr] * batch  # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...\r\n    return out, out_p\r\n\r\n\r\n### make regions from ratios\r\n### \";\" makes outercells, \",\" makes inner cells\r\ndef make_cells(ratios):\r\n    if \";\" not in ratios and \",\" in ratios:\r\n        ratios = ratios.replace(\",\", \";\")\r\n    ratios = ratios.split(\";\")\r\n    ratios = [inratios.split(\",\") for inratios in ratios]\r\n\r\n    icells = []\r\n    ocells = []\r\n\r\n    def startend(cells, array):\r\n        current_start = 0\r\n        array = [float(x) for x in array]\r\n        for value in array:\r\n            end = current_start + (value / sum(array))\r\n            cells.append([current_start, end])\r\n            current_start = end\r\n\r\n    startend(ocells, [r[0] for r in ratios])\r\n\r\n    for inratios in ratios:\r\n        if 2 > len(inratios):\r\n            icells.append([[0, 1]])\r\n        else:\r\n            add = []\r\n            startend(add, inratios[1:])\r\n            icells.append(add)\r\n    return ocells, icells, sum(len(cell) for cell in icells)\r\n\r\n\r\ndef make_emblist(self, prompts):\r\n    with torch.no_grad():\r\n        tokens = self.tokenizer(\r\n            prompts,\r\n            max_length=self.tokenizer.model_max_length,\r\n            padding=True,\r\n            truncation=True,\r\n            return_tensors=\"pt\",\r\n        ).input_ids.to(self.device)\r\n        embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)\r\n    return embs\r\n\r\n\r\ndef split_dims(xs, height, width):\r\n    def repeat_div(x, y):\r\n        while y > 0:\r\n            x = math.ceil(x / 2)\r\n            y = y - 1\r\n        return x\r\n\r\n    scale = math.ceil(math.log2(math.sqrt(height * width / xs)))\r\n    dsh = repeat_div(height, scale)\r\n    dsw = repeat_div(width, scale)\r\n    return dsh, dsw\r\n\r\n\r\n##### for prompt mode\r\ndef get_attn_maps(self, attn):\r\n    height, width = self.hw\r\n    target_tokens = self.target_tokens\r\n    if (height, width) not in self.attnmaps_sizes:\r\n        self.attnmaps_sizes.append((height, width))\r\n\r\n    for b in range(self.batch):\r\n        for t in target_tokens:\r\n            power = self.power\r\n            add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1)\r\n            add = torch.sum(add, dim=2)\r\n            key = f\"{t}-{b}\"\r\n            if key not in self.attnmaps:\r\n                self.attnmaps[key] = add\r\n            else:\r\n                if self.attnmaps[key].shape[1] != add.shape[1]:\r\n                    add = add.view(8, height, width)\r\n                    add = FF.resize(add, self.attnmaps_sizes[0], antialias=None)\r\n                    add = add.reshape_as(self.attnmaps[key])\r\n\r\n                self.attnmaps[key] = self.attnmaps[key] + add\r\n\r\n\r\ndef reset_attnmaps(self):  # init parameters in every batch\r\n    self.step = 0\r\n    self.attnmaps = {}  # made from attention maps\r\n    self.attnmaps_sizes = []  # height,width set of u-net blocks\r\n    self.attnmasks = {}  # made from attnmaps for regions\r\n    self.maskready = False\r\n    self.history = {}\r\n\r\n\r\ndef saveattnmaps(self, output, h, w, th, step, regions):\r\n    masks = []\r\n    for i, mask in enumerate(self.history[step].values()):\r\n        img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step)\r\n        if self.ex:\r\n            masks = [x - mask for x in masks]\r\n            masks.append(mask)\r\n            if len(masks) == regions - 1:\r\n                output.images.extend([FF.to_pil_image(mask) for mask in masks])\r\n                masks = []\r\n        else:\r\n            output.images.append(img)\r\n\r\n\r\ndef makepmask(\r\n    self, mask, h, w, th, step\r\n):  # make masks from attention cache return [for preview, for attention, for Latent]\r\n    th = th - step * 0.005\r\n    if 0.05 >= th:\r\n        th = 0.05\r\n    mask = torch.mean(mask, dim=0)\r\n    mask = mask / mask.max().item()\r\n    mask = torch.where(mask > th, 1, 0)\r\n    mask = mask.float()\r\n    mask = mask.view(1, *self.attnmaps_sizes[0])\r\n    img = FF.to_pil_image(mask)\r\n    img = img.resize((w, h))\r\n    mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None)\r\n    lmask = mask\r\n    mask = mask.reshape(h * w)\r\n    mask = torch.where(mask > 0.1, 1, 0)\r\n    return img, mask, lmask\r\n\r\n\r\ndef tokendealer(self, all_prompts):\r\n    for prompts in all_prompts:\r\n        targets = [p.split(\",\")[-1] for p in prompts[1:]]\r\n        tt = []\r\n\r\n        for target in targets:\r\n            ptokens = (\r\n                self.tokenizer(\r\n                    prompts,\r\n                    max_length=self.tokenizer.model_max_length,\r\n                    padding=True,\r\n                    truncation=True,\r\n                    return_tensors=\"pt\",\r\n                ).input_ids\r\n            )[0]\r\n            ttokens = (\r\n                self.tokenizer(\r\n                    target,\r\n                    max_length=self.tokenizer.model_max_length,\r\n                    padding=True,\r\n                    truncation=True,\r\n                    return_tensors=\"pt\",\r\n                ).input_ids\r\n            )[0]\r\n\r\n            tlist = []\r\n\r\n            for t in range(ttokens.shape[0] - 2):\r\n                for p in range(ptokens.shape[0]):\r\n                    if ttokens[t + 1] == ptokens[p]:\r\n                        tlist.append(p)\r\n            if tlist != []:\r\n                tt.append(tlist)\r\n\r\n    return tt\r\n\r\n\r\ndef scaled_dot_product_attention(\r\n    self,\r\n    query,\r\n    key,\r\n    value,\r\n    attn_mask=None,\r\n    dropout_p=0.0,\r\n    is_causal=False,\r\n    scale=None,\r\n    getattn=False,\r\n) -> torch.Tensor:\r\n    # Efficient implementation equivalent to the following:\r\n    L, S = query.size(-2), key.size(-2)\r\n    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\r\n    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device)\r\n    if is_causal:\r\n        assert attn_mask is None\r\n        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)\r\n        attn_bias.masked_fill_(temp_mask.logical_not(), float(\"-inf\"))\r\n        attn_bias.to(query.dtype)\r\n\r\n    if attn_mask is not None:\r\n        if attn_mask.dtype == torch.bool:\r\n            attn_mask.masked_fill_(attn_mask.logical_not(), float(\"-inf\"))\r\n        else:\r\n            attn_bias += attn_mask\r\n    attn_weight = query @ key.transpose(-2, -1) * scale_factor\r\n    attn_weight += attn_bias\r\n    attn_weight = torch.softmax(attn_weight, dim=-1)\r\n    if getattn:\r\n        get_attn_maps(self, attn_weight)\r\n    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)\r\n    return attn_weight @ value\r\n\r\n\r\ndef retrieve_timesteps(\r\n    scheduler,\r\n    num_inference_steps: Optional[int] = None,\r\n    device: Optional[Union[str, torch.device]] = None,\r\n    timesteps: Optional[List[int]] = None,\r\n    sigmas: Optional[List[float]] = None,\r\n    **kwargs,\r\n):\r\n    r\"\"\"\r\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\r\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\r\n\r\n    Args:\r\n        scheduler (`SchedulerMixin`):\r\n            The scheduler to get timesteps from.\r\n        num_inference_steps (`int`):\r\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\r\n            must be `None`.\r\n        device (`str` or `torch.device`, *optional*):\r\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\r\n        timesteps (`List[int]`, *optional*):\r\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\r\n            `num_inference_steps` and `sigmas` must be `None`.\r\n        sigmas (`List[float]`, *optional*):\r\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\r\n            `num_inference_steps` and `timesteps` must be `None`.\r\n\r\n    Returns:\r\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\r\n        second element is the number of inference steps.\r\n    \"\"\"\r\n    if timesteps is not None and sigmas is not None:\r\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\r\n    if timesteps is not None:\r\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\r\n        if not accepts_timesteps:\r\n            raise ValueError(\r\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\r\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\r\n            )\r\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\r\n        timesteps = scheduler.timesteps\r\n        num_inference_steps = len(timesteps)\r\n    elif sigmas is not None:\r\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\r\n        if not accept_sigmas:\r\n            raise ValueError(\r\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\r\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\r\n            )\r\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\r\n        timesteps = scheduler.timesteps\r\n        num_inference_steps = len(timesteps)\r\n    else:\r\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\r\n        timesteps = scheduler.timesteps\r\n    return timesteps, num_inference_steps\r\n\r\n\r\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\r\n    r\"\"\"\r\n    Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on\r\n    Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are\r\n    Flawed](https://huggingface.co/papers/2305.08891).\r\n\r\n    Args:\r\n        noise_cfg (`torch.Tensor`):\r\n            The predicted noise tensor for the guided diffusion process.\r\n        noise_pred_text (`torch.Tensor`):\r\n            The predicted noise tensor for the text-guided diffusion process.\r\n        guidance_rescale (`float`, *optional*, defaults to 0.0):\r\n            A rescale factor applied to the noise predictions.\r\n\r\n    Returns:\r\n        noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.\r\n    \"\"\"\r\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\r\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\r\n    # rescale the results from guidance (fixes overexposure)\r\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\r\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\r\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\r\n    return noise_cfg\r\n"
  },
  {
    "path": "examples/community/rerender_a_video.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nimport torchvision.transforms as T\nfrom gmflow.gmflow import GMFlow\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel\nfrom diffusers.models.attention_processor import Attention, AttnProcessor\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import BaseOutput, deprecate, is_torch_xla_available, logging\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef coords_grid(b, h, w, homogeneous=False, device=None):\n    y, x = torch.meshgrid(torch.arange(h), torch.arange(w))  # [H, W]\n\n    stacks = [x, y]\n\n    if homogeneous:\n        ones = torch.ones_like(x)  # [H, W]\n        stacks.append(ones)\n\n    grid = torch.stack(stacks, dim=0).float()  # [2, H, W] or [3, H, W]\n\n    grid = grid[None].repeat(b, 1, 1, 1)  # [B, 2, H, W] or [B, 3, H, W]\n\n    if device is not None:\n        grid = grid.to(device)\n\n    return grid\n\n\ndef bilinear_sample(img, sample_coords, mode=\"bilinear\", padding_mode=\"zeros\", return_mask=False):\n    # img: [B, C, H, W]\n    # sample_coords: [B, 2, H, W] in image scale\n    if sample_coords.size(1) != 2:  # [B, H, W, 2]\n        sample_coords = sample_coords.permute(0, 3, 1, 2)\n\n    b, _, h, w = sample_coords.shape\n\n    # Normalize to [-1, 1]\n    x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1\n    y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1\n\n    grid = torch.stack([x_grid, y_grid], dim=-1)  # [B, H, W, 2]\n\n    img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)\n\n    if return_mask:\n        mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1)  # [B, H, W]\n\n        return img, mask\n\n    return img\n\n\ndef flow_warp(feature, flow, mask=False, mode=\"bilinear\", padding_mode=\"zeros\"):\n    b, c, h, w = feature.size()\n    assert flow.size(1) == 2\n\n    grid = coords_grid(b, h, w).to(flow.device) + flow  # [B, 2, H, W]\n    grid = grid.to(feature.dtype)\n    return bilinear_sample(feature, grid, mode=mode, padding_mode=padding_mode, return_mask=mask)\n\n\ndef forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5):\n    # fwd_flow, bwd_flow: [B, 2, H, W]\n    # alpha and beta values are following UnFlow\n    # (https://huggingface.co/papers/1711.07837)\n    assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4\n    assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2\n    flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1)  # [B, H, W]\n\n    warped_bwd_flow = flow_warp(bwd_flow, fwd_flow)  # [B, 2, H, W]\n    warped_fwd_flow = flow_warp(fwd_flow, bwd_flow)  # [B, 2, H, W]\n\n    diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1)  # [B, H, W]\n    diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)\n\n    threshold = alpha * flow_mag + beta\n\n    fwd_occ = (diff_fwd > threshold).float()  # [B, H, W]\n    bwd_occ = (diff_bwd > threshold).float()\n\n    return fwd_occ, bwd_occ\n\n\n@torch.no_grad()\ndef get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False, device=None):\n    if image3 is None:\n        image3 = image1\n    padder = InputPadder(image1.shape, padding_factor=8)\n    image1, image2 = padder.pad(image1[None].to(device), image2[None].to(device))\n    results_dict = flow_model(\n        image1, image2, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True\n    )\n    flow_pr = results_dict[\"flow_preds\"][-1]  # [B, 2, H, W]\n    fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0)  # [1, 2, H, W]\n    bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0)  # [1, 2, H, W]\n    fwd_occ, bwd_occ = forward_backward_consistency_check(fwd_flow, bwd_flow)  # [1, H, W] float\n    if pixel_consistency:\n        warped_image1 = flow_warp(image1, bwd_flow)\n        bwd_occ = torch.clamp(\n            bwd_occ + (abs(image2 - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0, 1\n        ).unsqueeze(0)\n    warped_results = flow_warp(image3, bwd_flow)\n    return warped_results, bwd_occ, bwd_flow\n\n\nblur = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))\n\n\n@dataclass\nclass TextToVideoSDPipelineOutput(BaseOutput):\n    \"\"\"\n    Output class for text-to-video pipelines.\n\n    Args:\n        frames (`List[np.ndarray]` or `torch.Tensor`)\n            List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as\n            a `torch` tensor. The length of the list denotes the video length (the number of frames).\n    \"\"\"\n\n    frames: Union[List[np.ndarray], torch.Tensor]\n\n\n@torch.no_grad()\ndef find_flat_region(mask):\n    device = mask.device\n    kernel_x = torch.Tensor([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]).unsqueeze(0).unsqueeze(0).to(device)\n    kernel_y = torch.Tensor([[-1, -1, -1], [0, 0, 0], [1, 1, 1]]).unsqueeze(0).unsqueeze(0).to(device)\n    mask_ = F.pad(mask.unsqueeze(0), (1, 1, 1, 1), mode=\"replicate\")\n\n    grad_x = torch.nn.functional.conv2d(mask_, kernel_x)\n    grad_y = torch.nn.functional.conv2d(mask_, kernel_y)\n    return ((abs(grad_x) + abs(grad_y)) == 0).float()[0]\n\n\nclass AttnState:\n    STORE = 0\n    LOAD = 1\n    LOAD_AND_STORE_PREV = 2\n\n    def __init__(self):\n        self.reset()\n\n    @property\n    def state(self):\n        return self.__state\n\n    @property\n    def timestep(self):\n        return self.__timestep\n\n    def set_timestep(self, t):\n        self.__timestep = t\n\n    def reset(self):\n        self.__state = AttnState.STORE\n        self.__timestep = 0\n\n    def to_load(self):\n        self.__state = AttnState.LOAD\n\n    def to_load_and_store_prev(self):\n        self.__state = AttnState.LOAD_AND_STORE_PREV\n\n\nclass CrossFrameAttnProcessor(AttnProcessor):\n    \"\"\"\n    Cross frame attention processor. Each frame attends the first frame and previous frame.\n\n    Args:\n        attn_state: Whether the model is processing the first frame or an intermediate frame\n    \"\"\"\n\n    def __init__(self, attn_state: AttnState):\n        super().__init__()\n        self.attn_state = attn_state\n        self.first_maps = {}\n        self.prev_maps = {}\n\n    def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):\n        # Is self attention\n        if encoder_hidden_states is None:\n            t = self.attn_state.timestep\n            if self.attn_state.state == AttnState.STORE:\n                self.first_maps[t] = hidden_states.detach()\n                self.prev_maps[t] = hidden_states.detach()\n                res = super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, temb)\n            else:\n                if self.attn_state.state == AttnState.LOAD_AND_STORE_PREV:\n                    tmp = hidden_states.detach()\n                cross_map = torch.cat((self.first_maps[t], self.prev_maps[t]), dim=1)\n                res = super().__call__(attn, hidden_states, cross_map, attention_mask, temb)\n                if self.attn_state.state == AttnState.LOAD_AND_STORE_PREV:\n                    self.prev_maps[t] = tmp\n        else:\n            res = super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, temb)\n\n        return res\n\n\ndef prepare_image(image):\n    if isinstance(image, torch.Tensor):\n        # Batch single image\n        if image.ndim == 3:\n            image = image.unsqueeze(0)\n\n        image = image.to(dtype=torch.float32)\n    else:\n        # preprocess image\n        if isinstance(image, (PIL.Image.Image, np.ndarray)):\n            image = [image]\n\n        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n            image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n            image = np.concatenate(image, axis=0)\n        elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n            image = np.concatenate([i[None, :] for i in image], axis=0)\n\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n    return image\n\n\nclass RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):\n    r\"\"\"\n    Pipeline for video-to-video translation using Stable Diffusion with Rerender Algorithm.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    In addition the pipeline inherits the following loading methods:\n        - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):\n            Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets\n            as a list, the outputs from each ControlNet are added together to create one combined additional\n            conditioning.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder=None,\n        requires_safety_checker: bool = True,\n        device=None,\n    ):\n        super().__init__(\n            vae,\n            text_encoder,\n            tokenizer,\n            unet,\n            controlnet,\n            scheduler,\n            safety_checker,\n            feature_extractor,\n            image_encoder,\n            requires_safety_checker,\n        )\n        self.to(device)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n        self.attn_state = AttnState()\n        attn_processor_dict = {}\n        for k in unet.attn_processors.keys():\n            if k.startswith(\"up\"):\n                attn_processor_dict[k] = CrossFrameAttnProcessor(self.attn_state)\n            else:\n                attn_processor_dict[k] = AttnProcessor()\n\n        self.unet.set_attn_processor(attn_processor_dict)\n\n        flow_model = GMFlow(\n            feature_channels=128,\n            num_scales=1,\n            upsample_factor=8,\n            num_head=1,\n            attention_type=\"swin\",\n            ffn_dim_expansion=4,\n            num_transformer_layers=6,\n        ).to(self.device)\n\n        checkpoint = torch.utils.model_zoo.load_url(\n            \"https://huggingface.co/Anonymous-sub/Rerender/resolve/main/models/gmflow_sintel-0c07dcb3.pth\",\n            map_location=lambda storage, loc: storage,\n        )\n        weights = checkpoint[\"model\"] if \"model\" in checkpoint else checkpoint\n        flow_model.load_state_dict(weights, strict=False)\n        flow_model.eval()\n        self.flow_model = flow_model\n\n    # Modified from src/diffusers/pipelines/controlnet/pipeline_controlnet.StableDiffusionControlNetImg2ImgPipeline.check_inputs\n    def check_inputs(\n        self,\n        prompt,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        controlnet_conditioning_scale=1.0,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n    ):\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        # `prompt` needs more sophisticated handling when there are multiple\n        # conditionings.\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if isinstance(prompt, list):\n                logger.warning(\n                    f\"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}\"\n                    \" prompts. The conditionings will be fixed across the prompts.\"\n                )\n\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if len(control_guidance_start) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image\n    def prepare_control_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        return timesteps, num_inference_steps - t_start\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents\n    def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            init_latents = image\n\n        else:\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            elif isinstance(generator, list):\n                init_latents = [\n                    self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)\n                ]\n                init_latents = torch.cat(init_latents, dim=0)\n            else:\n                init_latents = self.vae.encode(image).latent_dist.sample(generator)\n\n            init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            deprecation_message = (\n                f\"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial\"\n                \" images (`image`). Initial images are now duplicating to match the number of text prompts. Note\"\n                \" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update\"\n                \" your script to pass as many initial images as text prompts to suppress this warning.\"\n            )\n            deprecate(\"len(prompt) != len(image)\", \"1.0.0\", deprecation_message, standard_warn=False)\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        shape = init_latents.shape\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        # get latents\n        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n        latents = init_latents\n\n        return latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        frames: Union[List[np.ndarray], torch.Tensor] = None,\n        control_frames: Union[List[np.ndarray], torch.Tensor] = None,\n        strength: float = 0.8,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 0.8,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        warp_start: Union[float, List[float]] = 0.0,\n        warp_end: Union[float, List[float]] = 0.3,\n        mask_start: Union[float, List[float]] = 0.5,\n        mask_end: Union[float, List[float]] = 0.8,\n        smooth_boundary: bool = True,\n        mask_strength: Union[float, List[float]] = 0.5,\n        inner_strength: Union[float, List[float]] = 0.9,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            frames (`List[np.ndarray]` or `torch.Tensor`): The input images to be used as the starting point for the image generation process.\n            control_frames (`List[np.ndarray]` or `torch.Tensor` or `Callable`): The ControlNet input images condition to provide guidance to the `unet` for generation or any callable object to convert frame to control_frame.\n            strength ('float'): SDEdit strength.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original unet. If multiple ControlNets are specified in init, you can set the\n                corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting\n                than for [`~StableDiffusionControlNetPipeline.__call__`].\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                In this mode, the ControlNet encoder will try best to recognize the content of the input image even if\n                you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the controlnet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the controlnet stops applying.\n            warp_start (`float`): Shape-aware fusion start timestep.\n            warp_end (`float`): Shape-aware fusion end timestep.\n            mask_start (`float`): Pixel-aware fusion start timestep.\n            mask_end (`float`):Pixel-aware fusion end timestep.\n            smooth_boundary (`bool`): Smooth fusion boundary. Set `True` to prevent artifacts at boundary.\n            mask_strength (`float`): Pixel-aware fusion strength.\n            inner_strength (`float`): Pixel-aware fusion detail level.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n        )\n\n        # 2. Define call parameters\n        # Currently we only support 1 prompt\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            assert False\n        else:\n            assert False\n        num_images_per_prompt = 1\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        global_pool_conditions = (\n            controlnet.config.global_pool_conditions\n            if isinstance(controlnet, ControlNetModel)\n            else controlnet.nets[0].config.global_pool_conditions\n        )\n        guess_mode = guess_mode or global_pool_conditions\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 4. Process the first frame\n        height, width = None, None\n        output_frames = []\n        self.attn_state.reset()\n\n        # 4.1 prepare frames\n        image = self.image_processor.preprocess(frames[0]).to(dtype=self.dtype)\n        first_image = image[0]  # C, H, W\n\n        # 4.2 Prepare controlnet_conditioning_image\n        # Currently we only support single control\n        if isinstance(controlnet, ControlNetModel):\n            control_image = self.prepare_control_image(\n                image=control_frames(frames[0]) if callable(control_frames) else control_frames[0],\n                width=width,\n                height=height,\n                batch_size=batch_size,\n                num_images_per_prompt=1,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n        else:\n            assert False\n\n        # 4.3 Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps, cur_num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n        latent_timestep = timesteps[:1].repeat(batch_size)\n\n        # 4.4 Prepare latent variables\n        latents = self.prepare_latents(\n            image,\n            latent_timestep,\n            batch_size,\n            num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator,\n        )\n\n        # 4.5 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 4.6 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n        first_x0_list = []\n\n        # 4.7 Denoising loop\n        num_warmup_steps = len(timesteps) - cur_num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=cur_num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                self.attn_state.set_timestep(t.item())\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # controlnet(s) inference\n                if guess_mode and do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                else:\n                    control_model_input = latent_model_input\n                    controlnet_prompt_embeds = prompt_embeds\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=controlnet_prompt_embeds,\n                    controlnet_cond=control_image,\n                    conditioning_scale=cond_scale,\n                    guess_mode=guess_mode,\n                    return_dict=False,\n                )\n\n                if guess_mode and do_classifier_free_guidance:\n                    # Inferred ControlNet only for the conditional batch.\n                    # To apply the output of ControlNet to both the unconditional and conditional batches,\n                    # add 0 to the unconditional batch to keep it unchanged.\n                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                alpha_prod_t = self.scheduler.alphas_cumprod[t]\n                beta_prod_t = 1 - alpha_prod_t\n                pred_x0 = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)\n                first_x0 = pred_x0.detach()\n                first_x0_list.append(first_x0)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        callback(i, t, latents)\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n        else:\n            image = latents\n\n        first_result = image\n        prev_result = image\n        do_denormalize = [True] * image.shape[0]\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        output_frames.append(image[0])\n\n        # 5. Process each frame\n        for idx in range(1, len(frames)):\n            image = frames[idx]\n            prev_image = frames[idx - 1]\n            control_image = control_frames(image) if callable(control_frames) else control_frames[idx]\n            # 5.1 prepare frames\n            image = self.image_processor.preprocess(image).to(dtype=self.dtype)\n            prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)\n\n            warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(\n                self.flow_model, first_image, image[0], first_result, False, self.device\n            )\n            blend_mask_0 = blur(F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4))\n            blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1)\n\n            warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask(\n                self.flow_model, prev_image[0], image[0], prev_result, False, self.device\n            )\n            blend_mask_pre = blur(F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4))\n            blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1)\n\n            warp_mask = 1 - F.max_pool2d(blend_mask_0, kernel_size=8)\n            warp_flow = F.interpolate(bwd_flow_0 / 8.0, scale_factor=1.0 / 8, mode=\"bilinear\")\n\n            # 5.2 Prepare controlnet_conditioning_image\n            # Currently we only support single control\n            if isinstance(controlnet, ControlNetModel):\n                control_image = self.prepare_control_image(\n                    image=control_image,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size,\n                    num_images_per_prompt=1,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n            else:\n                assert False\n\n            # 5.3 Prepare timesteps\n            self.scheduler.set_timesteps(num_inference_steps, device=device)\n            timesteps, cur_num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n            latent_timestep = timesteps[:1].repeat(batch_size)\n\n            skip_t = int(num_inference_steps * (1 - strength))\n            warp_start_t = int(warp_start * num_inference_steps)\n            warp_end_t = int(warp_end * num_inference_steps)\n            mask_start_t = int(mask_start * num_inference_steps)\n            mask_end_t = int(mask_end * num_inference_steps)\n\n            # 5.4 Prepare latent variables\n            init_latents = self.prepare_latents(\n                image,\n                latent_timestep,\n                batch_size,\n                num_images_per_prompt,\n                prompt_embeds.dtype,\n                device,\n                generator,\n            )\n\n            # 5.5 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n            extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n            # 5.6 Create tensor stating which controlnets to keep\n            controlnet_keep = []\n            for i in range(len(timesteps)):\n                keeps = [\n                    1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                    for s, e in zip(control_guidance_start, control_guidance_end)\n                ]\n                controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n            # 5.7 Denoising loop\n            num_warmup_steps = len(timesteps) - cur_num_inference_steps * self.scheduler.order\n\n            def denoising_loop(latents, mask=None, xtrg=None, noise_rescale=None):\n                dir_xt = 0\n                latents_dtype = latents.dtype\n                with self.progress_bar(total=cur_num_inference_steps) as progress_bar:\n                    for i, t in enumerate(timesteps):\n                        self.attn_state.set_timestep(t.item())\n                        if i + skip_t >= mask_start_t and i + skip_t <= mask_end_t and xtrg is not None:\n                            rescale = torch.maximum(1.0 - mask, (1 - mask**2) ** 0.5 * inner_strength)\n                            if noise_rescale is not None:\n                                rescale = (1.0 - mask) * (1 - noise_rescale) + rescale * noise_rescale\n                            noise = randn_tensor(xtrg.shape, generator=generator, device=device, dtype=xtrg.dtype)\n                            latents_ref = self.scheduler.add_noise(xtrg, noise, t)\n                            latents = latents_ref * mask + (1.0 - mask) * (latents - dir_xt) + rescale * dir_xt\n                            latents = latents.to(latents_dtype)\n\n                        # expand the latents if we are doing classifier free guidance\n                        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                        # controlnet(s) inference\n                        if guess_mode and do_classifier_free_guidance:\n                            # Infer ControlNet only for the conditional batch.\n                            control_model_input = latents\n                            control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                            controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                        else:\n                            control_model_input = latent_model_input\n                            controlnet_prompt_embeds = prompt_embeds\n\n                        if isinstance(controlnet_keep[i], list):\n                            cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                        else:\n                            controlnet_cond_scale = controlnet_conditioning_scale\n                            if isinstance(controlnet_cond_scale, list):\n                                controlnet_cond_scale = controlnet_cond_scale[0]\n                            cond_scale = controlnet_cond_scale * controlnet_keep[i]\n                        down_block_res_samples, mid_block_res_sample = self.controlnet(\n                            control_model_input,\n                            t,\n                            encoder_hidden_states=controlnet_prompt_embeds,\n                            controlnet_cond=control_image,\n                            conditioning_scale=cond_scale,\n                            guess_mode=guess_mode,\n                            return_dict=False,\n                        )\n\n                        if guess_mode and do_classifier_free_guidance:\n                            # Inferred ControlNet only for the conditional batch.\n                            # To apply the output of ControlNet to both the unconditional and conditional batches,\n                            # add 0 to the unconditional batch to keep it unchanged.\n                            down_block_res_samples = [\n                                torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples\n                            ]\n                            mid_block_res_sample = torch.cat(\n                                [torch.zeros_like(mid_block_res_sample), mid_block_res_sample]\n                            )\n\n                        # predict the noise residual\n                        noise_pred = self.unet(\n                            latent_model_input,\n                            t,\n                            encoder_hidden_states=prompt_embeds,\n                            cross_attention_kwargs=cross_attention_kwargs,\n                            down_block_additional_residuals=down_block_res_samples,\n                            mid_block_additional_residual=mid_block_res_sample,\n                            return_dict=False,\n                        )[0]\n\n                        # perform guidance\n                        if do_classifier_free_guidance:\n                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                        # Get pred_x0 from scheduler\n                        alpha_prod_t = self.scheduler.alphas_cumprod[t]\n                        beta_prod_t = 1 - alpha_prod_t\n                        pred_x0 = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)\n\n                        if i + skip_t >= warp_start_t and i + skip_t <= warp_end_t:\n                            # warp x_0\n                            pred_x0 = (\n                                flow_warp(first_x0_list[i], warp_flow, mode=\"nearest\") * warp_mask\n                                + (1 - warp_mask) * pred_x0\n                            )\n\n                            # get x_t from x_0\n                            latents = self.scheduler.add_noise(pred_x0, noise_pred, t).to(latents_dtype)\n\n                        prev_t = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps\n                        if i == len(timesteps) - 1:\n                            alpha_t_prev = 1.0\n                        else:\n                            alpha_t_prev = self.scheduler.alphas_cumprod[prev_t]\n\n                        dir_xt = (1.0 - alpha_t_prev) ** 0.5 * noise_pred\n\n                        # compute the previous noisy sample x_t -> x_t-1\n                        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[\n                            0\n                        ]\n\n                        # call the callback, if provided\n                        if i == len(timesteps) - 1 or (\n                            (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0\n                        ):\n                            progress_bar.update()\n                            if callback is not None and i % callback_steps == 0:\n                                callback(i, t, latents)\n\n                        if XLA_AVAILABLE:\n                            xm.mark_step()\n\n                    return latents\n\n            if mask_start_t <= mask_end_t:\n                self.attn_state.to_load()\n            else:\n                self.attn_state.to_load_and_store_prev()\n            latents = denoising_loop(init_latents)\n\n            if mask_start_t <= mask_end_t:\n                direct_result = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n\n                blend_results = (1 - blend_mask_pre) * warped_pre + blend_mask_pre * direct_result\n                blend_results = (1 - blend_mask_0) * warped_0 + blend_mask_0 * blend_results\n\n                bwd_occ = 1 - torch.clamp(1 - bwd_occ_pre + 1 - bwd_occ_0, 0, 1)\n                blend_mask = blur(F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4))\n                blend_mask = 1 - torch.clamp(blend_mask + bwd_occ, 0, 1)\n\n                blend_results = blend_results.to(latents.dtype)\n                xtrg = self.vae.encode(blend_results).latent_dist.sample(generator)\n                xtrg = self.vae.config.scaling_factor * xtrg\n                blend_results_rec = self.vae.decode(xtrg / self.vae.config.scaling_factor, return_dict=False)[0]\n                xtrg_rec = self.vae.encode(blend_results_rec).latent_dist.sample(generator)\n                xtrg_rec = self.vae.config.scaling_factor * xtrg_rec\n                xtrg_ = xtrg + (xtrg - xtrg_rec)\n                blend_results_rec_new = self.vae.decode(xtrg_ / self.vae.config.scaling_factor, return_dict=False)[0]\n                tmp = (abs(blend_results_rec_new - blend_results).mean(dim=1, keepdims=True) > 0.25).float()\n\n                mask_x = F.max_pool2d(\n                    (F.interpolate(tmp, scale_factor=1 / 8.0, mode=\"bilinear\") > 0).float(),\n                    kernel_size=3,\n                    stride=1,\n                    padding=1,\n                )\n\n                mask = 1 - F.max_pool2d(1 - blend_mask, kernel_size=8)  # * (1-mask_x)\n\n                if smooth_boundary:\n                    noise_rescale = find_flat_region(mask)\n                else:\n                    noise_rescale = torch.ones_like(mask)\n\n                xtrg = (xtrg + (1 - mask_x) * (xtrg - xtrg_rec)) * mask\n                xtrg = xtrg.to(latents.dtype)\n\n                self.scheduler.set_timesteps(num_inference_steps, device=device)\n                timesteps, cur_num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n\n                self.attn_state.to_load_and_store_prev()\n                latents = denoising_loop(init_latents, mask * mask_strength, xtrg, noise_rescale)\n\n            if not output_type == \"latent\":\n                image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n            else:\n                image = latents\n\n            prev_result = image\n\n            do_denormalize = [True] * image.shape[0]\n            image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n            output_frames.append(image[0])\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return output_frames\n\n        return TextToVideoSDPipelineOutput(frames=output_frames)\n\n\nclass InputPadder:\n    \"\"\"Pads images such that dimensions are divisible by 8\"\"\"\n\n    def __init__(self, dims, mode=\"sintel\", padding_factor=8):\n        self.ht, self.wd = dims[-2:]\n        pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor\n        pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor\n        if mode == \"sintel\":\n            self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2]\n        else:\n            self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]\n\n    def pad(self, *inputs):\n        return [F.pad(x, self._pad, mode=\"replicate\") for x in inputs]\n\n    def unpad(self, x):\n        ht, wd = x.shape[-2:]\n        c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]\n        return x[..., c[0] : c[1], c[2] : c[3]]\n"
  },
  {
    "path": "examples/community/run_onnx_controlnet.py",
    "content": "import argparse\nimport inspect\nimport os\nimport time\nimport warnings\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom PIL import Image\nfrom transformers import CLIPTokenizer\n\nfrom diffusers import OnnxRuntimeModel, StableDiffusionImg2ImgPipeline, UniPCMultistepScheduler\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    deprecate,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> # !pip install opencv-python transformers accelerate\n        >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler\n        >>> from diffusers.utils import load_image\n        >>> import numpy as np\n        >>> import torch\n\n        >>> import cv2\n        >>> from PIL import Image\n\n        >>> # download an image\n        >>> image = load_image(\n        ...     \"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\"\n        ... )\n        >>> np_image = np.array(image)\n\n        >>> # get canny image\n        >>> np_image = cv2.Canny(np_image, 100, 200)\n        >>> np_image = np_image[:, :, None]\n        >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2)\n        >>> canny_image = Image.fromarray(np_image)\n\n        >>> # load control net and stable diffusion v1-5\n        >>> controlnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\n        >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(\n        ...     \"stable-diffusion-v1-5/stable-diffusion-v1-5\", controlnet=controlnet, torch_dtype=torch.float16\n        ... )\n\n        >>> # speed up diffusion process with faster scheduler and memory optimization\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> # generate image\n        >>> generator = torch.manual_seed(0)\n        >>> image = pipe(\n        ...     \"futuristic-looking woman\",\n        ...     num_inference_steps=20,\n        ...     generator=generator,\n        ...     image=image,\n        ...     control_image=canny_image,\n        ... ).images[0]\n        ```\n\"\"\"\n\n\ndef prepare_image(image):\n    if isinstance(image, torch.Tensor):\n        # Batch single image\n        if image.ndim == 3:\n            image = image.unsqueeze(0)\n\n        image = image.to(dtype=torch.float32)\n    else:\n        # preprocess image\n        if isinstance(image, (PIL.Image.Image, np.ndarray)):\n            image = [image]\n\n        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n            image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n            image = np.concatenate(image, axis=0)\n        elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n            image = np.concatenate([i[None, :] for i in image], axis=0)\n\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n    return image\n\n\nclass OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):\n    vae_encoder: OnnxRuntimeModel\n    vae_decoder: OnnxRuntimeModel\n    text_encoder: OnnxRuntimeModel\n    tokenizer: CLIPTokenizer\n    unet: OnnxRuntimeModel\n    scheduler: KarrasDiffusionSchedulers\n\n    def __init__(\n        self,\n        vae_encoder: OnnxRuntimeModel,\n        vae_decoder: OnnxRuntimeModel,\n        text_encoder: OnnxRuntimeModel,\n        tokenizer: CLIPTokenizer,\n        unet: OnnxRuntimeModel,\n        scheduler: KarrasDiffusionSchedulers,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae_encoder=vae_encoder,\n            vae_decoder=vae_decoder,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n        )\n        self.vae_scale_factor = 2 ** (4 - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n\n    def _encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        num_images_per_prompt: Optional[int],\n        do_classifier_free_guidance: bool,\n        negative_prompt: str | None,\n        prompt_embeds: Optional[np.ndarray] = None,\n        negative_prompt_embeds: Optional[np.ndarray] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                prompt to be encoded\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            prompt_embeds (`np.ndarray`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`np.ndarray`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n        \"\"\"\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # get prompt text embeddings\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"np\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"max_length\", return_tensors=\"np\").input_ids\n\n            if not np.array_equal(text_input_ids, untruncated_ids):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]\n\n        prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt] * batch_size\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"np\",\n            )\n            negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]\n\n        if do_classifier_free_guidance:\n            negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents\n    def decode_latents(self, latents):\n        warnings.warn(\n            \"The decode_latents method is deprecated and will be removed in a future version. Please\"\n            \" use VaeImageProcessor instead\",\n            FutureWarning,\n        )\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        num_controlnet,\n        prompt,\n        image,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        controlnet_conditioning_scale=1.0,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n    ):\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        # Check `image`\n        if num_controlnet == 1:\n            self.check_image(image, prompt, prompt_embeds)\n        elif num_controlnet > 1:\n            if not isinstance(image, list):\n                raise TypeError(\"For multiple controlnets: `image` must be type `list`\")\n\n            # When `image` is a nested list:\n            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])\n            elif any(isinstance(i, list) for i in image):\n                raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif len(image) != num_controlnet:\n                raise ValueError(\n                    f\"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {num_controlnet} ControlNets.\"\n                )\n\n            for image_ in image:\n                self.check_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n        if num_controlnet == 1:\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif num_controlnet > 1:\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif (\n                isinstance(controlnet_conditioning_scale, list)\n                and len(controlnet_conditioning_scale) != num_controlnet\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if num_controlnet > 1:\n            if len(control_guidance_start) != num_controlnet:\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {num_controlnet} controlnets available. Make sure to provide {num_controlnet}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image\n    def prepare_control_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        return timesteps, num_inference_steps - t_start\n\n    def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            init_latents = image\n\n        else:\n            _image = image.cpu().detach().numpy()\n            init_latents = self.vae_encoder(sample=_image)[0]\n            init_latents = torch.from_numpy(init_latents).to(device=device, dtype=dtype)\n            init_latents = 0.18215 * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            deprecation_message = (\n                f\"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial\"\n                \" images (`image`). Initial images are now duplicating to match the number of text prompts. Note\"\n                \" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update\"\n                \" your script to pass as many initial images as text prompts to suppress this warning.\"\n            )\n            deprecate(\"len(prompt) != len(image)\", \"1.0.0\", deprecation_message, standard_warn=False)\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        shape = init_latents.shape\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        # get latents\n        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n        latents = init_latents\n\n        return latents\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        num_controlnet: int,\n        fp16: bool = True,\n        prompt: Union[str, List[str]] = None,\n        image: Union[\n            torch.Tensor,\n            PIL.Image.Image,\n            np.ndarray,\n            List[torch.Tensor],\n            List[PIL.Image.Image],\n            List[np.ndarray],\n        ] = None,\n        control_image: Union[\n            torch.Tensor,\n            PIL.Image.Image,\n            np.ndarray,\n            List[torch.Tensor],\n            List[PIL.Image.Image],\n            List[np.ndarray],\n        ] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        strength: float = 0.8,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 0.8,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The initial image will be used as the starting point for the image generation process. Can also accept\n                image latents as `image`, if passing latents directly, it will not be encoded again.\n            control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If\n                the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can\n                also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If\n                height and/or width are passed, `image` is resized according to them. If multiple ControlNets are\n                specified in init, images must be passed as a list such that each element of the list can be correctly\n                batched for input to a single controlnet.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original unet. If multiple ControlNets are specified in init, you can set the\n                corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting\n                than for [`~StableDiffusionControlNetPipeline.__call__`].\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                In this mode, the ControlNet encoder will try best to recognize the content of the input image even if\n                you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the controlnet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the controlnet stops applying.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        if fp16:\n            torch_dtype = torch.float16\n            np_dtype = np.float16\n        else:\n            torch_dtype = torch.float32\n            np_dtype = np.float32\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = num_controlnet\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            num_controlnet,\n            prompt,\n            control_image,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        if num_controlnet > 1 and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * num_controlnet\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n        # 4. Prepare image\n        image = self.image_processor.preprocess(image).to(dtype=torch.float32)\n\n        # 5. Prepare controlnet_conditioning_image\n        if num_controlnet == 1:\n            control_image = self.prepare_control_image(\n                image=control_image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=torch_dtype,\n                do_classifier_free_guidance=do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n        elif num_controlnet > 1:\n            control_images = []\n\n            for control_image_ in control_image:\n                control_image_ = self.prepare_control_image(\n                    image=control_image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=torch_dtype,\n                    do_classifier_free_guidance=do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                control_images.append(control_image_)\n\n            control_image = control_images\n        else:\n            assert False\n\n        # 5. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 6. Prepare latent variables\n        latents = self.prepare_latents(\n            image,\n            latent_timestep,\n            batch_size,\n            num_images_per_prompt,\n            torch_dtype,\n            device,\n            generator,\n        )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7.1 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if num_controlnet == 1 else keeps)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                # predict the noise residual\n                _latent_model_input = latent_model_input.cpu().detach().numpy()\n                _prompt_embeds = np.array(prompt_embeds, dtype=np_dtype)\n                _t = np.array([t.cpu().detach().numpy()], dtype=np_dtype)\n\n                if num_controlnet == 1:\n                    control_images = np.array([control_image], dtype=np_dtype)\n                else:\n                    control_images = []\n                    for _control_img in control_image:\n                        _control_img = _control_img.cpu().detach().numpy()\n                        control_images.append(_control_img)\n                    control_images = np.array(control_images, dtype=np_dtype)\n\n                control_scales = np.array(cond_scale, dtype=np_dtype)\n                control_scales = np.resize(control_scales, (num_controlnet, 1))\n\n                noise_pred = self.unet(\n                    sample=_latent_model_input,\n                    timestep=_t,\n                    encoder_hidden_states=_prompt_embeds,\n                    controlnet_conds=control_images,\n                    conditioning_scales=control_scales,\n                )[0]\n                noise_pred = torch.from_numpy(noise_pred).to(device)\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            _latents = latents.cpu().detach().numpy() / 0.18215\n            _latents = np.array(_latents, dtype=np_dtype)\n            image = self.vae_decoder(latent_sample=_latents)[0]\n            image = torch.from_numpy(image).to(device, dtype=torch.float32)\n            has_nsfw_concept = None\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\n        \"--sd_model\",\n        type=str,\n        required=True,\n        help=\"Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).\",\n    )\n\n    parser.add_argument(\n        \"--onnx_model_dir\",\n        type=str,\n        required=True,\n        help=\"Path to the ONNX directory\",\n    )\n\n    parser.add_argument(\"--qr_img_path\", type=str, required=True, help=\"Path to the qr code image\")\n\n    args = parser.parse_args()\n\n    qr_image = Image.open(args.qr_img_path)\n    qr_image = qr_image.resize((512, 512))\n\n    # init stable diffusion pipeline\n    pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(args.sd_model)\n    pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)\n\n    provider = [\"CUDAExecutionProvider\", \"CPUExecutionProvider\"]\n    onnx_pipeline = OnnxStableDiffusionControlNetImg2ImgPipeline(\n        vae_encoder=OnnxRuntimeModel.from_pretrained(\n            os.path.join(args.onnx_model_dir, \"vae_encoder\"), provider=provider\n        ),\n        vae_decoder=OnnxRuntimeModel.from_pretrained(\n            os.path.join(args.onnx_model_dir, \"vae_decoder\"), provider=provider\n        ),\n        text_encoder=OnnxRuntimeModel.from_pretrained(\n            os.path.join(args.onnx_model_dir, \"text_encoder\"), provider=provider\n        ),\n        tokenizer=pipeline.tokenizer,\n        unet=OnnxRuntimeModel.from_pretrained(os.path.join(args.onnx_model_dir, \"unet\"), provider=provider),\n        scheduler=pipeline.scheduler,\n    )\n    onnx_pipeline = onnx_pipeline.to(\"cuda\")\n\n    prompt = \"a cute cat fly to the moon\"\n    negative_prompt = \"paintings, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples, necklace, worst quality, low quality, watermark, username, signature, multiple breasts, lowres, bad anatomy, bad hands, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, bad feet, single color, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, disfigured, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, bad body perspect\"\n\n    for i in range(10):\n        start_time = time.time()\n        image = onnx_pipeline(\n            num_controlnet=2,\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            image=qr_image,\n            control_image=[qr_image, qr_image],\n            width=512,\n            height=512,\n            strength=0.75,\n            num_inference_steps=20,\n            num_images_per_prompt=1,\n            controlnet_conditioning_scale=[0.8, 0.8],\n            control_guidance_start=[0.3, 0.3],\n            control_guidance_end=[0.9, 0.9],\n        ).images[0]\n        print(time.time() - start_time)\n        image.save(\"output_qr_code.png\")\n"
  },
  {
    "path": "examples/community/run_tensorrt_controlnet.py",
    "content": "import argparse\nimport atexit\nimport inspect\nimport os\nimport time\nimport warnings\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport pycuda.driver as cuda\nimport tensorrt as trt\nimport torch\nfrom PIL import Image\nfrom pycuda.tools import make_default_context\nfrom transformers import CLIPTokenizer\n\nfrom diffusers import OnnxRuntimeModel, StableDiffusionImg2ImgPipeline, UniPCMultistepScheduler\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    deprecate,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\n# Initialize CUDA\ncuda.init()\ncontext = make_default_context()\ndevice = context.get_device()\natexit.register(context.pop)\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef load_engine(trt_runtime, engine_path):\n    with open(engine_path, \"rb\") as f:\n        engine_data = f.read()\n    engine = trt_runtime.deserialize_cuda_engine(engine_data)\n    return engine\n\n\nclass TensorRTModel:\n    def __init__(\n        self,\n        trt_engine_path,\n        **kwargs,\n    ):\n        cuda.init()\n        stream = cuda.Stream()\n        TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)\n        trt.init_libnvinfer_plugins(TRT_LOGGER, \"\")\n        trt_runtime = trt.Runtime(TRT_LOGGER)\n        engine = load_engine(trt_runtime, trt_engine_path)\n        context = engine.create_execution_context()\n\n        # allocates memory for network inputs/outputs on both CPU and GPU\n        host_inputs = []\n        cuda_inputs = []\n        host_outputs = []\n        cuda_outputs = []\n        bindings = []\n        input_names = []\n        output_names = []\n\n        for binding in engine:\n            datatype = engine.get_binding_dtype(binding)\n            if datatype == trt.DataType.HALF:\n                dtype = np.float16\n            else:\n                dtype = np.float32\n\n            shape = tuple(engine.get_binding_shape(binding))\n            host_mem = cuda.pagelocked_empty(shape, dtype)\n            cuda_mem = cuda.mem_alloc(host_mem.nbytes)\n            bindings.append(int(cuda_mem))\n\n            if engine.binding_is_input(binding):\n                host_inputs.append(host_mem)\n                cuda_inputs.append(cuda_mem)\n                input_names.append(binding)\n            else:\n                host_outputs.append(host_mem)\n                cuda_outputs.append(cuda_mem)\n                output_names.append(binding)\n\n        self.stream = stream\n        self.context = context\n        self.engine = engine\n\n        self.host_inputs = host_inputs\n        self.cuda_inputs = cuda_inputs\n        self.host_outputs = host_outputs\n        self.cuda_outputs = cuda_outputs\n        self.bindings = bindings\n        self.batch_size = engine.max_batch_size\n\n        self.input_names = input_names\n        self.output_names = output_names\n\n    def __call__(self, **kwargs):\n        context = self.context\n        stream = self.stream\n        bindings = self.bindings\n\n        host_inputs = self.host_inputs\n        cuda_inputs = self.cuda_inputs\n        host_outputs = self.host_outputs\n        cuda_outputs = self.cuda_outputs\n\n        for idx, input_name in enumerate(self.input_names):\n            _input = kwargs[input_name]\n            np.copyto(host_inputs[idx], _input)\n            # transfer input data to the GPU\n            cuda.memcpy_htod_async(cuda_inputs[idx], host_inputs[idx], stream)\n\n        context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)\n\n        result = {}\n        for idx, output_name in enumerate(self.output_names):\n            # transfer predictions back from the GPU\n            cuda.memcpy_dtoh_async(host_outputs[idx], cuda_outputs[idx], stream)\n            result[output_name] = host_outputs[idx]\n\n        stream.synchronize()\n\n        return result\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> # !pip install opencv-python transformers accelerate\n        >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler\n        >>> from diffusers.utils import load_image\n        >>> import numpy as np\n        >>> import torch\n\n        >>> import cv2\n        >>> from PIL import Image\n\n        >>> # download an image\n        >>> image = load_image(\n        ...     \"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\"\n        ... )\n        >>> np_image = np.array(image)\n\n        >>> # get canny image\n        >>> np_image = cv2.Canny(np_image, 100, 200)\n        >>> np_image = np_image[:, :, None]\n        >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2)\n        >>> canny_image = Image.fromarray(np_image)\n\n        >>> # load control net and stable diffusion v1-5\n        >>> controlnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\n        >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(\n        ...     \"stable-diffusion-v1-5/stable-diffusion-v1-5\", controlnet=controlnet, torch_dtype=torch.float16\n        ... )\n\n        >>> # speed up diffusion process with faster scheduler and memory optimization\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> # generate image\n        >>> generator = torch.manual_seed(0)\n        >>> image = pipe(\n        ...     \"futuristic-looking woman\",\n        ...     num_inference_steps=20,\n        ...     generator=generator,\n        ...     image=image,\n        ...     control_image=canny_image,\n        ... ).images[0]\n        ```\n\"\"\"\n\n\ndef prepare_image(image):\n    if isinstance(image, torch.Tensor):\n        # Batch single image\n        if image.ndim == 3:\n            image = image.unsqueeze(0)\n\n        image = image.to(dtype=torch.float32)\n    else:\n        # preprocess image\n        if isinstance(image, (PIL.Image.Image, np.ndarray)):\n            image = [image]\n\n        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n            image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n            image = np.concatenate(image, axis=0)\n        elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n            image = np.concatenate([i[None, :] for i in image], axis=0)\n\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n    return image\n\n\nclass TensorRTStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):\n    vae_encoder: OnnxRuntimeModel\n    vae_decoder: OnnxRuntimeModel\n    text_encoder: OnnxRuntimeModel\n    tokenizer: CLIPTokenizer\n    unet: TensorRTModel\n    scheduler: KarrasDiffusionSchedulers\n\n    def __init__(\n        self,\n        vae_encoder: OnnxRuntimeModel,\n        vae_decoder: OnnxRuntimeModel,\n        text_encoder: OnnxRuntimeModel,\n        tokenizer: CLIPTokenizer,\n        unet: TensorRTModel,\n        scheduler: KarrasDiffusionSchedulers,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae_encoder=vae_encoder,\n            vae_decoder=vae_decoder,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n        )\n        self.vae_scale_factor = 2 ** (4 - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n\n    def _encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        num_images_per_prompt: Optional[int],\n        do_classifier_free_guidance: bool,\n        negative_prompt: str | None,\n        prompt_embeds: Optional[np.ndarray] = None,\n        negative_prompt_embeds: Optional[np.ndarray] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                prompt to be encoded\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            prompt_embeds (`np.ndarray`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`np.ndarray`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n        \"\"\"\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # get prompt text embeddings\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"np\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"max_length\", return_tensors=\"np\").input_ids\n\n            if not np.array_equal(text_input_ids, untruncated_ids):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]\n\n        prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt] * batch_size\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"np\",\n            )\n            negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]\n\n        if do_classifier_free_guidance:\n            negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents\n    def decode_latents(self, latents):\n        warnings.warn(\n            \"The decode_latents method is deprecated and will be removed in a future version. Please\"\n            \" use VaeImageProcessor instead\",\n            FutureWarning,\n        )\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        num_controlnet,\n        prompt,\n        image,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        controlnet_conditioning_scale=1.0,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n    ):\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        # Check `image`\n        if num_controlnet == 1:\n            self.check_image(image, prompt, prompt_embeds)\n        elif num_controlnet > 1:\n            if not isinstance(image, list):\n                raise TypeError(\"For multiple controlnets: `image` must be type `list`\")\n\n            # When `image` is a nested list:\n            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])\n            elif any(isinstance(i, list) for i in image):\n                raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif len(image) != num_controlnet:\n                raise ValueError(\n                    f\"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {num_controlnet} ControlNets.\"\n                )\n\n            for image_ in image:\n                self.check_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n        if num_controlnet == 1:\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif num_controlnet > 1:\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif (\n                isinstance(controlnet_conditioning_scale, list)\n                and len(controlnet_conditioning_scale) != num_controlnet\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if num_controlnet > 1:\n            if len(control_guidance_start) != num_controlnet:\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {num_controlnet} controlnets available. Make sure to provide {num_controlnet}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image\n    def prepare_control_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n\n        return timesteps, num_inference_steps - t_start\n\n    def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            init_latents = image\n\n        else:\n            _image = image.cpu().detach().numpy()\n            init_latents = self.vae_encoder(sample=_image)[0]\n            init_latents = torch.from_numpy(init_latents).to(device=device, dtype=dtype)\n            init_latents = 0.18215 * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            deprecation_message = (\n                f\"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial\"\n                \" images (`image`). Initial images are now duplicating to match the number of text prompts. Note\"\n                \" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update\"\n                \" your script to pass as many initial images as text prompts to suppress this warning.\"\n            )\n            deprecate(\"len(prompt) != len(image)\", \"1.0.0\", deprecation_message, standard_warn=False)\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        shape = init_latents.shape\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        # get latents\n        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n        latents = init_latents\n\n        return latents\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        num_controlnet: int,\n        fp16: bool = True,\n        prompt: Union[str, List[str]] = None,\n        image: Union[\n            torch.Tensor,\n            PIL.Image.Image,\n            np.ndarray,\n            List[torch.Tensor],\n            List[PIL.Image.Image],\n            List[np.ndarray],\n        ] = None,\n        control_image: Union[\n            torch.Tensor,\n            PIL.Image.Image,\n            np.ndarray,\n            List[torch.Tensor],\n            List[PIL.Image.Image],\n            List[np.ndarray],\n        ] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        strength: float = 0.8,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 0.8,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The initial image will be used as the starting point for the image generation process. Can also accept\n                image latents as `image`, if passing latents directly, it will not be encoded again.\n            control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If\n                the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can\n                also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If\n                height and/or width are passed, `image` is resized according to them. If multiple ControlNets are\n                specified in init, images must be passed as a list such that each element of the list can be correctly\n                batched for input to a single controlnet.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original unet. If multiple ControlNets are specified in init, you can set the\n                corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting\n                than for [`~StableDiffusionControlNetPipeline.__call__`].\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                In this mode, the ControlNet encoder will try best to recognize the content of the input image even if\n                you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the controlnet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the controlnet stops applying.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        if fp16:\n            torch_dtype = torch.float16\n            np_dtype = np.float16\n        else:\n            torch_dtype = torch.float32\n            np_dtype = np.float32\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = num_controlnet\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            num_controlnet,\n            prompt,\n            control_image,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        if num_controlnet > 1 and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * num_controlnet\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n        # 4. Prepare image\n        image = self.image_processor.preprocess(image).to(dtype=torch.float32)\n\n        # 5. Prepare controlnet_conditioning_image\n        if num_controlnet == 1:\n            control_image = self.prepare_control_image(\n                image=control_image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=torch_dtype,\n                do_classifier_free_guidance=do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n        elif num_controlnet > 1:\n            control_images = []\n\n            for control_image_ in control_image:\n                control_image_ = self.prepare_control_image(\n                    image=control_image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=torch_dtype,\n                    do_classifier_free_guidance=do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                control_images.append(control_image_)\n\n            control_image = control_images\n        else:\n            assert False\n\n        # 5. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 6. Prepare latent variables\n        latents = self.prepare_latents(\n            image,\n            latent_timestep,\n            batch_size,\n            num_images_per_prompt,\n            torch_dtype,\n            device,\n            generator,\n        )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7.1 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if num_controlnet == 1 else keeps)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                # predict the noise residual\n                _latent_model_input = latent_model_input.cpu().detach().numpy()\n                _prompt_embeds = np.array(prompt_embeds, dtype=np_dtype)\n                _t = np.array([t.cpu().detach().numpy()], dtype=np_dtype)\n\n                if num_controlnet == 1:\n                    control_images = np.array([control_image], dtype=np_dtype)\n                else:\n                    control_images = []\n                    for _control_img in control_image:\n                        _control_img = _control_img.cpu().detach().numpy()\n                        control_images.append(_control_img)\n                    control_images = np.array(control_images, dtype=np_dtype)\n\n                control_scales = np.array(cond_scale, dtype=np_dtype)\n                control_scales = np.resize(control_scales, (num_controlnet, 1))\n\n                noise_pred = self.unet(\n                    sample=_latent_model_input,\n                    timestep=_t,\n                    encoder_hidden_states=_prompt_embeds,\n                    controlnet_conds=control_images,\n                    conditioning_scales=control_scales,\n                )[\"noise_pred\"]\n                noise_pred = torch.from_numpy(noise_pred).to(device)\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            _latents = latents.cpu().detach().numpy() / 0.18215\n            _latents = np.array(_latents, dtype=np_dtype)\n            image = self.vae_decoder(latent_sample=_latents)[0]\n            image = torch.from_numpy(image).to(device, dtype=torch.float32)\n            has_nsfw_concept = None\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\n        \"--sd_model\",\n        type=str,\n        required=True,\n        help=\"Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).\",\n    )\n\n    parser.add_argument(\n        \"--onnx_model_dir\",\n        type=str,\n        required=True,\n        help=\"Path to the ONNX directory\",\n    )\n\n    parser.add_argument(\n        \"--unet_engine_path\",\n        type=str,\n        required=True,\n        help=\"Path to the unet + controlnet tensorrt model\",\n    )\n\n    parser.add_argument(\"--qr_img_path\", type=str, required=True, help=\"Path to the qr code image\")\n\n    args = parser.parse_args()\n\n    qr_image = Image.open(args.qr_img_path)\n    qr_image = qr_image.resize((512, 512))\n\n    # init stable diffusion pipeline\n    pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(args.sd_model)\n    pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)\n\n    provider = [\"CUDAExecutionProvider\", \"CPUExecutionProvider\"]\n    onnx_pipeline = TensorRTStableDiffusionControlNetImg2ImgPipeline(\n        vae_encoder=OnnxRuntimeModel.from_pretrained(\n            os.path.join(args.onnx_model_dir, \"vae_encoder\"), provider=provider\n        ),\n        vae_decoder=OnnxRuntimeModel.from_pretrained(\n            os.path.join(args.onnx_model_dir, \"vae_decoder\"), provider=provider\n        ),\n        text_encoder=OnnxRuntimeModel.from_pretrained(\n            os.path.join(args.onnx_model_dir, \"text_encoder\"), provider=provider\n        ),\n        tokenizer=pipeline.tokenizer,\n        unet=TensorRTModel(args.unet_engine_path),\n        scheduler=pipeline.scheduler,\n    )\n    onnx_pipeline = onnx_pipeline.to(\"cuda\")\n\n    prompt = \"a cute cat fly to the moon\"\n    negative_prompt = \"paintings, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples, necklace, worst quality, low quality, watermark, username, signature, multiple breasts, lowres, bad anatomy, bad hands, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, bad feet, single color, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, disfigured, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, bad body perspect\"\n\n    for i in range(10):\n        start_time = time.time()\n        image = onnx_pipeline(\n            num_controlnet=2,\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            image=qr_image,\n            control_image=[qr_image, qr_image],\n            width=512,\n            height=512,\n            strength=0.75,\n            num_inference_steps=20,\n            num_images_per_prompt=1,\n            controlnet_conditioning_scale=[0.8, 0.8],\n            control_guidance_start=[0.3, 0.3],\n            control_guidance_end=[0.9, 0.9],\n        ).images[0]\n        print(time.time() - start_time)\n        image.save(\"output_qr_code.png\")\n"
  },
  {
    "path": "examples/community/scheduling_ufogen.py",
    "content": "# Copyright 2025 UC Berkeley Team and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim\n\nimport math\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.schedulers.scheduling_utils import SchedulerMixin\nfrom diffusers.utils import BaseOutput\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\n@dataclass\n# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UFOGen\nclass UFOGenSchedulerOutput(BaseOutput):\n    \"\"\"\n    Output class for the scheduler's `step` function output.\n\n    Args:\n        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):\n            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the\n            denoising loop.\n        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):\n            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.\n            `pred_original_sample` can be used to preview progress or for guidance.\n    \"\"\"\n\n    prev_sample: torch.Tensor\n    pred_original_sample: Optional[torch.Tensor] = None\n\n\n# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar\ndef betas_for_alpha_bar(\n    num_diffusion_timesteps,\n    max_beta=0.999,\n    alpha_transform_type=\"cosine\",\n):\n    \"\"\"\n    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of\n    (1-beta) over time from t = [0,1].\n\n    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up\n    to that part of the diffusion process.\n\n\n    Args:\n        num_diffusion_timesteps (`int`): the number of betas to produce.\n        max_beta (`float`): the maximum beta to use; use values lower than 1 to\n                     prevent singularities.\n        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.\n                     Choose from `cosine` or `exp`\n\n    Returns:\n        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs\n    \"\"\"\n    if alpha_transform_type == \"cosine\":\n\n        def alpha_bar_fn(t):\n            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2\n\n    elif alpha_transform_type == \"exp\":\n\n        def alpha_bar_fn(t):\n            return math.exp(t * -12.0)\n\n    else:\n        raise ValueError(f\"Unsupported alpha_transform_type: {alpha_transform_type}\")\n\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))\n    return torch.tensor(betas, dtype=torch.float32)\n\n\n# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr\ndef rescale_zero_terminal_snr(betas):\n    \"\"\"\n    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)\n\n\n    Args:\n        betas (`torch.Tensor`):\n            the betas that the scheduler is being initialized with.\n\n    Returns:\n        `torch.Tensor`: rescaled betas with zero terminal SNR\n    \"\"\"\n    # Convert betas to alphas_bar_sqrt\n    alphas = 1.0 - betas\n    alphas_cumprod = torch.cumprod(alphas, dim=0)\n    alphas_bar_sqrt = alphas_cumprod.sqrt()\n\n    # Store old values.\n    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()\n    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()\n\n    # Shift so the last timestep is zero.\n    alphas_bar_sqrt -= alphas_bar_sqrt_T\n\n    # Scale so the first timestep is back to the old value.\n    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)\n\n    # Convert alphas_bar_sqrt to betas\n    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt\n    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod\n    alphas = torch.cat([alphas_bar[0:1], alphas])\n    betas = 1 - alphas\n\n    return betas\n\n\nclass UFOGenScheduler(SchedulerMixin, ConfigMixin):\n    \"\"\"\n    `UFOGenScheduler` implements multistep and onestep sampling for a UFOGen model, introduced in\n    [UFOGen: You Forward Once Large Scale Text-to-Image Generation via Diffusion GANs](https://huggingface.co/papers/2311.09257)\n    by Yanwu Xu, Yang Zhao, Zhisheng Xiao, and Tingbo Hou. UFOGen is a varianet of the denoising diffusion GAN (DDGAN)\n    model designed for one-step sampling.\n\n    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic\n    methods the library implements for all schedulers such as loading and saving.\n\n    Args:\n        num_train_timesteps (`int`, defaults to 1000):\n            The number of diffusion steps to train the model.\n        beta_start (`float`, defaults to 0.0001):\n            The starting `beta` value of inference.\n        beta_end (`float`, defaults to 0.02):\n            The final `beta` value.\n        beta_schedule (`str`, defaults to `\"linear\"`):\n            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from\n            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.\n        clip_sample (`bool`, defaults to `True`):\n            Clip the predicted sample for numerical stability.\n        clip_sample_range (`float`, defaults to 1.0):\n            The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.\n        set_alpha_to_one (`bool`, defaults to `True`):\n            Each diffusion step uses the alphas product value at that step and at the previous one. For the final step\n            there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,\n            otherwise it uses the alpha value at step 0.\n        prediction_type (`str`, defaults to `epsilon`, *optional*):\n            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),\n            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen\n            Video](https://imagen.research.google/video/paper.pdf) paper).\n        thresholding (`bool`, defaults to `False`):\n            Whether to use the \"dynamic thresholding\" method. This is unsuitable for latent-space diffusion models such\n            as Stable Diffusion.\n        dynamic_thresholding_ratio (`float`, defaults to 0.995):\n            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.\n        sample_max_value (`float`, defaults to 1.0):\n            The threshold value for dynamic thresholding. Valid only when `thresholding=True`.\n        timestep_spacing (`str`, defaults to `\"leading\"`):\n            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and\n            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.\n        steps_offset (`int`, defaults to 0):\n            An offset added to the inference steps, as required by some model families.\n        rescale_betas_zero_snr (`bool`, defaults to `False`):\n            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and\n            dark samples instead of limiting it to samples with medium brightness. Loosely related to\n            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).\n        denoising_step_size (`int`, defaults to 250):\n            The denoising step size parameter from the UFOGen paper. The number of steps used for training is roughly\n            `math.ceil(num_train_timesteps / denoising_step_size)`.\n    \"\"\"\n\n    order = 1\n\n    @register_to_config\n    def __init__(\n        self,\n        num_train_timesteps: int = 1000,\n        beta_start: float = 0.0001,\n        beta_end: float = 0.02,\n        beta_schedule: str = \"linear\",\n        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,\n        clip_sample: bool = True,\n        set_alpha_to_one: bool = True,\n        prediction_type: str = \"epsilon\",\n        thresholding: bool = False,\n        dynamic_thresholding_ratio: float = 0.995,\n        clip_sample_range: float = 1.0,\n        sample_max_value: float = 1.0,\n        timestep_spacing: str = \"leading\",\n        steps_offset: int = 0,\n        rescale_betas_zero_snr: bool = False,\n        denoising_step_size: int = 250,\n    ):\n        if trained_betas is not None:\n            self.betas = torch.tensor(trained_betas, dtype=torch.float32)\n        elif beta_schedule == \"linear\":\n            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)\n        elif beta_schedule == \"scaled_linear\":\n            # this schedule is very specific to the latent diffusion model.\n            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2\n        elif beta_schedule == \"squaredcos_cap_v2\":\n            # Glide cosine schedule\n            self.betas = betas_for_alpha_bar(num_train_timesteps)\n        elif beta_schedule == \"sigmoid\":\n            # GeoDiff sigmoid schedule\n            betas = torch.linspace(-6, 6, num_train_timesteps)\n            self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start\n        else:\n            raise NotImplementedError(f\"{beta_schedule} is not implemented for {self.__class__}\")\n\n        # Rescale for zero SNR\n        if rescale_betas_zero_snr:\n            self.betas = rescale_zero_terminal_snr(self.betas)\n\n        self.alphas = 1.0 - self.betas\n        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)\n\n        # For the final step, there is no previous alphas_cumprod because we are already at 0\n        # `set_alpha_to_one` decides whether we set this parameter simply to one or\n        # whether we use the final alpha of the \"non-previous\" one.\n        self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]\n\n        # standard deviation of the initial noise distribution\n        self.init_noise_sigma = 1.0\n\n        # setable values\n        self.custom_timesteps = False\n        self.num_inference_steps = None\n        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())\n\n    def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:\n        \"\"\"\n        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the\n        current timestep.\n\n        Args:\n            sample (`torch.Tensor`):\n                The input sample.\n            timestep (`int`, *optional*):\n                The current timestep in the diffusion chain.\n\n        Returns:\n            `torch.Tensor`:\n                A scaled input sample.\n        \"\"\"\n        return sample\n\n    def set_timesteps(\n        self,\n        num_inference_steps: Optional[int] = None,\n        device: Union[str, torch.device] = None,\n        timesteps: Optional[List[int]] = None,\n    ):\n        \"\"\"\n        Sets the discrete timesteps used for the diffusion chain (to be run before inference).\n\n        Args:\n            num_inference_steps (`int`):\n                The number of diffusion steps used when generating samples with a pre-trained model. If used,\n                `timesteps` must be `None`.\n            device (`str` or `torch.device`, *optional*):\n                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed,\n                `num_inference_steps` must be `None`.\n\n        \"\"\"\n        if num_inference_steps is not None and timesteps is not None:\n            raise ValueError(\"Can only pass one of `num_inference_steps` or `custom_timesteps`.\")\n\n        if timesteps is not None:\n            for i in range(1, len(timesteps)):\n                if timesteps[i] >= timesteps[i - 1]:\n                    raise ValueError(\"`custom_timesteps` must be in descending order.\")\n\n            if timesteps[0] >= self.config.num_train_timesteps:\n                raise ValueError(\n                    f\"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}.\"\n                )\n\n            timesteps = np.array(timesteps, dtype=np.int64)\n            self.custom_timesteps = True\n        else:\n            if num_inference_steps > self.config.num_train_timesteps:\n                raise ValueError(\n                    f\"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:\"\n                    f\" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle\"\n                    f\" maximal {self.config.num_train_timesteps} timesteps.\"\n                )\n\n            self.num_inference_steps = num_inference_steps\n            self.custom_timesteps = False\n\n            # TODO: For now, handle special case when num_inference_steps == 1 separately\n            if num_inference_steps == 1:\n                # Set the timestep schedule to num_train_timesteps - 1 rather than 0\n                # (that is, the one-step timestep schedule is always trailing rather than leading or linspace)\n                timesteps = np.array([self.config.num_train_timesteps - 1], dtype=np.int64)\n            else:\n                # TODO: For now, retain the DDPM timestep spacing logic\n                # \"linspace\", \"leading\", \"trailing\" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891\n                if self.config.timestep_spacing == \"linspace\":\n                    timesteps = (\n                        np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)\n                        .round()[::-1]\n                        .copy()\n                        .astype(np.int64)\n                    )\n                elif self.config.timestep_spacing == \"leading\":\n                    step_ratio = self.config.num_train_timesteps // self.num_inference_steps\n                    # creates integer timesteps by multiplying by ratio\n                    # casting to int to avoid issues when num_inference_step is power of 3\n                    timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)\n                    timesteps += self.config.steps_offset\n                elif self.config.timestep_spacing == \"trailing\":\n                    step_ratio = self.config.num_train_timesteps / self.num_inference_steps\n                    # creates integer timesteps by multiplying by ratio\n                    # casting to int to avoid issues when num_inference_step is power of 3\n                    timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)\n                    timesteps -= 1\n                else:\n                    raise ValueError(\n                        f\"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'.\"\n                    )\n\n        self.timesteps = torch.from_numpy(timesteps).to(device)\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample\n    def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        \"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the\n        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by\n        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing\n        pixels from saturation at each step. We find that dynamic thresholding results in significantly better\n        photorealism as well as better image-text alignment, especially when using very large guidance weights.\"\n\n        https://huggingface.co/papers/2205.11487\n        \"\"\"\n        dtype = sample.dtype\n        batch_size, channels, *remaining_dims = sample.shape\n\n        if dtype not in (torch.float32, torch.float64):\n            sample = sample.float()  # upcast for quantile calculation, and clamp not implemented for cpu half\n\n        # Flatten sample for doing quantile calculation along each image\n        sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))\n\n        abs_sample = sample.abs()  # \"a certain percentile absolute pixel value\"\n\n        s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)\n        s = torch.clamp(\n            s, min=1, max=self.config.sample_max_value\n        )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]\n        s = s.unsqueeze(1)  # (batch_size, 1) because clamp will broadcast along dim=0\n        sample = torch.clamp(sample, -s, s) / s  # \"we threshold xt0 to the range [-s, s] and then divide by s\"\n\n        sample = sample.reshape(batch_size, channels, *remaining_dims)\n        sample = sample.to(dtype)\n\n        return sample\n\n    def step(\n        self,\n        model_output: torch.Tensor,\n        timestep: int,\n        sample: torch.Tensor,\n        generator: torch.Generator | None = None,\n        return_dict: bool = True,\n    ) -> Union[UFOGenSchedulerOutput, Tuple]:\n        \"\"\"\n        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion\n        process from the learned model outputs (most often the predicted noise).\n\n        Args:\n            model_output (`torch.Tensor`):\n                The direct output from learned diffusion model.\n            timestep (`float`):\n                The current discrete timestep in the diffusion chain.\n            sample (`torch.Tensor`):\n                A current instance of a sample created by the diffusion process.\n            generator (`torch.Generator`, *optional*):\n                A random number generator.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~schedulers.scheduling_ufogen.UFOGenSchedulerOutput`] or `tuple`.\n\n        Returns:\n            [`~schedulers.scheduling_ddpm.UFOGenSchedulerOutput`] or `tuple`:\n                If return_dict is `True`, [`~schedulers.scheduling_ufogen.UFOGenSchedulerOutput`] is returned, otherwise a\n                tuple is returned where the first element is the sample tensor.\n\n        \"\"\"\n        # 0. Resolve timesteps\n        t = timestep\n        prev_t = self.previous_timestep(t)\n\n        # 1. compute alphas, betas\n        alpha_prod_t = self.alphas_cumprod[t]\n        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.final_alpha_cumprod\n        beta_prod_t = 1 - alpha_prod_t\n        # beta_prod_t_prev = 1 - alpha_prod_t_prev\n        # current_alpha_t = alpha_prod_t / alpha_prod_t_prev\n        # current_beta_t = 1 - current_alpha_t\n\n        # 2. compute predicted original sample from predicted noise also called\n        # \"predicted x_0\" of formula (15) from https://huggingface.co/papers/2006.11239\n        if self.config.prediction_type == \"epsilon\":\n            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n        elif self.config.prediction_type == \"sample\":\n            pred_original_sample = model_output\n        elif self.config.prediction_type == \"v_prediction\":\n            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output\n        else:\n            raise ValueError(\n                f\"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or\"\n                \" `v_prediction`  for UFOGenScheduler.\"\n            )\n\n        # 3. Clip or threshold \"predicted x_0\"\n        if self.config.thresholding:\n            pred_original_sample = self._threshold_sample(pred_original_sample)\n        elif self.config.clip_sample:\n            pred_original_sample = pred_original_sample.clamp(\n                -self.config.clip_sample_range, self.config.clip_sample_range\n            )\n\n        # 4. Single-step or multi-step sampling\n        # Noise is not used on the final timestep of the timestep schedule.\n        # This also means that noise is not used for one-step sampling.\n        if t != self.timesteps[-1]:\n            # TODO: is this correct?\n            # Sample prev sample x_{t - 1} ~ q(x_{t - 1} | x_0 =  G(x_t, t))\n            device = model_output.device\n            noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)\n            sqrt_alpha_prod_t_prev = alpha_prod_t_prev**0.5\n            sqrt_one_minus_alpha_prod_t_prev = (1 - alpha_prod_t_prev) ** 0.5\n            pred_prev_sample = sqrt_alpha_prod_t_prev * pred_original_sample + sqrt_one_minus_alpha_prod_t_prev * noise\n        else:\n            # Simply return the pred_original_sample. If `prediction_type == \"sample\"`, this is equivalent to returning\n            # the output of the GAN generator U-Net on the initial noisy latents x_T ~ N(0, I).\n            pred_prev_sample = pred_original_sample\n\n        if not return_dict:\n            return (pred_prev_sample,)\n\n        return UFOGenSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise\n    def add_noise(\n        self,\n        original_samples: torch.Tensor,\n        noise: torch.Tensor,\n        timesteps: torch.IntTensor,\n    ) -> torch.Tensor:\n        # Make sure alphas_cumprod and timestep have same device and dtype as original_samples\n        alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)\n        timesteps = timesteps.to(original_samples.device)\n\n        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5\n        sqrt_alpha_prod = sqrt_alpha_prod.flatten()\n        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):\n            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)\n\n        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5\n        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()\n        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):\n            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)\n\n        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise\n        return noisy_samples\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity\n    def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:\n        # Make sure alphas_cumprod and timestep have same device and dtype as sample\n        alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)\n        timesteps = timesteps.to(sample.device)\n\n        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5\n        sqrt_alpha_prod = sqrt_alpha_prod.flatten()\n        while len(sqrt_alpha_prod.shape) < len(sample.shape):\n            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)\n\n        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5\n        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()\n        while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):\n            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)\n\n        velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample\n        return velocity\n\n    def __len__(self):\n        return self.config.num_train_timesteps\n\n    # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep\n    def previous_timestep(self, timestep):\n        if self.custom_timesteps:\n            index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]\n            if index == self.timesteps.shape[0] - 1:\n                prev_t = torch.tensor(-1)\n            else:\n                prev_t = self.timesteps[index + 1]\n        else:\n            num_inference_steps = (\n                self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps\n            )\n            prev_t = timestep - self.config.num_train_timesteps // num_inference_steps\n\n        return prev_t\n"
  },
  {
    "path": "examples/community/sd_text2img_k_diffusion.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nimport warnings\nfrom typing import Callable, List, Optional, Union\n\nimport torch\nfrom k_diffusion.external import CompVisDenoiser, CompVisVDenoiser\n\nfrom diffusers import DiffusionPipeline, LMSDiscreteScheduler, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.utils import logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass ModelWrapper:\n    def __init__(self, model, alphas_cumprod):\n        self.model = model\n        self.alphas_cumprod = alphas_cumprod\n\n    def apply_model(self, *args, **kwargs):\n        if len(args) == 3:\n            encoder_hidden_states = args[-1]\n            args = args[:2]\n        if kwargs.get(\"cond\", None) is not None:\n            encoder_hidden_states = kwargs.pop(\"cond\")\n        return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample\n\n\nclass StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae,\n        text_encoder,\n        tokenizer,\n        unet,\n        scheduler,\n        safety_checker,\n        feature_extractor,\n    ):\n        super().__init__()\n\n        if safety_checker is None:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        # get correct sigmas from LMS\n        scheduler = LMSDiscreteScheduler.from_config(scheduler.config)\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n        model = ModelWrapper(unet, scheduler.alphas_cumprod)\n        if scheduler.config.prediction_type == \"v_prediction\":\n            self.k_diffusion_model = CompVisVDenoiser(model)\n        else:\n            self.k_diffusion_model = CompVisDenoiser(model)\n\n    def set_sampler(self, scheduler_type: str):\n        warnings.warn(\"The `set_sampler` method is deprecated, please use `set_scheduler` instead.\")\n        return self.set_scheduler(scheduler_type)\n\n    def set_scheduler(self, scheduler_type: str):\n        library = importlib.import_module(\"k_diffusion\")\n        sampling = getattr(library, \"sampling\")\n        self.sampler = getattr(sampling, scheduler_type)\n\n    def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `list(int)`):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n        \"\"\"\n        batch_size = len(prompt) if isinstance(prompt, list) else 1\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer(prompt, padding=\"max_length\", return_tensors=\"pt\").input_ids\n\n        if not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n            )\n\n        if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n            attention_mask = text_inputs.attention_mask.to(device)\n        else:\n            attention_mask = None\n\n        text_embeddings = self.text_encoder(\n            text_input_ids.to(device),\n            attention_mask=attention_mask,\n        )\n        text_embeddings = text_embeddings[0]\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = text_embeddings.shape\n        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)\n        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = text_input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            uncond_embeddings = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            uncond_embeddings = uncond_embeddings[0]\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = uncond_embeddings.shape[1]\n            uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)\n            uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        return text_embeddings\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        else:\n            has_nsfw_concept = None\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def check_inputs(self, prompt, height, width, callback_steps):\n        if not isinstance(prompt, str) and not isinstance(prompt, list):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (batch_size, num_channels_latents, height // 8, width // 8)\n        if latents is None:\n            if device.type == \"mps\":\n                # randn does not work reproducibly on mps\n                latents = torch.randn(shape, generator=generator, device=\"cpu\", dtype=dtype).to(device)\n            else:\n                latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            if latents.shape != shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {shape}\")\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        return latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(prompt, height, width, callback_steps)\n\n        # 2. Define call parameters\n        batch_size = 1 if isinstance(prompt, str) else len(prompt)\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = True\n        if guidance_scale <= 1.0:\n            raise ValueError(\"has to use guidance_scale\")\n\n        # 3. Encode input prompt\n        text_embeddings = self._encode_prompt(\n            prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt\n        )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)\n        sigmas = self.scheduler.sigmas\n        sigmas = sigmas.to(text_embeddings.dtype)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            text_embeddings.dtype,\n            device,\n            generator,\n            latents,\n        )\n        latents = latents * sigmas[0]\n        self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)\n        self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)\n\n        def model_fn(x, t):\n            latent_model_input = torch.cat([x] * 2)\n\n            noise_pred = self.k_diffusion_model(latent_model_input, t, cond=text_embeddings)\n\n            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n            return noise_pred\n\n        latents = self.sampler(model_fn, latents, sigmas)\n\n        # 8. Post-processing\n        image = self.decode_latents(latents)\n\n        # 9. Run safety checker\n        image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)\n\n        # 10. Convert to PIL\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/sde_drag.py",
    "content": "import math\nimport tempfile\nfrom typing import List, Optional\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom accelerate import Accelerator\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel\nfrom diffusers.loaders import AttnProcsLayers, StableDiffusionLoraLoaderMixin\nfrom diffusers.models.attention_processor import (\n    AttnAddedKVProcessor,\n    AttnAddedKVProcessor2_0,\n    LoRAAttnAddedKVProcessor,\n    LoRAAttnProcessor,\n    LoRAAttnProcessor2_0,\n    SlicedAttnAddedKVProcessor,\n)\nfrom diffusers.optimization import get_scheduler\n\n\nclass SdeDragPipeline(DiffusionPipeline):\n    r\"\"\"\n    Pipeline for image drag-and-drop editing using stochastic differential equations: https://huggingface.co/papers/2311.01410.\n    Please refer to the [official repository](https://github.com/ML-GSAI/SDE-Drag) for more information.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Please use\n            [`DDIMScheduler`].\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: DPMSolverMultistepScheduler,\n    ):\n        super().__init__()\n\n        self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: str,\n        image: PIL.Image.Image,\n        mask_image: PIL.Image.Image,\n        source_points: List[List[int]],\n        target_points: List[List[int]],\n        t0: Optional[float] = 0.6,\n        steps: Optional[int] = 200,\n        step_size: Optional[int] = 2,\n        image_scale: Optional[float] = 0.3,\n        adapt_radius: Optional[int] = 5,\n        min_lora_scale: Optional[float] = 0.5,\n        generator: torch.Generator | None = None,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for image editing.\n        Args:\n            prompt (`str`, *required*):\n                The prompt to guide the image editing.\n            image (`PIL.Image.Image`, *required*):\n                Which will be edited, parts of the image will be masked out with `mask_image` and edited\n                according to `prompt`.\n            mask_image (`PIL.Image.Image`, *required*):\n                To mask `image`. White pixels in the mask will be edited, while black pixels will be preserved.\n            source_points (`List[List[int]]`, *required*):\n                Used to mark the starting positions of drag editing in the image, with each pixel represented as a\n                `List[int]` of length 2.\n            target_points (`List[List[int]]`, *required*):\n                Used to mark the target positions of drag editing in the image, with each pixel represented as a\n                `List[int]` of length 2.\n            t0 (`float`, *optional*, defaults to 0.6):\n                The time parameter. Higher t0 improves the fidelity while lowering the faithfulness of the edited images\n                and vice versa.\n            steps (`int`, *optional*, defaults to 200):\n                The number of sampling iterations.\n            step_size (`int`, *optional*, defaults to 2):\n                The drag distance of each drag step.\n            image_scale (`float`, *optional*, defaults to 0.3):\n                To avoid duplicating the content, use image_scale to perturbs the source.\n            adapt_radius (`int`, *optional*, defaults to 5):\n                The size of the region for copy and paste operations during each step of the drag process.\n            min_lora_scale (`float`, *optional*, defaults to 0.5):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n                min_lora_scale specifies the minimum LoRA scale during the image drag-editing process.\n            generator ('torch.Generator', *optional*, defaults to None):\n                To make generation deterministic(https://pytorch.org/docs/stable/generated/torch.Generator.html).\n        Examples:\n        ```py\n        >>> import PIL\n        >>> import torch\n        >>> from diffusers import DDIMScheduler, DiffusionPipeline\n\n        >>> # Load the pipeline\n        >>> model_path = \"stable-diffusion-v1-5/stable-diffusion-v1-5\"\n        >>> scheduler = DDIMScheduler.from_pretrained(model_path, subfolder=\"scheduler\")\n        >>> pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline=\"sde_drag\")\n        >>> pipe.to('cuda')\n\n        >>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality.\n        >>> # If not training LoRA, please avoid using torch.float16\n        >>> # pipe.to(torch.float16)\n\n        >>> # Provide prompt, image, mask image, and the starting and target points for drag editing.\n        >>> prompt = \"prompt of the image\"\n        >>> image = PIL.Image.open('/path/to/image')\n        >>> mask_image = PIL.Image.open('/path/to/mask_image')\n        >>> source_points = [[123, 456]]\n        >>> target_points = [[234, 567]]\n\n        >>> # train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image.\n        >>> pipe.train_lora(prompt, image)\n\n        >>> output = pipe(prompt, image, mask_image, source_points, target_points)\n        >>> output_image = PIL.Image.fromarray(output)\n        >>> output_image.save(\"./output.png\")\n        ```\n        \"\"\"\n\n        self.scheduler.set_timesteps(steps)\n\n        noise_scale = (1 - image_scale**2) ** (0.5)\n\n        text_embeddings = self._get_text_embed(prompt)\n        uncond_embeddings = self._get_text_embed([\"\"])\n        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        latent = self._get_img_latent(image)\n\n        mask = mask_image.resize((latent.shape[3], latent.shape[2]))\n        mask = torch.tensor(np.array(mask))\n        mask = mask.unsqueeze(0).expand_as(latent).to(self.device)\n\n        source_points = torch.tensor(source_points).div(torch.tensor([8]), rounding_mode=\"trunc\")\n        target_points = torch.tensor(target_points).div(torch.tensor([8]), rounding_mode=\"trunc\")\n\n        distance = target_points - source_points\n        distance_norm_max = torch.norm(distance.float(), dim=1, keepdim=True).max()\n\n        if distance_norm_max <= step_size:\n            drag_num = 1\n        else:\n            drag_num = distance_norm_max.div(torch.tensor([step_size]), rounding_mode=\"trunc\")\n            if (distance_norm_max / drag_num - step_size).abs() > (\n                distance_norm_max / (drag_num + 1) - step_size\n            ).abs():\n                drag_num += 1\n\n        latents = []\n        for i in tqdm(range(int(drag_num)), desc=\"SDE Drag\"):\n            source_new = source_points + (i / drag_num * distance).to(torch.int)\n            target_new = source_points + ((i + 1) / drag_num * distance).to(torch.int)\n\n            latent, noises, hook_latents, lora_scales, cfg_scales = self._forward(\n                latent, steps, t0, min_lora_scale, text_embeddings, generator\n            )\n            latent = self._copy_and_paste(\n                latent,\n                source_new,\n                target_new,\n                adapt_radius,\n                latent.shape[2] - 1,\n                latent.shape[3] - 1,\n                image_scale,\n                noise_scale,\n                generator,\n            )\n            latent = self._backward(\n                latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator\n            )\n\n            latents.append(latent)\n\n        result_image = 1 / 0.18215 * latents[-1]\n\n        with torch.no_grad():\n            result_image = self.vae.decode(result_image).sample\n\n        result_image = (result_image / 2 + 0.5).clamp(0, 1)\n        result_image = result_image.cpu().permute(0, 2, 3, 1).numpy()[0]\n        result_image = (result_image * 255).astype(np.uint8)\n\n        return result_image\n\n    def train_lora(self, prompt, image, lora_step=100, lora_rank=16, generator=None):\n        accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=\"fp16\")\n\n        self.vae.requires_grad_(False)\n        self.text_encoder.requires_grad_(False)\n        self.unet.requires_grad_(False)\n\n        unet_lora_attn_procs = {}\n        for name, attn_processor in self.unet.attn_processors.items():\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else self.unet.config.cross_attention_dim\n            if name.startswith(\"mid_block\"):\n                hidden_size = self.unet.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = self.unet.config.block_out_channels[block_id]\n            else:\n                raise NotImplementedError(\"name must start with up_blocks, mid_blocks, or down_blocks\")\n\n            if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):\n                lora_attn_processor_class = LoRAAttnAddedKVProcessor\n            else:\n                lora_attn_processor_class = (\n                    LoRAAttnProcessor2_0\n                    if hasattr(torch.nn.functional, \"scaled_dot_product_attention\")\n                    else LoRAAttnProcessor\n                )\n            unet_lora_attn_procs[name] = lora_attn_processor_class(\n                hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank\n            )\n\n        self.unet.set_attn_processor(unet_lora_attn_procs)\n        unet_lora_layers = AttnProcsLayers(self.unet.attn_processors)\n        params_to_optimize = unet_lora_layers.parameters()\n\n        optimizer = torch.optim.AdamW(\n            params_to_optimize,\n            lr=2e-4,\n            betas=(0.9, 0.999),\n            weight_decay=1e-2,\n            eps=1e-08,\n        )\n\n        lr_scheduler = get_scheduler(\n            \"constant\",\n            optimizer=optimizer,\n            num_warmup_steps=0,\n            num_training_steps=lora_step,\n            num_cycles=1,\n            power=1.0,\n        )\n\n        unet_lora_layers = accelerator.prepare_model(unet_lora_layers)\n        optimizer = accelerator.prepare_optimizer(optimizer)\n        lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)\n\n        with torch.no_grad():\n            text_inputs = self._tokenize_prompt(prompt, tokenizer_max_length=None)\n            text_embedding = self._encode_prompt(\n                text_inputs.input_ids, text_inputs.attention_mask, text_encoder_use_attention_mask=False\n            )\n\n        image_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n        image = image_transforms(image).to(self.device, dtype=self.vae.dtype)\n        image = image.unsqueeze(dim=0)\n        latents_dist = self.vae.encode(image).latent_dist\n\n        for _ in tqdm(range(lora_step), desc=\"Train LoRA\"):\n            self.unet.train()\n            model_input = latents_dist.sample() * self.vae.config.scaling_factor\n\n            # Sample noise that we'll add to the latents\n            noise = torch.randn(\n                model_input.size(),\n                dtype=model_input.dtype,\n                layout=model_input.layout,\n                device=model_input.device,\n                generator=generator,\n            )\n            bsz, channels, height, width = model_input.shape\n\n            # Sample a random timestep for each image\n            timesteps = torch.randint(\n                0, self.scheduler.config.num_train_timesteps, (bsz,), device=model_input.device, generator=generator\n            )\n            timesteps = timesteps.long()\n\n            # Add noise to the model input according to the noise magnitude at each timestep\n            # (this is the forward diffusion process)\n            noisy_model_input = self.scheduler.add_noise(model_input, noise, timesteps)\n\n            # Predict the noise residual\n            model_pred = self.unet(noisy_model_input, timesteps, text_embedding).sample\n\n            # Get the target for loss depending on the prediction type\n            if self.scheduler.config.prediction_type == \"epsilon\":\n                target = noise\n            elif self.scheduler.config.prediction_type == \"v_prediction\":\n                target = self.scheduler.get_velocity(model_input, noise, timesteps)\n            else:\n                raise ValueError(f\"Unknown prediction type {self.scheduler.config.prediction_type}\")\n\n            loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n            accelerator.backward(loss)\n            optimizer.step()\n            lr_scheduler.step()\n            optimizer.zero_grad()\n\n        with tempfile.TemporaryDirectory() as save_lora_dir:\n            StableDiffusionLoraLoaderMixin.save_lora_weights(\n                save_directory=save_lora_dir,\n                unet_lora_layers=unet_lora_layers,\n                text_encoder_lora_layers=None,\n            )\n\n            self.unet.load_attn_procs(save_lora_dir)\n\n    def _tokenize_prompt(self, prompt, tokenizer_max_length=None):\n        if tokenizer_max_length is not None:\n            max_length = tokenizer_max_length\n        else:\n            max_length = self.tokenizer.model_max_length\n\n        text_inputs = self.tokenizer(\n            prompt,\n            truncation=True,\n            padding=\"max_length\",\n            max_length=max_length,\n            return_tensors=\"pt\",\n        )\n\n        return text_inputs\n\n    def _encode_prompt(self, input_ids, attention_mask, text_encoder_use_attention_mask=False):\n        text_input_ids = input_ids.to(self.device)\n\n        if text_encoder_use_attention_mask:\n            attention_mask = attention_mask.to(self.device)\n        else:\n            attention_mask = None\n\n        prompt_embeds = self.text_encoder(\n            text_input_ids,\n            attention_mask=attention_mask,\n        )\n        prompt_embeds = prompt_embeds[0]\n\n        return prompt_embeds\n\n    @torch.no_grad()\n    def _get_text_embed(self, prompt):\n        text_input = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]\n        return text_embeddings\n\n    def _copy_and_paste(\n        self, latent, source_new, target_new, adapt_radius, max_height, max_width, image_scale, noise_scale, generator\n    ):\n        def adaption_r(source, target, adapt_radius, max_height, max_width):\n            r_x_lower = min(adapt_radius, source[0], target[0])\n            r_x_upper = min(adapt_radius, max_width - source[0], max_width - target[0])\n            r_y_lower = min(adapt_radius, source[1], target[1])\n            r_y_upper = min(adapt_radius, max_height - source[1], max_height - target[1])\n            return r_x_lower, r_x_upper, r_y_lower, r_y_upper\n\n        for source_, target_ in zip(source_new, target_new):\n            r_x_lower, r_x_upper, r_y_lower, r_y_upper = adaption_r(\n                source_, target_, adapt_radius, max_height, max_width\n            )\n\n            source_feature = latent[\n                :, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper\n            ].clone()\n\n            latent[\n                :, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper\n            ] = image_scale * source_feature + noise_scale * torch.randn(\n                latent.shape[0],\n                4,\n                r_y_lower + r_y_upper,\n                r_x_lower + r_x_upper,\n                device=self.device,\n                generator=generator,\n            )\n\n            latent[\n                :, :, target_[1] - r_y_lower : target_[1] + r_y_upper, target_[0] - r_x_lower : target_[0] + r_x_upper\n            ] = source_feature * 1.1\n        return latent\n\n    @torch.no_grad()\n    def _get_img_latent(self, image, height=None, weight=None):\n        data = image.convert(\"RGB\")\n        if height is not None:\n            data = data.resize((weight, height))\n        transform = transforms.ToTensor()\n        data = transform(data).unsqueeze(0)\n        data = (data * 2.0) - 1.0\n        data = data.to(self.device, dtype=self.vae.dtype)\n        latent = self.vae.encode(data).latent_dist.sample()\n        latent = 0.18215 * latent\n        return latent\n\n    @torch.no_grad()\n    def _get_eps(self, latent, timestep, guidance_scale, text_embeddings, lora_scale=None):\n        latent_model_input = torch.cat([latent] * 2) if guidance_scale > 1.0 else latent\n        text_embeddings = text_embeddings if guidance_scale > 1.0 else text_embeddings.chunk(2)[1]\n\n        cross_attention_kwargs = None if lora_scale is None else {\"scale\": lora_scale}\n\n        with torch.no_grad():\n            noise_pred = self.unet(\n                latent_model_input,\n                timestep,\n                encoder_hidden_states=text_embeddings,\n                cross_attention_kwargs=cross_attention_kwargs,\n            ).sample\n\n        if guidance_scale > 1.0:\n            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n        elif guidance_scale == 1.0:\n            noise_pred_text = noise_pred\n            noise_pred_uncond = 0.0\n        else:\n            raise NotImplementedError(guidance_scale)\n        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n        return noise_pred\n\n    def _forward_sde(\n        self, timestep, sample, guidance_scale, text_embeddings, steps, eta=1.0, lora_scale=None, generator=None\n    ):\n        num_train_timesteps = len(self.scheduler)\n        alphas_cumprod = self.scheduler.alphas_cumprod\n        initial_alpha_cumprod = torch.tensor(1.0)\n\n        prev_timestep = timestep + num_train_timesteps // steps\n\n        alpha_prod_t = alphas_cumprod[timestep] if timestep >= 0 else initial_alpha_cumprod\n        alpha_prod_t_prev = alphas_cumprod[prev_timestep]\n\n        beta_prod_t_prev = 1 - alpha_prod_t_prev\n\n        x_prev = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) * sample + (1 - alpha_prod_t_prev / alpha_prod_t) ** (\n            0.5\n        ) * torch.randn(\n            sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator\n        )\n        eps = self._get_eps(x_prev, prev_timestep, guidance_scale, text_embeddings, lora_scale)\n\n        sigma_t_prev = (\n            eta\n            * (1 - alpha_prod_t) ** (0.5)\n            * (1 - alpha_prod_t_prev / (1 - alpha_prod_t_prev) * (1 - alpha_prod_t) / alpha_prod_t) ** (0.5)\n        )\n\n        pred_original_sample = (x_prev - beta_prod_t_prev ** (0.5) * eps) / alpha_prod_t_prev ** (0.5)\n        pred_sample_direction_coeff = (1 - alpha_prod_t - sigma_t_prev**2) ** (0.5)\n\n        noise = (\n            sample - alpha_prod_t ** (0.5) * pred_original_sample - pred_sample_direction_coeff * eps\n        ) / sigma_t_prev\n\n        return x_prev, noise\n\n    def _sample(\n        self,\n        timestep,\n        sample,\n        guidance_scale,\n        text_embeddings,\n        steps,\n        sde=False,\n        noise=None,\n        eta=1.0,\n        lora_scale=None,\n        generator=None,\n    ):\n        num_train_timesteps = len(self.scheduler)\n        alphas_cumprod = self.scheduler.alphas_cumprod\n        final_alpha_cumprod = torch.tensor(1.0)\n\n        eps = self._get_eps(sample, timestep, guidance_scale, text_embeddings, lora_scale)\n\n        prev_timestep = timestep - num_train_timesteps // steps\n\n        alpha_prod_t = alphas_cumprod[timestep]\n        alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alpha_cumprod\n\n        beta_prod_t = 1 - alpha_prod_t\n\n        sigma_t = (\n            eta\n            * ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** (0.5)\n            * (1 - alpha_prod_t / alpha_prod_t_prev) ** (0.5)\n            if sde\n            else 0\n        )\n\n        pred_original_sample = (sample - beta_prod_t ** (0.5) * eps) / alpha_prod_t ** (0.5)\n        pred_sample_direction_coeff = (1 - alpha_prod_t_prev - sigma_t**2) ** (0.5)\n\n        noise = (\n            torch.randn(\n                sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator\n            )\n            if noise is None\n            else noise\n        )\n        latent = (\n            alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction_coeff * eps + sigma_t * noise\n        )\n\n        return latent\n\n    def _forward(self, latent, steps, t0, lora_scale_min, text_embeddings, generator):\n        def scale_schedule(begin, end, n, length, type=\"linear\"):\n            if type == \"constant\":\n                return end\n            elif type == \"linear\":\n                return begin + (end - begin) * n / length\n            elif type == \"cos\":\n                factor = (1 - math.cos(n * math.pi / length)) / 2\n                return (1 - factor) * begin + factor * end\n            else:\n                raise NotImplementedError(type)\n\n        noises = []\n        latents = []\n        lora_scales = []\n        cfg_scales = []\n        latents.append(latent)\n        t0 = int(t0 * steps)\n        t_begin = steps - t0\n\n        length = len(self.scheduler.timesteps[t_begin - 1 : -1]) - 1\n        index = 1\n        for t in self.scheduler.timesteps[t_begin:].flip(dims=[0]):\n            lora_scale = scale_schedule(1, lora_scale_min, index, length, type=\"cos\")\n            cfg_scale = scale_schedule(1, 3.0, index, length, type=\"linear\")\n            latent, noise = self._forward_sde(\n                t, latent, cfg_scale, text_embeddings, steps, lora_scale=lora_scale, generator=generator\n            )\n\n            noises.append(noise)\n            latents.append(latent)\n            lora_scales.append(lora_scale)\n            cfg_scales.append(cfg_scale)\n            index += 1\n        return latent, noises, latents, lora_scales, cfg_scales\n\n    def _backward(\n        self, latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator\n    ):\n        t0 = int(t0 * steps)\n        t_begin = steps - t0\n\n        hook_latent = hook_latents.pop()\n        latent = torch.where(mask > 128, latent, hook_latent)\n        for t in self.scheduler.timesteps[t_begin - 1 : -1]:\n            latent = self._sample(\n                t,\n                latent,\n                cfg_scales.pop(),\n                text_embeddings,\n                steps,\n                sde=True,\n                noise=noises.pop(),\n                lora_scale=lora_scales.pop(),\n                generator=generator,\n            )\n            hook_latent = hook_latents.pop()\n            latent = torch.where(mask > 128, latent, hook_latent)\n        return latent\n"
  },
  {
    "path": "examples/community/seed_resize_stable_diffusion.py",
    "content": "\"\"\"\nmodified based on diffusion library from Huggingface: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py\n\"\"\"\n\nimport inspect\nfrom typing import Callable, List, Optional, Union\n\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler\nfrom diffusers.utils import logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass SeedResizeStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        text_embeddings: Optional[torch.Tensor] = None,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n\n        if isinstance(prompt, str):\n            batch_size = 1\n        elif isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        # get prompt text embeddings\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n\n        if text_input_ids.shape[-1] > self.tokenizer.model_max_length:\n            removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n            )\n            text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]\n\n        if text_embeddings is None:\n            text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = text_embeddings.shape\n        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)\n        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"]\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = text_input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = uncond_embeddings.shape[1]\n            uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)\n            uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # get the initial random noise unless the user supplied it\n\n        # Unlike in other pipelines, latents need to be generated in the target device\n        # for 1-to-1 results reproducibility with the CompVis implementation.\n        # However this currently doesn't work in `mps`.\n        latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)\n        latents_shape_reference = (batch_size * num_images_per_prompt, self.unet.config.in_channels, 64, 64)\n        latents_dtype = text_embeddings.dtype\n        if latents is None:\n            if self.device.type == \"mps\":\n                # randn does not exist on mps\n                latents_reference = torch.randn(\n                    latents_shape_reference, generator=generator, device=\"cpu\", dtype=latents_dtype\n                ).to(self.device)\n                latents = torch.randn(latents_shape, generator=generator, device=\"cpu\", dtype=latents_dtype).to(\n                    self.device\n                )\n            else:\n                latents_reference = torch.randn(\n                    latents_shape_reference, generator=generator, device=self.device, dtype=latents_dtype\n                )\n                latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)\n        else:\n            if latents_reference.shape != latents_shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {latents_shape}\")\n            latents_reference = latents_reference.to(self.device)\n            latents = latents.to(self.device)\n\n        # This is the key part of the pipeline where we\n        # try to ensure that the generated images w/ the same seed\n        # but different sizes actually result in similar images\n        dx = (latents_shape[3] - latents_shape_reference[3]) // 2\n        dy = (latents_shape[2] - latents_shape_reference[2]) // 2\n        w = latents_shape_reference[3] if dx >= 0 else latents_shape_reference[3] + 2 * dx\n        h = latents_shape_reference[2] if dy >= 0 else latents_shape_reference[2] + 2 * dy\n        tx = 0 if dx < 0 else dx\n        ty = 0 if dy < 0 else dy\n        dx = max(-dx, 0)\n        dy = max(-dy, 0)\n        # import pdb\n        # pdb.set_trace()\n        latents[:, :, ty : ty + h, tx : tx + w] = latents_reference[:, :, dy : dy + h, dx : dx + w]\n\n        # set timesteps\n        self.scheduler.set_timesteps(num_inference_steps)\n\n        # Some schedulers like PNDM have timesteps as arrays\n        # It's more optimized to move all timesteps to correct device beforehand\n        timesteps_tensor = self.scheduler.timesteps.to(self.device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        for i, t in enumerate(self.progress_bar(timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n            # call the callback, if provided\n            if callback is not None and i % callback_steps == 0:\n                step_idx = i // getattr(self.scheduler, \"order\", 1)\n                callback(step_idx, t, latents)\n\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents).sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(\n                self.device\n            )\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)\n            )\n        else:\n            has_nsfw_concept = None\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/speech_to_image_diffusion.py",
    "content": "import inspect\nfrom typing import Callable, List, Optional, Union\n\nimport torch\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTextModel,\n    CLIPTokenizer,\n    WhisperForConditionalGeneration,\n    WhisperProcessor,\n)\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDIMScheduler,\n    DiffusionPipeline,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n    UNet2DConditionModel,\n)\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.utils import logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass SpeechToImagePipeline(DiffusionPipeline, StableDiffusionMixin):\n    def __init__(\n        self,\n        speech_model: WhisperForConditionalGeneration,\n        speech_processor: WhisperProcessor,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n\n        if safety_checker is None:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        self.register_modules(\n            speech_model=speech_model,\n            speech_processor=speech_processor,\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        audio,\n        sampling_rate=16_000,\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        inputs = self.speech_processor.feature_extractor(\n            audio, return_tensors=\"pt\", sampling_rate=sampling_rate\n        ).input_features.to(self.device)\n        predicted_ids = self.speech_model.generate(inputs, max_length=480_000)\n\n        prompt = self.speech_processor.tokenizer.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True)[\n            0\n        ]\n\n        if isinstance(prompt, str):\n            batch_size = 1\n        elif isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        # get prompt text embeddings\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n\n        if text_input_ids.shape[-1] > self.tokenizer.model_max_length:\n            removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n            )\n            text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]\n        text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = text_embeddings.shape\n        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)\n        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = text_input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = uncond_embeddings.shape[1]\n            uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)\n            uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # get the initial random noise unless the user supplied it\n\n        # Unlike in other pipelines, latents need to be generated in the target device\n        # for 1-to-1 results reproducibility with the CompVis implementation.\n        # However this currently doesn't work in `mps`.\n        latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)\n        latents_dtype = text_embeddings.dtype\n        if latents is None:\n            if self.device.type == \"mps\":\n                # randn does not exist on mps\n                latents = torch.randn(latents_shape, generator=generator, device=\"cpu\", dtype=latents_dtype).to(\n                    self.device\n                )\n            else:\n                latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)\n        else:\n            if latents.shape != latents_shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {latents_shape}\")\n            latents = latents.to(self.device)\n\n        # set timesteps\n        self.scheduler.set_timesteps(num_inference_steps)\n\n        # Some schedulers like PNDM have timesteps as arrays\n        # It's more optimized to move all timesteps to correct device beforehand\n        timesteps_tensor = self.scheduler.timesteps.to(self.device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        for i, t in enumerate(self.progress_bar(timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n            # call the callback, if provided\n            if callback is not None and i % callback_steps == 0:\n                step_idx = i // getattr(self.scheduler, \"order\", 1)\n                callback(step_idx, t, latents)\n\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents).sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return image\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)\n"
  },
  {
    "path": "examples/community/stable_diffusion_comparison.py",
    "content": "from typing import Any, Callable, List, Optional, Union\n\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDIMScheduler,\n    DiffusionPipeline,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\n\n\npipe1_model_id = \"CompVis/stable-diffusion-v1-1\"\npipe2_model_id = \"CompVis/stable-diffusion-v1-2\"\npipe3_model_id = \"CompVis/stable-diffusion-v1-3\"\npipe4_model_id = \"CompVis/stable-diffusion-v1-4\"\n\n\nclass StableDiffusionComparisonPipeline(DiffusionPipeline, StableDiffusionMixin):\n    r\"\"\"\n    Pipeline for parallel comparison of Stable Diffusion v1-v4\n    This pipeline inherits from DiffusionPipeline and depends on the use of an Auth Token for\n    downloading pre-trained checkpoints from Hugging Face Hub.\n    If using Hugging Face Hub, pass the Model ID for Stable Diffusion v1.4 as the previous 3 checkpoints will be loaded\n    automatically.\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionMegaSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super()._init_()\n\n        self.pipe1 = StableDiffusionPipeline.from_pretrained(pipe1_model_id)\n        self.pipe2 = StableDiffusionPipeline.from_pretrained(pipe2_model_id)\n        self.pipe3 = StableDiffusionPipeline.from_pretrained(pipe3_model_id)\n        self.pipe4 = StableDiffusionPipeline(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            requires_safety_checker=requires_safety_checker,\n        )\n\n        self.register_modules(pipeline1=self.pipe1, pipeline2=self.pipe2, pipeline3=self.pipe3, pipeline4=self.pipe4)\n\n    @property\n    def layers(self) -> dict[str, Any]:\n        return {k: getattr(self, k) for k in self.config.keys() if not k.startswith(\"_\")}\n\n    @torch.no_grad()\n    def text2img_sd1_1(\n        self,\n        prompt: Union[str, List[str]],\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        return self.pipe1(\n            prompt=prompt,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n            **kwargs,\n        )\n\n    @torch.no_grad()\n    def text2img_sd1_2(\n        self,\n        prompt: Union[str, List[str]],\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        return self.pipe2(\n            prompt=prompt,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n            **kwargs,\n        )\n\n    @torch.no_grad()\n    def text2img_sd1_3(\n        self,\n        prompt: Union[str, List[str]],\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        return self.pipe3(\n            prompt=prompt,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n            **kwargs,\n        )\n\n    @torch.no_grad()\n    def text2img_sd1_4(\n        self,\n        prompt: Union[str, List[str]],\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        return self.pipe4(\n            prompt=prompt,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n            **kwargs,\n        )\n\n    @torch.no_grad()\n    def _call_(\n        self,\n        prompt: Union[str, List[str]],\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation. This function will generate 4 results as part\n        of running all the 4 pipelines for SD1.1-1.4 together in a serial-processing, parallel-invocation fashion.\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            height (`int`, optional, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, optional, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, optional, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, optional, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            eta (`float`, optional, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, optional):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`torch.Tensor`, optional):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, optional, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, optional, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n        self.to(device)\n\n        # Checks if the height and width are divisible by 8 or not\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` must be divisible by 8 but are {height} and {width}.\")\n\n        # Get first result from Stable Diffusion Checkpoint v1.1\n        res1 = self.text2img_sd1_1(\n            prompt=prompt,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n            **kwargs,\n        )\n\n        # Get first result from Stable Diffusion Checkpoint v1.2\n        res2 = self.text2img_sd1_2(\n            prompt=prompt,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n            **kwargs,\n        )\n\n        # Get first result from Stable Diffusion Checkpoint v1.3\n        res3 = self.text2img_sd1_3(\n            prompt=prompt,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n            **kwargs,\n        )\n\n        # Get first result from Stable Diffusion Checkpoint v1.4\n        res4 = self.text2img_sd1_4(\n            prompt=prompt,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n            **kwargs,\n        )\n\n        # Get all result images into a single list and pass it via StableDiffusionPipelineOutput for final result\n        return StableDiffusionPipelineOutput([res1[0], res2[0], res3[0], res4[0]])\n"
  },
  {
    "path": "examples/community/stable_diffusion_controlnet_img2img.py",
    "content": "# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, ControlNetModel, UNet2DConditionModel, logging\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    PIL_INTERPOLATION,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import numpy as np\n        >>> import torch\n        >>> from PIL import Image\n        >>> from diffusers import ControlNetModel, UniPCMultistepScheduler\n        >>> from diffusers.utils import load_image\n\n        >>> input_image = load_image(\"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\")\n\n        >>> controlnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\n\n        >>> pipe_controlnet = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(\n                \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n                controlnet=controlnet,\n                safety_checker=None,\n                torch_dtype=torch.float16\n                )\n\n        >>> pipe_controlnet.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)\n        >>> pipe_controlnet.enable_xformers_memory_efficient_attention()\n        >>> pipe_controlnet.enable_model_cpu_offload()\n\n        # using image with edges for our canny controlnet\n        >>> control_image = load_image(\n            \"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_canny_edged.png\")\n\n\n        >>> result_img = pipe_controlnet(controlnet_conditioning_image=control_image,\n                        image=input_image,\n                        prompt=\"an android robot, cyberpank, digitl art masterpiece\",\n                        num_inference_steps=20).images[0]\n\n        >>> result_img.show()\n        ```\n\"\"\"\n\n\ndef prepare_image(image):\n    if isinstance(image, torch.Tensor):\n        # Batch single image\n        if image.ndim == 3:\n            image = image.unsqueeze(0)\n\n        image = image.to(dtype=torch.float32)\n    else:\n        # preprocess image\n        if isinstance(image, (PIL.Image.Image, np.ndarray)):\n            image = [image]\n\n        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n            image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n            image = np.concatenate(image, axis=0)\n        elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n            image = np.concatenate([i[None, :] for i in image], axis=0)\n\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n    return image\n\n\ndef prepare_controlnet_conditioning_image(\n    controlnet_conditioning_image,\n    width,\n    height,\n    batch_size,\n    num_images_per_prompt,\n    device,\n    dtype,\n    do_classifier_free_guidance,\n):\n    if not isinstance(controlnet_conditioning_image, torch.Tensor):\n        if isinstance(controlnet_conditioning_image, PIL.Image.Image):\n            controlnet_conditioning_image = [controlnet_conditioning_image]\n\n        if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):\n            controlnet_conditioning_image = [\n                np.array(i.resize((width, height), resample=PIL_INTERPOLATION[\"lanczos\"]))[None, :]\n                for i in controlnet_conditioning_image\n            ]\n            controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)\n            controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0\n            controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)\n            controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)\n        elif isinstance(controlnet_conditioning_image[0], torch.Tensor):\n            controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)\n\n    image_batch_size = controlnet_conditioning_image.shape[0]\n\n    if image_batch_size == 1:\n        repeat_by = batch_size\n    else:\n        # image batch size is the same as prompt batch size\n        repeat_by = num_images_per_prompt\n\n    controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)\n\n    controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)\n\n    if do_classifier_free_guidance:\n        controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)\n\n    return controlnet_conditioning_image\n\n\nclass StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin):\n    \"\"\"\n    Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n        \"\"\"\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            prompt_embeds = self.text_encoder(\n                text_input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        else:\n            has_nsfw_concept = None\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n\n        if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:\n            raise TypeError(\n                \"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        elif image_is_tensor:\n            image_batch_size = image.shape[0]\n        elif image_is_pil_list:\n            image_batch_size = len(image)\n        elif image_is_tensor_list:\n            image_batch_size = len(image)\n        else:\n            raise ValueError(\"controlnet condition image is not valid\")\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n        else:\n            raise ValueError(\"prompt or prompt_embeds are not valid\")\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    def check_inputs(\n        self,\n        prompt,\n        image,\n        controlnet_conditioning_image,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        strength=None,\n        controlnet_guidance_start=None,\n        controlnet_guidance_end=None,\n        controlnet_conditioning_scale=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        # check controlnet condition image\n\n        if isinstance(self.controlnet, ControlNetModel):\n            self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)\n        elif isinstance(self.controlnet, MultiControlNetModel):\n            if not isinstance(controlnet_conditioning_image, list):\n                raise TypeError(\"For multiple controlnets: `image` must be type `list`\")\n\n            if len(controlnet_conditioning_image) != len(self.controlnet.nets):\n                raise ValueError(\n                    \"For multiple controlnets: `image` must have the same length as the number of controlnets.\"\n                )\n\n            for image_ in controlnet_conditioning_image:\n                self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n\n        if isinstance(self.controlnet, ControlNetModel):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif isinstance(self.controlnet, MultiControlNetModel):\n            if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if isinstance(image, torch.Tensor):\n            if image.ndim != 3 and image.ndim != 4:\n                raise ValueError(\"`image` must have 3 or 4 dimensions\")\n\n            if image.ndim == 3:\n                image_batch_size = 1\n                image_channels, image_height, image_width = image.shape\n            elif image.ndim == 4:\n                image_batch_size, image_channels, image_height, image_width = image.shape\n            else:\n                assert False\n\n            if image_channels != 3:\n                raise ValueError(\"`image` must have 3 channels\")\n\n            if image.min() < -1 or image.max() > 1:\n                raise ValueError(\"`image` should be in range [-1, 1]\")\n\n        if self.vae.config.latent_channels != self.unet.config.in_channels:\n            raise ValueError(\n                f\"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received\"\n                f\" latent channels: {self.vae.config.latent_channels},\"\n                f\" Please verify the config of `pipeline.unet` and the `pipeline.vae`\"\n            )\n\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of `strength` should in [0.0, 1.0] but is {strength}\")\n\n        if controlnet_guidance_start < 0 or controlnet_guidance_start > 1:\n            raise ValueError(\n                f\"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}\"\n            )\n\n        if controlnet_guidance_end < 0 or controlnet_guidance_end > 1:\n            raise ValueError(\n                f\"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}\"\n            )\n\n        if controlnet_guidance_start > controlnet_guidance_end:\n            raise ValueError(\n                \"The value of `controlnet_guidance_start` should be less than `controlnet_guidance_end`, but got\"\n                f\" `controlnet_guidance_start` {controlnet_guidance_start} >= `controlnet_guidance_end` {controlnet_guidance_end}\"\n            )\n\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start:]\n\n        return timesteps, num_inference_steps - t_start\n\n    def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if isinstance(generator, list):\n            init_latents = [\n                self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)\n            ]\n            init_latents = torch.cat(init_latents, dim=0)\n        else:\n            init_latents = self.vae.encode(image).latent_dist.sample(generator)\n\n        init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        shape = init_latents.shape\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        # get latents\n        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n        latents = init_latents\n\n        return latents\n\n    def _default_height_width(self, height, width, image):\n        if isinstance(image, list):\n            image = image[0]\n\n        if height is None:\n            if isinstance(image, PIL.Image.Image):\n                height = image.height\n            elif isinstance(image, torch.Tensor):\n                height = image.shape[3]\n\n            height = (height // 8) * 8  # round down to nearest multiple of 8\n\n        if width is None:\n            if isinstance(image, PIL.Image.Image):\n                width = image.width\n            elif isinstance(image, torch.Tensor):\n                width = image.shape[2]\n\n            width = (width // 8) * 8  # round down to nearest multiple of 8\n\n        return height, width\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: Union[torch.Tensor, PIL.Image.Image] = None,\n        controlnet_conditioning_image: Union[\n            torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]\n        ] = None,\n        strength: float = 0.8,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        controlnet_guidance_start: float = 0.0,\n        controlnet_guidance_end: float = 1.0,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will\n                be masked out with `mask_image` and repainted according to `prompt`.\n            controlnet_conditioning_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]`):\n                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If\n                the type is specified as `torch.Tensor`, it is passed to ControlNet as is. PIL.Image.Image` can\n                also be accepted as an image. The control image is automatically resized to fit the output image.\n            strength (`float`, *optional*):\n                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`\n                will be used as a starting point, adding more noise to it the larger the `strength`. The number of\n                denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will\n                be maximum and the denoising process will run for the full number of iterations specified in\n                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original unet.\n            controlnet_guidance_start ('float', *optional*, defaults to 0.0):\n                The percentage of total steps the controlnet starts applying. Must be between 0 and 1.\n            controlnet_guidance_end ('float', *optional*, defaults to 1.0):\n                The percentage of total steps the controlnet ends applying. Must be between 0 and 1. Must be greater\n                than `controlnet_guidance_start`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        # 0. Default height and width to unet\n        height, width = self._default_height_width(height, width, controlnet_conditioning_image)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            image,\n            controlnet_conditioning_image,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            strength,\n            controlnet_guidance_start,\n            controlnet_guidance_end,\n            controlnet_conditioning_scale,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n\n        # 4. Prepare image, and controlnet_conditioning_image\n        image = prepare_image(image)\n\n        # condition image(s)\n        if isinstance(self.controlnet, ControlNetModel):\n            controlnet_conditioning_image = prepare_controlnet_conditioning_image(\n                controlnet_conditioning_image=controlnet_conditioning_image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=self.controlnet.dtype,\n                do_classifier_free_guidance=do_classifier_free_guidance,\n            )\n        elif isinstance(self.controlnet, MultiControlNetModel):\n            controlnet_conditioning_images = []\n\n            for image_ in controlnet_conditioning_image:\n                image_ = prepare_controlnet_conditioning_image(\n                    controlnet_conditioning_image=image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=self.controlnet.dtype,\n                    do_classifier_free_guidance=do_classifier_free_guidance,\n                )\n\n                controlnet_conditioning_images.append(image_)\n\n            controlnet_conditioning_image = controlnet_conditioning_images\n        else:\n            assert False\n\n        # 5. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 6. Prepare latent variables\n        if latents is None:\n            latents = self.prepare_latents(\n                image,\n                latent_timestep,\n                batch_size,\n                num_images_per_prompt,\n                prompt_embeds.dtype,\n                device,\n                generator,\n            )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # compute the percentage of total steps we are at\n                current_sampling_percent = i / len(timesteps)\n\n                if (\n                    current_sampling_percent < controlnet_guidance_start\n                    or current_sampling_percent > controlnet_guidance_end\n                ):\n                    # do not apply the controlnet\n                    down_block_res_samples = None\n                    mid_block_res_sample = None\n                else:\n                    # apply the controlnet\n                    down_block_res_samples, mid_block_res_sample = self.controlnet(\n                        latent_model_input,\n                        t,\n                        encoder_hidden_states=prompt_embeds,\n                        controlnet_cond=controlnet_conditioning_image,\n                        conditioning_scale=controlnet_conditioning_scale,\n                        return_dict=False,\n                    )\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                ).sample\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        if output_type == \"latent\":\n            image = latents\n            has_nsfw_concept = None\n        elif output_type == \"pil\":\n            # 8. Post-processing\n            image = self.decode_latents(latents)\n\n            # 9. Run safety checker\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n\n            # 10. Convert to PIL\n            image = self.numpy_to_pil(image)\n        else:\n            # 8. Post-processing\n            image = self.decode_latents(latents)\n\n            # 9. Run safety checker\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/stable_diffusion_controlnet_inpaint.py",
    "content": "# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, ControlNetModel, UNet2DConditionModel, logging\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    PIL_INTERPOLATION,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import numpy as np\n        >>> import torch\n        >>> from PIL import Image\n        >>> from stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline\n\n        >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation\n        >>> from diffusers import ControlNetModel, UniPCMultistepScheduler\n        >>> from diffusers.utils import load_image\n\n        >>> def ade_palette():\n                return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],\n                        [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],\n                        [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],\n                        [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],\n                        [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],\n                        [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],\n                        [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],\n                        [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],\n                        [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],\n                        [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],\n                        [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],\n                        [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],\n                        [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],\n                        [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],\n                        [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],\n                        [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],\n                        [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],\n                        [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],\n                        [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],\n                        [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],\n                        [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],\n                        [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],\n                        [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],\n                        [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],\n                        [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],\n                        [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],\n                        [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],\n                        [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],\n                        [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],\n                        [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],\n                        [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],\n                        [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],\n                        [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],\n                        [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],\n                        [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],\n                        [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],\n                        [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],\n                        [102, 255, 0], [92, 0, 255]]\n\n        >>> image_processor = AutoImageProcessor.from_pretrained(\"openmmlab/upernet-convnext-small\")\n        >>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained(\"openmmlab/upernet-convnext-small\")\n\n        >>> controlnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-seg\", torch_dtype=torch.float16)\n\n        >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(\n                \"stable-diffusion-v1-5/stable-diffusion-inpainting\", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16\n            )\n\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n        >>> pipe.enable_xformers_memory_efficient_attention()\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> def image_to_seg(image):\n                pixel_values = image_processor(image, return_tensors=\"pt\").pixel_values\n                with torch.no_grad():\n                    outputs = image_segmentor(pixel_values)\n                seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]\n                color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)  # height, width, 3\n                palette = np.array(ade_palette())\n                for label, color in enumerate(palette):\n                    color_seg[seg == label, :] = color\n                color_seg = color_seg.astype(np.uint8)\n                seg_image = Image.fromarray(color_seg)\n                return seg_image\n\n        >>> image = load_image(\n                \"https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\n            )\n\n        >>> mask_image = load_image(\n                \"https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n            )\n\n        >>> controlnet_conditioning_image = image_to_seg(image)\n\n        >>> image = pipe(\n                \"Face of a yellow cat, high resolution, sitting on a park bench\",\n                image,\n                mask_image,\n                controlnet_conditioning_image,\n                num_inference_steps=20,\n            ).images[0]\n\n        >>> image.save(\"out.png\")\n        ```\n\"\"\"\n\n\ndef prepare_image(image):\n    if isinstance(image, torch.Tensor):\n        # Batch single image\n        if image.ndim == 3:\n            image = image.unsqueeze(0)\n\n        image = image.to(dtype=torch.float32)\n    else:\n        # preprocess image\n        if isinstance(image, (PIL.Image.Image, np.ndarray)):\n            image = [image]\n\n        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n            image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n            image = np.concatenate(image, axis=0)\n        elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n            image = np.concatenate([i[None, :] for i in image], axis=0)\n\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n    return image\n\n\ndef prepare_mask_image(mask_image):\n    if isinstance(mask_image, torch.Tensor):\n        if mask_image.ndim == 2:\n            # Batch and add channel dim for single mask\n            mask_image = mask_image.unsqueeze(0).unsqueeze(0)\n        elif mask_image.ndim == 3 and mask_image.shape[0] == 1:\n            # Single mask, the 0'th dimension is considered to be\n            # the existing batch size of 1\n            mask_image = mask_image.unsqueeze(0)\n        elif mask_image.ndim == 3 and mask_image.shape[0] != 1:\n            # Batch of mask, the 0'th dimension is considered to be\n            # the batching dimension\n            mask_image = mask_image.unsqueeze(1)\n\n        # Binarize mask\n        mask_image[mask_image < 0.5] = 0\n        mask_image[mask_image >= 0.5] = 1\n    else:\n        # preprocess mask\n        if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):\n            mask_image = [mask_image]\n\n        if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):\n            mask_image = np.concatenate([np.array(m.convert(\"L\"))[None, None, :] for m in mask_image], axis=0)\n            mask_image = mask_image.astype(np.float32) / 255.0\n        elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):\n            mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)\n\n        mask_image[mask_image < 0.5] = 0\n        mask_image[mask_image >= 0.5] = 1\n        mask_image = torch.from_numpy(mask_image)\n\n    return mask_image\n\n\ndef prepare_controlnet_conditioning_image(\n    controlnet_conditioning_image,\n    width,\n    height,\n    batch_size,\n    num_images_per_prompt,\n    device,\n    dtype,\n    do_classifier_free_guidance,\n):\n    if not isinstance(controlnet_conditioning_image, torch.Tensor):\n        if isinstance(controlnet_conditioning_image, PIL.Image.Image):\n            controlnet_conditioning_image = [controlnet_conditioning_image]\n\n        if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):\n            controlnet_conditioning_image = [\n                np.array(i.resize((width, height), resample=PIL_INTERPOLATION[\"lanczos\"]))[None, :]\n                for i in controlnet_conditioning_image\n            ]\n            controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)\n            controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0\n            controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)\n            controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)\n        elif isinstance(controlnet_conditioning_image[0], torch.Tensor):\n            controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)\n\n    image_batch_size = controlnet_conditioning_image.shape[0]\n\n    if image_batch_size == 1:\n        repeat_by = batch_size\n    else:\n        # image batch size is the same as prompt batch size\n        repeat_by = num_images_per_prompt\n\n    controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)\n\n    controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)\n\n    if do_classifier_free_guidance:\n        controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)\n\n    return controlnet_conditioning_image\n\n\nclass StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, StableDiffusionMixin):\n    \"\"\"\n    Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n        \"\"\"\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            prompt_embeds = self.text_encoder(\n                text_input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        else:\n            has_nsfw_concept = None\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n\n        if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:\n            raise TypeError(\n                \"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        elif image_is_tensor:\n            image_batch_size = image.shape[0]\n        elif image_is_pil_list:\n            image_batch_size = len(image)\n        elif image_is_tensor_list:\n            image_batch_size = len(image)\n        else:\n            raise ValueError(\"controlnet condition image is not valid\")\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n        else:\n            raise ValueError(\"prompt or prompt_embeds are not valid\")\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    def check_inputs(\n        self,\n        prompt,\n        image,\n        mask_image,\n        controlnet_conditioning_image,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        controlnet_conditioning_scale=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        # check controlnet condition image\n        if isinstance(self.controlnet, ControlNetModel):\n            self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)\n        elif isinstance(self.controlnet, MultiControlNetModel):\n            if not isinstance(controlnet_conditioning_image, list):\n                raise TypeError(\"For multiple controlnets: `image` must be type `list`\")\n            if len(controlnet_conditioning_image) != len(self.controlnet.nets):\n                raise ValueError(\n                    \"For multiple controlnets: `image` must have the same length as the number of controlnets.\"\n                )\n            for image_ in controlnet_conditioning_image:\n                self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n        if isinstance(self.controlnet, ControlNetModel):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif isinstance(self.controlnet, MultiControlNetModel):\n            if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):\n            raise TypeError(\"if `image` is a tensor, `mask_image` must also be a tensor\")\n\n        if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):\n            raise TypeError(\"if `image` is a PIL image, `mask_image` must also be a PIL image\")\n\n        if isinstance(image, torch.Tensor):\n            if image.ndim != 3 and image.ndim != 4:\n                raise ValueError(\"`image` must have 3 or 4 dimensions\")\n\n            if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:\n                raise ValueError(\"`mask_image` must have 2, 3, or 4 dimensions\")\n\n            if image.ndim == 3:\n                image_batch_size = 1\n                image_channels, image_height, image_width = image.shape\n            elif image.ndim == 4:\n                image_batch_size, image_channels, image_height, image_width = image.shape\n            else:\n                assert False\n\n            if mask_image.ndim == 2:\n                mask_image_batch_size = 1\n                mask_image_channels = 1\n                mask_image_height, mask_image_width = mask_image.shape\n            elif mask_image.ndim == 3:\n                mask_image_channels = 1\n                mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape\n            elif mask_image.ndim == 4:\n                mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape\n\n            if image_channels != 3:\n                raise ValueError(\"`image` must have 3 channels\")\n\n            if mask_image_channels != 1:\n                raise ValueError(\"`mask_image` must have 1 channel\")\n\n            if image_batch_size != mask_image_batch_size:\n                raise ValueError(\"`image` and `mask_image` mush have the same batch sizes\")\n\n            if image_height != mask_image_height or image_width != mask_image_width:\n                raise ValueError(\"`image` and `mask_image` must have the same height and width dimensions\")\n\n            if image.min() < -1 or image.max() > 1:\n                raise ValueError(\"`image` should be in range [-1, 1]\")\n\n            if mask_image.min() < 0 or mask_image.max() > 1:\n                raise ValueError(\"`mask_image` should be in range [0, 1]\")\n        else:\n            mask_image_channels = 1\n            image_channels = 3\n\n        single_image_latent_channels = self.vae.config.latent_channels\n\n        total_latent_channels = single_image_latent_channels * 2 + mask_image_channels\n\n        if total_latent_channels != self.unet.config.in_channels:\n            raise ValueError(\n                f\"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received\"\n                f\" non inpainting latent channels: {single_image_latent_channels},\"\n                f\" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}.\"\n                f\" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs.\"\n            )\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n\n        return latents\n\n    def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\n        # and half precision\n        mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))\n        mask_image = mask_image.to(device=device, dtype=dtype)\n\n        # duplicate mask for each generation per prompt, using mps friendly method\n        if mask_image.shape[0] < batch_size:\n            if not batch_size % mask_image.shape[0] == 0:\n                raise ValueError(\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\n                    f\" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number\"\n                    \" of masks that you pass is divisible by the total requested batch size.\"\n                )\n            mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)\n\n        mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image\n\n        mask_image_latents = mask_image\n\n        return mask_image_latents\n\n    def prepare_masked_image_latents(\n        self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance\n    ):\n        masked_image = masked_image.to(device=device, dtype=dtype)\n\n        # encode the mask image into latents space so we can concatenate it to the latents\n        if isinstance(generator, list):\n            masked_image_latents = [\n                self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])\n                for i in range(batch_size)\n            ]\n            masked_image_latents = torch.cat(masked_image_latents, dim=0)\n        else:\n            masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)\n        masked_image_latents = self.vae.config.scaling_factor * masked_image_latents\n\n        # duplicate masked_image_latents for each generation per prompt, using mps friendly method\n        if masked_image_latents.shape[0] < batch_size:\n            if not batch_size % masked_image_latents.shape[0] == 0:\n                raise ValueError(\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                    f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                )\n            masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)\n\n        masked_image_latents = (\n            torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents\n        )\n\n        # aligning device to prevent device errors when concating it with the latent model input\n        masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\n        return masked_image_latents\n\n    def _default_height_width(self, height, width, image):\n        if isinstance(image, list):\n            image = image[0]\n\n        if height is None:\n            if isinstance(image, PIL.Image.Image):\n                height = image.height\n            elif isinstance(image, torch.Tensor):\n                height = image.shape[3]\n\n            height = (height // 8) * 8  # round down to nearest multiple of 8\n\n        if width is None:\n            if isinstance(image, PIL.Image.Image):\n                width = image.width\n            elif isinstance(image, torch.Tensor):\n                width = image.shape[2]\n\n            width = (width // 8) * 8  # round down to nearest multiple of 8\n\n        return height, width\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: Union[torch.Tensor, PIL.Image.Image] = None,\n        mask_image: Union[torch.Tensor, PIL.Image.Image] = None,\n        controlnet_conditioning_image: Union[\n            torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]\n        ] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will\n                be masked out with `mask_image` and repainted according to `prompt`.\n            mask_image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted\n                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)\n                instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            controlnet_conditioning_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]`):\n                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If\n                the type is specified as `torch.Tensor`, it is passed to ControlNet as is. PIL.Image.Image` can\n                also be accepted as an image. The control image is automatically resized to fit the output image.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original unet.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        # 0. Default height and width to unet\n        height, width = self._default_height_width(height, width, controlnet_conditioning_image)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            image,\n            mask_image,\n            controlnet_conditioning_image,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            controlnet_conditioning_scale,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n\n        # 4. Prepare mask, image, and controlnet_conditioning_image\n        image = prepare_image(image)\n\n        mask_image = prepare_mask_image(mask_image)\n\n        # condition image(s)\n        if isinstance(self.controlnet, ControlNetModel):\n            controlnet_conditioning_image = prepare_controlnet_conditioning_image(\n                controlnet_conditioning_image=controlnet_conditioning_image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=self.controlnet.dtype,\n                do_classifier_free_guidance=do_classifier_free_guidance,\n            )\n        elif isinstance(self.controlnet, MultiControlNetModel):\n            controlnet_conditioning_images = []\n\n            for image_ in controlnet_conditioning_image:\n                image_ = prepare_controlnet_conditioning_image(\n                    controlnet_conditioning_image=image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=self.controlnet.dtype,\n                    do_classifier_free_guidance=do_classifier_free_guidance,\n                )\n                controlnet_conditioning_images.append(image_)\n\n            controlnet_conditioning_image = controlnet_conditioning_images\n        else:\n            assert False\n\n        masked_image = image * (mask_image < 0.5)\n\n        # 5. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.vae.config.latent_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        mask_image_latents = self.prepare_mask_latents(\n            mask_image,\n            batch_size * num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            do_classifier_free_guidance,\n        )\n\n        masked_image_latents = self.prepare_masked_image_latents(\n            masked_image,\n            batch_size * num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            do_classifier_free_guidance,\n        )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                non_inpainting_latent_model_input = (\n                    torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                )\n\n                non_inpainting_latent_model_input = self.scheduler.scale_model_input(\n                    non_inpainting_latent_model_input, t\n                )\n\n                inpainting_latent_model_input = torch.cat(\n                    [non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1\n                )\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    non_inpainting_latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    controlnet_cond=controlnet_conditioning_image,\n                    conditioning_scale=controlnet_conditioning_scale,\n                    return_dict=False,\n                )\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    inpainting_latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                ).sample\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        if output_type == \"latent\":\n            image = latents\n            has_nsfw_concept = None\n        elif output_type == \"pil\":\n            # 8. Post-processing\n            image = self.decode_latents(latents)\n\n            # 9. Run safety checker\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n\n            # 10. Convert to PIL\n            image = self.numpy_to_pil(image)\n        else:\n            # 8. Post-processing\n            image = self.decode_latents(latents)\n\n            # 9. Run safety checker\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/stable_diffusion_controlnet_inpaint_img2img.py",
    "content": "# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, ControlNetModel, UNet2DConditionModel, logging\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    PIL_INTERPOLATION,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import numpy as np\n        >>> import torch\n        >>> from PIL import Image\n        >>> from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline\n\n        >>> from transformers import AutoImageProcessor, UperNetForSemanticSegmentation\n        >>> from diffusers import ControlNetModel, UniPCMultistepScheduler\n        >>> from diffusers.utils import load_image\n\n        >>> def ade_palette():\n                return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],\n                        [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],\n                        [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],\n                        [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],\n                        [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],\n                        [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],\n                        [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],\n                        [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],\n                        [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],\n                        [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],\n                        [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],\n                        [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],\n                        [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],\n                        [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],\n                        [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],\n                        [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],\n                        [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],\n                        [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],\n                        [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],\n                        [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],\n                        [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],\n                        [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],\n                        [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],\n                        [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],\n                        [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],\n                        [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],\n                        [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],\n                        [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],\n                        [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],\n                        [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],\n                        [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],\n                        [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],\n                        [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],\n                        [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],\n                        [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],\n                        [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],\n                        [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],\n                        [102, 255, 0], [92, 0, 255]]\n\n        >>> image_processor = AutoImageProcessor.from_pretrained(\"openmmlab/upernet-convnext-small\")\n        >>> image_segmentor = UperNetForSemanticSegmentation.from_pretrained(\"openmmlab/upernet-convnext-small\")\n\n        >>> controlnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-seg\", torch_dtype=torch.float16)\n\n        >>> pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(\n                \"stable-diffusion-v1-5/stable-diffusion-inpainting\", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16\n            )\n\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n        >>> pipe.enable_xformers_memory_efficient_attention()\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> def image_to_seg(image):\n                pixel_values = image_processor(image, return_tensors=\"pt\").pixel_values\n                with torch.no_grad():\n                    outputs = image_segmentor(pixel_values)\n                seg = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]\n                color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)  # height, width, 3\n                palette = np.array(ade_palette())\n                for label, color in enumerate(palette):\n                    color_seg[seg == label, :] = color\n                color_seg = color_seg.astype(np.uint8)\n                seg_image = Image.fromarray(color_seg)\n                return seg_image\n\n        >>> image = load_image(\n                \"https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png\"\n            )\n\n        >>> mask_image = load_image(\n                \"https://github.com/CompVis/latent-diffusion/raw/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png\"\n            )\n\n        >>> controlnet_conditioning_image = image_to_seg(image)\n\n        >>> image = pipe(\n                \"Face of a yellow cat, high resolution, sitting on a park bench\",\n                image,\n                mask_image,\n                controlnet_conditioning_image,\n                num_inference_steps=20,\n            ).images[0]\n\n        >>> image.save(\"out.png\")\n        ```\n\"\"\"\n\n\ndef prepare_image(image):\n    if isinstance(image, torch.Tensor):\n        # Batch single image\n        if image.ndim == 3:\n            image = image.unsqueeze(0)\n\n        image = image.to(dtype=torch.float32)\n    else:\n        # preprocess image\n        if isinstance(image, (PIL.Image.Image, np.ndarray)):\n            image = [image]\n\n        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n            image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n            image = np.concatenate(image, axis=0)\n        elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n            image = np.concatenate([i[None, :] for i in image], axis=0)\n\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n    return image\n\n\ndef prepare_mask_image(mask_image):\n    if isinstance(mask_image, torch.Tensor):\n        if mask_image.ndim == 2:\n            # Batch and add channel dim for single mask\n            mask_image = mask_image.unsqueeze(0).unsqueeze(0)\n        elif mask_image.ndim == 3 and mask_image.shape[0] == 1:\n            # Single mask, the 0'th dimension is considered to be\n            # the existing batch size of 1\n            mask_image = mask_image.unsqueeze(0)\n        elif mask_image.ndim == 3 and mask_image.shape[0] != 1:\n            # Batch of mask, the 0'th dimension is considered to be\n            # the batching dimension\n            mask_image = mask_image.unsqueeze(1)\n\n        # Binarize mask\n        mask_image[mask_image < 0.5] = 0\n        mask_image[mask_image >= 0.5] = 1\n    else:\n        # preprocess mask\n        if isinstance(mask_image, (PIL.Image.Image, np.ndarray)):\n            mask_image = [mask_image]\n\n        if isinstance(mask_image, list) and isinstance(mask_image[0], PIL.Image.Image):\n            mask_image = np.concatenate([np.array(m.convert(\"L\"))[None, None, :] for m in mask_image], axis=0)\n            mask_image = mask_image.astype(np.float32) / 255.0\n        elif isinstance(mask_image, list) and isinstance(mask_image[0], np.ndarray):\n            mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0)\n\n        mask_image[mask_image < 0.5] = 0\n        mask_image[mask_image >= 0.5] = 1\n        mask_image = torch.from_numpy(mask_image)\n\n    return mask_image\n\n\ndef prepare_controlnet_conditioning_image(\n    controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype\n):\n    if not isinstance(controlnet_conditioning_image, torch.Tensor):\n        if isinstance(controlnet_conditioning_image, PIL.Image.Image):\n            controlnet_conditioning_image = [controlnet_conditioning_image]\n\n        if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):\n            controlnet_conditioning_image = [\n                np.array(i.resize((width, height), resample=PIL_INTERPOLATION[\"lanczos\"]))[None, :]\n                for i in controlnet_conditioning_image\n            ]\n            controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)\n            controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0\n            controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)\n            controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)\n        elif isinstance(controlnet_conditioning_image[0], torch.Tensor):\n            controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)\n\n    image_batch_size = controlnet_conditioning_image.shape[0]\n\n    if image_batch_size == 1:\n        repeat_by = batch_size\n    else:\n        # image batch size is the same as prompt batch size\n        repeat_by = num_images_per_prompt\n\n    controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)\n\n    controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)\n\n    return controlnet_conditioning_image\n\n\nclass StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin):\n    \"\"\"\n    Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: ControlNetModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n        \"\"\"\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            prompt_embeds = self.text_encoder(\n                text_input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        else:\n            has_nsfw_concept = None\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        image,\n        mask_image,\n        controlnet_conditioning_image,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        strength=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)\n        controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)\n        controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(\n            controlnet_conditioning_image[0], PIL.Image.Image\n        )\n        controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(\n            controlnet_conditioning_image[0], torch.Tensor\n        )\n\n        if (\n            not controlnet_cond_image_is_pil\n            and not controlnet_cond_image_is_tensor\n            and not controlnet_cond_image_is_pil_list\n            and not controlnet_cond_image_is_tensor_list\n        ):\n            raise TypeError(\n                \"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors\"\n            )\n\n        if controlnet_cond_image_is_pil:\n            controlnet_cond_image_batch_size = 1\n        elif controlnet_cond_image_is_tensor:\n            controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]\n        elif controlnet_cond_image_is_pil_list:\n            controlnet_cond_image_batch_size = len(controlnet_conditioning_image)\n        elif controlnet_cond_image_is_tensor_list:\n            controlnet_cond_image_batch_size = len(controlnet_conditioning_image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n        if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):\n            raise TypeError(\"if `image` is a tensor, `mask_image` must also be a tensor\")\n\n        if isinstance(image, PIL.Image.Image) and not isinstance(mask_image, PIL.Image.Image):\n            raise TypeError(\"if `image` is a PIL image, `mask_image` must also be a PIL image\")\n\n        if isinstance(image, torch.Tensor):\n            if image.ndim != 3 and image.ndim != 4:\n                raise ValueError(\"`image` must have 3 or 4 dimensions\")\n\n            if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:\n                raise ValueError(\"`mask_image` must have 2, 3, or 4 dimensions\")\n\n            if image.ndim == 3:\n                image_batch_size = 1\n                image_channels, image_height, image_width = image.shape\n            elif image.ndim == 4:\n                image_batch_size, image_channels, image_height, image_width = image.shape\n\n            if mask_image.ndim == 2:\n                mask_image_batch_size = 1\n                mask_image_channels = 1\n                mask_image_height, mask_image_width = mask_image.shape\n            elif mask_image.ndim == 3:\n                mask_image_channels = 1\n                mask_image_batch_size, mask_image_height, mask_image_width = mask_image.shape\n            elif mask_image.ndim == 4:\n                mask_image_batch_size, mask_image_channels, mask_image_height, mask_image_width = mask_image.shape\n\n            if image_channels != 3:\n                raise ValueError(\"`image` must have 3 channels\")\n\n            if mask_image_channels != 1:\n                raise ValueError(\"`mask_image` must have 1 channel\")\n\n            if image_batch_size != mask_image_batch_size:\n                raise ValueError(\"`image` and `mask_image` mush have the same batch sizes\")\n\n            if image_height != mask_image_height or image_width != mask_image_width:\n                raise ValueError(\"`image` and `mask_image` must have the same height and width dimensions\")\n\n            if image.min() < -1 or image.max() > 1:\n                raise ValueError(\"`image` should be in range [-1, 1]\")\n\n            if mask_image.min() < 0 or mask_image.max() > 1:\n                raise ValueError(\"`mask_image` should be in range [0, 1]\")\n        else:\n            mask_image_channels = 1\n            image_channels = 3\n\n        single_image_latent_channels = self.vae.config.latent_channels\n\n        total_latent_channels = single_image_latent_channels * 2 + mask_image_channels\n\n        if total_latent_channels != self.unet.config.in_channels:\n            raise ValueError(\n                f\"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received\"\n                f\" non inpainting latent channels: {single_image_latent_channels},\"\n                f\" mask channels: {mask_image_channels}, and masked image channels: {single_image_latent_channels}.\"\n                f\" Please verify the config of `pipeline.unet` and the `mask_image` and `image` inputs.\"\n            )\n\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start:]\n\n        return timesteps, num_inference_steps - t_start\n\n    def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if isinstance(generator, list):\n            init_latents = [\n                self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)\n            ]\n            init_latents = torch.cat(init_latents, dim=0)\n        else:\n            init_latents = self.vae.encode(image).latent_dist.sample(generator)\n\n        init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        shape = init_latents.shape\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        # get latents\n        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n        latents = init_latents\n\n        return latents\n\n    def prepare_mask_latents(self, mask_image, batch_size, height, width, dtype, device, do_classifier_free_guidance):\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\n        # and half precision\n        mask_image = F.interpolate(mask_image, size=(height // self.vae_scale_factor, width // self.vae_scale_factor))\n        mask_image = mask_image.to(device=device, dtype=dtype)\n\n        # duplicate mask for each generation per prompt, using mps friendly method\n        if mask_image.shape[0] < batch_size:\n            if not batch_size % mask_image.shape[0] == 0:\n                raise ValueError(\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\n                    f\" a total batch size of {batch_size}, but {mask_image.shape[0]} masks were passed. Make sure the number\"\n                    \" of masks that you pass is divisible by the total requested batch size.\"\n                )\n            mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)\n\n        mask_image = torch.cat([mask_image] * 2) if do_classifier_free_guidance else mask_image\n\n        mask_image_latents = mask_image\n\n        return mask_image_latents\n\n    def prepare_masked_image_latents(\n        self, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance\n    ):\n        masked_image = masked_image.to(device=device, dtype=dtype)\n\n        # encode the mask image into latents space so we can concatenate it to the latents\n        if isinstance(generator, list):\n            masked_image_latents = [\n                self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])\n                for i in range(batch_size)\n            ]\n            masked_image_latents = torch.cat(masked_image_latents, dim=0)\n        else:\n            masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)\n        masked_image_latents = self.vae.config.scaling_factor * masked_image_latents\n\n        # duplicate masked_image_latents for each generation per prompt, using mps friendly method\n        if masked_image_latents.shape[0] < batch_size:\n            if not batch_size % masked_image_latents.shape[0] == 0:\n                raise ValueError(\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                    f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                )\n            masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)\n\n        masked_image_latents = (\n            torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents\n        )\n\n        # aligning device to prevent device errors when concating it with the latent model input\n        masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\n        return masked_image_latents\n\n    def _default_height_width(self, height, width, image):\n        if isinstance(image, list):\n            image = image[0]\n\n        if height is None:\n            if isinstance(image, PIL.Image.Image):\n                height = image.height\n            elif isinstance(image, torch.Tensor):\n                height = image.shape[3]\n\n            height = (height // 8) * 8  # round down to nearest multiple of 8\n\n        if width is None:\n            if isinstance(image, PIL.Image.Image):\n                width = image.width\n            elif isinstance(image, torch.Tensor):\n                width = image.shape[2]\n\n            width = (width // 8) * 8  # round down to nearest multiple of 8\n\n        return height, width\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: Union[torch.Tensor, PIL.Image.Image] = None,\n        mask_image: Union[torch.Tensor, PIL.Image.Image] = None,\n        controlnet_conditioning_image: Union[\n            torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]\n        ] = None,\n        strength: float = 0.8,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: float = 1.0,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will\n                be masked out with `mask_image` and repainted according to `prompt`.\n            mask_image (`torch.Tensor` or `PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted\n                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)\n                instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            controlnet_conditioning_image (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]` or `List[PIL.Image.Image]`):\n                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If\n                the type is specified as `torch.Tensor`, it is passed to ControlNet as is. PIL.Image.Image` can\n                also be accepted as an image. The control image is automatically resized to fit the output image.\n            strength (`float`, *optional*):\n                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`\n                will be used as a starting point, adding more noise to it the larger the `strength`. The number of\n                denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will\n                be maximum and the denoising process will run for the full number of iterations specified in\n                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original unet.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        # 0. Default height and width to unet\n        height, width = self._default_height_width(height, width, controlnet_conditioning_image)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            image,\n            mask_image,\n            controlnet_conditioning_image,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            strength,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n\n        # 4. Prepare mask, image, and controlnet_conditioning_image\n        image = prepare_image(image)\n\n        mask_image = prepare_mask_image(mask_image)\n\n        controlnet_conditioning_image = prepare_controlnet_conditioning_image(\n            controlnet_conditioning_image,\n            width,\n            height,\n            batch_size * num_images_per_prompt,\n            num_images_per_prompt,\n            device,\n            self.controlnet.dtype,\n        )\n\n        masked_image = image * (mask_image < 0.5)\n\n        # 5. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 6. Prepare latent variables\n        if latents is None:\n            latents = self.prepare_latents(\n                image,\n                latent_timestep,\n                batch_size,\n                num_images_per_prompt,\n                prompt_embeds.dtype,\n                device,\n                generator,\n            )\n\n        mask_image_latents = self.prepare_mask_latents(\n            mask_image,\n            batch_size * num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            do_classifier_free_guidance,\n        )\n\n        masked_image_latents = self.prepare_masked_image_latents(\n            masked_image,\n            batch_size * num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            do_classifier_free_guidance,\n        )\n\n        if do_classifier_free_guidance:\n            controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                non_inpainting_latent_model_input = (\n                    torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                )\n\n                non_inpainting_latent_model_input = self.scheduler.scale_model_input(\n                    non_inpainting_latent_model_input, t\n                )\n\n                inpainting_latent_model_input = torch.cat(\n                    [non_inpainting_latent_model_input, mask_image_latents, masked_image_latents], dim=1\n                )\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    non_inpainting_latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    controlnet_cond=controlnet_conditioning_image,\n                    return_dict=False,\n                )\n\n                down_block_res_samples = [\n                    down_block_res_sample * controlnet_conditioning_scale\n                    for down_block_res_sample in down_block_res_samples\n                ]\n                mid_block_res_sample *= controlnet_conditioning_scale\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    inpainting_latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                ).sample\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        if output_type == \"latent\":\n            image = latents\n            has_nsfw_concept = None\n        elif output_type == \"pil\":\n            # 8. Post-processing\n            image = self.decode_latents(latents)\n\n            # 9. Run safety checker\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n\n            # 10. Convert to PIL\n            image = self.numpy_to_pil(image)\n        else:\n            # 8. Post-processing\n            image = self.decode_latents(latents)\n\n            # 9. Run safety checker\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/stable_diffusion_controlnet_reference.py",
    "content": "# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\n\nfrom diffusers import StableDiffusionControlNetPipeline\nfrom diffusers.models import ControlNetModel\nfrom diffusers.models.attention import BasicTransformerBlock\nfrom diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.utils import logging\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import cv2\n        >>> import torch\n        >>> import numpy as np\n        >>> from PIL import Image\n        >>> from diffusers import UniPCMultistepScheduler\n        >>> from diffusers.utils import load_image\n\n        >>> input_image = load_image(\"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\")\n\n        >>> # get canny image\n        >>> image = cv2.Canny(np.array(input_image), 100, 200)\n        >>> image = image[:, :, None]\n        >>> image = np.concatenate([image, image, image], axis=2)\n        >>> canny_image = Image.fromarray(image)\n\n        >>> controlnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\n        >>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(\n                \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n                controlnet=controlnet,\n                safety_checker=None,\n                torch_dtype=torch.float16\n                ).to('cuda:0')\n\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)\n\n        >>> result_img = pipe(ref_image=input_image,\n                        prompt=\"1girl\",\n                        image=canny_image,\n                        num_inference_steps=20,\n                        reference_attn=True,\n                        reference_adain=True).images[0]\n\n        >>> result_img.show()\n        ```\n\"\"\"\n\n\ndef torch_dfs(model: torch.nn.Module):\n    result = [model]\n    for child in model.children():\n        result += torch_dfs(child)\n    return result\n\n\nclass StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline):\n    def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):\n        refimage = refimage.to(device=device, dtype=dtype)\n\n        # encode the mask image into latents space so we can concatenate it to the latents\n        if isinstance(generator, list):\n            ref_image_latents = [\n                self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])\n                for i in range(batch_size)\n            ]\n            ref_image_latents = torch.cat(ref_image_latents, dim=0)\n        else:\n            ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)\n        ref_image_latents = self.vae.config.scaling_factor * ref_image_latents\n\n        # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method\n        if ref_image_latents.shape[0] < batch_size:\n            if not batch_size % ref_image_latents.shape[0] == 0:\n                raise ValueError(\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                    f\" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed.\"\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                )\n            ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)\n\n        ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents\n\n        # aligning device to prevent device errors when concating it with the latent model input\n        ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)\n        return ref_image_latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: Union[\n            torch.Tensor,\n            PIL.Image.Image,\n            np.ndarray,\n            List[torch.Tensor],\n            List[PIL.Image.Image],\n            List[np.ndarray],\n        ] = None,\n        ref_image: Union[torch.Tensor, PIL.Image.Image] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        guess_mode: bool = False,\n        attention_auto_machine_weight: float = 1.0,\n        gn_auto_machine_weight: float = 1.0,\n        style_fidelity: float = 0.5,\n        reference_attn: bool = True,\n        reference_adain: bool = True,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If\n                the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can\n                also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If\n                height and/or width are passed, `image` is resized according to them. If multiple ControlNets are\n                specified in init, images must be passed as a list such that each element of the list can be correctly\n                batched for input to a single controlnet.\n            ref_image (`torch.Tensor`, `PIL.Image.Image`):\n                The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If\n                the type is specified as `torch.Tensor`, it is passed to Reference Control as is. `PIL.Image.Image` can\n                also be accepted as an image.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original unet. If multiple ControlNets are specified in init, you can set the\n                corresponding scale as a list.\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                In this mode, the ControlNet encoder will try best to recognize the content of the input image even if\n                you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.\n            attention_auto_machine_weight (`float`):\n                Weight of using reference query for self attention's context.\n                If attention_auto_machine_weight=1.0, use reference query for all self attention's context.\n            gn_auto_machine_weight (`float`):\n                Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.\n            style_fidelity (`float`):\n                style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,\n                elif style_fidelity=0.0, prompt more important, else balanced.\n            reference_attn (`bool`):\n                Whether to use reference query for self attention's context.\n            reference_adain (`bool`):\n                Whether to use reference adain.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        assert reference_attn or reference_adain, \"`reference_attn` or `reference_adain` must be True.\"\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            image,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            controlnet_conditioning_scale,\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        global_pool_conditions = (\n            controlnet.config.global_pool_conditions\n            if isinstance(controlnet, ControlNetModel)\n            else controlnet.nets[0].config.global_pool_conditions\n        )\n        guess_mode = guess_mode or global_pool_conditions\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 4. Prepare image\n        if isinstance(controlnet, ControlNetModel):\n            image = self.prepare_image(\n                image=image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n            height, width = image.shape[-2:]\n        elif isinstance(controlnet, MultiControlNetModel):\n            images = []\n\n            for image_ in image:\n                image_ = self.prepare_image(\n                    image=image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                images.append(image_)\n\n            image = images\n            height, width = image[0].shape[-2:]\n        else:\n            assert False\n\n        # 5. Preprocess reference image\n        ref_image = self.prepare_image(\n            image=ref_image,\n            width=width,\n            height=height,\n            batch_size=batch_size * num_images_per_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            device=device,\n            dtype=prompt_embeds.dtype,\n        )\n\n        # 6. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 7. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 8. Prepare reference latent variables\n        ref_image_latents = self.prepare_ref_latents(\n            ref_image,\n            batch_size * num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            do_classifier_free_guidance,\n        )\n\n        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 9. Modify self attention and group norm\n        MODE = \"write\"\n        uc_mask = (\n            torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)\n            .type_as(ref_image_latents)\n            .bool()\n        )\n\n        def hacked_basic_transformer_inner_forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            encoder_attention_mask: Optional[torch.Tensor] = None,\n            timestep: Optional[torch.LongTensor] = None,\n            cross_attention_kwargs: Dict[str, Any] = None,\n            class_labels: Optional[torch.LongTensor] = None,\n        ):\n            if self.use_ada_layer_norm:\n                norm_hidden_states = self.norm1(hidden_states, timestep)\n            elif self.use_ada_layer_norm_zero:\n                norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(\n                    hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype\n                )\n            else:\n                norm_hidden_states = self.norm1(hidden_states)\n\n            # 1. Self-Attention\n            cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n            if self.only_cross_attention:\n                attn_output = self.attn1(\n                    norm_hidden_states,\n                    encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                    attention_mask=attention_mask,\n                    **cross_attention_kwargs,\n                )\n            else:\n                if MODE == \"write\":\n                    self.bank.append(norm_hidden_states.detach().clone())\n                    attn_output = self.attn1(\n                        norm_hidden_states,\n                        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                        attention_mask=attention_mask,\n                        **cross_attention_kwargs,\n                    )\n                if MODE == \"read\":\n                    if attention_auto_machine_weight > self.attn_weight:\n                        attn_output_uc = self.attn1(\n                            norm_hidden_states,\n                            encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),\n                            # attention_mask=attention_mask,\n                            **cross_attention_kwargs,\n                        )\n                        attn_output_c = attn_output_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            attn_output_c[uc_mask] = self.attn1(\n                                norm_hidden_states[uc_mask],\n                                encoder_hidden_states=norm_hidden_states[uc_mask],\n                                **cross_attention_kwargs,\n                            )\n                        attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc\n                        self.bank.clear()\n                    else:\n                        attn_output = self.attn1(\n                            norm_hidden_states,\n                            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                            attention_mask=attention_mask,\n                            **cross_attention_kwargs,\n                        )\n            if self.use_ada_layer_norm_zero:\n                attn_output = gate_msa.unsqueeze(1) * attn_output\n            hidden_states = attn_output + hidden_states\n\n            if self.attn2 is not None:\n                norm_hidden_states = (\n                    self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)\n                )\n\n                # 2. Cross-Attention\n                attn_output = self.attn2(\n                    norm_hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=encoder_attention_mask,\n                    **cross_attention_kwargs,\n                )\n                hidden_states = attn_output + hidden_states\n\n            # 3. Feed-forward\n            norm_hidden_states = self.norm3(hidden_states)\n\n            if self.use_ada_layer_norm_zero:\n                norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]\n\n            ff_output = self.ff(norm_hidden_states)\n\n            if self.use_ada_layer_norm_zero:\n                ff_output = gate_mlp.unsqueeze(1) * ff_output\n\n            hidden_states = ff_output + hidden_states\n\n            return hidden_states\n\n        def hacked_mid_forward(self, *args, **kwargs):\n            eps = 1e-6\n            x = self.original_forward(*args, **kwargs)\n            if MODE == \"write\":\n                if gn_auto_machine_weight >= self.gn_weight:\n                    var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)\n                    self.mean_bank.append(mean)\n                    self.var_bank.append(var)\n            if MODE == \"read\":\n                if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                    var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)\n                    std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                    mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))\n                    var_acc = sum(self.var_bank) / float(len(self.var_bank))\n                    std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                    x_uc = (((x - mean) / std) * std_acc) + mean_acc\n                    x_c = x_uc.clone()\n                    if do_classifier_free_guidance and style_fidelity > 0:\n                        x_c[uc_mask] = x[uc_mask]\n                    x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc\n                self.mean_bank = []\n                self.var_bank = []\n            return x\n\n        def hack_CrossAttnDownBlock2D_forward(\n            self,\n            hidden_states: torch.Tensor,\n            temb: Optional[torch.Tensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            encoder_attention_mask: Optional[torch.Tensor] = None,\n        ):\n            eps = 1e-6\n\n            # TODO(Patrick, William) - attention mask is not used\n            output_states = ()\n\n            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n                output_states = output_states + (hidden_states,)\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.downsamplers is not None:\n                for downsampler in self.downsamplers:\n                    hidden_states = downsampler(hidden_states)\n\n                output_states = output_states + (hidden_states,)\n\n            return hidden_states, output_states\n\n        def hacked_DownBlock2D_forward(self, hidden_states, temb=None, *args, **kwargs):\n            eps = 1e-6\n\n            output_states = ()\n\n            for i, resnet in enumerate(self.resnets):\n                hidden_states = resnet(hidden_states, temb)\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n                output_states = output_states + (hidden_states,)\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.downsamplers is not None:\n                for downsampler in self.downsamplers:\n                    hidden_states = downsampler(hidden_states)\n\n                output_states = output_states + (hidden_states,)\n\n            return hidden_states, output_states\n\n        def hacked_CrossAttnUpBlock2D_forward(\n            self,\n            hidden_states: torch.Tensor,\n            res_hidden_states_tuple: Tuple[torch.Tensor, ...],\n            temb: Optional[torch.Tensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            upsample_size: Optional[int] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            encoder_attention_mask: Optional[torch.Tensor] = None,\n        ):\n            eps = 1e-6\n            # TODO(Patrick, William) - attention mask is not used\n            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):\n                # pop res hidden states\n                res_hidden_states = res_hidden_states_tuple[-1]\n                res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.upsamplers is not None:\n                for upsampler in self.upsamplers:\n                    hidden_states = upsampler(hidden_states, upsample_size)\n\n            return hidden_states\n\n        def hacked_UpBlock2D_forward(\n            self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, *args, **kwargs\n        ):\n            eps = 1e-6\n            for i, resnet in enumerate(self.resnets):\n                # pop res hidden states\n                res_hidden_states = res_hidden_states_tuple[-1]\n                res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n                hidden_states = resnet(hidden_states, temb)\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.upsamplers is not None:\n                for upsampler in self.upsamplers:\n                    hidden_states = upsampler(hidden_states, upsample_size)\n\n            return hidden_states\n\n        if reference_attn:\n            attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]\n            attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])\n\n            for i, module in enumerate(attn_modules):\n                module._original_inner_forward = module.forward\n                module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)\n                module.bank = []\n                module.attn_weight = float(i) / float(len(attn_modules))\n\n        if reference_adain:\n            gn_modules = [self.unet.mid_block]\n            self.unet.mid_block.gn_weight = 0\n\n            down_blocks = self.unet.down_blocks\n            for w, module in enumerate(down_blocks):\n                module.gn_weight = 1.0 - float(w) / float(len(down_blocks))\n                gn_modules.append(module)\n\n            up_blocks = self.unet.up_blocks\n            for w, module in enumerate(up_blocks):\n                module.gn_weight = float(w) / float(len(up_blocks))\n                gn_modules.append(module)\n\n            for i, module in enumerate(gn_modules):\n                if getattr(module, \"original_forward\", None) is None:\n                    module.original_forward = module.forward\n                if i == 0:\n                    # mid_block\n                    module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)\n                elif isinstance(module, CrossAttnDownBlock2D):\n                    module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)\n                elif isinstance(module, DownBlock2D):\n                    module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)\n                elif isinstance(module, CrossAttnUpBlock2D):\n                    module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)\n                elif isinstance(module, UpBlock2D):\n                    module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)\n                module.mean_bank = []\n                module.var_bank = []\n                module.gn_weight *= 2\n\n        # 11. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # controlnet(s) inference\n                if guess_mode and do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                else:\n                    control_model_input = latent_model_input\n                    controlnet_prompt_embeds = prompt_embeds\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=controlnet_prompt_embeds,\n                    controlnet_cond=image,\n                    conditioning_scale=controlnet_conditioning_scale,\n                    guess_mode=guess_mode,\n                    return_dict=False,\n                )\n\n                if guess_mode and do_classifier_free_guidance:\n                    # Inferred ControlNet only for the conditional batch.\n                    # To apply the output of ControlNet to both the unconditional and conditional batches,\n                    # add 0 to the unconditional batch to keep it unchanged.\n                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])\n\n                # ref only part\n                noise = randn_tensor(\n                    ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype\n                )\n                ref_xt = self.scheduler.add_noise(\n                    ref_image_latents,\n                    noise,\n                    t.reshape(\n                        1,\n                    ),\n                )\n                ref_xt = self.scheduler.scale_model_input(ref_xt, t)\n\n                MODE = \"write\"\n                self.unet(\n                    ref_xt,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    return_dict=False,\n                )\n\n                # predict the noise residual\n                MODE = \"read\"\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/stable_diffusion_ipex.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport intel_extension_for_pytorch as ipex\nimport torch\nfrom packaging import version\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    deprecate,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import StableDiffusionPipeline\n\n        >>> pipe = DiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", custom_pipeline=\"stable_diffusion_ipex\")\n\n        >>> # For Float32\n        >>> pipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) #value of image height/width should be consistent with the pipeline inference\n        >>> # For BFloat16\n        >>> pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512) #value of image height/width should be consistent with the pipeline inference\n\n        >>> prompt = \"a photo of an astronaut riding a horse on mars\"\n        >>> # For Float32\n        >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()'\n        >>> # For BFloat16\n        >>> with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):\n        >>>     image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()'\n        ```\n\"\"\"\n\n\nclass StableDiffusionIPEXPipeline(\n    DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion on IPEX.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1):\n        prompt_embeds = None\n        negative_prompt_embeds = None\n        negative_prompt = None\n        callback_steps = 1\n        generator = None\n        latents = None\n\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n\n        device = \"cpu\"\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n\n        # 5. Prepare latent variables\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            self.unet.config.in_channels,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n        dummy = torch.ones(1, dtype=torch.int32)\n        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n        latent_model_input = self.scheduler.scale_model_input(latent_model_input, dummy)\n\n        unet_input_example = (latent_model_input, dummy, prompt_embeds)\n        vae_decoder_input_example = latents\n\n        return unet_input_example, vae_decoder_input_example\n\n    def prepare_for_ipex(self, promt, dtype=torch.float32, height=None, width=None, guidance_scale=7.5):\n        self.unet = self.unet.to(memory_format=torch.channels_last)\n        self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last)\n        self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last)\n        if self.safety_checker is not None:\n            self.safety_checker = self.safety_checker.to(memory_format=torch.channels_last)\n\n        unet_input_example, vae_decoder_input_example = self.get_input_example(promt, height, width, guidance_scale)\n\n        # optimize with ipex\n        if dtype == torch.bfloat16:\n            self.unet = ipex.optimize(self.unet.eval(), dtype=torch.bfloat16, inplace=True)\n            self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True)\n            self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)\n            if self.safety_checker is not None:\n                self.safety_checker = ipex.optimize(self.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)\n        elif dtype == torch.float32:\n            self.unet = ipex.optimize(\n                self.unet.eval(),\n                dtype=torch.float32,\n                inplace=True,\n                weights_prepack=True,\n                auto_kernel_selection=False,\n            )\n            self.vae.decoder = ipex.optimize(\n                self.vae.decoder.eval(),\n                dtype=torch.float32,\n                inplace=True,\n                weights_prepack=True,\n                auto_kernel_selection=False,\n            )\n            self.text_encoder = ipex.optimize(\n                self.text_encoder.eval(),\n                dtype=torch.float32,\n                inplace=True,\n                weights_prepack=True,\n                auto_kernel_selection=False,\n            )\n            if self.safety_checker is not None:\n                self.safety_checker = ipex.optimize(\n                    self.safety_checker.eval(),\n                    dtype=torch.float32,\n                    inplace=True,\n                    weights_prepack=True,\n                    auto_kernel_selection=False,\n                )\n        else:\n            raise ValueError(\" The value of 'dtype' should be 'torch.bfloat16' or 'torch.float32' !\")\n\n        # trace unet model to get better performance on IPEX\n        with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():\n            unet_trace_model = torch.jit.trace(self.unet, unet_input_example, check_trace=False, strict=False)\n            unet_trace_model = torch.jit.freeze(unet_trace_model)\n        self.unet.forward = unet_trace_model.forward\n\n        # trace vae.decoder model to get better performance on IPEX\n        with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad():\n            ave_decoder_trace_model = torch.jit.trace(\n                self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False\n            )\n            ave_decoder_trace_model = torch.jit.freeze(ave_decoder_trace_model)\n        self.vae.decoder.forward = ave_decoder_trace_model.forward\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n        \"\"\"\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            prompt_embeds = self.text_encoder(\n                text_input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        else:\n            has_nsfw_concept = None\n        return image, has_nsfw_concept\n\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)[\"sample\"]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if output_type == \"latent\":\n            image = latents\n            has_nsfw_concept = None\n        elif output_type == \"pil\":\n            # 8. Post-processing\n            image = self.decode_latents(latents)\n\n            # 9. Run safety checker\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n\n            # 10. Convert to PIL\n            image = self.numpy_to_pil(image)\n        else:\n            # 8. Post-processing\n            image = self.decode_latents(latents)\n\n            # 9. Run safety checker\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/stable_diffusion_mega.py",
    "content": "from typing import Any, Callable, List, Optional, Union\n\nimport PIL.Image\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDIMScheduler,\n    DiffusionPipeline,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n    StableDiffusionImg2ImgPipeline,\n    StableDiffusionInpaintPipelineLegacy,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.utils import deprecate, logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass StableDiffusionMegaPipeline(DiffusionPipeline, StableDiffusionMixin):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionMegaSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    @property\n    def components(self) -> dict[str, Any]:\n        return {k: getattr(self, k) for k in self.config.keys() if not k.startswith(\"_\")}\n\n    @torch.no_grad()\n    def inpaint(\n        self,\n        prompt: Union[str, List[str]],\n        image: Union[torch.Tensor, PIL.Image.Image],\n        mask_image: Union[torch.Tensor, PIL.Image.Image],\n        strength: float = 0.8,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: Optional[float] = 0.0,\n        generator: torch.Generator | None = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n    ):\n        # For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline\n        return StableDiffusionInpaintPipelineLegacy(**self.components)(\n            prompt=prompt,\n            image=image,\n            mask_image=mask_image,\n            strength=strength,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n        )\n\n    @torch.no_grad()\n    def img2img(\n        self,\n        prompt: Union[str, List[str]],\n        image: Union[torch.Tensor, PIL.Image.Image],\n        strength: float = 0.8,\n        num_inference_steps: Optional[int] = 50,\n        guidance_scale: Optional[float] = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: Optional[float] = 0.0,\n        generator: torch.Generator | None = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        # For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline\n        return StableDiffusionImg2ImgPipeline(**self.components)(\n            prompt=prompt,\n            image=image,\n            strength=strength,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n        )\n\n    @torch.no_grad()\n    def text2img(\n        self,\n        prompt: Union[str, List[str]],\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n    ):\n        # For more information on how this function https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline\n        return StableDiffusionPipeline(**self.components)(\n            prompt=prompt,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n        )\n"
  },
  {
    "path": "examples/community/stable_diffusion_reference.py",
    "content": "# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom packaging import version\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel\nfrom diffusers.configuration_utils import FrozenDict, deprecate\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models.attention import BasicTransformerBlock\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    PIL_INTERPOLATION,\n    USE_PEFT_BACKEND,\n    logging,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import UniPCMultistepScheduler\n        >>> from diffusers.utils import load_image\n\n        >>> input_image = load_image(\"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\")\n\n        >>> pipe = StableDiffusionReferencePipeline.from_pretrained(\n                \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n                safety_checker=None,\n                torch_dtype=torch.float16\n                ).to('cuda:0')\n\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n\n        >>> result_img = pipe(ref_image=input_image,\n                        prompt=\"1girl\",\n                        num_inference_steps=20,\n                        reference_attn=True,\n                        reference_adain=True).images[0]\n\n        >>> result_img.show()\n        ```\n\"\"\"\n\n\ndef torch_dfs(model: torch.nn.Module):\n    r\"\"\"\n    Performs a depth-first search on the given PyTorch model and returns a list of all its child modules.\n\n    Args:\n        model (torch.nn.Module): The PyTorch model to perform the depth-first search on.\n\n    Returns:\n        list: A list of all child modules of the given model.\n    \"\"\"\n    result = [model]\n    for child in model.children():\n        result += torch_dfs(child)\n    return result\n\n\nclass StableDiffusionReferencePipeline(\n    DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin\n):\n    r\"\"\"\n    Pipeline for Stable Diffusion Reference.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n    - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n    - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n    - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n    - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n    - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"skip_prk_steps\", True) is False:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration\"\n                \" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make\"\n                \" sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to\"\n                \" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face\"\n                \" Hub, it would be very nice if you could open a Pull request for the\"\n                \" `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\n                \"skip_prk_steps not set\",\n                \"1.0.0\",\n                deprecation_message,\n                standard_warn=False,\n            )\n            new_config = dict(scheduler.config)\n            new_config[\"skip_prk_steps\"] = True\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n        # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4\n        if unet is not None and unet.config.in_channels != 4:\n            logger.warning(\n                f\"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,\"\n                f\" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,\"\n                \". If you did not intend to modify\"\n                \" this behavior, please check whether you have loaded the right checkpoint.\"\n            )\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def _default_height_width(\n        self,\n        height: Optional[int],\n        width: Optional[int],\n        image: Union[PIL.Image.Image, torch.Tensor, List[PIL.Image.Image]],\n    ) -> Tuple[int, int]:\n        r\"\"\"\n        Calculate the default height and width for the given image.\n\n        Args:\n            height (int or None): The desired height of the image. If None, the height will be determined based on the input image.\n            width (int or None): The desired width of the image. If None, the width will be determined based on the input image.\n            image (PIL.Image.Image or torch.Tensor or list[PIL.Image.Image]): The input image or a list of images.\n\n        Returns:\n            Tuple[int, int]: A tuple containing the calculated height and width.\n\n        \"\"\"\n        # NOTE: It is possible that a list of images have different\n        # dimensions for each image, so just checking the first image\n        # is not _exactly_ correct, but it is simple.\n        while isinstance(image, list):\n            image = image[0]\n\n        if height is None:\n            if isinstance(image, PIL.Image.Image):\n                height = image.height\n            elif isinstance(image, torch.Tensor):\n                height = image.shape[2]\n\n            height = (height // 8) * 8  # round down to nearest multiple of 8\n\n        if width is None:\n            if isinstance(image, PIL.Image.Image):\n                width = image.width\n            elif isinstance(image, torch.Tensor):\n                width = image.shape[3]\n\n            width = (width // 8) * 8  # round down to nearest multiple of 8\n\n        return height, width\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs\n    def check_inputs(\n        self,\n        prompt: Optional[Union[str, List[str]]],\n        height: int,\n        width: int,\n        callback_steps: Optional[int],\n        negative_prompt: str | None = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[torch.Tensor] = None,\n        ip_adapter_image_embeds: Optional[torch.Tensor] = None,\n        callback_on_step_end_tensor_inputs: Optional[List[str]] = None,\n    ) -> None:\n        \"\"\"\n        Check the validity of the input arguments for the diffusion model.\n\n        Args:\n            prompt (Optional[Union[str, List[str]]]): The prompt text or list of prompt texts.\n            height (int): The height of the input image.\n            width (int): The width of the input image.\n            callback_steps (Optional[int]): The number of steps to perform the callback on.\n            negative_prompt (str | None): The negative prompt text.\n            prompt_embeds (Optional[torch.Tensor]): The prompt embeddings.\n            negative_prompt_embeds (Optional[torch.Tensor]): The negative prompt embeddings.\n            ip_adapter_image (Optional[torch.Tensor]): The input adapter image.\n            ip_adapter_image_embeds (Optional[torch.Tensor]): The input adapter image embeddings.\n            callback_on_step_end_tensor_inputs (Optional[List[str]]): The list of tensor inputs to perform the callback on.\n\n        Raises:\n            ValueError: If `height` or `width` is not divisible by 8.\n            ValueError: If `callback_steps` is not a positive integer.\n            ValueError: If `callback_on_step_end_tensor_inputs` contains invalid tensor inputs.\n            ValueError: If both `prompt` and `prompt_embeds` are provided.\n            ValueError: If neither `prompt` nor `prompt_embeds` are provided.\n            ValueError: If `prompt` is not of type `str` or `list`.\n            ValueError: If both `negative_prompt` and `negative_prompt_embeds` are provided.\n            ValueError: If both `prompt_embeds` and `negative_prompt_embeds` are provided and have different shapes.\n            ValueError: If both `ip_adapter_image` and `ip_adapter_image_embeds` are provided.\n\n        Returns:\n            None\n        \"\"\"\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        device: torch.device,\n        num_images_per_prompt: int,\n        do_classifier_free_guidance: bool,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Encodes the prompt into embeddings.\n\n        Args:\n            prompt (Union[str, List[str]]): The prompt text or a list of prompt texts.\n            device (torch.device): The device to use for encoding.\n            num_images_per_prompt (int): The number of images per prompt.\n            do_classifier_free_guidance (bool): Whether to use classifier-free guidance.\n            negative_prompt (Optional[Union[str, List[str]]], optional): The negative prompt text or a list of negative prompt texts. Defaults to None.\n            prompt_embeds (Optional[torch.Tensor], optional): The prompt embeddings. Defaults to None.\n            negative_prompt_embeds (Optional[torch.Tensor], optional): The negative prompt embeddings. Defaults to None.\n            lora_scale (Optional[float], optional): The LoRA scale. Defaults to None.\n            **kwargs: Additional keyword arguments.\n\n        Returns:\n            torch.Tensor: The encoded prompt embeddings.\n        \"\"\"\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: str | None,\n        device: torch.device,\n        num_images_per_prompt: int,\n        do_classifier_free_guidance: bool,\n        negative_prompt: str | None = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(\n        self,\n        batch_size: int,\n        num_channels_latents: int,\n        height: int,\n        width: int,\n        dtype: torch.dtype,\n        device: torch.device,\n        generator: Union[torch.Generator, List[torch.Generator]],\n        latents: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Prepare the latent vectors for diffusion.\n\n        Args:\n            batch_size (int): The number of samples in the batch.\n            num_channels_latents (int): The number of channels in the latent vectors.\n            height (int): The height of the latent vectors.\n            width (int): The width of the latent vectors.\n            dtype (torch.dtype): The data type of the latent vectors.\n            device (torch.device): The device to place the latent vectors on.\n            generator (Union[torch.Generator, List[torch.Generator]]): The generator(s) to use for random number generation.\n            latents (Optional[torch.Tensor]): The pre-existing latent vectors. If None, new latent vectors will be generated.\n\n        Returns:\n            torch.Tensor: The prepared latent vectors.\n        \"\"\"\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(\n        self, generator: Union[torch.Generator, List[torch.Generator]], eta: float\n    ) -> dict[str, Any]:\n        r\"\"\"\n        Prepare extra keyword arguments for the scheduler step.\n\n        Args:\n            generator (Union[torch.Generator, List[torch.Generator]]): The generator used for sampling.\n            eta (float): The value of eta (η) used with the DDIMScheduler. Should be between 0 and 1.\n\n        Returns:\n            Dict[str, Any]: A dictionary containing the extra keyword arguments for the scheduler step.\n        \"\"\"\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def prepare_image(\n        self,\n        image: Union[torch.Tensor, PIL.Image.Image, List[Union[torch.Tensor, PIL.Image.Image]]],\n        width: int,\n        height: int,\n        batch_size: int,\n        num_images_per_prompt: int,\n        device: torch.device,\n        dtype: torch.dtype,\n        do_classifier_free_guidance: bool = False,\n        guess_mode: bool = False,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Prepares the input image for processing.\n\n        Args:\n            image (torch.Tensor or PIL.Image.Image or list): The input image(s).\n            width (int): The desired width of the image.\n            height (int): The desired height of the image.\n            batch_size (int): The batch size for processing.\n            num_images_per_prompt (int): The number of images per prompt.\n            device (torch.device): The device to use for processing.\n            dtype (torch.dtype): The data type of the image.\n            do_classifier_free_guidance (bool, optional): Whether to perform classifier-free guidance. Defaults to False.\n            guess_mode (bool, optional): Whether to use guess mode. Defaults to False.\n\n        Returns:\n            torch.Tensor: The prepared image for processing.\n        \"\"\"\n        if not isinstance(image, torch.Tensor):\n            if isinstance(image, PIL.Image.Image):\n                image = [image]\n\n            if isinstance(image[0], PIL.Image.Image):\n                images = []\n\n                for image_ in image:\n                    image_ = image_.convert(\"RGB\")\n                    image_ = image_.resize((width, height), resample=PIL_INTERPOLATION[\"lanczos\"])\n                    image_ = np.array(image_)\n                    image_ = image_[None, :]\n                    images.append(image_)\n\n                image = images\n\n                image = np.concatenate(image, axis=0)\n                image = np.array(image).astype(np.float32) / 255.0\n                image = (image - 0.5) / 0.5\n                image = image.transpose(0, 3, 1, 2)\n                image = torch.from_numpy(image)\n            elif isinstance(image[0], torch.Tensor):\n                image = torch.cat(image, dim=0)\n\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    def prepare_ref_latents(\n        self,\n        refimage: torch.Tensor,\n        batch_size: int,\n        dtype: torch.dtype,\n        device: torch.device,\n        generator: Union[int, List[int]],\n        do_classifier_free_guidance: bool,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Prepares reference latents for generating images.\n\n        Args:\n            refimage (torch.Tensor): The reference image.\n            batch_size (int): The desired batch size.\n            dtype (torch.dtype): The data type of the tensors.\n            device (torch.device): The device to perform computations on.\n            generator (int or list): The generator index or a list of generator indices.\n            do_classifier_free_guidance (bool): Whether to use classifier-free guidance.\n\n        Returns:\n            torch.Tensor: The prepared reference latents.\n        \"\"\"\n        refimage = refimage.to(device=device, dtype=dtype)\n\n        # encode the mask image into latents space so we can concatenate it to the latents\n        if isinstance(generator, list):\n            ref_image_latents = [\n                self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])\n                for i in range(batch_size)\n            ]\n            ref_image_latents = torch.cat(ref_image_latents, dim=0)\n        else:\n            ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)\n        ref_image_latents = self.vae.config.scaling_factor * ref_image_latents\n\n        # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method\n        if ref_image_latents.shape[0] < batch_size:\n            if not batch_size % ref_image_latents.shape[0] == 0:\n                raise ValueError(\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                    f\" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed.\"\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                )\n            ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)\n\n        # aligning device to prevent device errors when concating it with the latent model input\n        ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)\n        return ref_image_latents\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(\n        self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype\n    ) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:\n        r\"\"\"\n        Runs the safety checker on the given image.\n\n        Args:\n            image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.\n            device (torch.device): The device to run the safety checker on.\n            dtype (torch.dtype): The data type of the input image.\n\n        Returns:\n            (image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and\n            a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.\n        \"\"\"\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        ref_image: Union[torch.Tensor, PIL.Image.Image] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        attention_auto_machine_weight: float = 1.0,\n        gn_auto_machine_weight: float = 1.0,\n        style_fidelity: float = 0.5,\n        reference_attn: bool = True,\n        reference_adain: bool = True,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            ref_image (`torch.Tensor`, `PIL.Image.Image`):\n                The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If\n                the type is specified as `torch.Tensor`, it is passed to Reference Control as is. `PIL.Image.Image` can\n                also be accepted as an image.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of\n                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).\n                Guidance rescale factor should fix overexposure when using zero terminal SNR.\n            attention_auto_machine_weight (`float`):\n                Weight of using reference query for self attention's context.\n                If attention_auto_machine_weight=1.0, use reference query for all self attention's context.\n            gn_auto_machine_weight (`float`):\n                Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.\n            style_fidelity (`float`):\n                style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,\n                elif style_fidelity=0.0, prompt more important, else balanced.\n            reference_attn (`bool`):\n                Whether to use reference query for self attention's context.\n            reference_adain (`bool`):\n                Whether to use reference adain.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        assert reference_attn or reference_adain, \"`reference_attn` or `reference_adain` must be True.\"\n\n        # 0. Default height and width to unet\n        height, width = self._default_height_width(height, width, ref_image)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            cross_attention_kwargs.get(\"scale\", None) if cross_attention_kwargs is not None else None\n        )\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n        )\n\n        # 4. Preprocess reference image\n        ref_image = self.prepare_image(\n            image=ref_image,\n            width=width,\n            height=height,\n            batch_size=batch_size * num_images_per_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            device=device,\n            dtype=prompt_embeds.dtype,\n        )\n\n        # 5. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 7. Prepare reference latent variables\n        ref_image_latents = self.prepare_ref_latents(\n            ref_image,\n            batch_size * num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            do_classifier_free_guidance,\n        )\n\n        # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 9. Modify self attention and group norm\n        MODE = \"write\"\n        uc_mask = (\n            torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)\n            .type_as(ref_image_latents)\n            .bool()\n        )\n\n        def hacked_basic_transformer_inner_forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            encoder_attention_mask: Optional[torch.Tensor] = None,\n            timestep: Optional[torch.LongTensor] = None,\n            cross_attention_kwargs: Dict[str, Any] = None,\n            class_labels: Optional[torch.LongTensor] = None,\n        ):\n            if self.use_ada_layer_norm:\n                norm_hidden_states = self.norm1(hidden_states, timestep)\n            elif self.use_ada_layer_norm_zero:\n                norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(\n                    hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype\n                )\n            else:\n                norm_hidden_states = self.norm1(hidden_states)\n\n            # 1. Self-Attention\n            cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n            if self.only_cross_attention:\n                attn_output = self.attn1(\n                    norm_hidden_states,\n                    encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                    attention_mask=attention_mask,\n                    **cross_attention_kwargs,\n                )\n            else:\n                if MODE == \"write\":\n                    self.bank.append(norm_hidden_states.detach().clone())\n                    attn_output = self.attn1(\n                        norm_hidden_states,\n                        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                        attention_mask=attention_mask,\n                        **cross_attention_kwargs,\n                    )\n                if MODE == \"read\":\n                    if attention_auto_machine_weight > self.attn_weight:\n                        attn_output_uc = self.attn1(\n                            norm_hidden_states,\n                            encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),\n                            # attention_mask=attention_mask,\n                            **cross_attention_kwargs,\n                        )\n                        attn_output_c = attn_output_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            attn_output_c[uc_mask] = self.attn1(\n                                norm_hidden_states[uc_mask],\n                                encoder_hidden_states=norm_hidden_states[uc_mask],\n                                **cross_attention_kwargs,\n                            )\n                        attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc\n                        self.bank.clear()\n                    else:\n                        attn_output = self.attn1(\n                            norm_hidden_states,\n                            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                            attention_mask=attention_mask,\n                            **cross_attention_kwargs,\n                        )\n            if self.use_ada_layer_norm_zero:\n                attn_output = gate_msa.unsqueeze(1) * attn_output\n            hidden_states = attn_output + hidden_states\n\n            if self.attn2 is not None:\n                norm_hidden_states = (\n                    self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)\n                )\n\n                # 2. Cross-Attention\n                attn_output = self.attn2(\n                    norm_hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=encoder_attention_mask,\n                    **cross_attention_kwargs,\n                )\n                hidden_states = attn_output + hidden_states\n\n            # 3. Feed-forward\n            norm_hidden_states = self.norm3(hidden_states)\n\n            if self.use_ada_layer_norm_zero:\n                norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]\n\n            ff_output = self.ff(norm_hidden_states)\n\n            if self.use_ada_layer_norm_zero:\n                ff_output = gate_mlp.unsqueeze(1) * ff_output\n\n            hidden_states = ff_output + hidden_states\n\n            return hidden_states\n\n        def hacked_mid_forward(self, *args, **kwargs):\n            eps = 1e-6\n            x = self.original_forward(*args, **kwargs)\n            if MODE == \"write\":\n                if gn_auto_machine_weight >= self.gn_weight:\n                    var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)\n                    self.mean_bank.append(mean)\n                    self.var_bank.append(var)\n            if MODE == \"read\":\n                if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                    var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)\n                    std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                    mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))\n                    var_acc = sum(self.var_bank) / float(len(self.var_bank))\n                    std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                    x_uc = (((x - mean) / std) * std_acc) + mean_acc\n                    x_c = x_uc.clone()\n                    if do_classifier_free_guidance and style_fidelity > 0:\n                        x_c[uc_mask] = x[uc_mask]\n                    x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc\n                self.mean_bank = []\n                self.var_bank = []\n            return x\n\n        def hack_CrossAttnDownBlock2D_forward(\n            self,\n            hidden_states: torch.Tensor,\n            temb: Optional[torch.Tensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            encoder_attention_mask: Optional[torch.Tensor] = None,\n        ):\n            eps = 1e-6\n\n            # TODO(Patrick, William) - attention mask is not used\n            output_states = ()\n\n            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n                output_states = output_states + (hidden_states,)\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.downsamplers is not None:\n                for downsampler in self.downsamplers:\n                    hidden_states = downsampler(hidden_states)\n\n                output_states = output_states + (hidden_states,)\n\n            return hidden_states, output_states\n\n        def hacked_DownBlock2D_forward(\n            self,\n            hidden_states: torch.Tensor,\n            temb: Optional[torch.Tensor] = None,\n            **kwargs: Any,\n        ) -> Tuple[torch.Tensor, ...]:\n            eps = 1e-6\n\n            output_states = ()\n\n            for i, resnet in enumerate(self.resnets):\n                hidden_states = resnet(hidden_states, temb)\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n                output_states = output_states + (hidden_states,)\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.downsamplers is not None:\n                for downsampler in self.downsamplers:\n                    hidden_states = downsampler(hidden_states)\n\n                output_states = output_states + (hidden_states,)\n\n            return hidden_states, output_states\n\n        def hacked_CrossAttnUpBlock2D_forward(\n            self,\n            hidden_states: torch.Tensor,\n            res_hidden_states_tuple: Tuple[torch.Tensor, ...],\n            temb: Optional[torch.Tensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            upsample_size: Optional[int] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            encoder_attention_mask: Optional[torch.Tensor] = None,\n        ) -> torch.Tensor:\n            eps = 1e-6\n            # TODO(Patrick, William) - attention mask is not used\n            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):\n                # pop res hidden states\n                res_hidden_states = res_hidden_states_tuple[-1]\n                res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.upsamplers is not None:\n                for upsampler in self.upsamplers:\n                    hidden_states = upsampler(hidden_states, upsample_size)\n\n            return hidden_states\n\n        def hacked_UpBlock2D_forward(\n            self,\n            hidden_states: torch.Tensor,\n            res_hidden_states_tuple: Tuple[torch.Tensor, ...],\n            temb: Optional[torch.Tensor] = None,\n            upsample_size: Optional[int] = None,\n            **kwargs: Any,\n        ) -> torch.Tensor:\n            eps = 1e-6\n            for i, resnet in enumerate(self.resnets):\n                # pop res hidden states\n                res_hidden_states = res_hidden_states_tuple[-1]\n                res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n                hidden_states = resnet(hidden_states, temb)\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.upsamplers is not None:\n                for upsampler in self.upsamplers:\n                    hidden_states = upsampler(hidden_states, upsample_size)\n\n            return hidden_states\n\n        if reference_attn:\n            attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]\n            attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])\n\n            for i, module in enumerate(attn_modules):\n                module._original_inner_forward = module.forward\n                module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)\n                module.bank = []\n                module.attn_weight = float(i) / float(len(attn_modules))\n\n        if reference_adain:\n            gn_modules = [self.unet.mid_block]\n            self.unet.mid_block.gn_weight = 0\n\n            down_blocks = self.unet.down_blocks\n            for w, module in enumerate(down_blocks):\n                module.gn_weight = 1.0 - float(w) / float(len(down_blocks))\n                gn_modules.append(module)\n\n            up_blocks = self.unet.up_blocks\n            for w, module in enumerate(up_blocks):\n                module.gn_weight = float(w) / float(len(up_blocks))\n                gn_modules.append(module)\n\n            for i, module in enumerate(gn_modules):\n                if getattr(module, \"original_forward\", None) is None:\n                    module.original_forward = module.forward\n                if i == 0:\n                    # mid_block\n                    module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)\n                elif isinstance(module, CrossAttnDownBlock2D):\n                    module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)\n                elif isinstance(module, DownBlock2D):\n                    module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)\n                elif isinstance(module, CrossAttnUpBlock2D):\n                    module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)\n                elif isinstance(module, UpBlock2D):\n                    module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)\n                module.mean_bank = []\n                module.var_bank = []\n                module.gn_weight *= 2\n\n        # 10. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # ref only part\n                noise = randn_tensor(\n                    ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype\n                )\n                ref_xt = self.scheduler.add_noise(\n                    ref_image_latents,\n                    noise,\n                    t.reshape(\n                        1,\n                    ),\n                )\n                ref_xt = torch.cat([ref_xt] * 2) if do_classifier_free_guidance else ref_xt\n                ref_xt = self.scheduler.scale_model_input(ref_xt, t)\n\n                MODE = \"write\"\n                self.unet(\n                    ref_xt,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    return_dict=False,\n                )\n\n                # predict the noise residual\n                MODE = \"read\"\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if do_classifier_free_guidance and guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/stable_diffusion_repaint.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Callable, List, Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom packaging import version\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel\nfrom diffusers.configuration_utils import FrozenDict, deprecate\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import (\n    StableDiffusionSafetyChecker,\n)\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    logging,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef prepare_mask_and_masked_image(image, mask):\n    \"\"\"\n    Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be\n    converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the\n    ``image`` and ``1`` for the ``mask``.\n    The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be\n    binarized (``mask > 0.5``) and cast to ``torch.float32`` too.\n    Args:\n        image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.\n            It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``\n            ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.\n        mask (_type_): The mask to apply to the image, i.e. regions to inpaint.\n            It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``\n            ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.\n    Raises:\n        ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask\n        should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.\n        TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not\n            (ot the other way around).\n    Returns:\n        tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4\n            dimensions: ``batch x channels x height x width``.\n    \"\"\"\n    if isinstance(image, torch.Tensor):\n        if not isinstance(mask, torch.Tensor):\n            raise TypeError(f\"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not\")\n\n        # Batch single image\n        if image.ndim == 3:\n            assert image.shape[0] == 3, \"Image outside a batch should be of shape (3, H, W)\"\n            image = image.unsqueeze(0)\n\n        # Batch and add channel dim for single mask\n        if mask.ndim == 2:\n            mask = mask.unsqueeze(0).unsqueeze(0)\n\n        # Batch single mask or add channel dim\n        if mask.ndim == 3:\n            # Single batched mask, no channel dim or single mask not batched but channel dim\n            if mask.shape[0] == 1:\n                mask = mask.unsqueeze(0)\n\n            # Batched masks no channel dim\n            else:\n                mask = mask.unsqueeze(1)\n\n        assert image.ndim == 4 and mask.ndim == 4, \"Image and Mask must have 4 dimensions\"\n        assert image.shape[-2:] == mask.shape[-2:], \"Image and Mask must have the same spatial dimensions\"\n        assert image.shape[0] == mask.shape[0], \"Image and Mask must have the same batch size\"\n\n        # Check image is in [-1, 1]\n        if image.min() < -1 or image.max() > 1:\n            raise ValueError(\"Image should be in [-1, 1] range\")\n\n        # Check mask is in [0, 1]\n        if mask.min() < 0 or mask.max() > 1:\n            raise ValueError(\"Mask should be in [0, 1] range\")\n\n        # Binarize mask\n        mask[mask < 0.5] = 0\n        mask[mask >= 0.5] = 1\n\n        # Image as float32\n        image = image.to(dtype=torch.float32)\n    elif isinstance(mask, torch.Tensor):\n        raise TypeError(f\"`mask` is a torch.Tensor but `image` (type: {type(image)} is not\")\n    else:\n        # preprocess image\n        if isinstance(image, (PIL.Image.Image, np.ndarray)):\n            image = [image]\n\n        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):\n            image = [np.array(i.convert(\"RGB\"))[None, :] for i in image]\n            image = np.concatenate(image, axis=0)\n        elif isinstance(image, list) and isinstance(image[0], np.ndarray):\n            image = np.concatenate([i[None, :] for i in image], axis=0)\n\n        image = image.transpose(0, 3, 1, 2)\n        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n        # preprocess mask\n        if isinstance(mask, (PIL.Image.Image, np.ndarray)):\n            mask = [mask]\n\n        if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):\n            mask = np.concatenate([np.array(m.convert(\"L\"))[None, None, :] for m in mask], axis=0)\n            mask = mask.astype(np.float32) / 255.0\n        elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):\n            mask = np.concatenate([m[None, None, :] for m in mask], axis=0)\n\n        mask[mask < 0.5] = 0\n        mask[mask >= 0.5] = 1\n        mask = torch.from_numpy(mask)\n\n    # masked_image = image * (mask >= 0.5)\n    masked_image = image\n\n    return mask, masked_image\n\n\nclass StableDiffusionRepaintPipeline(\n    DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin\n):\n    r\"\"\"\n    Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n    In addition the pipeline inherits the following loading methods:\n        - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]\n        - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]\n    as well as the following saving methods:\n        - *LoRA*: [`loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`]\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"skip_prk_steps\", True) is False:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration\"\n                \" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make\"\n                \" sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to\"\n                \" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face\"\n                \" Hub, it would be very nice if you could open a Pull request for the\"\n                \" `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\n                \"skip_prk_steps not set\",\n                \"1.0.0\",\n                deprecation_message,\n                standard_warn=False,\n            )\n            new_config = dict(scheduler.config)\n            new_config[\"skip_prk_steps\"] = True\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n        # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4\n        if unet is not None and unet.config.in_channels != 4:\n            logger.warning(\n                f\"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,\"\n                f\" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,\"\n                \". If you did not intend to modify\"\n                \" this behavior, please check whether you have loaded the right checkpoint.\"\n            )\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n        \"\"\"\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            prompt_embeds = self.text_encoder(\n                text_input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        else:\n            has_nsfw_concept = None\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents).sample\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def prepare_mask_latents(\n        self,\n        mask,\n        masked_image,\n        batch_size,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        do_classifier_free_guidance,\n    ):\n        # resize the mask to latents shape as we concatenate the mask to the latents\n        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload\n        # and half precision\n        mask = torch.nn.functional.interpolate(\n            mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)\n        )\n        mask = mask.to(device=device, dtype=dtype)\n\n        masked_image = masked_image.to(device=device, dtype=dtype)\n\n        # encode the mask image into latents space so we can concatenate it to the latents\n        if isinstance(generator, list):\n            masked_image_latents = [\n                self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])\n                for i in range(batch_size)\n            ]\n            masked_image_latents = torch.cat(masked_image_latents, dim=0)\n        else:\n            masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)\n        masked_image_latents = self.vae.config.scaling_factor * masked_image_latents\n\n        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method\n        if mask.shape[0] < batch_size:\n            if not batch_size % mask.shape[0] == 0:\n                raise ValueError(\n                    \"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to\"\n                    f\" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number\"\n                    \" of masks that you pass is divisible by the total requested batch size.\"\n                )\n            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)\n        if masked_image_latents.shape[0] < batch_size:\n            if not batch_size % masked_image_latents.shape[0] == 0:\n                raise ValueError(\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                    f\" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed.\"\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                )\n            masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)\n\n        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask\n        masked_image_latents = (\n            torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents\n        )\n\n        # aligning device to prevent device errors when concating it with the latent model input\n        masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)\n        return mask, masked_image_latents\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: Union[torch.Tensor, PIL.Image.Image] = None,\n        mask_image: Union[torch.Tensor, PIL.Image.Image] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        jump_length: Optional[int] = 10,\n        jump_n_sample: Optional[int] = 10,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will\n                be masked out with `mask_image` and repainted according to `prompt`.\n            mask_image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted\n                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)\n                instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            jump_length (`int`, *optional*, defaults to 10):\n                The number of steps taken forward in time before going backward in time for a single jump (\"j\" in\n                RePaint paper). Take a look at Figure 9 and 10 in https://huggingface.co/papers/2201.09865.\n            jump_n_sample (`int`, *optional*, defaults to 10):\n                The number of times we will make forward time jump for a given chosen time sample. Take a look at\n                Figure 9 and 10 in https://huggingface.co/papers/2201.09865.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`\n                is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n        Examples:\n        ```py\n        >>> import PIL\n        >>> import requests\n        >>> import torch\n        >>> from io import BytesIO\n        >>> from diffusers import StableDiffusionPipeline, RePaintScheduler\n        >>> def download_image(url):\n        ...     response = requests.get(url)\n        ...     return PIL.Image.open(BytesIO(response.content)).convert(\"RGB\")\n        >>> base_url = \"https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/\"\n        >>> img_url = base_url + \"overture-creations-5sI6fQgYIuo.png\"\n        >>> mask_url = base_url + \"overture-creations-5sI6fQgYIuo_mask.png \"\n        >>> init_image = download_image(img_url).resize((512, 512))\n        >>> mask_image = download_image(mask_url).resize((512, 512))\n        >>> pipe = DiffusionPipeline.from_pretrained(\n        ...     \"CompVis/stable-diffusion-v1-4\", torch_dtype=torch.float16, custom_pipeline=\"stable_diffusion_repaint\",\n        ... )\n        >>> pipe.scheduler = RePaintScheduler.from_config(pipe.scheduler.config)\n        >>> pipe = pipe.to(\"cuda\")\n        >>> prompt = \"Face of a yellow cat, high resolution, sitting on a park bench\"\n        >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]\n        ```\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n        )\n\n        if image is None:\n            raise ValueError(\"`image` input cannot be undefined.\")\n\n        if mask_image is None:\n            raise ValueError(\"`mask_image` input cannot be undefined.\")\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n\n        # 4. Preprocess mask and image\n        mask, masked_image = prepare_mask_and_masked_image(image, mask_image)\n\n        # 5. set timesteps\n        self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, device)\n        self.scheduler.eta = eta\n\n        timesteps = self.scheduler.timesteps\n        # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.vae.config.latent_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 7. Prepare mask latent variables\n        mask, masked_image_latents = self.prepare_mask_latents(\n            mask,\n            masked_image,\n            batch_size * num_images_per_prompt,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            do_classifier_free_guidance=False,  # We do not need duplicate mask and image\n        )\n\n        # 8. Check that sizes of mask, masked image and latents match\n        # num_channels_mask = mask.shape[1]\n        # num_channels_masked_image = masked_image_latents.shape[1]\n        if num_channels_latents != self.unet.config.in_channels:\n            raise ValueError(\n                f\"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects\"\n                f\" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} \"\n                f\" = Please verify the config of\"\n                \" `pipeline.unet` or your `mask_image` or `image` input.\"\n            )\n\n        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        t_last = timesteps[0] + 1\n\n        # 10. Denoising loop\n        with self.progress_bar(total=len(timesteps)) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if t >= t_last:\n                    # compute the reverse: x_t-1 -> x_t\n                    latents = self.scheduler.undo_step(latents, t_last, generator)\n                    progress_bar.update()\n                    t_last = t\n                    continue\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n\n                # concat latents, mask, masked_image_latents in the channel dimension\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n                # latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)\n\n                # predict the noise residual\n                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(\n                    noise_pred,\n                    t,\n                    latents,\n                    masked_image_latents,\n                    mask,\n                    **extra_step_kwargs,\n                ).prev_sample\n\n                # call the callback, if provided\n                progress_bar.update()\n                if callback is not None and i % callback_steps == 0:\n                    step_idx = i // getattr(self.scheduler, \"order\", 1)\n                    callback(step_idx, t, latents)\n\n                t_last = t\n\n        # 11. Post-processing\n        image = self.decode_latents(latents)\n\n        # 12. Run safety checker\n        image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n\n        # 13. Convert to PIL\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/stable_diffusion_tensorrt_img2img.py",
    "content": "#\n# Copyright 2025 The HuggingFace Inc. team.\n# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gc\nimport os\nfrom collections import OrderedDict\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport onnx\nimport onnx_graphsurgeon as gs\nimport PIL.Image\nimport tensorrt as trt\nimport torch\nfrom cuda import cudart\nfrom huggingface_hub import snapshot_download\nfrom huggingface_hub.utils import validate_hf_hub_args\nfrom onnx import shape_inference\nfrom packaging import version\nfrom polygraphy import cuda\nfrom polygraphy.backend.common import bytes_from_path\nfrom polygraphy.backend.onnx.loader import fold_constants\nfrom polygraphy.backend.trt import (\n    CreateConfig,\n    Profile,\n    engine_from_bytes,\n    engine_from_network,\n    network_from_onnx_path,\n    save_engine,\n)\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.configuration_utils import FrozenDict, deprecate\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.stable_diffusion import (\n    StableDiffusionPipelineOutput,\n    StableDiffusionSafetyChecker,\n)\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents\nfrom diffusers.schedulers import DDIMScheduler\nfrom diffusers.utils import logging\n\n\n\"\"\"\nInstallation instructions\npython3 -m pip install --upgrade transformers diffusers>=0.16.0\npython3 -m pip install --upgrade tensorrt~=10.2.0\npython3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com\npython3 -m pip install onnxruntime\n\"\"\"\n\nTRT_LOGGER = trt.Logger(trt.Logger.ERROR)\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n# Map of numpy dtype -> torch dtype\nnumpy_to_torch_dtype_dict = {\n    np.uint8: torch.uint8,\n    np.int8: torch.int8,\n    np.int16: torch.int16,\n    np.int32: torch.int32,\n    np.int64: torch.int64,\n    np.float16: torch.float16,\n    np.float32: torch.float32,\n    np.float64: torch.float64,\n    np.complex64: torch.complex64,\n    np.complex128: torch.complex128,\n}\nif np.version.full_version >= \"1.24.0\":\n    numpy_to_torch_dtype_dict[np.bool_] = torch.bool\nelse:\n    numpy_to_torch_dtype_dict[np.bool] = torch.bool\n\n# Map of torch dtype -> numpy dtype\ntorch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}\n\n\ndef preprocess_image(image):\n    \"\"\"\n    image: torch.Tensor\n    \"\"\"\n    w, h = image.size\n    w, h = (x - x % 32 for x in (w, h))  # resize to integer multiple of 32\n    image = image.resize((w, h))\n    image = np.array(image).astype(np.float32) / 255.0\n    image = image[None].transpose(0, 3, 1, 2)\n    image = torch.from_numpy(image).contiguous()\n    return 2.0 * image - 1.0\n\n\nclass Engine:\n    def __init__(self, engine_path):\n        self.engine_path = engine_path\n        self.engine = None\n        self.context = None\n        self.buffers = OrderedDict()\n        self.tensors = OrderedDict()\n\n    def __del__(self):\n        [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]\n        del self.engine\n        del self.context\n        del self.buffers\n        del self.tensors\n\n    def build(\n        self,\n        onnx_path,\n        fp16,\n        input_profile=None,\n        enable_all_tactics=False,\n        timing_cache=None,\n    ):\n        logger.warning(f\"Building TensorRT engine for {onnx_path}: {self.engine_path}\")\n        p = Profile()\n        if input_profile:\n            for name, dims in input_profile.items():\n                assert len(dims) == 3\n                p.add(name, min=dims[0], opt=dims[1], max=dims[2])\n\n        extra_build_args = {}\n        if not enable_all_tactics:\n            extra_build_args[\"tactic_sources\"] = []\n\n        engine = engine_from_network(\n            network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),\n            config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),\n            save_timing_cache=timing_cache,\n        )\n        save_engine(engine, path=self.engine_path)\n\n    def load(self):\n        logger.warning(f\"Loading TensorRT engine: {self.engine_path}\")\n        self.engine = engine_from_bytes(bytes_from_path(self.engine_path))\n\n    def activate(self):\n        self.context = self.engine.create_execution_context()\n\n    def allocate_buffers(self, shape_dict=None, device=\"cuda\"):\n        for binding in range(self.engine.num_io_tensors):\n            name = self.engine.get_tensor_name(binding)\n            if shape_dict and name in shape_dict:\n                shape = shape_dict[name]\n            else:\n                shape = self.engine.get_tensor_shape(name)\n            dtype = trt.nptype(self.engine.get_tensor_dtype(name))\n            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:\n                self.context.set_input_shape(name, shape)\n            tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)\n            self.tensors[name] = tensor\n\n    def infer(self, feed_dict, stream):\n        for name, buf in feed_dict.items():\n            self.tensors[name].copy_(buf)\n        for name, tensor in self.tensors.items():\n            self.context.set_tensor_address(name, tensor.data_ptr())\n        noerror = self.context.execute_async_v3(stream)\n        if not noerror:\n            raise ValueError(\"ERROR: inference failed.\")\n\n        return self.tensors\n\n\nclass Optimizer:\n    def __init__(self, onnx_graph):\n        self.graph = gs.import_onnx(onnx_graph)\n\n    def cleanup(self, return_onnx=False):\n        self.graph.cleanup().toposort()\n        if return_onnx:\n            return gs.export_onnx(self.graph)\n\n    def select_outputs(self, keep, names=None):\n        self.graph.outputs = [self.graph.outputs[o] for o in keep]\n        if names:\n            for i, name in enumerate(names):\n                self.graph.outputs[i].name = name\n\n    def fold_constants(self, return_onnx=False):\n        onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)\n        self.graph = gs.import_onnx(onnx_graph)\n        if return_onnx:\n            return onnx_graph\n\n    def infer_shapes(self, return_onnx=False):\n        onnx_graph = gs.export_onnx(self.graph)\n        if onnx_graph.ByteSize() > 2147483648:\n            raise TypeError(\"ERROR: model size exceeds supported 2GB limit\")\n        else:\n            onnx_graph = shape_inference.infer_shapes(onnx_graph)\n\n        self.graph = gs.import_onnx(onnx_graph)\n        if return_onnx:\n            return onnx_graph\n\n\nclass BaseModel:\n    def __init__(self, model, fp16=False, device=\"cuda\", max_batch_size=16, embedding_dim=768, text_maxlen=77):\n        self.model = model\n        self.name = \"SD Model\"\n        self.fp16 = fp16\n        self.device = device\n\n        self.min_batch = 1\n        self.max_batch = max_batch_size\n        self.min_image_shape = 256  # min image resolution: 256x256\n        self.max_image_shape = 1024  # max image resolution: 1024x1024\n        self.min_latent_shape = self.min_image_shape // 8\n        self.max_latent_shape = self.max_image_shape // 8\n\n        self.embedding_dim = embedding_dim\n        self.text_maxlen = text_maxlen\n\n    def get_model(self):\n        return self.model\n\n    def get_input_names(self):\n        pass\n\n    def get_output_names(self):\n        pass\n\n    def get_dynamic_axes(self):\n        return None\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        pass\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        return None\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        return None\n\n    def optimize(self, onnx_graph):\n        opt = Optimizer(onnx_graph)\n        opt.cleanup()\n        opt.fold_constants()\n        opt.infer_shapes()\n        onnx_opt_graph = opt.cleanup(return_onnx=True)\n        return onnx_opt_graph\n\n    def check_dims(self, batch_size, image_height, image_width):\n        assert batch_size >= self.min_batch and batch_size <= self.max_batch\n        assert image_height % 8 == 0 or image_width % 8 == 0\n        latent_height = image_height // 8\n        latent_width = image_width // 8\n        assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape\n        assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape\n        return (latent_height, latent_width)\n\n    def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):\n        min_batch = batch_size if static_batch else self.min_batch\n        max_batch = batch_size if static_batch else self.max_batch\n        latent_height = image_height // 8\n        latent_width = image_width // 8\n        min_image_height = image_height if static_shape else self.min_image_shape\n        max_image_height = image_height if static_shape else self.max_image_shape\n        min_image_width = image_width if static_shape else self.min_image_shape\n        max_image_width = image_width if static_shape else self.max_image_shape\n        min_latent_height = latent_height if static_shape else self.min_latent_shape\n        max_latent_height = latent_height if static_shape else self.max_latent_shape\n        min_latent_width = latent_width if static_shape else self.min_latent_shape\n        max_latent_width = latent_width if static_shape else self.max_latent_shape\n        return (\n            min_batch,\n            max_batch,\n            min_image_height,\n            max_image_height,\n            min_image_width,\n            max_image_width,\n            min_latent_height,\n            max_latent_height,\n            min_latent_width,\n            max_latent_width,\n        )\n\n\ndef getOnnxPath(model_name, onnx_dir, opt=True):\n    return os.path.join(onnx_dir, model_name + (\".opt\" if opt else \"\") + \".onnx\")\n\n\ndef getEnginePath(model_name, engine_dir):\n    return os.path.join(engine_dir, model_name + \".plan\")\n\n\ndef build_engines(\n    models: dict,\n    engine_dir,\n    onnx_dir,\n    onnx_opset,\n    opt_image_height,\n    opt_image_width,\n    opt_batch_size=1,\n    force_engine_rebuild=False,\n    static_batch=False,\n    static_shape=True,\n    enable_all_tactics=False,\n    timing_cache=None,\n):\n    built_engines = {}\n    if not os.path.isdir(onnx_dir):\n        os.makedirs(onnx_dir)\n    if not os.path.isdir(engine_dir):\n        os.makedirs(engine_dir)\n\n    # Export models to ONNX\n    for model_name, model_obj in models.items():\n        engine_path = getEnginePath(model_name, engine_dir)\n        if force_engine_rebuild or not os.path.exists(engine_path):\n            logger.warning(\"Building Engines...\")\n            logger.warning(\"Engine build can take a while to complete\")\n            onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)\n            onnx_opt_path = getOnnxPath(model_name, onnx_dir)\n            if force_engine_rebuild or not os.path.exists(onnx_opt_path):\n                if force_engine_rebuild or not os.path.exists(onnx_path):\n                    logger.warning(f\"Exporting model: {onnx_path}\")\n                    model = model_obj.get_model()\n                    with torch.inference_mode(), torch.autocast(\"cuda\"):\n                        inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)\n                        torch.onnx.export(\n                            model,\n                            inputs,\n                            onnx_path,\n                            export_params=True,\n                            opset_version=onnx_opset,\n                            do_constant_folding=True,\n                            input_names=model_obj.get_input_names(),\n                            output_names=model_obj.get_output_names(),\n                            dynamic_axes=model_obj.get_dynamic_axes(),\n                        )\n                    del model\n                    torch.cuda.empty_cache()\n                    gc.collect()\n                else:\n                    logger.warning(f\"Found cached model: {onnx_path}\")\n\n                # Optimize onnx\n                if force_engine_rebuild or not os.path.exists(onnx_opt_path):\n                    logger.warning(f\"Generating optimizing model: {onnx_opt_path}\")\n                    onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))\n                    onnx.save(onnx_opt_graph, onnx_opt_path)\n                else:\n                    logger.warning(f\"Found cached optimized model: {onnx_opt_path} \")\n\n    # Build TensorRT engines\n    for model_name, model_obj in models.items():\n        engine_path = getEnginePath(model_name, engine_dir)\n        engine = Engine(engine_path)\n        onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)\n        onnx_opt_path = getOnnxPath(model_name, onnx_dir)\n\n        if force_engine_rebuild or not os.path.exists(engine.engine_path):\n            engine.build(\n                onnx_opt_path,\n                fp16=True,\n                input_profile=model_obj.get_input_profile(\n                    opt_batch_size,\n                    opt_image_height,\n                    opt_image_width,\n                    static_batch=static_batch,\n                    static_shape=static_shape,\n                ),\n                timing_cache=timing_cache,\n            )\n        built_engines[model_name] = engine\n\n    # Load and activate TensorRT engines\n    for model_name, model_obj in models.items():\n        engine = built_engines[model_name]\n        engine.load()\n        engine.activate()\n\n    return built_engines\n\n\ndef runEngine(engine, feed_dict, stream):\n    return engine.infer(feed_dict, stream)\n\n\nclass CLIP(BaseModel):\n    def __init__(self, model, device, max_batch_size, embedding_dim):\n        super(CLIP, self).__init__(\n            model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim\n        )\n        self.name = \"CLIP\"\n\n    def get_input_names(self):\n        return [\"input_ids\"]\n\n    def get_output_names(self):\n        return [\"text_embeddings\", \"pooler_output\"]\n\n    def get_dynamic_axes(self):\n        return {\"input_ids\": {0: \"B\"}, \"text_embeddings\": {0: \"B\"}}\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        self.check_dims(batch_size, image_height, image_width)\n        min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(\n            batch_size, image_height, image_width, static_batch, static_shape\n        )\n        return {\n            \"input_ids\": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]\n        }\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        self.check_dims(batch_size, image_height, image_width)\n        return {\n            \"input_ids\": (batch_size, self.text_maxlen),\n            \"text_embeddings\": (batch_size, self.text_maxlen, self.embedding_dim),\n        }\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        self.check_dims(batch_size, image_height, image_width)\n        return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)\n\n    def optimize(self, onnx_graph):\n        opt = Optimizer(onnx_graph)\n        opt.select_outputs([0])  # delete graph output#1\n        opt.cleanup()\n        opt.fold_constants()\n        opt.infer_shapes()\n        opt.select_outputs([0], names=[\"text_embeddings\"])  # rename network output\n        opt_onnx_graph = opt.cleanup(return_onnx=True)\n        return opt_onnx_graph\n\n\ndef make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False):\n    return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)\n\n\nclass UNet(BaseModel):\n    def __init__(\n        self, model, fp16=False, device=\"cuda\", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4\n    ):\n        super(UNet, self).__init__(\n            model=model,\n            fp16=fp16,\n            device=device,\n            max_batch_size=max_batch_size,\n            embedding_dim=embedding_dim,\n            text_maxlen=text_maxlen,\n        )\n        self.unet_dim = unet_dim\n        self.name = \"UNet\"\n\n    def get_input_names(self):\n        return [\"sample\", \"timestep\", \"encoder_hidden_states\"]\n\n    def get_output_names(self):\n        return [\"latent\"]\n\n    def get_dynamic_axes(self):\n        return {\n            \"sample\": {0: \"2B\", 2: \"H\", 3: \"W\"},\n            \"encoder_hidden_states\": {0: \"2B\"},\n            \"latent\": {0: \"2B\", 2: \"H\", 3: \"W\"},\n        }\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        (\n            min_batch,\n            max_batch,\n            _,\n            _,\n            _,\n            _,\n            min_latent_height,\n            max_latent_height,\n            min_latent_width,\n            max_latent_width,\n        ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)\n        return {\n            \"sample\": [\n                (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),\n                (2 * batch_size, self.unet_dim, latent_height, latent_width),\n                (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),\n            ],\n            \"encoder_hidden_states\": [\n                (2 * min_batch, self.text_maxlen, self.embedding_dim),\n                (2 * batch_size, self.text_maxlen, self.embedding_dim),\n                (2 * max_batch, self.text_maxlen, self.embedding_dim),\n            ],\n        }\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        return {\n            \"sample\": (2 * batch_size, self.unet_dim, latent_height, latent_width),\n            \"encoder_hidden_states\": (2 * batch_size, self.text_maxlen, self.embedding_dim),\n            \"latent\": (2 * batch_size, 4, latent_height, latent_width),\n        }\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        dtype = torch.float16 if self.fp16 else torch.float32\n        return (\n            torch.randn(\n                2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device\n            ),\n            torch.tensor([1.0], dtype=torch.float32, device=self.device),\n            torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),\n        )\n\n\ndef make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False):\n    return UNet(\n        model,\n        fp16=True,\n        device=device,\n        max_batch_size=max_batch_size,\n        embedding_dim=embedding_dim,\n        unet_dim=(9 if inpaint else 4),\n    )\n\n\nclass VAE(BaseModel):\n    def __init__(self, model, device, max_batch_size, embedding_dim):\n        super(VAE, self).__init__(\n            model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim\n        )\n        self.name = \"VAE decoder\"\n\n    def get_input_names(self):\n        return [\"latent\"]\n\n    def get_output_names(self):\n        return [\"images\"]\n\n    def get_dynamic_axes(self):\n        return {\"latent\": {0: \"B\", 2: \"H\", 3: \"W\"}, \"images\": {0: \"B\", 2: \"8H\", 3: \"8W\"}}\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        (\n            min_batch,\n            max_batch,\n            _,\n            _,\n            _,\n            _,\n            min_latent_height,\n            max_latent_height,\n            min_latent_width,\n            max_latent_width,\n        ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)\n        return {\n            \"latent\": [\n                (min_batch, 4, min_latent_height, min_latent_width),\n                (batch_size, 4, latent_height, latent_width),\n                (max_batch, 4, max_latent_height, max_latent_width),\n            ]\n        }\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        return {\n            \"latent\": (batch_size, 4, latent_height, latent_width),\n            \"images\": (batch_size, 3, image_height, image_width),\n        }\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)\n\n\ndef make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):\n    return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)\n\n\nclass TorchVAEEncoder(torch.nn.Module):\n    def __init__(self, model):\n        super().__init__()\n        self.vae_encoder = model\n\n    def forward(self, x):\n        return retrieve_latents(self.vae_encoder.encode(x))\n\n\nclass VAEEncoder(BaseModel):\n    def __init__(self, model, device, max_batch_size, embedding_dim):\n        super(VAEEncoder, self).__init__(\n            model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim\n        )\n        self.name = \"VAE encoder\"\n\n    def get_model(self):\n        vae_encoder = TorchVAEEncoder(self.model)\n        return vae_encoder\n\n    def get_input_names(self):\n        return [\"images\"]\n\n    def get_output_names(self):\n        return [\"latent\"]\n\n    def get_dynamic_axes(self):\n        return {\"images\": {0: \"B\", 2: \"8H\", 3: \"8W\"}, \"latent\": {0: \"B\", 2: \"H\", 3: \"W\"}}\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        assert batch_size >= self.min_batch and batch_size <= self.max_batch\n        min_batch = batch_size if static_batch else self.min_batch\n        max_batch = batch_size if static_batch else self.max_batch\n        self.check_dims(batch_size, image_height, image_width)\n        (\n            min_batch,\n            max_batch,\n            min_image_height,\n            max_image_height,\n            min_image_width,\n            max_image_width,\n            _,\n            _,\n            _,\n            _,\n        ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)\n\n        return {\n            \"images\": [\n                (min_batch, 3, min_image_height, min_image_width),\n                (batch_size, 3, image_height, image_width),\n                (max_batch, 3, max_image_height, max_image_width),\n            ]\n        }\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        return {\n            \"images\": (batch_size, 3, image_height, image_width),\n            \"latent\": (batch_size, 4, latent_height, latent_width),\n        }\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        self.check_dims(batch_size, image_height, image_width)\n        return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device)\n\n\ndef make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False):\n    return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)\n\n\nclass TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):\n    r\"\"\"\n    Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: DDIMScheduler,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n        stages=[\"clip\", \"unet\", \"vae\", \"vae_encoder\"],\n        image_height: int = 512,\n        image_width: int = 512,\n        max_batch_size: int = 16,\n        # ONNX export parameters\n        onnx_opset: int = 17,\n        onnx_dir: str = \"onnx\",\n        # TensorRT engine build parameters\n        engine_dir: str = \"engine\",\n        force_engine_rebuild: bool = False,\n        timing_cache: str = \"timing_cache\",\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n\n        self.stages = stages\n        self.image_height, self.image_width = image_height, image_width\n        self.inpaint = False\n        self.onnx_opset = onnx_opset\n        self.onnx_dir = onnx_dir\n        self.engine_dir = engine_dir\n        self.force_engine_rebuild = force_engine_rebuild\n        self.timing_cache = timing_cache\n        self.build_static_batch = False\n        self.build_dynamic_shape = False\n\n        self.max_batch_size = max_batch_size\n        # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.\n        if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512:\n            self.max_batch_size = 4\n\n        self.stream = None  # loaded in loadResources()\n        self.models = {}  # loaded in __loadModels()\n        self.engine = {}  # loaded in build_engines()\n\n        self.vae.forward = self.vae.decode\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def __loadModels(self):\n        # Load pipeline models\n        self.embedding_dim = self.text_encoder.config.hidden_size\n        models_args = {\n            \"device\": self.torch_device,\n            \"max_batch_size\": self.max_batch_size,\n            \"embedding_dim\": self.embedding_dim,\n            \"inpaint\": self.inpaint,\n        }\n        if \"clip\" in self.stages:\n            self.models[\"clip\"] = make_CLIP(self.text_encoder, **models_args)\n        if \"unet\" in self.stages:\n            self.models[\"unet\"] = make_UNet(self.unet, **models_args)\n        if \"vae\" in self.stages:\n            self.models[\"vae\"] = make_VAE(self.vae, **models_args)\n        if \"vae_encoder\" in self.stages:\n            self.models[\"vae_encoder\"] = make_VAEEncoder(self.vae, **models_args)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(\n        self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype\n    ) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:\n        r\"\"\"\n        Runs the safety checker on the given image.\n        Args:\n            image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.\n            device (torch.device): The device to run the safety checker on.\n            dtype (torch.dtype): The data type of the input image.\n        Returns:\n            (image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and\n            a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.\n        \"\"\"\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    @classmethod\n    @validate_hf_hub_args\n    def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):\n        cache_dir = kwargs.pop(\"cache_dir\", None)\n        proxies = kwargs.pop(\"proxies\", None)\n        local_files_only = kwargs.pop(\"local_files_only\", False)\n        token = kwargs.pop(\"token\", None)\n        revision = kwargs.pop(\"revision\", None)\n\n        cls.cached_folder = (\n            pretrained_model_name_or_path\n            if os.path.isdir(pretrained_model_name_or_path)\n            else snapshot_download(\n                pretrained_model_name_or_path,\n                cache_dir=cache_dir,\n                proxies=proxies,\n                local_files_only=local_files_only,\n                token=token,\n                revision=revision,\n            )\n        )\n\n    def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):\n        super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings)\n\n        self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir)\n        self.engine_dir = os.path.join(self.cached_folder, self.engine_dir)\n        self.timing_cache = os.path.join(self.cached_folder, self.timing_cache)\n\n        # set device\n        self.torch_device = self._execution_device\n        logger.warning(f\"Running inference on device: {self.torch_device}\")\n\n        # load models\n        self.__loadModels()\n\n        # build engines\n        self.engine = build_engines(\n            self.models,\n            self.engine_dir,\n            self.onnx_dir,\n            self.onnx_opset,\n            opt_image_height=self.image_height,\n            opt_image_width=self.image_width,\n            force_engine_rebuild=self.force_engine_rebuild,\n            static_batch=self.build_static_batch,\n            static_shape=not self.build_dynamic_shape,\n            timing_cache=self.timing_cache,\n        )\n\n        return self\n\n    def __initialize_timesteps(self, timesteps, strength):\n        self.scheduler.set_timesteps(timesteps)\n        offset = self.scheduler.steps_offset if hasattr(self.scheduler, \"steps_offset\") else 0\n        init_timestep = int(timesteps * strength) + offset\n        init_timestep = min(init_timestep, timesteps)\n        t_start = max(timesteps - init_timestep + offset, 0)\n        timesteps = self.scheduler.timesteps[t_start:].to(self.torch_device)\n        return timesteps, t_start\n\n    def __preprocess_images(self, batch_size, images=()):\n        init_images = []\n        for image in images:\n            image = image.to(self.torch_device).float()\n            image = image.repeat(batch_size, 1, 1, 1)\n            init_images.append(image)\n        return tuple(init_images)\n\n    def __encode_image(self, init_image):\n        init_latents = runEngine(self.engine[\"vae_encoder\"], {\"images\": init_image}, self.stream)[\"latent\"]\n        init_latents = 0.18215 * init_latents\n        return init_latents\n\n    def __encode_prompt(self, prompt, negative_prompt):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n        \"\"\"\n        # Tokenize prompt\n        text_input_ids = (\n            self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            .input_ids.type(torch.int32)\n            .to(self.torch_device)\n        )\n\n        # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt\n        text_embeddings = runEngine(self.engine[\"clip\"], {\"input_ids\": text_input_ids}, self.stream)[\n            \"text_embeddings\"\n        ].clone()\n\n        # Tokenize negative prompt\n        uncond_input_ids = (\n            self.tokenizer(\n                negative_prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            .input_ids.type(torch.int32)\n            .to(self.torch_device)\n        )\n        uncond_embeddings = runEngine(self.engine[\"clip\"], {\"input_ids\": uncond_input_ids}, self.stream)[\n            \"text_embeddings\"\n        ]\n\n        # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance\n        text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)\n\n        return text_embeddings\n\n    def __denoise_latent(\n        self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None\n    ):\n        if not isinstance(timesteps, torch.Tensor):\n            timesteps = self.scheduler.timesteps\n        for step_index, timestep in enumerate(timesteps):\n            # Expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2)\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)\n            if isinstance(mask, torch.Tensor):\n                latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)\n\n            # Predict the noise residual\n            timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep\n\n            noise_pred = runEngine(\n                self.engine[\"unet\"],\n                {\"sample\": latent_model_input, \"timestep\": timestep_float, \"encoder_hidden_states\": text_embeddings},\n                self.stream,\n            )[\"latent\"]\n\n            # Perform guidance\n            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample\n\n        latents = 1.0 / 0.18215 * latents\n        return latents\n\n    def __decode_latent(self, latents):\n        images = runEngine(self.engine[\"vae\"], {\"latent\": latents}, self.stream)[\"images\"]\n        images = (images / 2 + 0.5).clamp(0, 1)\n        return images.cpu().permute(0, 2, 3, 1).float().numpy()\n\n    def __loadResources(self, image_height, image_width, batch_size):\n        self.stream = cudart.cudaStreamCreate()[1]\n\n        # Allocate buffers for TensorRT engine bindings\n        for model_name, obj in self.models.items():\n            self.engine[model_name].allocate_buffers(\n                shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device\n            )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: Union[torch.Tensor, PIL.Image.Image] = None,\n        strength: float = 0.8,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will\n                be masked out with `mask_image` and repainted according to `prompt`.\n            strength (`float`, *optional*, defaults to 0.8):\n                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`\n                will be used as a starting point, adding more noise to it the larger the `strength`. The number of\n                denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will\n                be maximum and the denoising process will run for the full number of iterations specified in\n                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n\n        \"\"\"\n        self.generator = generator\n        self.denoising_steps = num_inference_steps\n        self._guidance_scale = guidance_scale\n\n        # Pre-compute latent input scales and linear multistep coefficients\n        self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)\n\n        # Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n            prompt = [prompt]\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"Expected prompt to be of type list or str but got {type(prompt)}\")\n\n        if negative_prompt is None:\n            negative_prompt = [\"\"] * batch_size\n\n        if negative_prompt is not None and isinstance(negative_prompt, str):\n            negative_prompt = [negative_prompt]\n\n        assert len(prompt) == len(negative_prompt)\n\n        if batch_size > self.max_batch_size:\n            raise ValueError(\n                f\"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4\"\n            )\n\n        # load resources\n        self.__loadResources(self.image_height, self.image_width, batch_size)\n\n        with torch.inference_mode(), torch.autocast(\"cuda\"), trt.Runtime(TRT_LOGGER):\n            # Initialize timesteps\n            timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength)\n            latent_timestep = timesteps[:1].repeat(batch_size)\n\n            # Pre-process input image\n            if isinstance(image, PIL.Image.Image):\n                image = preprocess_image(image)\n            init_image = self.__preprocess_images(batch_size, (image,))[0]\n\n            # VAE encode init image\n            init_latents = self.__encode_image(init_image)\n\n            # Add noise to latents using timesteps\n            noise = torch.randn(\n                init_latents.shape, generator=self.generator, device=self.torch_device, dtype=torch.float32\n            )\n            latents = self.scheduler.add_noise(init_latents, noise, latent_timestep)\n\n            # CLIP text encoder\n            text_embeddings = self.__encode_prompt(prompt, negative_prompt)\n\n            # UNet denoiser\n            latents = self.__denoise_latent(latents, text_embeddings, timesteps=timesteps, step_offset=t_start)\n\n            # VAE decode latent\n            images = self.__decode_latent(latents)\n\n        images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)\n        images = self.numpy_to_pil(images)\n        return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/stable_diffusion_tensorrt_inpaint.py",
    "content": "#\n# Copyright 2025 The HuggingFace Inc. team.\n# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gc\nimport os\nfrom collections import OrderedDict\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport onnx\nimport onnx_graphsurgeon as gs\nimport PIL.Image\nimport tensorrt as trt\nimport torch\nfrom cuda import cudart\nfrom huggingface_hub import snapshot_download\nfrom huggingface_hub.utils import validate_hf_hub_args\nfrom onnx import shape_inference\nfrom packaging import version\nfrom polygraphy import cuda\nfrom polygraphy.backend.common import bytes_from_path\nfrom polygraphy.backend.onnx.loader import fold_constants\nfrom polygraphy.backend.trt import (\n    CreateConfig,\n    Profile,\n    engine_from_bytes,\n    engine_from_network,\n    network_from_onnx_path,\n    save_engine,\n)\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.configuration_utils import FrozenDict, deprecate\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.stable_diffusion import (\n    StableDiffusionPipelineOutput,\n    StableDiffusionSafetyChecker,\n)\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import (\n    prepare_mask_and_masked_image,\n    retrieve_latents,\n)\nfrom diffusers.schedulers import DDIMScheduler\nfrom diffusers.utils import logging\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\n\"\"\"\nInstallation instructions\npython3 -m pip install --upgrade transformers diffusers>=0.16.0\npython3 -m pip install --upgrade tensorrt~=10.2.0\npython3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com\npython3 -m pip install onnxruntime\n\"\"\"\n\nTRT_LOGGER = trt.Logger(trt.Logger.ERROR)\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n# Map of numpy dtype -> torch dtype\nnumpy_to_torch_dtype_dict = {\n    np.uint8: torch.uint8,\n    np.int8: torch.int8,\n    np.int16: torch.int16,\n    np.int32: torch.int32,\n    np.int64: torch.int64,\n    np.float16: torch.float16,\n    np.float32: torch.float32,\n    np.float64: torch.float64,\n    np.complex64: torch.complex64,\n    np.complex128: torch.complex128,\n}\nif np.version.full_version >= \"1.24.0\":\n    numpy_to_torch_dtype_dict[np.bool_] = torch.bool\nelse:\n    numpy_to_torch_dtype_dict[np.bool] = torch.bool\n\n# Map of torch dtype -> numpy dtype\ntorch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}\n\n\ndef preprocess_image(image):\n    \"\"\"\n    image: torch.Tensor\n    \"\"\"\n    w, h = image.size\n    w, h = (x - x % 32 for x in (w, h))  # resize to integer multiple of 32\n    image = image.resize((w, h))\n    image = np.array(image).astype(np.float32) / 255.0\n    image = image[None].transpose(0, 3, 1, 2)\n    image = torch.from_numpy(image).contiguous()\n    return 2.0 * image - 1.0\n\n\nclass Engine:\n    def __init__(self, engine_path):\n        self.engine_path = engine_path\n        self.engine = None\n        self.context = None\n        self.buffers = OrderedDict()\n        self.tensors = OrderedDict()\n\n    def __del__(self):\n        [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]\n        del self.engine\n        del self.context\n        del self.buffers\n        del self.tensors\n\n    def build(\n        self,\n        onnx_path,\n        fp16,\n        input_profile=None,\n        enable_all_tactics=False,\n        timing_cache=None,\n    ):\n        logger.warning(f\"Building TensorRT engine for {onnx_path}: {self.engine_path}\")\n        p = Profile()\n        if input_profile:\n            for name, dims in input_profile.items():\n                assert len(dims) == 3\n                p.add(name, min=dims[0], opt=dims[1], max=dims[2])\n\n        extra_build_args = {}\n        if not enable_all_tactics:\n            extra_build_args[\"tactic_sources\"] = []\n\n        engine = engine_from_network(\n            network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),\n            config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),\n            save_timing_cache=timing_cache,\n        )\n        save_engine(engine, path=self.engine_path)\n\n    def load(self):\n        logger.warning(f\"Loading TensorRT engine: {self.engine_path}\")\n        self.engine = engine_from_bytes(bytes_from_path(self.engine_path))\n\n    def activate(self):\n        self.context = self.engine.create_execution_context()\n\n    def allocate_buffers(self, shape_dict=None, device=\"cuda\"):\n        for binding in range(self.engine.num_io_tensors):\n            name = self.engine.get_tensor_name(binding)\n            if shape_dict and name in shape_dict:\n                shape = shape_dict[name]\n            else:\n                shape = self.engine.get_tensor_shape(name)\n            dtype = trt.nptype(self.engine.get_tensor_dtype(name))\n            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:\n                self.context.set_input_shape(name, shape)\n            tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)\n            self.tensors[name] = tensor\n\n    def infer(self, feed_dict, stream):\n        for name, buf in feed_dict.items():\n            self.tensors[name].copy_(buf)\n        for name, tensor in self.tensors.items():\n            self.context.set_tensor_address(name, tensor.data_ptr())\n        noerror = self.context.execute_async_v3(stream)\n        if not noerror:\n            raise ValueError(\"ERROR: inference failed.\")\n\n        return self.tensors\n\n\nclass Optimizer:\n    def __init__(self, onnx_graph):\n        self.graph = gs.import_onnx(onnx_graph)\n\n    def cleanup(self, return_onnx=False):\n        self.graph.cleanup().toposort()\n        if return_onnx:\n            return gs.export_onnx(self.graph)\n\n    def select_outputs(self, keep, names=None):\n        self.graph.outputs = [self.graph.outputs[o] for o in keep]\n        if names:\n            for i, name in enumerate(names):\n                self.graph.outputs[i].name = name\n\n    def fold_constants(self, return_onnx=False):\n        onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)\n        self.graph = gs.import_onnx(onnx_graph)\n        if return_onnx:\n            return onnx_graph\n\n    def infer_shapes(self, return_onnx=False):\n        onnx_graph = gs.export_onnx(self.graph)\n        if onnx_graph.ByteSize() > 2147483648:\n            raise TypeError(\"ERROR: model size exceeds supported 2GB limit\")\n        else:\n            onnx_graph = shape_inference.infer_shapes(onnx_graph)\n\n        self.graph = gs.import_onnx(onnx_graph)\n        if return_onnx:\n            return onnx_graph\n\n\nclass BaseModel:\n    def __init__(self, model, fp16=False, device=\"cuda\", max_batch_size=16, embedding_dim=768, text_maxlen=77):\n        self.model = model\n        self.name = \"SD Model\"\n        self.fp16 = fp16\n        self.device = device\n\n        self.min_batch = 1\n        self.max_batch = max_batch_size\n        self.min_image_shape = 256  # min image resolution: 256x256\n        self.max_image_shape = 1024  # max image resolution: 1024x1024\n        self.min_latent_shape = self.min_image_shape // 8\n        self.max_latent_shape = self.max_image_shape // 8\n\n        self.embedding_dim = embedding_dim\n        self.text_maxlen = text_maxlen\n\n    def get_model(self):\n        return self.model\n\n    def get_input_names(self):\n        pass\n\n    def get_output_names(self):\n        pass\n\n    def get_dynamic_axes(self):\n        return None\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        pass\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        return None\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        return None\n\n    def optimize(self, onnx_graph):\n        opt = Optimizer(onnx_graph)\n        opt.cleanup()\n        opt.fold_constants()\n        opt.infer_shapes()\n        onnx_opt_graph = opt.cleanup(return_onnx=True)\n        return onnx_opt_graph\n\n    def check_dims(self, batch_size, image_height, image_width):\n        assert batch_size >= self.min_batch and batch_size <= self.max_batch\n        assert image_height % 8 == 0 or image_width % 8 == 0\n        latent_height = image_height // 8\n        latent_width = image_width // 8\n        assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape\n        assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape\n        return (latent_height, latent_width)\n\n    def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):\n        min_batch = batch_size if static_batch else self.min_batch\n        max_batch = batch_size if static_batch else self.max_batch\n        latent_height = image_height // 8\n        latent_width = image_width // 8\n        min_image_height = image_height if static_shape else self.min_image_shape\n        max_image_height = image_height if static_shape else self.max_image_shape\n        min_image_width = image_width if static_shape else self.min_image_shape\n        max_image_width = image_width if static_shape else self.max_image_shape\n        min_latent_height = latent_height if static_shape else self.min_latent_shape\n        max_latent_height = latent_height if static_shape else self.max_latent_shape\n        min_latent_width = latent_width if static_shape else self.min_latent_shape\n        max_latent_width = latent_width if static_shape else self.max_latent_shape\n        return (\n            min_batch,\n            max_batch,\n            min_image_height,\n            max_image_height,\n            min_image_width,\n            max_image_width,\n            min_latent_height,\n            max_latent_height,\n            min_latent_width,\n            max_latent_width,\n        )\n\n\ndef getOnnxPath(model_name, onnx_dir, opt=True):\n    return os.path.join(onnx_dir, model_name + (\".opt\" if opt else \"\") + \".onnx\")\n\n\ndef getEnginePath(model_name, engine_dir):\n    return os.path.join(engine_dir, model_name + \".plan\")\n\n\ndef build_engines(\n    models: dict,\n    engine_dir,\n    onnx_dir,\n    onnx_opset,\n    opt_image_height,\n    opt_image_width,\n    opt_batch_size=1,\n    force_engine_rebuild=False,\n    static_batch=False,\n    static_shape=True,\n    enable_all_tactics=False,\n    timing_cache=None,\n):\n    built_engines = {}\n    if not os.path.isdir(onnx_dir):\n        os.makedirs(onnx_dir)\n    if not os.path.isdir(engine_dir):\n        os.makedirs(engine_dir)\n\n    # Export models to ONNX\n    for model_name, model_obj in models.items():\n        engine_path = getEnginePath(model_name, engine_dir)\n        if force_engine_rebuild or not os.path.exists(engine_path):\n            logger.warning(\"Building Engines...\")\n            logger.warning(\"Engine build can take a while to complete\")\n            onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)\n            onnx_opt_path = getOnnxPath(model_name, onnx_dir)\n            if force_engine_rebuild or not os.path.exists(onnx_opt_path):\n                if force_engine_rebuild or not os.path.exists(onnx_path):\n                    logger.warning(f\"Exporting model: {onnx_path}\")\n                    model = model_obj.get_model()\n                    with torch.inference_mode(), torch.autocast(\"cuda\"):\n                        inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)\n                        torch.onnx.export(\n                            model,\n                            inputs,\n                            onnx_path,\n                            export_params=True,\n                            opset_version=onnx_opset,\n                            do_constant_folding=True,\n                            input_names=model_obj.get_input_names(),\n                            output_names=model_obj.get_output_names(),\n                            dynamic_axes=model_obj.get_dynamic_axes(),\n                        )\n                    del model\n                    torch.cuda.empty_cache()\n                    gc.collect()\n                else:\n                    logger.warning(f\"Found cached model: {onnx_path}\")\n\n                # Optimize onnx\n                if force_engine_rebuild or not os.path.exists(onnx_opt_path):\n                    logger.warning(f\"Generating optimizing model: {onnx_opt_path}\")\n                    onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))\n                    onnx.save(onnx_opt_graph, onnx_opt_path)\n                else:\n                    logger.warning(f\"Found cached optimized model: {onnx_opt_path} \")\n\n    # Build TensorRT engines\n    for model_name, model_obj in models.items():\n        engine_path = getEnginePath(model_name, engine_dir)\n        engine = Engine(engine_path)\n        onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)\n        onnx_opt_path = getOnnxPath(model_name, onnx_dir)\n\n        if force_engine_rebuild or not os.path.exists(engine.engine_path):\n            engine.build(\n                onnx_opt_path,\n                fp16=True,\n                input_profile=model_obj.get_input_profile(\n                    opt_batch_size,\n                    opt_image_height,\n                    opt_image_width,\n                    static_batch=static_batch,\n                    static_shape=static_shape,\n                ),\n                timing_cache=timing_cache,\n            )\n        built_engines[model_name] = engine\n\n    # Load and activate TensorRT engines\n    for model_name, model_obj in models.items():\n        engine = built_engines[model_name]\n        engine.load()\n        engine.activate()\n\n    return built_engines\n\n\ndef runEngine(engine, feed_dict, stream):\n    return engine.infer(feed_dict, stream)\n\n\nclass CLIP(BaseModel):\n    def __init__(self, model, device, max_batch_size, embedding_dim):\n        super(CLIP, self).__init__(\n            model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim\n        )\n        self.name = \"CLIP\"\n\n    def get_input_names(self):\n        return [\"input_ids\"]\n\n    def get_output_names(self):\n        return [\"text_embeddings\", \"pooler_output\"]\n\n    def get_dynamic_axes(self):\n        return {\"input_ids\": {0: \"B\"}, \"text_embeddings\": {0: \"B\"}}\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        self.check_dims(batch_size, image_height, image_width)\n        min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(\n            batch_size, image_height, image_width, static_batch, static_shape\n        )\n        return {\n            \"input_ids\": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]\n        }\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        self.check_dims(batch_size, image_height, image_width)\n        return {\n            \"input_ids\": (batch_size, self.text_maxlen),\n            \"text_embeddings\": (batch_size, self.text_maxlen, self.embedding_dim),\n        }\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        self.check_dims(batch_size, image_height, image_width)\n        return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)\n\n    def optimize(self, onnx_graph):\n        opt = Optimizer(onnx_graph)\n        opt.select_outputs([0])  # delete graph output#1\n        opt.cleanup()\n        opt.fold_constants()\n        opt.infer_shapes()\n        opt.select_outputs([0], names=[\"text_embeddings\"])  # rename network output\n        opt_onnx_graph = opt.cleanup(return_onnx=True)\n        return opt_onnx_graph\n\n\ndef make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False):\n    return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)\n\n\nclass UNet(BaseModel):\n    def __init__(\n        self, model, fp16=False, device=\"cuda\", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4\n    ):\n        super(UNet, self).__init__(\n            model=model,\n            fp16=fp16,\n            device=device,\n            max_batch_size=max_batch_size,\n            embedding_dim=embedding_dim,\n            text_maxlen=text_maxlen,\n        )\n        self.unet_dim = unet_dim\n        self.name = \"UNet\"\n\n    def get_input_names(self):\n        return [\"sample\", \"timestep\", \"encoder_hidden_states\"]\n\n    def get_output_names(self):\n        return [\"latent\"]\n\n    def get_dynamic_axes(self):\n        return {\n            \"sample\": {0: \"2B\", 2: \"H\", 3: \"W\"},\n            \"encoder_hidden_states\": {0: \"2B\"},\n            \"latent\": {0: \"2B\", 2: \"H\", 3: \"W\"},\n        }\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        (\n            min_batch,\n            max_batch,\n            _,\n            _,\n            _,\n            _,\n            min_latent_height,\n            max_latent_height,\n            min_latent_width,\n            max_latent_width,\n        ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)\n        return {\n            \"sample\": [\n                (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),\n                (2 * batch_size, self.unet_dim, latent_height, latent_width),\n                (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),\n            ],\n            \"encoder_hidden_states\": [\n                (2 * min_batch, self.text_maxlen, self.embedding_dim),\n                (2 * batch_size, self.text_maxlen, self.embedding_dim),\n                (2 * max_batch, self.text_maxlen, self.embedding_dim),\n            ],\n        }\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        return {\n            \"sample\": (2 * batch_size, self.unet_dim, latent_height, latent_width),\n            \"encoder_hidden_states\": (2 * batch_size, self.text_maxlen, self.embedding_dim),\n            \"latent\": (2 * batch_size, 4, latent_height, latent_width),\n        }\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        dtype = torch.float16 if self.fp16 else torch.float32\n        return (\n            torch.randn(\n                2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device\n            ),\n            torch.tensor([1.0], dtype=torch.float32, device=self.device),\n            torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),\n        )\n\n\ndef make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False, unet_dim=4):\n    return UNet(\n        model,\n        fp16=True,\n        device=device,\n        max_batch_size=max_batch_size,\n        embedding_dim=embedding_dim,\n        unet_dim=unet_dim,\n    )\n\n\nclass VAE(BaseModel):\n    def __init__(self, model, device, max_batch_size, embedding_dim):\n        super(VAE, self).__init__(\n            model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim\n        )\n        self.name = \"VAE decoder\"\n\n    def get_input_names(self):\n        return [\"latent\"]\n\n    def get_output_names(self):\n        return [\"images\"]\n\n    def get_dynamic_axes(self):\n        return {\"latent\": {0: \"B\", 2: \"H\", 3: \"W\"}, \"images\": {0: \"B\", 2: \"8H\", 3: \"8W\"}}\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        (\n            min_batch,\n            max_batch,\n            _,\n            _,\n            _,\n            _,\n            min_latent_height,\n            max_latent_height,\n            min_latent_width,\n            max_latent_width,\n        ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)\n        return {\n            \"latent\": [\n                (min_batch, 4, min_latent_height, min_latent_width),\n                (batch_size, 4, latent_height, latent_width),\n                (max_batch, 4, max_latent_height, max_latent_width),\n            ]\n        }\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        return {\n            \"latent\": (batch_size, 4, latent_height, latent_width),\n            \"images\": (batch_size, 3, image_height, image_width),\n        }\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)\n\n\ndef make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):\n    return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)\n\n\nclass TorchVAEEncoder(torch.nn.Module):\n    def __init__(self, model):\n        super().__init__()\n        self.vae_encoder = model\n\n    def forward(self, x):\n        return self.vae_encoder.encode(x).latent_dist.sample()\n\n\nclass VAEEncoder(BaseModel):\n    def __init__(self, model, device, max_batch_size, embedding_dim):\n        super(VAEEncoder, self).__init__(\n            model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim\n        )\n        self.name = \"VAE encoder\"\n\n    def get_model(self):\n        vae_encoder = TorchVAEEncoder(self.model)\n        return vae_encoder\n\n    def get_input_names(self):\n        return [\"images\"]\n\n    def get_output_names(self):\n        return [\"latent\"]\n\n    def get_dynamic_axes(self):\n        return {\"images\": {0: \"B\", 2: \"8H\", 3: \"8W\"}, \"latent\": {0: \"B\", 2: \"H\", 3: \"W\"}}\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        assert batch_size >= self.min_batch and batch_size <= self.max_batch\n        min_batch = batch_size if static_batch else self.min_batch\n        max_batch = batch_size if static_batch else self.max_batch\n        self.check_dims(batch_size, image_height, image_width)\n        (\n            min_batch,\n            max_batch,\n            min_image_height,\n            max_image_height,\n            min_image_width,\n            max_image_width,\n            _,\n            _,\n            _,\n            _,\n        ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)\n\n        return {\n            \"images\": [\n                (min_batch, 3, min_image_height, min_image_width),\n                (batch_size, 3, image_height, image_width),\n                (max_batch, 3, max_image_height, max_image_width),\n            ]\n        }\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        return {\n            \"images\": (batch_size, 3, image_height, image_width),\n            \"latent\": (batch_size, 4, latent_height, latent_width),\n        }\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        self.check_dims(batch_size, image_height, image_width)\n        return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device)\n\n\ndef make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False):\n    return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)\n\n\nclass TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):\n    r\"\"\"\n    Pipeline for inpainting using TensorRT accelerated Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: DDIMScheduler,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n        stages=[\"clip\", \"unet\", \"vae\", \"vae_encoder\"],\n        image_height: int = 512,\n        image_width: int = 512,\n        max_batch_size: int = 16,\n        # ONNX export parameters\n        onnx_opset: int = 17,\n        onnx_dir: str = \"onnx\",\n        # TensorRT engine build parameters\n        engine_dir: str = \"engine\",\n        force_engine_rebuild: bool = False,\n        timing_cache: str = \"timing_cache\",\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n\n        self.stages = stages\n        self.image_height, self.image_width = image_height, image_width\n        self.inpaint = True\n        self.onnx_opset = onnx_opset\n        self.onnx_dir = onnx_dir\n        self.engine_dir = engine_dir\n        self.force_engine_rebuild = force_engine_rebuild\n        self.timing_cache = timing_cache\n        self.build_static_batch = False\n        self.build_dynamic_shape = False\n\n        self.max_batch_size = max_batch_size\n        # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.\n        if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512:\n            self.max_batch_size = 4\n\n        self.stream = None  # loaded in loadResources()\n        self.models = {}  # loaded in __loadModels()\n        self.engine = {}  # loaded in build_engines()\n\n        self.vae.forward = self.vae.decode\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def __loadModels(self):\n        # Load pipeline models\n        self.embedding_dim = self.text_encoder.config.hidden_size\n        models_args = {\n            \"device\": self.torch_device,\n            \"max_batch_size\": self.max_batch_size,\n            \"embedding_dim\": self.embedding_dim,\n            \"inpaint\": self.inpaint,\n        }\n        if \"clip\" in self.stages:\n            self.models[\"clip\"] = make_CLIP(self.text_encoder, **models_args)\n        if \"unet\" in self.stages:\n            self.models[\"unet\"] = make_UNet(self.unet, **models_args, unet_dim=self.unet.config.in_channels)\n        if \"vae\" in self.stages:\n            self.models[\"vae\"] = make_VAE(self.vae, **models_args)\n        if \"vae_encoder\" in self.stages:\n            self.models[\"vae_encoder\"] = make_VAEEncoder(self.vae, **models_args)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline\n\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\n        if isinstance(generator, list):\n            image_latents = [\n                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                for i in range(image.shape[0])\n            ]\n            image_latents = torch.cat(image_latents, dim=0)\n        else:\n            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n        image_latents = self.vae.config.scaling_factor * image_latents\n\n        return image_latents\n\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n        image=None,\n        timestep=None,\n        is_strength_max=True,\n        return_noise=False,\n        return_image_latents=False,\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if (image is None or timestep is None) and not is_strength_max:\n            raise ValueError(\n                \"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise.\"\n                \"However, either the image or the noise timestep has not been provided.\"\n            )\n\n        if return_image_latents or (latents is None and not is_strength_max):\n            image = image.to(device=device, dtype=dtype)\n\n            if image.shape[1] == 4:\n                image_latents = image\n            else:\n                image_latents = self._encode_vae_image(image=image, generator=generator)\n            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)\n\n        if latents is None:\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            # if strength is 1. then initialise the latents to noise, else initial to image + noise\n            latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)\n            # if pure noise then scale the initial latents by the  Scheduler's init sigma\n            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents\n        else:\n            noise = latents.to(device)\n            latents = noise * self.scheduler.init_noise_sigma\n\n        outputs = (latents,)\n\n        if return_noise:\n            outputs += (noise,)\n\n        if return_image_latents:\n            outputs += (image_latents,)\n\n        return outputs\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(\n        self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype\n    ) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:\n        r\"\"\"\n        Runs the safety checker on the given image.\n        Args:\n            image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.\n            device (torch.device): The device to run the safety checker on.\n            dtype (torch.dtype): The data type of the input image.\n        Returns:\n            (image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and\n            a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.\n        \"\"\"\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    @classmethod\n    @validate_hf_hub_args\n    def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):\n        cache_dir = kwargs.pop(\"cache_dir\", None)\n        proxies = kwargs.pop(\"proxies\", None)\n        local_files_only = kwargs.pop(\"local_files_only\", False)\n        token = kwargs.pop(\"token\", None)\n        revision = kwargs.pop(\"revision\", None)\n\n        cls.cached_folder = (\n            pretrained_model_name_or_path\n            if os.path.isdir(pretrained_model_name_or_path)\n            else snapshot_download(\n                pretrained_model_name_or_path,\n                cache_dir=cache_dir,\n                proxies=proxies,\n                local_files_only=local_files_only,\n                token=token,\n                revision=revision,\n            )\n        )\n\n    def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):\n        super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings)\n\n        self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir)\n        self.engine_dir = os.path.join(self.cached_folder, self.engine_dir)\n        self.timing_cache = os.path.join(self.cached_folder, self.timing_cache)\n\n        # set device\n        self.torch_device = self._execution_device\n        logger.warning(f\"Running inference on device: {self.torch_device}\")\n\n        # load models\n        self.__loadModels()\n\n        # build engines\n        self.engine = build_engines(\n            self.models,\n            self.engine_dir,\n            self.onnx_dir,\n            self.onnx_opset,\n            opt_image_height=self.image_height,\n            opt_image_width=self.image_width,\n            force_engine_rebuild=self.force_engine_rebuild,\n            static_batch=self.build_static_batch,\n            static_shape=not self.build_dynamic_shape,\n            timing_cache=self.timing_cache,\n        )\n\n        return self\n\n    def __initialize_timesteps(self, num_inference_steps, strength):\n        self.scheduler.set_timesteps(num_inference_steps)\n        offset = self.scheduler.config.steps_offset if hasattr(self.scheduler, \"steps_offset\") else 0\n        init_timestep = int(num_inference_steps * strength) + offset\n        init_timestep = min(init_timestep, num_inference_steps)\n        t_start = max(num_inference_steps - init_timestep + offset, 0)\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :].to(self.torch_device)\n        return timesteps, num_inference_steps - t_start\n\n    def __preprocess_images(self, batch_size, images=()):\n        init_images = []\n        for image in images:\n            image = image.to(self.torch_device).float()\n            image = image.repeat(batch_size, 1, 1, 1)\n            init_images.append(image)\n        return tuple(init_images)\n\n    def __encode_image(self, init_image):\n        init_latents = runEngine(self.engine[\"vae_encoder\"], {\"images\": init_image}, self.stream)[\"latent\"]\n        init_latents = 0.18215 * init_latents\n        return init_latents\n\n    def __encode_prompt(self, prompt, negative_prompt):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n        \"\"\"\n        # Tokenize prompt\n        text_input_ids = (\n            self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            .input_ids.type(torch.int32)\n            .to(self.torch_device)\n        )\n\n        # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt\n        text_embeddings = runEngine(self.engine[\"clip\"], {\"input_ids\": text_input_ids}, self.stream)[\n            \"text_embeddings\"\n        ].clone()\n\n        # Tokenize negative prompt\n        uncond_input_ids = (\n            self.tokenizer(\n                negative_prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            .input_ids.type(torch.int32)\n            .to(self.torch_device)\n        )\n        uncond_embeddings = runEngine(self.engine[\"clip\"], {\"input_ids\": uncond_input_ids}, self.stream)[\n            \"text_embeddings\"\n        ]\n\n        # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance\n        text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)\n\n        return text_embeddings\n\n    def __denoise_latent(\n        self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None\n    ):\n        if not isinstance(timesteps, torch.Tensor):\n            timesteps = self.scheduler.timesteps\n        for step_index, timestep in enumerate(timesteps):\n            # Expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2)\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)\n            if isinstance(mask, torch.Tensor):\n                latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)\n\n            # Predict the noise residual\n            timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep\n\n            noise_pred = runEngine(\n                self.engine[\"unet\"],\n                {\"sample\": latent_model_input, \"timestep\": timestep_float, \"encoder_hidden_states\": text_embeddings},\n                self.stream,\n            )[\"latent\"]\n\n            # Perform guidance\n            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample\n\n        latents = 1.0 / 0.18215 * latents\n        return latents\n\n    def __decode_latent(self, latents):\n        images = runEngine(self.engine[\"vae\"], {\"latent\": latents}, self.stream)[\"images\"]\n        images = (images / 2 + 0.5).clamp(0, 1)\n        return images.cpu().permute(0, 2, 3, 1).float().numpy()\n\n    def __loadResources(self, image_height, image_width, batch_size):\n        self.stream = cudart.cudaStreamCreate()[1]\n\n        # Allocate buffers for TensorRT engine bindings\n        for model_name, obj in self.models.items():\n            self.engine[model_name].allocate_buffers(\n                shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device\n            )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: Union[torch.Tensor, PIL.Image.Image] = None,\n        mask_image: Union[torch.Tensor, PIL.Image.Image] = None,\n        strength: float = 1.0,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will\n                be masked out with `mask_image` and repainted according to `prompt`.\n            mask_image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be\n                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted\n                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)\n                instead of 3, so the expected shape would be `(B, H, W, 1)`.\n            strength (`float`, *optional*, defaults to 0.8):\n                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`\n                will be used as a starting point, adding more noise to it the larger the `strength`. The number of\n                denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will\n                be maximum and the denoising process will run for the full number of iterations specified in\n                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n\n        \"\"\"\n        self.generator = generator\n        self.denoising_steps = num_inference_steps\n        self._guidance_scale = guidance_scale\n\n        # Pre-compute latent input scales and linear multistep coefficients\n        self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)\n\n        # Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n            prompt = [prompt]\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"Expected prompt to be of type list or str but got {type(prompt)}\")\n\n        if negative_prompt is None:\n            negative_prompt = [\"\"] * batch_size\n\n        if negative_prompt is not None and isinstance(negative_prompt, str):\n            negative_prompt = [negative_prompt]\n\n        assert len(prompt) == len(negative_prompt)\n\n        if batch_size > self.max_batch_size:\n            raise ValueError(\n                f\"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4\"\n            )\n\n        # Validate image dimensions\n        mask_width, mask_height = mask_image.size\n        if mask_height != self.image_height or mask_width != self.image_width:\n            raise ValueError(\n                f\"Input image height and width {self.image_height} and {self.image_width} are not equal to \"\n                f\"the respective dimensions of the mask image {mask_height} and {mask_width}\"\n            )\n\n        # load resources\n        self.__loadResources(self.image_height, self.image_width, batch_size)\n\n        with torch.inference_mode(), torch.autocast(\"cuda\"), trt.Runtime(TRT_LOGGER):\n            # Spatial dimensions of latent tensor\n            latent_height = self.image_height // 8\n            latent_width = self.image_width // 8\n\n            # Pre-process input images\n            mask, masked_image, init_image = self.__preprocess_images(\n                batch_size,\n                prepare_mask_and_masked_image(\n                    image,\n                    mask_image,\n                    self.image_height,\n                    self.image_width,\n                    return_image=True,\n                ),\n            )\n\n            mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width))\n            mask = torch.cat([mask] * 2)\n\n            # Initialize timesteps\n            timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength)\n\n            # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)\n            latent_timestep = timesteps[:1].repeat(batch_size)\n            # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise\n            is_strength_max = strength == 1.0\n\n            # Pre-initialize latents\n            num_channels_latents = self.vae.config.latent_channels\n            latents_outputs = self.prepare_latents(\n                batch_size,\n                num_channels_latents,\n                self.image_height,\n                self.image_width,\n                torch.float32,\n                self.torch_device,\n                generator,\n                image=init_image,\n                timestep=latent_timestep,\n                is_strength_max=is_strength_max,\n            )\n\n            latents = latents_outputs[0]\n\n            # VAE encode masked image\n            masked_latents = self.__encode_image(masked_image)\n            masked_latents = torch.cat([masked_latents] * 2)\n\n            # CLIP text encoder\n            text_embeddings = self.__encode_prompt(prompt, negative_prompt)\n\n            # UNet denoiser\n            latents = self.__denoise_latent(\n                latents,\n                text_embeddings,\n                timesteps=timesteps,\n                step_offset=t_start,\n                mask=mask,\n                masked_image_latents=masked_latents,\n            )\n\n            # VAE decode latent\n            images = self.__decode_latent(latents)\n\n        images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)\n        images = self.numpy_to_pil(images)\n        return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/stable_diffusion_tensorrt_txt2img.py",
    "content": "#\n# Copyright 2025 The HuggingFace Inc. team.\n# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License-Identifier: Apache-2.0\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gc\nimport os\nfrom collections import OrderedDict\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport onnx\nimport onnx_graphsurgeon as gs\nimport PIL.Image\nimport tensorrt as trt\nimport torch\nfrom cuda import cudart\nfrom huggingface_hub import snapshot_download\nfrom huggingface_hub.utils import validate_hf_hub_args\nfrom onnx import shape_inference\nfrom packaging import version\nfrom polygraphy import cuda\nfrom polygraphy.backend.common import bytes_from_path\nfrom polygraphy.backend.onnx.loader import fold_constants\nfrom polygraphy.backend.trt import (\n    CreateConfig,\n    Profile,\n    engine_from_bytes,\n    engine_from_network,\n    network_from_onnx_path,\n    save_engine,\n)\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.configuration_utils import FrozenDict, deprecate\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.stable_diffusion import (\n    StableDiffusionPipelineOutput,\n    StableDiffusionSafetyChecker,\n)\nfrom diffusers.schedulers import DDIMScheduler\nfrom diffusers.utils import logging\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\n\"\"\"\nInstallation instructions\npython3 -m pip install --upgrade transformers diffusers>=0.16.0\npython3 -m pip install --upgrade tensorrt~=10.2.0\npython3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com\npython3 -m pip install onnxruntime\n\"\"\"\n\nTRT_LOGGER = trt.Logger(trt.Logger.ERROR)\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n# Map of numpy dtype -> torch dtype\nnumpy_to_torch_dtype_dict = {\n    np.uint8: torch.uint8,\n    np.int8: torch.int8,\n    np.int16: torch.int16,\n    np.int32: torch.int32,\n    np.int64: torch.int64,\n    np.float16: torch.float16,\n    np.float32: torch.float32,\n    np.float64: torch.float64,\n    np.complex64: torch.complex64,\n    np.complex128: torch.complex128,\n}\nif np.version.full_version >= \"1.24.0\":\n    numpy_to_torch_dtype_dict[np.bool_] = torch.bool\nelse:\n    numpy_to_torch_dtype_dict[np.bool] = torch.bool\n\n# Map of torch dtype -> numpy dtype\ntorch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}\n\n\nclass Engine:\n    def __init__(self, engine_path):\n        self.engine_path = engine_path\n        self.engine = None\n        self.context = None\n        self.buffers = OrderedDict()\n        self.tensors = OrderedDict()\n\n    def __del__(self):\n        [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]\n        del self.engine\n        del self.context\n        del self.buffers\n        del self.tensors\n\n    def build(\n        self,\n        onnx_path,\n        fp16,\n        input_profile=None,\n        enable_all_tactics=False,\n        timing_cache=None,\n    ):\n        logger.warning(f\"Building TensorRT engine for {onnx_path}: {self.engine_path}\")\n        p = Profile()\n        if input_profile:\n            for name, dims in input_profile.items():\n                assert len(dims) == 3\n                p.add(name, min=dims[0], opt=dims[1], max=dims[2])\n\n        extra_build_args = {}\n        if not enable_all_tactics:\n            extra_build_args[\"tactic_sources\"] = []\n\n        engine = engine_from_network(\n            network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]),\n            config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **extra_build_args),\n            save_timing_cache=timing_cache,\n        )\n        save_engine(engine, path=self.engine_path)\n\n    def load(self):\n        logger.warning(f\"Loading TensorRT engine: {self.engine_path}\")\n        self.engine = engine_from_bytes(bytes_from_path(self.engine_path))\n\n    def activate(self):\n        self.context = self.engine.create_execution_context()\n\n    def allocate_buffers(self, shape_dict=None, device=\"cuda\"):\n        for binding in range(self.engine.num_io_tensors):\n            name = self.engine.get_tensor_name(binding)\n            if shape_dict and name in shape_dict:\n                shape = shape_dict[name]\n            else:\n                shape = self.engine.get_tensor_shape(name)\n            dtype = trt.nptype(self.engine.get_tensor_dtype(name))\n            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:\n                self.context.set_input_shape(name, shape)\n            tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)\n            self.tensors[name] = tensor\n\n    def infer(self, feed_dict, stream):\n        for name, buf in feed_dict.items():\n            self.tensors[name].copy_(buf)\n        for name, tensor in self.tensors.items():\n            self.context.set_tensor_address(name, tensor.data_ptr())\n        noerror = self.context.execute_async_v3(stream)\n        if not noerror:\n            raise ValueError(\"ERROR: inference failed.\")\n\n        return self.tensors\n\n\nclass Optimizer:\n    def __init__(self, onnx_graph):\n        self.graph = gs.import_onnx(onnx_graph)\n\n    def cleanup(self, return_onnx=False):\n        self.graph.cleanup().toposort()\n        if return_onnx:\n            return gs.export_onnx(self.graph)\n\n    def select_outputs(self, keep, names=None):\n        self.graph.outputs = [self.graph.outputs[o] for o in keep]\n        if names:\n            for i, name in enumerate(names):\n                self.graph.outputs[i].name = name\n\n    def fold_constants(self, return_onnx=False):\n        onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)\n        self.graph = gs.import_onnx(onnx_graph)\n        if return_onnx:\n            return onnx_graph\n\n    def infer_shapes(self, return_onnx=False):\n        onnx_graph = gs.export_onnx(self.graph)\n        if onnx_graph.ByteSize() > 2147483648:\n            raise TypeError(\"ERROR: model size exceeds supported 2GB limit\")\n        else:\n            onnx_graph = shape_inference.infer_shapes(onnx_graph)\n\n        self.graph = gs.import_onnx(onnx_graph)\n        if return_onnx:\n            return onnx_graph\n\n\nclass BaseModel:\n    def __init__(self, model, fp16=False, device=\"cuda\", max_batch_size=16, embedding_dim=768, text_maxlen=77):\n        self.model = model\n        self.name = \"SD Model\"\n        self.fp16 = fp16\n        self.device = device\n\n        self.min_batch = 1\n        self.max_batch = max_batch_size\n        self.min_image_shape = 256  # min image resolution: 256x256\n        self.max_image_shape = 1024  # max image resolution: 1024x1024\n        self.min_latent_shape = self.min_image_shape // 8\n        self.max_latent_shape = self.max_image_shape // 8\n\n        self.embedding_dim = embedding_dim\n        self.text_maxlen = text_maxlen\n\n    def get_model(self):\n        return self.model\n\n    def get_input_names(self):\n        pass\n\n    def get_output_names(self):\n        pass\n\n    def get_dynamic_axes(self):\n        return None\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        pass\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        return None\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        return None\n\n    def optimize(self, onnx_graph):\n        opt = Optimizer(onnx_graph)\n        opt.cleanup()\n        opt.fold_constants()\n        opt.infer_shapes()\n        onnx_opt_graph = opt.cleanup(return_onnx=True)\n        return onnx_opt_graph\n\n    def check_dims(self, batch_size, image_height, image_width):\n        assert batch_size >= self.min_batch and batch_size <= self.max_batch\n        assert image_height % 8 == 0 or image_width % 8 == 0\n        latent_height = image_height // 8\n        latent_width = image_width // 8\n        assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape\n        assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape\n        return (latent_height, latent_width)\n\n    def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):\n        min_batch = batch_size if static_batch else self.min_batch\n        max_batch = batch_size if static_batch else self.max_batch\n        latent_height = image_height // 8\n        latent_width = image_width // 8\n        min_image_height = image_height if static_shape else self.min_image_shape\n        max_image_height = image_height if static_shape else self.max_image_shape\n        min_image_width = image_width if static_shape else self.min_image_shape\n        max_image_width = image_width if static_shape else self.max_image_shape\n        min_latent_height = latent_height if static_shape else self.min_latent_shape\n        max_latent_height = latent_height if static_shape else self.max_latent_shape\n        min_latent_width = latent_width if static_shape else self.min_latent_shape\n        max_latent_width = latent_width if static_shape else self.max_latent_shape\n        return (\n            min_batch,\n            max_batch,\n            min_image_height,\n            max_image_height,\n            min_image_width,\n            max_image_width,\n            min_latent_height,\n            max_latent_height,\n            min_latent_width,\n            max_latent_width,\n        )\n\n\ndef getOnnxPath(model_name, onnx_dir, opt=True):\n    return os.path.join(onnx_dir, model_name + (\".opt\" if opt else \"\") + \".onnx\")\n\n\ndef getEnginePath(model_name, engine_dir):\n    return os.path.join(engine_dir, model_name + \".plan\")\n\n\ndef build_engines(\n    models: dict,\n    engine_dir,\n    onnx_dir,\n    onnx_opset,\n    opt_image_height,\n    opt_image_width,\n    opt_batch_size=1,\n    force_engine_rebuild=False,\n    static_batch=False,\n    static_shape=True,\n    enable_all_tactics=False,\n    timing_cache=None,\n):\n    built_engines = {}\n    if not os.path.isdir(onnx_dir):\n        os.makedirs(onnx_dir)\n    if not os.path.isdir(engine_dir):\n        os.makedirs(engine_dir)\n\n    # Export models to ONNX\n    for model_name, model_obj in models.items():\n        engine_path = getEnginePath(model_name, engine_dir)\n        if force_engine_rebuild or not os.path.exists(engine_path):\n            logger.warning(\"Building Engines...\")\n            logger.warning(\"Engine build can take a while to complete\")\n            onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)\n            onnx_opt_path = getOnnxPath(model_name, onnx_dir)\n            if force_engine_rebuild or not os.path.exists(onnx_opt_path):\n                if force_engine_rebuild or not os.path.exists(onnx_path):\n                    logger.warning(f\"Exporting model: {onnx_path}\")\n                    model = model_obj.get_model()\n                    with torch.inference_mode(), torch.autocast(\"cuda\"):\n                        inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)\n                        torch.onnx.export(\n                            model,\n                            inputs,\n                            onnx_path,\n                            export_params=True,\n                            opset_version=onnx_opset,\n                            do_constant_folding=True,\n                            input_names=model_obj.get_input_names(),\n                            output_names=model_obj.get_output_names(),\n                            dynamic_axes=model_obj.get_dynamic_axes(),\n                        )\n                    del model\n                    torch.cuda.empty_cache()\n                    gc.collect()\n                else:\n                    logger.warning(f\"Found cached model: {onnx_path}\")\n\n                # Optimize onnx\n                if force_engine_rebuild or not os.path.exists(onnx_opt_path):\n                    logger.warning(f\"Generating optimizing model: {onnx_opt_path}\")\n                    onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))\n                    onnx.save(onnx_opt_graph, onnx_opt_path)\n                else:\n                    logger.warning(f\"Found cached optimized model: {onnx_opt_path} \")\n\n    # Build TensorRT engines\n    for model_name, model_obj in models.items():\n        engine_path = getEnginePath(model_name, engine_dir)\n        engine = Engine(engine_path)\n        onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)\n        onnx_opt_path = getOnnxPath(model_name, onnx_dir)\n\n        if force_engine_rebuild or not os.path.exists(engine.engine_path):\n            engine.build(\n                onnx_opt_path,\n                fp16=True,\n                input_profile=model_obj.get_input_profile(\n                    opt_batch_size,\n                    opt_image_height,\n                    opt_image_width,\n                    static_batch=static_batch,\n                    static_shape=static_shape,\n                ),\n                timing_cache=timing_cache,\n            )\n        built_engines[model_name] = engine\n\n    # Load and activate TensorRT engines\n    for model_name, model_obj in models.items():\n        engine = built_engines[model_name]\n        engine.load()\n        engine.activate()\n\n    return built_engines\n\n\ndef runEngine(engine, feed_dict, stream):\n    return engine.infer(feed_dict, stream)\n\n\nclass CLIP(BaseModel):\n    def __init__(self, model, device, max_batch_size, embedding_dim):\n        super(CLIP, self).__init__(\n            model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim\n        )\n        self.name = \"CLIP\"\n\n    def get_input_names(self):\n        return [\"input_ids\"]\n\n    def get_output_names(self):\n        return [\"text_embeddings\", \"pooler_output\"]\n\n    def get_dynamic_axes(self):\n        return {\"input_ids\": {0: \"B\"}, \"text_embeddings\": {0: \"B\"}}\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        self.check_dims(batch_size, image_height, image_width)\n        min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(\n            batch_size, image_height, image_width, static_batch, static_shape\n        )\n        return {\n            \"input_ids\": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]\n        }\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        self.check_dims(batch_size, image_height, image_width)\n        return {\n            \"input_ids\": (batch_size, self.text_maxlen),\n            \"text_embeddings\": (batch_size, self.text_maxlen, self.embedding_dim),\n        }\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        self.check_dims(batch_size, image_height, image_width)\n        return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)\n\n    def optimize(self, onnx_graph):\n        opt = Optimizer(onnx_graph)\n        opt.select_outputs([0])  # delete graph output#1\n        opt.cleanup()\n        opt.fold_constants()\n        opt.infer_shapes()\n        opt.select_outputs([0], names=[\"text_embeddings\"])  # rename network output\n        opt_onnx_graph = opt.cleanup(return_onnx=True)\n        return opt_onnx_graph\n\n\ndef make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False):\n    return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)\n\n\nclass UNet(BaseModel):\n    def __init__(\n        self, model, fp16=False, device=\"cuda\", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4\n    ):\n        super(UNet, self).__init__(\n            model=model,\n            fp16=fp16,\n            device=device,\n            max_batch_size=max_batch_size,\n            embedding_dim=embedding_dim,\n            text_maxlen=text_maxlen,\n        )\n        self.unet_dim = unet_dim\n        self.name = \"UNet\"\n\n    def get_input_names(self):\n        return [\"sample\", \"timestep\", \"encoder_hidden_states\"]\n\n    def get_output_names(self):\n        return [\"latent\"]\n\n    def get_dynamic_axes(self):\n        return {\n            \"sample\": {0: \"2B\", 2: \"H\", 3: \"W\"},\n            \"encoder_hidden_states\": {0: \"2B\"},\n            \"latent\": {0: \"2B\", 2: \"H\", 3: \"W\"},\n        }\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        (\n            min_batch,\n            max_batch,\n            _,\n            _,\n            _,\n            _,\n            min_latent_height,\n            max_latent_height,\n            min_latent_width,\n            max_latent_width,\n        ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)\n        return {\n            \"sample\": [\n                (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),\n                (2 * batch_size, self.unet_dim, latent_height, latent_width),\n                (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),\n            ],\n            \"encoder_hidden_states\": [\n                (2 * min_batch, self.text_maxlen, self.embedding_dim),\n                (2 * batch_size, self.text_maxlen, self.embedding_dim),\n                (2 * max_batch, self.text_maxlen, self.embedding_dim),\n            ],\n        }\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        return {\n            \"sample\": (2 * batch_size, self.unet_dim, latent_height, latent_width),\n            \"encoder_hidden_states\": (2 * batch_size, self.text_maxlen, self.embedding_dim),\n            \"latent\": (2 * batch_size, 4, latent_height, latent_width),\n        }\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        dtype = torch.float16 if self.fp16 else torch.float32\n        return (\n            torch.randn(\n                2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device\n            ),\n            torch.tensor([1.0], dtype=torch.float32, device=self.device),\n            torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),\n        )\n\n\ndef make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False):\n    return UNet(\n        model,\n        fp16=True,\n        device=device,\n        max_batch_size=max_batch_size,\n        embedding_dim=embedding_dim,\n        unet_dim=(9 if inpaint else 4),\n    )\n\n\nclass VAE(BaseModel):\n    def __init__(self, model, device, max_batch_size, embedding_dim):\n        super(VAE, self).__init__(\n            model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim\n        )\n        self.name = \"VAE decoder\"\n\n    def get_input_names(self):\n        return [\"latent\"]\n\n    def get_output_names(self):\n        return [\"images\"]\n\n    def get_dynamic_axes(self):\n        return {\"latent\": {0: \"B\", 2: \"H\", 3: \"W\"}, \"images\": {0: \"B\", 2: \"8H\", 3: \"8W\"}}\n\n    def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        (\n            min_batch,\n            max_batch,\n            _,\n            _,\n            _,\n            _,\n            min_latent_height,\n            max_latent_height,\n            min_latent_width,\n            max_latent_width,\n        ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)\n        return {\n            \"latent\": [\n                (min_batch, 4, min_latent_height, min_latent_width),\n                (batch_size, 4, latent_height, latent_width),\n                (max_batch, 4, max_latent_height, max_latent_width),\n            ]\n        }\n\n    def get_shape_dict(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        return {\n            \"latent\": (batch_size, 4, latent_height, latent_width),\n            \"images\": (batch_size, 3, image_height, image_width),\n        }\n\n    def get_sample_input(self, batch_size, image_height, image_width):\n        latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)\n        return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)\n\n\ndef make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):\n    return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)\n\n\nclass TensorRTStableDiffusionPipeline(DiffusionPipeline):\n    r\"\"\"\n    Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: DDIMScheduler,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n        stages=[\"clip\", \"unet\", \"vae\"],\n        image_height: int = 768,\n        image_width: int = 768,\n        max_batch_size: int = 16,\n        # ONNX export parameters\n        onnx_opset: int = 18,\n        onnx_dir: str = \"onnx\",\n        # TensorRT engine build parameters\n        engine_dir: str = \"engine\",\n        force_engine_rebuild: bool = False,\n        timing_cache: str = \"timing_cache\",\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"clip_sample\", False) is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = (\n            unet is not None\n            and hasattr(unet.config, \"_diffusers_version\")\n            and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse(\"0.9.0.dev0\")\n        )\n        is_unet_sample_size_less_64 = (\n            unet is not None and hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- stable-diffusion-v1-5/stable-diffusion-v1-5\"\n                \" \\n- stable-diffusion-v1-5/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n\n        self.stages = stages\n        self.image_height, self.image_width = image_height, image_width\n        self.inpaint = False\n        self.onnx_opset = onnx_opset\n        self.onnx_dir = onnx_dir\n        self.engine_dir = engine_dir\n        self.force_engine_rebuild = force_engine_rebuild\n        self.timing_cache = timing_cache\n        self.build_static_batch = False\n        self.build_dynamic_shape = False\n\n        self.max_batch_size = max_batch_size\n        # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.\n        if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512:\n            self.max_batch_size = 4\n\n        self.stream = None  # loaded in loadResources()\n        self.models = {}  # loaded in __loadModels()\n        self.engine = {}  # loaded in build_engines()\n\n        self.vae.forward = self.vae.decode\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def __loadModels(self):\n        # Load pipeline models\n        self.embedding_dim = self.text_encoder.config.hidden_size\n        models_args = {\n            \"device\": self.torch_device,\n            \"max_batch_size\": self.max_batch_size,\n            \"embedding_dim\": self.embedding_dim,\n            \"inpaint\": self.inpaint,\n        }\n        if \"clip\" in self.stages:\n            self.models[\"clip\"] = make_CLIP(self.text_encoder, **models_args)\n        if \"unet\" in self.stages:\n            self.models[\"unet\"] = make_UNet(self.unet, **models_args)\n        if \"vae\" in self.stages:\n            self.models[\"vae\"] = make_VAE(self.vae, **models_args)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(\n        self,\n        batch_size: int,\n        num_channels_latents: int,\n        height: int,\n        width: int,\n        dtype: torch.dtype,\n        device: torch.device,\n        generator: Union[torch.Generator, List[torch.Generator]],\n        latents: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Prepare the latent vectors for diffusion.\n        Args:\n            batch_size (int): The number of samples in the batch.\n            num_channels_latents (int): The number of channels in the latent vectors.\n            height (int): The height of the latent vectors.\n            width (int): The width of the latent vectors.\n            dtype (torch.dtype): The data type of the latent vectors.\n            device (torch.device): The device to place the latent vectors on.\n            generator (Union[torch.Generator, List[torch.Generator]]): The generator(s) to use for random number generation.\n            latents (Optional[torch.Tensor]): The pre-existing latent vectors. If None, new latent vectors will be generated.\n        Returns:\n            torch.Tensor: The prepared latent vectors.\n        \"\"\"\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(\n        self, image: Union[torch.Tensor, PIL.Image.Image], device: torch.device, dtype: torch.dtype\n    ) -> Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]:\n        r\"\"\"\n        Runs the safety checker on the given image.\n        Args:\n            image (Union[torch.Tensor, PIL.Image.Image]): The input image to be checked.\n            device (torch.device): The device to run the safety checker on.\n            dtype (torch.dtype): The data type of the input image.\n        Returns:\n            (image, has_nsfw_concept) Tuple[Union[torch.Tensor, PIL.Image.Image], Optional[bool]]: A tuple containing the processed image and\n            a boolean indicating whether the image has a NSFW (Not Safe for Work) concept.\n        \"\"\"\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    @classmethod\n    @validate_hf_hub_args\n    def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):\n        cache_dir = kwargs.pop(\"cache_dir\", None)\n        proxies = kwargs.pop(\"proxies\", None)\n        local_files_only = kwargs.pop(\"local_files_only\", False)\n        token = kwargs.pop(\"token\", None)\n        revision = kwargs.pop(\"revision\", None)\n\n        cls.cached_folder = (\n            pretrained_model_name_or_path\n            if os.path.isdir(pretrained_model_name_or_path)\n            else snapshot_download(\n                pretrained_model_name_or_path,\n                cache_dir=cache_dir,\n                proxies=proxies,\n                local_files_only=local_files_only,\n                token=token,\n                revision=revision,\n            )\n        )\n\n    def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):\n        super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings)\n\n        self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir)\n        self.engine_dir = os.path.join(self.cached_folder, self.engine_dir)\n        self.timing_cache = os.path.join(self.cached_folder, self.timing_cache)\n\n        # set device\n        self.torch_device = self._execution_device\n        logger.warning(f\"Running inference on device: {self.torch_device}\")\n\n        # load models\n        self.__loadModels()\n\n        # build engines\n        self.engine = build_engines(\n            self.models,\n            self.engine_dir,\n            self.onnx_dir,\n            self.onnx_opset,\n            opt_image_height=self.image_height,\n            opt_image_width=self.image_width,\n            force_engine_rebuild=self.force_engine_rebuild,\n            static_batch=self.build_static_batch,\n            static_shape=not self.build_dynamic_shape,\n            timing_cache=self.timing_cache,\n        )\n\n        return self\n\n    def __encode_prompt(self, prompt, negative_prompt):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n        \"\"\"\n        # Tokenize prompt\n        text_input_ids = (\n            self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            .input_ids.type(torch.int32)\n            .to(self.torch_device)\n        )\n\n        # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt\n        text_embeddings = runEngine(self.engine[\"clip\"], {\"input_ids\": text_input_ids}, self.stream)[\n            \"text_embeddings\"\n        ].clone()\n\n        # Tokenize negative prompt\n        uncond_input_ids = (\n            self.tokenizer(\n                negative_prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            .input_ids.type(torch.int32)\n            .to(self.torch_device)\n        )\n        uncond_embeddings = runEngine(self.engine[\"clip\"], {\"input_ids\": uncond_input_ids}, self.stream)[\n            \"text_embeddings\"\n        ]\n\n        # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance\n        text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)\n\n        return text_embeddings\n\n    def __denoise_latent(\n        self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None\n    ):\n        if not isinstance(timesteps, torch.Tensor):\n            timesteps = self.scheduler.timesteps\n        for step_index, timestep in enumerate(timesteps):\n            # Expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2)\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)\n            if isinstance(mask, torch.Tensor):\n                latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)\n\n            # Predict the noise residual\n            timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep\n\n            noise_pred = runEngine(\n                self.engine[\"unet\"],\n                {\"sample\": latent_model_input, \"timestep\": timestep_float, \"encoder_hidden_states\": text_embeddings},\n                self.stream,\n            )[\"latent\"]\n\n            # Perform guidance\n            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + self._guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample\n\n        latents = 1.0 / 0.18215 * latents\n        return latents\n\n    def __decode_latent(self, latents):\n        images = runEngine(self.engine[\"vae\"], {\"latent\": latents}, self.stream)[\"images\"]\n        images = (images / 2 + 0.5).clamp(0, 1)\n        return images.cpu().permute(0, 2, 3, 1).float().numpy()\n\n    def __loadResources(self, image_height, image_width, batch_size):\n        self.stream = cudart.cudaStreamCreate()[1]\n\n        # Allocate buffers for TensorRT engine bindings\n        for model_name, obj in self.models.items():\n            self.engine[model_name].allocate_buffers(\n                shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device\n            )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n\n        \"\"\"\n        self.generator = generator\n        self.denoising_steps = num_inference_steps\n        self._guidance_scale = guidance_scale\n\n        # Pre-compute latent input scales and linear multistep coefficients\n        self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)\n\n        # Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n            prompt = [prompt]\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"Expected prompt to be of type list or str but got {type(prompt)}\")\n\n        if negative_prompt is None:\n            negative_prompt = [\"\"] * batch_size\n\n        if negative_prompt is not None and isinstance(negative_prompt, str):\n            negative_prompt = [negative_prompt]\n\n        assert len(prompt) == len(negative_prompt)\n\n        if batch_size > self.max_batch_size:\n            raise ValueError(\n                f\"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4\"\n            )\n\n        # load resources\n        self.__loadResources(self.image_height, self.image_width, batch_size)\n\n        with torch.inference_mode(), torch.autocast(\"cuda\"), trt.Runtime(TRT_LOGGER):\n            # CLIP text encoder\n            text_embeddings = self.__encode_prompt(prompt, negative_prompt)\n\n            # Pre-initialize latents\n            num_channels_latents = self.unet.config.in_channels\n            latents = self.prepare_latents(\n                batch_size,\n                num_channels_latents,\n                self.image_height,\n                self.image_width,\n                torch.float32,\n                self.torch_device,\n                generator,\n            )\n\n            # UNet denoiser\n            latents = self.__denoise_latent(latents, text_embeddings)\n\n            # VAE decode latent\n            images = self.__decode_latent(latents)\n\n        images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)\n        images = self.numpy_to_pil(images)\n        return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/community/stable_diffusion_xl_controlnet_reference.py",
    "content": "# Based on stable_diffusion_xl_reference.py and stable_diffusion_controlnet_reference.py\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\n\nfrom diffusers import StableDiffusionXLControlNetPipeline\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.image_processor import PipelineImageInput\nfrom diffusers.models import ControlNetModel\nfrom diffusers.models.attention import BasicTransformerBlock\nfrom diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> # !pip install opencv-python transformers accelerate\n        >>> from diffusers import ControlNetModel, AutoencoderKL\n        >>> from diffusers.schedulers import UniPCMultistepScheduler\n        >>> from diffusers.utils import load_image\n        >>> import numpy as np\n        >>> import torch\n\n        >>> import cv2\n        >>> from PIL import Image\n\n        >>> # download an image for the Canny controlnet\n        >>> canny_image = load_image(\n        ...     \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg\"\n        ... )\n\n        >>> # download an image for the Reference controlnet\n        >>> ref_image = load_image(\n        ...     \"https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png\"\n        ... )\n\n        >>> # initialize the models and pipeline\n        >>> controlnet_conditioning_scale = 0.5  # recommended for good generalization\n        >>> controlnet = ControlNetModel.from_pretrained(\n        ...     \"diffusers/controlnet-canny-sdxl-1.0\", torch_dtype=torch.float16\n        ... )\n        >>> vae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16)\n        >>> pipe = StableDiffusionXLControlNetReferencePipeline.from_pretrained(\n        ...     \"stabilityai/stable-diffusion-xl-base-1.0\", controlnet=controlnet, vae=vae, torch_dtype=torch.float16\n        ... ).to(\"cuda:0\")\n\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n\n        >>> # get canny image\n        >>> image = np.array(canny_image)\n        >>> image = cv2.Canny(image, 100, 200)\n        >>> image = image[:, :, None]\n        >>> image = np.concatenate([image, image, image], axis=2)\n        >>> canny_image = Image.fromarray(image)\n\n        >>> # generate image\n        >>> image = pipe(\n        ...     prompt=\"a cat\",\n        ...     num_inference_steps=20,\n        ...     controlnet_conditioning_scale=controlnet_conditioning_scale,\n        ...     image=canny_image,\n        ...     ref_image=ref_image,\n        ...     reference_attn=True,\n        ...     reference_adain=True\n        ...     style_fidelity=1.0,\n        ...     generator=torch.Generator(\"cuda\").manual_seed(42)\n        ... ).images[0]\n        ```\n\"\"\"\n\n\ndef torch_dfs(model: torch.nn.Module):\n    result = [model]\n    for child in model.children():\n        result += torch_dfs(child)\n    return result\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass StableDiffusionXLControlNetReferencePipeline(StableDiffusionXLControlNetPipeline):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):\n            Second frozen text-encoder\n            ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        tokenizer_2 ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):\n            Provides additional conditioning to the `unet` during the denoising process. If you set multiple\n            ControlNets as a list, the outputs from each ControlNet are added together to create one combined\n            additional conditioning.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `\"True\"`):\n            Whether the negative prompt embeddings should always be set to 0. Also see the config of\n            `stabilityai/stable-diffusion-xl-base-1-0`.\n        add_watermarker (`bool`, *optional*):\n            Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to\n            watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no\n            watermarker is used.\n    \"\"\"\n\n    def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):\n        refimage = refimage.to(device=device)\n        needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n        if needs_upcasting:\n            self.upcast_vae()\n            refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n        if refimage.dtype != self.vae.dtype:\n            refimage = refimage.to(dtype=self.vae.dtype)\n        # encode the mask image into latents space so we can concatenate it to the latents\n        if isinstance(generator, list):\n            ref_image_latents = [\n                self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])\n                for i in range(batch_size)\n            ]\n            ref_image_latents = torch.cat(ref_image_latents, dim=0)\n        else:\n            ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)\n        ref_image_latents = self.vae.config.scaling_factor * ref_image_latents\n\n        # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method\n        if ref_image_latents.shape[0] < batch_size:\n            if not batch_size % ref_image_latents.shape[0] == 0:\n                raise ValueError(\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                    f\" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed.\"\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                )\n            ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)\n\n        ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents\n\n        # aligning device to prevent device errors when concating it with the latent model input\n        ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)\n\n        # cast back to fp16 if needed\n        if needs_upcasting:\n            self.vae.to(dtype=torch.float16)\n\n        return ref_image_latents\n\n    def prepare_ref_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        if not isinstance(image, torch.Tensor):\n            if isinstance(image, PIL.Image.Image):\n                image = [image]\n\n            if isinstance(image[0], PIL.Image.Image):\n                images = []\n\n                for image_ in image:\n                    image_ = image_.convert(\"RGB\")\n                    image_ = image_.resize((width, height), resample=PIL_INTERPOLATION[\"lanczos\"])\n                    image_ = np.array(image_)\n                    image_ = image_[None, :]\n                    images.append(image_)\n\n                image = images\n\n                image = np.concatenate(image, axis=0)\n                image = np.array(image).astype(np.float32) / 255.0\n                image = (image - 0.5) / 0.5\n                image = image.transpose(0, 3, 1, 2)\n                image = torch.from_numpy(image)\n\n            elif isinstance(image[0], torch.Tensor):\n                image = torch.stack(image, dim=0)\n\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    def check_ref_inputs(\n        self,\n        ref_image,\n        reference_guidance_start,\n        reference_guidance_end,\n        style_fidelity,\n        reference_attn,\n        reference_adain,\n    ):\n        ref_image_is_pil = isinstance(ref_image, PIL.Image.Image)\n        ref_image_is_tensor = isinstance(ref_image, torch.Tensor)\n\n        if not ref_image_is_pil and not ref_image_is_tensor:\n            raise TypeError(\n                f\"ref image must be passed and be one of PIL image or torch tensor, but is {type(ref_image)}\"\n            )\n\n        if not reference_attn and not reference_adain:\n            raise ValueError(\"`reference_attn` or `reference_adain` must be True.\")\n\n        if style_fidelity < 0.0:\n            raise ValueError(f\"style fidelity: {style_fidelity} can't be smaller than 0.\")\n        if style_fidelity > 1.0:\n            raise ValueError(f\"style fidelity: {style_fidelity} can't be larger than 1.0.\")\n\n        if reference_guidance_start >= reference_guidance_end:\n            raise ValueError(\n                f\"reference guidance start: {reference_guidance_start} cannot be larger or equal to reference guidance end: {reference_guidance_end}.\"\n            )\n        if reference_guidance_start < 0.0:\n            raise ValueError(f\"reference guidance start: {reference_guidance_start} can't be smaller than 0.\")\n        if reference_guidance_end > 1.0:\n            raise ValueError(f\"reference guidance end: {reference_guidance_end} can't be larger than 1.0.\")\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        image: PipelineImageInput = None,\n        ref_image: Union[torch.Tensor, PIL.Image.Image] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        sigmas: List[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        original_size: Tuple[int, int] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Tuple[int, int] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        attention_auto_machine_weight: float = 1.0,\n        gn_auto_machine_weight: float = 1.0,\n        reference_guidance_start: float = 0.0,\n        reference_guidance_end: float = 1.0,\n        style_fidelity: float = 0.5,\n        reference_attn: bool = True,\n        reference_adain: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is\n                specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted\n                as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or\n                width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,\n                images must be passed as a list such that each element of the list can be correctly batched for input\n                to a single ControlNet.\n            ref_image (`torch.Tensor`, `PIL.Image.Image`):\n                The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If\n                the type is specified as `Torch.Tensor`, it is passed to Reference Control as is. `PIL.Image.Image` can\n                also be accepted as an image.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image. Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image. Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise as determined by the discrete timesteps selected by the\n                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a\n                \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`\n                and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, pooled text embeddings are generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt\n                weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input\n                argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set\n                the corresponding scale as a list.\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                The ControlNet encoder tries to recognize the content of the input image even if you remove all\n                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the ControlNet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the ControlNet stops applying.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            attention_auto_machine_weight (`float`):\n                Weight of using reference query for self attention's context.\n                If attention_auto_machine_weight=1.0, use reference query for all self attention's context.\n            gn_auto_machine_weight (`float`):\n                Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.\n            reference_guidance_start (`float`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the reference ControlNet starts applying.\n            reference_guidance_end (`float`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the reference ControlNet stops applying.\n            style_fidelity (`float`):\n                style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,\n                elif style_fidelity=0.0, prompt more important, else balanced.\n            reference_attn (`bool`):\n                Whether to use reference query for self attention's context.\n            reference_adain (`bool`):\n                Whether to use reference adain.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned containing the output images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            image,\n            callback_steps,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            negative_pooled_prompt_embeds,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self.check_ref_inputs(\n            ref_image,\n            reference_guidance_start,\n            reference_guidance_end,\n            style_fidelity,\n            reference_attn,\n            reference_adain,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        global_pool_conditions = (\n            controlnet.config.global_pool_conditions\n            if isinstance(controlnet, ControlNetModel)\n            else controlnet.nets[0].config.global_pool_conditions\n        )\n        guess_mode = guess_mode or global_pool_conditions\n\n        # 3.1 Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt,\n            prompt_2,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # 3.2 Encode ip_adapter_image\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 4. Prepare image\n        if isinstance(controlnet, ControlNetModel):\n            image = self.prepare_image(\n                image=image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n            height, width = image.shape[-2:]\n        elif isinstance(controlnet, MultiControlNetModel):\n            images = []\n\n            for image_ in image:\n                image_ = self.prepare_image(\n                    image=image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                images.append(image_)\n\n            image = images\n            height, width = image[0].shape[-2:]\n        else:\n            assert False\n\n        # 5. Preprocess reference image\n        ref_image = self.prepare_ref_image(\n            image=ref_image,\n            width=width,\n            height=height,\n            batch_size=batch_size * num_images_per_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            device=device,\n            dtype=prompt_embeds.dtype,\n        )\n\n        # 6. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler, num_inference_steps, device, timesteps, sigmas\n        )\n        self._num_timesteps = len(timesteps)\n\n        # 7. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 7.5 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 8. Prepare reference latent variables\n        ref_image_latents = self.prepare_ref_latents(\n            ref_image,\n            batch_size * num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            self.do_classifier_free_guidance,\n        )\n\n        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 9.1 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        reference_keeps = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n            reference_keep = 1.0 - float(\n                i / len(timesteps) < reference_guidance_start or (i + 1) / len(timesteps) > reference_guidance_end\n            )\n            reference_keeps.append(reference_keep)\n\n        # 9.2 Modify self attention and group norm\n        MODE = \"write\"\n        uc_mask = (\n            torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)\n            .type_as(ref_image_latents)\n            .bool()\n        )\n\n        do_classifier_free_guidance = self.do_classifier_free_guidance\n\n        def hacked_basic_transformer_inner_forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            encoder_attention_mask: Optional[torch.Tensor] = None,\n            timestep: Optional[torch.LongTensor] = None,\n            cross_attention_kwargs: Dict[str, Any] = None,\n            class_labels: Optional[torch.LongTensor] = None,\n        ):\n            if self.use_ada_layer_norm:\n                norm_hidden_states = self.norm1(hidden_states, timestep)\n            elif self.use_ada_layer_norm_zero:\n                norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(\n                    hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype\n                )\n            else:\n                norm_hidden_states = self.norm1(hidden_states)\n\n            # 1. Self-Attention\n            cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n            if self.only_cross_attention:\n                attn_output = self.attn1(\n                    norm_hidden_states,\n                    encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                    attention_mask=attention_mask,\n                    **cross_attention_kwargs,\n                )\n            else:\n                if MODE == \"write\":\n                    self.bank.append(norm_hidden_states.detach().clone())\n                    attn_output = self.attn1(\n                        norm_hidden_states,\n                        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                        attention_mask=attention_mask,\n                        **cross_attention_kwargs,\n                    )\n                if MODE == \"read\":\n                    if attention_auto_machine_weight > self.attn_weight:\n                        attn_output_uc = self.attn1(\n                            norm_hidden_states,\n                            encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),\n                            # attention_mask=attention_mask,\n                            **cross_attention_kwargs,\n                        )\n                        attn_output_c = attn_output_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            attn_output_c[uc_mask] = self.attn1(\n                                norm_hidden_states[uc_mask],\n                                encoder_hidden_states=norm_hidden_states[uc_mask],\n                                **cross_attention_kwargs,\n                            )\n                        attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc\n                        self.bank.clear()\n                    else:\n                        attn_output = self.attn1(\n                            norm_hidden_states,\n                            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                            attention_mask=attention_mask,\n                            **cross_attention_kwargs,\n                        )\n            if self.use_ada_layer_norm_zero:\n                attn_output = gate_msa.unsqueeze(1) * attn_output\n            hidden_states = attn_output + hidden_states\n\n            if self.attn2 is not None:\n                norm_hidden_states = (\n                    self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)\n                )\n\n                # 2. Cross-Attention\n                attn_output = self.attn2(\n                    norm_hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=encoder_attention_mask,\n                    **cross_attention_kwargs,\n                )\n                hidden_states = attn_output + hidden_states\n\n            # 3. Feed-forward\n            norm_hidden_states = self.norm3(hidden_states)\n\n            if self.use_ada_layer_norm_zero:\n                norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]\n\n            ff_output = self.ff(norm_hidden_states)\n\n            if self.use_ada_layer_norm_zero:\n                ff_output = gate_mlp.unsqueeze(1) * ff_output\n\n            hidden_states = ff_output + hidden_states\n\n            return hidden_states\n\n        def hacked_mid_forward(self, *args, **kwargs):\n            eps = 1e-6\n            x = self.original_forward(*args, **kwargs)\n            if MODE == \"write\":\n                if gn_auto_machine_weight >= self.gn_weight:\n                    var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)\n                    self.mean_bank.append(mean)\n                    self.var_bank.append(var)\n            if MODE == \"read\":\n                if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                    var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)\n                    std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                    mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))\n                    var_acc = sum(self.var_bank) / float(len(self.var_bank))\n                    std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                    x_uc = (((x - mean) / std) * std_acc) + mean_acc\n                    x_c = x_uc.clone()\n                    if do_classifier_free_guidance and style_fidelity > 0:\n                        x_c[uc_mask] = x[uc_mask]\n                    x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc\n                self.mean_bank = []\n                self.var_bank = []\n            return x\n\n        def hack_CrossAttnDownBlock2D_forward(\n            self,\n            hidden_states: torch.Tensor,\n            temb: Optional[torch.Tensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            encoder_attention_mask: Optional[torch.Tensor] = None,\n        ):\n            eps = 1e-6\n\n            # TODO(Patrick, William) - attention mask is not used\n            output_states = ()\n\n            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n                output_states = output_states + (hidden_states,)\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.downsamplers is not None:\n                for downsampler in self.downsamplers:\n                    hidden_states = downsampler(hidden_states)\n\n                output_states = output_states + (hidden_states,)\n\n            return hidden_states, output_states\n\n        def hacked_DownBlock2D_forward(self, hidden_states, temb=None, *args, **kwargs):\n            eps = 1e-6\n\n            output_states = ()\n\n            for i, resnet in enumerate(self.resnets):\n                hidden_states = resnet(hidden_states, temb)\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n                output_states = output_states + (hidden_states,)\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.downsamplers is not None:\n                for downsampler in self.downsamplers:\n                    hidden_states = downsampler(hidden_states)\n\n                output_states = output_states + (hidden_states,)\n\n            return hidden_states, output_states\n\n        def hacked_CrossAttnUpBlock2D_forward(\n            self,\n            hidden_states: torch.Tensor,\n            res_hidden_states_tuple: Tuple[torch.Tensor, ...],\n            temb: Optional[torch.Tensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            upsample_size: Optional[int] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            encoder_attention_mask: Optional[torch.Tensor] = None,\n        ):\n            eps = 1e-6\n            # TODO(Patrick, William) - attention mask is not used\n            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):\n                # pop res hidden states\n                res_hidden_states = res_hidden_states_tuple[-1]\n                res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.upsamplers is not None:\n                for upsampler in self.upsamplers:\n                    hidden_states = upsampler(hidden_states, upsample_size)\n\n            return hidden_states\n\n        def hacked_UpBlock2D_forward(\n            self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, *args, **kwargs\n        ):\n            eps = 1e-6\n            for i, resnet in enumerate(self.resnets):\n                # pop res hidden states\n                res_hidden_states = res_hidden_states_tuple[-1]\n                res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n                hidden_states = resnet(hidden_states, temb)\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.upsamplers is not None:\n                for upsampler in self.upsamplers:\n                    hidden_states = upsampler(hidden_states, upsample_size)\n\n            return hidden_states\n\n        if reference_attn:\n            attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]\n            attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])\n\n            for i, module in enumerate(attn_modules):\n                module._original_inner_forward = module.forward\n                module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)\n                module.bank = []\n                module.attn_weight = float(i) / float(len(attn_modules))\n\n        if reference_adain:\n            gn_modules = [self.unet.mid_block]\n            self.unet.mid_block.gn_weight = 0\n\n            down_blocks = self.unet.down_blocks\n            for w, module in enumerate(down_blocks):\n                module.gn_weight = 1.0 - float(w) / float(len(down_blocks))\n                gn_modules.append(module)\n\n            up_blocks = self.unet.up_blocks\n            for w, module in enumerate(up_blocks):\n                module.gn_weight = float(w) / float(len(up_blocks))\n                gn_modules.append(module)\n\n            for i, module in enumerate(gn_modules):\n                if getattr(module, \"original_forward\", None) is None:\n                    module.original_forward = module.forward\n                if i == 0:\n                    # mid_block\n                    module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)\n                elif isinstance(module, CrossAttnDownBlock2D):\n                    module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)\n                elif isinstance(module, DownBlock2D):\n                    module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)\n                elif isinstance(module, CrossAttnUpBlock2D):\n                    module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)\n                elif isinstance(module, UpBlock2D):\n                    module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)\n                module.mean_bank = []\n                module.var_bank = []\n                module.gn_weight *= 2\n\n        # 9.2 Prepare added time ids & embeddings\n        if isinstance(image, list):\n            original_size = original_size or image[0].shape[-2:]\n        else:\n            original_size = original_size or image.shape[-2:]\n        target_size = target_size or (height, width)\n\n        add_text_embeds = pooled_prompt_embeds\n        if self.text_encoder_2 is None:\n            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        else:\n            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n        add_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n\n        if negative_original_size is not None and negative_target_size is not None:\n            negative_add_time_ids = self._get_add_time_ids(\n                negative_original_size,\n                negative_crops_coords_top_left,\n                negative_target_size,\n                dtype=prompt_embeds.dtype,\n                text_encoder_projection_dim=text_encoder_projection_dim,\n            )\n        else:\n            negative_add_time_ids = add_time_ids\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        # 10. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n\n        # 10.1 Apply denoising_end\n        if (\n            self.denoising_end is not None\n            and isinstance(self.denoising_end, float)\n            and self.denoising_end > 0\n            and self.denoising_end < 1\n        ):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        is_unet_compiled = is_compiled_module(self.unet)\n        is_controlnet_compiled = is_compiled_module(self.controlnet)\n        is_torch_higher_equal_2_1 = is_torch_version(\">=\", \"2.1\")\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # Relevant thread:\n                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428\n                if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:\n                    torch._inductor.cudagraph_mark_step_begin()\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n                # controlnet(s) inference\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                    controlnet_added_cond_kwargs = {\n                        \"text_embeds\": add_text_embeds.chunk(2)[1],\n                        \"time_ids\": add_time_ids.chunk(2)[1],\n                    }\n                else:\n                    control_model_input = latent_model_input\n                    controlnet_prompt_embeds = prompt_embeds\n                    controlnet_added_cond_kwargs = added_cond_kwargs\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=controlnet_prompt_embeds,\n                    controlnet_cond=image,\n                    conditioning_scale=cond_scale,\n                    guess_mode=guess_mode,\n                    added_cond_kwargs=controlnet_added_cond_kwargs,\n                    return_dict=False,\n                )\n\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Inferred ControlNet only for the conditional batch.\n                    # To apply the output of ControlNet to both the unconditional and conditional batches,\n                    # add 0 to the unconditional batch to keep it unchanged.\n                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])\n\n                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n                    added_cond_kwargs[\"image_embeds\"] = image_embeds\n\n                # ref only part\n                if reference_keeps[i] > 0:\n                    noise = randn_tensor(\n                        ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype\n                    )\n                    ref_xt = self.scheduler.add_noise(\n                        ref_image_latents,\n                        noise,\n                        t.reshape(\n                            1,\n                        ),\n                    )\n                    ref_xt = self.scheduler.scale_model_input(ref_xt, t)\n\n                    MODE = \"write\"\n                    self.unet(\n                        ref_xt,\n                        t,\n                        encoder_hidden_states=prompt_embeds,\n                        cross_attention_kwargs=cross_attention_kwargs,\n                        added_cond_kwargs=added_cond_kwargs,\n                        return_dict=False,\n                    )\n\n                # predict the noise residual\n                MODE = \"read\"\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                    negative_pooled_prompt_embeds = callback_outputs.pop(\n                        \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                    )\n                    add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n                    negative_add_time_ids = callback_outputs.pop(\"negative_add_time_ids\", negative_add_time_ids)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n\n            # unscale/denormalize the latents\n            # denormalize with the mean and std if available and not None\n            has_latents_mean = hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None\n            has_latents_std = hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None\n            if has_latents_mean and has_latents_std:\n                latents_mean = (\n                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents_std = (\n                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean\n            else:\n                latents = latents / self.vae.config.scaling_factor\n\n            image = self.vae.decode(latents, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n\n        if not output_type == \"latent\":\n            # apply watermark if available\n            if self.watermark is not None:\n                image = self.watermark.apply_watermark(image)\n\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/stable_diffusion_xl_reference.py",
    "content": "# Based on stable_diffusion_reference.py\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\n\nfrom diffusers import StableDiffusionXLPipeline\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.image_processor import PipelineImageInput\nfrom diffusers.models.attention import BasicTransformerBlock\nfrom diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput\nfrom diffusers.utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm  # type: ignore\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers.schedulers import UniPCMultistepScheduler\n        >>> from diffusers.utils import load_image\n\n        >>> input_image = load_image(\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg\")\n\n        >>> pipe = StableDiffusionXLReferencePipeline.from_pretrained(\n            \"stabilityai/stable-diffusion-xl-base-1.0\",\n            torch_dtype=torch.float16,\n            use_safetensors=True,\n            variant=\"fp16\").to('cuda:0')\n\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n        >>> result_img = pipe(ref_image=input_image,\n                        prompt=\"a dog\",\n                        num_inference_steps=20,\n                        reference_attn=True,\n                        reference_adain=True).images[0]\n\n        >>> result_img.show()\n        ```\n\"\"\"\n\n\ndef torch_dfs(model: torch.nn.Module):\n    result = [model]\n    for child in model.children():\n        result += torch_dfs(child)\n    return result\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg\ndef rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):\n    \"\"\"\n    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and\n    Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4\n    \"\"\"\n    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)\n    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)\n    # rescale the results from guidance (fixes overexposure)\n    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)\n    # mix with the original results from guidance by factor guidance_rescale to avoid \"plain looking\" images\n    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg\n    return noise_cfg\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):\n    def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):\n        refimage = refimage.to(device=device)\n        needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n        if needs_upcasting:\n            self.upcast_vae()\n            refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n        if refimage.dtype != self.vae.dtype:\n            refimage = refimage.to(dtype=self.vae.dtype)\n        # encode the mask image into latents space so we can concatenate it to the latents\n        if isinstance(generator, list):\n            ref_image_latents = [\n                self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])\n                for i in range(batch_size)\n            ]\n            ref_image_latents = torch.cat(ref_image_latents, dim=0)\n        else:\n            ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)\n        ref_image_latents = self.vae.config.scaling_factor * ref_image_latents\n\n        # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method\n        if ref_image_latents.shape[0] < batch_size:\n            if not batch_size % ref_image_latents.shape[0] == 0:\n                raise ValueError(\n                    \"The passed images and the required batch size don't match. Images are supposed to be duplicated\"\n                    f\" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed.\"\n                    \" Make sure the number of images that you pass is divisible by the total requested batch size.\"\n                )\n            ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)\n\n        ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents\n\n        # aligning device to prevent device errors when concating it with the latent model input\n        ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)\n\n        # cast back to fp16 if needed\n        if needs_upcasting:\n            self.vae.to(dtype=torch.float16)\n\n        return ref_image_latents\n\n    def prepare_ref_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        if not isinstance(image, torch.Tensor):\n            if isinstance(image, PIL.Image.Image):\n                image = [image]\n\n            if isinstance(image[0], PIL.Image.Image):\n                images = []\n\n                for image_ in image:\n                    image_ = image_.convert(\"RGB\")\n                    image_ = image_.resize((width, height), resample=PIL_INTERPOLATION[\"lanczos\"])\n                    image_ = np.array(image_)\n                    image_ = image_[None, :]\n                    images.append(image_)\n\n                image = images\n\n                image = np.concatenate(image, axis=0)\n                image = np.array(image).astype(np.float32) / 255.0\n                image = (image - 0.5) / 0.5\n                image = image.transpose(0, 3, 1, 2)\n                image = torch.from_numpy(image)\n\n            elif isinstance(image[0], torch.Tensor):\n                image = torch.stack(image, dim=0)\n\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    def check_ref_inputs(\n        self,\n        ref_image,\n        reference_guidance_start,\n        reference_guidance_end,\n        style_fidelity,\n        reference_attn,\n        reference_adain,\n    ):\n        ref_image_is_pil = isinstance(ref_image, PIL.Image.Image)\n        ref_image_is_tensor = isinstance(ref_image, torch.Tensor)\n\n        if not ref_image_is_pil and not ref_image_is_tensor:\n            raise TypeError(\n                f\"ref image must be passed and be one of PIL image or torch tensor, but is {type(ref_image)}\"\n            )\n\n        if not reference_attn and not reference_adain:\n            raise ValueError(\"`reference_attn` or `reference_adain` must be True.\")\n\n        if style_fidelity < 0.0:\n            raise ValueError(f\"style fidelity: {style_fidelity} can't be smaller than 0.\")\n        if style_fidelity > 1.0:\n            raise ValueError(f\"style fidelity: {style_fidelity} can't be larger than 1.0.\")\n\n        if reference_guidance_start >= reference_guidance_end:\n            raise ValueError(\n                f\"reference guidance start: {reference_guidance_start} cannot be larger or equal to reference guidance end: {reference_guidance_end}.\"\n            )\n        if reference_guidance_start < 0.0:\n            raise ValueError(f\"reference guidance start: {reference_guidance_start} can't be smaller than 0.\")\n        if reference_guidance_end > 1.0:\n            raise ValueError(f\"reference guidance end: {reference_guidance_end} can't be larger than 1.0.\")\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        ref_image: Union[torch.Tensor, PIL.Image.Image] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        sigmas: List[float] = None,\n        denoising_end: Optional[float] = None,\n        guidance_scale: float = 5.0,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        negative_prompt_2: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guidance_rescale: float = 0.0,\n        original_size: Optional[Tuple[int, int]] = None,\n        crops_coords_top_left: Tuple[int, int] = (0, 0),\n        target_size: Optional[Tuple[int, int]] = None,\n        negative_original_size: Optional[Tuple[int, int]] = None,\n        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),\n        negative_target_size: Optional[Tuple[int, int]] = None,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        attention_auto_machine_weight: float = 1.0,\n        gn_auto_machine_weight: float = 1.0,\n        reference_guidance_start: float = 0.0,\n        reference_guidance_end: float = 1.0,\n        style_fidelity: float = 0.5,\n        reference_attn: bool = True,\n        reference_adain: bool = True,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in both text-encoders\n            ref_image (`torch.Tensor`, `PIL.Image.Image`):\n                The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If\n                the type is specified as `Torch.Tensor`, it is passed to Reference Control as is. `PIL.Image.Image` can\n                also be accepted as an image.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n                Anything below 512 pixels won't work well for\n                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)\n                and checkpoints that are not specifically fine-tuned on low resolutions.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            denoising_end (`float`, *optional*):\n                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be\n                completed before it is intentionally prematurely terminated. As a result, the returned sample will\n                still retain a substantial amount of noise as determined by the discrete timesteps selected by the\n                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a\n                \"Mixture of Denoisers\" multi-pipeline setup, as elaborated in [**Refining the Image\n                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)\n            guidance_scale (`float`, *optional*, defaults to 5.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            negative_prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and\n                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`\n                input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead\n                of a plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            guidance_rescale (`float`, *optional*, defaults to 0.0):\n                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are\n                Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of\n                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891).\n                Guidance rescale factor should fix overexposure when using zero terminal SNR.\n            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.\n                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as\n                explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                `crops_coords_top_left` can be used to generate an image that appears to be \"cropped\" from the position\n                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting\n                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                For most cases, `target_size` should be set to the desired height and width of the generated image. If\n                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in\n                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).\n            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a specific image resolution. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):\n                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's\n                micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):\n                To negatively condition the generation process based on a target image resolution. It should be as same\n                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of\n                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more\n                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            attention_auto_machine_weight (`float`):\n                Weight of using reference query for self attention's context.\n                If attention_auto_machine_weight=1.0, use reference query for all self attention's context.\n            gn_auto_machine_weight (`float`):\n                Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.\n            reference_guidance_start (`float`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the reference ControlNet starts applying.\n            reference_guidance_end (`float`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the reference ControlNet stops applying.\n            style_fidelity (`float`):\n                style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,\n                elif style_fidelity=0.0, prompt more important, else balanced.\n            reference_attn (`bool`):\n                Whether to use reference query for self attention's context.\n            reference_adain (`bool`):\n                Whether to use reference adain.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a\n            `tuple`. When returning a tuple, the first element is a list with the generated images.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        # 0. Default height and width to unet\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        original_size = original_size or (height, width)\n        target_size = target_size or (height, width)\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            height,\n            width,\n            callback_steps,\n            negative_prompt,\n            negative_prompt_2,\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self.check_ref_inputs(\n            ref_image,\n            reference_guidance_start,\n            reference_guidance_end,\n            style_fidelity,\n            reference_attn,\n            reference_adain,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._guidance_rescale = guidance_rescale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._denoising_end = denoising_end\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=self.do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            negative_prompt_2=negative_prompt_2,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,\n            lora_scale=lora_scale,\n            clip_skip=self.clip_skip,\n        )\n\n        # 4. Preprocess reference image\n        ref_image = self.prepare_ref_image(\n            image=ref_image,\n            width=width,\n            height=height,\n            batch_size=batch_size * num_images_per_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            device=device,\n            dtype=prompt_embeds.dtype,\n        )\n\n        # 5. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler, num_inference_steps, device, timesteps, sigmas\n        )\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 7. Prepare reference latent variables\n        ref_image_latents = self.prepare_ref_latents(\n            ref_image,\n            batch_size * num_images_per_prompt,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            self.do_classifier_free_guidance,\n        )\n\n        # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 8.1 Create tensor stating which reference controlnets to keep\n        reference_keeps = []\n        for i in range(len(timesteps)):\n            reference_keep = 1.0 - float(\n                i / len(timesteps) < reference_guidance_start or (i + 1) / len(timesteps) > reference_guidance_end\n            )\n            reference_keeps.append(reference_keep)\n\n        # 8.2 Modify self attention and group norm\n        MODE = \"write\"\n        uc_mask = (\n            torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)\n            .type_as(ref_image_latents)\n            .bool()\n        )\n\n        do_classifier_free_guidance = self.do_classifier_free_guidance\n\n        def hacked_basic_transformer_inner_forward(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            encoder_attention_mask: Optional[torch.Tensor] = None,\n            timestep: Optional[torch.LongTensor] = None,\n            cross_attention_kwargs: Dict[str, Any] = None,\n            class_labels: Optional[torch.LongTensor] = None,\n        ):\n            if self.use_ada_layer_norm:\n                norm_hidden_states = self.norm1(hidden_states, timestep)\n            elif self.use_ada_layer_norm_zero:\n                norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(\n                    hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype\n                )\n            else:\n                norm_hidden_states = self.norm1(hidden_states)\n\n            # 1. Self-Attention\n            cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n            if self.only_cross_attention:\n                attn_output = self.attn1(\n                    norm_hidden_states,\n                    encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                    attention_mask=attention_mask,\n                    **cross_attention_kwargs,\n                )\n            else:\n                if MODE == \"write\":\n                    self.bank.append(norm_hidden_states.detach().clone())\n                    attn_output = self.attn1(\n                        norm_hidden_states,\n                        encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                        attention_mask=attention_mask,\n                        **cross_attention_kwargs,\n                    )\n                if MODE == \"read\":\n                    if attention_auto_machine_weight > self.attn_weight:\n                        attn_output_uc = self.attn1(\n                            norm_hidden_states,\n                            encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1),\n                            # attention_mask=attention_mask,\n                            **cross_attention_kwargs,\n                        )\n                        attn_output_c = attn_output_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            attn_output_c[uc_mask] = self.attn1(\n                                norm_hidden_states[uc_mask],\n                                encoder_hidden_states=norm_hidden_states[uc_mask],\n                                **cross_attention_kwargs,\n                            )\n                        attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc\n                        self.bank.clear()\n                    else:\n                        attn_output = self.attn1(\n                            norm_hidden_states,\n                            encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,\n                            attention_mask=attention_mask,\n                            **cross_attention_kwargs,\n                        )\n            if self.use_ada_layer_norm_zero:\n                attn_output = gate_msa.unsqueeze(1) * attn_output\n            hidden_states = attn_output + hidden_states\n\n            if self.attn2 is not None:\n                norm_hidden_states = (\n                    self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)\n                )\n\n                # 2. Cross-Attention\n                attn_output = self.attn2(\n                    norm_hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=encoder_attention_mask,\n                    **cross_attention_kwargs,\n                )\n                hidden_states = attn_output + hidden_states\n\n            # 3. Feed-forward\n            norm_hidden_states = self.norm3(hidden_states)\n\n            if self.use_ada_layer_norm_zero:\n                norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]\n\n            ff_output = self.ff(norm_hidden_states)\n\n            if self.use_ada_layer_norm_zero:\n                ff_output = gate_mlp.unsqueeze(1) * ff_output\n\n            hidden_states = ff_output + hidden_states\n\n            return hidden_states\n\n        def hacked_mid_forward(self, *args, **kwargs):\n            eps = 1e-6\n            x = self.original_forward(*args, **kwargs)\n            if MODE == \"write\":\n                if gn_auto_machine_weight >= self.gn_weight:\n                    var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)\n                    self.mean_bank.append(mean)\n                    self.var_bank.append(var)\n            if MODE == \"read\":\n                if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                    var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)\n                    std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                    mean_acc = sum(self.mean_bank) / float(len(self.mean_bank))\n                    var_acc = sum(self.var_bank) / float(len(self.var_bank))\n                    std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                    x_uc = (((x - mean) / std) * std_acc) + mean_acc\n                    x_c = x_uc.clone()\n                    if do_classifier_free_guidance and style_fidelity > 0:\n                        x_c[uc_mask] = x[uc_mask]\n                    x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc\n                self.mean_bank = []\n                self.var_bank = []\n            return x\n\n        def hack_CrossAttnDownBlock2D_forward(\n            self,\n            hidden_states: torch.Tensor,\n            temb: Optional[torch.Tensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            encoder_attention_mask: Optional[torch.Tensor] = None,\n        ):\n            eps = 1e-6\n\n            # TODO(Patrick, William) - attention mask is not used\n            output_states = ()\n\n            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n                output_states = output_states + (hidden_states,)\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.downsamplers is not None:\n                for downsampler in self.downsamplers:\n                    hidden_states = downsampler(hidden_states)\n\n                output_states = output_states + (hidden_states,)\n\n            return hidden_states, output_states\n\n        def hacked_DownBlock2D_forward(self, hidden_states, temb=None, *args, **kwargs):\n            eps = 1e-6\n\n            output_states = ()\n\n            for i, resnet in enumerate(self.resnets):\n                hidden_states = resnet(hidden_states, temb)\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n                output_states = output_states + (hidden_states,)\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.downsamplers is not None:\n                for downsampler in self.downsamplers:\n                    hidden_states = downsampler(hidden_states)\n\n                output_states = output_states + (hidden_states,)\n\n            return hidden_states, output_states\n\n        def hacked_CrossAttnUpBlock2D_forward(\n            self,\n            hidden_states: torch.Tensor,\n            res_hidden_states_tuple: Tuple[torch.Tensor, ...],\n            temb: Optional[torch.Tensor] = None,\n            encoder_hidden_states: Optional[torch.Tensor] = None,\n            cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n            upsample_size: Optional[int] = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            encoder_attention_mask: Optional[torch.Tensor] = None,\n        ):\n            eps = 1e-6\n            # TODO(Patrick, William) - attention mask is not used\n            for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):\n                # pop res hidden states\n                res_hidden_states = res_hidden_states_tuple[-1]\n                res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.upsamplers is not None:\n                for upsampler in self.upsamplers:\n                    hidden_states = upsampler(hidden_states, upsample_size)\n\n            return hidden_states\n\n        def hacked_UpBlock2D_forward(\n            self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, *args, **kwargs\n        ):\n            eps = 1e-6\n            for i, resnet in enumerate(self.resnets):\n                # pop res hidden states\n                res_hidden_states = res_hidden_states_tuple[-1]\n                res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n                hidden_states = resnet(hidden_states, temb)\n\n                if MODE == \"write\":\n                    if gn_auto_machine_weight >= self.gn_weight:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        self.mean_bank.append([mean])\n                        self.var_bank.append([var])\n                if MODE == \"read\":\n                    if len(self.mean_bank) > 0 and len(self.var_bank) > 0:\n                        var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)\n                        std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5\n                        mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i]))\n                        var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i]))\n                        std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5\n                        hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc\n                        hidden_states_c = hidden_states_uc.clone()\n                        if do_classifier_free_guidance and style_fidelity > 0:\n                            hidden_states_c[uc_mask] = hidden_states[uc_mask]\n                        hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc\n\n            if MODE == \"read\":\n                self.mean_bank = []\n                self.var_bank = []\n\n            if self.upsamplers is not None:\n                for upsampler in self.upsamplers:\n                    hidden_states = upsampler(hidden_states, upsample_size)\n\n            return hidden_states\n\n        if reference_attn:\n            attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)]\n            attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])\n\n            for i, module in enumerate(attn_modules):\n                module._original_inner_forward = module.forward\n                module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock)\n                module.bank = []\n                module.attn_weight = float(i) / float(len(attn_modules))\n\n        if reference_adain:\n            gn_modules = [self.unet.mid_block]\n            self.unet.mid_block.gn_weight = 0\n\n            down_blocks = self.unet.down_blocks\n            for w, module in enumerate(down_blocks):\n                module.gn_weight = 1.0 - float(w) / float(len(down_blocks))\n                gn_modules.append(module)\n\n            up_blocks = self.unet.up_blocks\n            for w, module in enumerate(up_blocks):\n                module.gn_weight = float(w) / float(len(up_blocks))\n                gn_modules.append(module)\n\n            for i, module in enumerate(gn_modules):\n                if getattr(module, \"original_forward\", None) is None:\n                    module.original_forward = module.forward\n                if i == 0:\n                    # mid_block\n                    module.forward = hacked_mid_forward.__get__(module, torch.nn.Module)\n                elif isinstance(module, CrossAttnDownBlock2D):\n                    module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D)\n                elif isinstance(module, DownBlock2D):\n                    module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D)\n                elif isinstance(module, CrossAttnUpBlock2D):\n                    module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)\n                elif isinstance(module, UpBlock2D):\n                    module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D)\n                module.mean_bank = []\n                module.var_bank = []\n                module.gn_weight *= 2\n\n        # 9. Prepare added time ids & embeddings\n        add_text_embeds = pooled_prompt_embeds\n        if self.text_encoder_2 is None:\n            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])\n        else:\n            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim\n\n        add_time_ids = self._get_add_time_ids(\n            original_size,\n            crops_coords_top_left,\n            target_size,\n            dtype=prompt_embeds.dtype,\n            text_encoder_projection_dim=text_encoder_projection_dim,\n        )\n        if negative_original_size is not None and negative_target_size is not None:\n            negative_add_time_ids = self._get_add_time_ids(\n                negative_original_size,\n                negative_crops_coords_top_left,\n                negative_target_size,\n                dtype=prompt_embeds.dtype,\n                text_encoder_projection_dim=text_encoder_projection_dim,\n            )\n        else:\n            negative_add_time_ids = add_time_ids\n\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)\n            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)\n\n        prompt_embeds = prompt_embeds.to(device)\n        add_text_embeds = add_text_embeds.to(device)\n        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 10. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        # 10.1 Apply denoising_end\n        if (\n            self.denoising_end is not None\n            and isinstance(self.denoising_end, float)\n            and self.denoising_end > 0\n            and self.denoising_end < 1\n        ):\n            discrete_timestep_cutoff = int(\n                round(\n                    self.scheduler.config.num_train_timesteps\n                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)\n                )\n            )\n            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))\n            timesteps = timesteps[:num_inference_steps]\n\n        # 11. Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n                    added_cond_kwargs[\"image_embeds\"] = image_embeds\n\n                # ref only part\n                if reference_keeps[i] > 0:\n                    noise = randn_tensor(\n                        ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype\n                    )\n                    ref_xt = self.scheduler.add_noise(\n                        ref_image_latents,\n                        noise,\n                        t.reshape(\n                            1,\n                        ),\n                    )\n                    ref_xt = self.scheduler.scale_model_input(ref_xt, t)\n\n                    MODE = \"write\"\n                    self.unet(\n                        ref_xt,\n                        t,\n                        encoder_hidden_states=prompt_embeds,\n                        cross_attention_kwargs=cross_attention_kwargs,\n                        added_cond_kwargs=added_cond_kwargs,\n                        return_dict=False,\n                    )\n\n                # predict the noise residual\n                MODE = \"read\"\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:\n                    # Based on 3.4. in https://huggingface.co/papers/2305.08891\n                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n                    add_text_embeds = callback_outputs.pop(\"add_text_embeds\", add_text_embeds)\n                    negative_pooled_prompt_embeds = callback_outputs.pop(\n                        \"negative_pooled_prompt_embeds\", negative_pooled_prompt_embeds\n                    )\n                    add_time_ids = callback_outputs.pop(\"add_time_ids\", add_time_ids)\n                    negative_add_time_ids = callback_outputs.pop(\"negative_add_time_ids\", negative_add_time_ids)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if not output_type == \"latent\":\n            # make sure the VAE is in float32 mode, as it overflows in float16\n            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast\n\n            if needs_upcasting:\n                self.upcast_vae()\n                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)\n            elif latents.dtype != self.vae.dtype:\n                if torch.backends.mps.is_available():\n                    # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                    self.vae = self.vae.to(latents.dtype)\n\n            # unscale/denormalize the latents\n            # denormalize with the mean and std if available and not None\n            has_latents_mean = hasattr(self.vae.config, \"latents_mean\") and self.vae.config.latents_mean is not None\n            has_latents_std = hasattr(self.vae.config, \"latents_std\") and self.vae.config.latents_std is not None\n            if has_latents_mean and has_latents_std:\n                latents_mean = (\n                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents_std = (\n                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)\n                )\n                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean\n            else:\n                latents = latents / self.vae.config.scaling_factor\n\n            image = self.vae.decode(latents, return_dict=False)[0]\n\n            # cast back to fp16 if needed\n            if needs_upcasting:\n                self.vae.to(dtype=torch.float16)\n        else:\n            image = latents\n\n        if not output_type == \"latent\":\n            # apply watermark if available\n            if self.watermark is not None:\n                image = self.watermark.apply_watermark(image)\n\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return StableDiffusionXLPipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/stable_unclip.py",
    "content": "import types\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom transformers import CLIPTextModelWithProjection, CLIPTokenizer\nfrom transformers.models.clip.modeling_clip import CLIPTextModelOutput\n\nfrom diffusers.models import PriorTransformer\nfrom diffusers.pipelines import DiffusionPipeline, StableDiffusionImageVariationPipeline\nfrom diffusers.schedulers import UnCLIPScheduler\nfrom diffusers.utils import logging\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):\n    image = image.to(device=device)\n    image_embeddings = image  # take image as image_embeddings\n    image_embeddings = image_embeddings.unsqueeze(1)\n\n    # duplicate image embeddings for each generation per prompt, using mps friendly method\n    bs_embed, seq_len, _ = image_embeddings.shape\n    image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)\n    image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n    if do_classifier_free_guidance:\n        uncond_embeddings = torch.zeros_like(image_embeddings)\n\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        image_embeddings = torch.cat([uncond_embeddings, image_embeddings])\n\n    return image_embeddings\n\n\nclass StableUnCLIPPipeline(DiffusionPipeline):\n    def __init__(\n        self,\n        prior: PriorTransformer,\n        tokenizer: CLIPTokenizer,\n        text_encoder: CLIPTextModelWithProjection,\n        prior_scheduler: UnCLIPScheduler,\n        decoder_pipe_kwargs: Optional[dict] = None,\n    ):\n        super().__init__()\n\n        decoder_pipe_kwargs = {\"image_encoder\": None} if decoder_pipe_kwargs is None else decoder_pipe_kwargs\n\n        decoder_pipe_kwargs[\"torch_dtype\"] = decoder_pipe_kwargs.get(\"torch_dtype\", None) or prior.dtype\n\n        self.decoder_pipe = StableDiffusionImageVariationPipeline.from_pretrained(\n            \"lambdalabs/sd-image-variations-diffusers\", **decoder_pipe_kwargs\n        )\n\n        # replace `_encode_image` method\n        self.decoder_pipe._encode_image = types.MethodType(_encode_image, self.decoder_pipe)\n\n        self.register_modules(\n            prior=prior,\n            tokenizer=tokenizer,\n            text_encoder=text_encoder,\n            prior_scheduler=prior_scheduler,\n        )\n\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,\n        text_attention_mask: Optional[torch.Tensor] = None,\n    ):\n        if text_model_output is None:\n            batch_size = len(prompt) if isinstance(prompt, list) else 1\n            # get prompt text embeddings\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            text_mask = text_inputs.attention_mask.bool().to(device)\n\n            if text_input_ids.shape[-1] > self.tokenizer.model_max_length:\n                removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n                text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]\n\n            text_encoder_output = self.text_encoder(text_input_ids.to(device))\n\n            text_embeddings = text_encoder_output.text_embeds\n            text_encoder_hidden_states = text_encoder_output.last_hidden_state\n\n        else:\n            batch_size = text_model_output[0].shape[0]\n            text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]\n            text_mask = text_attention_mask\n\n        text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)\n        text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n        text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)\n\n        if do_classifier_free_guidance:\n            uncond_tokens = [\"\"] * batch_size\n\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            uncond_text_mask = uncond_input.attention_mask.bool().to(device)\n            uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))\n\n            uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds\n            uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n\n            seq_len = uncond_embeddings.shape[1]\n            uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)\n            uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)\n\n            seq_len = uncond_text_encoder_hidden_states.shape[1]\n            uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)\n            uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(\n                batch_size * num_images_per_prompt, seq_len, -1\n            )\n            uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)\n\n            # done duplicates\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n            text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])\n\n            text_mask = torch.cat([uncond_text_mask, text_mask])\n\n        return text_embeddings, text_encoder_hidden_states, text_mask\n\n    @property\n    def _execution_device(self):\n        r\"\"\"\n        Returns the device on which the pipeline's models will be executed. After calling\n        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module\n        hooks.\n        \"\"\"\n        if self.device != torch.device(\"meta\") or not hasattr(self.prior, \"_hf_hook\"):\n            return self.device\n        for module in self.prior.modules():\n            if (\n                hasattr(module, \"_hf_hook\")\n                and hasattr(module._hf_hook, \"execution_device\")\n                and module._hf_hook.execution_device is not None\n            ):\n                return torch.device(module._hf_hook.execution_device)\n        return self.device\n\n    def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            if latents.shape != shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {shape}\")\n            latents = latents.to(device)\n\n        latents = latents * scheduler.init_noise_sigma\n        return latents\n\n    def to(self, torch_device: Optional[Union[str, torch.device]] = None):\n        self.decoder_pipe.to(torch_device)\n        super().to(torch_device)\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Optional[Union[str, List[str]]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_images_per_prompt: int = 1,\n        prior_num_inference_steps: int = 25,\n        generator: torch.Generator | None = None,\n        prior_latents: Optional[torch.Tensor] = None,\n        text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,\n        text_attention_mask: Optional[torch.Tensor] = None,\n        prior_guidance_scale: float = 4.0,\n        decoder_guidance_scale: float = 8.0,\n        decoder_num_inference_steps: int = 50,\n        decoder_num_images_per_prompt: Optional[int] = 1,\n        decoder_eta: float = 0.0,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n    ):\n        if prompt is not None:\n            if isinstance(prompt, str):\n                batch_size = 1\n            elif isinstance(prompt, list):\n                batch_size = len(prompt)\n            else:\n                raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        else:\n            batch_size = text_model_output[0].shape[0]\n\n        device = self._execution_device\n\n        batch_size = batch_size * num_images_per_prompt\n\n        do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0\n\n        text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(\n            prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask\n        )\n\n        # prior\n\n        self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)\n        prior_timesteps_tensor = self.prior_scheduler.timesteps\n\n        embedding_dim = self.prior.config.embedding_dim\n\n        prior_latents = self.prepare_latents(\n            (batch_size, embedding_dim),\n            text_embeddings.dtype,\n            device,\n            generator,\n            prior_latents,\n            self.prior_scheduler,\n        )\n\n        for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents\n\n            predicted_image_embedding = self.prior(\n                latent_model_input,\n                timestep=t,\n                proj_embedding=text_embeddings,\n                encoder_hidden_states=text_encoder_hidden_states,\n                attention_mask=text_mask,\n            ).predicted_image_embedding\n\n            if do_classifier_free_guidance:\n                predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)\n                predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (\n                    predicted_image_embedding_text - predicted_image_embedding_uncond\n                )\n\n            if i + 1 == prior_timesteps_tensor.shape[0]:\n                prev_timestep = None\n            else:\n                prev_timestep = prior_timesteps_tensor[i + 1]\n\n            prior_latents = self.prior_scheduler.step(\n                predicted_image_embedding,\n                timestep=t,\n                sample=prior_latents,\n                generator=generator,\n                prev_timestep=prev_timestep,\n            ).prev_sample\n\n        prior_latents = self.prior.post_process_latents(prior_latents)\n\n        image_embeddings = prior_latents\n\n        output = self.decoder_pipe(\n            image=image_embeddings,\n            height=height,\n            width=width,\n            num_inference_steps=decoder_num_inference_steps,\n            guidance_scale=decoder_guidance_scale,\n            generator=generator,\n            output_type=output_type,\n            return_dict=return_dict,\n            num_images_per_prompt=decoder_num_images_per_prompt,\n            eta=decoder_eta,\n        )\n        return output\n"
  },
  {
    "path": "examples/community/text_inpainting.py",
    "content": "from typing import Callable, List, Optional, Union\n\nimport PIL.Image\nimport torch\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPSegForImageSegmentation,\n    CLIPSegProcessor,\n    CLIPTextModel,\n    CLIPTokenizer,\n)\n\nfrom diffusers import DiffusionPipeline\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler\nfrom diffusers.utils import deprecate, logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass TextInpainting(DiffusionPipeline, StableDiffusionMixin):\n    r\"\"\"\n    Pipeline for text based inpainting using Stable Diffusion.\n    Uses CLIPSeg to get a mask from the given text, then calls the Inpainting pipeline with the generated mask\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        segmentation_model ([`CLIPSegForImageSegmentation`]):\n            CLIPSeg Model to generate mask from the given text. Please refer to the [model card]() for details.\n        segmentation_processor ([`CLIPSegProcessor`]):\n            CLIPSeg processor to get image, text features to translate prompt to English, if necessary. Please refer to the\n            [model card](https://huggingface.co/docs/transformers/model_doc/clipseg) for details.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    def __init__(\n        self,\n        segmentation_model: CLIPSegForImageSegmentation,\n        segmentation_processor: CLIPSegProcessor,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if scheduler is not None and getattr(scheduler.config, \"skip_prk_steps\", True) is False:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration\"\n                \" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make\"\n                \" sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to\"\n                \" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face\"\n                \" Hub, it would be very nice if you could open a Pull request for the\"\n                \" `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"skip_prk_steps not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"skip_prk_steps\"] = True\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        self.register_modules(\n            segmentation_model=segmentation_model,\n            segmentation_processor=segmentation_processor,\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        image: Union[torch.Tensor, PIL.Image.Image],\n        text: str,\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            image (`PIL.Image.Image`):\n                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will\n                be masked out with `mask_image` and repainted according to `prompt`.\n            text (`str``):\n                The text to use to generate the mask.\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n\n        # We use the input text to generate the mask\n        inputs = self.segmentation_processor(\n            text=[text], images=[image], padding=\"max_length\", return_tensors=\"pt\"\n        ).to(self.device)\n        outputs = self.segmentation_model(**inputs)\n        mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy()\n        mask_pil = self.numpy_to_pil(mask)[0].resize(image.size)\n\n        # Run inpainting pipeline with the generated mask\n        inpainting_pipeline = StableDiffusionInpaintPipeline(\n            vae=self.vae,\n            text_encoder=self.text_encoder,\n            tokenizer=self.tokenizer,\n            unet=self.unet,\n            scheduler=self.scheduler,\n            safety_checker=self.safety_checker,\n            feature_extractor=self.feature_extractor,\n        )\n        return inpainting_pipeline(\n            prompt=prompt,\n            image=image,\n            mask_image=mask_pil,\n            height=height,\n            width=width,\n            num_inference_steps=num_inference_steps,\n            guidance_scale=guidance_scale,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            eta=eta,\n            generator=generator,\n            latents=latents,\n            output_type=output_type,\n            return_dict=return_dict,\n            callback=callback,\n            callback_steps=callback_steps,\n        )\n"
  },
  {
    "path": "examples/community/tiled_upscaling.py",
    "content": "# Copyright 2025 Peter Willemsen <peter@codebuffet.co>. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import Callable, List, Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom PIL import Image\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline\nfrom diffusers.schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler\n\n\ndef make_transparency_mask(size, overlap_pixels, remove_borders=[]):\n    size_x = size[0] - overlap_pixels * 2\n    size_y = size[1] - overlap_pixels * 2\n    for letter in [\"l\", \"r\"]:\n        if letter in remove_borders:\n            size_x += overlap_pixels\n    for letter in [\"t\", \"b\"]:\n        if letter in remove_borders:\n            size_y += overlap_pixels\n    mask = np.ones((size_y, size_x), dtype=np.uint8) * 255\n    mask = np.pad(mask, mode=\"linear_ramp\", pad_width=overlap_pixels, end_values=0)\n\n    if \"l\" in remove_borders:\n        mask = mask[:, overlap_pixels : mask.shape[1]]\n    if \"r\" in remove_borders:\n        mask = mask[:, 0 : mask.shape[1] - overlap_pixels]\n    if \"t\" in remove_borders:\n        mask = mask[overlap_pixels : mask.shape[0], :]\n    if \"b\" in remove_borders:\n        mask = mask[0 : mask.shape[0] - overlap_pixels, :]\n    return mask\n\n\ndef clamp(n, smallest, largest):\n    return max(smallest, min(n, largest))\n\n\ndef clamp_rect(rect: [int], min: [int], max: [int]):\n    return (\n        clamp(rect[0], min[0], max[0]),\n        clamp(rect[1], min[1], max[1]),\n        clamp(rect[2], min[0], max[0]),\n        clamp(rect[3], min[1], max[1]),\n    )\n\n\ndef add_overlap_rect(rect: [int], overlap: int, image_size: [int]):\n    rect = list(rect)\n    rect[0] -= overlap\n    rect[1] -= overlap\n    rect[2] += overlap\n    rect[3] += overlap\n    rect = clamp_rect(rect, [0, 0], [image_size[0], image_size[1]])\n    return rect\n\n\ndef squeeze_tile(tile, original_image, original_slice, slice_x):\n    result = Image.new(\"RGB\", (tile.size[0] + original_slice, tile.size[1]))\n    result.paste(\n        original_image.resize((tile.size[0], tile.size[1]), Image.BICUBIC).crop(\n            (slice_x, 0, slice_x + original_slice, tile.size[1])\n        ),\n        (0, 0),\n    )\n    result.paste(tile, (original_slice, 0))\n    return result\n\n\ndef unsqueeze_tile(tile, original_image_slice):\n    crop_rect = (original_image_slice * 4, 0, tile.size[0], tile.size[1])\n    tile = tile.crop(crop_rect)\n    return tile\n\n\ndef next_divisible(n, d):\n    divisor = n % d\n    return n - divisor\n\n\nclass StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):\n    r\"\"\"\n    Pipeline for tile-based text-guided image super-resolution using Stable Diffusion 2, trading memory for compute\n    to create gigantic images.\n\n    This model inherits from [`StableDiffusionUpscalePipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        low_res_scheduler ([`SchedulerMixin`]):\n            A scheduler used to add initial noise to the low res conditioning image. It must be an instance of\n            [`DDPMScheduler`].\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        low_res_scheduler: DDPMScheduler,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],\n        max_noise_level: int = 350,\n    ):\n        super().__init__(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            low_res_scheduler=low_res_scheduler,\n            scheduler=scheduler,\n            max_noise_level=max_noise_level,\n        )\n\n    def _process_tile(self, original_image_slice, x, y, tile_size, tile_border, image, final_image, **kwargs):\n        torch.manual_seed(0)\n        crop_rect = (\n            min(image.size[0] - (tile_size + original_image_slice), x * tile_size),\n            min(image.size[1] - (tile_size + original_image_slice), y * tile_size),\n            min(image.size[0], (x + 1) * tile_size),\n            min(image.size[1], (y + 1) * tile_size),\n        )\n        crop_rect_with_overlap = add_overlap_rect(crop_rect, tile_border, image.size)\n        tile = image.crop(crop_rect_with_overlap)\n        translated_slice_x = ((crop_rect[0] + ((crop_rect[2] - crop_rect[0]) / 2)) / image.size[0]) * tile.size[0]\n        translated_slice_x = translated_slice_x - (original_image_slice / 2)\n        translated_slice_x = max(0, translated_slice_x)\n        to_input = squeeze_tile(tile, image, original_image_slice, translated_slice_x)\n        orig_input_size = to_input.size\n        to_input = to_input.resize((tile_size, tile_size), Image.BICUBIC)\n        upscaled_tile = super(StableDiffusionTiledUpscalePipeline, self).__call__(image=to_input, **kwargs).images[0]\n        upscaled_tile = upscaled_tile.resize((orig_input_size[0] * 4, orig_input_size[1] * 4), Image.BICUBIC)\n        upscaled_tile = unsqueeze_tile(upscaled_tile, original_image_slice)\n        upscaled_tile = upscaled_tile.resize((tile.size[0] * 4, tile.size[1] * 4), Image.BICUBIC)\n        remove_borders = []\n        if x == 0:\n            remove_borders.append(\"l\")\n        elif crop_rect[2] == image.size[0]:\n            remove_borders.append(\"r\")\n        if y == 0:\n            remove_borders.append(\"t\")\n        elif crop_rect[3] == image.size[1]:\n            remove_borders.append(\"b\")\n        transparency_mask = Image.fromarray(\n            make_transparency_mask(\n                (upscaled_tile.size[0], upscaled_tile.size[1]), tile_border * 4, remove_borders=remove_borders\n            ),\n            mode=\"L\",\n        )\n        final_image.paste(\n            upscaled_tile, (crop_rect_with_overlap[0] * 4, crop_rect_with_overlap[1] * 4), transparency_mask\n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        image: Union[PIL.Image.Image, List[PIL.Image.Image]],\n        num_inference_steps: int = 75,\n        guidance_scale: float = 9.0,\n        noise_level: int = 50,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        tile_size: int = 128,\n        tile_border: int = 32,\n        original_image_slice: int = 32,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.Tensor`):\n                `Image`, or tensor representing an image batch which will be upscaled. *\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            tile_size (`int`, *optional*):\n                The size of the tiles. Too big can result in an OOM-error.\n            tile_border (`int`, *optional*):\n                The number of pixels around a tile to consider (bigger means less seams, too big can lead to an OOM-error).\n            original_image_slice (`int`, *optional*):\n                The amount of pixels of the original image to calculate with the current tile (bigger means more depth\n                is preserved, less blur occurs in the final image, too big can lead to an OOM-error or loss in detail).\n            callback (`Callable`, *optional*):\n                A function that take a callback function with a single argument, a dict,\n                that contains the (partially) processed image under \"image\",\n                as well as the progress (0 to 1, where 1 is completed) under \"progress\".\n\n        Returns: A PIL.Image that is 4 times larger than the original input image.\n\n        \"\"\"\n\n        final_image = Image.new(\"RGB\", (image.size[0] * 4, image.size[1] * 4))\n        tcx = math.ceil(image.size[0] / tile_size)\n        tcy = math.ceil(image.size[1] / tile_size)\n        total_tile_count = tcx * tcy\n        current_count = 0\n        for y in range(tcy):\n            for x in range(tcx):\n                self._process_tile(\n                    original_image_slice,\n                    x,\n                    y,\n                    tile_size,\n                    tile_border,\n                    image,\n                    final_image,\n                    prompt=prompt,\n                    num_inference_steps=num_inference_steps,\n                    guidance_scale=guidance_scale,\n                    noise_level=noise_level,\n                    negative_prompt=negative_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    eta=eta,\n                    generator=generator,\n                    latents=latents,\n                )\n                current_count += 1\n                if callback is not None:\n                    callback({\"progress\": current_count / total_tile_count, \"image\": final_image})\n        return final_image\n\n\ndef main():\n    # Run a demo\n    model_id = \"stabilityai/stable-diffusion-x4-upscaler\"\n    pipe = StableDiffusionTiledUpscalePipeline.from_pretrained(model_id, variant=\"fp16\", torch_dtype=torch.float16)\n    pipe = pipe.to(\"cuda\")\n    image = Image.open(\"../../docs/source/imgs/diffusers_library.jpg\")\n\n    def callback(obj):\n        print(f\"progress: {obj['progress']:.4f}\")\n        obj[\"image\"].save(\"diffusers_library_progress.jpg\")\n\n    final_image = pipe(image=image, prompt=\"Black font, white background, vector\", noise_level=40, callback=callback)\n    final_image.save(\"diffusers_library.jpg\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/community/unclip_image_interpolation.py",
    "content": "import inspect\nfrom typing import List, Optional, Union\n\nimport PIL.Image\nimport torch\nfrom torch.nn import functional as F\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers import (\n    DiffusionPipeline,\n    ImagePipelineOutput,\n    UnCLIPScheduler,\n    UNet2DConditionModel,\n    UNet2DModel,\n)\nfrom diffusers.pipelines.unclip import UnCLIPTextProjModel\nfrom diffusers.utils import logging\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef slerp(val, low, high):\n    \"\"\"\n    Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.\n    \"\"\"\n    low_norm = low / torch.norm(low)\n    high_norm = high / torch.norm(high)\n    omega = torch.acos((low_norm * high_norm))\n    so = torch.sin(omega)\n    res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high\n    return res\n\n\nclass UnCLIPImageInterpolationPipeline(DiffusionPipeline):\n    \"\"\"\n    Pipeline to generate variations from an input image using unCLIP\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        text_encoder ([`CLIPTextModelWithProjection`]):\n            Frozen text-encoder.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `image_encoder`.\n        image_encoder ([`CLIPVisionModelWithProjection`]):\n            Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),\n            specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_proj ([`UnCLIPTextProjModel`]):\n            Utility class to prepare and combine the embeddings before they are passed to the decoder.\n        decoder ([`UNet2DConditionModel`]):\n            The decoder to invert the image embedding into an image.\n        super_res_first ([`UNet2DModel`]):\n            Super resolution unet. Used in all but the last step of the super resolution diffusion process.\n        super_res_last ([`UNet2DModel`]):\n            Super resolution unet. Used in the last step of the super resolution diffusion process.\n        decoder_scheduler ([`UnCLIPScheduler`]):\n            Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.\n        super_res_scheduler ([`UnCLIPScheduler`]):\n            Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.\n\n    \"\"\"\n\n    decoder: UNet2DConditionModel\n    text_proj: UnCLIPTextProjModel\n    text_encoder: CLIPTextModelWithProjection\n    tokenizer: CLIPTokenizer\n    feature_extractor: CLIPImageProcessor\n    image_encoder: CLIPVisionModelWithProjection\n    super_res_first: UNet2DModel\n    super_res_last: UNet2DModel\n\n    decoder_scheduler: UnCLIPScheduler\n    super_res_scheduler: UnCLIPScheduler\n\n    # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.__init__\n    def __init__(\n        self,\n        decoder: UNet2DConditionModel,\n        text_encoder: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        text_proj: UnCLIPTextProjModel,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection,\n        super_res_first: UNet2DModel,\n        super_res_last: UNet2DModel,\n        decoder_scheduler: UnCLIPScheduler,\n        super_res_scheduler: UnCLIPScheduler,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            decoder=decoder,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            text_proj=text_proj,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n            super_res_first=super_res_first,\n            super_res_last=super_res_last,\n            decoder_scheduler=decoder_scheduler,\n            super_res_scheduler=super_res_scheduler,\n        )\n\n    # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents\n    def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            if latents.shape != shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {shape}\")\n            latents = latents.to(device)\n\n        latents = latents * scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_prompt\n    def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):\n        batch_size = len(prompt) if isinstance(prompt, list) else 1\n\n        # get prompt text embeddings\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        text_mask = text_inputs.attention_mask.bool().to(device)\n        text_encoder_output = self.text_encoder(text_input_ids.to(device))\n\n        prompt_embeds = text_encoder_output.text_embeds\n        text_encoder_hidden_states = text_encoder_output.last_hidden_state\n\n        prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n        text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n        text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)\n\n        if do_classifier_free_guidance:\n            uncond_tokens = [\"\"] * batch_size\n\n            max_length = text_input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            uncond_text_mask = uncond_input.attention_mask.bool().to(device)\n            negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))\n\n            negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds\n            uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n\n            seq_len = negative_prompt_embeds.shape[1]\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)\n\n            seq_len = uncond_text_encoder_hidden_states.shape[1]\n            uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)\n            uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(\n                batch_size * num_images_per_prompt, seq_len, -1\n            )\n            uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)\n\n            # done duplicates\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n            text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])\n\n            text_mask = torch.cat([uncond_text_mask, text_mask])\n\n        return prompt_embeds, text_encoder_hidden_states, text_mask\n\n    # Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_image\n    def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if image_embeddings is None:\n            if not isinstance(image, torch.Tensor):\n                image = self.feature_extractor(images=image, return_tensors=\"pt\").pixel_values\n\n            image = image.to(device=device, dtype=dtype)\n            image_embeddings = self.image_encoder(image).image_embeds\n\n        image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)\n\n        return image_embeddings\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        image: Optional[Union[List[PIL.Image.Image], torch.Tensor]] = None,\n        steps: int = 5,\n        decoder_num_inference_steps: int = 25,\n        super_res_num_inference_steps: int = 7,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        image_embeddings: Optional[torch.Tensor] = None,\n        decoder_latents: Optional[torch.Tensor] = None,\n        super_res_latents: Optional[torch.Tensor] = None,\n        decoder_guidance_scale: float = 8.0,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n    ):\n        \"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            image (`List[PIL.Image.Image]` or `torch.Tensor`):\n                The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the\n                configuration of\n                [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)\n                `CLIPImageProcessor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed.\n            steps (`int`, *optional*, defaults to 5):\n                The number of interpolation images to generate.\n            decoder_num_inference_steps (`int`, *optional*, defaults to 25):\n                The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality\n                image at the expense of slower inference.\n            super_res_num_inference_steps (`int`, *optional*, defaults to 7):\n                The number of denoising steps for super resolution. More denoising steps usually lead to a higher\n                quality image at the expense of slower inference.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            image_embeddings (`torch.Tensor`, *optional*):\n                Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings\n                can be passed for tasks like image interpolations. `image` can the be left to `None`.\n            decoder_latents (`torch.Tensor` of shape (batch size, channels, height, width), *optional*):\n                Pre-generated noisy latents to be used as inputs for the decoder.\n            super_res_latents (`torch.Tensor` of shape (batch size, channels, super res height, super res width), *optional*):\n                Pre-generated noisy latents to be used as inputs for the decoder.\n            decoder_guidance_scale (`float`, *optional*, defaults to 4.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.\n        \"\"\"\n\n        batch_size = steps\n\n        device = self._execution_device\n\n        if isinstance(image, List):\n            if len(image) != 2:\n                raise AssertionError(\n                    f\"Expected 'image' List to be of size 2, but passed 'image' length is {len(image)}\"\n                )\n            elif not (isinstance(image[0], PIL.Image.Image) and isinstance(image[0], PIL.Image.Image)):\n                raise AssertionError(\n                    f\"Expected 'image' List to contain PIL.Image.Image, but passed 'image' contents are {type(image[0])} and {type(image[1])}\"\n                )\n        elif isinstance(image, torch.Tensor):\n            if image.shape[0] != 2:\n                raise AssertionError(\n                    f\"Expected 'image' to be torch.Tensor of shape 2 in 0th dimension, but passed 'image' size is {image.shape[0]}\"\n                )\n        elif isinstance(image_embeddings, torch.Tensor):\n            if image_embeddings.shape[0] != 2:\n                raise AssertionError(\n                    f\"Expected 'image_embeddings' to be torch.Tensor of shape 2 in 0th dimension, but passed 'image_embeddings' shape is {image_embeddings.shape[0]}\"\n                )\n        else:\n            raise AssertionError(\n                f\"Expected 'image' or 'image_embeddings' to be not None with types List[PIL.Image] or torch.Tensor respectively. Received {type(image)} and {type(image_embeddings)} respectively\"\n            )\n\n        original_image_embeddings = self._encode_image(\n            image=image, device=device, num_images_per_prompt=1, image_embeddings=image_embeddings\n        )\n\n        image_embeddings = []\n\n        for interp_step in torch.linspace(0, 1, steps):\n            temp_image_embeddings = slerp(\n                interp_step, original_image_embeddings[0], original_image_embeddings[1]\n            ).unsqueeze(0)\n            image_embeddings.append(temp_image_embeddings)\n\n        image_embeddings = torch.cat(image_embeddings).to(device)\n\n        do_classifier_free_guidance = decoder_guidance_scale > 1.0\n\n        prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(\n            prompt=[\"\" for i in range(steps)],\n            device=device,\n            num_images_per_prompt=1,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n        )\n\n        text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(\n            image_embeddings=image_embeddings,\n            prompt_embeds=prompt_embeds,\n            text_encoder_hidden_states=text_encoder_hidden_states,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n        )\n\n        if device.type == \"mps\":\n            # HACK: MPS: There is a panic when padding bool tensors,\n            # so cast to int tensor for the pad and back to bool afterwards\n            text_mask = text_mask.type(torch.int)\n            decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)\n            decoder_text_mask = decoder_text_mask.type(torch.bool)\n        else:\n            decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)\n\n        self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)\n        decoder_timesteps_tensor = self.decoder_scheduler.timesteps\n\n        num_channels_latents = self.decoder.config.in_channels\n        height = self.decoder.config.sample_size\n        width = self.decoder.config.sample_size\n\n        # Get the decoder latents for 1 step and then repeat the same tensor for the entire batch to keep same noise across all interpolation steps.\n        decoder_latents = self.prepare_latents(\n            (1, num_channels_latents, height, width),\n            text_encoder_hidden_states.dtype,\n            device,\n            generator,\n            decoder_latents,\n            self.decoder_scheduler,\n        )\n        decoder_latents = decoder_latents.repeat((batch_size, 1, 1, 1))\n\n        for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents\n\n            noise_pred = self.decoder(\n                sample=latent_model_input,\n                timestep=t,\n                encoder_hidden_states=text_encoder_hidden_states,\n                class_labels=additive_clip_time_embeddings,\n                attention_mask=decoder_text_mask,\n            ).sample\n\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)\n                noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)\n                noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)\n                noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)\n\n            if i + 1 == decoder_timesteps_tensor.shape[0]:\n                prev_timestep = None\n            else:\n                prev_timestep = decoder_timesteps_tensor[i + 1]\n\n            # compute the previous noisy sample x_t -> x_t-1\n            decoder_latents = self.decoder_scheduler.step(\n                noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator\n            ).prev_sample\n\n        decoder_latents = decoder_latents.clamp(-1, 1)\n\n        image_small = decoder_latents\n\n        # done decoder\n\n        # super res\n\n        self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)\n        super_res_timesteps_tensor = self.super_res_scheduler.timesteps\n\n        channels = self.super_res_first.config.in_channels // 2\n        height = self.super_res_first.config.sample_size\n        width = self.super_res_first.config.sample_size\n\n        super_res_latents = self.prepare_latents(\n            (batch_size, channels, height, width),\n            image_small.dtype,\n            device,\n            generator,\n            super_res_latents,\n            self.super_res_scheduler,\n        )\n\n        if device.type == \"mps\":\n            # MPS does not support many interpolations\n            image_upscaled = F.interpolate(image_small, size=[height, width])\n        else:\n            interpolate_antialias = {}\n            if \"antialias\" in inspect.signature(F.interpolate).parameters:\n                interpolate_antialias[\"antialias\"] = True\n\n            image_upscaled = F.interpolate(\n                image_small, size=[height, width], mode=\"bicubic\", align_corners=False, **interpolate_antialias\n            )\n\n        for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):\n            # no classifier free guidance\n\n            if i == super_res_timesteps_tensor.shape[0] - 1:\n                unet = self.super_res_last\n            else:\n                unet = self.super_res_first\n\n            latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)\n\n            noise_pred = unet(\n                sample=latent_model_input,\n                timestep=t,\n            ).sample\n\n            if i + 1 == super_res_timesteps_tensor.shape[0]:\n                prev_timestep = None\n            else:\n                prev_timestep = super_res_timesteps_tensor[i + 1]\n\n            # compute the previous noisy sample x_t -> x_t-1\n            super_res_latents = self.super_res_scheduler.step(\n                noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator\n            ).prev_sample\n\n        image = super_res_latents\n        # done super res\n\n        # post processing\n\n        image = image * 0.5 + 0.5\n        image = image.clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image,)\n\n        return ImagePipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/unclip_text_interpolation.py",
    "content": "import inspect\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom torch.nn import functional as F\nfrom transformers import CLIPTextModelWithProjection, CLIPTokenizer\nfrom transformers.models.clip.modeling_clip import CLIPTextModelOutput\n\nfrom diffusers import (\n    DiffusionPipeline,\n    ImagePipelineOutput,\n    PriorTransformer,\n    UnCLIPScheduler,\n    UNet2DConditionModel,\n    UNet2DModel,\n)\nfrom diffusers.pipelines.unclip import UnCLIPTextProjModel\nfrom diffusers.utils import logging\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef slerp(val, low, high):\n    \"\"\"\n    Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.\n    \"\"\"\n    low_norm = low / torch.norm(low)\n    high_norm = high / torch.norm(high)\n    omega = torch.acos((low_norm * high_norm))\n    so = torch.sin(omega)\n    res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high\n    return res\n\n\nclass UnCLIPTextInterpolationPipeline(DiffusionPipeline):\n    \"\"\"\n    Pipeline for prompt-to-prompt interpolation on CLIP text embeddings and using the UnCLIP / Dall-E to decode them to images.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        text_encoder ([`CLIPTextModelWithProjection`]):\n            Frozen text-encoder.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        prior ([`PriorTransformer`]):\n            The canonical unCLIP prior to approximate the image embedding from the text embedding.\n        text_proj ([`UnCLIPTextProjModel`]):\n            Utility class to prepare and combine the embeddings before they are passed to the decoder.\n        decoder ([`UNet2DConditionModel`]):\n            The decoder to invert the image embedding into an image.\n        super_res_first ([`UNet2DModel`]):\n            Super resolution unet. Used in all but the last step of the super resolution diffusion process.\n        super_res_last ([`UNet2DModel`]):\n            Super resolution unet. Used in the last step of the super resolution diffusion process.\n        prior_scheduler ([`UnCLIPScheduler`]):\n            Scheduler used in the prior denoising process. Just a modified DDPMScheduler.\n        decoder_scheduler ([`UnCLIPScheduler`]):\n            Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.\n        super_res_scheduler ([`UnCLIPScheduler`]):\n            Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.\n\n    \"\"\"\n\n    prior: PriorTransformer\n    decoder: UNet2DConditionModel\n    text_proj: UnCLIPTextProjModel\n    text_encoder: CLIPTextModelWithProjection\n    tokenizer: CLIPTokenizer\n    super_res_first: UNet2DModel\n    super_res_last: UNet2DModel\n\n    prior_scheduler: UnCLIPScheduler\n    decoder_scheduler: UnCLIPScheduler\n    super_res_scheduler: UnCLIPScheduler\n\n    # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.__init__\n    def __init__(\n        self,\n        prior: PriorTransformer,\n        decoder: UNet2DConditionModel,\n        text_encoder: CLIPTextModelWithProjection,\n        tokenizer: CLIPTokenizer,\n        text_proj: UnCLIPTextProjModel,\n        super_res_first: UNet2DModel,\n        super_res_last: UNet2DModel,\n        prior_scheduler: UnCLIPScheduler,\n        decoder_scheduler: UnCLIPScheduler,\n        super_res_scheduler: UnCLIPScheduler,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            prior=prior,\n            decoder=decoder,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            text_proj=text_proj,\n            super_res_first=super_res_first,\n            super_res_last=super_res_last,\n            prior_scheduler=prior_scheduler,\n            decoder_scheduler=decoder_scheduler,\n            super_res_scheduler=super_res_scheduler,\n        )\n\n    # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents\n    def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            if latents.shape != shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {shape}\")\n            latents = latents.to(device)\n\n        latents = latents * scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,\n        text_attention_mask: Optional[torch.Tensor] = None,\n    ):\n        if text_model_output is None:\n            batch_size = len(prompt) if isinstance(prompt, list) else 1\n            # get prompt text embeddings\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            text_mask = text_inputs.attention_mask.bool().to(device)\n\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n                text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]\n\n            text_encoder_output = self.text_encoder(text_input_ids.to(device))\n\n            prompt_embeds = text_encoder_output.text_embeds\n            text_encoder_hidden_states = text_encoder_output.last_hidden_state\n\n        else:\n            batch_size = text_model_output[0].shape[0]\n            prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]\n            text_mask = text_attention_mask\n\n        prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n        text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n        text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)\n\n        if do_classifier_free_guidance:\n            uncond_tokens = [\"\"] * batch_size\n\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            uncond_text_mask = uncond_input.attention_mask.bool().to(device)\n            negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))\n\n            negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds\n            uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n\n            seq_len = negative_prompt_embeds.shape[1]\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)\n\n            seq_len = uncond_text_encoder_hidden_states.shape[1]\n            uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)\n            uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(\n                batch_size * num_images_per_prompt, seq_len, -1\n            )\n            uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)\n\n            # done duplicates\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n            text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])\n\n            text_mask = torch.cat([uncond_text_mask, text_mask])\n\n        return prompt_embeds, text_encoder_hidden_states, text_mask\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        start_prompt: str,\n        end_prompt: str,\n        steps: int = 5,\n        prior_num_inference_steps: int = 25,\n        decoder_num_inference_steps: int = 25,\n        super_res_num_inference_steps: int = 7,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        prior_guidance_scale: float = 4.0,\n        decoder_guidance_scale: float = 8.0,\n        enable_sequential_cpu_offload=True,\n        gpu_id=0,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n    ):\n        \"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            start_prompt (`str`):\n                The prompt to start the image generation interpolation from.\n            end_prompt (`str`):\n                The prompt to end the image generation interpolation at.\n            steps (`int`, *optional*, defaults to 5):\n                The number of steps over which to interpolate from start_prompt to end_prompt. The pipeline returns\n                the same number of images as this value.\n            prior_num_inference_steps (`int`, *optional*, defaults to 25):\n                The number of denoising steps for the prior. More denoising steps usually lead to a higher quality\n                image at the expense of slower inference.\n            decoder_num_inference_steps (`int`, *optional*, defaults to 25):\n                The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality\n                image at the expense of slower inference.\n            super_res_num_inference_steps (`int`, *optional*, defaults to 7):\n                The number of denoising steps for super resolution. More denoising steps usually lead to a higher\n                quality image at the expense of slower inference.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            prior_guidance_scale (`float`, *optional*, defaults to 4.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            decoder_guidance_scale (`float`, *optional*, defaults to 4.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            enable_sequential_cpu_offload (`bool`, *optional*, defaults to `True`):\n                If True, offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's\n                models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only\n                when their specific submodule has its `forward` method called.\n            gpu_id (`int`, *optional*, defaults to `0`):\n                The gpu_id to be passed to enable_sequential_cpu_offload. Only works when enable_sequential_cpu_offload is set to True.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.\n        \"\"\"\n\n        if not isinstance(start_prompt, str) or not isinstance(end_prompt, str):\n            raise ValueError(\n                f\"`start_prompt` and `end_prompt` should be of type `str` but got {type(start_prompt)} and\"\n                f\" {type(end_prompt)} instead\"\n            )\n\n        if enable_sequential_cpu_offload:\n            self.enable_sequential_cpu_offload(gpu_id=gpu_id)\n\n        device = self._execution_device\n\n        # Turn the prompts into embeddings.\n        inputs = self.tokenizer(\n            [start_prompt, end_prompt],\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        )\n        inputs.to(device)\n        text_model_output = self.text_encoder(**inputs)\n\n        text_attention_mask = torch.max(inputs.attention_mask[0], inputs.attention_mask[1])\n        text_attention_mask = torch.cat([text_attention_mask.unsqueeze(0)] * steps).to(device)\n\n        # Interpolate from the start to end prompt using slerp and add the generated images to an image output pipeline\n        batch_text_embeds = []\n        batch_last_hidden_state = []\n\n        for interp_val in torch.linspace(0, 1, steps):\n            text_embeds = slerp(interp_val, text_model_output.text_embeds[0], text_model_output.text_embeds[1])\n            last_hidden_state = slerp(\n                interp_val, text_model_output.last_hidden_state[0], text_model_output.last_hidden_state[1]\n            )\n            batch_text_embeds.append(text_embeds.unsqueeze(0))\n            batch_last_hidden_state.append(last_hidden_state.unsqueeze(0))\n\n        batch_text_embeds = torch.cat(batch_text_embeds)\n        batch_last_hidden_state = torch.cat(batch_last_hidden_state)\n\n        text_model_output = CLIPTextModelOutput(\n            text_embeds=batch_text_embeds, last_hidden_state=batch_last_hidden_state\n        )\n\n        batch_size = text_model_output[0].shape[0]\n\n        do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0\n\n        prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(\n            prompt=None,\n            device=device,\n            num_images_per_prompt=1,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            text_model_output=text_model_output,\n            text_attention_mask=text_attention_mask,\n        )\n\n        # prior\n\n        self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)\n        prior_timesteps_tensor = self.prior_scheduler.timesteps\n\n        embedding_dim = self.prior.config.embedding_dim\n\n        prior_latents = self.prepare_latents(\n            (batch_size, embedding_dim),\n            prompt_embeds.dtype,\n            device,\n            generator,\n            None,\n            self.prior_scheduler,\n        )\n\n        for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents\n\n            predicted_image_embedding = self.prior(\n                latent_model_input,\n                timestep=t,\n                proj_embedding=prompt_embeds,\n                encoder_hidden_states=text_encoder_hidden_states,\n                attention_mask=text_mask,\n            ).predicted_image_embedding\n\n            if do_classifier_free_guidance:\n                predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)\n                predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (\n                    predicted_image_embedding_text - predicted_image_embedding_uncond\n                )\n\n            if i + 1 == prior_timesteps_tensor.shape[0]:\n                prev_timestep = None\n            else:\n                prev_timestep = prior_timesteps_tensor[i + 1]\n\n            prior_latents = self.prior_scheduler.step(\n                predicted_image_embedding,\n                timestep=t,\n                sample=prior_latents,\n                generator=generator,\n                prev_timestep=prev_timestep,\n            ).prev_sample\n\n        prior_latents = self.prior.post_process_latents(prior_latents)\n\n        image_embeddings = prior_latents\n\n        # done prior\n\n        # decoder\n\n        text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(\n            image_embeddings=image_embeddings,\n            prompt_embeds=prompt_embeds,\n            text_encoder_hidden_states=text_encoder_hidden_states,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n        )\n\n        if device.type == \"mps\":\n            # HACK: MPS: There is a panic when padding bool tensors,\n            # so cast to int tensor for the pad and back to bool afterwards\n            text_mask = text_mask.type(torch.int)\n            decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)\n            decoder_text_mask = decoder_text_mask.type(torch.bool)\n        else:\n            decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)\n\n        self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)\n        decoder_timesteps_tensor = self.decoder_scheduler.timesteps\n\n        num_channels_latents = self.decoder.config.in_channels\n        height = self.decoder.config.sample_size\n        width = self.decoder.config.sample_size\n\n        decoder_latents = self.prepare_latents(\n            (batch_size, num_channels_latents, height, width),\n            text_encoder_hidden_states.dtype,\n            device,\n            generator,\n            None,\n            self.decoder_scheduler,\n        )\n\n        for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents\n\n            noise_pred = self.decoder(\n                sample=latent_model_input,\n                timestep=t,\n                encoder_hidden_states=text_encoder_hidden_states,\n                class_labels=additive_clip_time_embeddings,\n                attention_mask=decoder_text_mask,\n            ).sample\n\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)\n                noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)\n                noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)\n                noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)\n\n            if i + 1 == decoder_timesteps_tensor.shape[0]:\n                prev_timestep = None\n            else:\n                prev_timestep = decoder_timesteps_tensor[i + 1]\n\n            # compute the previous noisy sample x_t -> x_t-1\n            decoder_latents = self.decoder_scheduler.step(\n                noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator\n            ).prev_sample\n\n        decoder_latents = decoder_latents.clamp(-1, 1)\n\n        image_small = decoder_latents\n\n        # done decoder\n\n        # super res\n\n        self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)\n        super_res_timesteps_tensor = self.super_res_scheduler.timesteps\n\n        channels = self.super_res_first.config.in_channels // 2\n        height = self.super_res_first.config.sample_size\n        width = self.super_res_first.config.sample_size\n\n        super_res_latents = self.prepare_latents(\n            (batch_size, channels, height, width),\n            image_small.dtype,\n            device,\n            generator,\n            None,\n            self.super_res_scheduler,\n        )\n\n        if device.type == \"mps\":\n            # MPS does not support many interpolations\n            image_upscaled = F.interpolate(image_small, size=[height, width])\n        else:\n            interpolate_antialias = {}\n            if \"antialias\" in inspect.signature(F.interpolate).parameters:\n                interpolate_antialias[\"antialias\"] = True\n\n            image_upscaled = F.interpolate(\n                image_small, size=[height, width], mode=\"bicubic\", align_corners=False, **interpolate_antialias\n            )\n\n        for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):\n            # no classifier free guidance\n\n            if i == super_res_timesteps_tensor.shape[0] - 1:\n                unet = self.super_res_last\n            else:\n                unet = self.super_res_first\n\n            latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)\n\n            noise_pred = unet(\n                sample=latent_model_input,\n                timestep=t,\n            ).sample\n\n            if i + 1 == super_res_timesteps_tensor.shape[0]:\n                prev_timestep = None\n            else:\n                prev_timestep = super_res_timesteps_tensor[i + 1]\n\n            # compute the previous noisy sample x_t -> x_t-1\n            super_res_latents = self.super_res_scheduler.step(\n                noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator\n            ).prev_sample\n\n        image = super_res_latents\n        # done super res\n\n        # post processing\n\n        image = image * 0.5 + 0.5\n        image = image.clamp(0, 1)\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image,)\n\n        return ImagePipelineOutput(images=image)\n"
  },
  {
    "path": "examples/community/wildcard_stable_diffusion.py",
    "content": "import inspect\nimport os\nimport random\nimport re\nfrom dataclasses import dataclass\nfrom typing import Callable, Dict, List, Optional, Union\n\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler\nfrom diffusers.utils import deprecate, logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nglobal_re_wildcard = re.compile(r\"__([^_]*)__\")\n\n\ndef get_filename(path: str):\n    # this doesn't work on Windows\n    return os.path.basename(path).split(\".txt\")[0]\n\n\ndef read_wildcard_values(path: str):\n    with open(path, encoding=\"utf8\") as f:\n        return f.read().splitlines()\n\n\ndef grab_wildcard_values(wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_files: List[str] = []):\n    for wildcard_file in wildcard_files:\n        filename = get_filename(wildcard_file)\n        read_values = read_wildcard_values(wildcard_file)\n        if filename not in wildcard_option_dict:\n            wildcard_option_dict[filename] = []\n        wildcard_option_dict[filename].extend(read_values)\n    return wildcard_option_dict\n\n\ndef replace_prompt_with_wildcards(\n    prompt: str, wildcard_option_dict: Dict[str, List[str]] = {}, wildcard_files: List[str] = []\n):\n    new_prompt = prompt\n\n    # get wildcard options\n    wildcard_option_dict = grab_wildcard_values(wildcard_option_dict, wildcard_files)\n\n    for m in global_re_wildcard.finditer(new_prompt):\n        wildcard_value = m.group()\n        replace_value = random.choice(wildcard_option_dict[wildcard_value.strip(\"__\")])\n        new_prompt = new_prompt.replace(wildcard_value, replace_value, 1)\n\n    return new_prompt\n\n\n@dataclass\nclass WildcardStableDiffusionOutput(StableDiffusionPipelineOutput):\n    prompts: List[str]\n\n\nclass WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):\n    r\"\"\"\n    Example Usage:\n        pipe = WildcardStableDiffusionPipeline.from_pretrained(\n            \"CompVis/stable-diffusion-v1-4\",\n\n            torch_dtype=torch.float16,\n        )\n        prompt = \"__animal__ sitting on a __object__ wearing a __clothing__\"\n        out = pipe(\n            prompt,\n            wildcard_option_dict={\n                \"clothing\":[\"hat\", \"shirt\", \"scarf\", \"beret\"]\n            },\n            wildcard_files=[\"object.txt\", \"animal.txt\"],\n            num_prompt_samples=1\n        )\n\n\n    Pipeline for text-to-image generation with wild cards using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n\n        if scheduler is not None and getattr(scheduler.config, \"steps_offset\", 1) != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        height: int = 512,\n        width: int = 512,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        wildcard_option_dict: Dict[str, List[str]] = {},\n        wildcard_files: List[str] = [],\n        num_prompt_samples: Optional[int] = 1,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n                if `guidance_scale` is less than `1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            wildcard_option_dict (Dict[str, List[str]]):\n                dict with key as `wildcard` and values as a list of possible replacements. For example if a prompt, \"A __animal__ sitting on a chair\". A wildcard_option_dict can provide possible values for \"animal\" like this: {\"animal\":[\"dog\", \"cat\", \"fox\"]}\n            wildcard_files: (List[str])\n               List of filenames of txt files for wildcard replacements. For example if a prompt, \"A __animal__ sitting on a chair\". A file can be provided [\"animal.txt\"]\n            num_prompt_samples: int\n                Number of times to sample wildcards for each prompt provided\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n\n        if isinstance(prompt, str):\n            prompt = [\n                replace_prompt_with_wildcards(prompt, wildcard_option_dict, wildcard_files)\n                for i in range(num_prompt_samples)\n            ]\n            batch_size = len(prompt)\n        elif isinstance(prompt, list):\n            prompt_list = []\n            for p in prompt:\n                for i in range(num_prompt_samples):\n                    prompt_list.append(replace_prompt_with_wildcards(p, wildcard_option_dict, wildcard_files))\n            prompt = prompt_list\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        # get prompt text embeddings\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n\n        if text_input_ids.shape[-1] > self.tokenizer.model_max_length:\n            removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n            )\n            text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]\n        text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = text_embeddings.shape\n        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)\n        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = text_input_ids.shape[-1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]\n\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = uncond_embeddings.shape[1]\n            uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)\n            uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n\n        # get the initial random noise unless the user supplied it\n\n        # Unlike in other pipelines, latents need to be generated in the target device\n        # for 1-to-1 results reproducibility with the CompVis implementation.\n        # However this currently doesn't work in `mps`.\n        latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8)\n        latents_dtype = text_embeddings.dtype\n        if latents is None:\n            if self.device.type == \"mps\":\n                # randn does not exist on mps\n                latents = torch.randn(latents_shape, generator=generator, device=\"cpu\", dtype=latents_dtype).to(\n                    self.device\n                )\n            else:\n                latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)\n        else:\n            if latents.shape != latents_shape:\n                raise ValueError(f\"Unexpected latents shape, got {latents.shape}, expected {latents_shape}\")\n            latents = latents.to(self.device)\n\n        # set timesteps\n        self.scheduler.set_timesteps(num_inference_steps)\n\n        # Some schedulers like PNDM have timesteps as arrays\n        # It's more optimized to move all timesteps to correct device beforehand\n        timesteps_tensor = self.scheduler.timesteps.to(self.device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        for i, t in enumerate(self.progress_bar(timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n            # call the callback, if provided\n            if callback is not None and i % callback_steps == 0:\n                step_idx = i // getattr(self.scheduler, \"order\", 1)\n                callback(step_idx, t, latents)\n\n        latents = 1 / 0.18215 * latents\n        image = self.vae.decode(latents).sample\n\n        image = (image / 2 + 0.5).clamp(0, 1)\n\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n\n        if self.safety_checker is not None:\n            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors=\"pt\").to(\n                self.device\n            )\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)\n            )\n        else:\n            has_nsfw_concept = None\n\n        if output_type == \"pil\":\n            image = self.numpy_to_pil(image)\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return WildcardStableDiffusionOutput(images=image, nsfw_content_detected=has_nsfw_concept, prompts=prompt)\n"
  },
  {
    "path": "examples/conftest.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# tests directory-specific settings - this file is run automatically\n# by pytest before any tests are run\n\nimport sys\nimport warnings\nfrom os.path import abspath, dirname, join\n\n\n# allow having multiple repository checkouts and not needing to remember to rerun\n# 'pip install -e .[dev]' when switching between checkouts and running tests.\ngit_repo_path = abspath(join(dirname(dirname(dirname(__file__))), \"src\"))\nsys.path.insert(1, git_repo_path)\n\n# Add parent directory to path so we can import from tests\nrepo_root = abspath(dirname(dirname(__file__)))\nif repo_root not in sys.path:\n    sys.path.insert(0, repo_root)\n\n\n# silence FutureWarning warnings in tests since often we can't act on them until\n# they become normal warnings - i.e. the tests still need to test the current functionality\nwarnings.simplefilter(action=\"ignore\", category=FutureWarning)\n\n\ndef pytest_addoption(parser):\n    from tests.testing_utils import pytest_addoption_shared\n\n    pytest_addoption_shared(parser)\n\n\ndef pytest_terminal_summary(terminalreporter):\n    from tests.testing_utils import pytest_terminal_summary_main\n\n    make_reports = terminalreporter.config.getoption(\"--make-reports\")\n    if make_reports:\n        pytest_terminal_summary_main(terminalreporter, id=make_reports)\n"
  },
  {
    "path": "examples/consistency_distillation/README.md",
    "content": "# Latent Consistency Distillation Example:\n\n[Latent Consistency Models (LCMs)](https://huggingface.co/papers/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill stable-diffusion-v1.5 for inference with few timesteps.\n\n## Full model distillation\n\n### Running locally with PyTorch\n\n#### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the example folder and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell e.g. a notebook\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\n\n\n#### Example\n\nThe following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path/to/saved/model\"\n\naccelerate launch train_lcm_distill_sd_wds.py \\\n    --pretrained_teacher_model=$MODEL_NAME \\\n    --output_dir=$OUTPUT_DIR \\\n    --mixed_precision=fp16 \\\n    --resolution=512 \\\n    --learning_rate=1e-6 --loss_type=\"huber\" --ema_decay=0.95 --adam_weight_decay=0.0 \\\n    --max_train_steps=1000 \\\n    --max_train_samples=4000000 \\\n    --dataloader_num_workers=8 \\\n    --train_shards_path_or_url=\"pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true\" \\\n    --validation_steps=200 \\\n    --checkpointing_steps=200 --checkpoints_total_limit=10 \\\n    --train_batch_size=12 \\\n    --gradient_checkpointing --enable_xformers_memory_efficient_attention \\\n    --gradient_accumulation_steps=1 \\\n    --use_8bit_adam \\\n    --resume_from_checkpoint=latest \\\n    --report_to=wandb \\\n    --seed=453645634 \\\n    --push_to_hub\n```\n\n## LCM-LoRA\n\nInstead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.\n\n### Example\n\nThe following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path/to/saved/model\"\n\naccelerate launch train_lcm_distill_lora_sd_wds.py \\\n    --pretrained_teacher_model=$MODEL_NAME \\\n    --output_dir=$OUTPUT_DIR \\\n    --mixed_precision=fp16 \\\n    --resolution=512 \\\n    --lora_rank=64 \\\n    --learning_rate=1e-4 --loss_type=\"huber\" --adam_weight_decay=0.0 \\\n    --max_train_steps=1000 \\\n    --max_train_samples=4000000 \\\n    --dataloader_num_workers=8 \\\n    --train_shards_path_or_url=\"pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true\" \\\n    --validation_steps=200 \\\n    --checkpointing_steps=200 --checkpoints_total_limit=10 \\\n    --train_batch_size=12 \\\n    --gradient_checkpointing --enable_xformers_memory_efficient_attention \\\n    --gradient_accumulation_steps=1 \\\n    --use_8bit_adam \\\n    --resume_from_checkpoint=latest \\\n    --report_to=wandb \\\n    --seed=453645634 \\\n    --push_to_hub \\\n```"
  },
  {
    "path": "examples/consistency_distillation/README_sdxl.md",
    "content": "# Latent Consistency Distillation Example:\n\n[Latent Consistency Models (LCMs)](https://huggingface.co/papers/2310.04378) is a method to distill a latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use latent consistency distillation to distill SDXL for inference with few timesteps.\n\n## Full model distillation\n\n### Running locally with PyTorch\n\n#### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the example folder and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell e.g. a notebook\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\n\n\n#### Example\n\nThe following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example, and for illustrative purposes only. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/). You may also need to search the hyperparameter space according to the dataset you use.\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport OUTPUT_DIR=\"path/to/saved/model\"\n\naccelerate launch train_lcm_distill_sdxl_wds.py \\\n    --pretrained_teacher_model=$MODEL_NAME \\\n    --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \\\n    --output_dir=$OUTPUT_DIR \\\n    --mixed_precision=fp16 \\\n    --resolution=1024 \\\n    --learning_rate=1e-6 --loss_type=\"huber\" --use_fix_crop_and_size --ema_decay=0.95 --adam_weight_decay=0.0 \\\n    --max_train_steps=1000 \\\n    --max_train_samples=4000000 \\\n    --dataloader_num_workers=8 \\\n    --train_shards_path_or_url=\"pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true\" \\\n    --validation_steps=200 \\\n    --checkpointing_steps=200 --checkpoints_total_limit=10 \\\n    --train_batch_size=12 \\\n    --gradient_checkpointing --enable_xformers_memory_efficient_attention \\\n    --gradient_accumulation_steps=1 \\\n    --use_8bit_adam \\\n    --resume_from_checkpoint=latest \\\n    --report_to=wandb \\\n    --seed=453645634 \\\n    --push_to_hub \\\n```\n\n## LCM-LoRA\n\nInstead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.\n\n### Example\n\nThe following uses the [Conceptual Captions 12M (CC12M) dataset](https://github.com/google-research-datasets/conceptual-12m) as an example. For best results you may consider large and high-quality text-image datasets such as [LAION](https://laion.ai/blog/laion-400-open-dataset/).\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport OUTPUT_DIR=\"path/to/saved/model\"\n\naccelerate launch train_lcm_distill_lora_sdxl_wds.py \\\n    --pretrained_teacher_model=$MODEL_DIR \\\n    --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \\\n    --output_dir=$OUTPUT_DIR \\\n    --mixed_precision=fp16 \\\n    --resolution=1024 \\\n    --lora_rank=64 \\\n    --learning_rate=1e-4 --loss_type=\"huber\" --use_fix_crop_and_size --adam_weight_decay=0.0 \\\n    --max_train_steps=1000 \\\n    --max_train_samples=4000000 \\\n    --dataloader_num_workers=8 \\\n    --train_shards_path_or_url=\"pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true\" \\\n    --validation_steps=200 \\\n    --checkpointing_steps=200 --checkpoints_total_limit=10 \\\n    --train_batch_size=12 \\\n    --gradient_checkpointing --enable_xformers_memory_efficient_attention \\\n    --gradient_accumulation_steps=1 \\\n    --use_8bit_adam \\\n    --resume_from_checkpoint=latest \\\n    --report_to=wandb \\\n    --seed=453645634 \\\n    --push_to_hub \\\n```\n\nWe provide another version for LCM LoRA SDXL that follows best practices of `peft` and leverages the `datasets` library for quick experimentation. The script doesn't load two UNets unlike `train_lcm_distill_lora_sdxl_wds.py` which reduces the memory requirements quite a bit.\n\nBelow is an example training command that trains an LCM LoRA on the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions):\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\nexport VAE_PATH=\"madebyollin/sdxl-vae-fp16-fix\"\n\naccelerate launch train_lcm_distill_lora_sdxl.py \\\n  --pretrained_teacher_model=${MODEL_NAME}  \\\n  --pretrained_vae_model_name_or_path=${VAE_PATH} \\\n  --output_dir=\"narutos-lora-lcm-sdxl\" \\\n  --mixed_precision=\"fp16\" \\\n  --dataset_name=$DATASET_NAME \\\n  --resolution=1024 \\\n  --train_batch_size=24 \\\n  --gradient_accumulation_steps=1 \\\n  --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --lora_rank=64 \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=3000 \\\n  --checkpointing_steps=500 \\\n  --validation_steps=50 \\\n  --seed=\"0\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub\n```\n\n"
  },
  {
    "path": "examples/consistency_distillation/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2\nwebdataset"
  },
  {
    "path": "examples/consistency_distillation/test_lcm_lora.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass TextToImageLCM(ExamplesTestsAccelerate):\n    def test_text_to_image_lcm_lora_sdxl(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/consistency_distillation/train_lcm_distill_lora_sdxl.py\n                --pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --lora_rank 4\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n    def test_text_to_image_lcm_lora_sdxl_checkpointing(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/consistency_distillation/train_lcm_distill_lora_sdxl.py\n                --pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --lora_rank 4\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 7\n                --checkpointing_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n            test_args = f\"\"\"\n                examples/consistency_distillation/train_lcm_distill_lora_sdxl.py\n                --pretrained_teacher_model hf-internal-testing/tiny-stable-diffusion-xl-pipe\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --lora_rank 4\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 9\n                --checkpointing_steps 2\n                --resume_from_checkpoint latest\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\", \"checkpoint-6\", \"checkpoint-8\"},\n            )\n"
  },
  {
    "path": "examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport functools\nimport gc\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import List, Union\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport torchvision.transforms.functional as TF\nimport transformers\nimport webdataset as wds\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom braceexpand import braceexpand\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig, get_peft_model, get_peft_model_state_dict\nfrom torch.utils.data import default_collate\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig\nfrom webdataset.tariterators import (\n    base_plus_ext,\n    tar_file_expander,\n    url_opener,\n    valid_sample,\n)\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    LCMScheduler,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import resolve_interpolation_mode\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nMAX_SEQ_LENGTH = 77\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = \"default\"):\n    kohya_ss_state_dict = {}\n    for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items():\n        kohya_key = peft_key.replace(\"base_model.model\", prefix)\n        kohya_key = kohya_key.replace(\"lora_A\", \"lora_down\")\n        kohya_key = kohya_key.replace(\"lora_B\", \"lora_up\")\n        kohya_key = kohya_key.replace(\".\", \"_\", kohya_key.count(\".\") - 2)\n        kohya_ss_state_dict[kohya_key] = weight.to(dtype)\n\n        # Set alpha parameter\n        if \"lora_down\" in kohya_key:\n            alpha_key = f\"{kohya_key.split('.')[0]}.alpha\"\n            kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)\n\n    return kohya_ss_state_dict\n\n\ndef filter_keys(key_set):\n    def _f(dictionary):\n        return {k: v for k, v in dictionary.items() if k in key_set}\n\n    return _f\n\n\ndef group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):\n    \"\"\"Return function over iterator that groups key, value pairs into samples.\n\n    :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to\n    lower case (Default value = True)\n    \"\"\"\n    current_sample = None\n    for filesample in data:\n        assert isinstance(filesample, dict)\n        fname, value = filesample[\"fname\"], filesample[\"data\"]\n        prefix, suffix = keys(fname)\n        if prefix is None:\n            continue\n        if lcase:\n            suffix = suffix.lower()\n        # FIXME webdataset version throws if suffix in current_sample, but we have a potential for\n        #  this happening in the current LAION400m dataset if a tar ends with same prefix as the next\n        #  begins, rare, but can happen since prefix aren't unique across tar files in that dataset\n        if current_sample is None or prefix != current_sample[\"__key__\"] or suffix in current_sample:\n            if valid_sample(current_sample):\n                yield current_sample\n            current_sample = {\"__key__\": prefix, \"__url__\": filesample[\"__url__\"]}\n        if suffixes is None or suffix in suffixes:\n            current_sample[suffix] = value\n    if valid_sample(current_sample):\n        yield current_sample\n\n\ndef tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):\n    # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw\n    streams = url_opener(src, handler=handler)\n    files = tar_file_expander(streams, handler=handler)\n    samples = group_by_keys_nothrow(files, handler=handler)\n    return samples\n\n\nclass WebdatasetFilter:\n    def __init__(self, min_size=1024, max_pwatermark=0.5):\n        self.min_size = min_size\n        self.max_pwatermark = max_pwatermark\n\n    def __call__(self, x):\n        try:\n            if \"json\" in x:\n                x_json = json.loads(x[\"json\"])\n                filter_size = (x_json.get(\"original_width\", 0.0) or 0.0) >= self.min_size and x_json.get(\n                    \"original_height\", 0\n                ) >= self.min_size\n                filter_watermark = (x_json.get(\"pwatermark\", 1.0) or 1.0) <= self.max_pwatermark\n                return filter_size and filter_watermark\n            else:\n                return False\n        except Exception:\n            return False\n\n\nclass SDText2ImageDataset:\n    def __init__(\n        self,\n        train_shards_path_or_url: Union[str, List[str]],\n        num_train_examples: int,\n        per_gpu_batch_size: int,\n        global_batch_size: int,\n        num_workers: int,\n        resolution: int = 512,\n        interpolation_type: str = \"bilinear\",\n        shuffle_buffer_size: int = 1000,\n        pin_memory: bool = False,\n        persistent_workers: bool = False,\n    ):\n        if not isinstance(train_shards_path_or_url, str):\n            train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]\n            # flatten list using itertools\n            train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))\n\n        interpolation_mode = resolve_interpolation_mode(interpolation_type)\n\n        def transform(example):\n            # resize image\n            image = example[\"image\"]\n            image = TF.resize(image, resolution, interpolation=interpolation_mode)\n\n            # get crop coordinates and crop image\n            c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))\n            image = TF.crop(image, c_top, c_left, resolution, resolution)\n            image = TF.to_tensor(image)\n            image = TF.normalize(image, [0.5], [0.5])\n\n            example[\"image\"] = image\n            return example\n\n        processing_pipeline = [\n            wds.decode(\"pil\", handler=wds.ignore_and_continue),\n            wds.rename(image=\"jpg;png;jpeg;webp\", text=\"text;txt;caption\", handler=wds.warn_and_continue),\n            wds.map(filter_keys({\"image\", \"text\"})),\n            wds.map(transform),\n            wds.to_tuple(\"image\", \"text\"),\n        ]\n\n        # Create train dataset and loader\n        pipeline = [\n            wds.ResampledShards(train_shards_path_or_url),\n            tarfile_to_samples_nothrow,\n            wds.shuffle(shuffle_buffer_size),\n            *processing_pipeline,\n            wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),\n        ]\n\n        num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers))  # per dataloader worker\n        num_batches = num_worker_batches * num_workers\n        num_samples = num_batches * global_batch_size\n\n        # each worker is iterating over this\n        self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)\n        self._train_dataloader = wds.WebLoader(\n            self._train_dataset,\n            batch_size=None,\n            shuffle=False,\n            num_workers=num_workers,\n            pin_memory=pin_memory,\n            persistent_workers=persistent_workers,\n        )\n        # add meta-data to dataloader instance for convenience\n        self._train_dataloader.num_batches = num_batches\n        self._train_dataloader.num_samples = num_samples\n\n    @property\n    def train_dataset(self):\n        return self._train_dataset\n\n    @property\n    def train_dataloader(self):\n        return self._train_dataloader\n\n\ndef log_validation(vae, unet, args, accelerator, weight_dtype, step):\n    logger.info(\"Running validation... \")\n    if torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)\n\n    unet = accelerator.unwrap_model(unet)\n    pipeline = StableDiffusionPipeline.from_pretrained(\n        args.pretrained_teacher_model,\n        vae=vae,\n        scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder=\"scheduler\"),\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n        safety_checker=None,\n    )\n    pipeline.set_progress_bar_config(disable=True)\n\n    lora_state_dict = get_module_kohya_state_dict(unet, \"lora_unet\", weight_dtype)\n    pipeline.load_lora_weights(lora_state_dict)\n    pipeline.fuse_lora()\n\n    pipeline = pipeline.to(accelerator.device, dtype=weight_dtype)\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    validation_prompts = [\n        \"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography\",\n        \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\",\n        \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\",\n        \"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece\",\n    ]\n\n    image_logs = []\n\n    for _, prompt in enumerate(validation_prompts):\n        images = []\n        with autocast_ctx:\n            images = pipeline(\n                prompt=prompt,\n                num_inference_steps=4,\n                num_images_per_prompt=4,\n                generator=generator,\n                guidance_scale=1.0,\n            ).images\n        image_logs.append({\"validation_prompt\": prompt, \"images\": images})\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                formatted_images = []\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({\"validation\": formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n        del pipeline\n        gc.collect()\n        torch.cuda.empty_cache()\n\n        return image_logs\n\n\n# From LatentConsistencyModel.get_guidance_scale_embedding\ndef guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):\n    \"\"\"\n    See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n    Args:\n        timesteps (`torch.Tensor`):\n            generate embedding vectors at these timesteps\n        embedding_dim (`int`, *optional*, defaults to 512):\n            dimension of the embeddings to generate\n        dtype:\n            data type of the generated embeddings\n\n    Returns:\n        `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n    \"\"\"\n    assert len(w.shape) == 1\n    w = w * 1000.0\n\n    half_dim = embedding_dim // 2\n    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n    emb = w.to(dtype)[:, None] * emb[None, :]\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n    if embedding_dim % 2 == 1:  # zero pad\n        emb = torch.nn.functional.pad(emb, (0, 1))\n    assert emb.shape == (w.shape[0], embedding_dim)\n    return emb\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\")\n    return x[(...,) + (None,) * dims_to_append]\n\n\n# From LCMScheduler.get_scalings_for_boundary_condition_discrete\ndef scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):\n    scaled_timestep = timestep_scaling * timestep\n    c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)\n    c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5\n    return c_skip, c_out\n\n\n# Compare LCMScheduler.step, Step 4\ndef get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):\n    alphas = extract_into_tensor(alphas, timesteps, sample.shape)\n    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)\n    if prediction_type == \"epsilon\":\n        pred_x_0 = (sample - sigmas * model_output) / alphas\n    elif prediction_type == \"sample\":\n        pred_x_0 = model_output\n    elif prediction_type == \"v_prediction\":\n        pred_x_0 = alphas * sample - sigmas * model_output\n    else:\n        raise ValueError(\n            f\"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`\"\n            f\" are supported.\"\n        )\n\n    return pred_x_0\n\n\n# Based on step 4 in DDIMScheduler.step\ndef get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):\n    alphas = extract_into_tensor(alphas, timesteps, sample.shape)\n    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)\n    if prediction_type == \"epsilon\":\n        pred_epsilon = model_output\n    elif prediction_type == \"sample\":\n        pred_epsilon = (sample - alphas * model_output) / sigmas\n    elif prediction_type == \"v_prediction\":\n        pred_epsilon = alphas * model_output + sigmas * sample\n    else:\n        raise ValueError(\n            f\"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`\"\n            f\" are supported.\"\n        )\n\n    return pred_epsilon\n\n\ndef extract_into_tensor(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\nclass DDIMSolver:\n    def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):\n        # DDIM sampling parameters\n        step_ratio = timesteps // ddim_timesteps\n        self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1\n        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]\n        self.ddim_alpha_cumprods_prev = np.asarray(\n            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()\n        )\n        # convert to torch tensors\n        self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()\n        self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)\n        self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)\n\n    def to(self, device):\n        self.ddim_timesteps = self.ddim_timesteps.to(device)\n        self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)\n        self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)\n        return self\n\n    def ddim_step(self, pred_x0, pred_noise, timestep_index):\n        alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)\n        dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise\n        x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt\n        return x_prev\n\n\n@torch.no_grad()\ndef update_ema(target_params, source_params, rate=0.99):\n    \"\"\"\n    Update target parameters to be closer to those of source parameters using\n    an exponential moving average.\n\n    :param target_params: the target parameter sequence.\n    :param source_params: the source parameter sequence.\n    :param rate: the EMA rate (closer to 1 means slower).\n    \"\"\"\n    for targ, src in zip(target_params, source_params):\n        targ.detach().mul_(rate).add_(src, alpha=1 - rate)\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    # ----------Model Checkpoint Loading Arguments----------\n    parser.add_argument(\n        \"--pretrained_teacher_model\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained LDM teacher model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--teacher_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained LDM teacher model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained LDM model identifier from huggingface.co/models.\",\n    )\n    # ----------Training Arguments----------\n    # ----General Training Arguments----\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"lcm-xl-distilled\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    # ----Logging----\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    # ----Checkpointing----\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    # ----Image Processing----\n    parser.add_argument(\n        \"--train_shards_path_or_url\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--interpolation_type\",\n        type=str,\n        default=\"bilinear\",\n        help=(\n            \"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,\"\n            \" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    # ----Dataloader----\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    # ----Batch Size and Training Steps----\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    # ----Learning Rate----\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    # ----Optimizer (Adam)----\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    # ----Diffusion Training Arguments----\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    # ----Latent Consistency Distillation (LCD) Specific Arguments----\n    parser.add_argument(\n        \"--w_min\",\n        type=float,\n        default=5.0,\n        required=False,\n        help=(\n            \"The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG\"\n            \" formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as\"\n            \" compared to the original paper.\"\n        ),\n    )\n    parser.add_argument(\n        \"--w_max\",\n        type=float,\n        default=15.0,\n        required=False,\n        help=(\n            \"The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG\"\n            \" formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as\"\n            \" compared to the original paper.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_ddim_timesteps\",\n        type=int,\n        default=50,\n        help=\"The number of timesteps to use for DDIM sampling.\",\n    )\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"l2\",\n        choices=[\"l2\", \"huber\"],\n        help=\"The type of loss to use for the LCD loss.\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=0.001,\n        help=\"The huber loss parameter. Only used if `--loss_type=huber`.\",\n    )\n    parser.add_argument(\n        \"--lora_rank\",\n        type=int,\n        default=64,\n        help=\"The rank of the LoRA projection matrix.\",\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=64,\n        help=(\n            \"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight\"\n            \" update delta_W. No scaling will be performed if this value is equal to `lora_rank`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--lora_dropout\",\n        type=float,\n        default=0.0,\n        help=\"The dropout probability for the dropout layer added before applying the LoRA to each layer input.\",\n    )\n    parser.add_argument(\n        \"--lora_target_modules\",\n        type=str,\n        default=None,\n        help=(\n            \"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will\"\n            \" be used. By default, LoRA will be applied to all conv and linear layers.\"\n        ),\n    )\n    parser.add_argument(\n        \"--vae_encode_batch_size\",\n        type=int,\n        default=32,\n        required=False,\n        help=(\n            \"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE.\"\n            \" Encoding or decoding the whole batch at once may run into OOM issues.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_scaling_factor\",\n        type=float,\n        default=10.0,\n        help=(\n            \"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The\"\n            \" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically\"\n            \" suffice.\"\n        ),\n    )\n    # ----Mixed Precision----\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cast_teacher_unet\",\n        action=\"store_true\",\n        help=\"Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.\",\n    )\n    # ----Training Optimizations----\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    # ----Distributed Training----\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    # ----------Validation Arguments----------\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=200,\n        help=\"Run validation every X steps.\",\n    )\n    # ----------Huggingface Hub Arguments-----------\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    # ----------Accelerate Arguments----------\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"text2image-fine-tune\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    return args\n\n\n# Adapted from pipelines.StableDiffusionPipeline.encode_prompt\ndef encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):\n    captions = []\n    for caption in prompt_batch:\n        if random.random() < proportion_empty_prompts:\n            captions.append(\"\")\n        elif isinstance(caption, str):\n            captions.append(caption)\n        elif isinstance(caption, (list, np.ndarray)):\n            # take a random caption if there are multiple\n            captions.append(random.choice(caption) if is_train else caption[0])\n\n    with torch.no_grad():\n        text_inputs = tokenizer(\n            captions,\n            padding=\"max_length\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]\n\n    return prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        split_batches=True,  # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes\n    )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n                token=args.hub_token,\n                private=True,\n            ).repo_id\n\n    # 1. Create the noise scheduler and the desired noise schedule.\n    noise_scheduler = DDPMScheduler.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"scheduler\", revision=args.teacher_revision\n    )\n\n    # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us\n    alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)\n    sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)\n    # Initialize the DDIM ODE solver for distillation.\n    solver = DDIMSolver(\n        noise_scheduler.alphas_cumprod.numpy(),\n        timesteps=noise_scheduler.config.num_train_timesteps,\n        ddim_timesteps=args.num_ddim_timesteps,\n    )\n\n    # 2. Load tokenizers from SD 1.X/2.X checkpoint.\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"tokenizer\", revision=args.teacher_revision, use_fast=False\n    )\n\n    # 3. Load text encoders from SD 1.X/2.X checkpoint.\n    # import correct text encoder classes\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"text_encoder\", revision=args.teacher_revision\n    )\n\n    # 4. Load VAE from SD 1.X/2.X checkpoint\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_teacher_model,\n        subfolder=\"vae\",\n        revision=args.teacher_revision,\n    )\n\n    # 5. Load teacher U-Net from SD 1.X/2.X checkpoint\n    teacher_unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"unet\", revision=args.teacher_revision\n    )\n\n    # 6. Freeze teacher vae, text_encoder, and teacher_unet\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    teacher_unet.requires_grad_(False)\n\n    # 7. Create online student U-Net.\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"unet\", revision=args.teacher_revision\n    )\n    unet.train()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if accelerator.unwrap_model(unet).dtype != torch.float32:\n        raise ValueError(\n            f\"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}\"\n        )\n\n    # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.\n    if args.lora_target_modules is not None:\n        lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(\",\")]\n    else:\n        lora_target_modules = [\n            \"to_q\",\n            \"to_k\",\n            \"to_v\",\n            \"to_out.0\",\n            \"proj_in\",\n            \"proj_out\",\n            \"ff.net.0.proj\",\n            \"ff.net.2\",\n            \"conv1\",\n            \"conv2\",\n            \"conv_shortcut\",\n            \"downsamplers.0.conv\",\n            \"upsamplers.0.conv\",\n            \"time_emb_proj\",\n        ]\n    lora_config = LoraConfig(\n        r=args.lora_rank,\n        target_modules=lora_target_modules,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n    )\n    unet = get_peft_model(unet, lora_config)\n\n    # 9. Handle mixed precision and device placement\n    # For mixed precision training we cast all non-trainable weights to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    vae.to(accelerator.device)\n    if args.pretrained_vae_model_name_or_path is not None:\n        vae.to(dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # Move teacher_unet to device, optionally cast to weight_dtype\n    teacher_unet.to(accelerator.device)\n    if args.cast_teacher_unet:\n        teacher_unet.to(dtype=weight_dtype)\n\n    # Also move the alpha and sigma noise schedules to accelerator.device.\n    alpha_schedule = alpha_schedule.to(accelerator.device)\n    sigma_schedule = sigma_schedule.to(accelerator.device)\n    # Move the ODE solver to accelerator.device.\n    solver = solver.to(accelerator.device)\n\n    # 10. Handle saving and loading of checkpoints\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                unet_ = accelerator.unwrap_model(unet)\n                lora_state_dict = get_peft_model_state_dict(unet_, adapter_name=\"default\")\n                StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, \"unet_lora\"), lora_state_dict)\n                # save weights in peft format to be able to load them back\n                unet_.save_pretrained(output_dir)\n\n                for _, model in enumerate(models):\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            # load the LoRA into the model\n            unet_ = accelerator.unwrap_model(unet)\n            unet_.load_adapter(input_dir, \"default\", is_trainable=True)\n\n            for _ in range(len(models)):\n                # pop models so that they are not loaded again\n                models.pop()\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # 11. Enable optimizations\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n            teacher_unet.enable_xformers_memory_efficient_attention()\n            # target_unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # 12. Optimizer creation\n    optimizer = optimizer_class(\n        unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # 13. Dataset creation and data processing\n    # Here, we compute not just the text embeddings but also the additional embeddings\n    # needed for the SD XL UNet to operate.\n    def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):\n        prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)\n        return {\"prompt_embeds\": prompt_embeds}\n\n    dataset = SDText2ImageDataset(\n        train_shards_path_or_url=args.train_shards_path_or_url,\n        num_train_examples=args.max_train_samples,\n        per_gpu_batch_size=args.train_batch_size,\n        global_batch_size=args.train_batch_size * accelerator.num_processes,\n        num_workers=args.dataloader_num_workers,\n        resolution=args.resolution,\n        interpolation_type=args.interpolation_type,\n        shuffle_buffer_size=1000,\n        pin_memory=True,\n        persistent_workers=True,\n    )\n    train_dataloader = dataset.train_dataloader\n\n    compute_embeddings_fn = functools.partial(\n        compute_embeddings,\n        proportion_empty_prompts=0,\n        text_encoder=text_encoder,\n        tokenizer=tokenizer,\n    )\n\n    # 14. LR Scheduler creation\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps,\n        num_training_steps=args.max_train_steps,\n    )\n\n    # 15. Prepare for training\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    uncond_input_ids = tokenizer(\n        [\"\"] * args.train_batch_size, return_tensors=\"pt\", padding=\"max_length\", max_length=77\n    ).input_ids.to(accelerator.device)\n    uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]\n\n    if torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type)\n\n    # 16. Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches each epoch = {train_dataloader.num_batches}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # 1. Load and process the image and text conditioning\n                image, text = batch\n\n                image = image.to(accelerator.device, non_blocking=True)\n                encoded_text = compute_embeddings_fn(text)\n\n                pixel_values = image.to(dtype=weight_dtype)\n                if vae.dtype != weight_dtype:\n                    vae.to(dtype=weight_dtype)\n\n                # encode pixel values with batch size of at most args.vae_encode_batch_size\n                latents = []\n                for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):\n                    latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())\n                latents = torch.cat(latents, dim=0)\n\n                latents = latents * vae.config.scaling_factor\n                latents = latents.to(weight_dtype)\n                bsz = latents.shape[0]\n\n                # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.\n                # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]\n                topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps\n                index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()\n                start_timesteps = solver.ddim_timesteps[index]\n                timesteps = start_timesteps - topk\n                timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)\n\n                # 3. Get boundary scalings for start_timesteps and (end) timesteps.\n                c_skip_start, c_out_start = scalings_for_boundary_conditions(\n                    start_timesteps, timestep_scaling=args.timestep_scaling_factor\n                )\n                c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]\n                c_skip, c_out = scalings_for_boundary_conditions(\n                    timesteps, timestep_scaling=args.timestep_scaling_factor\n                )\n                c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]\n\n                # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each\n                # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]\n                noise = torch.randn_like(latents)\n                noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)\n\n                # 5. Sample a random guidance scale w from U[w_min, w_max]\n                # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding\n                w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min\n                w = w.reshape(bsz, 1, 1, 1)\n                w = w.to(device=latents.device, dtype=latents.dtype)\n\n                # 6. Prepare prompt embeds and unet_added_conditions\n                prompt_embeds = encoded_text.pop(\"prompt_embeds\")\n\n                # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)\n                noise_pred = unet(\n                    noisy_model_input,\n                    start_timesteps,\n                    timestep_cond=None,\n                    encoder_hidden_states=prompt_embeds.float(),\n                    added_cond_kwargs=encoded_text,\n                ).sample\n\n                pred_x_0 = get_predicted_original_sample(\n                    noise_pred,\n                    start_timesteps,\n                    noisy_model_input,\n                    noise_scheduler.config.prediction_type,\n                    alpha_schedule,\n                    sigma_schedule,\n                )\n\n                model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0\n\n                # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the\n                # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these\n                # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE\n                # solver timestep.\n                with torch.no_grad():\n                    with autocast_ctx:\n                        # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c\n                        cond_teacher_output = teacher_unet(\n                            noisy_model_input.to(weight_dtype),\n                            start_timesteps,\n                            encoder_hidden_states=prompt_embeds.to(weight_dtype),\n                        ).sample\n                        cond_pred_x0 = get_predicted_original_sample(\n                            cond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n                        cond_pred_noise = get_predicted_noise(\n                            cond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n\n                        # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0\n                        uncond_teacher_output = teacher_unet(\n                            noisy_model_input.to(weight_dtype),\n                            start_timesteps,\n                            encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),\n                        ).sample\n                        uncond_pred_x0 = get_predicted_original_sample(\n                            uncond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n                        uncond_pred_noise = get_predicted_noise(\n                            uncond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n\n                        # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)\n                        # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation\n                        pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)\n                        pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)\n                        # 4. Run one step of the ODE solver to estimate the next point x_prev on the\n                        # augmented PF-ODE trajectory (solving backward in time)\n                        # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.\n                        x_prev = solver.ddim_step(pred_x0, pred_noise, index)\n\n                # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)\n                # Note that we do not use a separate target network for LCM-LoRA distillation.\n                with torch.no_grad():\n                    with autocast_ctx:\n                        target_noise_pred = unet(\n                            x_prev.float(),\n                            timesteps,\n                            timestep_cond=None,\n                            encoder_hidden_states=prompt_embeds.float(),\n                        ).sample\n                    pred_x_0 = get_predicted_original_sample(\n                        target_noise_pred,\n                        timesteps,\n                        x_prev,\n                        noise_scheduler.config.prediction_type,\n                        alpha_schedule,\n                        sigma_schedule,\n                    )\n                    target = c_skip * x_prev + c_out * pred_x_0\n\n                # 10. Calculate loss\n                if args.loss_type == \"l2\":\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                elif args.loss_type == \"huber\":\n                    loss = torch.mean(\n                        torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c\n                    )\n\n                # 11. Backpropagate on the online student model (`unet`)\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if global_step % args.validation_steps == 0:\n                        log_validation(vae, unet, args, accelerator, weight_dtype, global_step)\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = accelerator.unwrap_model(unet)\n        unet.save_pretrained(args.output_dir)\n        lora_state_dict = get_peft_model_state_dict(unet, adapter_name=\"default\")\n        StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, \"unet_lora\"), lora_state_dict)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/consistency_distillation/train_lcm_distill_lora_sdxl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The LCM team and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport functools\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    LCMScheduler,\n    StableDiffusionXLPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import cast_training_params, resolve_interpolation_mode\nfrom diffusers.utils import (\n    check_min_version,\n    convert_state_dict_to_diffusers,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\nclass DDIMSolver:\n    def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):\n        # DDIM sampling parameters\n        step_ratio = timesteps // ddim_timesteps\n\n        self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1\n        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]\n        self.ddim_alpha_cumprods_prev = np.asarray(\n            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()\n        )\n        # convert to torch tensors\n        self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()\n        self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)\n        self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)\n\n    def to(self, device):\n        self.ddim_timesteps = self.ddim_timesteps.to(device)\n        self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)\n        self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)\n        return self\n\n    def ddim_step(self, pred_x0, pred_noise, timestep_index):\n        alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)\n        dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise\n        x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt\n        return x_prev\n\n\ndef log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_final_validation=False):\n    logger.info(\"Running validation... \")\n\n    pipeline = StableDiffusionXLPipeline.from_pretrained(\n        args.pretrained_teacher_model,\n        vae=vae,\n        scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder=\"scheduler\"),\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n    ).to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    to_load = None\n    if not is_final_validation:\n        if unet is None:\n            raise ValueError(\"Must provide a `unet` when doing intermediate validation.\")\n        unet = accelerator.unwrap_model(unet)\n        state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n        to_load = state_dict\n    else:\n        to_load = args.output_dir\n\n    pipeline.load_lora_weights(to_load)\n    pipeline.fuse_lora()\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    validation_prompts = [\n        \"cute sundar pichai character\",\n        \"robotic cat with wings\",\n        \"a photo of yoda\",\n        \"a cute creature with blue eyes\",\n    ]\n\n    image_logs = []\n\n    for _, prompt in enumerate(validation_prompts):\n        images = []\n        if torch.backends.mps.is_available():\n            autocast_ctx = nullcontext()\n        else:\n            autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)\n\n        with autocast_ctx:\n            images = pipeline(\n                prompt=prompt,\n                num_inference_steps=4,\n                num_images_per_prompt=4,\n                generator=generator,\n                guidance_scale=0.0,\n            ).images\n        image_logs.append({\"validation_prompt\": prompt, \"images\": images})\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                formatted_images = []\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n            logger_name = \"test\" if is_final_validation else \"validation\"\n            tracker.log({logger_name: formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n        del pipeline\n        gc.collect()\n        torch.cuda.empty_cache()\n\n        return image_logs\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\")\n    return x[(...,) + (None,) * dims_to_append]\n\n\n# From LCMScheduler.get_scalings_for_boundary_condition_discrete\ndef scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):\n    scaled_timestep = timestep_scaling * timestep\n    c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)\n    c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5\n    return c_skip, c_out\n\n\n# Compare LCMScheduler.step, Step 4\ndef get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):\n    alphas = extract_into_tensor(alphas, timesteps, sample.shape)\n    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)\n    if prediction_type == \"epsilon\":\n        pred_x_0 = (sample - sigmas * model_output) / alphas\n    elif prediction_type == \"sample\":\n        pred_x_0 = model_output\n    elif prediction_type == \"v_prediction\":\n        pred_x_0 = alphas * sample - sigmas * model_output\n    else:\n        raise ValueError(\n            f\"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`\"\n            f\" are supported.\"\n        )\n\n    return pred_x_0\n\n\n# Based on step 4 in DDIMScheduler.step\ndef get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):\n    alphas = extract_into_tensor(alphas, timesteps, sample.shape)\n    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)\n    if prediction_type == \"epsilon\":\n        pred_epsilon = model_output\n    elif prediction_type == \"sample\":\n        pred_epsilon = (sample - alphas * model_output) / sigmas\n    elif prediction_type == \"v_prediction\":\n        pred_epsilon = alphas * model_output + sigmas * sample\n    else:\n        raise ValueError(\n            f\"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`\"\n            f\" are supported.\"\n        )\n\n    return pred_epsilon\n\n\ndef extract_into_tensor(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    # ----------Model Checkpoint Loading Arguments----------\n    parser.add_argument(\n        \"--pretrained_teacher_model\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained LDM teacher model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--teacher_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained LDM teacher model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained LDM model identifier from huggingface.co/models.\",\n    )\n    # ----------Training Arguments----------\n    # ----General Training Arguments----\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"lcm-xl-distilled\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    # ----Logging----\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    # ----Checkpointing----\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    # ----Image Processing----\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--interpolation_type\",\n        type=str,\n        default=\"bilinear\",\n        help=(\n            \"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,\"\n            \" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--encode_batch_size\",\n        type=int,\n        default=8,\n        help=\"Batch size to use for VAE encoding of the images for efficient processing.\",\n    )\n    # ----Dataloader----\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    # ----Batch Size and Training Steps----\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    # ----Learning Rate----\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    # ----Optimizer (Adam)----\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    # ----Diffusion Training Arguments----\n    # ----Latent Consistency Distillation (LCD) Specific Arguments----\n    parser.add_argument(\n        \"--w_min\",\n        type=float,\n        default=3.0,\n        required=False,\n        help=(\n            \"The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG\"\n            \" formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as\"\n            \" compared to the original paper.\"\n        ),\n    )\n    parser.add_argument(\n        \"--w_max\",\n        type=float,\n        default=15.0,\n        required=False,\n        help=(\n            \"The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG\"\n            \" formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as\"\n            \" compared to the original paper.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_ddim_timesteps\",\n        type=int,\n        default=50,\n        help=\"The number of timesteps to use for DDIM sampling.\",\n    )\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"l2\",\n        choices=[\"l2\", \"huber\"],\n        help=\"The type of loss to use for the LCD loss.\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=0.001,\n        help=\"The huber loss parameter. Only used if `--loss_type=huber`.\",\n    )\n    parser.add_argument(\n        \"--lora_rank\",\n        type=int,\n        default=64,\n        help=\"The rank of the LoRA projection matrix.\",\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=64,\n        help=(\n            \"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight\"\n            \" update delta_W. No scaling will be performed if this value is equal to `lora_rank`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--lora_dropout\",\n        type=float,\n        default=0.0,\n        help=\"The dropout probability for the dropout layer added before applying the LoRA to each layer input.\",\n    )\n    parser.add_argument(\n        \"--lora_target_modules\",\n        type=str,\n        default=None,\n        help=(\n            \"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will\"\n            \" be used. By default, LoRA will be applied to all conv and linear layers.\"\n        ),\n    )\n    parser.add_argument(\n        \"--vae_encode_batch_size\",\n        type=int,\n        default=8,\n        required=False,\n        help=(\n            \"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE.\"\n            \" Encoding or decoding the whole batch at once may run into OOM issues.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_scaling_factor\",\n        type=float,\n        default=10.0,\n        help=(\n            \"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The\"\n            \" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically\"\n            \" suffice.\"\n        ),\n    )\n    # ----Mixed Precision----\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    # ----Training Optimizations----\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    # ----Distributed Training----\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    # ----------Validation Arguments----------\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=200,\n        help=\"Run validation every X steps.\",\n    )\n    # ----------Huggingface Hub Arguments-----------\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    # ----------Accelerate Arguments----------\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"text2image-fine-tune\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(prompt_batch, text_encoders, tokenizers, is_train=True):\n    prompt_embeds_list = []\n\n    captions = []\n    for caption in prompt_batch:\n        if isinstance(caption, str):\n            captions.append(caption)\n        elif isinstance(caption, (list, np.ndarray)):\n            # take a random caption if there are multiple\n            captions.append(random.choice(caption) if is_train else caption[0])\n\n    with torch.no_grad():\n        for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n            text_inputs = tokenizer(\n                captions,\n                padding=\"max_length\",\n                max_length=tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            prompt_embeds = text_encoder(\n                text_input_ids.to(text_encoder.device),\n                output_hidden_states=True,\n            )\n\n            # We are only ALWAYS interested in the pooled output of the final text encoder\n            pooled_prompt_embeds = prompt_embeds[0]\n            prompt_embeds = prompt_embeds.hidden_states[-2]\n            bs_embed, seq_len, _ = prompt_embeds.shape\n            prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n            prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        split_batches=True,  # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes\n    )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n                token=args.hub_token,\n                private=True,\n            ).repo_id\n\n    # 1. Create the noise scheduler and the desired noise schedule.\n    noise_scheduler = DDPMScheduler.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"scheduler\", revision=args.teacher_revision\n    )\n\n    # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us\n    alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)\n    sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)\n    # Initialize the DDIM ODE solver for distillation.\n    solver = DDIMSolver(\n        noise_scheduler.alphas_cumprod.numpy(),\n        timesteps=noise_scheduler.config.num_train_timesteps,\n        ddim_timesteps=args.num_ddim_timesteps,\n    )\n\n    # 2. Load tokenizers from SDXL checkpoint.\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"tokenizer\", revision=args.teacher_revision, use_fast=False\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"tokenizer_2\", revision=args.teacher_revision, use_fast=False\n    )\n\n    # 3. Load text encoders from SDXL checkpoint.\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_teacher_model, args.teacher_revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_teacher_model, args.teacher_revision, subfolder=\"text_encoder_2\"\n    )\n\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"text_encoder\", revision=args.teacher_revision\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"text_encoder_2\", revision=args.teacher_revision\n    )\n\n    # 4. Load VAE from SDXL checkpoint (or more stable VAE)\n    vae_path = (\n        args.pretrained_teacher_model\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.teacher_revision,\n    )\n\n    # 6. Freeze teacher vae, text_encoders.\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n\n    # 7. Create online student U-Net.\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"unet\", revision=args.teacher_revision\n    )\n    unet.requires_grad_(False)\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if accelerator.unwrap_model(unet).dtype != torch.float32:\n        raise ValueError(\n            f\"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}\"\n        )\n\n    # 8. Handle mixed precision and device placement\n    # For mixed precision training we cast all non-trainable weights to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    unet.to(accelerator.device, dtype=weight_dtype)\n    if args.pretrained_vae_model_name_or_path is None:\n        vae.to(accelerator.device, dtype=torch.float32)\n    else:\n        vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.\n    if args.lora_target_modules is not None:\n        lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(\",\")]\n    else:\n        lora_target_modules = [\n            \"to_q\",\n            \"to_k\",\n            \"to_v\",\n            \"to_out.0\",\n            \"proj_in\",\n            \"proj_out\",\n            \"ff.net.0.proj\",\n            \"ff.net.2\",\n            \"conv1\",\n            \"conv2\",\n            \"conv_shortcut\",\n            \"downsamplers.0.conv\",\n            \"upsamplers.0.conv\",\n            \"time_emb_proj\",\n        ]\n    lora_config = LoraConfig(\n        r=args.lora_rank,\n        target_modules=lora_target_modules,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n    )\n    unet.add_adapter(lora_config)\n\n    # Also move the alpha and sigma noise schedules to accelerator.device.\n    alpha_schedule = alpha_schedule.to(accelerator.device)\n    sigma_schedule = sigma_schedule.to(accelerator.device)\n    solver = solver.to(accelerator.device)\n\n    # 10. Handle saving and loading of checkpoints\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                unet_ = accelerator.unwrap_model(unet)\n                # also save the checkpoints in native `diffusers` format so that it can be easily\n                # be independently loaded via `load_lora_weights()`.\n                state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))\n                StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict)\n\n                for _, model in enumerate(models):\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            # load the LoRA into the model\n            unet_ = accelerator.unwrap_model(unet)\n            lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)\n            unet_state_dict = {\n                f\"{k.replace('unet.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"unet.\")\n            }\n            unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)\n            incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name=\"default\")\n            if incompatible_keys is not None:\n                # check only for unexpected keys\n                unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n                if unexpected_keys:\n                    logger.warning(\n                        f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                        f\" {unexpected_keys}. \"\n                    )\n\n            for _ in range(len(models)):\n                # pop models so that they are not loaded again\n                models.pop()\n\n            # Make sure the trainable params are in float32. This is again needed since the base models\n            # are in `weight_dtype`. More details:\n            # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n            if args.mixed_precision == \"fp16\":\n                cast_training_params(unet_, dtype=torch.float32)\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # 11. Enable optimizations\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # 12. Optimizer creation\n    params_to_optimize = filter(lambda p: p.requires_grad, unet.parameters())\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # 13. Dataset creation and data processing\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    column_names = dataset[\"train\"].column_names\n\n    # Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    interpolation_mode = resolve_interpolation_mode(args.interpolation_type)\n    train_resize = transforms.Resize(args.resolution, interpolation=interpolation_mode)\n    train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)\n    train_flip = transforms.RandomHorizontalFlip(p=1.0)\n    train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        # image aug\n        original_sizes = []\n        all_images = []\n        crop_top_lefts = []\n        for image in images:\n            original_sizes.append((image.height, image.width))\n            image = train_resize(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                x1 = image.width - x1\n                image = train_flip(image)\n            crop_top_left = (y1, x1)\n            crop_top_lefts.append(crop_top_left)\n            image = train_transforms(image)\n            all_images.append(image)\n\n        examples[\"original_sizes\"] = original_sizes\n        examples[\"crop_top_lefts\"] = crop_top_lefts\n        examples[\"pixel_values\"] = all_images\n        examples[\"captions\"] = list(examples[caption_column])\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        original_sizes = [example[\"original_sizes\"] for example in examples]\n        crop_top_lefts = [example[\"crop_top_lefts\"] for example in examples]\n        captions = [example[\"captions\"] for example in examples]\n\n        return {\n            \"pixel_values\": pixel_values,\n            \"captions\": captions,\n            \"original_sizes\": original_sizes,\n            \"crop_top_lefts\": crop_top_lefts,\n        }\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # 14. Embeddings for the UNet.\n    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n    def compute_embeddings(prompt_batch, original_sizes, crop_coords, text_encoders, tokenizers, is_train=True):\n        def compute_time_ids(original_size, crops_coords_top_left):\n            target_size = (args.resolution, args.resolution)\n            add_time_ids = list(original_size + crops_coords_top_left + target_size)\n            add_time_ids = torch.tensor([add_time_ids])\n            add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)\n            return add_time_ids\n\n        prompt_embeds, pooled_prompt_embeds = encode_prompt(prompt_batch, text_encoders, tokenizers, is_train)\n        add_text_embeds = pooled_prompt_embeds\n\n        add_time_ids = torch.cat([compute_time_ids(s, c) for s, c in zip(original_sizes, crop_coords)])\n\n        prompt_embeds = prompt_embeds.to(accelerator.device)\n        add_text_embeds = add_text_embeds.to(accelerator.device)\n        unet_added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n        return {\"prompt_embeds\": prompt_embeds, **unet_added_cond_kwargs}\n\n    text_encoders = [text_encoder_one, text_encoder_two]\n    tokenizers = [tokenizer_one, tokenizer_two]\n\n    compute_embeddings_fn = functools.partial(compute_embeddings, text_encoders=text_encoders, tokenizers=tokenizers)\n\n    # 15. LR Scheduler creation\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(unet, dtype=torch.float32)\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n    )\n\n    # 16. Prepare for training\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # 17. Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    unet.train()\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # 1. Load and process the image and text conditioning\n                pixel_values, text, orig_size, crop_coords = (\n                    batch[\"pixel_values\"],\n                    batch[\"captions\"],\n                    batch[\"original_sizes\"],\n                    batch[\"crop_top_lefts\"],\n                )\n\n                encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)\n\n                # encode pixel values with batch size of at most args.vae_encode_batch_size\n                pixel_values = pixel_values.to(dtype=vae.dtype)\n                latents = []\n                for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):\n                    latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())\n                latents = torch.cat(latents, dim=0)\n\n                latents = latents * vae.config.scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    latents = latents.to(weight_dtype)\n\n                # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.\n                # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]\n                bsz = latents.shape[0]\n                topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps\n                index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()\n                start_timesteps = solver.ddim_timesteps[index]\n                timesteps = start_timesteps - topk\n                timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)\n\n                # 3. Get boundary scalings for start_timesteps and (end) timesteps.\n                c_skip_start, c_out_start = scalings_for_boundary_conditions(\n                    start_timesteps, timestep_scaling=args.timestep_scaling_factor\n                )\n                c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]\n                c_skip, c_out = scalings_for_boundary_conditions(\n                    timesteps, timestep_scaling=args.timestep_scaling_factor\n                )\n                c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]\n\n                # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each\n                # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]\n                noise = torch.randn_like(latents)\n                noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)\n\n                # 5. Sample a random guidance scale w from U[w_min, w_max]\n                # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding\n                w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min\n                w = w.reshape(bsz, 1, 1, 1)\n                w = w.to(device=latents.device, dtype=latents.dtype)\n\n                # 6. Prepare prompt embeds and unet_added_conditions\n                prompt_embeds = encoded_text.pop(\"prompt_embeds\")\n\n                # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)\n                noise_pred = unet(\n                    noisy_model_input,\n                    start_timesteps,\n                    encoder_hidden_states=prompt_embeds,\n                    added_cond_kwargs=encoded_text,\n                ).sample\n                pred_x_0 = get_predicted_original_sample(\n                    noise_pred,\n                    start_timesteps,\n                    noisy_model_input,\n                    noise_scheduler.config.prediction_type,\n                    alpha_schedule,\n                    sigma_schedule,\n                )\n                model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0\n\n                # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the\n                # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these\n                # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE\n                # solver timestep.\n\n                # With the adapters disabled, the `unet` is the regular teacher model.\n                accelerator.unwrap_model(unet).disable_adapters()\n                with torch.no_grad():\n                    # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c\n                    cond_teacher_output = unet(\n                        noisy_model_input,\n                        start_timesteps,\n                        encoder_hidden_states=prompt_embeds,\n                        added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},\n                    ).sample\n                    cond_pred_x0 = get_predicted_original_sample(\n                        cond_teacher_output,\n                        start_timesteps,\n                        noisy_model_input,\n                        noise_scheduler.config.prediction_type,\n                        alpha_schedule,\n                        sigma_schedule,\n                    )\n                    cond_pred_noise = get_predicted_noise(\n                        cond_teacher_output,\n                        start_timesteps,\n                        noisy_model_input,\n                        noise_scheduler.config.prediction_type,\n                        alpha_schedule,\n                        sigma_schedule,\n                    )\n\n                    # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0\n                    uncond_prompt_embeds = torch.zeros_like(prompt_embeds)\n                    uncond_pooled_prompt_embeds = torch.zeros_like(encoded_text[\"text_embeds\"])\n                    uncond_added_conditions = copy.deepcopy(encoded_text)\n                    uncond_added_conditions[\"text_embeds\"] = uncond_pooled_prompt_embeds\n                    uncond_teacher_output = unet(\n                        noisy_model_input,\n                        start_timesteps,\n                        encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),\n                        added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},\n                    ).sample\n                    uncond_pred_x0 = get_predicted_original_sample(\n                        uncond_teacher_output,\n                        start_timesteps,\n                        noisy_model_input,\n                        noise_scheduler.config.prediction_type,\n                        alpha_schedule,\n                        sigma_schedule,\n                    )\n                    uncond_pred_noise = get_predicted_noise(\n                        uncond_teacher_output,\n                        start_timesteps,\n                        noisy_model_input,\n                        noise_scheduler.config.prediction_type,\n                        alpha_schedule,\n                        sigma_schedule,\n                    )\n\n                    # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)\n                    # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation\n                    pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)\n                    pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)\n                    # 4. Run one step of the ODE solver to estimate the next point x_prev on the\n                    # augmented PF-ODE trajectory (solving backward in time)\n                    # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.\n                    x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype)\n\n                # re-enable unet adapters to turn the `unet` into a student unet.\n                accelerator.unwrap_model(unet).enable_adapters()\n\n                # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)\n                # Note that we do not use a separate target network for LCM-LoRA distillation.\n                with torch.no_grad():\n                    target_noise_pred = unet(\n                        x_prev,\n                        timesteps,\n                        encoder_hidden_states=prompt_embeds,\n                        added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},\n                    ).sample\n                    pred_x_0 = get_predicted_original_sample(\n                        target_noise_pred,\n                        timesteps,\n                        x_prev,\n                        noise_scheduler.config.prediction_type,\n                        alpha_schedule,\n                        sigma_schedule,\n                    )\n                    target = c_skip * x_prev + c_out * pred_x_0\n\n                # 10. Calculate loss\n                if args.loss_type == \"l2\":\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                elif args.loss_type == \"huber\":\n                    loss = torch.mean(\n                        torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c\n                    )\n\n                # 11. Backpropagate on the online student model (`unet`) (only LoRA)\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if global_step % args.validation_steps == 0:\n                        log_validation(\n                            vae, args, accelerator, weight_dtype, global_step, unet=unet, is_final_validation=False\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = accelerator.unwrap_model(unet)\n        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n        StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        del unet\n        torch.cuda.empty_cache()\n\n        # Final inference.\n        if args.validation_steps is not None:\n            log_validation(vae, args, accelerator, weight_dtype, step=global_step, unet=None, is_final_validation=True)\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport functools\nimport gc\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import List, Union\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport torchvision.transforms.functional as TF\nimport transformers\nimport webdataset as wds\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom braceexpand import braceexpand\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig, get_peft_model, get_peft_model_state_dict\nfrom torch.utils.data import default_collate\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\nfrom webdataset.tariterators import (\n    base_plus_ext,\n    tar_file_expander,\n    url_opener,\n    valid_sample,\n)\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    LCMScheduler,\n    StableDiffusionXLPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import resolve_interpolation_mode\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nMAX_SEQ_LENGTH = 77\n\n# Adjust for your dataset\nWDS_JSON_WIDTH = \"width\"  # original_width for LAION\nWDS_JSON_HEIGHT = \"height\"  # original_height for LAION\nMIN_SIZE = 700  # ~960 for LAION, ideal: 1024 if the dataset contains large images\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = \"default\"):\n    kohya_ss_state_dict = {}\n    for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items():\n        kohya_key = peft_key.replace(\"base_model.model\", prefix)\n        kohya_key = kohya_key.replace(\"lora_A\", \"lora_down\")\n        kohya_key = kohya_key.replace(\"lora_B\", \"lora_up\")\n        kohya_key = kohya_key.replace(\".\", \"_\", kohya_key.count(\".\") - 2)\n        kohya_ss_state_dict[kohya_key] = weight.to(dtype)\n\n        # Set alpha parameter\n        if \"lora_down\" in kohya_key:\n            alpha_key = f\"{kohya_key.split('.')[0]}.alpha\"\n            kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)\n\n    return kohya_ss_state_dict\n\n\ndef filter_keys(key_set):\n    def _f(dictionary):\n        return {k: v for k, v in dictionary.items() if k in key_set}\n\n    return _f\n\n\ndef group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):\n    \"\"\"Return function over iterator that groups key, value pairs into samples.\n\n    :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to\n    lower case (Default value = True)\n    \"\"\"\n    current_sample = None\n    for filesample in data:\n        assert isinstance(filesample, dict)\n        fname, value = filesample[\"fname\"], filesample[\"data\"]\n        prefix, suffix = keys(fname)\n        if prefix is None:\n            continue\n        if lcase:\n            suffix = suffix.lower()\n        # FIXME webdataset version throws if suffix in current_sample, but we have a potential for\n        #  this happening in the current LAION400m dataset if a tar ends with same prefix as the next\n        #  begins, rare, but can happen since prefix aren't unique across tar files in that dataset\n        if current_sample is None or prefix != current_sample[\"__key__\"] or suffix in current_sample:\n            if valid_sample(current_sample):\n                yield current_sample\n            current_sample = {\"__key__\": prefix, \"__url__\": filesample[\"__url__\"]}\n        if suffixes is None or suffix in suffixes:\n            current_sample[suffix] = value\n    if valid_sample(current_sample):\n        yield current_sample\n\n\ndef tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):\n    # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw\n    streams = url_opener(src, handler=handler)\n    files = tar_file_expander(streams, handler=handler)\n    samples = group_by_keys_nothrow(files, handler=handler)\n    return samples\n\n\nclass WebdatasetFilter:\n    def __init__(self, min_size=1024, max_pwatermark=0.5):\n        self.min_size = min_size\n        self.max_pwatermark = max_pwatermark\n\n    def __call__(self, x):\n        try:\n            if \"json\" in x:\n                x_json = json.loads(x[\"json\"])\n                filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(\n                    WDS_JSON_HEIGHT, 0\n                ) >= self.min_size\n                filter_watermark = (x_json.get(\"pwatermark\", 0.0) or 0.0) <= self.max_pwatermark\n                return filter_size and filter_watermark\n            else:\n                return False\n        except Exception:\n            return False\n\n\nclass SDXLText2ImageDataset:\n    def __init__(\n        self,\n        train_shards_path_or_url: Union[str, List[str]],\n        num_train_examples: int,\n        per_gpu_batch_size: int,\n        global_batch_size: int,\n        num_workers: int,\n        resolution: int = 1024,\n        interpolation_type: str = \"bilinear\",\n        shuffle_buffer_size: int = 1000,\n        pin_memory: bool = False,\n        persistent_workers: bool = False,\n        use_fix_crop_and_size: bool = False,\n    ):\n        if not isinstance(train_shards_path_or_url, str):\n            train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]\n            # flatten list using itertools\n            train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))\n\n        def get_orig_size(json):\n            if use_fix_crop_and_size:\n                return (resolution, resolution)\n            else:\n                return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))\n\n        interpolation_mode = resolve_interpolation_mode(interpolation_type)\n\n        def transform(example):\n            # resize image\n            image = example[\"image\"]\n            image = TF.resize(image, resolution, interpolation=interpolation_mode)\n\n            # get crop coordinates and crop image\n            c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))\n            image = TF.crop(image, c_top, c_left, resolution, resolution)\n            image = TF.to_tensor(image)\n            image = TF.normalize(image, [0.5], [0.5])\n\n            example[\"image\"] = image\n            example[\"crop_coords\"] = (c_top, c_left) if not use_fix_crop_and_size else (0, 0)\n            return example\n\n        processing_pipeline = [\n            wds.decode(\"pil\", handler=wds.ignore_and_continue),\n            wds.rename(\n                image=\"jpg;png;jpeg;webp\", text=\"text;txt;caption\", orig_size=\"json\", handler=wds.warn_and_continue\n            ),\n            wds.map(filter_keys({\"image\", \"text\", \"orig_size\"})),\n            wds.map_dict(orig_size=get_orig_size),\n            wds.map(transform),\n            wds.to_tuple(\"image\", \"text\", \"orig_size\", \"crop_coords\"),\n        ]\n\n        # Create train dataset and loader\n        pipeline = [\n            wds.ResampledShards(train_shards_path_or_url),\n            tarfile_to_samples_nothrow,\n            wds.select(WebdatasetFilter(min_size=MIN_SIZE)),\n            wds.shuffle(shuffle_buffer_size),\n            *processing_pipeline,\n            wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),\n        ]\n\n        num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers))  # per dataloader worker\n        num_batches = num_worker_batches * num_workers\n        num_samples = num_batches * global_batch_size\n\n        # each worker is iterating over this\n        self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)\n        self._train_dataloader = wds.WebLoader(\n            self._train_dataset,\n            batch_size=None,\n            shuffle=False,\n            num_workers=num_workers,\n            pin_memory=pin_memory,\n            persistent_workers=persistent_workers,\n        )\n        # add meta-data to dataloader instance for convenience\n        self._train_dataloader.num_batches = num_batches\n        self._train_dataloader.num_samples = num_samples\n\n    @property\n    def train_dataset(self):\n        return self._train_dataset\n\n    @property\n    def train_dataloader(self):\n        return self._train_dataloader\n\n\ndef log_validation(vae, unet, args, accelerator, weight_dtype, step):\n    logger.info(\"Running validation... \")\n    if torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)\n\n    unet = accelerator.unwrap_model(unet)\n    pipeline = StableDiffusionXLPipeline.from_pretrained(\n        args.pretrained_teacher_model,\n        vae=vae,\n        scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder=\"scheduler\"),\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    lora_state_dict = get_module_kohya_state_dict(unet, \"lora_unet\", weight_dtype)\n    pipeline.load_lora_weights(lora_state_dict)\n    pipeline.fuse_lora()\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    validation_prompts = [\n        \"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography\",\n        \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\",\n        \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\",\n        \"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece\",\n    ]\n\n    image_logs = []\n\n    for _, prompt in enumerate(validation_prompts):\n        images = []\n        with autocast_ctx:\n            images = pipeline(\n                prompt=prompt,\n                num_inference_steps=4,\n                num_images_per_prompt=4,\n                generator=generator,\n                guidance_scale=0.0,\n            ).images\n        image_logs.append({\"validation_prompt\": prompt, \"images\": images})\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                formatted_images = []\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({\"validation\": formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n        del pipeline\n        gc.collect()\n        torch.cuda.empty_cache()\n\n        return image_logs\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\")\n    return x[(...,) + (None,) * dims_to_append]\n\n\n# From LCMScheduler.get_scalings_for_boundary_condition_discrete\ndef scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):\n    scaled_timestep = timestep_scaling * timestep\n    c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)\n    c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5\n    return c_skip, c_out\n\n\n# Compare LCMScheduler.step, Step 4\ndef get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):\n    alphas = extract_into_tensor(alphas, timesteps, sample.shape)\n    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)\n    if prediction_type == \"epsilon\":\n        pred_x_0 = (sample - sigmas * model_output) / alphas\n    elif prediction_type == \"sample\":\n        pred_x_0 = model_output\n    elif prediction_type == \"v_prediction\":\n        pred_x_0 = alphas * sample - sigmas * model_output\n    else:\n        raise ValueError(\n            f\"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`\"\n            f\" are supported.\"\n        )\n\n    return pred_x_0\n\n\n# Based on step 4 in DDIMScheduler.step\ndef get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):\n    alphas = extract_into_tensor(alphas, timesteps, sample.shape)\n    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)\n    if prediction_type == \"epsilon\":\n        pred_epsilon = model_output\n    elif prediction_type == \"sample\":\n        pred_epsilon = (sample - alphas * model_output) / sigmas\n    elif prediction_type == \"v_prediction\":\n        pred_epsilon = alphas * model_output + sigmas * sample\n    else:\n        raise ValueError(\n            f\"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`\"\n            f\" are supported.\"\n        )\n\n    return pred_epsilon\n\n\ndef extract_into_tensor(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\nclass DDIMSolver:\n    def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):\n        # DDIM sampling parameters\n        step_ratio = timesteps // ddim_timesteps\n\n        self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1\n        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]\n        self.ddim_alpha_cumprods_prev = np.asarray(\n            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()\n        )\n        # convert to torch tensors\n        self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()\n        self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)\n        self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)\n\n    def to(self, device):\n        self.ddim_timesteps = self.ddim_timesteps.to(device)\n        self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)\n        self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)\n        return self\n\n    def ddim_step(self, pred_x0, pred_noise, timestep_index):\n        alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)\n        dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise\n        x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt\n        return x_prev\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    # ----------Model Checkpoint Loading Arguments----------\n    parser.add_argument(\n        \"--pretrained_teacher_model\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained LDM teacher model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--teacher_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained LDM teacher model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained LDM model identifier from huggingface.co/models.\",\n    )\n    # ----------Training Arguments----------\n    # ----General Training Arguments----\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"lcm-xl-distilled\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    # ----Logging----\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    # ----Checkpointing----\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    # ----Image Processing----\n    parser.add_argument(\n        \"--train_shards_path_or_url\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--interpolation_type\",\n        type=str,\n        default=\"bilinear\",\n        help=(\n            \"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,\"\n            \" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_fix_crop_and_size\",\n        action=\"store_true\",\n        help=\"Whether or not to use the fixed crop and size for the teacher model.\",\n        default=False,\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    # ----Dataloader----\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    # ----Batch Size and Training Steps----\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    # ----Learning Rate----\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    # ----Optimizer (Adam)----\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    # ----Diffusion Training Arguments----\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    # ----Latent Consistency Distillation (LCD) Specific Arguments----\n    parser.add_argument(\n        \"--w_min\",\n        type=float,\n        default=3.0,\n        required=False,\n        help=(\n            \"The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG\"\n            \" formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as\"\n            \" compared to the original paper.\"\n        ),\n    )\n    parser.add_argument(\n        \"--w_max\",\n        type=float,\n        default=15.0,\n        required=False,\n        help=(\n            \"The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG\"\n            \" formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as\"\n            \" compared to the original paper.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_ddim_timesteps\",\n        type=int,\n        default=50,\n        help=\"The number of timesteps to use for DDIM sampling.\",\n    )\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"l2\",\n        choices=[\"l2\", \"huber\"],\n        help=\"The type of loss to use for the LCD loss.\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=0.001,\n        help=\"The huber loss parameter. Only used if `--loss_type=huber`.\",\n    )\n    parser.add_argument(\n        \"--lora_rank\",\n        type=int,\n        default=64,\n        help=\"The rank of the LoRA projection matrix.\",\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=64,\n        help=(\n            \"The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight\"\n            \" update delta_W. No scaling will be performed if this value is equal to `lora_rank`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--lora_dropout\",\n        type=float,\n        default=0.0,\n        help=\"The dropout probability for the dropout layer added before applying the LoRA to each layer input.\",\n    )\n    parser.add_argument(\n        \"--lora_target_modules\",\n        type=str,\n        default=None,\n        help=(\n            \"A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will\"\n            \" be used. By default, LoRA will be applied to all conv and linear layers.\"\n        ),\n    )\n    parser.add_argument(\n        \"--vae_encode_batch_size\",\n        type=int,\n        default=8,\n        required=False,\n        help=(\n            \"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE.\"\n            \" Encoding or decoding the whole batch at once may run into OOM issues.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_scaling_factor\",\n        type=float,\n        default=10.0,\n        help=(\n            \"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The\"\n            \" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically\"\n            \" suffice.\"\n        ),\n    )\n    # ----Mixed Precision----\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cast_teacher_unet\",\n        action=\"store_true\",\n        help=\"Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.\",\n    )\n    # ----Training Optimizations----\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    # ----Distributed Training----\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    # ----------Validation Arguments----------\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=200,\n        help=\"Run validation every X steps.\",\n    )\n    # ----------Huggingface Hub Arguments-----------\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    # ----------Accelerate Arguments----------\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"text2image-fine-tune\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    return args\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):\n    prompt_embeds_list = []\n\n    captions = []\n    for caption in prompt_batch:\n        if random.random() < proportion_empty_prompts:\n            captions.append(\"\")\n        elif isinstance(caption, str):\n            captions.append(caption)\n        elif isinstance(caption, (list, np.ndarray)):\n            # take a random caption if there are multiple\n            captions.append(random.choice(caption) if is_train else caption[0])\n\n    with torch.no_grad():\n        for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n            text_inputs = tokenizer(\n                captions,\n                padding=\"max_length\",\n                max_length=tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            prompt_embeds = text_encoder(\n                text_input_ids.to(text_encoder.device),\n                output_hidden_states=True,\n            )\n\n            # We are only ALWAYS interested in the pooled output of the final text encoder\n            pooled_prompt_embeds = prompt_embeds[0]\n            prompt_embeds = prompt_embeds.hidden_states[-2]\n            bs_embed, seq_len, _ = prompt_embeds.shape\n            prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n            prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        split_batches=True,  # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes\n    )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n                token=args.hub_token,\n                private=True,\n            ).repo_id\n\n    # 1. Create the noise scheduler and the desired noise schedule.\n    noise_scheduler = DDPMScheduler.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"scheduler\", revision=args.teacher_revision\n    )\n\n    # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us\n    alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)\n    sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)\n    # Initialize the DDIM ODE solver for distillation.\n    solver = DDIMSolver(\n        noise_scheduler.alphas_cumprod.numpy(),\n        timesteps=noise_scheduler.config.num_train_timesteps,\n        ddim_timesteps=args.num_ddim_timesteps,\n    )\n\n    # 2. Load tokenizers from SD-XL checkpoint.\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"tokenizer\", revision=args.teacher_revision, use_fast=False\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"tokenizer_2\", revision=args.teacher_revision, use_fast=False\n    )\n\n    # 3. Load text encoders from SD-XL checkpoint.\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_teacher_model, args.teacher_revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_teacher_model, args.teacher_revision, subfolder=\"text_encoder_2\"\n    )\n\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"text_encoder\", revision=args.teacher_revision\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"text_encoder_2\", revision=args.teacher_revision\n    )\n\n    # 4. Load VAE from SD-XL checkpoint (or more stable VAE)\n    vae_path = (\n        args.pretrained_teacher_model\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.teacher_revision,\n    )\n\n    # 5. Load teacher U-Net from SD-XL checkpoint\n    teacher_unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"unet\", revision=args.teacher_revision\n    )\n\n    # 6. Freeze teacher vae, text_encoders, and teacher_unet\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    teacher_unet.requires_grad_(False)\n\n    # 7. Create online student U-Net.\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"unet\", revision=args.teacher_revision\n    )\n    unet.train()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if accelerator.unwrap_model(unet).dtype != torch.float32:\n        raise ValueError(\n            f\"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}\"\n        )\n\n    # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.\n    if args.lora_target_modules is not None:\n        lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(\",\")]\n    else:\n        lora_target_modules = [\n            \"to_q\",\n            \"to_k\",\n            \"to_v\",\n            \"to_out.0\",\n            \"proj_in\",\n            \"proj_out\",\n            \"ff.net.0.proj\",\n            \"ff.net.2\",\n            \"conv1\",\n            \"conv2\",\n            \"conv_shortcut\",\n            \"downsamplers.0.conv\",\n            \"upsamplers.0.conv\",\n            \"time_emb_proj\",\n        ]\n    lora_config = LoraConfig(\n        r=args.lora_rank,\n        target_modules=lora_target_modules,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n    )\n    unet = get_peft_model(unet, lora_config)\n\n    # 9. Handle mixed precision and device placement\n    # For mixed precision training we cast all non-trainable weights to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    vae.to(accelerator.device)\n    if args.pretrained_vae_model_name_or_path is not None:\n        vae.to(dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    # Move teacher_unet to device, optionally cast to weight_dtype\n    teacher_unet.to(accelerator.device)\n    if args.cast_teacher_unet:\n        teacher_unet.to(dtype=weight_dtype)\n\n    # Also move the alpha and sigma noise schedules to accelerator.device.\n    alpha_schedule = alpha_schedule.to(accelerator.device)\n    sigma_schedule = sigma_schedule.to(accelerator.device)\n    # Move the ODE solver to accelerator.device.\n    solver = solver.to(accelerator.device)\n\n    # 10. Handle saving and loading of checkpoints\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                unet_ = accelerator.unwrap_model(unet)\n                lora_state_dict = get_peft_model_state_dict(unet_, adapter_name=\"default\")\n                StableDiffusionXLPipeline.save_lora_weights(os.path.join(output_dir, \"unet_lora\"), lora_state_dict)\n                # save weights in peft format to be able to load them back\n                unet_.save_pretrained(output_dir)\n\n                for _, model in enumerate(models):\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            # load the LoRA into the model\n            unet_ = accelerator.unwrap_model(unet)\n            unet_.load_adapter(input_dir, \"default\", is_trainable=True)\n\n            for _ in range(len(models)):\n                # pop models so that they are not loaded again\n                models.pop()\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # 11. Enable optimizations\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n            teacher_unet.enable_xformers_memory_efficient_attention()\n            # target_unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # 12. Optimizer creation\n    optimizer = optimizer_class(\n        unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # 13. Dataset creation and data processing\n    # Here, we compute not just the text embeddings but also the additional embeddings\n    # needed for the SD XL UNet to operate.\n    def compute_embeddings(\n        prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True\n    ):\n        target_size = (args.resolution, args.resolution)\n        original_sizes = list(map(list, zip(*original_sizes)))\n        crops_coords_top_left = list(map(list, zip(*crop_coords)))\n\n        original_sizes = torch.tensor(original_sizes, dtype=torch.long)\n        crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)\n\n        prompt_embeds, pooled_prompt_embeds = encode_prompt(\n            prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train\n        )\n        add_text_embeds = pooled_prompt_embeds\n\n        # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n        add_time_ids = list(target_size)\n        add_time_ids = torch.tensor([add_time_ids])\n        add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)\n        add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)\n        add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)\n\n        prompt_embeds = prompt_embeds.to(accelerator.device)\n        add_text_embeds = add_text_embeds.to(accelerator.device)\n        unet_added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n        return {\"prompt_embeds\": prompt_embeds, **unet_added_cond_kwargs}\n\n    dataset = SDXLText2ImageDataset(\n        train_shards_path_or_url=args.train_shards_path_or_url,\n        num_train_examples=args.max_train_samples,\n        per_gpu_batch_size=args.train_batch_size,\n        global_batch_size=args.train_batch_size * accelerator.num_processes,\n        num_workers=args.dataloader_num_workers,\n        resolution=args.resolution,\n        interpolation_type=args.interpolation_type,\n        shuffle_buffer_size=1000,\n        pin_memory=True,\n        persistent_workers=True,\n        use_fix_crop_and_size=args.use_fix_crop_and_size,\n    )\n    train_dataloader = dataset.train_dataloader\n\n    # Let's first compute all the embeddings so that we can free up the text encoders\n    # from memory.\n    text_encoders = [text_encoder_one, text_encoder_two]\n    tokenizers = [tokenizer_one, tokenizer_two]\n\n    compute_embeddings_fn = functools.partial(\n        compute_embeddings,\n        proportion_empty_prompts=0,\n        text_encoders=text_encoders,\n        tokenizers=tokenizers,\n    )\n\n    # 14. LR Scheduler creation\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps,\n        num_training_steps=args.max_train_steps,\n    )\n\n    # 15. Prepare for training\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Create uncond embeds for classifier free guidance\n    uncond_prompt_embeds = torch.zeros(args.train_batch_size, 77, 2048).to(accelerator.device)\n    uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, 1280).to(accelerator.device)\n\n    # 16. Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches each epoch = {train_dataloader.num_batches}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)\n                image, text, orig_size, crop_coords = batch\n\n                image = image.to(accelerator.device, non_blocking=True)\n                encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)\n\n                if args.pretrained_vae_model_name_or_path is not None:\n                    pixel_values = image.to(dtype=weight_dtype)\n                    if vae.dtype != weight_dtype:\n                        vae.to(dtype=weight_dtype)\n                else:\n                    pixel_values = image\n\n                # encode pixel values with batch size of at most args.vae_encode_batch_size\n                latents = []\n                for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):\n                    latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())\n                latents = torch.cat(latents, dim=0)\n\n                latents = latents * vae.config.scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    latents = latents.to(weight_dtype)\n                bsz = latents.shape[0]\n\n                # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.\n                # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]\n                topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps\n                index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()\n                start_timesteps = solver.ddim_timesteps[index]\n                timesteps = start_timesteps - topk\n                timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)\n\n                # 3. Get boundary scalings for start_timesteps and (end) timesteps.\n                c_skip_start, c_out_start = scalings_for_boundary_conditions(\n                    start_timesteps, timestep_scaling=args.timestep_scaling_factor\n                )\n                c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]\n                c_skip, c_out = scalings_for_boundary_conditions(\n                    timesteps, timestep_scaling=args.timestep_scaling_factor\n                )\n                c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]\n\n                # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each\n                # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]\n                noise = torch.randn_like(latents)\n                noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)\n\n                # 5. Sample a random guidance scale w from U[w_min, w_max]\n                # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding\n                w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min\n                w = w.reshape(bsz, 1, 1, 1)\n                w = w.to(device=latents.device, dtype=latents.dtype)\n\n                # 6. Prepare prompt embeds and unet_added_conditions\n                prompt_embeds = encoded_text.pop(\"prompt_embeds\")\n\n                # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)\n                noise_pred = unet(\n                    noisy_model_input,\n                    start_timesteps,\n                    timestep_cond=None,\n                    encoder_hidden_states=prompt_embeds.float(),\n                    added_cond_kwargs=encoded_text,\n                ).sample\n\n                pred_x_0 = get_predicted_original_sample(\n                    noise_pred,\n                    start_timesteps,\n                    noisy_model_input,\n                    noise_scheduler.config.prediction_type,\n                    alpha_schedule,\n                    sigma_schedule,\n                )\n\n                model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0\n\n                # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the\n                # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these\n                # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE\n                # solver timestep.\n                with torch.no_grad():\n                    if torch.backends.mps.is_available() or \"playground\" in args.pretrained_teacher_model:\n                        autocast_ctx = nullcontext()\n                    else:\n                        autocast_ctx = torch.autocast(accelerator.device.type)\n\n                    with autocast_ctx:\n                        # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c\n                        cond_teacher_output = teacher_unet(\n                            noisy_model_input.to(weight_dtype),\n                            start_timesteps,\n                            encoder_hidden_states=prompt_embeds.to(weight_dtype),\n                            added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},\n                        ).sample\n                        cond_pred_x0 = get_predicted_original_sample(\n                            cond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n                        cond_pred_noise = get_predicted_noise(\n                            cond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n\n                        # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0\n                        uncond_added_conditions = copy.deepcopy(encoded_text)\n                        uncond_added_conditions[\"text_embeds\"] = uncond_pooled_prompt_embeds\n                        uncond_teacher_output = teacher_unet(\n                            noisy_model_input.to(weight_dtype),\n                            start_timesteps,\n                            encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),\n                            added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},\n                        ).sample\n                        uncond_pred_x0 = get_predicted_original_sample(\n                            uncond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n                        uncond_pred_noise = get_predicted_noise(\n                            uncond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n\n                        # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)\n                        # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation\n                        pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)\n                        pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)\n                        # 4. Run one step of the ODE solver to estimate the next point x_prev on the\n                        # augmented PF-ODE trajectory (solving backward in time)\n                        # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.\n                        x_prev = solver.ddim_step(pred_x0, pred_noise, index)\n\n                # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)\n                # Note that we do not use a separate target network for LCM-LoRA distillation.\n                with torch.no_grad():\n                    if torch.backends.mps.is_available():\n                        autocast_ctx = nullcontext()\n                    else:\n                        autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)\n\n                    with autocast_ctx:\n                        target_noise_pred = unet(\n                            x_prev.float(),\n                            timesteps,\n                            timestep_cond=None,\n                            encoder_hidden_states=prompt_embeds.float(),\n                            added_cond_kwargs=encoded_text,\n                        ).sample\n                    pred_x_0 = get_predicted_original_sample(\n                        target_noise_pred,\n                        timesteps,\n                        x_prev,\n                        noise_scheduler.config.prediction_type,\n                        alpha_schedule,\n                        sigma_schedule,\n                    )\n                    target = c_skip * x_prev + c_out * pred_x_0\n\n                # 10. Calculate loss\n                if args.loss_type == \"l2\":\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                elif args.loss_type == \"huber\":\n                    loss = torch.mean(\n                        torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c\n                    )\n\n                # 11. Backpropagate on the online student model (`unet`)\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if global_step % args.validation_steps == 0:\n                        log_validation(vae, unet, args, accelerator, weight_dtype, global_step)\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = accelerator.unwrap_model(unet)\n        unet.save_pretrained(args.output_dir)\n        lora_state_dict = get_peft_model_state_dict(unet, adapter_name=\"default\")\n        StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, \"unet_lora\"), lora_state_dict)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/consistency_distillation/train_lcm_distill_sd_wds.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport functools\nimport gc\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import List, Union\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport torchvision.transforms.functional as TF\nimport transformers\nimport webdataset as wds\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom braceexpand import braceexpand\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom torch.utils.data import default_collate\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig\nfrom webdataset.tariterators import (\n    base_plus_ext,\n    tar_file_expander,\n    url_opener,\n    valid_sample,\n)\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    LCMScheduler,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import resolve_interpolation_mode\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nMAX_SEQ_LENGTH = 77\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef filter_keys(key_set):\n    def _f(dictionary):\n        return {k: v for k, v in dictionary.items() if k in key_set}\n\n    return _f\n\n\ndef group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):\n    \"\"\"Return function over iterator that groups key, value pairs into samples.\n\n    :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to\n    lower case (Default value = True)\n    \"\"\"\n    current_sample = None\n    for filesample in data:\n        assert isinstance(filesample, dict)\n        fname, value = filesample[\"fname\"], filesample[\"data\"]\n        prefix, suffix = keys(fname)\n        if prefix is None:\n            continue\n        if lcase:\n            suffix = suffix.lower()\n        # FIXME webdataset version throws if suffix in current_sample, but we have a potential for\n        #  this happening in the current LAION400m dataset if a tar ends with same prefix as the next\n        #  begins, rare, but can happen since prefix aren't unique across tar files in that dataset\n        if current_sample is None or prefix != current_sample[\"__key__\"] or suffix in current_sample:\n            if valid_sample(current_sample):\n                yield current_sample\n            current_sample = {\"__key__\": prefix, \"__url__\": filesample[\"__url__\"]}\n        if suffixes is None or suffix in suffixes:\n            current_sample[suffix] = value\n    if valid_sample(current_sample):\n        yield current_sample\n\n\ndef tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):\n    # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw\n    streams = url_opener(src, handler=handler)\n    files = tar_file_expander(streams, handler=handler)\n    samples = group_by_keys_nothrow(files, handler=handler)\n    return samples\n\n\nclass WebdatasetFilter:\n    def __init__(self, min_size=1024, max_pwatermark=0.5):\n        self.min_size = min_size\n        self.max_pwatermark = max_pwatermark\n\n    def __call__(self, x):\n        try:\n            if \"json\" in x:\n                x_json = json.loads(x[\"json\"])\n                filter_size = (x_json.get(\"original_width\", 0.0) or 0.0) >= self.min_size and x_json.get(\n                    \"original_height\", 0\n                ) >= self.min_size\n                filter_watermark = (x_json.get(\"pwatermark\", 1.0) or 1.0) <= self.max_pwatermark\n                return filter_size and filter_watermark\n            else:\n                return False\n        except Exception:\n            return False\n\n\nclass SDText2ImageDataset:\n    def __init__(\n        self,\n        train_shards_path_or_url: Union[str, List[str]],\n        num_train_examples: int,\n        per_gpu_batch_size: int,\n        global_batch_size: int,\n        num_workers: int,\n        resolution: int = 512,\n        interpolation_type: str = \"bilinear\",\n        shuffle_buffer_size: int = 1000,\n        pin_memory: bool = False,\n        persistent_workers: bool = False,\n    ):\n        if not isinstance(train_shards_path_or_url, str):\n            train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]\n            # flatten list using itertools\n            train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))\n\n        interpolation_mode = resolve_interpolation_mode(interpolation_type)\n\n        def transform(example):\n            # resize image\n            image = example[\"image\"]\n            image = TF.resize(image, resolution, interpolation=interpolation_mode)\n\n            # get crop coordinates and crop image\n            c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))\n            image = TF.crop(image, c_top, c_left, resolution, resolution)\n            image = TF.to_tensor(image)\n            image = TF.normalize(image, [0.5], [0.5])\n\n            example[\"image\"] = image\n            return example\n\n        processing_pipeline = [\n            wds.decode(\"pil\", handler=wds.ignore_and_continue),\n            wds.rename(image=\"jpg;png;jpeg;webp\", text=\"text;txt;caption\", handler=wds.warn_and_continue),\n            wds.map(filter_keys({\"image\", \"text\"})),\n            wds.map(transform),\n            wds.to_tuple(\"image\", \"text\"),\n        ]\n\n        # Create train dataset and loader\n        pipeline = [\n            wds.ResampledShards(train_shards_path_or_url),\n            tarfile_to_samples_nothrow,\n            wds.shuffle(shuffle_buffer_size),\n            *processing_pipeline,\n            wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),\n        ]\n\n        num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers))  # per dataloader worker\n        num_batches = num_worker_batches * num_workers\n        num_samples = num_batches * global_batch_size\n\n        # each worker is iterating over this\n        self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)\n        self._train_dataloader = wds.WebLoader(\n            self._train_dataset,\n            batch_size=None,\n            shuffle=False,\n            num_workers=num_workers,\n            pin_memory=pin_memory,\n            persistent_workers=persistent_workers,\n        )\n        # add meta-data to dataloader instance for convenience\n        self._train_dataloader.num_batches = num_batches\n        self._train_dataloader.num_samples = num_samples\n\n    @property\n    def train_dataset(self):\n        return self._train_dataset\n\n    @property\n    def train_dataloader(self):\n        return self._train_dataloader\n\n\ndef log_validation(vae, unet, args, accelerator, weight_dtype, step, name=\"target\"):\n    logger.info(\"Running validation... \")\n\n    unet = accelerator.unwrap_model(unet)\n    pipeline = StableDiffusionPipeline.from_pretrained(\n        args.pretrained_teacher_model,\n        vae=vae,\n        unet=unet,\n        scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder=\"scheduler\"),\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    validation_prompts = [\n        \"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography\",\n        \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\",\n        \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\",\n        \"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece\",\n    ]\n\n    image_logs = []\n\n    for _, prompt in enumerate(validation_prompts):\n        images = []\n        if torch.backends.mps.is_available():\n            autocast_ctx = nullcontext()\n        else:\n            autocast_ctx = torch.autocast(accelerator.device.type)\n\n        with autocast_ctx:\n            images = pipeline(\n                prompt=prompt,\n                num_inference_steps=4,\n                num_images_per_prompt=4,\n                generator=generator,\n            ).images\n        image_logs.append({\"validation_prompt\": prompt, \"images\": images})\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                formatted_images = []\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({f\"validation/{name}\": formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n        del pipeline\n        gc.collect()\n        torch.cuda.empty_cache()\n\n        return image_logs\n\n\n# From LatentConsistencyModel.get_guidance_scale_embedding\ndef guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):\n    \"\"\"\n    See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n    Args:\n        timesteps (`torch.Tensor`):\n            generate embedding vectors at these timesteps\n        embedding_dim (`int`, *optional*, defaults to 512):\n            dimension of the embeddings to generate\n        dtype:\n            data type of the generated embeddings\n\n    Returns:\n        `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n    \"\"\"\n    assert len(w.shape) == 1\n    w = w * 1000.0\n\n    half_dim = embedding_dim // 2\n    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n    emb = w.to(dtype)[:, None] * emb[None, :]\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n    if embedding_dim % 2 == 1:  # zero pad\n        emb = torch.nn.functional.pad(emb, (0, 1))\n    assert emb.shape == (w.shape[0], embedding_dim)\n    return emb\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\")\n    return x[(...,) + (None,) * dims_to_append]\n\n\n# From LCMScheduler.get_scalings_for_boundary_condition_discrete\ndef scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):\n    scaled_timestep = timestep_scaling * timestep\n    c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)\n    c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5\n    return c_skip, c_out\n\n\n# Compare LCMScheduler.step, Step 4\ndef get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):\n    alphas = extract_into_tensor(alphas, timesteps, sample.shape)\n    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)\n    if prediction_type == \"epsilon\":\n        pred_x_0 = (sample - sigmas * model_output) / alphas\n    elif prediction_type == \"sample\":\n        pred_x_0 = model_output\n    elif prediction_type == \"v_prediction\":\n        pred_x_0 = alphas * sample - sigmas * model_output\n    else:\n        raise ValueError(\n            f\"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`\"\n            f\" are supported.\"\n        )\n\n    return pred_x_0\n\n\n# Based on step 4 in DDIMScheduler.step\ndef get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):\n    alphas = extract_into_tensor(alphas, timesteps, sample.shape)\n    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)\n    if prediction_type == \"epsilon\":\n        pred_epsilon = model_output\n    elif prediction_type == \"sample\":\n        pred_epsilon = (sample - alphas * model_output) / sigmas\n    elif prediction_type == \"v_prediction\":\n        pred_epsilon = alphas * model_output + sigmas * sample\n    else:\n        raise ValueError(\n            f\"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`\"\n            f\" are supported.\"\n        )\n\n    return pred_epsilon\n\n\ndef extract_into_tensor(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\nclass DDIMSolver:\n    def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):\n        # DDIM sampling parameters\n        step_ratio = timesteps // ddim_timesteps\n        self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1\n        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]\n        self.ddim_alpha_cumprods_prev = np.asarray(\n            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()\n        )\n        # convert to torch tensors\n        self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()\n        self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)\n        self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)\n\n    def to(self, device):\n        self.ddim_timesteps = self.ddim_timesteps.to(device)\n        self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)\n        self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)\n        return self\n\n    def ddim_step(self, pred_x0, pred_noise, timestep_index):\n        alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)\n        dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise\n        x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt\n        return x_prev\n\n\n@torch.no_grad()\ndef update_ema(target_params, source_params, rate=0.99):\n    \"\"\"\n    Update target parameters to be closer to those of source parameters using\n    an exponential moving average.\n\n    :param target_params: the target parameter sequence.\n    :param source_params: the source parameter sequence.\n    :param rate: the EMA rate (closer to 1 means slower).\n    \"\"\"\n    for targ, src in zip(target_params, source_params):\n        targ.detach().mul_(rate).add_(src, alpha=1 - rate)\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    # ----------Model Checkpoint Loading Arguments----------\n    parser.add_argument(\n        \"--pretrained_teacher_model\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained LDM teacher model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--teacher_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained LDM teacher model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained LDM model identifier from huggingface.co/models.\",\n    )\n    # ----------Training Arguments----------\n    # ----General Training Arguments----\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"lcm-xl-distilled\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    # ----Logging----\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    # ----Checkpointing----\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    # ----Image Processing----\n    parser.add_argument(\n        \"--train_shards_path_or_url\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--interpolation_type\",\n        type=str,\n        default=\"bilinear\",\n        help=(\n            \"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,\"\n            \" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    # ----Dataloader----\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    # ----Batch Size and Training Steps----\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    # ----Learning Rate----\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    # ----Optimizer (Adam)----\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    # ----Diffusion Training Arguments----\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    # ----Latent Consistency Distillation (LCD) Specific Arguments----\n    parser.add_argument(\n        \"--w_min\",\n        type=float,\n        default=5.0,\n        required=False,\n        help=(\n            \"The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG\"\n            \" formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as\"\n            \" compared to the original paper.\"\n        ),\n    )\n    parser.add_argument(\n        \"--w_max\",\n        type=float,\n        default=15.0,\n        required=False,\n        help=(\n            \"The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG\"\n            \" formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as\"\n            \" compared to the original paper.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_ddim_timesteps\",\n        type=int,\n        default=50,\n        help=\"The number of timesteps to use for DDIM sampling.\",\n    )\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"l2\",\n        choices=[\"l2\", \"huber\"],\n        help=\"The type of loss to use for the LCD loss.\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=0.001,\n        help=\"The huber loss parameter. Only used if `--loss_type=huber`.\",\n    )\n    parser.add_argument(\n        \"--unet_time_cond_proj_dim\",\n        type=int,\n        default=256,\n        help=(\n            \"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net\"\n            \" does not have `time_cond_proj_dim` set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--vae_encode_batch_size\",\n        type=int,\n        default=32,\n        required=False,\n        help=(\n            \"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE.\"\n            \" Encoding or decoding the whole batch at once may run into OOM issues.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_scaling_factor\",\n        type=float,\n        default=10.0,\n        help=(\n            \"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The\"\n            \" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically\"\n            \" suffice.\"\n        ),\n    )\n    # ----Exponential Moving Average (EMA)----\n    parser.add_argument(\n        \"--ema_decay\",\n        type=float,\n        default=0.95,\n        required=False,\n        help=\"The exponential moving average (EMA) rate or decay factor.\",\n    )\n    # ----Mixed Precision----\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cast_teacher_unet\",\n        action=\"store_true\",\n        help=\"Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.\",\n    )\n    # ----Training Optimizations----\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    # ----Distributed Training----\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    # ----------Validation Arguments----------\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=200,\n        help=\"Run validation every X steps.\",\n    )\n    # ----------Huggingface Hub Arguments-----------\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    # ----------Accelerate Arguments----------\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"text2image-fine-tune\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    return args\n\n\n# Adapted from pipelines.StableDiffusionPipeline.encode_prompt\ndef encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):\n    captions = []\n    for caption in prompt_batch:\n        if random.random() < proportion_empty_prompts:\n            captions.append(\"\")\n        elif isinstance(caption, str):\n            captions.append(caption)\n        elif isinstance(caption, (list, np.ndarray)):\n            # take a random caption if there are multiple\n            captions.append(random.choice(caption) if is_train else caption[0])\n\n    with torch.no_grad():\n        text_inputs = tokenizer(\n            captions,\n            padding=\"max_length\",\n            max_length=tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]\n\n    return prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        split_batches=True,  # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes\n    )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n                token=args.hub_token,\n                private=True,\n            ).repo_id\n\n    # 1. Create the noise scheduler and the desired noise schedule.\n    noise_scheduler = DDPMScheduler.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"scheduler\", revision=args.teacher_revision\n    )\n\n    # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us\n    alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)\n    sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)\n    # Initialize the DDIM ODE solver for distillation.\n    solver = DDIMSolver(\n        noise_scheduler.alphas_cumprod.numpy(),\n        timesteps=noise_scheduler.config.num_train_timesteps,\n        ddim_timesteps=args.num_ddim_timesteps,\n    )\n\n    # 2. Load tokenizers from SD 1.X/2.X checkpoint.\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"tokenizer\", revision=args.teacher_revision, use_fast=False\n    )\n\n    # 3. Load text encoders from SD 1.X/2.X checkpoint.\n    # import correct text encoder classes\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"text_encoder\", revision=args.teacher_revision\n    )\n\n    # 4. Load VAE from SD 1.X/2.X checkpoint\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_teacher_model,\n        subfolder=\"vae\",\n        revision=args.teacher_revision,\n    )\n\n    # 5. Load teacher U-Net from SD 1.X/2.X checkpoint\n    teacher_unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"unet\", revision=args.teacher_revision\n    )\n\n    # 6. Freeze teacher vae, text_encoder, and teacher_unet\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    teacher_unet.requires_grad_(False)\n\n    # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)\n    # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None\n    time_cond_proj_dim = (\n        teacher_unet.config.time_cond_proj_dim\n        if teacher_unet.config.time_cond_proj_dim is not None\n        else args.unet_time_cond_proj_dim\n    )\n    unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)\n    # load teacher_unet weights into unet\n    unet.load_state_dict(teacher_unet.state_dict(), strict=False)\n    unet.train()\n\n    # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).\n    # Initialize from (online) unet\n    target_unet = UNet2DConditionModel.from_config(unet.config)\n    target_unet.load_state_dict(unet.state_dict())\n    target_unet.train()\n    target_unet.requires_grad_(False)\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if accelerator.unwrap_model(unet).dtype != torch.float32:\n        raise ValueError(\n            f\"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}\"\n        )\n\n    # 9. Handle mixed precision and device placement\n    # For mixed precision training we cast all non-trainable weights to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    vae.to(accelerator.device)\n    if args.pretrained_vae_model_name_or_path is not None:\n        vae.to(dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # Move teacher_unet to device, optionally cast to weight_dtype\n    target_unet.to(accelerator.device)\n    teacher_unet.to(accelerator.device)\n    if args.cast_teacher_unet:\n        teacher_unet.to(dtype=weight_dtype)\n\n    # Also move the alpha and sigma noise schedules to accelerator.device.\n    alpha_schedule = alpha_schedule.to(accelerator.device)\n    sigma_schedule = sigma_schedule.to(accelerator.device)\n    solver = solver.to(accelerator.device)\n\n    # 10. Handle saving and loading of checkpoints\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                target_unet.save_pretrained(os.path.join(output_dir, \"unet_target\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, \"unet_target\"))\n            target_unet.load_state_dict(load_model.state_dict())\n            target_unet.to(accelerator.device)\n            del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # 11. Enable optimizations\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n            teacher_unet.enable_xformers_memory_efficient_attention()\n            target_unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # 12. Optimizer creation\n    optimizer = optimizer_class(\n        unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # 13. Dataset creation and data processing\n    # Here, we compute not just the text embeddings but also the additional embeddings\n    # needed for the SD XL UNet to operate.\n    def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):\n        prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)\n        return {\"prompt_embeds\": prompt_embeds}\n\n    dataset = SDText2ImageDataset(\n        train_shards_path_or_url=args.train_shards_path_or_url,\n        num_train_examples=args.max_train_samples,\n        per_gpu_batch_size=args.train_batch_size,\n        global_batch_size=args.train_batch_size * accelerator.num_processes,\n        num_workers=args.dataloader_num_workers,\n        resolution=args.resolution,\n        interpolation_type=args.interpolation_type,\n        shuffle_buffer_size=1000,\n        pin_memory=True,\n        persistent_workers=True,\n    )\n    train_dataloader = dataset.train_dataloader\n\n    compute_embeddings_fn = functools.partial(\n        compute_embeddings,\n        proportion_empty_prompts=0,\n        text_encoder=text_encoder,\n        tokenizer=tokenizer,\n    )\n\n    # 14. LR Scheduler creation\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps,\n        num_training_steps=args.max_train_steps,\n    )\n\n    # 15. Prepare for training\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    uncond_input_ids = tokenizer(\n        [\"\"] * args.train_batch_size, return_tensors=\"pt\", padding=\"max_length\", max_length=77\n    ).input_ids.to(accelerator.device)\n    uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]\n\n    # 16. Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches each epoch = {train_dataloader.num_batches}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # 1. Load and process the image and text conditioning\n                image, text = batch\n\n                image = image.to(accelerator.device, non_blocking=True)\n                encoded_text = compute_embeddings_fn(text)\n\n                pixel_values = image.to(dtype=weight_dtype)\n                if vae.dtype != weight_dtype:\n                    vae.to(dtype=weight_dtype)\n\n                # encode pixel values with batch size of at most args.vae_encode_batch_size\n                latents = []\n                for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):\n                    latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())\n                latents = torch.cat(latents, dim=0)\n\n                latents = latents * vae.config.scaling_factor\n                latents = latents.to(weight_dtype)\n                bsz = latents.shape[0]\n\n                # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.\n                # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]\n                topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps\n                index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()\n                start_timesteps = solver.ddim_timesteps[index]\n                timesteps = start_timesteps - topk\n                timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)\n\n                # 3. Get boundary scalings for start_timesteps and (end) timesteps.\n                c_skip_start, c_out_start = scalings_for_boundary_conditions(\n                    start_timesteps, timestep_scaling=args.timestep_scaling_factor\n                )\n                c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]\n                c_skip, c_out = scalings_for_boundary_conditions(\n                    timesteps, timestep_scaling=args.timestep_scaling_factor\n                )\n                c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]\n\n                # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each\n                # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]\n                noise = torch.randn_like(latents)\n                noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)\n\n                # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it\n                w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min\n                w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)\n                w = w.reshape(bsz, 1, 1, 1)\n                # Move to U-Net device and dtype\n                w = w.to(device=latents.device, dtype=latents.dtype)\n                w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)\n\n                # 6. Prepare prompt embeds and unet_added_conditions\n                prompt_embeds = encoded_text.pop(\"prompt_embeds\")\n\n                # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)\n                noise_pred = unet(\n                    noisy_model_input,\n                    start_timesteps,\n                    timestep_cond=w_embedding,\n                    encoder_hidden_states=prompt_embeds.float(),\n                    added_cond_kwargs=encoded_text,\n                ).sample\n\n                pred_x_0 = get_predicted_original_sample(\n                    noise_pred,\n                    start_timesteps,\n                    noisy_model_input,\n                    noise_scheduler.config.prediction_type,\n                    alpha_schedule,\n                    sigma_schedule,\n                )\n\n                model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0\n\n                # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the\n                # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these\n                # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE\n                # solver timestep.\n                with torch.no_grad():\n                    if torch.backends.mps.is_available():\n                        autocast_ctx = nullcontext()\n                    else:\n                        autocast_ctx = torch.autocast(accelerator.device.type)\n\n                    with autocast_ctx:\n                        # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c\n                        cond_teacher_output = teacher_unet(\n                            noisy_model_input.to(weight_dtype),\n                            start_timesteps,\n                            encoder_hidden_states=prompt_embeds.to(weight_dtype),\n                        ).sample\n                        cond_pred_x0 = get_predicted_original_sample(\n                            cond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n                        cond_pred_noise = get_predicted_noise(\n                            cond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n\n                        # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0\n                        uncond_teacher_output = teacher_unet(\n                            noisy_model_input.to(weight_dtype),\n                            start_timesteps,\n                            encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),\n                        ).sample\n                        uncond_pred_x0 = get_predicted_original_sample(\n                            uncond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n                        uncond_pred_noise = get_predicted_noise(\n                            uncond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n\n                        # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)\n                        # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation\n                        pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)\n                        pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)\n                        # 4. Run one step of the ODE solver to estimate the next point x_prev on the\n                        # augmented PF-ODE trajectory (solving backward in time)\n                        # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.\n                        x_prev = solver.ddim_step(pred_x0, pred_noise, index)\n\n                # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)\n                with torch.no_grad():\n                    if torch.backends.mps.is_available():\n                        autocast_ctx = nullcontext()\n                    else:\n                        autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)\n\n                    with autocast_ctx:\n                        target_noise_pred = target_unet(\n                            x_prev.float(),\n                            timesteps,\n                            timestep_cond=w_embedding,\n                            encoder_hidden_states=prompt_embeds.float(),\n                        ).sample\n                    pred_x_0 = get_predicted_original_sample(\n                        target_noise_pred,\n                        timesteps,\n                        x_prev,\n                        noise_scheduler.config.prediction_type,\n                        alpha_schedule,\n                        sigma_schedule,\n                    )\n                    target = c_skip * x_prev + c_out * pred_x_0\n\n                # 10. Calculate loss\n                if args.loss_type == \"l2\":\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                elif args.loss_type == \"huber\":\n                    loss = torch.mean(\n                        torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c\n                    )\n\n                # 11. Backpropagate on the online student model (`unet`)\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                # 12. Make EMA update to target student model parameters (`target_unet`)\n                update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if global_step % args.validation_steps == 0:\n                        log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, \"target\")\n                        log_validation(vae, unet, args, accelerator, weight_dtype, global_step, \"online\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = accelerator.unwrap_model(unet)\n        unet.save_pretrained(os.path.join(args.output_dir, \"unet\"))\n\n        target_unet = accelerator.unwrap_model(target_unet)\n        target_unet.save_pretrained(os.path.join(args.output_dir, \"unet_target\"))\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/consistency_distillation/train_lcm_distill_sdxl_wds.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport functools\nimport gc\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import List, Union\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport torchvision.transforms.functional as TF\nimport transformers\nimport webdataset as wds\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom braceexpand import braceexpand\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom torch.utils.data import default_collate\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\nfrom webdataset.tariterators import (\n    base_plus_ext,\n    tar_file_expander,\n    url_opener,\n    valid_sample,\n)\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    LCMScheduler,\n    StableDiffusionXLPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import resolve_interpolation_mode\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nMAX_SEQ_LENGTH = 77\n\n# Adjust for your dataset\nWDS_JSON_WIDTH = \"width\"  # original_width for LAION\nWDS_JSON_HEIGHT = \"height\"  # original_height for LAION\nMIN_SIZE = 700  # ~960 for LAION, ideal: 1024 if the dataset contains large images\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef filter_keys(key_set):\n    def _f(dictionary):\n        return {k: v for k, v in dictionary.items() if k in key_set}\n\n    return _f\n\n\ndef group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):\n    \"\"\"Return function over iterator that groups key, value pairs into samples.\n\n    :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to\n    lower case (Default value = True)\n    \"\"\"\n    current_sample = None\n    for filesample in data:\n        assert isinstance(filesample, dict)\n        fname, value = filesample[\"fname\"], filesample[\"data\"]\n        prefix, suffix = keys(fname)\n        if prefix is None:\n            continue\n        if lcase:\n            suffix = suffix.lower()\n        # FIXME webdataset version throws if suffix in current_sample, but we have a potential for\n        #  this happening in the current LAION400m dataset if a tar ends with same prefix as the next\n        #  begins, rare, but can happen since prefix aren't unique across tar files in that dataset\n        if current_sample is None or prefix != current_sample[\"__key__\"] or suffix in current_sample:\n            if valid_sample(current_sample):\n                yield current_sample\n            current_sample = {\"__key__\": prefix, \"__url__\": filesample[\"__url__\"]}\n        if suffixes is None or suffix in suffixes:\n            current_sample[suffix] = value\n    if valid_sample(current_sample):\n        yield current_sample\n\n\ndef tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):\n    # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw\n    streams = url_opener(src, handler=handler)\n    files = tar_file_expander(streams, handler=handler)\n    samples = group_by_keys_nothrow(files, handler=handler)\n    return samples\n\n\nclass WebdatasetFilter:\n    def __init__(self, min_size=1024, max_pwatermark=0.5):\n        self.min_size = min_size\n        self.max_pwatermark = max_pwatermark\n\n    def __call__(self, x):\n        try:\n            if \"json\" in x:\n                x_json = json.loads(x[\"json\"])\n                filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get(\n                    WDS_JSON_HEIGHT, 0\n                ) >= self.min_size\n                filter_watermark = (x_json.get(\"pwatermark\", 0.0) or 0.0) <= self.max_pwatermark\n                return filter_size and filter_watermark\n            else:\n                return False\n        except Exception:\n            return False\n\n\nclass SDXLText2ImageDataset:\n    def __init__(\n        self,\n        train_shards_path_or_url: Union[str, List[str]],\n        num_train_examples: int,\n        per_gpu_batch_size: int,\n        global_batch_size: int,\n        num_workers: int,\n        resolution: int = 1024,\n        interpolation_type: str = \"bilinear\",\n        shuffle_buffer_size: int = 1000,\n        pin_memory: bool = False,\n        persistent_workers: bool = False,\n        use_fix_crop_and_size: bool = False,\n    ):\n        if not isinstance(train_shards_path_or_url, str):\n            train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]\n            # flatten list using itertools\n            train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))\n\n        def get_orig_size(json):\n            if use_fix_crop_and_size:\n                return (resolution, resolution)\n            else:\n                return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0)))\n\n        interpolation_mode = resolve_interpolation_mode(interpolation_type)\n\n        def transform(example):\n            # resize image\n            image = example[\"image\"]\n            image = TF.resize(image, resolution, interpolation=interpolation_mode)\n\n            # get crop coordinates and crop image\n            c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))\n            image = TF.crop(image, c_top, c_left, resolution, resolution)\n            image = TF.to_tensor(image)\n            image = TF.normalize(image, [0.5], [0.5])\n\n            example[\"image\"] = image\n            example[\"crop_coords\"] = (c_top, c_left) if not use_fix_crop_and_size else (0, 0)\n            return example\n\n        processing_pipeline = [\n            wds.decode(\"pil\", handler=wds.ignore_and_continue),\n            wds.rename(\n                image=\"jpg;png;jpeg;webp\", text=\"text;txt;caption\", orig_size=\"json\", handler=wds.warn_and_continue\n            ),\n            wds.map(filter_keys({\"image\", \"text\", \"orig_size\"})),\n            wds.map_dict(orig_size=get_orig_size),\n            wds.map(transform),\n            wds.to_tuple(\"image\", \"text\", \"orig_size\", \"crop_coords\"),\n        ]\n\n        # Create train dataset and loader\n        pipeline = [\n            wds.ResampledShards(train_shards_path_or_url),\n            tarfile_to_samples_nothrow,\n            wds.select(WebdatasetFilter(min_size=MIN_SIZE)),\n            wds.shuffle(shuffle_buffer_size),\n            *processing_pipeline,\n            wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),\n        ]\n\n        num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers))  # per dataloader worker\n        num_batches = num_worker_batches * num_workers\n        num_samples = num_batches * global_batch_size\n\n        # each worker is iterating over this\n        self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)\n        self._train_dataloader = wds.WebLoader(\n            self._train_dataset,\n            batch_size=None,\n            shuffle=False,\n            num_workers=num_workers,\n            pin_memory=pin_memory,\n            persistent_workers=persistent_workers,\n        )\n        # add meta-data to dataloader instance for convenience\n        self._train_dataloader.num_batches = num_batches\n        self._train_dataloader.num_samples = num_samples\n\n    @property\n    def train_dataset(self):\n        return self._train_dataset\n\n    @property\n    def train_dataloader(self):\n        return self._train_dataloader\n\n\ndef log_validation(vae, unet, args, accelerator, weight_dtype, step, name=\"target\"):\n    logger.info(\"Running validation... \")\n\n    unet = accelerator.unwrap_model(unet)\n    pipeline = StableDiffusionXLPipeline.from_pretrained(\n        args.pretrained_teacher_model,\n        vae=vae,\n        unet=unet,\n        scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder=\"scheduler\"),\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    validation_prompts = [\n        \"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography\",\n        \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\",\n        \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\",\n        \"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece\",\n    ]\n\n    image_logs = []\n\n    for _, prompt in enumerate(validation_prompts):\n        images = []\n        if torch.backends.mps.is_available():\n            autocast_ctx = nullcontext()\n        else:\n            autocast_ctx = torch.autocast(accelerator.device.type)\n\n        with autocast_ctx:\n            images = pipeline(\n                prompt=prompt,\n                num_inference_steps=4,\n                num_images_per_prompt=4,\n                generator=generator,\n            ).images\n        image_logs.append({\"validation_prompt\": prompt, \"images\": images})\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                formatted_images = []\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({f\"validation/{name}\": formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n        del pipeline\n        gc.collect()\n        torch.cuda.empty_cache()\n\n        return image_logs\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\")\n    return x[(...,) + (None,) * dims_to_append]\n\n\n# From LCMScheduler.get_scalings_for_boundary_condition_discrete\ndef scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):\n    scaled_timestep = timestep_scaling * timestep\n    c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)\n    c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5\n    return c_skip, c_out\n\n\n# Compare LCMScheduler.step, Step 4\ndef get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):\n    alphas = extract_into_tensor(alphas, timesteps, sample.shape)\n    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)\n    if prediction_type == \"epsilon\":\n        pred_x_0 = (sample - sigmas * model_output) / alphas\n    elif prediction_type == \"sample\":\n        pred_x_0 = model_output\n    elif prediction_type == \"v_prediction\":\n        pred_x_0 = alphas * sample - sigmas * model_output\n    else:\n        raise ValueError(\n            f\"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`\"\n            f\" are supported.\"\n        )\n\n    return pred_x_0\n\n\n# Based on step 4 in DDIMScheduler.step\ndef get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):\n    alphas = extract_into_tensor(alphas, timesteps, sample.shape)\n    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)\n    if prediction_type == \"epsilon\":\n        pred_epsilon = model_output\n    elif prediction_type == \"sample\":\n        pred_epsilon = (sample - alphas * model_output) / sigmas\n    elif prediction_type == \"v_prediction\":\n        pred_epsilon = alphas * model_output + sigmas * sample\n    else:\n        raise ValueError(\n            f\"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`\"\n            f\" are supported.\"\n        )\n\n    return pred_epsilon\n\n\ndef extract_into_tensor(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\n@torch.no_grad()\ndef update_ema(target_params, source_params, rate=0.99):\n    \"\"\"\n    Update target parameters to be closer to those of source parameters using\n    an exponential moving average.\n\n    :param target_params: the target parameter sequence.\n    :param source_params: the source parameter sequence.\n    :param rate: the EMA rate (closer to 1 means slower).\n    \"\"\"\n    for targ, src in zip(target_params, source_params):\n        targ.detach().mul_(rate).add_(src, alpha=1 - rate)\n\n\n# From LatentConsistencyModel.get_guidance_scale_embedding\ndef guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):\n    \"\"\"\n    See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n    Args:\n        timesteps (`torch.Tensor`):\n            generate embedding vectors at these timesteps\n        embedding_dim (`int`, *optional*, defaults to 512):\n            dimension of the embeddings to generate\n        dtype:\n            data type of the generated embeddings\n\n    Returns:\n        `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n    \"\"\"\n    assert len(w.shape) == 1\n    w = w * 1000.0\n\n    half_dim = embedding_dim // 2\n    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n    emb = w.to(dtype)[:, None] * emb[None, :]\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n    if embedding_dim % 2 == 1:  # zero pad\n        emb = torch.nn.functional.pad(emb, (0, 1))\n    assert emb.shape == (w.shape[0], embedding_dim)\n    return emb\n\n\nclass DDIMSolver:\n    def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):\n        # DDIM sampling parameters\n        step_ratio = timesteps // ddim_timesteps\n\n        self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1\n        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]\n        self.ddim_alpha_cumprods_prev = np.asarray(\n            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()\n        )\n        # convert to torch tensors\n        self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()\n        self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)\n        self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)\n\n    def to(self, device):\n        self.ddim_timesteps = self.ddim_timesteps.to(device)\n        self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)\n        self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)\n        return self\n\n    def ddim_step(self, pred_x0, pred_noise, timestep_index):\n        alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)\n        dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise\n        x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt\n        return x_prev\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    # ----------Model Checkpoint Loading Arguments----------\n    parser.add_argument(\n        \"--pretrained_teacher_model\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained LDM teacher model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--teacher_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained LDM teacher model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained LDM model identifier from huggingface.co/models.\",\n    )\n    # ----------Training Arguments----------\n    # ----General Training Arguments----\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"lcm-xl-distilled\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    # ----Logging----\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    # ----Checkpointing----\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    # ----Image Processing----\n    parser.add_argument(\n        \"--train_shards_path_or_url\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--interpolation_type\",\n        type=str,\n        default=\"bilinear\",\n        help=(\n            \"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,\"\n            \" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_fix_crop_and_size\",\n        action=\"store_true\",\n        help=\"Whether or not to use the fixed crop and size for the teacher model.\",\n        default=False,\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    # ----Dataloader----\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    # ----Batch Size and Training Steps----\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    # ----Learning Rate----\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    # ----Optimizer (Adam)----\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    # ----Diffusion Training Arguments----\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    # ----Latent Consistency Distillation (LCD) Specific Arguments----\n    parser.add_argument(\n        \"--w_min\",\n        type=float,\n        default=3.0,\n        required=False,\n        help=(\n            \"The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG\"\n            \" formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as\"\n            \" compared to the original paper.\"\n        ),\n    )\n    parser.add_argument(\n        \"--w_max\",\n        type=float,\n        default=15.0,\n        required=False,\n        help=(\n            \"The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG\"\n            \" formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as\"\n            \" compared to the original paper.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_ddim_timesteps\",\n        type=int,\n        default=50,\n        help=\"The number of timesteps to use for DDIM sampling.\",\n    )\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"l2\",\n        choices=[\"l2\", \"huber\"],\n        help=\"The type of loss to use for the LCD loss.\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=0.001,\n        help=\"The huber loss parameter. Only used if `--loss_type=huber`.\",\n    )\n    parser.add_argument(\n        \"--unet_time_cond_proj_dim\",\n        type=int,\n        default=256,\n        help=(\n            \"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net\"\n            \" does not have `time_cond_proj_dim` set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--vae_encode_batch_size\",\n        type=int,\n        default=8,\n        required=False,\n        help=(\n            \"The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE.\"\n            \" Encoding or decoding the whole batch at once may run into OOM issues.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_scaling_factor\",\n        type=float,\n        default=10.0,\n        help=(\n            \"The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The\"\n            \" higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically\"\n            \" suffice.\"\n        ),\n    )\n    # ----Exponential Moving Average (EMA)----\n    parser.add_argument(\n        \"--ema_decay\",\n        type=float,\n        default=0.95,\n        required=False,\n        help=\"The exponential moving average (EMA) rate or decay factor.\",\n    )\n    # ----Mixed Precision----\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cast_teacher_unet\",\n        action=\"store_true\",\n        help=\"Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.\",\n    )\n    # ----Training Optimizations----\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    # ----Distributed Training----\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    # ----------Validation Arguments----------\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=200,\n        help=\"Run validation every X steps.\",\n    )\n    # ----------Huggingface Hub Arguments-----------\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    # ----------Accelerate Arguments----------\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"text2image-fine-tune\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    return args\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):\n    prompt_embeds_list = []\n\n    captions = []\n    for caption in prompt_batch:\n        if random.random() < proportion_empty_prompts:\n            captions.append(\"\")\n        elif isinstance(caption, str):\n            captions.append(caption)\n        elif isinstance(caption, (list, np.ndarray)):\n            # take a random caption if there are multiple\n            captions.append(random.choice(caption) if is_train else caption[0])\n\n    with torch.no_grad():\n        for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n            text_inputs = tokenizer(\n                captions,\n                padding=\"max_length\",\n                max_length=tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            prompt_embeds = text_encoder(\n                text_input_ids.to(text_encoder.device),\n                output_hidden_states=True,\n            )\n\n            # We are only ALWAYS interested in the pooled output of the final text encoder\n            pooled_prompt_embeds = prompt_embeds[0]\n            prompt_embeds = prompt_embeds.hidden_states[-2]\n            bs_embed, seq_len, _ = prompt_embeds.shape\n            prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n            prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        split_batches=True,  # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes\n    )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n                token=args.hub_token,\n                private=True,\n            ).repo_id\n\n    # 1. Create the noise scheduler and the desired noise schedule.\n    noise_scheduler = DDPMScheduler.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"scheduler\", revision=args.teacher_revision\n    )\n\n    # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us\n    alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)\n    sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)\n    # Initialize the DDIM ODE solver for distillation.\n    solver = DDIMSolver(\n        noise_scheduler.alphas_cumprod.numpy(),\n        timesteps=noise_scheduler.config.num_train_timesteps,\n        ddim_timesteps=args.num_ddim_timesteps,\n    )\n\n    # 2. Load tokenizers from SD-XL checkpoint.\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"tokenizer\", revision=args.teacher_revision, use_fast=False\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"tokenizer_2\", revision=args.teacher_revision, use_fast=False\n    )\n\n    # 3. Load text encoders from SD-XL checkpoint.\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_teacher_model, args.teacher_revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_teacher_model, args.teacher_revision, subfolder=\"text_encoder_2\"\n    )\n\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"text_encoder\", revision=args.teacher_revision\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"text_encoder_2\", revision=args.teacher_revision\n    )\n\n    # 4. Load VAE from SD-XL checkpoint (or more stable VAE)\n    vae_path = (\n        args.pretrained_teacher_model\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.teacher_revision,\n    )\n\n    # 5. Load teacher U-Net from SD-XL checkpoint\n    teacher_unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_teacher_model, subfolder=\"unet\", revision=args.teacher_revision\n    )\n\n    # 6. Freeze teacher vae, text_encoders, and teacher_unet\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    teacher_unet.requires_grad_(False)\n\n    # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)\n    # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None\n    time_cond_proj_dim = (\n        teacher_unet.config.time_cond_proj_dim\n        if teacher_unet.config.time_cond_proj_dim is not None\n        else args.unet_time_cond_proj_dim\n    )\n    unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)\n    # load teacher_unet weights into unet\n    unet.load_state_dict(teacher_unet.state_dict(), strict=False)\n    unet.train()\n\n    # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).\n    # Initialize from (online) unet\n    target_unet = UNet2DConditionModel.from_config(unet.config)\n    target_unet.load_state_dict(unet.state_dict())\n    target_unet.train()\n    target_unet.requires_grad_(False)\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if accelerator.unwrap_model(unet).dtype != torch.float32:\n        raise ValueError(\n            f\"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}\"\n        )\n\n    # 9. Handle mixed precision and device placement\n    # For mixed precision training we cast all non-trainable weights to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    vae.to(accelerator.device)\n    if args.pretrained_vae_model_name_or_path is not None:\n        vae.to(dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n    target_unet.to(accelerator.device)\n    # Move teacher_unet to device, optionally cast to weight_dtype\n    teacher_unet.to(accelerator.device)\n    if args.cast_teacher_unet:\n        teacher_unet.to(dtype=weight_dtype)\n\n    # Also move the alpha and sigma noise schedules to accelerator.device.\n    alpha_schedule = alpha_schedule.to(accelerator.device)\n    sigma_schedule = sigma_schedule.to(accelerator.device)\n    # Move the ODE solver to accelerator.device.\n    solver = solver.to(accelerator.device)\n\n    # 10. Handle saving and loading of checkpoints\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                target_unet.save_pretrained(os.path.join(output_dir, \"unet_target\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, \"unet_target\"))\n            target_unet.load_state_dict(load_model.state_dict())\n            target_unet.to(accelerator.device)\n            del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # 11. Enable optimizations\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n            teacher_unet.enable_xformers_memory_efficient_attention()\n            target_unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # 12. Optimizer creation\n    optimizer = optimizer_class(\n        unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # 13. Dataset creation and data processing\n    # Here, we compute not just the text embeddings but also the additional embeddings\n    # needed for the SD XL UNet to operate.\n    def compute_embeddings(\n        prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True\n    ):\n        target_size = (args.resolution, args.resolution)\n        original_sizes = list(map(list, zip(*original_sizes)))\n        crops_coords_top_left = list(map(list, zip(*crop_coords)))\n\n        original_sizes = torch.tensor(original_sizes, dtype=torch.long)\n        crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)\n\n        prompt_embeds, pooled_prompt_embeds = encode_prompt(\n            prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train\n        )\n        add_text_embeds = pooled_prompt_embeds\n\n        # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n        add_time_ids = list(target_size)\n        add_time_ids = torch.tensor([add_time_ids])\n        add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)\n        add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)\n        add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)\n\n        prompt_embeds = prompt_embeds.to(accelerator.device)\n        add_text_embeds = add_text_embeds.to(accelerator.device)\n        unet_added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n        return {\"prompt_embeds\": prompt_embeds, **unet_added_cond_kwargs}\n\n    dataset = SDXLText2ImageDataset(\n        train_shards_path_or_url=args.train_shards_path_or_url,\n        num_train_examples=args.max_train_samples,\n        per_gpu_batch_size=args.train_batch_size,\n        global_batch_size=args.train_batch_size * accelerator.num_processes,\n        num_workers=args.dataloader_num_workers,\n        resolution=args.resolution,\n        interpolation_type=args.interpolation_type,\n        shuffle_buffer_size=1000,\n        pin_memory=True,\n        persistent_workers=True,\n        use_fix_crop_and_size=args.use_fix_crop_and_size,\n    )\n    train_dataloader = dataset.train_dataloader\n\n    # Let's first compute all the embeddings so that we can free up the text encoders\n    # from memory.\n    text_encoders = [text_encoder_one, text_encoder_two]\n    tokenizers = [tokenizer_one, tokenizer_two]\n\n    compute_embeddings_fn = functools.partial(\n        compute_embeddings,\n        proportion_empty_prompts=0,\n        text_encoders=text_encoders,\n        tokenizers=tokenizers,\n    )\n\n    # 14. LR Scheduler creation\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps,\n        num_training_steps=args.max_train_steps,\n    )\n\n    # 15. Prepare for training\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Create uncond embeds for classifier free guidance\n    uncond_prompt_embeds = torch.zeros(args.train_batch_size, 77, 2048).to(accelerator.device)\n    uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, 1280).to(accelerator.device)\n\n    # 16. Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches each epoch = {train_dataloader.num_batches}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates)\n                image, text, orig_size, crop_coords = batch\n\n                image = image.to(accelerator.device, non_blocking=True)\n                encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)\n\n                if args.pretrained_vae_model_name_or_path is not None:\n                    pixel_values = image.to(dtype=weight_dtype)\n                    if vae.dtype != weight_dtype:\n                        vae.to(dtype=weight_dtype)\n                else:\n                    pixel_values = image\n\n                # encode pixel values with batch size of at most args.vae_encode_batch_size\n                latents = []\n                for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size):\n                    latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample())\n                latents = torch.cat(latents, dim=0)\n\n                latents = latents * vae.config.scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    latents = latents.to(weight_dtype)\n                bsz = latents.shape[0]\n\n                # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.\n                # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]\n                topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps\n                index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()\n                start_timesteps = solver.ddim_timesteps[index]\n                timesteps = start_timesteps - topk\n                timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)\n\n                # 3. Get boundary scalings for start_timesteps and (end) timesteps.\n                c_skip_start, c_out_start = scalings_for_boundary_conditions(\n                    start_timesteps, timestep_scaling=args.timestep_scaling_factor\n                )\n                c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]\n                c_skip, c_out = scalings_for_boundary_conditions(\n                    timesteps, timestep_scaling=args.timestep_scaling_factor\n                )\n                c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]\n\n                # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each\n                # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]\n                noise = torch.randn_like(latents)\n                noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)\n\n                # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it\n                w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min\n                w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim)\n                w = w.reshape(bsz, 1, 1, 1)\n                # Move to U-Net device and dtype\n                w = w.to(device=latents.device, dtype=latents.dtype)\n                w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)\n\n                # 6. Prepare prompt embeds and unet_added_conditions\n                prompt_embeds = encoded_text.pop(\"prompt_embeds\")\n\n                # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)\n                noise_pred = unet(\n                    noisy_model_input,\n                    start_timesteps,\n                    timestep_cond=w_embedding,\n                    encoder_hidden_states=prompt_embeds.float(),\n                    added_cond_kwargs=encoded_text,\n                ).sample\n\n                pred_x_0 = get_predicted_original_sample(\n                    noise_pred,\n                    start_timesteps,\n                    noisy_model_input,\n                    noise_scheduler.config.prediction_type,\n                    alpha_schedule,\n                    sigma_schedule,\n                )\n\n                model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0\n\n                # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the\n                # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these\n                # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE\n                # solver timestep.\n                with torch.no_grad():\n                    if torch.backends.mps.is_available():\n                        autocast_ctx = nullcontext()\n                    else:\n                        autocast_ctx = torch.autocast(accelerator.device.type)\n\n                    with autocast_ctx:\n                        # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c\n                        cond_teacher_output = teacher_unet(\n                            noisy_model_input.to(weight_dtype),\n                            start_timesteps,\n                            encoder_hidden_states=prompt_embeds.to(weight_dtype),\n                            added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},\n                        ).sample\n                        cond_pred_x0 = get_predicted_original_sample(\n                            cond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n                        cond_pred_noise = get_predicted_noise(\n                            cond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n\n                        # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0\n                        uncond_added_conditions = copy.deepcopy(encoded_text)\n                        uncond_added_conditions[\"text_embeds\"] = uncond_pooled_prompt_embeds\n                        uncond_teacher_output = teacher_unet(\n                            noisy_model_input.to(weight_dtype),\n                            start_timesteps,\n                            encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),\n                            added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},\n                        ).sample\n                        uncond_pred_x0 = get_predicted_original_sample(\n                            uncond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n                        uncond_pred_noise = get_predicted_noise(\n                            uncond_teacher_output,\n                            start_timesteps,\n                            noisy_model_input,\n                            noise_scheduler.config.prediction_type,\n                            alpha_schedule,\n                            sigma_schedule,\n                        )\n\n                        # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)\n                        # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation\n                        pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)\n                        pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)\n                        # 4. Run one step of the ODE solver to estimate the next point x_prev on the\n                        # augmented PF-ODE trajectory (solving backward in time)\n                        # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.\n                        x_prev = solver.ddim_step(pred_x0, pred_noise, index)\n\n                # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)\n                with torch.no_grad():\n                    if torch.backends.mps.is_available():\n                        autocast_ctx = nullcontext()\n                    else:\n                        autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)\n\n                    with autocast_ctx:\n                        target_noise_pred = target_unet(\n                            x_prev.float(),\n                            timesteps,\n                            timestep_cond=w_embedding,\n                            encoder_hidden_states=prompt_embeds.float(),\n                            added_cond_kwargs=encoded_text,\n                        ).sample\n                    pred_x_0 = get_predicted_original_sample(\n                        target_noise_pred,\n                        timesteps,\n                        x_prev,\n                        noise_scheduler.config.prediction_type,\n                        alpha_schedule,\n                        sigma_schedule,\n                    )\n                    target = c_skip * x_prev + c_out * pred_x_0\n\n                # 10. Calculate loss\n                if args.loss_type == \"l2\":\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                elif args.loss_type == \"huber\":\n                    loss = torch.mean(\n                        torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c\n                    )\n\n                # 11. Backpropagate on the online student model (`unet`)\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                # 12. Make EMA update to target student model parameters (`target_unet`)\n                update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if global_step % args.validation_steps == 0:\n                        log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, \"target\")\n                        log_validation(vae, unet, args, accelerator, weight_dtype, global_step, \"online\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = accelerator.unwrap_model(unet)\n        unet.save_pretrained(os.path.join(args.output_dir, \"unet\"))\n\n        target_unet = accelerator.unwrap_model(target_unet)\n        target_unet.save_pretrained(os.path.join(args.output_dir, \"unet_target\"))\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/controlnet/README.md",
    "content": "# ControlNet training example\n\n[Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala.\n\nThis example is based on the [training example in the original ControlNet repository](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md). It trains a ControlNet to fill circles using a [small synthetic dataset](https://huggingface.co/datasets/fusing/fill50k).\n\n## Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the example folder and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell e.g. a notebook\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\n## Circle filling dataset\n\nThe original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.\n\nOur training examples use [Stable Diffusion 1.5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) as the original set of ControlNet models were trained from it. However, ControlNet can be trained to augment any Stable Diffusion compatible model (such as [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4)) or [stabilityai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).\n\n## Training\n\nOur training examples use two test conditioning images. They can be downloaded by running\n\n```sh\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png\n\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png\n```\n\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=4\n```\n\nThis default configuration requires ~38GB VRAM.\n\nBy default, the training script logs outputs to tensorboard. Pass `--report_to wandb` to use weights and\nbiases.\n\nGradient accumulation with a smaller batch size can be used to reduce training requirements to ~20 GB VRAM.\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4\n```\n\n## Training with multiple GPUs\n\n`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)\nfor running distributed training with `accelerate`. Here is an example command:\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch --mixed_precision=\"fp16\" --multi_gpu train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=4 \\\n --mixed_precision=\"fp16\" \\\n --tracker_project_name=\"controlnet-demo\" \\\n --report_to=wandb\n```\n\n## Example results\n\n#### After 300 steps with batch size 8\n\n| |  |\n|-------------------|:-------------------------:|\n| | red circle with blue background  |\n![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![red circle with blue background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/red_circle_with_blue_background_300_steps.png) |\n| | cyan circle with brown floral background |\n![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png) | ![cyan circle with brown floral background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/cyan_circle_with_brown_floral_background_300_steps.png) |\n\n\n#### After 6000 steps with batch size 8:\n\n| |  |\n|-------------------|:-------------------------:|\n| | red circle with blue background  |\n![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![red circle with blue background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/red_circle_with_blue_background_6000_steps.png) |\n| | cyan circle with brown floral background |\n![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png) | ![cyan circle with brown floral background](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/cyan_circle_with_brown_floral_background_6000_steps.png) |\n\n## Training on a 16 GB GPU\n\nOptimizations:\n- Gradient checkpointing\n- bitsandbyte's 8-bit optimizer\n\n[bitandbytes install instructions](https://github.com/TimDettmers/bitsandbytes#requirements--installation).\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n --gradient_checkpointing \\\n --use_8bit_adam\n```\n\n## Training on a 12 GB GPU\n\nOptimizations:\n- Gradient checkpointing\n- bitsandbyte's 8-bit optimizer\n- xformers\n- set grads to none\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n --gradient_checkpointing \\\n --use_8bit_adam \\\n --enable_xformers_memory_efficient_attention \\\n --set_grads_to_none\n```\n\nWhen using `enable_xformers_memory_efficient_attention`, please make sure to install `xformers` by `pip install xformers`.\n\n## Training on an 8 GB GPU\n\nWe have not exhaustively tested DeepSpeed support for ControlNet. While the configuration does\nsave memory, we have not confirmed the configuration to train successfully. You will very likely\nhave to make changes to the config to have a successful training run.\n\nOptimizations:\n- Gradient checkpointing\n- xformers\n- set grads to none\n- DeepSpeed stage 2 with parameter and optimizer offloading\n- fp16 mixed precision\n\n[DeepSpeed](https://www.deepspeed.ai/) can offload tensors from VRAM to either\nCPU or NVME. This requires significantly more RAM (about 25 GB).\n\nUse `accelerate config` to enable DeepSpeed stage 2.\n\nThe relevant parts of the resulting accelerate config file are\n\n```yaml\ncompute_environment: LOCAL_MACHINE\ndeepspeed_config:\n  gradient_accumulation_steps: 4\n  offload_optimizer_device: cpu\n  offload_param_device: cpu\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\n```\n\nSee [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.\n\nChanging the default Adam optimizer to DeepSpeed's Adam\n`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but\nit requires CUDA toolchain with the same version as pytorch. 8-bit optimizer\ndoes not seem to be compatible with DeepSpeed at the moment.\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_controlnet.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n --gradient_checkpointing \\\n --enable_xformers_memory_efficient_attention \\\n --set_grads_to_none \\\n --mixed_precision fp16\n```\n\n## Performing inference with the trained ControlNet\n\nThe trained model can be run the same as the original ControlNet pipeline with the newly trained ControlNet.\nSet `base_model_path` and `controlnet_path` to the values `--pretrained_model_name_or_path` and\n`--output_dir` were respectively set to in the training script.\n\n```py\nfrom diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler\nfrom diffusers.utils import load_image\nimport torch\n\nbase_model_path = \"path to model\"\ncontrolnet_path = \"path to controlnet\"\n\ncontrolnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)\npipe = StableDiffusionControlNetPipeline.from_pretrained(\n    base_model_path, controlnet=controlnet, torch_dtype=torch.float16\n)\n\n# speed up diffusion process with faster scheduler and memory optimization\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n# remove following line if xformers is not installed or when using Torch 2.0.\npipe.enable_xformers_memory_efficient_attention()\n# memory optimization.\npipe.enable_model_cpu_offload()\n\ncontrol_image = load_image(\"./conditioning_image_1.png\")\nprompt = \"pale golden rod circle with old lace background\"\n\n# generate image\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    prompt, num_inference_steps=20, generator=generator, image=control_image\n).images[0]\nimage.save(\"./output.png\")\n```\n\n## Training with Flax/JAX\n\nFor faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.\n\n### Running on Google Cloud TPU\n\nSee below for commands to set up a TPU VM(`--accelerator-type v4-8`). For more details about how to set up and use TPUs, refer to [Cloud docs for single VM setup](https://cloud.google.com/tpu/docs/run-calculation-jax).\n\nFirst create a single TPUv4-8 VM and connect to it:\n\n```\nZONE=us-central2-b\nTPU_TYPE=v4-8\nVM_NAME=hg_flax\n\ngcloud alpha compute tpus tpu-vm create $VM_NAME \\\n --zone $ZONE \\\n --accelerator-type $TPU_TYPE \\\n --version  tpu-vm-v4-base\n\ngcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \\\n```\n\nWhen connected install JAX `0.4.5`:\n\n```sh\npip install \"jax[tpu]==0.4.5\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html\n```\n\nTo verify that JAX was correctly installed, you can run the following command:\n\n```py\nimport jax\njax.device_count()\n```\n\nThis should display the number of TPU cores, which should be 4 on a TPUv4-8 VM.\n\nThen install Diffusers and the library's training dependencies:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder and run\n\n```bash\npip install -U -r requirements_flax.txt\n```\n\nIf you want to use Weights and Biases logging, you should also install `wandb` now\n\n```bash\npip install wandb\n```\n\n\nNow let's downloading two conditioning images that we will use to run validation during the training in order to track our progress\n\n```sh\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png\n```\n\nWe encourage you to store or share your model with the community. To use huggingface hub, please login to your Hugging Face account, or ([create one](https://huggingface.co/docs/diffusers/main/en/training/hf.co/join) if you don’t have one already):\n\n```sh\nhf auth login\n```\n\nMake sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"runs/fill-circle-{timestamp}\"\nexport HUB_MODEL_ID=\"controlnet-fill-circle\"\n```\n\nAnd finally start the training\n\n```bash\npython3 train_controlnet_flax.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --validation_steps=1000 \\\n --train_batch_size=2 \\\n --revision=\"non-ema\" \\\n --from_pt \\\n --report_to=\"wandb\" \\\n --tracker_project_name=$HUB_MODEL_ID \\\n --num_train_epochs=11 \\\n --push_to_hub \\\n --hub_model_id=$HUB_MODEL_ID\n ```\n\nSince we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).\n\nOur training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`.  Here is an example command (from [this blog article](https://huggingface.co/blog/train-your-controlnet)):\n\n```bash\nexport MODEL_DIR=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport OUTPUT_DIR=\"runs/uncanny-faces-{timestamp}\"\nexport HUB_MODEL_ID=\"controlnet-uncanny-faces\"\n\npython3 train_controlnet_flax.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=multimodalart/facesyntheticsspigacaptioned \\\n --streaming \\\n --conditioning_image_column=spiga_seg \\\n --image_column=image \\\n --caption_column=image_caption \\\n --resolution=512 \\\n --max_train_samples 100000 \\\n --learning_rate=1e-5 \\\n --train_batch_size=1 \\\n --revision=\"flax\" \\\n --report_to=\"wandb\" \\\n --tracker_project_name=$HUB_MODEL_ID\n```\n\nNote, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:\n\n* [Webdataset](https://webdataset.github.io/webdataset/)\n* [TorchData](https://github.com/pytorch/data)\n* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)\n\nWhen work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:\n\n```bash\n --checkpointing_steps=500\n```\nThis will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500\n\nYou can then start your training from this saved checkpoint with\n\n```bash\n --controlnet_model_name_or_path=\"./control_out/500\"\n```\n\nWe support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://huggingface.co/papers/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.\n\nWe also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).\n\nYou can **profile your code** with:\n\n```bash\n --profile_steps==5\n```\n\nRefer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin:\n\n```bash\npip install tensorflow tensorboard-plugin-profile\ntensorboard --logdir runs/fill-circle-100steps-20230411_165612/\n```\n\nThe profile can then be inspected at http://localhost:6006/#profile\n\nSometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).\n\nNote that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).\n\n## Support for Stable Diffusion XL\n\nWe provide a training script for training a ControlNet with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to [README_sdxl.md](./README_sdxl.md) for more details.\n"
  },
  {
    "path": "examples/controlnet/README_flux.md",
    "content": "# ControlNet training example for FLUX\n\nThe `train_controlnet_flux.py` script shows how to implement the ControlNet training procedure and adapt it for [FLUX](https://github.com/black-forest-labs/flux).\n\nTraining script provided by LibAI, which is an institution dedicated to the progress and achievement of artificial general intelligence. LibAI is the developer of [cutout.pro](https://www.cutout.pro/) and [promeai.pro](https://www.promeai.pro/).\n> [!NOTE]\n> **Memory consumption**\n>\n> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.\n\nHere is a gpu memory consumption for reference, tested on a single A100 with 80G.\n\n| period | GPU |\n| - | - | \n| load as float32 | ~70G |\n| mv transformer and vae to bf16 | ~48G |\n| pre compute txt embeddings | ~62G |\n| **offload te to cpu** | ~30G |\n| training | ~58G |\n| validation | ~71G |\n\n\n> **Gated access**\n>\n> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: `hf auth login`\n\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/controlnet` folder and run\n```bash\npip install -r requirements_flux.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\n\n## Custom Datasets\n\nWe support dataset formats:\nThe original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script. To use our example, add `--dataset_name=fusing/fill50k \\` to the script and remove line `--jsonl_for_train` mentioned below.\n\n\nWe also support importing data from jsonl(xxx.jsonl),using `--jsonl_for_train` to enable it, here is a brief example of jsonl files:\n```sh\n{\"image\": \"xxx\", \"text\": \"xxx\", \"conditioning_image\": \"xxx\"}\n{\"image\": \"xxx\", \"text\": \"xxx\", \"conditioning_image\": \"xxx\"}\n```\n\n## Training\n\nOur training examples use two test conditioning images. They can be downloaded by running\n\n```sh\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png\n```\n\nThen run `hf auth login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.\n\nwe can define the num_layers, num_single_layers, which determines the size of the control(default values are num_layers=4, num_single_layers=10)\n\n\n```bash\naccelerate launch train_controlnet_flux.py \\\n    --pretrained_model_name_or_path=\"black-forest-labs/FLUX.1-dev\" \\\n    --dataset_name=fusing/fill50k \\\n    --conditioning_image_column=conditioning_image \\\n    --image_column=image \\\n    --caption_column=text \\\n    --output_dir=\"path to save model\" \\\n    --mixed_precision=\"bf16\" \\\n    --resolution=512 \\\n    --learning_rate=1e-5 \\\n    --max_train_steps=15000 \\\n    --validation_steps=100 \\\n    --checkpointing_steps=200 \\\n    --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n    --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n    --train_batch_size=1 \\\n    --gradient_accumulation_steps=16 \\\n    --report_to=\"wandb\" \\\n    --lr_scheduler=\"cosine\" \\\n    --num_double_layers=4 \\\n    --num_single_layers=0 \\\n    --seed=42 \\\n    --push_to_hub \\\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on Weights and Biases.\n* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\nOur experiments were conducted on a single 80GB A100 GPU.\n\n### Inference\n\nOnce training is done, we can perform inference like so:\n\n```python\nimport torch\nfrom diffusers.utils import load_image\nfrom diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline\nfrom diffusers.models.controlnet_flux import FluxControlNetModel\n\nbase_model = 'black-forest-labs/FLUX.1-dev'\ncontrolnet_model = 'promeai/FLUX.1-controlnet-lineart-promeai'\ncontrolnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)\npipe = FluxControlNetPipeline.from_pretrained(\n    base_model, \n    controlnet=controlnet, \n    torch_dtype=torch.bfloat16\n)\n# enable memory optimizations   \npipe.enable_model_cpu_offload()\n\ncontrol_image = load_image(\"https://huggingface.co/promeai/FLUX.1-controlnet-lineart-promeai/resolve/main/images/example-control.jpg\")resize((1024, 1024))\nprompt = \"cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron mouth open holding a fancy black forest cake with candles on top in the kitchen of an old dark Victorian mansion lit by candlelight with a bright window to the foggy forest and very expensive stuff everywhere\"\n\nimage = pipe(\n    prompt, \n    control_image=control_image,\n    controlnet_conditioning_scale=0.6,\n    num_inference_steps=28, \n    guidance_scale=3.5,\n).images[0]\nimage.save(\"./output.png\")\n```\n\n## Apply Deepspeed Zero3 \n\nThis is an experimental process, I am not sure if it is suitable for everyone, we used this process to successfully train 512 resolution on A100(40g) * 8.\nPlease modify some of the code in the script.\n### 1.Customize zero3 settings\n\nCopy the **accelerate_config_zero3.yaml**,modify `num_processes` according to the number of gpus you want to use:\n\n```bash\ncompute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 8\n  offload_optimizer_device: cpu\n  offload_param_device: cpu\n  zero3_init_flag: true\n  zero3_save_16bit_model: true\n  zero_stage: 3\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 8\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n```\n\n### 2.Precompute all inputs (latent, embeddings)\n\nIn the train_controlnet_flux.py, We need to pre-calculate all parameters and put them into batches.So we first need to rewrite the `compute_embeddings` function. \n\n```python\ndef compute_embeddings(batch, proportion_empty_prompts, vae, flux_controlnet_pipeline, weight_dtype, is_train=True):\n    \n    ### compute text embeddings\n    prompt_batch = batch[args.caption_column]\n    captions = []\n    for caption in prompt_batch:\n        if random.random() < proportion_empty_prompts:\n            captions.append(\"\")\n        elif isinstance(caption, str):\n            captions.append(caption)\n        elif isinstance(caption, (list, np.ndarray)):\n            # take a random caption if there are multiple\n            captions.append(random.choice(caption) if is_train else caption[0])\n    prompt_batch = captions\n    prompt_embeds, pooled_prompt_embeds, text_ids = flux_controlnet_pipeline.encode_prompt(\n        prompt_batch, prompt_2=prompt_batch\n    )\n    prompt_embeds = prompt_embeds.to(dtype=weight_dtype)\n    pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype)\n    text_ids = text_ids.to(dtype=weight_dtype)\n\n    # text_ids [512,3] to [bs,512,3]\n    text_ids = text_ids.unsqueeze(0).expand(prompt_embeds.shape[0], -1, -1)\n\n    ### compute latents\n    def _pack_latents(latents, batch_size, num_channels_latents, height, width):\n        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)\n        latents = latents.permute(0, 2, 4, 1, 3, 5)\n        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)\n        return latents\n\n    # vae encode\n    pixel_values = batch[\"pixel_values\"]\n    pixel_values = torch.stack([image for image in pixel_values]).to(dtype=weight_dtype).to(vae.device)\n    pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample()\n    pixel_latents_tmp = (pixel_latents_tmp - vae.config.shift_factor) * vae.config.scaling_factor\n    pixel_latents = _pack_latents(\n        pixel_latents_tmp,\n        pixel_values.shape[0],\n        pixel_latents_tmp.shape[1],\n        pixel_latents_tmp.shape[2],\n        pixel_latents_tmp.shape[3],\n    ) \n\n    control_values = batch[\"conditioning_pixel_values\"]\n    control_values = torch.stack([image for image in control_values]).to(dtype=weight_dtype).to(vae.device)\n    control_latents = vae.encode(control_values).latent_dist.sample()\n    control_latents = (control_latents - vae.config.shift_factor) * vae.config.scaling_factor\n    control_latents = _pack_latents(\n        control_latents,\n        control_values.shape[0],\n        control_latents.shape[1],\n        control_latents.shape[2],\n        control_latents.shape[3],\n    )\n\n    # copied from pipeline_flux_controlnet\n    def _prepare_latent_image_ids(batch_size, height, width, device, dtype):\n        latent_image_ids = torch.zeros(height // 2, width // 2, 3)\n        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]\n        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]\n\n        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape\n\n        latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)\n        latent_image_ids = latent_image_ids.reshape(\n            batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels\n        )\n\n        return latent_image_ids.to(device=device, dtype=dtype)\n    latent_image_ids = _prepare_latent_image_ids(\n        batch_size=pixel_latents_tmp.shape[0],\n        height=pixel_latents_tmp.shape[2],\n        width=pixel_latents_tmp.shape[3],\n        device=pixel_values.device,\n        dtype=pixel_values.dtype,\n    )\n\n    # unet_added_cond_kwargs = {\"pooled_prompt_embeds\": pooled_prompt_embeds, \"text_ids\": text_ids}\n    return {\"prompt_embeds\": prompt_embeds, \"pooled_prompt_embeds\": pooled_prompt_embeds, \"text_ids\": text_ids, \"pixel_latents\": pixel_latents, \"control_latents\": control_latents, \"latent_image_ids\": latent_image_ids}\n```\n\nBecause we need images to pass through vae, we need to preprocess the images in the dataset first. At the same time, vae requires more gpu memory, so you may need to modify the `batch_size` below\n```diff\n+train_dataset = prepare_train_dataset(train_dataset, accelerator)\nwith accelerator.main_process_first():\n    from datasets.fingerprint import Hasher\n\n    # fingerprint used by the cache for the other processes to load the result\n    # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401\n    new_fingerprint = Hasher.hash(args)\n    train_dataset = train_dataset.map(\n-        compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=100\n+        compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=10\n    )\n\ndel text_encoders, tokenizers\ngc.collect()\ntorch.cuda.empty_cache()\n\n# Then get the training dataset ready to be passed to the dataloader.\n-train_dataset = prepare_train_dataset(train_dataset, accelerator)\n```\n### 3.Redefine the behavior of getting batchsize\n\nNow that we have all the preprocessing done, we need to modify the `collate_fn` function.\n\n```python\ndef collate_fn(examples):\n    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    conditioning_pixel_values = torch.stack([example[\"conditioning_pixel_values\"] for example in examples])\n    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    pixel_latents = torch.stack([torch.tensor(example[\"pixel_latents\"]) for example in examples])\n    pixel_latents = pixel_latents.to(memory_format=torch.contiguous_format).float()\n\n    control_latents = torch.stack([torch.tensor(example[\"control_latents\"]) for example in examples])\n    control_latents = control_latents.to(memory_format=torch.contiguous_format).float()\n    \n    latent_image_ids= torch.stack([torch.tensor(example[\"latent_image_ids\"]) for example in examples])\n    \n    prompt_ids = torch.stack([torch.tensor(example[\"prompt_embeds\"]) for example in examples])\n\n    pooled_prompt_embeds = torch.stack([torch.tensor(example[\"pooled_prompt_embeds\"]) for example in examples])\n    text_ids = torch.stack([torch.tensor(example[\"text_ids\"]) for example in examples])\n\n    return {\n        \"pixel_values\": pixel_values,\n        \"conditioning_pixel_values\": conditioning_pixel_values,\n        \"pixel_latents\": pixel_latents,\n        \"control_latents\": control_latents,\n        \"latent_image_ids\": latent_image_ids,\n        \"prompt_ids\": prompt_ids,\n        \"unet_added_conditions\": {\"pooled_prompt_embeds\": pooled_prompt_embeds, \"time_ids\": text_ids},\n    }\n```\nFinally, we just need to modify the way of obtaining various parameters during training.\n```python\nfor epoch in range(first_epoch, args.num_train_epochs):\n    for step, batch in enumerate(train_dataloader):\n        with accelerator.accumulate(flux_controlnet):\n            # Convert images to latent space\n            pixel_latents = batch[\"pixel_latents\"].to(dtype=weight_dtype)\n            control_image = batch[\"control_latents\"].to(dtype=weight_dtype)\n            latent_image_ids = batch[\"latent_image_ids\"].to(dtype=weight_dtype)\n\n            # Sample noise that we'll add to the latents\n            noise = torch.randn_like(pixel_latents).to(accelerator.device).to(dtype=weight_dtype)\n            bsz = pixel_latents.shape[0]\n\n            # Sample a random timestep for each image\n            t = torch.sigmoid(torch.randn((bsz,), device=accelerator.device, dtype=weight_dtype))\n\n            # apply flow matching\n            noisy_latents = (\n                1 - t.unsqueeze(1).unsqueeze(2).repeat(1, pixel_latents.shape[1], pixel_latents.shape[2])\n            ) * pixel_latents + t.unsqueeze(1).unsqueeze(2).repeat(\n                1, pixel_latents.shape[1], pixel_latents.shape[2]\n            ) * noise\n\n            guidance_vec = torch.full(\n                (noisy_latents.shape[0],), 3.5, device=noisy_latents.device, dtype=weight_dtype\n            )\n\n            controlnet_block_samples, controlnet_single_block_samples = flux_controlnet(\n                hidden_states=noisy_latents,\n                controlnet_cond=control_image,\n                timestep=t,\n                guidance=guidance_vec,\n                pooled_projections=batch[\"unet_added_conditions\"][\"pooled_prompt_embeds\"].to(dtype=weight_dtype),\n                encoder_hidden_states=batch[\"prompt_ids\"].to(dtype=weight_dtype),\n                txt_ids=batch[\"unet_added_conditions\"][\"time_ids\"][0].to(dtype=weight_dtype),\n                img_ids=latent_image_ids[0],\n                return_dict=False,\n            )\n\n            noise_pred = flux_transformer(\n                hidden_states=noisy_latents,\n                timestep=t,\n                guidance=guidance_vec,\n                pooled_projections=batch[\"unet_added_conditions\"][\"pooled_prompt_embeds\"].to(dtype=weight_dtype),\n                encoder_hidden_states=batch[\"prompt_ids\"].to(dtype=weight_dtype),\n                controlnet_block_samples=[sample.to(dtype=weight_dtype) for sample in controlnet_block_samples]\n                if controlnet_block_samples is not None\n                else None,\n                controlnet_single_block_samples=[\n                    sample.to(dtype=weight_dtype) for sample in controlnet_single_block_samples\n                ]\n                if controlnet_single_block_samples is not None\n                else None,\n                txt_ids=batch[\"unet_added_conditions\"][\"time_ids\"][0].to(dtype=weight_dtype),\n                img_ids=latent_image_ids[0],\n                return_dict=False,\n            )[0]\n```\nCongratulations! You have completed all the required code modifications required for deepspeedzero3.\n\n### 4.Training with deepspeedzero3\n\nStart!!!\n\n```bash\nexport pretrained_model_name_or_path='flux-dev-model-path'\nexport MODEL_TYPE='train_model_type'\nexport TRAIN_JSON_FILE=\"your_json_file\"\nexport CONTROL_TYPE='control_preprocessor_type'\nexport CAPTION_COLUMN='caption_column'\n\nexport CACHE_DIR=\"/data/train_csr/.cache/huggingface/\"\nexport OUTPUT_DIR='/data/train_csr/FLUX/MODEL_OUT/'$MODEL_TYPE\n# The first step is to use Python to precompute all caches.Replace the first line below with this line. (I am not sure why using accelerate would cause problems.)\n\nCUDA_VISIBLE_DEVICES=0 python3 train_controlnet_flux.py \\\n\n# The second step is to use the above accelerate config to train\naccelerate  launch  --config_file \"./accelerate_config_zero3.yaml\" train_controlnet_flux.py \\\n    --pretrained_model_name_or_path=$pretrained_model_name_or_path \\\n    --jsonl_for_train=$TRAIN_JSON_FILE \\\n    --conditioning_image_column=$CONTROL_TYPE \\\n    --image_column=image \\\n    --caption_column=$CAPTION_COLUMN\\\n    --cache_dir=$CACHE_DIR \\\n    --tracker_project_name=$MODEL_TYPE \\\n    --output_dir=$OUTPUT_DIR \\\n    --max_train_steps=500000 \\\n    --mixed_precision bf16 \\\n    --checkpointing_steps=1000 \\\n    --gradient_accumulation_steps=8 \\\n    --resolution=512 \\\n    --train_batch_size=1 \\\n    --learning_rate=1e-5 \\\n    --num_double_layers=4 \\\n    --num_single_layers=0 \\\n    --gradient_checkpointing \\\n    --resume_from_checkpoint=\"latest\" \\\n    # --use_adafactor \\ dont use\n    # --validation_steps=3 \\ not support \n    # --validation_image $VALIDATION_IMAGE \\ not support \n    # --validation_prompt \"xxx\" \\ not support \n```"
  },
  {
    "path": "examples/controlnet/README_sd3.md",
    "content": "# ControlNet training example for Stable Diffusion 3/3.5 (SD3/3.5)\n\nThe `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206) and [Stable Diffusion 3.5](https://stability.ai/news/introducing-stable-diffusion-3-5).\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/controlnet` folder and run\n```bash\npip install -r requirements_sd3.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\n\n## Circle filling dataset\n\nThe original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.\nPlease download the dataset and unzip it in the directory `fill50k` in the `examples/controlnet` folder.\n\n## Training\n\nFirst download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or the SD3.5 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). We will use it as a base model for the ControlNet training.\n> [!NOTE]\n> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or [Stable Diffusion 3.5 Large Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:\n\n```bash\nhf auth login\n```\n\nThis will also allow us to push the trained model parameters to the Hugging Face Hub platform.\n\n\nOur training examples use two test conditioning images. They can be downloaded by running\n\n```sh\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png\n\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png\n```\n\nThen run the following commands to train a ControlNet model.\n\n```bash\nexport MODEL_DIR=\"stabilityai/stable-diffusion-3-medium-diffusers\"\nexport OUTPUT_DIR=\"sd3-controlnet-out\"\n\naccelerate launch train_controlnet_sd3.py \\\n    --pretrained_model_name_or_path=$MODEL_DIR \\\n    --output_dir=$OUTPUT_DIR \\\n    --train_data_dir=\"fill50k\" \\\n    --resolution=1024 \\\n    --learning_rate=1e-5 \\\n    --max_train_steps=15000 \\\n    --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n    --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n    --validation_steps=100 \\\n    --train_batch_size=1 \\\n    --gradient_accumulation_steps=4\n```\n\nTo train a ControlNet model for Stable Diffusion 3.5, replace the `MODEL_DIR` with `stabilityai/stable-diffusion-3.5-medium`.\n\nTo better track our training experiments, we're using flags `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\nOur experiments were conducted on a single 40GB A100 GPU.\n\n### Inference\n\nOnce training is done, we can perform inference like so:\n\n```python\nfrom diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel\nfrom diffusers.utils import load_image\nimport torch\n\nbase_model_path = \"stabilityai/stable-diffusion-3-medium-diffusers\"\ncontrolnet_path = \"DavyMorgan/sd3-controlnet-out\"\n\ncontrolnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)\npipe = StableDiffusion3ControlNetPipeline.from_pretrained(\n    base_model_path, controlnet=controlnet\n)\npipe.to(\"cuda\", torch.float16)\n\n\ncontrol_image = load_image(\"./conditioning_image_1.png\").resize((1024, 1024))\nprompt = \"pale golden rod circle with old lace background\"\n\n# generate image\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    prompt, num_inference_steps=20, generator=generator, control_image=control_image\n).images[0]\nimage.save(\"./output.png\")\n```\n\nSimilarly, for SD3.5, replace the `base_model_path` with `stabilityai/stable-diffusion-3.5-medium` and controlnet_path `DavyMorgan/sd35-controlnet-out'.\n\n## Notes\n\n### GPU usage\n\nSD3 is a large model and requires a lot of GPU memory. \nWe recommend using one GPU with at least 80GB of memory.\nMake sure to use the right GPU when configuring the [accelerator](https://huggingface.co/docs/transformers/en/accelerate).\n\n\n## Example results\n\n### SD3\n\n#### After 500 steps with batch size 8\n\n| |  |\n|-------------------|:-------------------------:|\n|| pale golden rod circle with old lace background |\n ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-500.png) |\n\n\n#### After 6500 steps with batch size 8:\n\n| |  |\n|-------------------|:-------------------------:|\n|| pale golden rod circle with old lace background |\n ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-6500.png) |\n\n### SD3.5\n\n#### After 500 steps with batch size 8\n\n| |                                                                                                                                                     |\n|-------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------:|\n||                                                   pale golden rod circle with old lace background                                                   |\n ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-500-3.5.png) |\n\n\n#### After 3000 steps with batch size 8:\n\n| |                                                                                                                                                      |\n|-------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------:|\n||                                                   pale golden rod circle with old lace background                                                    |\n ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-3000-3.5.png) |\n\n"
  },
  {
    "path": "examples/controlnet/README_sdxl.md",
    "content": "# ControlNet training example for Stable Diffusion XL (SDXL)\n\nThe `train_controlnet_sdxl.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/controlnet` folder and run\n```bash\npip install -r requirements_sdxl.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\n\n## Circle filling dataset\n\nThe original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.\n\n## Training\n\nOur training examples use two test conditioning images. They can be downloaded by running\n\n```sh\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png\n\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png\n```\n\nThen run `hf auth login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.\n\n```bash\nexport MODEL_DIR=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_controlnet_sdxl.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --mixed_precision=\"fp16\" \\\n --resolution=1024 \\\n --learning_rate=1e-5 \\\n --max_train_steps=15000 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --validation_steps=100 \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n --report_to=\"wandb\" \\\n --seed=42 \\\n --push_to_hub\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.\n* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\nOur experiments were conducted on a single 40GB A100 GPU.\n\n### Inference\n\nOnce training is done, we can perform inference like so:\n\n```python\nfrom diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler\nfrom diffusers.utils import load_image\nimport torch\n\nbase_model_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\ncontrolnet_path = \"path to controlnet\"\n\ncontrolnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)\npipe = StableDiffusionXLControlNetPipeline.from_pretrained(\n    base_model_path, controlnet=controlnet, torch_dtype=torch.float16\n)\n\n# speed up diffusion process with faster scheduler and memory optimization\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n# remove following line if xformers is not installed or when using Torch 2.0.\npipe.enable_xformers_memory_efficient_attention()\n# memory optimization.\npipe.enable_model_cpu_offload()\n\ncontrol_image = load_image(\"./conditioning_image_1.png\").resize((1024, 1024))\nprompt = \"pale golden rod circle with old lace background\"\n\n# generate image\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    prompt, num_inference_steps=20, generator=generator, image=control_image\n).images[0]\nimage.save(\"./output.png\")\n```\n\n## Notes\n\n### Specifying a better VAE\n\nSDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of an alternative VAE (such as [`madebyollin/sdxl-vae-fp16-fix`](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).\n\nIf you're using this VAE during training, you need to ensure you're using it during inference too. You do so by:\n\n```diff\n+ vae = AutoencoderKL.from_pretrained(vae_path_or_repo_id, torch_dtype=torch.float16)\ncontrolnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)\npipe = StableDiffusionXLControlNetPipeline.from_pretrained(\n    base_model_path, controlnet=controlnet, torch_dtype=torch.float16,\n+   vae=vae,\n)\n"
  },
  {
    "path": "examples/controlnet/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\ndatasets\n"
  },
  {
    "path": "examples/controlnet/requirements_flax.txt",
    "content": "transformers>=4.25.1\ndatasets\nflax\noptax\ntorch\ntorchvision\nftfy\ntensorboard\nJinja2\n"
  },
  {
    "path": "examples/controlnet/requirements_flux.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2\ndatasets\nwandb\nSentencePiece"
  },
  {
    "path": "examples/controlnet/requirements_sd3.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2\ndatasets\nwandb\n"
  },
  {
    "path": "examples/controlnet/requirements_sdxl.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2\ndatasets\nwandb\n"
  },
  {
    "path": "examples/controlnet/test_controlnet.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass ControlNet(ExamplesTestsAccelerate):\n    def test_controlnet_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/controlnet/train_controlnet.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n            --dataset_name=hf-internal-testing/fill10\n            --output_dir={tmpdir}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/controlnet/train_controlnet.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n            --dataset_name=hf-internal-testing/fill10\n            --output_dir={tmpdir}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet\n            --max_train_steps=6\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n            resume_run_args = f\"\"\"\n            examples/controlnet/train_controlnet.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n            --dataset_name=hf-internal-testing/fill10\n            --output_dir={tmpdir}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-6\n            --checkpoints_total_limit=2\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n\n\nclass ControlNetSDXL(ExamplesTestsAccelerate):\n    def test_controlnet_sdxl(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/controlnet/train_controlnet_sdxl.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe\n            --dataset_name=hf-internal-testing/fill10\n            --output_dir={tmpdir}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl\n            --max_train_steps=4\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"diffusion_pytorch_model.safetensors\")))\n\n\nclass ControlNetSD3(ExamplesTestsAccelerate):\n    def test_controlnet_sd3(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/controlnet/train_controlnet_sd3.py\n            --pretrained_model_name_or_path=DavyMorgan/tiny-sd3-pipe\n            --dataset_name=hf-internal-testing/fill10\n            --output_dir={tmpdir}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd3\n            --max_train_steps=4\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"diffusion_pytorch_model.safetensors\")))\n\n\nclass ControlNetSD35(ExamplesTestsAccelerate):\n    def test_controlnet_sd3(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/controlnet/train_controlnet_sd3.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-sd35-pipe\n            --dataset_name=hf-internal-testing/fill10\n            --output_dir={tmpdir}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd35\n            --max_train_steps=4\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"diffusion_pytorch_model.safetensors\")))\n\n\nclass ControlNetflux(ExamplesTestsAccelerate):\n    def test_controlnet_flux(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/controlnet/train_controlnet_flux.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-flux-pipe\n            --output_dir={tmpdir}\n            --dataset_name=hf-internal-testing/fill10\n            --conditioning_image_column=conditioning_image\n            --image_column=image\n            --caption_column=text\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            --num_double_layers=1\n            --num_single_layers=1\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"diffusion_pytorch_model.safetensors\")))\n"
  },
  {
    "path": "examples/controlnet/train_controlnet.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport contextlib\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    ControlNetModel,\n    DDPMScheduler,\n    StableDiffusionControlNetPipeline,\n    UNet2DConditionModel,\n    UniPCMultistepScheduler,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef image_grid(imgs, rows, cols):\n    assert len(imgs) == rows * cols\n\n    w, h = imgs[0].size\n    grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i % cols * w, i // cols * h))\n    return grid\n\n\ndef log_validation(\n    vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False\n):\n    logger.info(\"Running validation... \")\n\n    if not is_final_validation:\n        controlnet = accelerator.unwrap_model(controlnet)\n    else:\n        controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)\n\n    pipeline = StableDiffusionControlNetPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=vae,\n        text_encoder=text_encoder,\n        tokenizer=tokenizer,\n        unet=unet,\n        controlnet=controlnet,\n        safety_checker=None,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n    inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(\"cuda\")\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        validation_image = Image.open(validation_image).convert(\"RGB\")\n\n        images = []\n\n        for _ in range(args.num_validation_images):\n            with inference_ctx:\n                image = pipeline(\n                    validation_prompt, validation_image, num_inference_steps=20, generator=generator\n                ).images[0]\n\n            images.append(image)\n\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images = [np.asarray(validation_image)]\n\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images.append(wandb.Image(validation_image, caption=\"Controlnet conditioning\"))\n\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({tracker_key: formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n    del pipeline\n    gc.collect()\n    torch.cuda.empty_cache()\n\n    return image_logs\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# controlnet-{repo_id}\n\nThese are controlnet weights trained on {base_model} with new type of conditioning.\n{img_str}\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion\",\n        \"stable-diffusion-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"controlnet\",\n        \"diffusers-training\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a ControlNet training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--controlnet_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained controlnet model or model identifier from huggingface.co/models.\"\n        \" If not specified controlnet weights are initialized from unet.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"controlnet-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the controlnet conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"train_controlnet\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--train_data_dir`\")\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    if args.validation_prompt is not None and args.validation_image is None:\n        raise ValueError(\"`--validation_image` must be set if `--validation_prompt` is set\")\n\n    if args.validation_prompt is None and args.validation_image is not None:\n        raise ValueError(\"`--validation_prompt` must be set if `--validation_image` is set\")\n\n    if (\n        args.validation_image is not None\n        and args.validation_prompt is not None\n        and len(args.validation_image) != 1\n        and len(args.validation_prompt) != 1\n        and len(args.validation_image) != len(args.validation_prompt)\n    ):\n        raise ValueError(\n            \"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,\"\n            \" or the same number of `--validation_prompt`s and `--validation_image`s\"\n        )\n\n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder.\"\n        )\n\n    return args\n\n\ndef make_train_dataset(args, tokenizer, accelerator):\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            data_dir=args.train_data_dir,\n        )\n    else:\n        if args.train_data_dir is not None:\n            dataset = load_dataset(\n                args.train_data_dir,\n                cache_dir=args.cache_dir,\n            )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    if args.image_column is None:\n        image_column = column_names[0]\n        logger.info(f\"image column defaulting to {image_column}\")\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.caption_column is None:\n        caption_column = column_names[1]\n        logger.info(f\"caption column defaulting to {caption_column}\")\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.conditioning_image_column is None:\n        conditioning_image_column = column_names[2]\n        logger.info(f\"conditioning image column defaulting to {conditioning_image_column}\")\n    else:\n        conditioning_image_column = args.conditioning_image_column\n        if conditioning_image_column not in column_names:\n            raise ValueError(\n                f\"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if random.random() < args.proportion_empty_prompts:\n                captions.append(\"\")\n            elif isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        return inputs.input_ids\n\n    image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    conditioning_image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        images = [image_transforms(image) for image in images]\n\n        conditioning_images = [image.convert(\"RGB\") for image in examples[conditioning_image_column]]\n        conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]\n\n        examples[\"pixel_values\"] = images\n        examples[\"conditioning_pixel_values\"] = conditioning_images\n        examples[\"input_ids\"] = tokenize_captions(examples)\n\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    return train_dataset\n\n\ndef collate_fn(examples):\n    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    conditioning_pixel_values = torch.stack([example[\"conditioning_pixel_values\"] for example in examples])\n    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n\n    return {\n        \"pixel_values\": pixel_values,\n        \"conditioning_pixel_values\": conditioning_pixel_values,\n        \"input_ids\": input_ids,\n    }\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    if args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.pretrained_model_name_or_path,\n            subfolder=\"tokenizer\",\n            revision=args.revision,\n            use_fast=False,\n        )\n\n    # import correct text encoder class\n    text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = text_encoder_cls.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    if args.controlnet_model_name_or_path:\n        logger.info(\"Loading existing controlnet weights\")\n        controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)\n    else:\n        logger.info(\"Initializing controlnet weights from unet\")\n        controlnet = ControlNetModel.from_unet(unet)\n\n    # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                i = len(weights) - 1\n\n                while len(weights) > 0:\n                    weights.pop()\n                    model = models[i]\n\n                    sub_dir = \"controlnet\"\n                    model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                    i -= 1\n\n        def load_model_hook(models, input_dir):\n            while len(models) > 0:\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = ControlNetModel.from_pretrained(input_dir, subfolder=\"controlnet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    vae.requires_grad_(False)\n    unet.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    controlnet.train()\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n            controlnet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        controlnet.enable_gradient_checkpointing()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if unwrap_model(controlnet).dtype != torch.float32:\n        raise ValueError(\n            f\"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}\"\n        )\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = controlnet.parameters()\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    train_dataset = make_train_dataset(args, tokenizer, accelerator)\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        controlnet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae, unet and text_encoder to device and cast to weight_dtype\n    vae.to(accelerator.device, dtype=weight_dtype)\n    unet.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n\n        # tensorboard cannot handle list types for config\n        tracker_config.pop(\"validation_prompt\")\n        tracker_config.pop(\"validation_image\")\n\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    image_logs = None\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(controlnet):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to(\n                    dtype=weight_dtype\n                )\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"], return_dict=False)[0]\n\n                controlnet_image = batch[\"conditioning_pixel_values\"].to(dtype=weight_dtype)\n\n                down_block_res_samples, mid_block_res_sample = controlnet(\n                    noisy_latents,\n                    timesteps,\n                    encoder_hidden_states=encoder_hidden_states,\n                    controlnet_cond=controlnet_image,\n                    return_dict=False,\n                )\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_latents,\n                    timesteps,\n                    encoder_hidden_states=encoder_hidden_states,\n                    down_block_additional_residuals=[\n                        sample.to(dtype=weight_dtype) for sample in down_block_res_samples\n                    ],\n                    mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),\n                    return_dict=False,\n                )[0]\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = controlnet.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        image_logs = log_validation(\n                            vae,\n                            text_encoder,\n                            tokenizer,\n                            unet,\n                            controlnet,\n                            args,\n                            accelerator,\n                            weight_dtype,\n                            global_step,\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        controlnet = unwrap_model(controlnet)\n        controlnet.save_pretrained(args.output_dir)\n\n        # Run a final round of validation.\n        image_logs = None\n        if args.validation_prompt is not None:\n            image_logs = log_validation(\n                vae=vae,\n                text_encoder=text_encoder,\n                tokenizer=tokenizer,\n                unet=unet,\n                controlnet=None,\n                args=args,\n                accelerator=accelerator,\n                weight_dtype=weight_dtype,\n                step=global_step,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                image_logs=image_logs,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/controlnet/train_controlnet_flax.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport time\nfrom pathlib import Path\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom datasets import load_dataset, load_from_disk\nfrom flax import jax_utils\nfrom flax.core.frozen_dict import unfreeze\nfrom flax.training import train_state\nfrom flax.training.common_utils import shard\nfrom huggingface_hub import create_repo, upload_folder\nfrom PIL import Image, PngImagePlugin\nfrom torch.utils.data import IterableDataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed\n\nfrom diffusers import (\n    FlaxAutoencoderKL,\n    FlaxControlNetModel,\n    FlaxDDPMScheduler,\n    FlaxStableDiffusionControlNetPipeline,\n    FlaxUNet2DConditionModel,\n)\nfrom diffusers.utils import check_min_version, is_wandb_available, make_image_grid\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\n\n\n# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image\n# see more https://github.com/python-pillow/Pillow/issues/5610\nLARGE_ENOUGH_NUMBER = 100\nPngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = logging.getLogger(__name__)\n\n\ndef log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype):\n    logger.info(\"Running validation...\")\n\n    pipeline_params = pipeline_params.copy()\n    pipeline_params[\"controlnet\"] = controlnet_params\n\n    num_samples = jax.device_count()\n    prng_seed = jax.random.split(rng, jax.device_count())\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        prompts = num_samples * [validation_prompt]\n        prompt_ids = pipeline.prepare_text_inputs(prompts)\n        prompt_ids = shard(prompt_ids)\n\n        validation_image = Image.open(validation_image).convert(\"RGB\")\n        processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])\n        processed_image = shard(processed_image)\n        images = pipeline(\n            prompt_ids=prompt_ids,\n            image=processed_image,\n            params=pipeline_params,\n            prng_seed=prng_seed,\n            num_inference_steps=50,\n            jit=True,\n        ).images\n\n        images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])\n        images = pipeline.numpy_to_pil(images)\n\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    if args.report_to == \"wandb\":\n        formatted_images = []\n        for log in image_logs:\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n\n            formatted_images.append(wandb.Image(validation_image, caption=\"Controlnet conditioning\"))\n            for image in images:\n                image = wandb.Image(image, caption=validation_prompt)\n                formatted_images.append(image)\n\n        wandb.log({\"validation\": formatted_images})\n    else:\n        logger.warning(f\"image logging not implemented for {args.report_to}\")\n\n    return image_logs\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# controlnet- {repo_id}\n\nThese are controlnet weights trained on {base_model} with new type of conditioning. You can find some example images in the following. \\n\n{img_str}\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion\",\n        \"stable-diffusion-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"controlnet\",\n        \"jax-diffusers-event\",\n        \"diffusers-training\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--controlnet_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained controlnet model or model identifier from huggingface.co/models.\"\n        \" If not specified controlnet weights are initialized from unet.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--from_pt\",\n        action=\"store_true\",\n        help=\"Load the pretrained model from a PyTorch checkpoint.\",\n    )\n    parser.add_argument(\n        \"--controlnet_revision\",\n        type=str,\n        default=None,\n        help=\"Revision of controlnet model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--profile_steps\",\n        type=int,\n        default=0,\n        help=\"How many training steps to profile in the beginning.\",\n    )\n    parser.add_argument(\n        \"--profile_validation\",\n        action=\"store_true\",\n        help=\"Whether to profile the (last) validation.\",\n    )\n    parser.add_argument(\n        \"--profile_memory\",\n        action=\"store_true\",\n        help=\"Whether to dump an initial (before training loop) and a final (at program end) memory profile.\",\n    )\n    parser.add_argument(\n        \"--ccache\",\n        type=str,\n        default=None,\n        help=\"Enables compilation cache.\",\n    )\n    parser.add_argument(\n        \"--controlnet_from_pt\",\n        action=\"store_true\",\n        help=\"Load the controlnet model from a PyTorch checkpoint.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"runs/{timestamp}\",\n        help=\"The output directory where the model predictions and checkpoints will be written. \"\n        \"Can contain placeholders: {timestamp}.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=0, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=1, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=5000,\n        help=(\"Save a checkpoint of the training state every X updates.\"),\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_steps\",\n        type=int,\n        default=100,\n        help=(\"log training metric every X steps to `--report_t`\"),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"wandb\",\n        help=('The integration to report the results and logs to. Currently only supported platforms are `\"wandb\"`'),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\"--streaming\", action=\"store_true\", help=\"To stream a large dataset from Hub.\")\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder.\"\n            \"Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) .\"\n            \"If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--load_from_disk\",\n        action=\"store_true\",\n        help=(\n            \"If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`\"\n            \"See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the controlnet conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set. Needed if `streaming` is set to True.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` and logging the images.\"\n        ),\n    )\n    parser.add_argument(\"--wandb_entity\", type=str, default=None, help=(\"The wandb entity to use (for teams).\"))\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"train_controlnet_flax\",\n        help=(\"The `project` argument passed to wandb\"),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\", type=int, default=1, help=\"Number of steps to accumulate gradients over\"\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    args = parser.parse_args()\n    args.output_dir = args.output_dir.replace(\"{timestamp}\", time.strftime(\"%Y%m%d_%H%M%S\"))\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n    if args.dataset_name is not None and args.train_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--train_data_dir`\")\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    if args.validation_prompt is not None and args.validation_image is None:\n        raise ValueError(\"`--validation_image` must be set if `--validation_prompt` is set\")\n\n    if args.validation_prompt is None and args.validation_image is not None:\n        raise ValueError(\"`--validation_prompt` must be set if `--validation_image` is set\")\n\n    if (\n        args.validation_image is not None\n        and args.validation_prompt is not None\n        and len(args.validation_image) != 1\n        and len(args.validation_prompt) != 1\n        and len(args.validation_image) != len(args.validation_prompt)\n    ):\n        raise ValueError(\n            \"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,\"\n            \" or the same number of `--validation_prompt`s and `--validation_image`s\"\n        )\n\n    # This idea comes from\n    # https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370\n    if args.streaming and args.max_train_samples is None:\n        raise ValueError(\"You must specify `max_train_samples` when using dataset streaming.\")\n\n    return args\n\n\ndef make_train_dataset(args, tokenizer, batch_size=None):\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            streaming=args.streaming,\n        )\n    else:\n        if args.train_data_dir is not None:\n            if args.load_from_disk:\n                dataset = load_from_disk(\n                    args.train_data_dir,\n                )\n            else:\n                dataset = load_dataset(\n                    args.train_data_dir,\n                    cache_dir=args.cache_dir,\n                )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    if isinstance(dataset[\"train\"], IterableDataset):\n        column_names = next(iter(dataset[\"train\"])).keys()\n    else:\n        column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    if args.image_column is None:\n        image_column = column_names[0]\n        logger.info(f\"image column defaulting to {image_column}\")\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.caption_column is None:\n        caption_column = column_names[1]\n        logger.info(f\"caption column defaulting to {caption_column}\")\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.conditioning_image_column is None:\n        conditioning_image_column = column_names[2]\n        logger.info(f\"conditioning image column defaulting to {caption_column}\")\n    else:\n        conditioning_image_column = args.conditioning_image_column\n        if conditioning_image_column not in column_names:\n            raise ValueError(\n                f\"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if random.random() < args.proportion_empty_prompts:\n                captions.append(\"\")\n            elif isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        return inputs.input_ids\n\n    image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    conditioning_image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        images = [image_transforms(image) for image in images]\n\n        conditioning_images = [image.convert(\"RGB\") for image in examples[conditioning_image_column]]\n        conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]\n\n        examples[\"pixel_values\"] = images\n        examples[\"conditioning_pixel_values\"] = conditioning_images\n        examples[\"input_ids\"] = tokenize_captions(examples)\n\n        return examples\n\n    if jax.process_index() == 0:\n        if args.max_train_samples is not None:\n            if args.streaming:\n                dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).take(args.max_train_samples)\n            else:\n                dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        if args.streaming:\n            train_dataset = dataset[\"train\"].map(\n                preprocess_train,\n                batched=True,\n                batch_size=batch_size,\n                remove_columns=list(dataset[\"train\"].features.keys()),\n            )\n        else:\n            train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    return train_dataset\n\n\ndef collate_fn(examples):\n    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    conditioning_pixel_values = torch.stack([example[\"conditioning_pixel_values\"] for example in examples])\n    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n\n    batch = {\n        \"pixel_values\": pixel_values,\n        \"conditioning_pixel_values\": conditioning_pixel_values,\n        \"input_ids\": input_ids,\n    }\n    batch = {k: v.numpy() for k, v in batch.items()}\n    return batch\n\n\ndef get_params_to_save(params):\n    return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    # Setup logging, we only want one process per machine to log things on the screen.\n    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)\n    if jax.process_index() == 0:\n        transformers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n\n    # wandb init\n    if jax.process_index() == 0 and args.report_to == \"wandb\":\n        wandb.init(\n            entity=args.wandb_entity,\n            project=args.tracker_project_name,\n            job_type=\"train\",\n            config=args,\n        )\n\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    rng = jax.random.PRNGKey(0)\n\n    # Handle the repository creation\n    if jax.process_index() == 0:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer and add the placeholder token as a additional special token\n    if args.tokenizer_name:\n        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = CLIPTokenizer.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n        )\n    else:\n        raise NotImplementedError(\"No tokenizer specified!\")\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps\n    train_dataset = make_train_dataset(args, tokenizer, batch_size=total_train_batch_size)\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=not args.streaming,\n        collate_fn=collate_fn,\n        batch_size=total_train_batch_size,\n        num_workers=args.dataloader_num_workers,\n        drop_last=True,\n    )\n\n    weight_dtype = jnp.float32\n    if args.mixed_precision == \"fp16\":\n        weight_dtype = jnp.float16\n    elif args.mixed_precision == \"bf16\":\n        weight_dtype = jnp.bfloat16\n\n    # Load models and create wrapper for stable diffusion\n    text_encoder = FlaxCLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        dtype=weight_dtype,\n        revision=args.revision,\n        from_pt=args.from_pt,\n    )\n    vae, vae_params = FlaxAutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        revision=args.revision,\n        subfolder=\"vae\",\n        dtype=weight_dtype,\n        from_pt=args.from_pt,\n    )\n    unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"unet\",\n        dtype=weight_dtype,\n        revision=args.revision,\n        from_pt=args.from_pt,\n    )\n\n    if args.controlnet_model_name_or_path:\n        logger.info(\"Loading existing controlnet weights\")\n        controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(\n            args.controlnet_model_name_or_path,\n            revision=args.controlnet_revision,\n            from_pt=args.controlnet_from_pt,\n            dtype=jnp.float32,\n        )\n    else:\n        logger.info(\"Initializing controlnet weights from unet\")\n        rng, rng_params = jax.random.split(rng)\n\n        controlnet = FlaxControlNetModel(\n            in_channels=unet.config.in_channels,\n            down_block_types=unet.config.down_block_types,\n            only_cross_attention=unet.config.only_cross_attention,\n            block_out_channels=unet.config.block_out_channels,\n            layers_per_block=unet.config.layers_per_block,\n            attention_head_dim=unet.config.attention_head_dim,\n            cross_attention_dim=unet.config.cross_attention_dim,\n            use_linear_projection=unet.config.use_linear_projection,\n            flip_sin_to_cos=unet.config.flip_sin_to_cos,\n            freq_shift=unet.config.freq_shift,\n        )\n        controlnet_params = controlnet.init_weights(rng=rng_params)\n        controlnet_params = unfreeze(controlnet_params)\n        for key in [\n            \"conv_in\",\n            \"time_embedding\",\n            \"down_blocks_0\",\n            \"down_blocks_1\",\n            \"down_blocks_2\",\n            \"down_blocks_3\",\n            \"mid_block\",\n        ]:\n            controlnet_params[key] = unet_params[key]\n\n    pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        tokenizer=tokenizer,\n        controlnet=controlnet,\n        safety_checker=None,\n        dtype=weight_dtype,\n        revision=args.revision,\n        from_pt=args.from_pt,\n    )\n    pipeline_params = jax_utils.replicate(pipeline_params)\n\n    # Optimization\n    if args.scale_lr:\n        args.learning_rate = args.learning_rate * total_train_batch_size\n\n    constant_scheduler = optax.constant_schedule(args.learning_rate)\n\n    adamw = optax.adamw(\n        learning_rate=constant_scheduler,\n        b1=args.adam_beta1,\n        b2=args.adam_beta2,\n        eps=args.adam_epsilon,\n        weight_decay=args.adam_weight_decay,\n    )\n\n    optimizer = optax.chain(\n        optax.clip_by_global_norm(args.max_grad_norm),\n        adamw,\n    )\n\n    state = train_state.TrainState.create(apply_fn=controlnet.__call__, params=controlnet_params, tx=optimizer)\n\n    noise_scheduler, noise_scheduler_state = FlaxDDPMScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n    )\n\n    # Initialize our training\n    validation_rng, train_rngs = jax.random.split(rng)\n    train_rngs = jax.random.split(train_rngs, jax.local_device_count())\n\n    def compute_snr(timesteps):\n        \"\"\"\n        Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849\n        \"\"\"\n        alphas_cumprod = noise_scheduler_state.common.alphas_cumprod\n        sqrt_alphas_cumprod = alphas_cumprod**0.5\n        sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5\n\n        alpha = sqrt_alphas_cumprod[timesteps]\n        sigma = sqrt_one_minus_alphas_cumprod[timesteps]\n        # Compute SNR.\n        snr = (alpha / sigma) ** 2\n        return snr\n\n    def train_step(state, unet_params, text_encoder_params, vae_params, batch, train_rng):\n        # reshape batch, add grad_step_dim if gradient_accumulation_steps > 1\n        if args.gradient_accumulation_steps > 1:\n            grad_steps = args.gradient_accumulation_steps\n            batch = jax.tree_map(lambda x: x.reshape((grad_steps, x.shape[0] // grad_steps) + x.shape[1:]), batch)\n\n        def compute_loss(params, minibatch, sample_rng):\n            # Convert images to latent space\n            vae_outputs = vae.apply(\n                {\"params\": vae_params}, minibatch[\"pixel_values\"], deterministic=True, method=vae.encode\n            )\n            latents = vae_outputs.latent_dist.sample(sample_rng)\n            # (NHWC) -> (NCHW)\n            latents = jnp.transpose(latents, (0, 3, 1, 2))\n            latents = latents * vae.config.scaling_factor\n\n            # Sample noise that we'll add to the latents\n            noise_rng, timestep_rng = jax.random.split(sample_rng)\n            noise = jax.random.normal(noise_rng, latents.shape)\n            # Sample a random timestep for each image\n            bsz = latents.shape[0]\n            timesteps = jax.random.randint(\n                timestep_rng,\n                (bsz,),\n                0,\n                noise_scheduler.config.num_train_timesteps,\n            )\n\n            # Add noise to the latents according to the noise magnitude at each timestep\n            # (this is the forward diffusion process)\n            noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)\n\n            # Get the text embedding for conditioning\n            encoder_hidden_states = text_encoder(\n                minibatch[\"input_ids\"],\n                params=text_encoder_params,\n                train=False,\n            )[0]\n\n            controlnet_cond = minibatch[\"conditioning_pixel_values\"]\n\n            # Predict the noise residual and compute loss\n            down_block_res_samples, mid_block_res_sample = controlnet.apply(\n                {\"params\": params},\n                noisy_latents,\n                timesteps,\n                encoder_hidden_states,\n                controlnet_cond,\n                train=True,\n                return_dict=False,\n            )\n\n            model_pred = unet.apply(\n                {\"params\": unet_params},\n                noisy_latents,\n                timesteps,\n                encoder_hidden_states,\n                down_block_additional_residuals=down_block_res_samples,\n                mid_block_additional_residual=mid_block_res_sample,\n            ).sample\n\n            # Get the target for loss depending on the prediction type\n            if noise_scheduler.config.prediction_type == \"epsilon\":\n                target = noise\n            elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)\n            else:\n                raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n            loss = (target - model_pred) ** 2\n\n            if args.snr_gamma is not None:\n                snr = jnp.array(compute_snr(timesteps))\n                snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma)\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    snr_loss_weights = snr_loss_weights / snr\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    snr_loss_weights = snr_loss_weights / (snr + 1)\n\n                loss = loss * snr_loss_weights\n\n            loss = loss.mean()\n\n            return loss\n\n        grad_fn = jax.value_and_grad(compute_loss)\n\n        # get a minibatch (one gradient accumulation slice)\n        def get_minibatch(batch, grad_idx):\n            return jax.tree_util.tree_map(\n                lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),\n                batch,\n            )\n\n        def loss_and_grad(grad_idx, train_rng):\n            # create minibatch for the grad step\n            minibatch = get_minibatch(batch, grad_idx) if grad_idx is not None else batch\n            sample_rng, train_rng = jax.random.split(train_rng, 2)\n            loss, grad = grad_fn(state.params, minibatch, sample_rng)\n            return loss, grad, train_rng\n\n        if args.gradient_accumulation_steps == 1:\n            loss, grad, new_train_rng = loss_and_grad(None, train_rng)\n        else:\n            init_loss_grad_rng = (\n                0.0,  # initial value for cumul_loss\n                jax.tree_map(jnp.zeros_like, state.params),  # initial value for cumul_grad\n                train_rng,  # initial value for train_rng\n            )\n\n            def cumul_grad_step(grad_idx, loss_grad_rng):\n                cumul_loss, cumul_grad, train_rng = loss_grad_rng\n                loss, grad, new_train_rng = loss_and_grad(grad_idx, train_rng)\n                cumul_loss, cumul_grad = jax.tree_map(jnp.add, (cumul_loss, cumul_grad), (loss, grad))\n                return cumul_loss, cumul_grad, new_train_rng\n\n            loss, grad, new_train_rng = jax.lax.fori_loop(\n                0,\n                args.gradient_accumulation_steps,\n                cumul_grad_step,\n                init_loss_grad_rng,\n            )\n            loss, grad = jax.tree_map(lambda x: x / args.gradient_accumulation_steps, (loss, grad))\n\n        grad = jax.lax.pmean(grad, \"batch\")\n\n        new_state = state.apply_gradients(grads=grad)\n\n        metrics = {\"loss\": loss}\n        metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n\n        def l2(xs):\n            return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))\n\n        metrics[\"l2_grads\"] = l2(jax.tree_util.tree_leaves(grad))\n\n        return new_state, metrics, new_train_rng\n\n    # Create parallel version of the train step\n    p_train_step = jax.pmap(train_step, \"batch\", donate_argnums=(0,))\n\n    # Replicate the train state on each device\n    state = jax_utils.replicate(state)\n    unet_params = jax_utils.replicate(unet_params)\n    text_encoder_params = jax_utils.replicate(text_encoder.params)\n    vae_params = jax_utils.replicate(vae_params)\n\n    # Train!\n    if args.streaming:\n        dataset_length = args.max_train_samples\n    else:\n        dataset_length = len(train_dataloader)\n    num_update_steps_per_epoch = math.ceil(dataset_length / args.gradient_accumulation_steps)\n\n    # Scheduler and math around the number of training steps.\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {args.max_train_samples if args.streaming else len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel & distributed) = {total_train_batch_size}\")\n    logger.info(f\"  Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}\")\n\n    if jax.process_index() == 0 and args.report_to == \"wandb\":\n        wandb.define_metric(\"*\", step_metric=\"train/step\")\n        wandb.define_metric(\"train/step\", step_metric=\"walltime\")\n        wandb.config.update(\n            {\n                \"num_train_examples\": args.max_train_samples if args.streaming else len(train_dataset),\n                \"total_train_batch_size\": total_train_batch_size,\n                \"total_optimization_step\": args.num_train_epochs * num_update_steps_per_epoch,\n                \"num_devices\": jax.device_count(),\n                \"controlnet_params\": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)),\n            }\n        )\n\n    global_step = step0 = 0\n    epochs = tqdm(\n        range(args.num_train_epochs),\n        desc=\"Epoch ... \",\n        position=0,\n        disable=jax.process_index() > 0,\n    )\n    if args.profile_memory:\n        jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, \"memory_initial.prof\"))\n    t00 = t0 = time.monotonic()\n    for epoch in epochs:\n        # ======================== Training ================================\n\n        train_metrics = []\n        train_metric = None\n\n        steps_per_epoch = (\n            args.max_train_samples // total_train_batch_size\n            if args.streaming or args.max_train_samples\n            else len(train_dataset) // total_train_batch_size\n        )\n        train_step_progress_bar = tqdm(\n            total=steps_per_epoch,\n            desc=\"Training...\",\n            position=1,\n            leave=False,\n            disable=jax.process_index() > 0,\n        )\n        # train\n        for batch in train_dataloader:\n            if args.profile_steps and global_step == 1:\n                train_metric[\"loss\"].block_until_ready()\n                jax.profiler.start_trace(args.output_dir)\n            if args.profile_steps and global_step == 1 + args.profile_steps:\n                train_metric[\"loss\"].block_until_ready()\n                jax.profiler.stop_trace()\n\n            batch = shard(batch)\n            with jax.profiler.StepTraceAnnotation(\"train\", step_num=global_step):\n                state, train_metric, train_rngs = p_train_step(\n                    state, unet_params, text_encoder_params, vae_params, batch, train_rngs\n                )\n            train_metrics.append(train_metric)\n\n            train_step_progress_bar.update(1)\n\n            global_step += 1\n            if global_step >= args.max_train_steps:\n                break\n\n            if (\n                args.validation_prompt is not None\n                and global_step % args.validation_steps == 0\n                and jax.process_index() == 0\n            ):\n                _ = log_validation(\n                    pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype\n                )\n\n            if global_step % args.logging_steps == 0 and jax.process_index() == 0:\n                if args.report_to == \"wandb\":\n                    train_metrics = jax_utils.unreplicate(train_metrics)\n                    train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)\n                    wandb.log(\n                        {\n                            \"walltime\": time.monotonic() - t00,\n                            \"train/step\": global_step,\n                            \"train/epoch\": global_step / dataset_length,\n                            \"train/steps_per_sec\": (global_step - step0) / (time.monotonic() - t0),\n                            **{f\"train/{k}\": v for k, v in train_metrics.items()},\n                        }\n                    )\n                t0, step0 = time.monotonic(), global_step\n                train_metrics = []\n            if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:\n                controlnet.save_pretrained(\n                    f\"{args.output_dir}/{global_step}\",\n                    params=get_params_to_save(state.params),\n                )\n\n        train_metric = jax_utils.unreplicate(train_metric)\n        train_step_progress_bar.close()\n        epochs.write(f\"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})\")\n\n    # Final validation & store model.\n    if jax.process_index() == 0:\n        if args.validation_prompt is not None:\n            if args.profile_validation:\n                jax.profiler.start_trace(args.output_dir)\n            image_logs = log_validation(\n                pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype\n            )\n            if args.profile_validation:\n                jax.profiler.stop_trace()\n        else:\n            image_logs = None\n\n        controlnet.save_pretrained(\n            args.output_dir,\n            params=get_params_to_save(state.params),\n        )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                image_logs=image_logs,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    if args.profile_memory:\n        jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, \"memory_final.prof\"))\n    logger.info(\"Finished training.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/controlnet/train_controlnet_flux.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport functools\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedType, ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import (\n    AutoTokenizer,\n    CLIPTextModel,\n    T5EncoderModel,\n)\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    FlowMatchEulerDiscreteScheduler,\n    FluxTransformer2DModel,\n)\nfrom diffusers.models.controlnets.controlnet_flux import FluxControlNetModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline\nfrom diffusers.training_utils import compute_density_for_timestep_sampling, free_memory\nfrom diffusers.utils import check_min_version, is_wandb_available, make_image_grid\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\nif is_torch_npu_available():\n    torch.npu.config.allow_internal_format = False\n\n\ndef log_validation(\n    vae, flux_transformer, flux_controlnet, args, accelerator, weight_dtype, step, is_final_validation=False\n):\n    logger.info(\"Running validation... \")\n\n    if not is_final_validation:\n        flux_controlnet = accelerator.unwrap_model(flux_controlnet)\n        pipeline = FluxControlNetPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            controlnet=flux_controlnet,\n            transformer=flux_transformer,\n            torch_dtype=torch.bfloat16,\n        )\n    else:\n        flux_controlnet = FluxControlNetModel.from_pretrained(\n            args.output_dir, torch_dtype=torch.bfloat16, variant=args.save_weight_dtype\n        )\n        pipeline = FluxControlNetPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            controlnet=flux_controlnet,\n            transformer=flux_transformer,\n            torch_dtype=torch.bfloat16,\n        )\n\n    pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n    if is_final_validation or torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type)\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        from diffusers.utils import load_image\n\n        validation_image = load_image(validation_image)\n        # maybe need to inference on 1024 to get a good image\n        validation_image = validation_image.resize((args.resolution, args.resolution))\n\n        images = []\n\n        # pre calculate  prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast\n        prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(\n            validation_prompt, prompt_2=validation_prompt\n        )\n        for _ in range(args.num_validation_images):\n            with autocast_ctx:\n                # need to fix in pipeline_flux_controlnet\n                image = pipeline(\n                    prompt_embeds=prompt_embeds,\n                    pooled_prompt_embeds=pooled_prompt_embeds,\n                    control_image=validation_image,\n                    num_inference_steps=28,\n                    controlnet_conditioning_scale=1,\n                    guidance_scale=3.5,\n                    generator=generator,\n                ).images[0]\n            image = image.resize((args.resolution, args.resolution))\n            images.append(image)\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images = [np.asarray(validation_image)]\n\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images.append(wandb.Image(validation_image, caption=\"Controlnet conditioning\"))\n\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({tracker_key: formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n    del pipeline\n    free_memory()\n    return image_logs\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# controlnet-{repo_id}\n\nThese are controlnet weights trained on {base_model} with new type of conditioning.\n{img_str}\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"flux\",\n        \"flux-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"controlnet\",\n        \"diffusers-training\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a ControlNet training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--controlnet_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained controlnet model or model identifier from huggingface.co/models.\"\n        \" If not specified controlnet weights are initialized from unet.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"controlnet-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--crops_coords_top_left_h\",\n        type=int,\n        default=0,\n        help=(\"Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet.\"),\n    )\n    parser.add_argument(\n        \"--crops_coords_top_left_w\",\n        type=int,\n        default=0,\n        help=(\"Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet.\"),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--use_adafactor\",\n        action=\"store_true\",\n        help=(\n            \"Adafactor is a stochastic optimization method based on Adam that reduces memory usage while retaining\"\n            \"the empirical benefits of adaptivity. This is achieved through maintaining a factored representation \"\n            \"of the squared gradient accumulator across training steps.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--enable_npu_flash_attention\", action=\"store_true\", help=\"Whether or not to use npu flash attention.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the controlnet conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_double_layers\",\n        type=int,\n        default=4,\n        help=\"Number of double layers in the controlnet (default: 4).\",\n    )\n    parser.add_argument(\n        \"--num_single_layers\",\n        type=int,\n        default=4,\n        help=\"Number of single layers in the controlnet (default: 4).\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=2,\n        help=\"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"flux_train_controlnet\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n    parser.add_argument(\n        \"--jsonl_for_train\",\n        type=str,\n        default=None,\n        help=\"Path to the jsonl file containing the training data.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the guidance scale used for transformer.\",\n    )\n\n    parser.add_argument(\n        \"--save_weight_dtype\",\n        type=str,\n        default=\"fp32\",\n        choices=[\n            \"fp16\",\n            \"bf16\",\n            \"fp32\",\n        ],\n        help=(\"Preserve precision type according to selected weight\"),\n    )\n\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"logit_normal\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--enable_model_cpu_offload\",\n        action=\"store_true\",\n        help=\"Enable model cpu offload and save memory.\",\n    )\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.jsonl_for_train is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--jsonl_for_train`\")\n\n    if args.dataset_name is not None and args.jsonl_for_train is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--jsonl_for_train`\")\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    if args.validation_prompt is not None and args.validation_image is None:\n        raise ValueError(\"`--validation_image` must be set if `--validation_prompt` is set\")\n\n    if args.validation_prompt is None and args.validation_image is not None:\n        raise ValueError(\"`--validation_prompt` must be set if `--validation_image` is set\")\n\n    if (\n        args.validation_image is not None\n        and args.validation_prompt is not None\n        and len(args.validation_image) != 1\n        and len(args.validation_prompt) != 1\n        and len(args.validation_image) != len(args.validation_prompt)\n    ):\n        raise ValueError(\n            \"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,\"\n            \" or the same number of `--validation_prompt`s and `--validation_image`s\"\n        )\n\n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder.\"\n        )\n\n    return args\n\n\ndef get_train_dataset(args, accelerator):\n    dataset = None\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    if args.jsonl_for_train is not None:\n        # load from json\n        dataset = load_dataset(\"json\", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)\n        dataset = dataset.flatten_indices()\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    if args.image_column is None:\n        image_column = column_names[0]\n        logger.info(f\"image column defaulting to {image_column}\")\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.caption_column is None:\n        caption_column = column_names[1]\n        logger.info(f\"caption column defaulting to {caption_column}\")\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.conditioning_image_column is None:\n        conditioning_image_column = column_names[2]\n        logger.info(f\"conditioning image column defaulting to {conditioning_image_column}\")\n    else:\n        conditioning_image_column = args.conditioning_image_column\n        if conditioning_image_column not in column_names:\n            raise ValueError(\n                f\"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    with accelerator.main_process_first():\n        train_dataset = dataset[\"train\"].shuffle(seed=args.seed)\n        if args.max_train_samples is not None:\n            train_dataset = train_dataset.select(range(args.max_train_samples))\n    return train_dataset\n\n\ndef prepare_train_dataset(dataset, accelerator):\n    interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)\n    if interpolation is None:\n        raise ValueError(f\"Unsupported interpolation mode {interpolation=}.\")\n\n    image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=interpolation),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    conditioning_image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=interpolation),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [\n            (image.convert(\"RGB\") if not isinstance(image, str) else Image.open(image).convert(\"RGB\"))\n            for image in examples[args.image_column]\n        ]\n        images = [image_transforms(image) for image in images]\n\n        conditioning_images = [\n            (image.convert(\"RGB\") if not isinstance(image, str) else Image.open(image).convert(\"RGB\"))\n            for image in examples[args.conditioning_image_column]\n        ]\n        conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]\n        examples[\"pixel_values\"] = images\n        examples[\"conditioning_pixel_values\"] = conditioning_images\n\n        return examples\n\n    with accelerator.main_process_first():\n        dataset = dataset.with_transform(preprocess_train)\n\n    return dataset\n\n\ndef collate_fn(examples):\n    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    conditioning_pixel_values = torch.stack([example[\"conditioning_pixel_values\"] for example in examples])\n    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    prompt_ids = torch.stack([torch.tensor(example[\"prompt_embeds\"]) for example in examples])\n\n    pooled_prompt_embeds = torch.stack([torch.tensor(example[\"pooled_prompt_embeds\"]) for example in examples])\n    text_ids = torch.stack([torch.tensor(example[\"text_ids\"]) for example in examples])\n\n    return {\n        \"pixel_values\": pixel_values,\n        \"conditioning_pixel_values\": conditioning_pixel_values,\n        \"prompt_ids\": prompt_ids,\n        \"unet_added_conditions\": {\"pooled_prompt_embeds\": pooled_prompt_embeds, \"time_ids\": text_ids},\n    }\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_out_dir = Path(args.output_dir, args.logging_dir)\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir))\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.\n    if torch.backends.mps.is_available():\n        print(\"MPS is enabled. Disabling AMP.\")\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        # DEBUG, INFO, WARNING, ERROR, CRITICAL\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizers\n    # load clip tokenizer\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n    # load t5 tokenizer\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n    )\n    # load clip text encoder\n    text_encoder_one = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    # load t5 text encoder\n    text_encoder_two = T5EncoderModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    flux_transformer = FluxTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    if args.controlnet_model_name_or_path:\n        logger.info(\"Loading existing controlnet weights\")\n        flux_controlnet = FluxControlNetModel.from_pretrained(args.controlnet_model_name_or_path)\n    else:\n        logger.info(\"Initializing controlnet weights from transformer\")\n        # we can define the num_layers, num_single_layers,\n        flux_controlnet = FluxControlNetModel.from_transformer(\n            flux_transformer,\n            attention_head_dim=flux_transformer.config[\"attention_head_dim\"],\n            num_attention_heads=flux_transformer.config[\"num_attention_heads\"],\n            num_layers=args.num_double_layers,\n            num_single_layers=args.num_single_layers,\n        )\n    logger.info(\"all models loaded successfully\")\n\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"scheduler\",\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    vae.requires_grad_(False)\n    flux_transformer.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    flux_controlnet.train()\n\n    # use some pipeline function\n    flux_controlnet_pipeline = FluxControlNetPipeline(\n        scheduler=noise_scheduler,\n        vae=vae,\n        text_encoder=text_encoder_one,\n        tokenizer=tokenizer_one,\n        text_encoder_2=text_encoder_two,\n        tokenizer_2=tokenizer_two,\n        transformer=flux_transformer,\n        controlnet=flux_controlnet,\n    )\n    if args.enable_model_cpu_offload:\n        flux_controlnet_pipeline.enable_model_cpu_offload()\n    else:\n        flux_controlnet_pipeline.to(accelerator.device)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                i = len(weights) - 1\n\n                while len(weights) > 0:\n                    weights.pop()\n                    model = models[i]\n\n                    sub_dir = \"flux_controlnet\"\n                    model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                    i -= 1\n\n        def load_model_hook(models, input_dir):\n            while len(models) > 0:\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = FluxControlNetModel.from_pretrained(input_dir, subfolder=\"flux_controlnet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            flux_transformer.enable_npu_flash_attention()\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu devices.\")\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            flux_transformer.enable_xformers_memory_efficient_attention()\n            flux_controlnet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        flux_transformer.enable_gradient_checkpointing()\n        flux_controlnet.enable_gradient_checkpointing()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if unwrap_model(flux_controlnet).dtype != torch.float32:\n        raise ValueError(\n            f\"Controlnet loaded as datatype {unwrap_model(flux_controlnet).dtype}. {low_precision_error_string}\"\n        )\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = flux_controlnet.parameters()\n    # use adafactor optimizer to save gpu memory\n    if args.use_adafactor:\n        from transformers import Adafactor\n\n        optimizer = Adafactor(\n            params_to_optimize,\n            lr=args.learning_rate,\n            scale_parameter=False,\n            relative_step=False,\n            # warmup_init=True,\n            weight_decay=args.adam_weight_decay,\n        )\n    else:\n        optimizer = optimizer_class(\n            params_to_optimize,\n            lr=args.learning_rate,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    vae.to(accelerator.device, dtype=weight_dtype)\n    flux_transformer.to(accelerator.device, dtype=weight_dtype)\n\n    def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline, weight_dtype, is_train=True):\n        prompt_batch = batch[args.caption_column]\n        captions = []\n        for caption in prompt_batch:\n            if random.random() < proportion_empty_prompts:\n                captions.append(\"\")\n            elif isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n        prompt_batch = captions\n        prompt_embeds, pooled_prompt_embeds, text_ids = flux_controlnet_pipeline.encode_prompt(\n            prompt_batch, prompt_2=prompt_batch\n        )\n        prompt_embeds = prompt_embeds.to(dtype=weight_dtype)\n        pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype)\n        text_ids = text_ids.to(dtype=weight_dtype)\n\n        # text_ids [512,3] to [bs,512,3]\n        text_ids = text_ids.unsqueeze(0).expand(prompt_embeds.shape[0], -1, -1)\n        return {\"prompt_embeds\": prompt_embeds, \"pooled_prompt_embeds\": pooled_prompt_embeds, \"text_ids\": text_ids}\n\n    train_dataset = get_train_dataset(args, accelerator)\n    compute_embeddings_fn = functools.partial(\n        compute_embeddings,\n        flux_controlnet_pipeline=flux_controlnet_pipeline,\n        proportion_empty_prompts=args.proportion_empty_prompts,\n        weight_dtype=weight_dtype,\n    )\n    with accelerator.main_process_first():\n        from datasets.fingerprint import Hasher\n\n        # fingerprint used by the cache for the other processes to load the result\n        # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401\n        new_fingerprint = Hasher.hash(args)\n        train_dataset = train_dataset.map(\n            compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50\n        )\n\n    text_encoder_one.to(\"cpu\")\n    text_encoder_two.to(\"cpu\")\n    free_memory()\n\n    # Then get the training dataset ready to be passed to the dataloader.\n    train_dataset = prepare_train_dataset(train_dataset, accelerator)\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n    # Prepare everything with our `accelerator`.\n    flux_controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        flux_controlnet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n\n        # tensorboard cannot handle list types for config\n        tracker_config.pop(\"validation_prompt\")\n        tracker_config.pop(\"validation_image\")\n\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    image_logs = None\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(flux_controlnet):\n                # Convert images to latent space\n                # vae encode\n                pixel_values = batch[\"pixel_values\"].to(dtype=weight_dtype)\n                pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample()\n                pixel_latents_tmp = (pixel_latents_tmp - vae.config.shift_factor) * vae.config.scaling_factor\n                pixel_latents = FluxControlNetPipeline._pack_latents(\n                    pixel_latents_tmp,\n                    pixel_values.shape[0],\n                    pixel_latents_tmp.shape[1],\n                    pixel_latents_tmp.shape[2],\n                    pixel_latents_tmp.shape[3],\n                )\n\n                control_values = batch[\"conditioning_pixel_values\"].to(dtype=weight_dtype)\n                control_latents = vae.encode(control_values).latent_dist.sample()\n                control_latents = (control_latents - vae.config.shift_factor) * vae.config.scaling_factor\n                control_image = FluxControlNetPipeline._pack_latents(\n                    control_latents,\n                    control_values.shape[0],\n                    control_latents.shape[1],\n                    control_latents.shape[2],\n                    control_latents.shape[3],\n                )\n\n                latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids(\n                    batch_size=pixel_latents_tmp.shape[0],\n                    height=pixel_latents_tmp.shape[2] // 2,\n                    width=pixel_latents_tmp.shape[3] // 2,\n                    device=pixel_values.device,\n                    dtype=pixel_values.dtype,\n                )\n\n                bsz = pixel_latents.shape[0]\n                noise = torch.randn_like(pixel_latents).to(accelerator.device).to(dtype=weight_dtype)\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)\n\n                # Add noise according to flow matching.\n                sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)\n                noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise\n\n                # handle guidance\n                if flux_transformer.config.guidance_embeds:\n                    guidance_vec = torch.full(\n                        (noisy_model_input.shape[0],),\n                        args.guidance_scale,\n                        device=noisy_model_input.device,\n                        dtype=weight_dtype,\n                    )\n                else:\n                    guidance_vec = None\n\n                controlnet_block_samples, controlnet_single_block_samples = flux_controlnet(\n                    hidden_states=noisy_model_input,\n                    controlnet_cond=control_image,\n                    timestep=timesteps / 1000,\n                    guidance=guidance_vec,\n                    pooled_projections=batch[\"unet_added_conditions\"][\"pooled_prompt_embeds\"].to(dtype=weight_dtype),\n                    encoder_hidden_states=batch[\"prompt_ids\"].to(dtype=weight_dtype),\n                    txt_ids=batch[\"unet_added_conditions\"][\"time_ids\"][0].to(dtype=weight_dtype),\n                    img_ids=latent_image_ids,\n                    return_dict=False,\n                )\n\n                noise_pred = flux_transformer(\n                    hidden_states=noisy_model_input,\n                    timestep=timesteps / 1000,\n                    guidance=guidance_vec,\n                    pooled_projections=batch[\"unet_added_conditions\"][\"pooled_prompt_embeds\"].to(dtype=weight_dtype),\n                    encoder_hidden_states=batch[\"prompt_ids\"].to(dtype=weight_dtype),\n                    controlnet_block_samples=[sample.to(dtype=weight_dtype) for sample in controlnet_block_samples]\n                    if controlnet_block_samples is not None\n                    else None,\n                    controlnet_single_block_samples=[\n                        sample.to(dtype=weight_dtype) for sample in controlnet_single_block_samples\n                    ]\n                    if controlnet_single_block_samples is not None\n                    else None,\n                    txt_ids=batch[\"unet_added_conditions\"][\"time_ids\"][0].to(dtype=weight_dtype),\n                    img_ids=latent_image_ids,\n                    return_dict=False,\n                )[0]\n\n                loss = F.mse_loss(noise_pred.float(), (noise - pixel_latents).float(), reduction=\"mean\")\n                accelerator.backward(loss)\n                # Check if the gradient of each model parameter contains NaN\n                for name, param in flux_controlnet.named_parameters():\n                    if param.grad is not None and torch.isnan(param.grad).any():\n                        logger.error(f\"Gradient for {name} contains NaN!\")\n\n                if accelerator.sync_gradients:\n                    params_to_clip = flux_controlnet.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.\n                if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        image_logs = log_validation(\n                            vae=vae,\n                            flux_transformer=flux_transformer,\n                            flux_controlnet=flux_controlnet,\n                            args=args,\n                            accelerator=accelerator,\n                            weight_dtype=weight_dtype,\n                            step=global_step,\n                        )\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        flux_controlnet = unwrap_model(flux_controlnet)\n        save_weight_dtype = torch.float32\n        if args.save_weight_dtype == \"fp16\":\n            save_weight_dtype = torch.float16\n        elif args.save_weight_dtype == \"bf16\":\n            save_weight_dtype = torch.bfloat16\n        flux_controlnet.to(save_weight_dtype)\n        if args.save_weight_dtype != \"fp32\":\n            flux_controlnet.save_pretrained(args.output_dir, variant=args.save_weight_dtype)\n        else:\n            flux_controlnet.save_pretrained(args.output_dir)\n        # Run a final round of validation.\n        # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.\n        image_logs = None\n        if args.validation_prompt is not None:\n            image_logs = log_validation(\n                vae=vae,\n                flux_transformer=flux_transformer,\n                flux_controlnet=None,\n                args=args,\n                accelerator=accelerator,\n                weight_dtype=weight_dtype,\n                step=global_step,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                image_logs=image_logs,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/controlnet/train_controlnet_sd3.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport contextlib\nimport copy\nimport functools\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport shutil\n\n# Add repo root to path to import from tests\nfrom pathlib import Path\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    FlowMatchEulerDiscreteScheduler,\n    SD3ControlNetModel,\n    SD3Transformer2DModel,\n    StableDiffusion3ControlNetPipeline,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory\nfrom diffusers.utils import check_min_version, is_wandb_available, make_image_grid\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.torch_utils import backend_empty_cache, is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):\n    logger.info(\"Running validation... \")\n\n    if not is_final_validation:\n        controlnet = accelerator.unwrap_model(controlnet)\n    else:\n        controlnet = SD3ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)\n\n    pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        controlnet=None,\n        safety_checker=None,\n        transformer=None,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    pipeline = pipeline.to(torch.device(accelerator.device))\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    with torch.no_grad():\n        (\n            prompt_embeds,\n            negative_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_pooled_prompt_embeds,\n        ) = pipeline.encode_prompt(\n            validation_prompts,\n            prompt_2=None,\n            prompt_3=None,\n        )\n\n    del pipeline\n    gc.collect()\n    backend_empty_cache(accelerator.device.type)\n\n    pipeline = StableDiffusion3ControlNetPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        controlnet=controlnet,\n        safety_checker=None,\n        text_encoder=None,\n        text_encoder_2=None,\n        text_encoder_3=None,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    pipeline.enable_model_cpu_offload(device=accelerator.device.type)\n    pipeline.set_progress_bar_config(disable=True)\n\n    image_logs = []\n    inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type)\n\n    for i, validation_image in enumerate(validation_images):\n        validation_image = Image.open(validation_image).convert(\"RGB\")\n        validation_prompt = validation_prompts[i]\n\n        images = []\n\n        for _ in range(args.num_validation_images):\n            with inference_ctx:\n                image = pipeline(\n                    prompt_embeds=prompt_embeds[i].unsqueeze(0),\n                    negative_prompt_embeds=negative_prompt_embeds[i].unsqueeze(0),\n                    pooled_prompt_embeds=pooled_prompt_embeds[i].unsqueeze(0),\n                    negative_pooled_prompt_embeds=negative_pooled_prompt_embeds[i].unsqueeze(0),\n                    control_image=validation_image,\n                    num_inference_steps=20,\n                    generator=generator,\n                ).images[0]\n\n            images.append(image)\n\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                tracker.writer.add_image(\n                    \"Controlnet conditioning\", np.asarray([validation_image]), step, dataformats=\"NHWC\"\n                )\n\n                formatted_images = []\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images.append(wandb.Image(validation_image, caption=\"Controlnet conditioning\"))\n\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({tracker_key: formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n    del pipeline\n    free_memory()\n\n    if not is_final_validation:\n        controlnet.to(accelerator.device)\n\n    return image_logs\n\n\n# Copied from dreambooth sd3 example\ndef load_text_encoders(class_one, class_two, class_three):\n    text_encoder_one = class_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = class_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_three = class_three.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_3\", revision=args.revision, variant=args.variant\n    )\n    return text_encoder_one, text_encoder_two, text_encoder_three\n\n\n# Copied from dreambooth sd3 example\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n    if model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# SD3 controlnet-{repo_id}\n\nThese are controlnet weights trained on {base_model} with new type of conditioning.\nThe weights were trained using [ControlNet](https://github.com/lllyasviel/ControlNet) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sd3.md).\n{img_str}\n\nPlease adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`.\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"openrail++\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"sd3\",\n        \"sd3-diffusers\",\n        \"controlnet\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a ControlNet training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--controlnet_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained controlnet model or model identifier from huggingface.co/models.\"\n        \" If not specified controlnet weights are initialized from unet.\",\n    )\n    parser.add_argument(\n        \"--num_extra_conditioning_channels\",\n        type=int,\n        default=0,\n        help=\"Number of extra conditioning channels for controlnet.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"controlnet-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--upcast_vae\",\n        action=\"store_true\",\n        help=\"Whether or not to upcast vae to fp32\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"logit_normal\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\"],\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--precondition_outputs\",\n        type=int,\n        default=1,\n        help=\"Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how \"\n        \"model `target` is calculated.\",\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the controlnet conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=77,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--dataset_preprocess_batch_size\", type=int, default=1000, help=\"Batch size for preprocessing dataset.\"\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"train_controlnet\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--train_data_dir`\")\n\n    if args.dataset_name is not None and args.train_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--train_data_dir`\")\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    if args.validation_prompt is not None and args.validation_image is None:\n        raise ValueError(\"`--validation_image` must be set if `--validation_prompt` is set\")\n\n    if args.validation_prompt is None and args.validation_image is not None:\n        raise ValueError(\"`--validation_prompt` must be set if `--validation_image` is set\")\n\n    if (\n        args.validation_image is not None\n        and args.validation_prompt is not None\n        and len(args.validation_image) != 1\n        and len(args.validation_prompt) != 1\n        and len(args.validation_image) != len(args.validation_prompt)\n    ):\n        raise ValueError(\n            \"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,\"\n            \" or the same number of `--validation_prompt`s and `--validation_image`s\"\n        )\n\n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder.\"\n        )\n\n    return args\n\n\ndef make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, accelerator):\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        if args.train_data_dir is not None:\n            dataset = load_dataset(\n                args.train_data_dir,\n                cache_dir=args.cache_dir,\n                trust_remote_code=True,\n            )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    if args.image_column is None:\n        image_column = column_names[0]\n        logger.info(f\"image column defaulting to {image_column}\")\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.caption_column is None:\n        caption_column = column_names[1]\n        logger.info(f\"caption column defaulting to {caption_column}\")\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.conditioning_image_column is None:\n        conditioning_image_column = column_names[2]\n        logger.info(f\"conditioning image column defaulting to {conditioning_image_column}\")\n    else:\n        conditioning_image_column = args.conditioning_image_column\n        if conditioning_image_column not in column_names:\n            raise ValueError(\n                f\"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    def process_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if random.random() < args.proportion_empty_prompts:\n                captions.append(\"\")\n            elif isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        return captions\n\n    image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    conditioning_image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        images = [image_transforms(image) for image in images]\n\n        conditioning_images = [image.convert(\"RGB\") for image in examples[conditioning_image_column]]\n        conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]\n\n        examples[\"pixel_values\"] = images\n        examples[\"conditioning_pixel_values\"] = conditioning_images\n        examples[\"prompts\"] = process_captions(examples)\n\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    return train_dataset\n\n\ndef collate_fn(examples):\n    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    conditioning_pixel_values = torch.stack([example[\"conditioning_pixel_values\"] for example in examples])\n    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    prompt_embeds = torch.stack([torch.tensor(example[\"prompt_embeds\"]) for example in examples])\n    pooled_prompt_embeds = torch.stack([torch.tensor(example[\"pooled_prompt_embeds\"]) for example in examples])\n\n    return {\n        \"pixel_values\": pixel_values,\n        \"conditioning_pixel_values\": conditioning_pixel_values,\n        \"prompt_embeds\": prompt_embeds,\n        \"pooled_prompt_embeds\": pooled_prompt_embeds,\n    }\n\n\n# Copied from dreambooth sd3 example\ndef _encode_prompt_with_t5(\n    text_encoder,\n    tokenizer,\n    max_sequence_length,\n    prompt=None,\n    num_images_per_prompt=1,\n    device=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=max_sequence_length,\n        truncation=True,\n        add_special_tokens=True,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    prompt_embeds = text_encoder(text_input_ids.to(device))[0]\n\n    dtype = text_encoder.dtype\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    _, seq_len, _ = prompt_embeds.shape\n\n    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds\n\n\n# Copied from dreambooth sd3 example\ndef _encode_prompt_with_clip(\n    text_encoder,\n    tokenizer,\n    prompt: str,\n    device=None,\n    num_images_per_prompt: int = 1,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=77,\n        truncation=True,\n        return_tensors=\"pt\",\n    )\n\n    text_input_ids = text_inputs.input_ids\n    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n\n    pooled_prompt_embeds = prompt_embeds[0]\n    prompt_embeds = prompt_embeds.hidden_states[-2]\n    prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)\n\n    _, seq_len, _ = prompt_embeds.shape\n    # duplicate text embeddings for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds, pooled_prompt_embeds\n\n\n# Copied from dreambooth sd3 example\ndef encode_prompt(\n    text_encoders,\n    tokenizers,\n    prompt: str,\n    max_sequence_length,\n    device=None,\n    num_images_per_prompt: int = 1,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n\n    clip_tokenizers = tokenizers[:2]\n    clip_text_encoders = text_encoders[:2]\n\n    clip_prompt_embeds_list = []\n    clip_pooled_prompt_embeds_list = []\n    for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):\n        prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            prompt=prompt,\n            device=device if device is not None else text_encoder.device,\n            num_images_per_prompt=num_images_per_prompt,\n        )\n        clip_prompt_embeds_list.append(prompt_embeds)\n        clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)\n\n    clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)\n\n    t5_prompt_embed = _encode_prompt_with_t5(\n        text_encoders[-1],\n        tokenizers[-1],\n        max_sequence_length,\n        prompt=prompt,\n        num_images_per_prompt=num_images_per_prompt,\n        device=device if device is not None else text_encoders[-1].device,\n    )\n\n    clip_prompt_embeds = torch.nn.functional.pad(\n        clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])\n    )\n    prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)\n\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    tokenizer_one = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n    tokenizer_two = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n    )\n    tokenizer_three = T5TokenizerFast.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_3\",\n        revision=args.revision,\n    )\n\n    # import correct text encoder class\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n    text_encoder_cls_three = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_3\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(\n        text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    transformer = SD3Transformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n    )\n\n    if args.controlnet_model_name_or_path:\n        logger.info(\"Loading existing controlnet weights\")\n        controlnet = SD3ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)\n    else:\n        logger.info(\"Initializing controlnet weights from transformer\")\n        controlnet = SD3ControlNetModel.from_transformer(\n            transformer, num_extra_conditioning_channels=args.num_extra_conditioning_channels\n        )\n\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    text_encoder_three.requires_grad_(False)\n    controlnet.train()\n\n    # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                i = len(weights) - 1\n\n                while len(weights) > 0:\n                    weights.pop()\n                    model = models[i]\n\n                    sub_dir = \"controlnet\"\n                    model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                    i -= 1\n\n        def load_model_hook(models, input_dir):\n            while len(models) > 0:\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = SD3ControlNetModel.from_pretrained(input_dir, subfolder=\"controlnet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        controlnet.enable_gradient_checkpointing()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if unwrap_model(controlnet).dtype != torch.float32:\n        raise ValueError(\n            f\"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}\"\n        )\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = controlnet.parameters()\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae, transformer and text_encoder to device and cast to weight_dtype\n    if args.upcast_vae:\n        vae.to(accelerator.device, dtype=torch.float32)\n    else:\n        vae.to(accelerator.device, dtype=weight_dtype)\n    transformer.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_three.to(accelerator.device, dtype=weight_dtype)\n\n    train_dataset = make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, accelerator)\n\n    tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]\n    text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]\n\n    def compute_text_embeddings(batch, text_encoders, tokenizers):\n        with torch.no_grad():\n            prompt = batch[\"prompts\"]\n            prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                text_encoders, tokenizers, prompt, args.max_sequence_length\n            )\n            prompt_embeds = prompt_embeds.to(accelerator.device)\n            pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)\n        return {\"prompt_embeds\": prompt_embeds, \"pooled_prompt_embeds\": pooled_prompt_embeds}\n\n    compute_embeddings_fn = functools.partial(\n        compute_text_embeddings,\n        text_encoders=text_encoders,\n        tokenizers=tokenizers,\n    )\n    with accelerator.main_process_first():\n        from datasets.fingerprint import Hasher\n\n        # fingerprint used by the cache for the other processes to load the result\n        # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401\n        new_fingerprint = Hasher.hash(args)\n        train_dataset = train_dataset.map(\n            compute_embeddings_fn,\n            batched=True,\n            batch_size=args.dataset_preprocess_batch_size,\n            new_fingerprint=new_fingerprint,\n        )\n\n    del text_encoder_one, text_encoder_two, text_encoder_three\n    del tokenizer_one, tokenizer_two, tokenizer_three\n    free_memory()\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        controlnet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n\n        # tensorboard cannot handle list types for config\n        tracker_config.pop(\"validation_prompt\")\n        tracker_config.pop(\"validation_image\")\n\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    image_logs = None\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(controlnet):\n                # Convert images to latent space\n                pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                model_input = vae.encode(pixel_values).latent_dist.sample()\n                model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor\n                model_input = model_input.to(dtype=weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                # Get the text embedding for conditioning\n                prompt_embeds = batch[\"prompt_embeds\"].to(dtype=weight_dtype)\n                pooled_prompt_embeds = batch[\"pooled_prompt_embeds\"].to(dtype=weight_dtype)\n\n                # controlnet(s) inference\n                controlnet_image = batch[\"conditioning_pixel_values\"].to(dtype=weight_dtype)\n                controlnet_image = vae.encode(controlnet_image).latent_dist.sample()\n                controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor\n\n                control_block_res_samples = controlnet(\n                    hidden_states=noisy_model_input,\n                    timestep=timesteps,\n                    encoder_hidden_states=prompt_embeds,\n                    pooled_projections=pooled_prompt_embeds,\n                    controlnet_cond=controlnet_image,\n                    return_dict=False,\n                )[0]\n                control_block_res_samples = [sample.to(dtype=weight_dtype) for sample in control_block_res_samples]\n\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=noisy_model_input,\n                    timestep=timesteps,\n                    encoder_hidden_states=prompt_embeds,\n                    pooled_projections=pooled_prompt_embeds,\n                    block_controlnet_hidden_states=control_block_res_samples,\n                    return_dict=False,\n                )[0]\n\n                # Follow: Section 5 of https://huggingface.co/papers/2206.00364.\n                # Preconditioning of the model outputs.\n                if args.precondition_outputs:\n                    model_pred = model_pred * (-sigmas) + noisy_model_input\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                if args.precondition_outputs:\n                    target = model_input\n                else:\n                    target = noise - model_input\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = controlnet.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        image_logs = log_validation(\n                            controlnet,\n                            args,\n                            accelerator,\n                            weight_dtype,\n                            global_step,\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        controlnet = unwrap_model(controlnet)\n        controlnet.save_pretrained(args.output_dir)\n\n        # Run a final round of validation.\n        image_logs = None\n        if args.validation_prompt is not None:\n            image_logs = log_validation(\n                controlnet=None,\n                args=args,\n                accelerator=accelerator,\n                weight_dtype=weight_dtype,\n                step=global_step,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                image_logs=image_logs,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/controlnet/train_controlnet_sdxl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport functools\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedType, ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    ControlNetModel,\n    DDPMScheduler,\n    StableDiffusionXLControlNetPipeline,\n    UNet2DConditionModel,\n    UniPCMultistepScheduler,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available, make_image_grid\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\nif is_torch_npu_available():\n    torch.npu.config.allow_internal_format = False\n\n\ndef log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):\n    logger.info(\"Running validation... \")\n\n    if not is_final_validation:\n        controlnet = accelerator.unwrap_model(controlnet)\n        pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            vae=vae,\n            unet=unet,\n            controlnet=controlnet,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n    else:\n        controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)\n        if args.pretrained_vae_model_name_or_path is not None:\n            vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype)\n        else:\n            vae = AutoencoderKL.from_pretrained(\n                args.pretrained_model_name_or_path, subfolder=\"vae\", torch_dtype=weight_dtype\n            )\n\n        pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            vae=vae,\n            controlnet=controlnet,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n\n    pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n    if is_final_validation or torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type)\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        validation_image = Image.open(validation_image).convert(\"RGB\")\n\n        try:\n            interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())\n        except (AttributeError, KeyError):\n            supported_interpolation_modes = [\n                f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n            ]\n            raise ValueError(\n                f\"Interpolation mode {args.image_interpolation_mode} is not supported. \"\n                f\"Please select one of the following: {', '.join(supported_interpolation_modes)}\"\n            )\n\n        transform = transforms.Compose(\n            [\n                transforms.Resize(args.resolution, interpolation=interpolation),\n                transforms.CenterCrop(args.resolution),\n            ]\n        )\n        validation_image = transform(validation_image)\n\n        images = []\n\n        for _ in range(args.num_validation_images):\n            with autocast_ctx:\n                image = pipeline(\n                    prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator\n                ).images[0]\n            images.append(image)\n\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images = [np.asarray(validation_image)]\n\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images.append(wandb.Image(validation_image, caption=\"Controlnet conditioning\"))\n\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({tracker_key: formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n    del pipeline\n    gc.collect()\n    torch.cuda.empty_cache()\n\n    return image_logs\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# controlnet-{repo_id}\n\nThese are controlnet weights trained on {base_model} with new type of conditioning.\n{img_str}\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"openrail++\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion-xl\",\n        \"stable-diffusion-xl-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"controlnet\",\n        \"diffusers-training\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a ControlNet training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--controlnet_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained controlnet model or model identifier from huggingface.co/models.\"\n        \" If not specified controlnet weights are initialized from unet.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"controlnet-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--crops_coords_top_left_h\",\n        type=int,\n        default=0,\n        help=(\"Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet.\"),\n    )\n    parser.add_argument(\n        \"--crops_coords_top_left_w\",\n        type=int,\n        default=0,\n        help=(\"Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet.\"),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--enable_npu_flash_attention\", action=\"store_true\", help=\"Whether or not to use npu flash attention.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the controlnet conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"sd_xl_train_controlnet\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--train_data_dir`\")\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    if args.validation_prompt is not None and args.validation_image is None:\n        raise ValueError(\"`--validation_image` must be set if `--validation_prompt` is set\")\n\n    if args.validation_prompt is None and args.validation_image is not None:\n        raise ValueError(\"`--validation_prompt` must be set if `--validation_image` is set\")\n\n    if (\n        args.validation_image is not None\n        and args.validation_prompt is not None\n        and len(args.validation_image) != 1\n        and len(args.validation_prompt) != 1\n        and len(args.validation_image) != len(args.validation_prompt)\n    ):\n        raise ValueError(\n            \"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,\"\n            \" or the same number of `--validation_prompt`s and `--validation_image`s\"\n        )\n\n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder.\"\n        )\n\n    return args\n\n\ndef get_train_dataset(args, accelerator):\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            data_dir=args.train_data_dir,\n        )\n    else:\n        if args.train_data_dir is not None:\n            dataset = load_dataset(\n                args.train_data_dir,\n                cache_dir=args.cache_dir,\n            )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    if args.image_column is None:\n        image_column = column_names[0]\n        logger.info(f\"image column defaulting to {image_column}\")\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.caption_column is None:\n        caption_column = column_names[1]\n        logger.info(f\"caption column defaulting to {caption_column}\")\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.conditioning_image_column is None:\n        conditioning_image_column = column_names[2]\n        logger.info(f\"conditioning image column defaulting to {conditioning_image_column}\")\n    else:\n        conditioning_image_column = args.conditioning_image_column\n        if conditioning_image_column not in column_names:\n            raise ValueError(\n                f\"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    with accelerator.main_process_first():\n        train_dataset = dataset[\"train\"].shuffle(seed=args.seed)\n        if args.max_train_samples is not None:\n            train_dataset = train_dataset.select(range(args.max_train_samples))\n    return train_dataset\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):\n    prompt_embeds_list = []\n\n    captions = []\n    for caption in prompt_batch:\n        if random.random() < proportion_empty_prompts:\n            captions.append(\"\")\n        elif isinstance(caption, str):\n            captions.append(caption)\n        elif isinstance(caption, (list, np.ndarray)):\n            # take a random caption if there are multiple\n            captions.append(random.choice(caption) if is_train else caption[0])\n\n    with torch.no_grad():\n        for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n            text_inputs = tokenizer(\n                captions,\n                padding=\"max_length\",\n                max_length=tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            prompt_embeds = text_encoder(\n                text_input_ids.to(text_encoder.device),\n                output_hidden_states=True,\n            )\n\n            # We are only ALWAYS interested in the pooled output of the final text encoder\n            pooled_prompt_embeds = prompt_embeds[0]\n            prompt_embeds = prompt_embeds.hidden_states[-2]\n            bs_embed, seq_len, _ = prompt_embeds.shape\n            prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n            prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef prepare_train_dataset(dataset, accelerator):\n    try:\n        interpolation_mode = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())\n    except (AttributeError, KeyError):\n        supported_interpolation_modes = [\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ]\n        raise ValueError(\n            f\"Interpolation mode {args.image_interpolation_mode} is not supported. \"\n            f\"Please select one of the following: {', '.join(supported_interpolation_modes)}\"\n        )\n\n    image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=interpolation_mode),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    conditioning_image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=interpolation_mode),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[args.image_column]]\n        images = [image_transforms(image) for image in images]\n\n        conditioning_images = [image.convert(\"RGB\") for image in examples[args.conditioning_image_column]]\n        conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]\n\n        examples[\"pixel_values\"] = images\n        examples[\"conditioning_pixel_values\"] = conditioning_images\n\n        return examples\n\n    with accelerator.main_process_first():\n        dataset = dataset.with_transform(preprocess_train)\n\n    return dataset\n\n\ndef collate_fn(examples):\n    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    conditioning_pixel_values = torch.stack([example[\"conditioning_pixel_values\"] for example in examples])\n    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    prompt_ids = torch.stack([torch.tensor(example[\"prompt_embeds\"]) for example in examples])\n\n    add_text_embeds = torch.stack([torch.tensor(example[\"text_embeds\"]) for example in examples])\n    add_time_ids = torch.stack([torch.tensor(example[\"time_ids\"]) for example in examples])\n\n    return {\n        \"pixel_values\": pixel_values,\n        \"conditioning_pixel_values\": conditioning_pixel_values,\n        \"prompt_ids\": prompt_ids,\n        \"unet_added_conditions\": {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids},\n    }\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    if args.controlnet_model_name_or_path:\n        logger.info(\"Loading existing controlnet weights\")\n        controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)\n    else:\n        logger.info(\"Initializing controlnet weights from unet\")\n        controlnet = ControlNetModel.from_unet(unet)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                i = len(weights) - 1\n\n                while len(weights) > 0:\n                    weights.pop()\n                    model = models[i]\n\n                    sub_dir = \"controlnet\"\n                    model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                    i -= 1\n\n        def load_model_hook(models, input_dir):\n            while len(models) > 0:\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = ControlNetModel.from_pretrained(input_dir, subfolder=\"controlnet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    vae.requires_grad_(False)\n    unet.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    controlnet.train()\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            unet.enable_npu_flash_attention()\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu devices.\")\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n            controlnet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        controlnet.enable_gradient_checkpointing()\n        unet.enable_gradient_checkpointing()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if unwrap_model(controlnet).dtype != torch.float32:\n        raise ValueError(\n            f\"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}\"\n        )\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = controlnet.parameters()\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae, unet and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    if args.pretrained_vae_model_name_or_path is not None:\n        vae.to(accelerator.device, dtype=weight_dtype)\n    else:\n        vae.to(accelerator.device, dtype=torch.float32)\n    unet.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    # Here, we compute not just the text embeddings but also the additional embeddings\n    # needed for the SD XL UNet to operate.\n    def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True):\n        original_size = (args.resolution, args.resolution)\n        target_size = (args.resolution, args.resolution)\n        crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)\n        prompt_batch = batch[args.caption_column]\n\n        prompt_embeds, pooled_prompt_embeds = encode_prompt(\n            prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train\n        )\n        add_text_embeds = pooled_prompt_embeds\n\n        # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n        add_time_ids = torch.tensor([add_time_ids])\n\n        prompt_embeds = prompt_embeds.to(accelerator.device)\n        add_text_embeds = add_text_embeds.to(accelerator.device)\n        add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)\n        add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)\n        unet_added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n        return {\"prompt_embeds\": prompt_embeds, **unet_added_cond_kwargs}\n\n    # Let's first compute all the embeddings so that we can free up the text encoders\n    # from memory.\n    text_encoders = [text_encoder_one, text_encoder_two]\n    tokenizers = [tokenizer_one, tokenizer_two]\n    train_dataset = get_train_dataset(args, accelerator)\n    compute_embeddings_fn = functools.partial(\n        compute_embeddings,\n        text_encoders=text_encoders,\n        tokenizers=tokenizers,\n        proportion_empty_prompts=args.proportion_empty_prompts,\n    )\n    with accelerator.main_process_first():\n        from datasets.fingerprint import Hasher\n\n        # fingerprint used by the cache for the other processes to load the result\n        # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401\n        new_fingerprint = Hasher.hash(args)\n        train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)\n\n    del text_encoders, tokenizers\n    gc.collect()\n    torch.cuda.empty_cache()\n\n    # Then get the training dataset ready to be passed to the dataloader.\n    train_dataset = prepare_train_dataset(train_dataset, accelerator)\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        controlnet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n\n        # tensorboard cannot handle list types for config\n        tracker_config.pop(\"validation_prompt\")\n        tracker_config.pop(\"validation_image\")\n\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    image_logs = None\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(controlnet):\n                # Convert images to latent space\n                if args.pretrained_vae_model_name_or_path is not None:\n                    pixel_values = batch[\"pixel_values\"].to(dtype=weight_dtype)\n                else:\n                    pixel_values = batch[\"pixel_values\"]\n                latents = vae.encode(pixel_values).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    latents = latents.to(weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to(\n                    dtype=weight_dtype\n                )\n\n                # ControlNet conditioning.\n                controlnet_image = batch[\"conditioning_pixel_values\"].to(dtype=weight_dtype)\n                down_block_res_samples, mid_block_res_sample = controlnet(\n                    noisy_latents,\n                    timesteps,\n                    encoder_hidden_states=batch[\"prompt_ids\"],\n                    added_cond_kwargs=batch[\"unet_added_conditions\"],\n                    controlnet_cond=controlnet_image,\n                    return_dict=False,\n                )\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_latents,\n                    timesteps,\n                    encoder_hidden_states=batch[\"prompt_ids\"],\n                    added_cond_kwargs=batch[\"unet_added_conditions\"],\n                    down_block_additional_residuals=[\n                        sample.to(dtype=weight_dtype) for sample in down_block_res_samples\n                    ],\n                    mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),\n                    return_dict=False,\n                )[0]\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = controlnet.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.\n                if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        image_logs = log_validation(\n                            vae=vae,\n                            unet=unet,\n                            controlnet=controlnet,\n                            args=args,\n                            accelerator=accelerator,\n                            weight_dtype=weight_dtype,\n                            step=global_step,\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        controlnet = unwrap_model(controlnet)\n        controlnet.save_pretrained(args.output_dir)\n\n        # Run a final round of validation.\n        # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.\n        image_logs = None\n        if args.validation_prompt is not None:\n            image_logs = log_validation(\n                vae=None,\n                unet=None,\n                controlnet=None,\n                args=args,\n                accelerator=accelerator,\n                weight_dtype=weight_dtype,\n                step=global_step,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                image_logs=image_logs,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/custom_diffusion/README.md",
    "content": "# Custom Diffusion training example\n\n[Custom Diffusion](https://huggingface.co/papers/2212.04488) is a method to customize text-to-image models like Stable Diffusion given just a few (4~5) images of a subject.\nThe `train_custom_diffusion.py` script shows how to implement the training procedure and adapt it for stable diffusion.\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the example folder and run\n\n```bash\npip install -r requirements.txt\npip install clip-retrieval\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell e.g. a notebook\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n### Cat example 😺\n\nNow let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it.\n\nWe also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.\nThe `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training.\n\n```bash\npip install clip-retrieval\npython retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200\n```\n\n**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport OUTPUT_DIR=\"path-to-save-model\"\nexport INSTANCE_DIR=\"./data/cat\"\n\naccelerate launch train_custom_diffusion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --class_data_dir=./real_reg/samples_cat/ \\\n  --with_prior_preservation --real_prior --prior_loss_weight=1.0 \\\n  --class_prompt=\"cat\" --num_class_images=200 \\\n  --instance_prompt=\"photo of a <new1> cat\"  \\\n  --resolution=512  \\\n  --train_batch_size=2  \\\n  --learning_rate=1e-5  \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=250 \\\n  --scale_lr --hflip  \\\n  --modifier_token \"<new1>\"\n```\n\n**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.**\n\nTo track your experiments using Weights and Biases (`wandb`) and to save intermediate results (which we HIGHLY recommend), follow these steps:\n\n* Install `wandb`: `pip install wandb`.\n* Authorize: `wandb login`.\n* Then specify a `validation_prompt` and set `report_to` to `wandb` while launching training. You can also configure the following related arguments:\n    * `num_validation_images`\n    * `validation_steps`\n\nHere is an example command:\n\n```bash\naccelerate launch train_custom_diffusion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --class_data_dir=./real_reg/samples_cat/ \\\n  --with_prior_preservation --real_prior --prior_loss_weight=1.0 \\\n  --class_prompt=\"cat\" --num_class_images=200 \\\n  --instance_prompt=\"photo of a <new1> cat\"  \\\n  --resolution=512  \\\n  --train_batch_size=2  \\\n  --learning_rate=1e-5  \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=250 \\\n  --scale_lr --hflip  \\\n  --modifier_token \"<new1>\" \\\n  --validation_prompt=\"<new1> cat sitting in a bucket\" \\\n  --report_to=\"wandb\"\n```\n\nHere is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau) where you can check out the intermediate results along with other training details.\n\nIf you specify `--push_to_hub`, the learned parameters will be pushed to a repository on the Hugging Face Hub. Here is an [example repository](https://huggingface.co/sayakpaul/custom-diffusion-cat).\n\n### Training on multiple concepts 🐱🪵\n\nProvide a [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with the info about each concept, similar to [this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py).\n\nTo collect the real images run this command for each concept in the json file.\n\n```bash\npip install clip-retrieval\npython retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200\n```\n\nAnd then we're ready to start training!\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_custom_diffusion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --output_dir=$OUTPUT_DIR \\\n  --concepts_list=./concept_list.json \\\n  --with_prior_preservation --real_prior --prior_loss_weight=1.0 \\\n  --resolution=512  \\\n  --train_batch_size=2  \\\n  --learning_rate=1e-5  \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --num_class_images=200 \\\n  --scale_lr --hflip  \\\n  --modifier_token \"<new1>+<new2>\"\n```\n\nHere is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg) where you can check out the intermediate results along with other training details.\n\n### Training on human faces\n\nFor fine-tuning on human faces we found the following configuration to work better: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, and `freeze_model=crossattn` with at least 15-20 images.\n\nTo collect the real images use this command first before training.\n\n```bash\npip install clip-retrieval\npython retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200\n```\n\nThen start training!\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport OUTPUT_DIR=\"path-to-save-model\"\nexport INSTANCE_DIR=\"path-to-images\"\n\naccelerate launch train_custom_diffusion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --class_data_dir=./real_reg/samples_person/ \\\n  --with_prior_preservation --real_prior --prior_loss_weight=1.0 \\\n  --class_prompt=\"person\" --num_class_images=200 \\\n  --instance_prompt=\"photo of a <new1> person\"  \\\n  --resolution=512  \\\n  --train_batch_size=2  \\\n  --learning_rate=5e-6  \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=1000 \\\n  --scale_lr --hflip --noaug \\\n  --freeze_model crossattn \\\n  --modifier_token \"<new1>\" \\\n  --enable_xformers_memory_efficient_attention\n```\n\n## Inference\n\nOnce you have trained a model using the above command, you can run inference using the below command. Make sure to include the `modifier token` (e.g. \\<new1\\> in above example) in your prompt.\n\n```python\nimport torch\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\n    \"CompVis/stable-diffusion-v1-4\", torch_dtype=torch.float16\n).to(\"cuda\")\npipe.unet.load_attn_procs(\n    \"path-to-save-model\", weight_name=\"pytorch_custom_diffusion_weights.bin\"\n)\npipe.load_textual_inversion(\"path-to-save-model\", weight_name=\"<new1>.bin\")\n\nimage = pipe(\n    \"<new1> cat sitting in a bucket\",\n    num_inference_steps=100,\n    guidance_scale=6.0,\n    eta=1.0,\n).images[0]\nimage.save(\"cat.png\")\n```\n\nIt's possible to directly load these parameters from a Hub repository:\n\n```python\nimport torch\nfrom huggingface_hub.repocard import RepoCard\nfrom diffusers import DiffusionPipeline\n\nmodel_id = \"sayakpaul/custom-diffusion-cat\"\ncard = RepoCard.load(model_id)\nbase_model_id = card.data.to_dict()[\"base_model\"]\n\npipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(\n\"cuda\")\npipe.unet.load_attn_procs(model_id, weight_name=\"pytorch_custom_diffusion_weights.bin\")\npipe.load_textual_inversion(model_id, weight_name=\"<new1>.bin\")\n\nimage = pipe(\n    \"<new1> cat sitting in a bucket\",\n    num_inference_steps=100,\n    guidance_scale=6.0,\n    eta=1.0,\n).images[0]\nimage.save(\"cat.png\")\n```\n\nHere is an example of performing inference with multiple concepts:\n\n```python\nimport torch\nfrom huggingface_hub.repocard import RepoCard\nfrom diffusers import DiffusionPipeline\n\nmodel_id = \"sayakpaul/custom-diffusion-cat-wooden-pot\"\ncard = RepoCard.load(model_id)\nbase_model_id = card.data.to_dict()[\"base_model\"]\n\npipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(\n\"cuda\")\npipe.unet.load_attn_procs(model_id, weight_name=\"pytorch_custom_diffusion_weights.bin\")\npipe.load_textual_inversion(model_id, weight_name=\"<new1>.bin\")\npipe.load_textual_inversion(model_id, weight_name=\"<new2>.bin\")\n\nimage = pipe(\n    \"the <new1> cat sculpture in the style of a <new2> wooden pot\",\n    num_inference_steps=100,\n    guidance_scale=6.0,\n    eta=1.0,\n).images[0]\nimage.save(\"multi-subject.png\")\n```\n\nHere, `cat` and `wooden pot` refer to the multiple concepts.\n\n### Inference from a training checkpoint\n\nYou can also perform inference from one of the complete checkpoint saved during the training process, if you used the `--checkpointing_steps` argument.\n\nTODO.\n\n## Set grads to none\nTo save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.\n\nMore info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\n\n## Experimental results\nYou can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. We also released a more extensive dataset of 101 concepts for evaluating model customization methods. For more details please refer to our [dataset webpage](https://www.cs.cmu.edu/~custom-diffusion/dataset.html)."
  },
  {
    "path": "examples/custom_diffusion/requirements.txt",
    "content": "accelerate\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2\n"
  },
  {
    "path": "examples/custom_diffusion/retrieve.py",
    "content": "#  Copyright 2025 Custom Diffusion authors. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nimport os\nfrom io import BytesIO\nfrom pathlib import Path\n\nimport requests\nfrom clip_retrieval.clip_client import ClipClient\nfrom PIL import Image\nfrom tqdm import tqdm\n\n\ndef retrieve(class_prompt, class_data_dir, num_class_images):\n    factor = 1.5\n    num_images = int(factor * num_class_images)\n    client = ClipClient(\n        url=\"https://knn.laion.ai/knn-service\", indice_name=\"laion_400m\", num_images=num_images, aesthetic_weight=0.1\n    )\n\n    os.makedirs(f\"{class_data_dir}/images\", exist_ok=True)\n    if len(list(Path(f\"{class_data_dir}/images\").iterdir())) >= num_class_images:\n        return\n\n    while True:\n        class_images = client.query(text=class_prompt)\n        if len(class_images) >= factor * num_class_images or num_images > 1e4:\n            break\n        else:\n            num_images = int(factor * num_images)\n            client = ClipClient(\n                url=\"https://knn.laion.ai/knn-service\",\n                indice_name=\"laion_400m\",\n                num_images=num_images,\n                aesthetic_weight=0.1,\n            )\n\n    count = 0\n    total = 0\n    pbar = tqdm(desc=\"downloading real regularization images\", total=num_class_images)\n\n    with (\n        open(f\"{class_data_dir}/caption.txt\", \"w\") as f1,\n        open(f\"{class_data_dir}/urls.txt\", \"w\") as f2,\n        open(f\"{class_data_dir}/images.txt\", \"w\") as f3,\n    ):\n        while total < num_class_images:\n            images = class_images[count]\n            count += 1\n            try:\n                img = requests.get(images[\"url\"], timeout=30)\n                if img.status_code == 200:\n                    _ = Image.open(BytesIO(img.content))\n                    with open(f\"{class_data_dir}/images/{total}.jpg\", \"wb\") as f:\n                        f.write(img.content)\n                    f1.write(images[\"caption\"] + \"\\n\")\n                    f2.write(images[\"url\"] + \"\\n\")\n                    f3.write(f\"{class_data_dir}/images/{total}.jpg\" + \"\\n\")\n                    total += 1\n                    pbar.update(1)\n                else:\n                    continue\n            except Exception:\n                continue\n    return\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\"\", add_help=False)\n    parser.add_argument(\"--class_prompt\", help=\"text prompt to retrieve images\", required=True, type=str)\n    parser.add_argument(\"--class_data_dir\", help=\"path to save images\", required=True, type=str)\n    parser.add_argument(\"--num_class_images\", help=\"number of images to download\", default=200, type=int)\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    retrieve(args.class_prompt, args.class_data_dir, args.num_class_images)\n"
  },
  {
    "path": "examples/custom_diffusion/test_custom_diffusion.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\nimport unittest\n\nfrom diffusers.utils import is_transformers_version\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\n@unittest.skipIf(is_transformers_version(\">=\", \"4.57.5\"), \"Size mismatch\")\nclass CustomDiffusion(ExamplesTestsAccelerate):\n    def test_custom_diffusion(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/custom_diffusion/train_custom_diffusion.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt <new1>\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 1.0e-05\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --modifier_token <new1>\n                --no_safe_serialization\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_custom_diffusion_weights.bin\")))\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"<new1>.bin\")))\n\n    def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/custom_diffusion/train_custom_diffusion.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n            --instance_data_dir=docs/source/en/imgs\n            --output_dir={tmpdir}\n            --instance_prompt=<new1>\n            --resolution=64\n            --train_batch_size=1\n            --modifier_token=<new1>\n            --dataloader_num_workers=0\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            --no_safe_serialization\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-4\", \"checkpoint-6\"})\n\n    def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/custom_diffusion/train_custom_diffusion.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n            --instance_data_dir=docs/source/en/imgs\n            --output_dir={tmpdir}\n            --instance_prompt=<new1>\n            --resolution=64\n            --train_batch_size=1\n            --modifier_token=<new1>\n            --dataloader_num_workers=0\n            --max_train_steps=4\n            --checkpointing_steps=2\n            --no_safe_serialization\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\"},\n            )\n\n            resume_run_args = f\"\"\"\n            examples/custom_diffusion/train_custom_diffusion.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n            --instance_data_dir=docs/source/en/imgs\n            --output_dir={tmpdir}\n            --instance_prompt=<new1>\n            --resolution=64\n            --train_batch_size=1\n            --modifier_token=<new1>\n            --dataloader_num_workers=0\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            --no_safe_serialization\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n"
  },
  {
    "path": "examples/custom_diffusion/train_custom_diffusion.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 Custom Diffusion authors and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom pathlib import Path\n\nimport numpy as np\nimport safetensors\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import HfApi, create_repo\nfrom huggingface_hub.utils import insecure_hashlib\nfrom packaging import version\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    DPMSolverMultistepScheduler,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import AttnProcsLayers\nfrom diffusers.models.attention_processor import (\n    CustomDiffusionAttnProcessor,\n    CustomDiffusionAttnProcessor2_0,\n    CustomDiffusionXFormersAttnProcessor,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef freeze_params(params):\n    for param in params:\n        param.requires_grad = False\n\n\ndef save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):\n    img_str = \"\"\n    for i, image in enumerate(images):\n        image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n        img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# Custom Diffusion - {repo_id}\n\nThese are Custom Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Custom Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \\n\n{img_str}\n\n\\nFor more details on the training, please follow [this link](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        prompt=prompt,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"text-to-image\",\n        \"diffusers\",\n        \"stable-diffusion\",\n        \"stable-diffusion-diffusers\",\n        \"custom-diffusion\",\n        \"diffusers-training\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef collate_fn(examples, with_prior_preservation):\n    input_ids = [example[\"instance_prompt_ids\"] for example in examples]\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    mask = [example[\"mask\"] for example in examples]\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        input_ids += [example[\"class_prompt_ids\"] for example in examples]\n        pixel_values += [example[\"class_images\"] for example in examples]\n        mask += [example[\"class_mask\"] for example in examples]\n\n    input_ids = torch.cat(input_ids, dim=0)\n    pixel_values = torch.stack(pixel_values)\n    mask = torch.stack(mask)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n    mask = mask.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"input_ids\": input_ids, \"pixel_values\": pixel_values, \"mask\": mask.unsqueeze(1)}\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\nclass CustomDiffusionDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images and the tokenizes prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        concepts_list,\n        tokenizer,\n        size=512,\n        mask_size=64,\n        center_crop=False,\n        with_prior_preservation=False,\n        num_class_images=200,\n        hflip=False,\n        aug=True,\n    ):\n        self.size = size\n        self.mask_size = mask_size\n        self.center_crop = center_crop\n        self.tokenizer = tokenizer\n        self.interpolation = Image.BILINEAR\n        self.aug = aug\n\n        self.instance_images_path = []\n        self.class_images_path = []\n        self.with_prior_preservation = with_prior_preservation\n        for concept in concepts_list:\n            inst_img_path = [\n                (x, concept[\"instance_prompt\"]) for x in Path(concept[\"instance_data_dir\"]).iterdir() if x.is_file()\n            ]\n            self.instance_images_path.extend(inst_img_path)\n\n            if with_prior_preservation:\n                class_data_root = Path(concept[\"class_data_dir\"])\n                if os.path.isdir(class_data_root):\n                    class_images_path = list(class_data_root.iterdir())\n                    class_prompt = [concept[\"class_prompt\"] for _ in range(len(class_images_path))]\n                else:\n                    with open(class_data_root, \"r\") as f:\n                        class_images_path = f.read().splitlines()\n                    with open(concept[\"class_prompt\"], \"r\") as f:\n                        class_prompt = f.read().splitlines()\n\n                class_img_path = list(zip(class_images_path, class_prompt))\n                self.class_images_path.extend(class_img_path[:num_class_images])\n\n        random.shuffle(self.instance_images_path)\n        self.num_instance_images = len(self.instance_images_path)\n        self.num_class_images = len(self.class_images_path)\n        self._length = max(self.num_class_images, self.num_instance_images)\n        self.flip = transforms.RandomHorizontalFlip(0.5 * hflip)\n\n        self.image_transforms = transforms.Compose(\n            [\n                self.flip,\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def preprocess(self, image, scale, resample):\n        outer, inner = self.size, scale\n        factor = self.size // self.mask_size\n        if scale > self.size:\n            outer, inner = scale, self.size\n        top, left = np.random.randint(0, outer - inner + 1), np.random.randint(0, outer - inner + 1)\n        image = image.resize((scale, scale), resample=resample)\n        image = np.array(image).astype(np.uint8)\n        image = (image / 127.5 - 1.0).astype(np.float32)\n        instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32)\n        mask = np.zeros((self.size // factor, self.size // factor))\n        if scale > self.size:\n            instance_image = image[top : top + inner, left : left + inner, :]\n            mask = np.ones((self.size // factor, self.size // factor))\n        else:\n            instance_image[top : top + inner, left : left + inner, :] = image\n            mask[\n                top // factor + 1 : (top + scale) // factor - 1, left // factor + 1 : (left + scale) // factor - 1\n            ] = 1.0\n        return instance_image, mask\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image, instance_prompt = self.instance_images_path[index % self.num_instance_images]\n        instance_image = Image.open(instance_image)\n        if not instance_image.mode == \"RGB\":\n            instance_image = instance_image.convert(\"RGB\")\n        instance_image = self.flip(instance_image)\n\n        # apply resize augmentation and create a valid image region mask\n        random_scale = self.size\n        if self.aug:\n            random_scale = (\n                np.random.randint(self.size // 3, self.size + 1)\n                if np.random.uniform() < 0.66\n                else np.random.randint(int(1.2 * self.size), int(1.4 * self.size))\n            )\n        instance_image, mask = self.preprocess(instance_image, random_scale, self.interpolation)\n\n        if random_scale < 0.6 * self.size:\n            instance_prompt = np.random.choice([\"a far away \", \"very small \"]) + instance_prompt\n        elif random_scale > self.size:\n            instance_prompt = np.random.choice([\"zoomed in \", \"close up \"]) + instance_prompt\n\n        example[\"instance_images\"] = torch.from_numpy(instance_image).permute(2, 0, 1)\n        example[\"mask\"] = torch.from_numpy(mask)\n        example[\"instance_prompt_ids\"] = self.tokenizer(\n            instance_prompt,\n            truncation=True,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids\n\n        if self.with_prior_preservation:\n            class_image, class_prompt = self.class_images_path[index % self.num_class_images]\n            class_image = Image.open(class_image)\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_mask\"] = torch.ones_like(example[\"mask\"])\n            example[\"class_prompt_ids\"] = self.tokenizer(\n                class_prompt,\n                truncation=True,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                return_tensors=\"pt\",\n            ).input_ids\n\n        return example\n\n\ndef save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True):\n    \"\"\"Saves the new token embeddings from the text encoder.\"\"\"\n    logger.info(\"Saving embeddings\")\n    learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight\n    for x, y in zip(modifier_token_id, args.modifier_token):\n        learned_embeds_dict = {}\n        learned_embeds_dict[y] = learned_embeds[x]\n\n        if safe_serialization:\n            filename = f\"{output_dir}/{y}.safetensors\"\n            safetensors.torch.save_file(learned_embeds_dict, filename, metadata={\"format\": \"pt\"})\n        else:\n            filename = f\"{output_dir}/{y}.bin\"\n            torch.save(learned_embeds_dict, filename)\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Custom Diffusion training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=\"A folder containing the training data of instance images.\",\n    )\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt with identifier specifying the instance\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=2,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\n        \"--real_prior\",\n        default=False,\n        action=\"store_true\",\n        help=\"real images as prior.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=200,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"custom-diffusion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=42, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=250,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-5,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=2,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--freeze_model\",\n        type=str,\n        default=\"crossattn_kv\",\n        choices=[\"crossattn_kv\", \"crossattn\"],\n        help=\"crossattn to enable fine-tuning of all params in the cross attention\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\n        \"--concepts_list\",\n        type=str,\n        default=None,\n        help=\"Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.\",\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--modifier_token\",\n        type=str,\n        default=None,\n        help=\"A token to use as a modifier for the concept.\",\n    )\n    parser.add_argument(\n        \"--initializer_token\", type=str, default=\"ktn+pll+ucd\", help=\"A token to use as initializer word.\"\n    )\n    parser.add_argument(\"--hflip\", action=\"store_true\", help=\"Apply horizontal flip data augmentation.\")\n    parser.add_argument(\n        \"--noaug\",\n        action=\"store_true\",\n        help=\"Dont apply augmentation during data augmentation when this flag is enabled.\",\n    )\n    parser.add_argument(\n        \"--no_safe_serialization\",\n        action=\"store_true\",\n        help=\"If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.concepts_list is None:\n            if args.class_data_dir is None:\n                raise ValueError(\"You must specify a data directory for class images.\")\n            if args.class_prompt is None:\n                raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate\n    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.\n    # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"custom-diffusion\", config=vars(args))\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n    if args.concepts_list is None:\n        args.concepts_list = [\n            {\n                \"instance_prompt\": args.instance_prompt,\n                \"class_prompt\": args.class_prompt,\n                \"instance_data_dir\": args.instance_data_dir,\n                \"class_data_dir\": args.class_data_dir,\n            }\n        ]\n    else:\n        with open(args.concepts_list, \"r\") as f:\n            args.concepts_list = json.load(f)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        for i, concept in enumerate(args.concepts_list):\n            class_images_dir = Path(concept[\"class_data_dir\"])\n            if not class_images_dir.exists():\n                class_images_dir.mkdir(parents=True, exist_ok=True)\n            if args.real_prior:\n                assert (class_images_dir / \"images\").exists(), (\n                    f'Please run: python retrieve.py --class_prompt \"{concept[\"class_prompt\"]}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'\n                )\n                assert len(list((class_images_dir / \"images\").iterdir())) == args.num_class_images, (\n                    f'Please run: python retrieve.py --class_prompt \"{concept[\"class_prompt\"]}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'\n                )\n                assert (class_images_dir / \"caption.txt\").exists(), (\n                    f'Please run: python retrieve.py --class_prompt \"{concept[\"class_prompt\"]}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'\n                )\n                assert (class_images_dir / \"images.txt\").exists(), (\n                    f'Please run: python retrieve.py --class_prompt \"{concept[\"class_prompt\"]}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}'\n                )\n                concept[\"class_prompt\"] = os.path.join(class_images_dir, \"caption.txt\")\n                concept[\"class_data_dir\"] = os.path.join(class_images_dir, \"images.txt\")\n                args.concepts_list[i] = concept\n                accelerator.wait_for_everyone()\n            else:\n                cur_class_images = len(list(class_images_dir.iterdir()))\n\n                if cur_class_images < args.num_class_images:\n                    torch_dtype = torch.float16 if accelerator.device.type == \"cuda\" else torch.float32\n                    if args.prior_generation_precision == \"fp32\":\n                        torch_dtype = torch.float32\n                    elif args.prior_generation_precision == \"fp16\":\n                        torch_dtype = torch.float16\n                    elif args.prior_generation_precision == \"bf16\":\n                        torch_dtype = torch.bfloat16\n                    pipeline = DiffusionPipeline.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        torch_dtype=torch_dtype,\n                        safety_checker=None,\n                        revision=args.revision,\n                        variant=args.variant,\n                    )\n                    pipeline.set_progress_bar_config(disable=True)\n\n                    num_new_images = args.num_class_images - cur_class_images\n                    logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n                    sample_dataset = PromptDataset(concept[\"class_prompt\"], num_new_images)\n                    sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n                    sample_dataloader = accelerator.prepare(sample_dataloader)\n                    pipeline.to(accelerator.device)\n\n                    for example in tqdm(\n                        sample_dataloader,\n                        desc=\"Generating class images\",\n                        disable=not accelerator.is_local_main_process,\n                    ):\n                        images = pipeline(example[\"prompt\"]).images\n\n                        for i, image in enumerate(images):\n                            hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                            image_filename = (\n                                class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                            )\n                            image.save(image_filename)\n\n                    del pipeline\n                    if torch.cuda.is_available():\n                        torch.cuda.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    if args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.tokenizer_name,\n            revision=args.revision,\n            use_fast=False,\n        )\n    elif args.pretrained_model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.pretrained_model_name_or_path,\n            subfolder=\"tokenizer\",\n            revision=args.revision,\n            use_fast=False,\n        )\n\n    # import correct text encoder class\n    text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = text_encoder_cls.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # Adding a modifier token which is optimized ####\n    # Code taken from https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py\n    modifier_token_id = []\n    initializer_token_id = []\n    if args.modifier_token is not None:\n        args.modifier_token = args.modifier_token.split(\"+\")\n        args.initializer_token = args.initializer_token.split(\"+\")\n        if len(args.modifier_token) > len(args.initializer_token):\n            raise ValueError(\"You must specify + separated initializer token for each modifier token.\")\n        for modifier_token, initializer_token in zip(\n            args.modifier_token, args.initializer_token[: len(args.modifier_token)]\n        ):\n            # Add the placeholder token in tokenizer\n            num_added_tokens = tokenizer.add_tokens(modifier_token)\n            if num_added_tokens == 0:\n                raise ValueError(\n                    f\"The tokenizer already contains the token {modifier_token}. Please pass a different\"\n                    \" `modifier_token` that is not already in the tokenizer.\"\n                )\n\n            # Convert the initializer_token, placeholder_token to ids\n            token_ids = tokenizer.encode([initializer_token], add_special_tokens=False)\n            print(token_ids)\n            # Check if initializer_token is a single token or a sequence of tokens\n            if len(token_ids) > 1:\n                raise ValueError(\"The initializer token must be a single token.\")\n\n            initializer_token_id.append(token_ids[0])\n            modifier_token_id.append(tokenizer.convert_tokens_to_ids(modifier_token))\n\n        # Resize the token embeddings as we are adding new special tokens to the tokenizer\n        text_encoder.resize_token_embeddings(len(tokenizer))\n\n        # Initialise the newly added placeholder token with the embeddings of the initializer token\n        token_embeds = text_encoder.get_input_embeddings().weight.data\n        for x, y in zip(modifier_token_id, initializer_token_id):\n            token_embeds[x] = token_embeds[y]\n\n        # Freeze all parameters except for the token embeddings in text encoder\n        params_to_freeze = itertools.chain(\n            text_encoder.text_model.encoder.parameters(),\n            text_encoder.text_model.final_layer_norm.parameters(),\n            text_encoder.text_model.embeddings.position_embedding.parameters(),\n        )\n        freeze_params(params_to_freeze)\n    ########################################################\n    ########################################################\n\n    vae.requires_grad_(False)\n    if args.modifier_token is None:\n        text_encoder.requires_grad_(False)\n    unet.requires_grad_(False)\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    if accelerator.mixed_precision != \"fp16\" and args.modifier_token is not None:\n        text_encoder.to(accelerator.device, dtype=weight_dtype)\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    attention_class = (\n        CustomDiffusionAttnProcessor2_0 if hasattr(F, \"scaled_dot_product_attention\") else CustomDiffusionAttnProcessor\n    )\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            attention_class = CustomDiffusionXFormersAttnProcessor\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # now we will add new Custom Diffusion weights to the attention layers\n    # It's important to realize here how many attention weights will be added and of which sizes\n    # The sizes of the attention layers consist only of two different variables:\n    # 1) - the \"hidden_size\", which is increased according to `unet.config.block_out_channels`.\n    # 2) - the \"cross attention size\", which is set to `unet.config.cross_attention_dim`.\n\n    # Let's first see how many attention processors we will have to set.\n    # For Stable Diffusion, it should be equal to:\n    # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12\n    # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2\n    # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18\n    # => 32 layers\n\n    # Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer\n    train_kv = True\n    train_q_out = False if args.freeze_model == \"crossattn_kv\" else True\n    custom_diffusion_attn_procs = {}\n\n    st = unet.state_dict()\n    for name, _ in unet.attn_processors.items():\n        cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n        if name.startswith(\"mid_block\"):\n            hidden_size = unet.config.block_out_channels[-1]\n        elif name.startswith(\"up_blocks\"):\n            block_id = int(name[len(\"up_blocks.\")])\n            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n        elif name.startswith(\"down_blocks\"):\n            block_id = int(name[len(\"down_blocks.\")])\n            hidden_size = unet.config.block_out_channels[block_id]\n        layer_name = name.split(\".processor\")[0]\n        weights = {\n            \"to_k_custom_diffusion.weight\": st[layer_name + \".to_k.weight\"],\n            \"to_v_custom_diffusion.weight\": st[layer_name + \".to_v.weight\"],\n        }\n        if train_q_out:\n            weights[\"to_q_custom_diffusion.weight\"] = st[layer_name + \".to_q.weight\"]\n            weights[\"to_out_custom_diffusion.0.weight\"] = st[layer_name + \".to_out.0.weight\"]\n            weights[\"to_out_custom_diffusion.0.bias\"] = st[layer_name + \".to_out.0.bias\"]\n        if cross_attention_dim is not None:\n            custom_diffusion_attn_procs[name] = attention_class(\n                train_kv=train_kv,\n                train_q_out=train_q_out,\n                hidden_size=hidden_size,\n                cross_attention_dim=cross_attention_dim,\n            ).to(unet.device)\n            custom_diffusion_attn_procs[name].load_state_dict(weights)\n        else:\n            custom_diffusion_attn_procs[name] = attention_class(\n                train_kv=False,\n                train_q_out=False,\n                hidden_size=hidden_size,\n                cross_attention_dim=cross_attention_dim,\n            )\n    del st\n    unet.set_attn_processor(custom_diffusion_attn_procs)\n    custom_diffusion_layers = AttnProcsLayers(unet.attn_processors)\n\n    accelerator.register_for_checkpointing(custom_diffusion_layers)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.modifier_token is not None:\n            text_encoder.gradient_checkpointing_enable()\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n        if args.with_prior_preservation:\n            args.learning_rate = args.learning_rate * 2.0\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    optimizer = optimizer_class(\n        itertools.chain(text_encoder.get_input_embeddings().parameters(), custom_diffusion_layers.parameters())\n        if args.modifier_token is not None\n        else custom_diffusion_layers.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = CustomDiffusionDataset(\n        concepts_list=args.concepts_list,\n        tokenizer=tokenizer,\n        with_prior_preservation=args.with_prior_preservation,\n        size=args.resolution,\n        mask_size=vae.encode(\n            torch.randn(1, 3, args.resolution, args.resolution).to(dtype=weight_dtype).to(accelerator.device)\n        )\n        .latent_dist.sample()\n        .size()[-1],\n        center_crop=args.center_crop,\n        num_class_images=args.num_class_images,\n        hflip=args.hflip,\n        aug=not args.noaug,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.modifier_token is not None:\n        custom_diffusion_layers, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            custom_diffusion_layers, text_encoder, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        custom_diffusion_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            custom_diffusion_layers, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.modifier_token is not None:\n            text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n                # Predict the noise residual\n                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n                    mask = torch.chunk(batch[\"mask\"], 2, dim=0)[0]\n                    # Compute instance loss\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()\n\n                    # Compute prior loss\n                    prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction=\"mean\")\n\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n                else:\n                    mask = batch[\"mask\"]\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()\n                accelerator.backward(loss)\n                # Zero out the gradients for all token embeddings except the newly added\n                # embeddings for the concept, as we only want to optimize the concept embeddings\n                if args.modifier_token is not None:\n                    if accelerator.num_processes > 1:\n                        grads_text_encoder = text_encoder.module.get_input_embeddings().weight.grad\n                    else:\n                        grads_text_encoder = text_encoder.get_input_embeddings().weight.grad\n                    # Get the index for tokens that we want to zero the grads for\n                    index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0]\n                    for i in range(1, len(modifier_token_id)):\n                        index_grads_to_zero = index_grads_to_zero & (\n                            torch.arange(len(tokenizer)) != modifier_token_id[i]\n                        )\n                    grads_text_encoder.data[index_grads_to_zero, :] = grads_text_encoder.data[\n                        index_grads_to_zero, :\n                    ].fill_(0)\n\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(text_encoder.parameters(), custom_diffusion_layers.parameters())\n                        if args.modifier_token is not None\n                        else custom_diffusion_layers.parameters()\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n            if accelerator.is_main_process:\n                images = []\n\n                if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                    logger.info(\n                        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n                        f\" {args.validation_prompt}.\"\n                    )\n                    # create pipeline\n                    pipeline = DiffusionPipeline.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        unet=accelerator.unwrap_model(unet),\n                        text_encoder=accelerator.unwrap_model(text_encoder),\n                        tokenizer=tokenizer,\n                        revision=args.revision,\n                        variant=args.variant,\n                        torch_dtype=weight_dtype,\n                    )\n                    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n                    pipeline = pipeline.to(accelerator.device)\n                    pipeline.set_progress_bar_config(disable=True)\n\n                    # run inference\n                    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n                    images = [\n                        pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[\n                            0\n                        ]\n                        for _ in range(args.num_validation_images)\n                    ]\n\n                    for tracker in accelerator.trackers:\n                        if tracker.name == \"tensorboard\":\n                            np_images = np.stack([np.asarray(img) for img in images])\n                            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n                        if tracker.name == \"wandb\":\n                            tracker.log(\n                                {\n                                    \"validation\": [\n                                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                        for i, image in enumerate(images)\n                                    ]\n                                }\n                            )\n\n                    del pipeline\n                    torch.cuda.empty_cache()\n\n    # Save the custom diffusion layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unet.to(torch.float32)\n        unet.save_attn_procs(args.output_dir, safe_serialization=not args.no_safe_serialization)\n        save_new_embed(\n            text_encoder,\n            modifier_token_id,\n            accelerator,\n            args,\n            args.output_dir,\n            safe_serialization=not args.no_safe_serialization,\n        )\n\n        # Final inference\n        # Load previous pipeline\n        pipeline = DiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype\n        )\n        pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n        pipeline = pipeline.to(accelerator.device)\n\n        # load attention processors\n        weight_name = (\n            \"pytorch_custom_diffusion_weights.safetensors\"\n            if not args.no_safe_serialization\n            else \"pytorch_custom_diffusion_weights.bin\"\n        )\n        pipeline.unet.load_attn_procs(args.output_dir, weight_name=weight_name)\n        for token in args.modifier_token:\n            token_weight_name = f\"{token}.safetensors\" if not args.no_safe_serialization else f\"{token}.bin\"\n            pipeline.load_textual_inversion(args.output_dir, weight_name=token_weight_name)\n\n        # run inference\n        if args.validation_prompt and args.num_validation_images > 0:\n            generator = (\n                torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n            )\n            images = [\n                pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]\n                for _ in range(args.num_validation_images)\n            ]\n\n            for tracker in accelerator.trackers:\n                if tracker.name == \"tensorboard\":\n                    np_images = np.stack([np.asarray(img) for img in images])\n                    tracker.writer.add_images(\"test\", np_images, epoch, dataformats=\"NHWC\")\n                if tracker.name == \"wandb\":\n                    tracker.log(\n                        {\n                            \"test\": [\n                                wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                for i, image in enumerate(images)\n                            ]\n                        }\n                    )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                prompt=args.instance_prompt,\n                repo_folder=args.output_dir,\n            )\n            api = HfApi(token=args.hub_token)\n            api.upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/README.md",
    "content": "# DreamBooth training example\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.\nThe `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion.\n\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nInstall the requirements in the `examples/dreambooth` folder as shown below.\n```bash\ncd examples/dreambooth\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell e.g. a notebook\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\n### Dog toy example\n\nNow let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.\n\nLet's first download it locally:\n\n```python\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog\"\nsnapshot_download(\n    \"diffusers/dog-example\",\n    local_dir=local_dir, repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nAnd launch the training using:\n\n**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=400 \\\n  --push_to_hub\n```\n\n### Training with prior-preservation loss\n\nPrior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.\nAccording to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"dog\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800 \\\n  --push_to_hub\n```\n\n\n### Training on a 16GB GPU:\n\nWith the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.\n\nTo install `bitsandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"dog\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=2 --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800 \\\n  --push_to_hub\n```\n\n\n### Training on a 12GB GPU:\n\nIt is possible to run dreambooth on a 12GB GPU by using the following optimizations:\n- [gradient checkpointing and the 8-bit optimizer](#training-on-a-16gb-gpu)\n- [xformers](#training-with-xformers)\n- [setting grads to none](#set-grads-to-none)\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"dog\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --enable_xformers_memory_efficient_attention \\\n  --set_grads_to_none \\\n  --learning_rate=2e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800 \\\n  --push_to_hub\n```\n\n\n### Training on a 8 GB GPU:\n\nBy using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some\ntensors from VRAM to either CPU or NVME allowing to train with less VRAM.\n\nDeepSpeed needs to be enabled with `accelerate config`. During configuration\nanswer yes to \"Do you want to use DeepSpeed?\". With DeepSpeed stage 2, fp16\nmixed precision and offloading both parameters and optimizer state to cpu it's\npossible to train on under 8 GB VRAM with a drawback of requiring significantly\nmore RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.\n\nChanging the default Adam optimizer to DeepSpeed's special version of Adam\n`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling\nit requires CUDA toolchain with the same version as pytorch. 8-bit optimizer\ndoes not seem to be compatible with DeepSpeed at the moment.\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"dog\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch --mixed_precision=\"fp16\" train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --sample_batch_size=1 \\\n  --gradient_accumulation_steps=1 --gradient_checkpointing \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800 \\\n  --push_to_hub\n```\n\n### Fine-tune text encoder with the UNet.\n\nThe script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.\nPass the `--train_text_encoder` argument to the script to enable training `text_encoder`.\n\n___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"dog\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --train_text_encoder \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --use_8bit_adam \\\n  --gradient_checkpointing \\\n  --learning_rate=2e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800 \\\n  --push_to_hub\n```\n\n### Using DreamBooth for pipelines other than Stable Diffusion\n\nThe [AltDiffusion pipeline](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion) also supports dreambooth fine-tuning. The process is the same as above, all you need to do is replace the `MODEL_NAME` like this:\n\n```\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\" --> export MODEL_NAME=\"BAAI/AltDiffusion-m9\"\nor\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\" --> export MODEL_NAME=\"BAAI/AltDiffusion\"\n```\n\n### Inference\n\nOnce you have trained a model using the above command, you can run inference simply using the `StableDiffusionPipeline`. Make sure to include the `identifier` (e.g. sks in above example) in your prompt.\n\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\nmodel_id = \"path-to-your-trained-model\"\npipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A photo of sks dog in a bucket\"\nimage = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]\n\nimage.save(\"dog-bucket.png\")\n```\n\n### Inference from a training checkpoint\n\nYou can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it.\n\n## Training with Low-Rank Adaptation of Large Language Models (LoRA)\n\nLow-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*\n\nIn a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:\n- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)\n- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.\n- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.\n\n[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in\nthe popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.\n\n### Training\n\nLet's get started with a simple example. We will re-use the dog example of the [previous section](#dog-toy-example).\n\nFirst, you need to set-up your dreambooth training example as is explained in the [installation section](#Installing-the-dependencies).\nNext, let's download the dog dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. Make sure to set `INSTANCE_DIR` to the name of your directory further below. This will be our training data.\n\nNow, you can launch the training. Here we will use [Stable Diffusion 1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5).\n\n**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**\n\n**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [wandb](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training and pass `--report_to=\"wandb\"` to automatically log images.___**\n\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n```\n\nFor this example we want to directly store the trained LoRA embeddings on the Hub, so\nwe need to be logged in and add the `--push_to_hub` flag.\n\n```bash\nhf auth login\n```\n\nNow we can start training!\n\n```bash\naccelerate launch train_dreambooth_lora.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --checkpointing_steps=100 \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=50 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n**___Note: When using LoRA we can use a much higher learning rate compared to vanilla dreambooth. Here we\nuse *1e-4* instead of the usual *2e-6*.___**\n\nThe final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dreambooth_dog_example](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example). **___Note: [The final weights](https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin) are only 3 MB in size which is orders of magnitudes smaller than the original model.**\n\nThe training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).\nYou can use the `Step` slider to see how the model learned the features of our subject while the model trained.\n\nOptionally, we can also train additional LoRA layers for the text encoder. Specify the `--train_text_encoder` argument above for that. If you're interested to know more about how we\nenable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).\n\nWith the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).\n\n\n### Inference\n\nAfter training, LoRA weights can be loaded very easily into the original pipeline. First, you need to\nload the original pipeline:\n\n```python\nfrom diffusers import DiffusionPipeline\npipe = DiffusionPipeline.from_pretrained(\"base-model-name\").to(\"cuda\")\n```\n\nNext, we can load the adapter layers into the pipeline with the [`load_lora_weights` function](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters#lora).\n\n```python\npipe.load_lora_weights(\"path-to-the-lora-checkpoint\")\n```\n\nFinally, we can run the model in inference.\n\n```python\nimage = pipe(\"A picture of a sks dog in a bucket\", num_inference_steps=25).images[0]\n```\n\nIf you are loading the LoRA parameters from the Hub and if the Hub repository has\na `base_model` tag (such as [this](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/blob/main/README.md?code=true#L4)), then\nyou can do:\n\n```py\nfrom huggingface_hub.repocard import RepoCard\n\nlora_model_id = \"patrickvonplaten/lora_dreambooth_dog_example\"\ncard = RepoCard.load(lora_model_id)\nbase_model_id = card.data.to_dict()[\"base_model\"]\n\npipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)\n...\n```\n\nIf you used `--train_text_encoder` during training, then use `pipe.load_lora_weights()` to load the LoRA\nweights. For example:\n\n```python\nfrom huggingface_hub.repocard import RepoCard\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\nlora_model_id = \"sayakpaul/dreambooth-text-encoder-test\"\ncard = RepoCard.load(lora_model_id)\nbase_model_id = card.data.to_dict()[\"base_model\"]\n\npipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)\npipe = pipe.to(\"cuda\")\npipe.load_lora_weights(lora_model_id)\nimage = pipe(\"A picture of a sks dog in a bucket\", num_inference_steps=25).images[0]\n```\n\nNote that the use of [`StableDiffusionLoraLoaderMixin.load_lora_weights`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.StableDiffusionLoraLoaderMixin.load_lora_weights) is preferred to [`UNet2DConditionLoadersMixin.load_attn_procs`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs) for loading LoRA parameters. This is because\n`StableDiffusionLoraLoaderMixin.load_lora_weights` can handle the following situations:\n\n* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`\"patrickvonplaten/lora_dreambooth_dog_example\"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do:\n\n  ```py\n  pipe.load_lora_weights(lora_model_path)\n  ```\n\n* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`\"sayakpaul/dreambooth\"`](https://huggingface.co/sayakpaul/dreambooth).\n\n## Training with Flax/JAX\n\nFor faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.\n\n____Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.___\n\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n```bash\npip install -U -r requirements_flax.txt\n```\n\n\n### Training without prior preservation loss\n\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\npython train_dreambooth_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --learning_rate=5e-6 \\\n  --max_train_steps=400\n```\n\n\n### Training with prior preservation loss\n\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport INSTANCE_DIR=\"dog\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\npython train_dreambooth_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --learning_rate=5e-6 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n\n\n### Fine-tune text encoder with the UNet.\n\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport INSTANCE_DIR=\"dog\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\npython train_dreambooth_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --train_text_encoder \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --learning_rate=2e-6 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n\n### Training with xformers:\nYou can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.\n\nYou can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint).\n\n### Set grads to none\n\nTo save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.\n\nMore info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\n\n### Experimental results\nYou can refer to [this blog post](https://huggingface.co/blog/dreambooth) that discusses some of DreamBooth experiments in detail. Specifically, it recommends a set of DreamBooth-specific tips and tricks that we have found to work well for a variety of subjects.\n\n## IF\n\nYou can use the lora and full dreambooth scripts to train the text to image [IF model](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0) and the stage II upscaler\n[IF model](https://huggingface.co/DeepFloyd/IF-II-L-v1.0).\n\nNote that IF has a predicted variance, and our finetuning scripts only train the models predicted error, so for finetuned IF models we switch to a fixed\nvariance schedule. The full finetuning scripts will update the scheduler config for the full saved model. However, when loading saved LoRA weights, you\nmust also update the pipeline's scheduler config.\n\n```py\nfrom diffusers import DiffusionPipeline\n\npipe = DiffusionPipeline.from_pretrained(\"DeepFloyd/IF-I-XL-v1.0\")\n\npipe.load_lora_weights(\"<lora weights path>\")\n\n# Update scheduler config to fixed variance schedule\npipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type=\"fixed_small\")\n```\n\nAdditionally, a few alternative cli flags are needed for IF.\n\n`--resolution=64`: IF is a pixel space diffusion model. In order to operate on un-compressed pixels, the input images are of a much smaller resolution.\n\n`--pre_compute_text_embeddings`: IF uses [T5](https://huggingface.co/docs/transformers/model_doc/t5) for its text encoder. In order to save GPU memory, we pre compute all text embeddings and then de-allocate\nT5.\n\n`--tokenizer_max_length=77`: T5 has a longer default text length, but the default IF encoding procedure uses a smaller number.\n\n`--text_encoder_use_attention_mask`: T5 passes the attention mask to the text encoder.\n\n### Tips and Tricks\nWe find LoRA to be sufficient for finetuning the stage I model as the low resolution of the model makes representing finegrained detail hard regardless.\n\nFor common and/or not-visually complex object concepts, you can get away with not-finetuning the upscaler. Just be sure to adjust the prompt passed to the\nupscaler to remove the new token from the instance prompt. I.e. if your stage I prompt is \"a sks dog\", use \"a dog\" for your stage II prompt.\n\nFor finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than\nLoRA finetuning stage II.\n\nFor finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best.\n\nFor stage II, we find that lower learning rates are also needed.\n\nWe found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler\nused in the training scripts.\n\n### Stage II additional validation images\n\nThe stage II validation requires images to upscale, we can download a downsized version of the training set:\n\n```py\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog_downsized\"\nsnapshot_download(\n    \"diffusers/dog-example-downsized\",\n    local_dir=local_dir,\n    repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\n### IF stage I LoRA Dreambooth\nThis training configuration requires ~28 GB VRAM.\n\n```sh\nexport MODEL_NAME=\"DeepFloyd/IF-I-XL-v1.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"dreambooth_dog_lora\"\n\naccelerate launch train_dreambooth_lora.py \\\n  --report_to wandb \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a sks dog\" \\\n  --resolution=64 \\\n  --train_batch_size=4 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --scale_lr \\\n  --max_train_steps=1200 \\\n  --validation_prompt=\"a sks dog\" \\\n  --validation_epochs=25 \\\n  --checkpointing_steps=100 \\\n  --pre_compute_text_embeddings \\\n  --tokenizer_max_length=77 \\\n  --text_encoder_use_attention_mask\n```\n\n### IF stage II LoRA Dreambooth\n\n`--validation_images`: These images are upscaled during validation steps.\n\n`--class_labels_conditioning=timesteps`: Pass additional conditioning to the UNet needed for stage II.\n\n`--learning_rate=1e-6`: Lower learning rate than stage I.\n\n`--resolution=256`: The upscaler expects higher resolution inputs\n\n```sh\nexport MODEL_NAME=\"DeepFloyd/IF-II-L-v1.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"dreambooth_dog_upscale\"\nexport VALIDATION_IMAGES=\"dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png\"\n\npython train_dreambooth_lora.py \\\n    --report_to wandb \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --instance_data_dir=$INSTANCE_DIR \\\n    --output_dir=$OUTPUT_DIR \\\n    --instance_prompt=\"a sks dog\" \\\n    --resolution=256 \\\n    --train_batch_size=4 \\\n    --gradient_accumulation_steps=1 \\\n    --learning_rate=1e-6 \\\n    --max_train_steps=2000 \\\n    --validation_prompt=\"a sks dog\" \\\n    --validation_epochs=100 \\\n    --checkpointing_steps=500 \\\n    --pre_compute_text_embeddings \\\n    --tokenizer_max_length=77 \\\n    --text_encoder_use_attention_mask \\\n    --validation_images $VALIDATION_IMAGES \\\n    --class_labels_conditioning=timesteps\n```\n\n### IF Stage I Full Dreambooth\n`--skip_save_text_encoder`: When training the full model, this will skip saving the entire T5 with the finetuned model. You can still load the pipeline\nwith a T5 loaded from the original model.\n\n`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam.\n\n`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is\nlikely the learning rate can be increased with larger batch sizes.\n\nUsing 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM.\n\n`--validation_scheduler`: Set a particular scheduler via a string. We found that it is better to use the DDPMScheduler for validation when training DeepFloyd IF.\n\n```sh\nexport MODEL_NAME=\"DeepFloyd/IF-I-XL-v1.0\"\n\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"dreambooth_if\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=64 \\\n  --train_batch_size=4 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=1e-7 \\\n  --max_train_steps=150 \\\n  --validation_prompt \"a photo of sks dog\" \\\n  --validation_steps 25 \\\n  --text_encoder_use_attention_mask \\\n  --tokenizer_max_length 77 \\\n  --pre_compute_text_embeddings \\\n  --use_8bit_adam \\\n  --set_grads_to_none \\\n  --skip_save_text_encoder \\\n  --validation_scheduler DDPMScheduler \\\n  --push_to_hub\n```\n\n### IF Stage II Full Dreambooth\n\n`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as\n1e-8.\n\n`--resolution=256`: The upscaler expects higher resolution inputs\n\n`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with\nfaces required large effective batch sizes.\n\n```sh\nexport MODEL_NAME=\"DeepFloyd/IF-II-L-v1.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"dreambooth_dog_upscale\"\nexport VALIDATION_IMAGES=\"dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png\"\n\naccelerate launch train_dreambooth.py \\\n  --report_to wandb \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a sks dog\" \\\n  --resolution=256 \\\n  --train_batch_size=2 \\\n  --gradient_accumulation_steps=6 \\\n  --learning_rate=5e-6 \\\n  --max_train_steps=2000 \\\n  --validation_prompt=\"a sks dog\" \\\n  --validation_steps=150 \\\n  --checkpointing_steps=500 \\\n  --pre_compute_text_embeddings \\\n  --tokenizer_max_length=77 \\\n  --text_encoder_use_attention_mask \\\n  --validation_images $VALIDATION_IMAGES \\\n  --class_labels_conditioning timesteps \\\n  --validation_scheduler DDPMScheduler\\\n  --push_to_hub\n```\n\n## Stable Diffusion XL\n\nWe support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).\n\n## Dataset\n\nWe support 🤗 [Datasets](https://huggingface.co/docs/datasets/index), you can find a dataset on the [Hugging Face Hub](https://huggingface.co/datasets) or use your own.\n\nThe quickest way to get started with your custom dataset is 🤗 Datasets' [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder).\n\nWe need to create a file `metadata.jsonl` in the directory with our images:\n\n```\n{\"file_name\": \"01.jpg\", \"prompt\": \"prompt 01\"}\n{\"file_name\": \"02.jpg\", \"prompt\": \"prompt 02\"}\n```\n\nIf we have a directory with image-text pairs e.g. `01.jpg` and `01.txt` then `convert_to_imagefolder.py` can create `metadata.jsonl`.\n\n```sh\npython convert_to_imagefolder.py --path my_dataset/\n```\n\nWe use `--dataset_name` and `--caption_column` with training scripts.\n\n```\n--dataset_name=my_dataset/\n--caption_column=prompt\n```\n"
  },
  {
    "path": "examples/dreambooth/README_flux.md",
    "content": "# DreamBooth training example for FLUX.1 [dev]\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.\n\nThe `train_dreambooth_flux.py` script shows how to implement the training procedure and adapt it for [FLUX.1 [dev]](https://blackforestlabs.ai/announcing-black-forest-labs/). We also provide a LoRA implementation in the `train_dreambooth_lora_flux.py` script.\n> [!NOTE]\n> **Memory consumption**\n>\n> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -\n> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.\n\n> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX: \n> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)\n> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training)\n\n> [!NOTE]\n> **Gated model**\n>\n> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:\n\n```bash\nhf auth login\n```\n\nThis will also allow us to push the trained model parameters to the Hugging Face Hub platform.\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/dreambooth` folder and run\n```bash\npip install -r requirements_flux.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\n\n### Dog toy example\n\nNow let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.\n\nLet's first download it locally:\n\n```python\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog\"\nsnapshot_download(\n    \"diffusers/dog-example\",\n    local_dir=local_dir, repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nThis will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.\n\nNow, we can launch training using:\n\n```bash\nexport MODEL_NAME=\"black-forest-labs/FLUX.1-dev\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-flux\"\n\naccelerate launch train_dreambooth_flux.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=\"bf16\" \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --guidance_scale=1 \\\n  --gradient_accumulation_steps=4 \\\n  --optimizer=\"prodigy\" \\\n  --learning_rate=1. \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\n> [!NOTE]\n> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.\n\n## LoRA + DreamBooth\n\n[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.\n\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\n### Prodigy Optimizer\nProdigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence. \nBy using prodigy we can \"eliminate\" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).\n\nto use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify -\n```bash\n--optimizer=\"prodigy\"\n```\n> [!TIP]\n> When using prodigy it's generally good practice to set- `--learning_rate=1.0`\n\nTo perform DreamBooth with LoRA, run:\n\n```bash\nexport MODEL_NAME=\"black-forest-labs/FLUX.1-dev\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-flux-lora\"\n\naccelerate launch train_dreambooth_lora_flux.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=\"bf16\" \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --guidance_scale=1 \\\n  --gradient_accumulation_steps=4 \\\n  --optimizer=\"prodigy\" \\\n  --learning_rate=1. \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n### LoRA Rank and Alpha\nTwo key LoRA hyperparameters are LoRA rank and LoRA alpha. \n- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).\n- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.\n- lora_alpha vs. rank:\nThis ratio dictates the LoRA's effective strength:\nlora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)\nlora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)\nlora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)\n\n> [!TIP]\n> A common starting point is to set `lora_alpha` equal to `rank`. \n> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16) \n> to give the LoRA updates more influence without increasing parameter count. \n> If you find your LoRA is \"overcooking\" or learning too aggressively, consider setting `lora_alpha` to half of `rank` \n> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.\n\n### Target Modules\nWhen LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. \nMore recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore \napplying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string\nthe exact modules for LoRA training. Here are some examples of target modules you can provide: \n- for attention only layers: `--lora_layers=\"attn.to_k,attn.to_q,attn.to_v,attn.to_out.0\"`\n- to train the same modules as in the fal trainer: `--lora_layers=\"attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2\"`\n- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks=\"attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out\"`\n> [!NOTE]\n> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:\n> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`\n> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k` \n> [!NOTE]\n> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.\n\n### Text Encoder Training\n\nAlongside the transformer, fine-tuning of the CLIP text encoder is also supported.\nTo do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:\n\n> [!NOTE]\n> This is still an experimental feature. \n> FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL).\nBy enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed.\n> At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.\n\nTo perform DreamBooth LoRA with text-encoder training, run:\n```bash\nexport MODEL_NAME=\"black-forest-labs/FLUX.1-dev\"\nexport OUTPUT_DIR=\"trained-flux-dev-dreambooth-lora\"\n\naccelerate launch train_dreambooth_lora_flux.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=\"bf16\" \\\n  --train_text_encoder\\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --guidance_scale=1 \\\n  --gradient_accumulation_steps=4 \\\n  --optimizer=\"prodigy\" \\\n  --learning_rate=1. \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n## Memory Optimizations\nAs mentioned, Flux Dreambooth LoRA training is very memory intensive Here are some options (some still experimental) for a more memory efficient training.\n### Image Resolution\nAn easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.\nNote that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions. \n### Gradient Checkpointing and Accumulation\n* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.\nby passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.\n* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.\nInstead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.\n### 8-bit-Adam Optimizer\nWhen training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training. \nMake sure to install `bitsandbytes` if you want to do so.\n### Latent caching\nWhen training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory. \nto enable `latent_caching` simply pass `--cache_latents`.\n### Precision of saved LoRA layers\nBy default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision=\"bf16\"`, final finetuned layers will be saved in `torch.bfloat16` as well. \nThis reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.\n\n## Training Kontext\n\n[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We\nprovide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for both T2I and I2I. The optimizations discussed above apply this script, too.\n\n**important**\n\n> [!NOTE] \n> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source.\n> To do this, execute the following steps in a new virtual environment:\n> ```\n> git clone https://github.com/huggingface/diffusers\n> cd diffusers\n> pip install -e .\n> ```\n\nBelow is an example training command:\n\n```bash\naccelerate launch train_dreambooth_lora_flux_kontext.py \\\n  --pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev  \\\n  --instance_data_dir=\"dog\" \\\n  --output_dir=\"kontext-dog\" \\\n  --mixed_precision=\"bf16\" \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --guidance_scale=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --optimizer=\"adamw\" \\\n  --use_8bit_adam \\\n  --cache_latents \\\n  --learning_rate=1e-4 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --seed=\"0\" \n```\n\nFine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not\nperform as expected.\n\nImage-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets:\n\n* Condition image\n* Target image\n* Instruction\n\n[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:\n\n```bash\naccelerate launch train_dreambooth_lora_flux_kontext.py \\\n  --pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev  \\\n  --output_dir=\"kontext-i2i\" \\\n  --dataset_name=\"kontext-community/relighting\" \\\n  --image_column=\"output\" --cond_image_column=\"file_name\" --caption_column=\"instruction\" \\\n  --mixed_precision=\"bf16\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --guidance_scale=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --optimizer=\"adamw\" \\\n  --use_8bit_adam \\\n  --cache_latents \\\n  --learning_rate=1e-4 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=200 \\\n  --max_train_steps=1000 \\\n  --rank=16\\\n  --seed=\"0\" \n```\n\nMore generally, when performing I2I fine-tuning, we expect you to:\n\n* Have a dataset `kontext-community/relighting`\n* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training\n\n### Misc notes\n\n* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.\n### Aspect Ratio Bucketing\nwe've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.\n\nTo enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:\n\n`--aspect_ratio_buckets=\"672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672\"\n`\nSince Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗 \n\n## Other notes\nThanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️\n"
  },
  {
    "path": "examples/dreambooth/README_flux2.md",
    "content": "# DreamBooth training example for FLUX.2 [dev] and FLUX 2 [klein]\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.\n[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.\n\nThe `train_dreambooth_lora_flux2.py`, `train_dreambooth_lora_flux2_klein.py` scripts shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://huggingface.co/black-forest-labs/FLUX.2-dev) and [FLUX 2 [klein]](https://huggingface.co/black-forest-labs/FLUX.2-klein).\n\n> [!NOTE]\n> **Model Variants**\n>\n> We support two FLUX model families:\n> - **FLUX.2 [dev]**: The full-size model using Mistral Small 3.1 as the text encoder. Very capable but memory intensive.\n> - **FLUX 2 [klein]**: Available in 4B and 9B parameter variants, using Qwen VL as the text encoder. Much more memory efficient and suitable for consumer hardware.\n\n> [!NOTE]\n> **Memory consumption**\n>\n> FLUX.2 [dev] can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -\n> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. FLUX 2 [klein] models (4B and 9B) are significantly more memory efficient alternatives. Below we provide some tips and tricks to reduce memory consumption during training.\n\n> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX: \n> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md)\n> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux2-training)\n\n> [!NOTE]\n> **Gated model**\n>\n> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you've accepted the gate. Use the command below to log in:\n\n```bash\nhf auth login\n```\n\nThis will also allow us to push the trained model parameters to the Hugging Face Hub platform.\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/dreambooth` folder and run\n```bash\npip install -r requirements_flux.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\n\n### Dog toy example\n\nNow let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.\n\nLet's first download it locally:\n\n```python\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog\"\nsnapshot_download(\n    \"diffusers/dog-example\",\n    local_dir=local_dir, repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nThis will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.\n\nAs mentioned, Flux2 LoRA training is *very* memory intensive (especially for FLUX.2 [dev]). Here are memory optimizations we can use (some still experimental) for a more memory efficient training:\n\n## Memory Optimizations\n> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption. \n> However some techniques may be mutually exclusive so be sure to check before launching a training run.\n\n### Remote Text Encoder \nFLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API. \nThis way, the text encoder model is not loaded into memory during training.\n\n> [!IMPORTANT]\n> **Remote text encoder is only supported for FLUX.2 [dev]**. FLUX 2 [klein] models use the Qwen VL text encoder and do not support remote text encoding.\n\n> [!NOTE] \n> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.\n\n### FSDP Text Encoder \nFLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings. \nThis way, it distributes the memory cost across multiple nodes.\n\n### CPU Offloading \nTo offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.\n\n### Latent Caching \nPre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.\n\n### QLoRA: Low Precision Training with Quantization\nPerform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:\n- **FP8 training** with `torchao`: \nenable FP8 training by passing `--do_fp8_training`. \n> [!IMPORTANT] Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater. \n> If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers like SimpleTuner, ai-toolkit, etc.\n- **NF4 training** with `bitsandbytes`: \nAlternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing:\n`--bnb_quantization_config_path` to enable 4-bit NF4 quantization.\n\n### Gradient Checkpointing and Accumulation\n* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.\nby passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.\n* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.\nInstead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.\n\n### 8-bit-Adam Optimizer\nWhen training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training. \nMake sure to install `bitsandbytes` if you want to do so.\n\n### Image Resolution\nAn easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.\nNote that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.\n\n### Precision of saved LoRA layers\nBy default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision=\"bf16\"`, final finetuned layers will be saved in `torch.bfloat16` as well. \nThis reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.\n\n## Training Examples\n\n### FLUX.2 [dev] Training\nTo perform DreamBooth with LoRA on FLUX.2 [dev], run:\n```bash\nexport MODEL_NAME=\"black-forest-labs/FLUX.2-dev\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-flux2\"\n\naccelerate launch train_dreambooth_lora_flux2.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --do_fp8_training \\\n  --gradient_checkpointing \\\n  --remote_text_encoder \\\n  --cache_latents \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --guidance_scale=1 \\\n  --use_8bit_adam \\\n  --gradient_accumulation_steps=4 \\\n  --optimizer=\"adamW\" \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=100 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n### FLUX 2 [klein] Training\n\nFLUX 2 [klein] models are more memory efficient alternatives available in 4B and 9B parameter variants. They use the Qwen VL text encoder instead of Mistral Small 3.1.\n\n> [!NOTE]\n> The `--remote_text_encoder` flag is **not supported** for FLUX 2 [klein] models. The Qwen VL text encoder must be loaded locally, but offloading is still supported.\n\n**FLUX 2 [klein] 4B:**\n\n```bash\nexport MODEL_NAME=\"black-forest-labs/FLUX.2-klein-4B\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-flux2-klein-4b\"\n\naccelerate launch train_dreambooth_lora_flux2_klein.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --do_fp8_training \\\n  --gradient_checkpointing \\\n  --cache_latents \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --guidance_scale=1 \\\n  --use_8bit_adam \\\n  --gradient_accumulation_steps=4 \\\n  --optimizer=\"adamW\" \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=100 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n**FLUX 2 [klein] 9B:**\n\n```bash\nexport MODEL_NAME=\"black-forest-labs/FLUX.2-klein-9B\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-flux2-klein-9b\"\n\naccelerate launch train_dreambooth_lora_flux2_klein.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --do_fp8_training \\\n  --gradient_checkpointing \\\n  --cache_latents \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --guidance_scale=1 \\\n  --use_8bit_adam \\\n  --gradient_accumulation_steps=4 \\\n  --optimizer=\"adamW\" \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=100 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\n> [!NOTE]\n> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. Note that this will use more resources and may slow down the training in some cases.\n\n### FSDP on the transformer\nBy setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:\n\n```shell\ndistributed_type: FSDP\nfsdp_config:\n  fsdp_version: 2\n  fsdp_offload_params: false\n  fsdp_sharding_strategy: HYBRID_SHARD\n  fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock\n  fsdp_forward_prefetch: true\n  fsdp_sync_module_states: false\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_use_orig_params: false\n  fsdp_activation_checkpointing: true\n  fsdp_reshard_after_forward: true\n  fsdp_cpu_ram_efficient_loading: false\n```\n\n### Prodigy Optimizer\nProdigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence. \nBy using prodigy we can \"eliminate\" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).\n\nto use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify -\n```bash\n--optimizer=\"prodigy\"\n```\n> [!TIP]\n> When using prodigy it's generally good practice to set- `--learning_rate=1.0`\n\n```bash\nexport MODEL_NAME=\"black-forest-labs/FLUX.2-dev\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-flux2-lora\"\n\naccelerate launch train_dreambooth_lora_flux2.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --do_fp8_training \\\n  --gradient_checkpointing \\\n  --remote_text_encoder \\\n  --cache_latents \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --guidance_scale=1 \\\n  --gradient_accumulation_steps=4 \\\n  --optimizer=\"prodigy\" \\\n  --learning_rate=1. \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant_with_warmup\" \\\n  --lr_warmup_steps=100 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n### LoRA Rank and Alpha\nTwo key LoRA hyperparameters are LoRA rank and LoRA alpha. \n- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).\n- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.\n- lora_alpha vs. rank:\nThis ratio dictates the LoRA's effective strength:\nlora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)\nlora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)\nlora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)\n\n> [!TIP]\n> A common starting point is to set `lora_alpha` equal to `rank`. \n> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16) \n> to give the LoRA updates more influence without increasing parameter count. \n> If you find your LoRA is \"overcooking\" or learning too aggressively, consider setting `lora_alpha` to half of `rank` \n> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.\n\n### Target Modules\nWhen LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. \nMore recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore \napplying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string\nthe exact modules for LoRA training. Here are some examples of target modules you can provide: \n- for attention only layers: `--lora_layers=\"attn.to_k,attn.to_q,attn.to_v,attn.to_out.0\"`\n- to train the same modules as in the fal trainer: `--lora_layers=\"attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2\"`\n- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks=\"attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out\"`\n> [!NOTE]\n> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:\n> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`\n> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k` \n> [!NOTE]\n> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.\n\n\n## Training Image-to-Image\n\nFlux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too.\n\n**Important**\nTo make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n\nTo start, you must have a dataset containing triplets:\n\n* Condition image - the input image to be transformed.\n* Target image - the desired output image after transformation.\n* Instruction - a text prompt describing the transformation from the condition image to the target image.\n\n[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:\n\n```bash\naccelerate launch train_dreambooth_lora_flux2_img2img.py \\\n  --pretrained_model_name_or_path=black-forest-labs/FLUX.2-dev  \\\n  --output_dir=\"flux2-i2i\" \\\n  --dataset_name=\"kontext-community/relighting\" \\\n  --image_column=\"output\" --cond_image_column=\"file_name\" --caption_column=\"instruction\" \\\n  --do_fp8_training \\\n  --gradient_checkpointing \\\n  --remote_text_encoder \\\n  --cache_latents \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --guidance_scale=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --optimizer=\"adamw\" \\\n  --use_8bit_adam \\\n  --cache_latents \\\n  --learning_rate=1e-4 \\\n  --lr_scheduler=\"constant_with_warmup\" \\\n  --lr_warmup_steps=200 \\\n  --max_train_steps=1000 \\\n  --rank=16\\\n  --seed=\"0\" \n```\n\nMore generally, when performing I2I fine-tuning, we expect you to:\n\n* Have a dataset `kontext-community/relighting`\n* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training\n\n### Misc notes\n\n* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.\n### Aspect Ratio Bucketing\nwe've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.\n\nTo enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:\n\n`--aspect_ratio_buckets=\"672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672\"\n\n\nSince Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗"
  },
  {
    "path": "examples/dreambooth/README_hidream.md",
    "content": "# DreamBooth training example for HiDream Image\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.\n\nThe `train_dreambooth_lora_hidream.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/). \n\n\nThis will also allow us to push the trained model parameters to the Hugging Face Hub platform.\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/dreambooth` folder and run\n```bash\npip install -r requirements_hidream.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.\n\n\n### 3d icon example\n\nFor this example we will use some 3d icon images: https://huggingface.co/datasets/linoyts/3d_icon.\n\nThis will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.\n\nNow, we can launch training using:\n> [!NOTE]\n> The following training configuration prioritizes lower memory consumption by using gradient checkpointing, \n> 8-bit Adam optimizer, latent caching, offloading, no validation.\n> all text embeddings are pre-computed to save memory.\n```bash\nexport MODEL_NAME=\"HiDream-ai/HiDream-I1-Dev\"\nexport INSTANCE_DIR=\"linoyts/3d_icon\"\nexport OUTPUT_DIR=\"trained-hidream-lora\"\n\naccelerate launch train_dreambooth_lora_hidream.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --dataset_name=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=\"bf16\" \\\n  --instance_prompt=\"3d icon\" \\\n  --caption_column=\"prompt\"\\\n  --validation_prompt=\"a 3dicon, a llama eating ramen\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --use_8bit_adam \\\n  --rank=8 \\\n  --learning_rate=2e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant_with_warmup\" \\\n  --lr_warmup_steps=100 \\\n  --max_train_steps=1000 \\\n  --cache_latents\\\n  --gradient_checkpointing \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nFor using `push_to_hub`, make you're logged into your Hugging Face account:\n\n```bash\nhf auth login\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\n## Notes\n\nAdditionally, we welcome you to explore the following CLI arguments:\n\n* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v\" will result in lora training of attention layers only.\n* `--rank`: The rank of the LoRA layers. The higher the rank, the more parameters are trained. The default is 16.\n\nWe provide several options for optimizing memory optimization:\n\n* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.\n* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.\n* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.\n\nRefer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.\n\n## Using quantization\n\nYou can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:\n\n```json\n{\n    \"load_in_4bit\": true,\n    \"bnb_4bit_quant_type\": \"nf4\"\n}\n```\n\nBelow, we provide some numbers with and without the use of NF4 quantization when training:\n\n```\n(with quantization)\nMemory (before device placement): 9.085089683532715 GB.\nMemory (after device placement): 34.59585428237915 GB.\nMemory (after backward): 36.90267467498779 GB.\n\n(without quantization)\nMemory (before device placement): 0.0 GB.\nMemory (after device placement): 57.6400408744812 GB.\nMemory (after backward): 59.932212829589844 GB.\n```\n\nThe reason why we see some memory before device placement in the case of quantization is because, by default bnb quantized models are placed on the GPU first."
  },
  {
    "path": "examples/dreambooth/README_lumina2.md",
    "content": "# DreamBooth training example for Lumina2\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.\n\nThe `train_dreambooth_lora_lumina2.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2). \n\n\nThis will also allow us to push the trained model parameters to the Hugging Face Hub platform.\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/dreambooth` folder and run\n```bash\npip install -r requirements_sana.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.\n\n\n### Dog toy example\n\nNow let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.\n\nLet's first download it locally:\n\n```python\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog\"\nsnapshot_download(\n    \"diffusers/dog-example\",\n    local_dir=local_dir, repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nThis will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.\n\nNow, we can launch training using:\n\n```bash\nexport MODEL_NAME=\"Alpha-VLLM/Lumina-Image-2.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-lumina2-lora\"\n\naccelerate launch train_dreambooth_lora_lumina2.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=\"bf16\" \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --use_8bit_adam \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nFor using `push_to_hub`, make you're logged into your Hugging Face account:\n\n```bash\nhf auth login\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\n## Notes\n\nAdditionally, we welcome you to explore the following CLI arguments:\n\n* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v\" will result in lora training of attention layers only.\n* `--system_prompt`: A custom system prompt to provide additional personality to the model.\n* `--max_sequence_length`: Maximum sequence length to use for text embeddings.\n\n\nWe provide several options for optimizing memory optimization:\n\n* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.\n* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.\n* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.\n\nRefer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2) of the `LuminaPipeline` to know more about the model.\n"
  },
  {
    "path": "examples/dreambooth/README_qwen.md",
    "content": "# DreamBooth training example for Qwen Image\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.\n\nThe `train_dreambooth_lora_qwen_image.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [Qwen Image](https://huggingface.co/Qwen/Qwen-Image). \n\n\nThis will also allow us to push the trained model parameters to the Hugging Face Hub platform.\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/dreambooth` folder and run\n```bash\npip install -r requirements_sana.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.\n\n\n### Dog toy example\n\nNow let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.\n\nLet's first download it locally:\n\n```python\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog\"\nsnapshot_download(\n    \"diffusers/dog-example\",\n    local_dir=local_dir, repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nThis will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.\n\nNow, we can launch training using:\n\n```bash\nexport MODEL_NAME=\"Qwen/Qwen-Image\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-qwenimage-lora\"\n\naccelerate launch train_dreambooth_lora_qwen_image.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=\"bf16\" \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --use_8bit_adam \\\n  --learning_rate=2e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nFor using `push_to_hub`, make you're logged into your Hugging Face account:\n\n```bash\nhf auth login\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\n## Notes\n\nAdditionally, we welcome you to explore the following CLI arguments:\n\n* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v\" will result in lora training of attention layers only.\n* `--max_sequence_length`: Maximum sequence length to use for text embeddings.\n\nWe provide several options for optimizing memory optimization:\n\n* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.\n* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.\n* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.\n\nRefer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwenimage) of the `QwenImagePipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.\n\n## Using quantization\n\nYou can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:\n\n```json\n{\n    \"load_in_4bit\": true,\n    \"bnb_4bit_quant_type\": \"nf4\"\n}\n```\n"
  },
  {
    "path": "examples/dreambooth/README_sana.md",
    "content": "# DreamBooth training example for SANA\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.\n\nThe `train_dreambooth_lora_sana.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [SANA](https://huggingface.co/papers/2410.10629). \n\n\nThis will also allow us to push the trained model parameters to the Hugging Face Hub platform.\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/dreambooth` folder and run\n```bash\npip install -r requirements_sana.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.\n\n\n### Dog toy example\n\nNow let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.\n\nLet's first download it locally:\n\n```python\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog\"\nsnapshot_download(\n    \"diffusers/dog-example\",\n    local_dir=local_dir, repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nThis will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.\n\nNow, we can launch training using:\n\n```bash\nexport MODEL_NAME=\"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-sana-lora\"\n\naccelerate launch train_dreambooth_lora_sana.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=\"bf16\" \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --use_8bit_adam \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nFor using `push_to_hub`, make you're logged into your Hugging Face account:\n\n```bash\nhf auth login\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\n## Notes\n\n### LoRA Rank and Alpha\nTwo key LoRA hyperparameters are LoRA rank and LoRA alpha. \n- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).\n- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.\n- lora_alpha vs. rank:\nThis ratio dictates the LoRA's effective strength:\nlora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)\nlora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)\nlora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)\n\n> [!TIP]\n> A common starting point is to set `lora_alpha` equal to `rank`. \n> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16) \n> to give the LoRA updates more influence without increasing parameter count. \n> If you find your LoRA is \"overcooking\" or learning too aggressively, consider setting `lora_alpha` to half of `rank` \n> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.\n\n### Additional CLI arguments\n\nAdditionally, we welcome you to explore the following CLI arguments:\n\n* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v\" will result in lora training of attention layers only.\n* `--complex_human_instruction`: Instructions for complex human attention as shown in [here](https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55).\n* `--max_sequence_length`: Maximum sequence length to use for text embeddings.\n\n\nWe provide several options for optimizing memory optimization:\n\n* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.\n* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.\n* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.\n\nRefer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.\n"
  },
  {
    "path": "examples/dreambooth/README_sd3.md",
    "content": "# DreamBooth training example for Stable Diffusion 3 (SD3)\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.\n\nThe `train_dreambooth_sd3.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). We also provide a LoRA implementation in the `train_dreambooth_lora_sd3.py` script.\n\n> [!NOTE]\n> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:\n\n```bash\nhf auth login\n```\n\nThis will also allow us to push the trained model parameters to the Hugging Face Hub platform.\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/dreambooth` folder and run\n```bash\npip install -r requirements_sd3.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\n\n### Dog toy example\n\nNow let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.\n\nLet's first download it locally:\n\n```python\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog\"\nsnapshot_download(\n    \"diffusers/dog-example\",\n    local_dir=local_dir, repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nThis will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.\n\nNow, we can launch training using:\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-3-medium-diffusers\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-sd3\"\n\naccelerate launch train_dreambooth_sd3.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=\"fp16\" \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\n> [!NOTE]\n> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.\n\n> [!TIP]\n> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.\n\n## LoRA + DreamBooth\n\n[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.\n\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\nTo perform DreamBooth with LoRA, run:\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-3-medium-diffusers\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-sd3-lora\"\n\naccelerate launch train_dreambooth_lora_sd3.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=\"fp16\" \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --learning_rate=4e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n### Targeting Specific Blocks & Layers\nAs image generation models get bigger & more powerful, more fine-tuners come to find that training only part of the \ntransformer blocks (sometimes as little as two) can be enough to get great results. \nIn some cases, it can be even better to maintain some of the blocks/layers frozen.\n\nFor **SD3.5-Large** specifically, you may find this information useful (taken from: [Stable Diffusion 3.5 Large Fine-tuning Tutorial](https://stabilityai.notion.site/Stable-Diffusion-3-5-Large-Fine-tuning-Tutorial-11a61cdcd1968027a15bdbd7c40be8c6#12461cdcd19680788a23c650dab26b93):\n> [!NOTE]\n> A commonly believed heuristic that we verified once again during the construction of the SD3.5 family of models is that later/higher layers (i.e. `30 - 37`)* impact tertiary details more heavily. Conversely, earlier layers (i.e. `12 - 24` )* influence the overall composition/primary form more. \n> So, freezing other layers/targeting specific layers is a viable approach.\n> `*`These suggested layers are speculative and not 100% guaranteed. The tips here are more or less a general idea for next steps.\n> **Photorealism**\n> In preliminary testing, we observed that freezing the last few layers of the architecture significantly improved model training when using a photorealistic dataset, preventing detail degradation introduced by small dataset from happening.\n> **Anatomy preservation**\n> To dampen any possible degradation of anatomy, training only the attention layers and **not** the adaptive linear layers could help. For reference, below is one of the transformer blocks.\n\n\nWe've added `--lora_layers` and `--lora_blocks` to make LoRA training modules configurable. \n- with `--lora_blocks` you can specify the block numbers for training. E.g. passing - \n```diff\n--lora_blocks \"12,13,14,15,16,17,18,19,20,21,22,23,24,30,31,32,33,34,35,36,37\"\n```\nwill trigger LoRA training of transformer blocks 12-24 and 30-37. By default, all blocks are trained. \n- with `--lora_layers` you can specify the types of layers you wish to train. \nBy default, the trained layers are -  \n`attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,attn.to_k,attn.to_out.0,attn.to_q,attn.to_v`\nIf you wish to have a leaner LoRA / train more blocks over layers you could pass - \n```diff\n+ --lora_layers attn.to_k,attn.to_q,attn.to_v,attn.to_out.0\n```\nThis will reduce LoRA size by roughly 50% for the same rank compared to the default. \nHowever, if you're after compact LoRAs, it's our impression that maintaining the default setting for `--lora_layers` and\nfreezing some of the early & blocks is usually better. \n\n\n### Text Encoder Training\nAlongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.\nTo do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:\n\n> [!NOTE]\n> SD3 has three text encoders (CLIP L/14, OpenCLIP bigG/14, and T5-v1.1-XXL).\nBy enabling `--train_text_encoder`, LoRA fine-tuning of both **CLIP encoders** is performed. At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.\n\nTo perform DreamBooth LoRA with text-encoder training, run:\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-3-medium-diffusers\"\nexport OUTPUT_DIR=\"trained-sd3-lora\"\n\naccelerate launch train_dreambooth_lora_sd3.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --output_dir=$OUTPUT_DIR \\\n  --dataset_name=\"Norod78/Yarn-art-style\" \\\n  --instance_prompt=\"a photo of TOK yarn art dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --train_text_encoder\\\n  --gradient_accumulation_steps=1 \\\n  --optimizer=\"prodigy\"\\\n  --learning_rate=1.0 \\\n  --text_encoder_lr=1.0 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=1500 \\\n  --rank=32 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n## Other notes\n\n1. We default to the \"logit_normal\" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.\n2. Thanks to `bghira`, `JinxuXiang`, and `bendanzzc` for helping us discover a bug in how VAE encoding was being done previously. This has been fixed in [#8917](https://github.com/huggingface/diffusers/pull/8917).\n3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well."
  },
  {
    "path": "examples/dreambooth/README_sdxl.md",
    "content": "# DreamBooth training example for Stable Diffusion XL (SDXL)\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.\n\nThe `train_dreambooth_lora_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).\n\n> 💡 **Note**: For now, we only allow DreamBooth fine-tuning of the SDXL UNet via LoRA. LoRA is a parameter-efficient fine-tuning technique introduced in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/dreambooth` folder and run\n```bash\npip install -r requirements_sdxl.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\n### Dog toy example\n\nNow let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.\n\nLet's first download it locally:\n\n```python\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog\"\nsnapshot_download(\n    \"diffusers/dog-example\",\n    local_dir=local_dir, repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nThis will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.\n\nNow, we can launch training using:\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"lora-trained-xl\"\nexport VAE_PATH=\"madebyollin/sdxl-vae-fp16-fix\"\n\naccelerate launch train_dreambooth_lora_sdxl.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --pretrained_vae_model_name_or_path=$VAE_PATH \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=\"fp16\" \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\nOur experiments were conducted on a single 40GB A100 GPU.\n\n### Dog toy example with < 16GB VRAM\n\nBy making use of [`gradient_checkpointing`](https://pytorch.org/docs/stable/checkpoint.html) (which is natively supported in Diffusers), [`xformers`](https://github.com/facebookresearch/xformers), and [`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) libraries, you can train SDXL LoRAs with less than 16GB of VRAM by adding the following flags to your accelerate launch command:\n\n```diff\n+  --enable_xformers_memory_efficient_attention \\\n+  --gradient_checkpointing \\\n+  --use_8bit_adam \\\n+  --mixed_precision=\"fp16\" \\\n```\n\nand making sure that you have the following libraries installed:\n\n```\nbitsandbytes>=0.40.0\nxformers>=0.0.20\n```\n\n### Inference\n\nOnce training is done, we can perform inference like so:\n\n```python\nfrom huggingface_hub.repocard import RepoCard\nfrom diffusers import DiffusionPipeline\nimport torch\n\nlora_model_id = <\"lora-sdxl-dreambooth-id\">\ncard = RepoCard.load(lora_model_id)\nbase_model_id = card.data.to_dict()[\"base_model\"]\n\npipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)\npipe = pipe.to(\"cuda\")\npipe.load_lora_weights(lora_model_id)\nimage = pipe(\"A picture of a sks dog in a bucket\", num_inference_steps=25).images[0]\nimage.save(\"sks_dog.png\")\n```\n\nWe can further refine the outputs with the [Refiner](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0):\n\n```python\nfrom huggingface_hub.repocard import RepoCard\nfrom diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline\nimport torch\n\nlora_model_id = <\"lora-sdxl-dreambooth-id\">\ncard = RepoCard.load(lora_model_id)\nbase_model_id = card.data.to_dict()[\"base_model\"]\n\n# Load the base pipeline and load the LoRA parameters into it.\npipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)\npipe = pipe.to(\"cuda\")\npipe.load_lora_weights(lora_model_id)\n\n# Load the refiner.\nrefiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-refiner-1.0\", torch_dtype=torch.float16, use_safetensors=True, variant=\"fp16\"\n)\nrefiner.to(\"cuda\")\n\nprompt = \"A picture of a sks dog in a bucket\"\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\n\n# Run inference.\nimage = pipe(prompt=prompt, output_type=\"latent\", generator=generator).images[0]\nimage = refiner(prompt=prompt, image=image[None, :], generator=generator).images[0]\nimage.save(\"refined_sks_dog.png\")\n```\n\nHere's a side-by-side comparison of the with and without Refiner pipeline outputs:\n\n| Without Refiner | With Refiner |\n|---|---|\n| ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/sks_dog.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_sks_dog.png) |\n\n### Training with text encoder(s)\n\nAlongside the UNet, LoRA fine-tuning of the text encoders is also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:\n\n* SDXL has two text encoders. So, we fine-tune both using LoRA.\n* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.\n\n### Specifying a better VAE\n\nSDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).\n\n## Notes\n\nIn our experiments, we found that SDXL yields good initial results without extensive hyperparameter tuning. For example, without fine-tuning the text encoders and without using prior-preservation, we observed decent results. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗\n\n## Results\n\nYou can explore the results from a couple of our internal experiments by checking out this link: [https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl](https://wandb.ai/sayakpaul/dreambooth-lora-sd-xl). Specifically, we used the same script with the exact same hyperparameters on the following datasets:\n\n* [Dogs](https://huggingface.co/datasets/diffusers/dog-example)\n* [Starbucks logo](https://huggingface.co/datasets/diffusers/starbucks-example)\n* [Mr. Potato Head](https://huggingface.co/datasets/diffusers/potato-head-example)\n* [Keramer face](https://huggingface.co/datasets/diffusers/keramer-face-example)\n\n## Running on a free-tier Colab Notebook\n\nCheck out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb).\n\n## Conducting EDM-style training\n\nIt's now possible to perform EDM-style training as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364).\n\nFor the SDXL model, simple set:\n\n```diff\n+  --do_edm_style_training \\\n```\n\nOther SDXL-like models that use the EDM formulation, such as [playgroundai/playground-v2.5-1024px-aesthetic](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic), can also be DreamBooth'd with the script. Below is an example command:\n\n```bash\naccelerate launch train_dreambooth_lora_sdxl.py \\\n  --pretrained_model_name_or_path=\"playgroundai/playground-v2.5-1024px-aesthetic\"  \\\n  --instance_data_dir=\"dog\" \\\n  --output_dir=\"dog-playground-lora\" \\\n  --mixed_precision=\"fp16\" \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --learning_rate=1e-4 \\\n  --use_8bit_adam \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n> [!CAUTION]\n> Min-SNR gamma is not supported with the EDM-style training yet. When training with the PlaygroundAI model, it's recommended to not pass any \"variant\".\n\n### DoRA training\nThe script now supports DoRA training too!\n> Proposed in [DoRA: Weight-Decomposed Low-Rank Adaptation](https://huggingface.co/papers/2402.09353),\n**DoRA** is very similar to LoRA, except it decomposes the pre-trained weight into two components, **magnitude** and **direction** and employs LoRA for _directional_ updates to efficiently minimize the number of trainable parameters.\nThe authors found that by using DoRA, both the learning capacity and training stability of LoRA are enhanced without any additional overhead during inference.\n\n> [!NOTE]\n> 💡DoRA training is still _experimental_\n> and is likely to require different hyperparameter values to perform best compared to a LoRA.\n> Specifically, we've noticed 2 differences to take into account your training:\n> 1. **LoRA seem to converge faster than DoRA** (so a set of parameters that may lead to overfitting when training a LoRA may be working well for a DoRA)\n> 2. **DoRA quality superior to LoRA especially in lower ranks** the difference in quality of DoRA of rank 8 and LoRA of rank 8 appears to be more significant than when training ranks of 32 or 64 for example.\n> This is also aligned with some of the quantitative analysis shown in the paper.\n\n**Usage**\n1. To use DoRA you need to upgrade the installation of `peft`:\n```bash\npip install -U peft\n```\n2. Enable DoRA training by adding this flag\n```bash\n--use_dora\n```\n**Inference**\nThe inference is the same as if you train a regular LoRA 🤗\n\n## Format compatibility\n\nYou can pass `--output_kohya_format` to additionally generate a state dictionary which should be compatible with other platforms and tools such as Automatic 1111, Comfy, Kohya, etc. The `output_dir` will contain a file named \"pytorch_lora_weights_kohya.safetensors\"."
  },
  {
    "path": "examples/dreambooth/README_z_image.md",
    "content": "# DreamBooth training example for Z-Image\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.\n[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.\n\nThe `train_dreambooth_lora_z_image.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image).\n\n> [!NOTE]\n> **About Z-Image**\n>\n> Z-Image is a high-quality text-to-image generation model from Alibaba's Tongyi Lab. It uses a DiT (Diffusion Transformer) architecture with Qwen3 as the text encoder. The model excels at generating images with accurate text rendering, especially for Chinese characters.\n\n> [!NOTE]\n> **Memory consumption**\n>\n> Z-Image is relatively memory efficient compared to other large-scale diffusion models. Below we provide some tips and tricks to further reduce memory consumption during training.\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/dreambooth` folder and run\n```bash\npip install -r requirements_z_image.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\n\n### Dog toy example\n\nNow let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.\n\nLet's first download it locally:\n\n```python\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./dog\"\nsnapshot_download(\n    \"diffusers/dog-example\",\n    local_dir=local_dir, repo_type=\"dataset\",\n    ignore_patterns=\".gitattributes\",\n)\n```\n\nThis will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.\n\n## Memory Optimizations\n\n> [!NOTE] \n> Many of these techniques complement each other and can be used together to further reduce memory consumption. However some techniques may be mutually exclusive so be sure to check before launching a training run.\n\n### CPU Offloading \nTo offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the VAE and text encoder to CPU memory and only move them to GPU when needed.\n\n### Latent Caching \nPre-encode the training images with the VAE, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.\n\n### QLoRA: Low Precision Training with Quantization\nPerform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:\n\n- **FP8 training** with `torchao`: \nEnable FP8 training by passing `--do_fp8_training`. \n> [!IMPORTANT] \n> Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater. If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers.\n\n- **NF4 training** with `bitsandbytes`: \nAlternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing `--bnb_quantization_config_path` to enable 4-bit NF4 quantization.\n\n### Gradient Checkpointing and Accumulation\n* `--gradient_accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. By passing a value > 1 you can reduce the amount of backward/update passes and hence also memory requirements.\n* With `--gradient_checkpointing` we can save memory by not storing all intermediate activations during the forward pass. Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expense of a slower backward pass.\n\n### 8-bit-Adam Optimizer\nWhen training with `AdamW` (doesn't apply to `prodigy`) you can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.\n\n### Image Resolution\nAn easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.\nNote that by default, images are resized to resolution of 1024, but it's good to keep in mind in case you're training on higher resolutions.\n\n### Precision of saved LoRA layers\nBy default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision=\"bf16\"`, final finetuned layers will be saved in `torch.bfloat16` as well. \nThis reduces memory requirements significantly without a significant quality loss. Note that if you do wish to save the final layers in float32 at the expense of more memory usage, you can do so by passing `--upcast_before_saving`.\n\n## Training Examples\n\n### Z-Image Training\n\nTo perform DreamBooth with LoRA on Z-Image, run:\n\n```bash\nexport MODEL_NAME=\"Tongyi-MAI/Z-Image\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-z-image-lora\"\n\naccelerate launch train_dreambooth_lora_z_image.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=\"bf16\" \\\n  --gradient_checkpointing \\\n  --cache_latents \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --guidance_scale=5.0 \\\n  --use_8bit_adam \\\n  --gradient_accumulation_steps=4 \\\n  --optimizer=\"adamW\" \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=100 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb\"` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.\n* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\n> [!NOTE]\n> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. The default is 512. Note that this will use more resources and may slow down the training in some cases.\n\n### Training with FP8 Quantization\n\nFor reduced memory usage with FP8 training:\n\n```bash\nexport MODEL_NAME=\"Tongyi-MAI/Z-Image\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-z-image-lora-fp8\"\n\naccelerate launch train_dreambooth_lora_z_image.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --do_fp8_training \\\n  --gradient_checkpointing \\\n  --cache_latents \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --guidance_scale=5.0 \\\n  --use_8bit_adam \\\n  --gradient_accumulation_steps=4 \\\n  --optimizer=\"adamW\" \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=100 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n### FSDP on the transformer\n\nBy setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:\n\n```yaml\ndistributed_type: FSDP\nfsdp_config:\n  fsdp_version: 2\n  fsdp_offload_params: false\n  fsdp_sharding_strategy: HYBRID_SHARD\n  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n  fsdp_transformer_layer_cls_to_wrap: ZImageTransformerBlock\n  fsdp_forward_prefetch: true\n  fsdp_sync_module_states: false\n  fsdp_state_dict_type: FULL_STATE_DICT\n  fsdp_use_orig_params: false\n  fsdp_activation_checkpointing: true\n  fsdp_reshard_after_forward: true\n  fsdp_cpu_ram_efficient_loading: false\n```\n\n### Prodigy Optimizer\n\nProdigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence. \nBy using prodigy we can \"eliminate\" the need for manual learning rate tuning. Read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).\n\nTo use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify:\n```bash\n--optimizer=\"prodigy\"\n```\n\n> [!TIP]\n> When using prodigy it's generally good practice to set `--learning_rate=1.0`\n\n```bash\nexport MODEL_NAME=\"Tongyi-MAI/Z-Image\"\nexport INSTANCE_DIR=\"dog\"\nexport OUTPUT_DIR=\"trained-z-image-lora-prodigy\"\n\naccelerate launch train_dreambooth_lora_z_image.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=\"bf16\" \\\n  --gradient_checkpointing \\\n  --cache_latents \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --guidance_scale=5.0 \\\n  --gradient_accumulation_steps=4 \\\n  --optimizer=\"prodigy\" \\\n  --learning_rate=1.0 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant_with_warmup\" \\\n  --lr_warmup_steps=100 \\\n  --max_train_steps=500 \\\n  --validation_prompt=\"A photo of sks dog in a bucket\" \\\n  --validation_epochs=25 \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n### LoRA Rank and Alpha\n\nTwo key LoRA hyperparameters are LoRA rank and LoRA alpha:\n\n- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).\n- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by `lora_alpha / lora_rank`.\n\n**lora_alpha vs. rank:**\n\nThis ratio dictates the LoRA's effective strength:\n- `lora_alpha == rank`: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)\n- `lora_alpha < rank`: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)\n- `lora_alpha > rank`: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)\n\n> [!TIP]\n> A common starting point is to set `lora_alpha` equal to `rank`. \n> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16) \n> to give the LoRA updates more influence without increasing parameter count. \n> If you find your LoRA is \"overcooking\" or learning too aggressively, consider setting `lora_alpha` to half of `rank` \n> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.\n\n### Target Modules\n\nWhen LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the UNet that relate the image representations with the prompts that describe them. \nMore recently, SOTA text-to-image diffusion models replaced the UNet with a diffusion Transformer (DiT). With this change, we may also want to explore applying LoRA training onto different types of layers and blocks.\n\nTo allow more flexibility and control over the targeted modules we added `--lora_layers`, in which you can specify in a comma separated string the exact modules for LoRA training. Here are some examples of target modules you can provide:\n\n- For attention only layers: `--lora_layers=\"to_k,to_q,to_v,to_out.0\"`\n- For attention and feed-forward layers: `--lora_layers=\"to_k,to_q,to_v,to_out.0,ff.net.0.proj,ff.net.2\"`\n\n> [!NOTE]\n> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string.\n\n> [!NOTE]\n> Keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.\n\n### Aspect Ratio Bucketing\n\nWe've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.\n\nTo enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:\n\n```bash\n--aspect_ratio_buckets=\"672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672\"\n```\n\n### Bilingual Prompts\n\nZ-Image has strong support for both Chinese and English prompts. When training with Chinese prompts, ensure your dataset captions are properly encoded in UTF-8:\n\n```bash\n--instance_prompt=\"一只sks狗的照片\"\n--validation_prompt=\"一只sks狗在桶里的照片\"\n```\n\n> [!TIP]\n> Z-Image excels at text rendering in generated images, especially for Chinese characters. If your use case involves generating images with text, consider including text-related examples in your training data.\n\n## Inference\n\nOnce you have trained a LoRA, you can load it for inference:\n\n```python\nimport torch\nfrom diffusers import ZImagePipeline\n\npipe = ZImagePipeline.from_pretrained(\"Tongyi-MAI/Z-Image\", torch_dtype=torch.bfloat16)\npipe.to(\"cuda\")\n\n# Load your trained LoRA\npipe.load_lora_weights(\"path/to/your/trained-z-image-lora\")\n\n# Generate an image\nimage = pipe(\n    prompt=\"A photo of sks dog in a bucket\",\n    height=1024,\n    width=1024,\n    num_inference_steps=50,\n    guidance_scale=5.0,\n    generator=torch.Generator(\"cuda\").manual_seed(42),\n).images[0]\n\nimage.save(\"output.png\")\n```\n\n---\n\nSince Z-Image finetuning is still in an experimental phase, we encourage you to explore different settings and share your insights! 🤗"
  },
  {
    "path": "examples/dreambooth/convert_to_imagefolder.py",
    "content": "import argparse\nimport json\nimport pathlib\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"--path\",\n    type=str,\n    required=True,\n    help=\"Path to folder with image-text pairs.\",\n)\nparser.add_argument(\"--caption_column\", type=str, default=\"prompt\", help=\"Name of caption column.\")\nargs = parser.parse_args()\n\npath = pathlib.Path(args.path)\nif not path.exists():\n    raise RuntimeError(f\"`--path` '{args.path}' does not exist.\")\n\nall_files = list(path.glob(\"*\"))\ncaptions = list(path.glob(\"*.txt\"))\nimages = set(all_files) - set(captions)\nimages = {image.stem: image for image in images}\ncaption_image = {caption: images.get(caption.stem) for caption in captions if images.get(caption.stem)}\n\nmetadata = path.joinpath(\"metadata.jsonl\")\n\nwith metadata.open(\"w\", encoding=\"utf-8\") as f:\n    for caption, image in caption_image.items():\n        caption_text = caption.read_text(encoding=\"utf-8\")\n        json.dump({\"file_name\": image.name, args.caption_column: caption_text}, f)\n        f.write(\"\\n\")\n"
  },
  {
    "path": "examples/dreambooth/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2\npeft==0.7.0"
  },
  {
    "path": "examples/dreambooth/requirements_flax.txt",
    "content": "transformers>=4.25.1\nflax\noptax\ntorch\ntorchvision\nftfy\ntensorboard\nJinja2\n"
  },
  {
    "path": "examples/dreambooth/requirements_flux.txt",
    "content": "accelerate>=0.31.0\ntorchvision\ntransformers>=4.41.2\nftfy\ntensorboard\nJinja2\npeft>=0.11.1\nsentencepiece"
  },
  {
    "path": "examples/dreambooth/requirements_hidream.txt",
    "content": "accelerate>=1.4.0\ntorchvision\ntransformers>=4.50.0\nftfy\ntensorboard\nJinja2\npeft>=0.14.0\nsentencepiece"
  },
  {
    "path": "examples/dreambooth/requirements_sana.txt",
    "content": "accelerate>=1.0.0\ntorchvision\ntransformers>=4.47.0\nftfy\ntensorboard\nJinja2\npeft>=0.14.0\nsentencepiece"
  },
  {
    "path": "examples/dreambooth/requirements_sd3.txt",
    "content": "accelerate>=0.31.0\ntorchvision\ntransformers>=4.41.2\nftfy\ntensorboard\nJinja2\npeft==0.11.1\nsentencepiece"
  },
  {
    "path": "examples/dreambooth/requirements_sdxl.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2\npeft==0.7.0"
  },
  {
    "path": "examples/dreambooth/test_dreambooth.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport shutil\nimport sys\nimport tempfile\n\nfrom diffusers import DiffusionPipeline, UNet2DConditionModel\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBooth(ExamplesTestsAccelerate):\n    def test_dreambooth(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"unet\", \"diffusion_pytorch_model.safetensors\")))\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"scheduler\", \"scheduler_config.json\")))\n\n    def test_dreambooth_if(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --pre_compute_text_embeddings\n                --tokenizer_max_length=77\n                --text_encoder_use_attention_mask\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"unet\", \"diffusion_pytorch_model.safetensors\")))\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"scheduler\", \"scheduler_config.json\")))\n\n    def test_dreambooth_checkpointing(self):\n        instance_prompt = \"photo\"\n        pretrained_model_name_or_path = \"hf-internal-testing/tiny-stable-diffusion-pipe\"\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            # Run training script with checkpointing\n            # max_train_steps == 4, checkpointing_steps == 2\n            # Should create checkpoints at steps 2, 4\n\n            initial_run_args = f\"\"\"\n                examples/dreambooth/train_dreambooth.py\n                --pretrained_model_name_or_path {pretrained_model_name_or_path}\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt {instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 4\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            # check can run the original fully trained output pipeline\n            pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)\n            pipe(instance_prompt, num_inference_steps=1)\n\n            # check checkpoint directories exist\n            self.assertTrue(os.path.isdir(os.path.join(tmpdir, \"checkpoint-2\")))\n            self.assertTrue(os.path.isdir(os.path.join(tmpdir, \"checkpoint-4\")))\n\n            # check can run an intermediate checkpoint\n            unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder=\"checkpoint-2/unet\")\n            pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)\n            pipe(instance_prompt, num_inference_steps=1)\n\n            # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming\n            shutil.rmtree(os.path.join(tmpdir, \"checkpoint-2\"))\n\n            # Run training script for 7 total steps resuming from checkpoint 4\n\n            resume_run_args = f\"\"\"\n                examples/dreambooth/train_dreambooth.py\n                --pretrained_model_name_or_path {pretrained_model_name_or_path}\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt {instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 6\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --resume_from_checkpoint=checkpoint-4\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            # check can run new fully trained pipeline\n            pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)\n            pipe(instance_prompt, num_inference_steps=1)\n\n            # check old checkpoints do not exist\n            self.assertFalse(os.path.isdir(os.path.join(tmpdir, \"checkpoint-2\")))\n\n            # check new checkpoints exist\n            self.assertTrue(os.path.isdir(os.path.join(tmpdir, \"checkpoint-4\")))\n            self.assertTrue(os.path.isdir(os.path.join(tmpdir, \"checkpoint-6\")))\n\n    def test_dreambooth_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/dreambooth/train_dreambooth.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n            --instance_data_dir=docs/source/en/imgs\n            --output_dir={tmpdir}\n            --instance_prompt=prompt\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/dreambooth/train_dreambooth.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n            --instance_data_dir=docs/source/en/imgs\n            --output_dir={tmpdir}\n            --instance_prompt=prompt\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\"},\n            )\n\n            resume_run_args = f\"\"\"\n            examples/dreambooth/train_dreambooth.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n            --instance_data_dir=docs/source/en/imgs\n            --output_dir={tmpdir}\n            --instance_prompt=prompt\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_flux.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport shutil\nimport sys\nimport tempfile\n\nfrom diffusers import DiffusionPipeline, FluxTransformer2DModel\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothFlux(ExamplesTestsAccelerate):\n    instance_data_dir = \"docs/source/en/imgs\"\n    instance_prompt = \"photo\"\n    pretrained_model_name_or_path = \"hf-internal-testing/tiny-flux-pipe\"\n    script_path = \"examples/dreambooth/train_dreambooth_flux.py\"\n\n    def test_dreambooth(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"transformer\", \"diffusion_pytorch_model.safetensors\")))\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"scheduler\", \"scheduler_config.json\")))\n\n    def test_dreambooth_checkpointing(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            # Run training script with checkpointing\n            # max_train_steps == 4, checkpointing_steps == 2\n            # Should create checkpoints at steps 2, 4\n\n            initial_run_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 4\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            # check can run the original fully trained output pipeline\n            pipe = DiffusionPipeline.from_pretrained(tmpdir)\n            pipe(self.instance_prompt, num_inference_steps=1)\n\n            # check checkpoint directories exist\n            self.assertTrue(os.path.isdir(os.path.join(tmpdir, \"checkpoint-2\")))\n            self.assertTrue(os.path.isdir(os.path.join(tmpdir, \"checkpoint-4\")))\n\n            # check can run an intermediate checkpoint\n            transformer = FluxTransformer2DModel.from_pretrained(tmpdir, subfolder=\"checkpoint-2/transformer\")\n            pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)\n            pipe(self.instance_prompt, num_inference_steps=1)\n\n            # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming\n            shutil.rmtree(os.path.join(tmpdir, \"checkpoint-2\"))\n\n            # Run training script for 7 total steps resuming from checkpoint 4\n\n            resume_run_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 6\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --resume_from_checkpoint=checkpoint-4\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            # check can run new fully trained pipeline\n            pipe = DiffusionPipeline.from_pretrained(tmpdir)\n            pipe(self.instance_prompt, num_inference_steps=1)\n\n            # check old checkpoints do not exist\n            self.assertFalse(os.path.isdir(os.path.join(tmpdir, \"checkpoint-2\")))\n\n            # check new checkpoints exist\n            self.assertTrue(os.path.isdir(os.path.join(tmpdir, \"checkpoint-4\")))\n            self.assertTrue(os.path.isdir(os.path.join(tmpdir, \"checkpoint-6\")))\n\n    def test_dreambooth_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\"},\n            )\n\n            resume_run_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_lora.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\nfrom diffusers import DiffusionPipeline  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothLoRA(ExamplesTestsAccelerate):\n    def test_dreambooth_lora(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth_lora.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"unet\"` in their names.\n            starts_with_unet = all(key.startswith(\"unet\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_unet)\n\n    def test_dreambooth_lora_with_text_encoder(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth_lora.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --train_text_encoder\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # check `text_encoder` is present at all.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            keys = lora_state_dict.keys()\n            is_text_encoder_present = any(k.startswith(\"text_encoder\") for k in keys)\n            self.assertTrue(is_text_encoder_present)\n\n            # the names of the keys of the state dict should either start with `unet`\n            # or `text_encoder`.\n            is_correct_naming = all(k.startswith(\"unet\") or k.startswith(\"text_encoder\") for k in keys)\n            self.assertTrue(is_correct_naming)\n\n    def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/dreambooth/train_dreambooth_lora.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n            --instance_data_dir=docs/source/en/imgs\n            --output_dir={tmpdir}\n            --instance_prompt=prompt\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/dreambooth/train_dreambooth_lora.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n            --instance_data_dir=docs/source/en/imgs\n            --output_dir={tmpdir}\n            --instance_prompt=prompt\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-2\", \"checkpoint-4\"})\n\n            resume_run_args = f\"\"\"\n            examples/dreambooth/train_dreambooth_lora.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n            --instance_data_dir=docs/source/en/imgs\n            --output_dir={tmpdir}\n            --instance_prompt=prompt\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n\n    def test_dreambooth_lora_if_model(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth_lora.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --pre_compute_text_embeddings\n                --tokenizer_max_length=77\n                --text_encoder_use_attention_mask\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"unet\"` in their names.\n            starts_with_unet = all(key.startswith(\"unet\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_unet)\n\n\nclass DreamBoothLoRASDXL(ExamplesTestsAccelerate):\n    def test_dreambooth_lora_sdxl(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth_lora_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"unet\"` in their names.\n            starts_with_unet = all(key.startswith(\"unet\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_unet)\n\n    def test_dreambooth_lora_sdxl_with_text_encoder(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth_lora_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --train_text_encoder\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"unet\"` or `\"text_encoder\"` or `\"text_encoder_2\"` in their names.\n            keys = lora_state_dict.keys()\n            starts_with_unet = all(\n                k.startswith(\"unet\") or k.startswith(\"text_encoder\") or k.startswith(\"text_encoder_2\") for k in keys\n            )\n            self.assertTrue(starts_with_unet)\n\n    def test_dreambooth_lora_sdxl_custom_captions(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth_lora_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --caption_column text\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n    def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth_lora_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --caption_column text\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --train_text_encoder\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n    def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):\n        pipeline_path = \"hf-internal-testing/tiny-stable-diffusion-xl-pipe\"\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth_lora_sdxl.py\n                --pretrained_model_name_or_path {pipeline_path}\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 6\n                --checkpointing_steps=2\n                --checkpoints_total_limit=2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            pipe = DiffusionPipeline.from_pretrained(pipeline_path)\n            pipe.load_lora_weights(tmpdir)\n            pipe(\"a prompt\", num_inference_steps=1)\n\n            # check checkpoint directories exist\n            # checkpoint-2 should have been deleted\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-4\", \"checkpoint-6\"})\n\n    def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):\n        pipeline_path = \"hf-internal-testing/tiny-stable-diffusion-xl-pipe\"\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth_lora_sdxl.py\n                --pretrained_model_name_or_path {pipeline_path}\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 7\n                --checkpointing_steps=2\n                --checkpoints_total_limit=2\n                --train_text_encoder\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            pipe = DiffusionPipeline.from_pretrained(pipeline_path)\n            pipe.load_lora_weights(tmpdir)\n            pipe(\"a prompt\", num_inference_steps=2)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                # checkpoint-2 should have been deleted\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_lora_edm.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothLoRASDXLWithEDM(ExamplesTestsAccelerate):\n    def test_dreambooth_lora_sdxl_with_edm(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth_lora_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe\n                --do_edm_style_training\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"unet\"` in their names.\n            starts_with_unet = all(key.startswith(\"unet\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_unet)\n\n    def test_dreambooth_lora_playground(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/dreambooth/train_dreambooth_lora_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-playground-v2-5-pipe\n                --instance_data_dir docs/source/en/imgs\n                --instance_prompt photo\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"unet\"` in their names.\n            starts_with_unet = all(key.startswith(\"unet\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_unet)\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_lora_flux.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\nfrom diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothLoRAFlux(ExamplesTestsAccelerate):\n    instance_data_dir = \"docs/source/en/imgs\"\n    instance_prompt = \"photo\"\n    pretrained_model_name_or_path = \"hf-internal-testing/tiny-flux-pipe\"\n    script_path = \"examples/dreambooth/train_dreambooth_lora_flux.py\"\n    transformer_layer_type = \"single_transformer_blocks.0.attn.to_k\"\n\n    def test_dreambooth_lora_flux(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_text_encoder_flux(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --train_text_encoder\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            starts_with_expected_prefix = all(\n                (key.startswith(\"transformer\") or key.startswith(\"text_encoder\")) for key in lora_state_dict.keys()\n            )\n            self.assertTrue(starts_with_expected_prefix)\n\n    def test_dreambooth_lora_latent_caching(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_layers(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lora_layers {self.transformer_layer_type}\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names. In this test, we only params of\n            # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict\n            starts_with_transformer = all(\n                key.startswith(\"transformer.single_transformer_blocks.0.attn.to_k\") for key in lora_state_dict.keys()\n            )\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-2\", \"checkpoint-4\"})\n\n            resume_run_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n\n    def test_dreambooth_lora_with_metadata(self):\n        # Use a `lora_alpha` that is different from `rank`.\n        lora_alpha = 8\n        rank = 4\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --lora_alpha={lora_alpha}\n                --rank={rank}\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            state_dict_file = os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")\n            self.assertTrue(os.path.isfile(state_dict_file))\n\n            # Check if the metadata was properly serialized.\n            with safetensors.torch.safe_open(state_dict_file, framework=\"pt\", device=\"cpu\") as f:\n                metadata = f.metadata() or {}\n\n            metadata.pop(\"format\", None)\n            raw = metadata.get(LORA_ADAPTER_METADATA_KEY)\n            if raw:\n                raw = json.loads(raw)\n\n            loaded_lora_alpha = raw[\"transformer.lora_alpha\"]\n            self.assertTrue(loaded_lora_alpha == lora_alpha)\n            loaded_lora_rank = raw[\"transformer.r\"]\n            self.assertTrue(loaded_lora_rank == rank)\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_lora_flux2.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\nfrom diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothLoRAFlux2(ExamplesTestsAccelerate):\n    instance_data_dir = \"docs/source/en/imgs\"\n    instance_prompt = \"dog\"\n    pretrained_model_name_or_path = \"hf-internal-testing/tiny-flux2\"\n    script_path = \"examples/dreambooth/train_dreambooth_lora_flux2.py\"\n    transformer_layer_type = \"single_transformer_blocks.0.attn.to_qkv_mlp_proj\"\n\n    def test_dreambooth_lora_flux2(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --max_sequence_length 8\n                --text_encoder_out_layers 1\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_latent_caching(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --max_sequence_length 8\n                --text_encoder_out_layers 1\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_layers(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lora_layers {self.transformer_layer_type}\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --max_sequence_length 8\n                --text_encoder_out_layers 1\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names. In this test, we only params of\n            # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict\n            starts_with_transformer = all(\n                key.startswith(f\"transformer.{self.transformer_layer_type}\") for key in lora_state_dict.keys()\n            )\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --max_sequence_length 8\n            --checkpointing_steps=2\n            --text_encoder_out_layers 1\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            --max_sequence_length 8\n            --text_encoder_out_layers 1\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-2\", \"checkpoint-4\"})\n\n            resume_run_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            --max_sequence_length 8\n            --text_encoder_out_layers 1\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n\n    def test_dreambooth_lora_with_metadata(self):\n        # Use a `lora_alpha` that is different from `rank`.\n        lora_alpha = 8\n        rank = 4\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --lora_alpha={lora_alpha}\n                --rank={rank}\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --max_sequence_length 8\n                --text_encoder_out_layers 1\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            state_dict_file = os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")\n            self.assertTrue(os.path.isfile(state_dict_file))\n\n            # Check if the metadata was properly serialized.\n            with safetensors.torch.safe_open(state_dict_file, framework=\"pt\", device=\"cpu\") as f:\n                metadata = f.metadata() or {}\n\n            metadata.pop(\"format\", None)\n            raw = metadata.get(LORA_ADAPTER_METADATA_KEY)\n            if raw:\n                raw = json.loads(raw)\n\n            loaded_lora_alpha = raw[\"transformer.lora_alpha\"]\n            self.assertTrue(loaded_lora_alpha == lora_alpha)\n            loaded_lora_rank = raw[\"transformer.r\"]\n            self.assertTrue(loaded_lora_rank == rank)\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_lora_flux2_klein.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\nfrom diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothLoRAFlux2Klein(ExamplesTestsAccelerate):\n    instance_data_dir = \"docs/source/en/imgs\"\n    instance_prompt = \"dog\"\n    pretrained_model_name_or_path = \"hf-internal-testing/tiny-flux2-klein\"\n    script_path = \"examples/dreambooth/train_dreambooth_lora_flux2_klein.py\"\n    transformer_layer_type = \"single_transformer_blocks.0.attn.to_qkv_mlp_proj\"\n\n    def test_dreambooth_lora_flux2(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --max_sequence_length 8\n                --text_encoder_out_layers 1\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_latent_caching(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --max_sequence_length 8\n                --text_encoder_out_layers 1\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_layers(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lora_layers {self.transformer_layer_type}\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --max_sequence_length 8\n                --text_encoder_out_layers 1\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names. In this test, we only params of\n            # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict\n            starts_with_transformer = all(\n                key.startswith(f\"transformer.{self.transformer_layer_type}\") for key in lora_state_dict.keys()\n            )\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --max_sequence_length 8\n            --checkpointing_steps=2\n            --text_encoder_out_layers 1\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            --max_sequence_length 8\n            --text_encoder_out_layers 1\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-2\", \"checkpoint-4\"})\n\n            resume_run_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            --max_sequence_length 8\n            --text_encoder_out_layers 1\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n\n    def test_dreambooth_lora_with_metadata(self):\n        # Use a `lora_alpha` that is different from `rank`.\n        lora_alpha = 8\n        rank = 4\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --lora_alpha={lora_alpha}\n                --rank={rank}\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --max_sequence_length 8\n                --text_encoder_out_layers 1\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            state_dict_file = os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")\n            self.assertTrue(os.path.isfile(state_dict_file))\n\n            # Check if the metadata was properly serialized.\n            with safetensors.torch.safe_open(state_dict_file, framework=\"pt\", device=\"cpu\") as f:\n                metadata = f.metadata() or {}\n\n            metadata.pop(\"format\", None)\n            raw = metadata.get(LORA_ADAPTER_METADATA_KEY)\n            if raw:\n                raw = json.loads(raw)\n\n            loaded_lora_alpha = raw[\"transformer.lora_alpha\"]\n            self.assertTrue(loaded_lora_alpha == lora_alpha)\n            loaded_lora_rank = raw[\"transformer.r\"]\n            self.assertTrue(loaded_lora_rank == rank)\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_lora_flux_kontext.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\nfrom diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothLoRAFluxKontext(ExamplesTestsAccelerate):\n    instance_data_dir = \"docs/source/en/imgs\"\n    instance_prompt = \"photo\"\n    pretrained_model_name_or_path = \"hf-internal-testing/tiny-flux-kontext-pipe\"\n    script_path = \"examples/dreambooth/train_dreambooth_lora_flux_kontext.py\"\n    transformer_layer_type = \"single_transformer_blocks.0.attn.to_k\"\n\n    def test_dreambooth_lora_flux_kontext(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_text_encoder_flux_kontext(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --train_text_encoder\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            starts_with_expected_prefix = all(\n                (key.startswith(\"transformer\") or key.startswith(\"text_encoder\")) for key in lora_state_dict.keys()\n            )\n            self.assertTrue(starts_with_expected_prefix)\n\n    def test_dreambooth_lora_latent_caching(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_layers(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lora_layers {self.transformer_layer_type}\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names. In this test, we only params of\n            # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict\n            starts_with_transformer = all(\n                key.startswith(\"transformer.single_transformer_blocks.0.attn.to_k\") for key in lora_state_dict.keys()\n            )\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_flux_kontext_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_lora_flux_kontext_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-2\", \"checkpoint-4\"})\n\n            resume_run_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n\n    def test_dreambooth_lora_with_metadata(self):\n        # Use a `lora_alpha` that is different from `rank`.\n        lora_alpha = 8\n        rank = 4\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --lora_alpha={lora_alpha}\n                --rank={rank}\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            state_dict_file = os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")\n            self.assertTrue(os.path.isfile(state_dict_file))\n\n            # Check if the metadata was properly serialized.\n            with safetensors.torch.safe_open(state_dict_file, framework=\"pt\", device=\"cpu\") as f:\n                metadata = f.metadata() or {}\n\n            metadata.pop(\"format\", None)\n            raw = metadata.get(LORA_ADAPTER_METADATA_KEY)\n            if raw:\n                raw = json.loads(raw)\n\n            loaded_lora_alpha = raw[\"transformer.lora_alpha\"]\n            self.assertTrue(loaded_lora_alpha == lora_alpha)\n            loaded_lora_rank = raw[\"transformer.r\"]\n            self.assertTrue(loaded_lora_rank == rank)\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_lora_hidream.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothLoRAHiDreamImage(ExamplesTestsAccelerate):\n    instance_data_dir = \"docs/source/en/imgs\"\n    pretrained_model_name_or_path = \"hf-internal-testing/tiny-hidream-i1-pipe\"\n    text_encoder_4_path = \"hf-internal-testing/tiny-random-LlamaForCausalLM\"\n    tokenizer_4_path = \"hf-internal-testing/tiny-random-LlamaForCausalLM\"\n    script_path = \"examples/dreambooth/train_dreambooth_lora_hidream.py\"\n    transformer_layer_type = \"double_stream_blocks.0.block.attn1.to_k\"\n\n    def test_dreambooth_lora_hidream(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}\n                --pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}\n                --instance_data_dir {self.instance_data_dir}\n                --resolution 32\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --max_sequence_length 16\n                \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_latent_caching(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}\n                --pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}\n                --instance_data_dir {self.instance_data_dir}\n                --resolution 32\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --max_sequence_length 16\n                \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_layers(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}\n                --pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}\n                --instance_data_dir {self.instance_data_dir}\n                --resolution 32\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lora_layers {self.transformer_layer_type}\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --max_sequence_length 16\n                \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names. In this test, we only params of\n            # `self.transformer_layer_type` should be in the state dict.\n            starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}\n            --pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --resolution=32\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            --max_sequence_length 16\n            \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}\n            --pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --resolution=32\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            --max_sequence_length 16\n            \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-2\", \"checkpoint-4\"})\n\n            resume_run_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}\n            --pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --resolution=32\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            --max_sequence_length 16\n            \"\"\".split()\n\n            resume_run_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_lora_lumina2.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothLoRAlumina2(ExamplesTestsAccelerate):\n    instance_data_dir = \"docs/source/en/imgs\"\n    pretrained_model_name_or_path = \"hf-internal-testing/tiny-lumina2-pipe\"\n    script_path = \"examples/dreambooth/train_dreambooth_lora_lumina2.py\"\n    transformer_layer_type = \"layers.0.attn.to_k\"\n\n    def test_dreambooth_lora_lumina2(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --resolution 32\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --max_sequence_length 16\n                \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_latent_caching(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --resolution 32\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --max_sequence_length 16\n                \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_layers(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --resolution 32\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lora_layers {self.transformer_layer_type}\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --max_sequence_length 16\n                \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names. In this test, we only params of\n            # `self.transformer_layer_type` should be in the state dict.\n            starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_lumina2_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --resolution=32\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            --max_sequence_length 16\n            \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_lora_lumina2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --resolution=32\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            --max_sequence_length 166\n            \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-2\", \"checkpoint-4\"})\n\n            resume_run_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --resolution=32\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            --max_sequence_length 16\n            \"\"\".split()\n\n            resume_run_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_lora_qwenimage.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\nfrom diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothLoRAQwenImage(ExamplesTestsAccelerate):\n    instance_data_dir = \"docs/source/en/imgs\"\n    instance_prompt = \"photo\"\n    pretrained_model_name_or_path = \"hf-internal-testing/tiny-qwenimage-pipe\"\n    script_path = \"examples/dreambooth/train_dreambooth_lora_qwen_image.py\"\n    transformer_layer_type = \"transformer_blocks.0.attn.to_k\"\n\n    def test_dreambooth_lora_qwen(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_latent_caching(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_layers(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lora_layers {self.transformer_layer_type}\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names. In this test, we only params of\n            # transformer.transformer_blocks.0.attn.to_k should be in the state dict\n            starts_with_transformer = all(\n                key.startswith(f\"transformer.{self.transformer_layer_type}\") for key in lora_state_dict.keys()\n            )\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-2\", \"checkpoint-4\"})\n\n            resume_run_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n\n    def test_dreambooth_lora_with_metadata(self):\n        # Use a `lora_alpha` that is different from `rank`.\n        lora_alpha = 8\n        rank = 4\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --lora_alpha={lora_alpha}\n                --rank={rank}\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            state_dict_file = os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")\n            self.assertTrue(os.path.isfile(state_dict_file))\n\n            # Check if the metadata was properly serialized.\n            with safetensors.torch.safe_open(state_dict_file, framework=\"pt\", device=\"cpu\") as f:\n                metadata = f.metadata() or {}\n\n            metadata.pop(\"format\", None)\n            raw = metadata.get(LORA_ADAPTER_METADATA_KEY)\n            if raw:\n                raw = json.loads(raw)\n\n            loaded_lora_alpha = raw[\"transformer.lora_alpha\"]\n            self.assertTrue(loaded_lora_alpha == lora_alpha)\n            loaded_lora_rank = raw[\"transformer.r\"]\n            self.assertTrue(loaded_lora_rank == rank)\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_lora_sana.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\nfrom diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothLoRASANA(ExamplesTestsAccelerate):\n    instance_data_dir = \"docs/source/en/imgs\"\n    pretrained_model_name_or_path = \"hf-internal-testing/tiny-sana-pipe\"\n    script_path = \"examples/dreambooth/train_dreambooth_lora_sana.py\"\n    transformer_layer_type = \"transformer_blocks.0.attn1.to_k\"\n\n    def test_dreambooth_lora_sana(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --resolution 32\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --max_sequence_length 16\n                \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_latent_caching(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --resolution 32\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --max_sequence_length 16\n                \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_layers(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --resolution 32\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lora_layers {self.transformer_layer_type}\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --max_sequence_length 16\n                \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names. In this test, we only params of\n            # `self.transformer_layer_type` should be in the state dict.\n            starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --resolution=32\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            --max_sequence_length 16\n            \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_lora_sana_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --resolution=32\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            --max_sequence_length 166\n            \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-2\", \"checkpoint-4\"})\n\n            resume_run_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --resolution=32\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            --max_sequence_length 16\n            \"\"\".split()\n\n            resume_run_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n\n    def test_dreambooth_lora_sana_with_metadata(self):\n        lora_alpha = 8\n        rank = 4\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --resolution=32\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --lora_alpha={lora_alpha}\n            --rank={rank}\n            --checkpointing_steps=2\n            --max_sequence_length 166\n            \"\"\".split()\n\n            test_args.extend([\"--instance_prompt\", \"\"])\n            run_command(self._launch_args + test_args)\n\n            state_dict_file = os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")\n            self.assertTrue(os.path.isfile(state_dict_file))\n\n            # Check if the metadata was properly serialized.\n            with safetensors.torch.safe_open(state_dict_file, framework=\"pt\", device=\"cpu\") as f:\n                metadata = f.metadata() or {}\n\n            metadata.pop(\"format\", None)\n            raw = metadata.get(LORA_ADAPTER_METADATA_KEY)\n            if raw:\n                raw = json.loads(raw)\n\n            loaded_lora_alpha = raw[\"transformer.lora_alpha\"]\n            self.assertTrue(loaded_lora_alpha == lora_alpha)\n            loaded_lora_rank = raw[\"transformer.r\"]\n            self.assertTrue(loaded_lora_rank == rank)\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_lora_sd3.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothLoRASD3(ExamplesTestsAccelerate):\n    instance_data_dir = \"docs/source/en/imgs\"\n    instance_prompt = \"photo\"\n    pretrained_model_name_or_path = \"hf-internal-testing/tiny-sd3-pipe\"\n    script_path = \"examples/dreambooth/train_dreambooth_lora_sd3.py\"\n\n    transformer_block_idx = 0\n    layer_type = \"attn.to_k\"\n\n    def test_dreambooth_lora_sd3(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_text_encoder_sd3(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --train_text_encoder\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            starts_with_expected_prefix = all(\n                (key.startswith(\"transformer\") or key.startswith(\"text_encoder\")) for key in lora_state_dict.keys()\n            )\n            self.assertTrue(starts_with_expected_prefix)\n\n    def test_dreambooth_lora_latent_caching(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --cache_latents\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            starts_with_transformer = all(key.startswith(\"transformer\") for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_block(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --lora_blocks {self.transformer_block_idx}\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"transformer\"` in their names.\n            # In this test, only params of transformer block 0 should be in the state dict\n            starts_with_transformer = all(\n                key.startswith(\"transformer.transformer_blocks.0\") for key in lora_state_dict.keys()\n            )\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_layer(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --lora_layers {self.layer_type}\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # In this test, only transformer params of attention layers `attn.to_k` should be in the state dict\n            starts_with_transformer = all(\"attn.to_k\" in key for key in lora_state_dict.keys())\n            self.assertTrue(starts_with_transformer)\n\n    def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-2\", \"checkpoint-4\"})\n\n            resume_run_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n"
  },
  {
    "path": "examples/dreambooth/test_dreambooth_sd3.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport shutil\nimport sys\nimport tempfile\n\nfrom diffusers import DiffusionPipeline, SD3Transformer2DModel\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass DreamBoothSD3(ExamplesTestsAccelerate):\n    instance_data_dir = \"docs/source/en/imgs\"\n    instance_prompt = \"photo\"\n    pretrained_model_name_or_path = \"hf-internal-testing/tiny-sd3-pipe\"\n    script_path = \"examples/dreambooth/train_dreambooth_sd3.py\"\n\n    def test_dreambooth(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"transformer\", \"diffusion_pytorch_model.safetensors\")))\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"scheduler\", \"scheduler_config.json\")))\n\n    def test_dreambooth_checkpointing(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            # Run training script with checkpointing\n            # max_train_steps == 4, checkpointing_steps == 2\n            # Should create checkpoints at steps 2, 4\n\n            initial_run_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 4\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            # check can run the original fully trained output pipeline\n            pipe = DiffusionPipeline.from_pretrained(tmpdir)\n            pipe(self.instance_prompt, num_inference_steps=1)\n\n            # check checkpoint directories exist\n            self.assertTrue(os.path.isdir(os.path.join(tmpdir, \"checkpoint-2\")))\n            self.assertTrue(os.path.isdir(os.path.join(tmpdir, \"checkpoint-4\")))\n\n            # check can run an intermediate checkpoint\n            transformer = SD3Transformer2DModel.from_pretrained(tmpdir, subfolder=\"checkpoint-2/transformer\")\n            pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)\n            pipe(self.instance_prompt, num_inference_steps=1)\n\n            # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming\n            shutil.rmtree(os.path.join(tmpdir, \"checkpoint-2\"))\n\n            # Run training script for 7 total steps resuming from checkpoint 4\n\n            resume_run_args = f\"\"\"\n                {self.script_path}\n                --pretrained_model_name_or_path {self.pretrained_model_name_or_path}\n                --instance_data_dir {self.instance_data_dir}\n                --instance_prompt {self.instance_prompt}\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 6\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --resume_from_checkpoint=checkpoint-4\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            # check can run new fully trained pipeline\n            pipe = DiffusionPipeline.from_pretrained(tmpdir)\n            pipe(self.instance_prompt, num_inference_steps=1)\n\n            # check old checkpoints do not exist\n            self.assertFalse(os.path.isdir(os.path.join(tmpdir, \"checkpoint-2\")))\n\n            # check new checkpoints exist\n            self.assertTrue(os.path.isdir(os.path.join(tmpdir, \"checkpoint-4\")))\n            self.assertTrue(os.path.isdir(os.path.join(tmpdir, \"checkpoint-6\")))\n\n    def test_dreambooth_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=6\n            --checkpoints_total_limit=2\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=4\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\"},\n            )\n\n            resume_run_args = f\"\"\"\n            {self.script_path}\n            --pretrained_model_name_or_path={self.pretrained_model_name_or_path}\n            --instance_data_dir={self.instance_data_dir}\n            --output_dir={tmpdir}\n            --instance_prompt={self.instance_prompt}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=8\n            --checkpointing_steps=2\n            --resume_from_checkpoint=checkpoint-4\n            --checkpoints_total_limit=2\n            \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-6\", \"checkpoint-8\"})\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport gc\nimport importlib\nimport itertools\nimport logging\nimport math\nimport os\nimport shutil\nimport warnings\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, model_info, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom packaging import version\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import compute_snr\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images: list = None,\n    base_model: str = None,\n    train_text_encoder=False,\n    prompt: str = None,\n    repo_folder: str = None,\n    pipeline: DiffusionPipeline = None,\n):\n    img_str = \"\"\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# DreamBooth - {repo_id}\n\nThis is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).\nYou can find some example images in the following. \\n\n{img_str}\n\nDreamBooth for the text encoder was enabled: {train_text_encoder}.\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        prompt=prompt,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\"text-to-image\", \"dreambooth\", \"diffusers-training\"]\n    if isinstance(pipeline, StableDiffusionPipeline):\n        tags.extend([\"stable-diffusion\", \"stable-diffusion-diffusers\"])\n    else:\n        tags.extend([\"if\", \"if-diffusers\"])\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    text_encoder,\n    tokenizer,\n    unet,\n    vae,\n    args,\n    accelerator,\n    weight_dtype,\n    global_step,\n    prompt_embeds,\n    negative_prompt_embeds,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n\n    pipeline_args = {}\n\n    if vae is not None:\n        pipeline_args[\"vae\"] = vae\n\n    # create pipeline (note: unet and vae are loaded again in float32)\n    pipeline = DiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        unet=unet,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n        **pipeline_args,\n    )\n\n    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n    scheduler_args = {}\n\n    if \"variance_type\" in pipeline.scheduler.config:\n        variance_type = pipeline.scheduler.config.variance_type\n\n        if variance_type in [\"learned\", \"learned_range\"]:\n            variance_type = \"fixed_small\"\n\n        scheduler_args[\"variance_type\"] = variance_type\n\n    module = importlib.import_module(\"diffusers\")\n    scheduler_class = getattr(module, args.validation_scheduler)\n    pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.pre_compute_text_embeddings:\n        pipeline_args = {\n            \"prompt_embeds\": prompt_embeds,\n            \"negative_prompt_embeds\": negative_prompt_embeds,\n        }\n    else:\n        pipeline_args = {\"prompt\": args.validation_prompt}\n\n    # run inference\n    generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)\n    images = []\n    if args.validation_images is None:\n        for _ in range(args.num_validation_images):\n            with torch.autocast(\"cuda\"):\n                image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]\n            images.append(image)\n    else:\n        for image in args.validation_images:\n            image = Image.open(image)\n            image = pipeline(**pipeline_args, image=image, generator=generator).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, global_step, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A folder containing the training data of instance images.\",\n    )\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"dreambooth-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more details\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n\n    parser.add_argument(\n        \"--offset_noise\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Fine-tuning against a modified noise\"\n            \" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information.\"\n        ),\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--pre_compute_text_embeddings\",\n        action=\"store_true\",\n        help=\"Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_max_length\",\n        type=int,\n        default=None,\n        required=False,\n        help=\"The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.\",\n    )\n    parser.add_argument(\n        \"--text_encoder_use_attention_mask\",\n        action=\"store_true\",\n        required=False,\n        help=\"Whether to use attention mask for the text encoder\",\n    )\n    parser.add_argument(\n        \"--skip_save_text_encoder\", action=\"store_true\", required=False, help=\"Set to not save text encoder\"\n    )\n    parser.add_argument(\n        \"--validation_images\",\n        required=False,\n        default=None,\n        nargs=\"+\",\n        help=\"Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.\",\n    )\n    parser.add_argument(\n        \"--class_labels_conditioning\",\n        required=False,\n        default=None,\n        help=\"The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.\",\n    )\n    parser.add_argument(\n        \"--validation_scheduler\",\n        type=str,\n        default=\"DPMSolverMultistepScheduler\",\n        choices=[\"DPMSolverMultistepScheduler\", \"DDPMScheduler\"],\n        help=\"Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    if args.train_text_encoder and args.pre_compute_text_embeddings:\n        raise ValueError(\"`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images and the tokenizes prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        tokenizer,\n        class_data_root=None,\n        class_prompt=None,\n        class_num=None,\n        size=512,\n        center_crop=False,\n        encoder_hidden_states=None,\n        class_prompt_encoder_hidden_states=None,\n        tokenizer_max_length=None,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n        self.tokenizer = tokenizer\n        self.encoder_hidden_states = encoder_hidden_states\n        self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states\n        self.tokenizer_max_length = tokenizer_max_length\n\n        self.instance_data_root = Path(instance_data_root)\n        if not self.instance_data_root.exists():\n            raise ValueError(f\"Instance {self.instance_data_root} images root doesn't exists.\")\n\n        self.instance_images_path = list(Path(instance_data_root).iterdir())\n        self.num_instance_images = len(self.instance_images_path)\n        self.instance_prompt = instance_prompt\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n            self.class_prompt = class_prompt\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])\n        instance_image = exif_transpose(instance_image)\n\n        if not instance_image.mode == \"RGB\":\n            instance_image = instance_image.convert(\"RGB\")\n        example[\"instance_images\"] = self.image_transforms(instance_image)\n\n        if self.encoder_hidden_states is not None:\n            example[\"instance_prompt_ids\"] = self.encoder_hidden_states\n        else:\n            text_inputs = tokenize_prompt(\n                self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length\n            )\n            example[\"instance_prompt_ids\"] = text_inputs.input_ids\n            example[\"instance_attention_mask\"] = text_inputs.attention_mask\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n\n            if self.class_prompt_encoder_hidden_states is not None:\n                example[\"class_prompt_ids\"] = self.class_prompt_encoder_hidden_states\n            else:\n                class_text_inputs = tokenize_prompt(\n                    self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length\n                )\n                example[\"class_prompt_ids\"] = class_text_inputs.input_ids\n                example[\"class_attention_mask\"] = class_text_inputs.attention_mask\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    has_attention_mask = \"instance_attention_mask\" in examples[0]\n\n    input_ids = [example[\"instance_prompt_ids\"] for example in examples]\n    pixel_values = [example[\"instance_images\"] for example in examples]\n\n    if has_attention_mask:\n        attention_mask = [example[\"instance_attention_mask\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        input_ids += [example[\"class_prompt_ids\"] for example in examples]\n        pixel_values += [example[\"class_images\"] for example in examples]\n\n        if has_attention_mask:\n            attention_mask += [example[\"class_attention_mask\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    input_ids = torch.cat(input_ids, dim=0)\n\n    batch = {\n        \"input_ids\": input_ids,\n        \"pixel_values\": pixel_values,\n    }\n\n    if has_attention_mask:\n        attention_mask = torch.cat(attention_mask, dim=0)\n        batch[\"attention_mask\"] = attention_mask\n\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef model_has_vae(args):\n    config_file_name = Path(\"vae\", AutoencoderKL.config_name).as_posix()\n    if os.path.isdir(args.pretrained_model_name_or_path):\n        config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)\n        return os.path.isfile(config_file_name)\n    else:\n        files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings\n        return any(file.rfilename == config_file_name for file in files_in_repo)\n\n\ndef tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):\n    if tokenizer_max_length is not None:\n        max_length = tokenizer_max_length\n    else:\n        max_length = tokenizer.model_max_length\n\n    text_inputs = tokenizer(\n        prompt,\n        truncation=True,\n        padding=\"max_length\",\n        max_length=max_length,\n        return_tensors=\"pt\",\n    )\n\n    return text_inputs\n\n\ndef encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):\n    text_input_ids = input_ids.to(text_encoder.device)\n\n    if text_encoder_use_attention_mask:\n        attention_mask = attention_mask.to(text_encoder.device)\n    else:\n        attention_mask = None\n\n    prompt_embeds = text_encoder(\n        text_input_ids,\n        attention_mask=attention_mask,\n        return_dict=False,\n    )\n    prompt_embeds = prompt_embeds[0]\n\n    return prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate\n    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.\n    # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.\n    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:\n        raise ValueError(\n            \"Gradient accumulation is not supported when training the text encoder in distributed training. \"\n            \"Please set gradient_accumulation_steps to 1. This feature will be supported in the future.\"\n        )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            torch_dtype = torch.float16 if accelerator.device.type == \"cuda\" else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n            pipeline = DiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                safety_checker=None,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    if args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.pretrained_model_name_or_path,\n            subfolder=\"tokenizer\",\n            revision=args.revision,\n            use_fast=False,\n        )\n\n    # import correct text encoder class\n    text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = text_encoder_cls.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n\n    if model_has_vae(args):\n        vae = AutoencoderKL.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n        )\n    else:\n        vae = None\n\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            for model in models:\n                sub_dir = \"unet\" if isinstance(model, type(unwrap_model(unet))) else \"text_encoder\"\n                model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n    def load_model_hook(models, input_dir):\n        while len(models) > 0:\n            # pop models so that they are not loaded again\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(text_encoder))):\n                # load transformers style into model\n                load_model = text_encoder_cls.from_pretrained(input_dir, subfolder=\"text_encoder\")\n                model.config = load_model.config\n            else:\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n            model.load_state_dict(load_model.state_dict())\n            del load_model\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if vae is not None:\n        vae.requires_grad_(False)\n\n    if not args.train_text_encoder:\n        text_encoder.requires_grad_(False)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder.gradient_checkpointing_enable()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \"Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training. copy of the weights should still be float32.\"\n    )\n\n    if unwrap_model(unet).dtype != torch.float32:\n        raise ValueError(f\"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}\")\n\n    if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:\n        raise ValueError(\n            f\"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}\"\n        )\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = (\n        itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()\n    )\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    if args.pre_compute_text_embeddings:\n\n        def compute_text_embeddings(prompt):\n            with torch.no_grad():\n                text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)\n                prompt_embeds = encode_prompt(\n                    text_encoder,\n                    text_inputs.input_ids,\n                    text_inputs.attention_mask,\n                    text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,\n                )\n\n            return prompt_embeds\n\n        pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)\n        validation_prompt_negative_prompt_embeds = compute_text_embeddings(\"\")\n\n        if args.validation_prompt is not None:\n            validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)\n        else:\n            validation_prompt_encoder_hidden_states = None\n\n        if args.class_prompt is not None:\n            pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)\n        else:\n            pre_computed_class_prompt_encoder_hidden_states = None\n\n        text_encoder = None\n        tokenizer = None\n\n        gc.collect()\n        torch.cuda.empty_cache()\n    else:\n        pre_computed_encoder_hidden_states = None\n        validation_prompt_encoder_hidden_states = None\n        validation_prompt_negative_prompt_embeds = None\n        pre_computed_class_prompt_encoder_hidden_states = None\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_prompt=args.class_prompt,\n        class_num=args.num_class_images,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        center_crop=args.center_crop,\n        encoder_hidden_states=pre_computed_encoder_hidden_states,\n        class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,\n        tokenizer_max_length=args.tokenizer_max_length,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae and text_encoder to device and cast to weight_dtype\n    if vae is not None:\n        vae.to(accelerator.device, dtype=weight_dtype)\n\n    if not args.train_text_encoder and text_encoder is not None:\n        text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = vars(copy.deepcopy(args))\n        tracker_config.pop(\"validation_images\")\n        accelerator.init_trackers(\"dreambooth\", config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.train_text_encoder:\n            text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                pixel_values = batch[\"pixel_values\"].to(dtype=weight_dtype)\n\n                if vae is not None:\n                    # Convert images to latent space\n                    model_input = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                    model_input = model_input * vae.config.scaling_factor\n                else:\n                    model_input = pixel_values\n\n                # Sample noise that we'll add to the model input\n                if args.offset_noise:\n                    noise = torch.randn_like(model_input) + 0.1 * torch.randn(\n                        model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device\n                    )\n                else:\n                    noise = torch.randn_like(model_input)\n                bsz, channels, height, width = model_input.shape\n                # Sample a random timestep for each image\n                timesteps = torch.randint(\n                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                )\n                timesteps = timesteps.long()\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                if args.pre_compute_text_embeddings:\n                    encoder_hidden_states = batch[\"input_ids\"]\n                else:\n                    encoder_hidden_states = encode_prompt(\n                        text_encoder,\n                        batch[\"input_ids\"],\n                        batch[\"attention_mask\"],\n                        text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,\n                    )\n\n                if unwrap_model(unet).config.in_channels == channels * 2:\n                    noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)\n\n                if args.class_labels_conditioning == \"timesteps\":\n                    class_labels = timesteps\n                else:\n                    class_labels = None\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False\n                )[0]\n\n                if model_pred.shape[1] == 6:\n                    model_pred, _ = torch.chunk(model_pred, 2, dim=1)\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n                    # Compute prior loss\n                    prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction=\"mean\")\n\n                # Compute instance loss\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n\n                    if noise_scheduler.config.prediction_type == \"v_prediction\":\n                        # Velocity objective needs to be floored to an SNR weight of one.\n                        divisor = snr + 1\n                    else:\n                        divisor = snr\n\n                    mse_loss_weights = (\n                        torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / divisor\n                    )\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(unet.parameters(), text_encoder.parameters())\n                        if args.train_text_encoder\n                        else unet.parameters()\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    images = []\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        images = log_validation(\n                            unwrap_model(text_encoder) if text_encoder is not None else text_encoder,\n                            tokenizer,\n                            unwrap_model(unet),\n                            vae,\n                            args,\n                            accelerator,\n                            weight_dtype,\n                            global_step,\n                            validation_prompt_encoder_hidden_states,\n                            validation_prompt_negative_prompt_embeds,\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        pipeline_args = {}\n\n        if text_encoder is not None:\n            pipeline_args[\"text_encoder\"] = unwrap_model(text_encoder)\n\n        if args.skip_save_text_encoder:\n            pipeline_args[\"text_encoder\"] = None\n\n        pipeline = DiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            unet=unwrap_model(unet),\n            revision=args.revision,\n            variant=args.variant,\n            **pipeline_args,\n        )\n\n        # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n        scheduler_args = {}\n\n        if \"variance_type\" in pipeline.scheduler.config:\n            variance_type = pipeline.scheduler.config.variance_type\n\n            if variance_type in [\"learned\", \"learned_range\"]:\n                variance_type = \"fixed_small\"\n\n            scheduler_args[\"variance_type\"] = variance_type\n\n        pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n\n        pipeline.save_pretrained(args.output_dir)\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                train_text_encoder=args.train_text_encoder,\n                prompt=args.instance_prompt,\n                repo_folder=args.output_dir,\n                pipeline=pipeline,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_flax.py",
    "content": "import argparse\nimport logging\nimport math\nimport os\nfrom pathlib import Path\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom flax import jax_utils\nfrom flax.training import train_state\nfrom flax.training.common_utils import shard\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom jax.experimental.compilation_cache import compilation_cache as cc\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed\n\nfrom diffusers import (\n    FlaxAutoencoderKL,\n    FlaxDDPMScheduler,\n    FlaxPNDMScheduler,\n    FlaxStableDiffusionPipeline,\n    FlaxUNet2DConditionModel,\n)\nfrom diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker\nfrom diffusers.utils import check_min_version\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\n# Cache compiled models across invocations of this script.\ncc.initialize_cache(os.path.expanduser(\"~/.cache/jax/compilation_cache\"))\n\nlogger = logging.getLogger(__name__)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained vae or vae identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A folder containing the training data of instance images.\",\n    )\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt with identifier specifying the instance\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--save_steps\", type=int, default=None, help=\"Save a checkpoint every X steps.\")\n    parser.add_argument(\"--seed\", type=int, default=0, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\"--train_text_encoder\", action=\"store_true\", help=\"Whether to train the text encoder\")\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.instance_data_dir is None:\n        raise ValueError(\"You must specify a train data directory.\")\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images and the tokenizes prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        tokenizer,\n        class_data_root=None,\n        class_prompt=None,\n        class_num=None,\n        size=512,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n        self.tokenizer = tokenizer\n\n        self.instance_data_root = Path(instance_data_root)\n        if not self.instance_data_root.exists():\n            raise ValueError(\"Instance images root doesn't exists.\")\n\n        self.instance_images_path = list(Path(instance_data_root).iterdir())\n        self.num_instance_images = len(self.instance_images_path)\n        self.instance_prompt = instance_prompt\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n            self.class_prompt = class_prompt\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])\n        if not instance_image.mode == \"RGB\":\n            instance_image = instance_image.convert(\"RGB\")\n        example[\"instance_images\"] = self.image_transforms(instance_image)\n        example[\"instance_prompt_ids\"] = self.tokenizer(\n            self.instance_prompt,\n            padding=\"do_not_pad\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n        ).input_ids\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt_ids\"] = self.tokenizer(\n                self.class_prompt,\n                padding=\"do_not_pad\",\n                truncation=True,\n                max_length=self.tokenizer.model_max_length,\n            ).input_ids\n\n        return example\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef get_params_to_save(params):\n    return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))\n\n\ndef main():\n    args = parse_args()\n\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    # Setup logging, we only want one process per machine to log things on the screen.\n    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)\n    if jax.process_index() == 0:\n        transformers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    rng = jax.random.PRNGKey(args.seed)\n\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path, safety_checker=None, revision=args.revision\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            total_sample_batch_size = args.sample_batch_size * jax.local_device_count()\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not jax.process_index() == 0\n            ):\n                prompt_ids = pipeline.prepare_inputs(example[\"prompt\"])\n                prompt_ids = shard(prompt_ids)\n                p_params = jax_utils.replicate(params)\n                rng = jax.random.split(rng)[0]\n                sample_rng = jax.random.split(rng, jax.device_count())\n                images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images\n                images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])\n                images = pipeline.numpy_to_pil(np.array(images))\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n\n    # Handle the repository creation\n    if jax.process_index() == 0:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer and add the placeholder token as a additional special token\n    if args.tokenizer_name:\n        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = CLIPTokenizer.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n        )\n    else:\n        raise NotImplementedError(\"No tokenizer specified!\")\n\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_prompt=args.class_prompt,\n        class_num=args.num_class_images,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        center_crop=args.center_crop,\n    )\n\n    def collate_fn(examples):\n        input_ids = [example[\"instance_prompt_ids\"] for example in examples]\n        pixel_values = [example[\"instance_images\"] for example in examples]\n\n        # Concat class and instance examples for prior preservation.\n        # We do this to avoid doing two forward passes.\n        if args.with_prior_preservation:\n            input_ids += [example[\"class_prompt_ids\"] for example in examples]\n            pixel_values += [example[\"class_images\"] for example in examples]\n\n        pixel_values = torch.stack(pixel_values)\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n        input_ids = tokenizer.pad(\n            {\"input_ids\": input_ids}, padding=\"max_length\", max_length=tokenizer.model_max_length, return_tensors=\"pt\"\n        ).input_ids\n\n        batch = {\n            \"input_ids\": input_ids,\n            \"pixel_values\": pixel_values,\n        }\n        batch = {k: v.numpy() for k, v in batch.items()}\n        return batch\n\n    total_train_batch_size = args.train_batch_size * jax.local_device_count()\n    if len(train_dataset) < total_train_batch_size:\n        raise ValueError(\n            f\"Training batch size is {total_train_batch_size}, but your dataset only contains\"\n            f\" {len(train_dataset)} images. Please, use a larger dataset or reduce the effective batch size. Note that\"\n            f\" there are {jax.local_device_count()} parallel devices, so your batch size can't be smaller than that.\"\n        )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=total_train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True\n    )\n\n    weight_dtype = jnp.float32\n    if args.mixed_precision == \"fp16\":\n        weight_dtype = jnp.float16\n    elif args.mixed_precision == \"bf16\":\n        weight_dtype = jnp.bfloat16\n\n    if args.pretrained_vae_name_or_path:\n        # TODO(patil-suraj): Upload flax weights for the VAE\n        vae_arg, vae_kwargs = (args.pretrained_vae_name_or_path, {\"from_pt\": True})\n    else:\n        vae_arg, vae_kwargs = (args.pretrained_model_name_or_path, {\"subfolder\": \"vae\", \"revision\": args.revision})\n\n    # Load models and create wrapper for stable diffusion\n    text_encoder = FlaxCLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        dtype=weight_dtype,\n        revision=args.revision,\n    )\n    vae, vae_params = FlaxAutoencoderKL.from_pretrained(\n        vae_arg,\n        dtype=weight_dtype,\n        **vae_kwargs,\n    )\n    unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"unet\",\n        dtype=weight_dtype,\n        revision=args.revision,\n    )\n\n    # Optimization\n    if args.scale_lr:\n        args.learning_rate = args.learning_rate * total_train_batch_size\n\n    constant_scheduler = optax.constant_schedule(args.learning_rate)\n\n    adamw = optax.adamw(\n        learning_rate=constant_scheduler,\n        b1=args.adam_beta1,\n        b2=args.adam_beta2,\n        eps=args.adam_epsilon,\n        weight_decay=args.adam_weight_decay,\n    )\n\n    optimizer = optax.chain(\n        optax.clip_by_global_norm(args.max_grad_norm),\n        adamw,\n    )\n\n    unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)\n    text_encoder_state = train_state.TrainState.create(\n        apply_fn=text_encoder.__call__, params=text_encoder.params, tx=optimizer\n    )\n\n    noise_scheduler = FlaxDDPMScheduler(\n        beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000\n    )\n    noise_scheduler_state = noise_scheduler.create_state()\n\n    # Initialize our training\n    train_rngs = jax.random.split(rng, jax.local_device_count())\n\n    def train_step(unet_state, text_encoder_state, vae_params, batch, train_rng):\n        dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)\n\n        if args.train_text_encoder:\n            params = {\"text_encoder\": text_encoder_state.params, \"unet\": unet_state.params}\n        else:\n            params = {\"unet\": unet_state.params}\n\n        def compute_loss(params):\n            # Convert images to latent space\n            vae_outputs = vae.apply(\n                {\"params\": vae_params}, batch[\"pixel_values\"], deterministic=True, method=vae.encode\n            )\n            latents = vae_outputs.latent_dist.sample(sample_rng)\n            # (NHWC) -> (NCHW)\n            latents = jnp.transpose(latents, (0, 3, 1, 2))\n            latents = latents * vae.config.scaling_factor\n\n            # Sample noise that we'll add to the latents\n            noise_rng, timestep_rng = jax.random.split(sample_rng)\n            noise = jax.random.normal(noise_rng, latents.shape)\n            # Sample a random timestep for each image\n            bsz = latents.shape[0]\n            timesteps = jax.random.randint(\n                timestep_rng,\n                (bsz,),\n                0,\n                noise_scheduler.config.num_train_timesteps,\n            )\n\n            # Add noise to the latents according to the noise magnitude at each timestep\n            # (this is the forward diffusion process)\n            noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)\n\n            # Get the text embedding for conditioning\n            if args.train_text_encoder:\n                encoder_hidden_states = text_encoder_state.apply_fn(\n                    batch[\"input_ids\"], params=params[\"text_encoder\"], dropout_rng=dropout_rng, train=True\n                )[0]\n            else:\n                encoder_hidden_states = text_encoder(\n                    batch[\"input_ids\"], params=text_encoder_state.params, train=False\n                )[0]\n\n            # Predict the noise residual\n            model_pred = unet.apply(\n                {\"params\": params[\"unet\"]}, noisy_latents, timesteps, encoder_hidden_states, train=True\n            ).sample\n\n            # Get the target for loss depending on the prediction type\n            if noise_scheduler.config.prediction_type == \"epsilon\":\n                target = noise\n            elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)\n            else:\n                raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n            if args.with_prior_preservation:\n                # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.\n                model_pred, model_pred_prior = jnp.split(model_pred, 2, axis=0)\n                target, target_prior = jnp.split(target, 2, axis=0)\n\n                # Compute instance loss\n                loss = (target - model_pred) ** 2\n                loss = loss.mean()\n\n                # Compute prior loss\n                prior_loss = (target_prior - model_pred_prior) ** 2\n                prior_loss = prior_loss.mean()\n\n                # Add the prior loss to the instance loss.\n                loss = loss + args.prior_loss_weight * prior_loss\n            else:\n                loss = (target - model_pred) ** 2\n                loss = loss.mean()\n\n            return loss\n\n        grad_fn = jax.value_and_grad(compute_loss)\n        loss, grad = grad_fn(params)\n        grad = jax.lax.pmean(grad, \"batch\")\n\n        new_unet_state = unet_state.apply_gradients(grads=grad[\"unet\"])\n        if args.train_text_encoder:\n            new_text_encoder_state = text_encoder_state.apply_gradients(grads=grad[\"text_encoder\"])\n        else:\n            new_text_encoder_state = text_encoder_state\n\n        metrics = {\"loss\": loss}\n        metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n\n        return new_unet_state, new_text_encoder_state, metrics, new_train_rng\n\n    # Create parallel version of the train step\n    p_train_step = jax.pmap(train_step, \"batch\", donate_argnums=(0, 1))\n\n    # Replicate the train state on each device\n    unet_state = jax_utils.replicate(unet_state)\n    text_encoder_state = jax_utils.replicate(text_encoder_state)\n    vae_params = jax_utils.replicate(vae_params)\n\n    # Train!\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader))\n\n    # Scheduler and math around the number of training steps.\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel & distributed) = {total_train_batch_size}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n\n    def checkpoint(step=None):\n        # Create the pipeline using the trained modules and save it.\n        scheduler, _ = FlaxPNDMScheduler.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"scheduler\")\n        safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(\n            \"CompVis/stable-diffusion-safety-checker\", from_pt=True\n        )\n        pipeline = FlaxStableDiffusionPipeline(\n            text_encoder=text_encoder,\n            vae=vae,\n            unet=unet,\n            tokenizer=tokenizer,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=CLIPImageProcessor.from_pretrained(\"openai/clip-vit-base-patch32\"),\n        )\n\n        outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir\n        pipeline.save_pretrained(\n            outdir,\n            params={\n                \"text_encoder\": get_params_to_save(text_encoder_state.params),\n                \"vae\": get_params_to_save(vae_params),\n                \"unet\": get_params_to_save(unet_state.params),\n                \"safety_checker\": safety_checker.params,\n            },\n        )\n\n        if args.push_to_hub:\n            message = f\"checkpoint-{step}\" if step is not None else \"End of training\"\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=message,\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    global_step = 0\n\n    epochs = tqdm(range(args.num_train_epochs), desc=\"Epoch ... \", position=0)\n    for epoch in epochs:\n        # ======================== Training ================================\n\n        train_metrics = []\n\n        steps_per_epoch = len(train_dataset) // total_train_batch_size\n        train_step_progress_bar = tqdm(total=steps_per_epoch, desc=\"Training...\", position=1, leave=False)\n        # train\n        for batch in train_dataloader:\n            batch = shard(batch)\n            unet_state, text_encoder_state, train_metric, train_rngs = p_train_step(\n                unet_state, text_encoder_state, vae_params, batch, train_rngs\n            )\n            train_metrics.append(train_metric)\n\n            train_step_progress_bar.update(jax.local_device_count())\n\n            global_step += 1\n            if jax.process_index() == 0 and args.save_steps and global_step % args.save_steps == 0:\n                checkpoint(global_step)\n            if global_step >= args.max_train_steps:\n                break\n\n        train_metric = jax_utils.unreplicate(train_metric)\n\n        train_step_progress_bar.close()\n        epochs.write(f\"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})\")\n\n    if jax.process_index() == 0:\n        checkpoint()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_flux.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=0.31.0\",\n#     \"transformers>=4.41.2\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.11.1\",\n#     \"sentencepiece\",\n# ]\n# ///\n\nimport argparse\nimport copy\nimport gc\nimport itertools\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    FlowMatchEulerDiscreteScheduler,\n    FluxPipeline,\n    FluxTransformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3\nfrom diffusers.utils import (\n    check_min_version,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\nif is_torch_npu_available():\n    import torch_npu\n\n    torch.npu.config.allow_internal_format = False\n    torch.npu.set_compile_mode(jit_compile=False)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    train_text_encoder=False,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# Flux [dev] DreamBooth - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).\n\nWas the text encoder fine-tuned? {train_text_encoder}.\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\npipeline = AutoPipelineForText2Image.from_pretrained('{repo_id}', torch_dtype=torch.bfloat16).to('cuda')\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"flux\",\n        \"flux-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef load_text_encoders(class_one, class_two):\n    text_encoder_one = class_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = class_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    return text_encoder_one, text_encoder_two\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n    autocast_ctx = nullcontext()\n\n    with autocast_ctx:\n        images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n    elif is_torch_npu_available():\n        torch_npu.npu.empty_cache()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=77,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-dreambooth\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the FLUX.1 dev variant is a guidance distilled model\",\n    )\n\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n    parser.add_argument(\"--enable_npu_flash_attention\", action=\"store_true\", help=\"Enabla Flash Attention for NPU\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        self.pixel_values = []\n        interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)\n        if interpolation is None:\n            raise ValueError(f\"Unsupported interpolation mode {interpolation=}.\")\n        train_resize = transforms.Resize(size, interpolation=interpolation)\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in self.instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            image = train_transforms(image)\n            self.pixel_values.append(image)\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef tokenize_prompt(tokenizer, prompt, max_sequence_length):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=max_sequence_length,\n        truncation=True,\n        return_length=False,\n        return_overflowing_tokens=False,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\ndef _encode_prompt_with_t5(\n    text_encoder,\n    tokenizer,\n    max_sequence_length=512,\n    prompt=None,\n    num_images_per_prompt=1,\n    device=None,\n    text_input_ids=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device))[0]\n\n    if hasattr(text_encoder, \"module\"):\n        dtype = text_encoder.module.dtype\n    else:\n        dtype = text_encoder.dtype\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    _, seq_len, _ = prompt_embeds.shape\n\n    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds\n\n\ndef _encode_prompt_with_clip(\n    text_encoder,\n    tokenizer,\n    prompt: str,\n    device=None,\n    text_input_ids=None,\n    num_images_per_prompt: int = 1,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=77,\n            truncation=True,\n            return_overflowing_tokens=False,\n            return_length=False,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n    if hasattr(text_encoder, \"module\"):\n        dtype = text_encoder.module.dtype\n    else:\n        dtype = text_encoder.dtype\n    # Use pooled output of CLIPTextModel\n    prompt_embeds = prompt_embeds.pooler_output\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    # duplicate text embeddings for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n    return prompt_embeds\n\n\ndef encode_prompt(\n    text_encoders,\n    tokenizers,\n    prompt: str,\n    max_sequence_length,\n    device=None,\n    num_images_per_prompt: int = 1,\n    text_input_ids_list=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if hasattr(text_encoders[0], \"module\"):\n        dtype = text_encoders[0].module.dtype\n    else:\n        dtype = text_encoders[0].dtype\n\n    device = device if device is not None else text_encoders[1].device\n    pooled_prompt_embeds = _encode_prompt_with_clip(\n        text_encoder=text_encoders[0],\n        tokenizer=tokenizers[0],\n        prompt=prompt,\n        device=device,\n        num_images_per_prompt=num_images_per_prompt,\n        text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,\n    )\n\n    prompt_embeds = _encode_prompt_with_t5(\n        text_encoder=text_encoders[1],\n        tokenizer=tokenizers[1],\n        max_sequence_length=max_sequence_length,\n        prompt=prompt,\n        num_images_per_prompt=num_images_per_prompt,\n        device=device,\n        text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,\n    )\n\n    text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\n    text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)\n\n    return prompt_embeds, pooled_prompt_embeds, text_ids\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            has_supported_fp16_accelerator = (\n                torch.cuda.is_available() or torch.backends.mps.is_available() or is_torch_npu_available()\n            )\n            torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n            pipeline = FluxPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n            elif is_torch_npu_available():\n                torch_npu.npu.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n    tokenizer_two = T5TokenizerFast.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    transformer = FluxTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n    )\n\n    transformer.requires_grad_(True)\n    vae.requires_grad_(False)\n    if args.train_text_encoder:\n        text_encoder_one.requires_grad_(True)\n        text_encoder_two.requires_grad_(False)\n    else:\n        text_encoder_one.requires_grad_(False)\n        text_encoder_two.requires_grad_(False)\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            transformer.set_attention_backend(\"_native_npu\")\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu device \")\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    vae.to(accelerator.device, dtype=weight_dtype)\n    if not args.train_text_encoder:\n        text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n        text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder_one.gradient_checkpointing_enable()\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            for i, model in enumerate(models):\n                if isinstance(unwrap_model(model), FluxTransformer2DModel):\n                    unwrap_model(model).save_pretrained(os.path.join(output_dir, \"transformer\"))\n                elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):\n                    if isinstance(unwrap_model(model), CLIPTextModelWithProjection):\n                        unwrap_model(model).save_pretrained(os.path.join(output_dir, \"text_encoder\"))\n                    else:\n                        unwrap_model(model).save_pretrained(os.path.join(output_dir, \"text_encoder_2\"))\n                else:\n                    raise ValueError(f\"Wrong model supplied: {type(model)=}.\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n    def load_model_hook(models, input_dir):\n        for _ in range(len(models)):\n            # pop models so that they are not loaded again\n            model = models.pop()\n\n            # load diffusers style into model\n            if isinstance(unwrap_model(model), FluxTransformer2DModel):\n                load_model = FluxTransformer2DModel.from_pretrained(input_dir, subfolder=\"transformer\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n            elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):\n                try:\n                    load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder=\"text_encoder\")\n                    model(**load_model.config)\n                    model.load_state_dict(load_model.state_dict())\n                except Exception:\n                    try:\n                        load_model = T5EncoderModel.from_pretrained(input_dir, subfolder=\"text_encoder_2\")\n                        model(**load_model.config)\n                        model.load_state_dict(load_model.state_dict())\n                    except Exception:\n                        raise ValueError(f\"Couldn't load the model of type: ({type(model)}).\")\n            else:\n                raise ValueError(f\"Unsupported model found: {type(model)=}\")\n\n            del load_model\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer.parameters(), \"lr\": args.learning_rate}\n    if args.train_text_encoder:\n        # different learning rate for text encoder and unet\n        text_parameters_one_with_lr = {\n            \"params\": text_encoder_one.parameters(),\n            \"weight_decay\": args.adam_weight_decay_text_encoder,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]\n    else:\n        params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n        if args.train_text_encoder and args.text_encoder_lr:\n            logger.warning(\n                f\"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:\"\n                f\" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. \"\n                f\"When using prodigy only learning_rate is used as the initial learning rate.\"\n            )\n            # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be\n            # --learning_rate\n            params_to_optimize[1][\"lr\"] = args.learning_rate\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    if not args.train_text_encoder:\n        tokenizers = [tokenizer_one, tokenizer_two]\n        text_encoders = [text_encoder_one, text_encoder_two]\n\n        def compute_text_embeddings(prompt, text_encoders, tokenizers):\n            with torch.no_grad():\n                prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n                    text_encoders, tokenizers, prompt, args.max_sequence_length\n                )\n                prompt_embeds = prompt_embeds.to(accelerator.device)\n                pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)\n                text_ids = text_ids.to(accelerator.device)\n            return prompt_embeds, pooled_prompt_embeds, text_ids\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings(\n            args.instance_prompt, text_encoders, tokenizers\n        )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        if not args.train_text_encoder:\n            class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings(\n                args.class_prompt, text_encoders, tokenizers\n            )\n\n    # Clear the memory here\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        del tokenizers, text_encoders\n        # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection\n        del text_encoder_one, text_encoder_two\n        gc.collect()\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n        elif is_torch_npu_available():\n            torch_npu.npu.empty_cache()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n\n    if not train_dataset.custom_instance_prompts:\n        if not args.train_text_encoder:\n            prompt_embeds = instance_prompt_hidden_states\n            pooled_prompt_embeds = instance_pooled_prompt_embeds\n            text_ids = instance_text_ids\n            if args.with_prior_preservation:\n                prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n                pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)\n                text_ids = torch.cat([text_ids, class_text_ids], dim=0)\n        # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the\n        # batch prompts on all training steps\n        else:\n            tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77)\n            tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt, max_sequence_length=512)\n            if args.with_prior_preservation:\n                class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77)\n                class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt, max_sequence_length=512)\n                tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)\n                tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        (\n            transformer,\n            text_encoder_one,\n            optimizer,\n            train_dataloader,\n            lr_scheduler,\n        ) = accelerator.prepare(\n            transformer,\n            text_encoder_one,\n            optimizer,\n            train_dataloader,\n            lr_scheduler,\n        )\n    else:\n        transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            transformer, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-flux-dev-lora\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n        if args.train_text_encoder:\n            text_encoder_one.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            if args.train_text_encoder:\n                models_to_accumulate.extend([text_encoder_one])\n            with accelerator.accumulate(models_to_accumulate):\n                pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                prompts = batch[\"prompts\"]\n\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    if not args.train_text_encoder:\n                        prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(\n                            prompts, text_encoders, tokenizers\n                        )\n                    else:\n                        tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)\n                        tokens_two = tokenize_prompt(\n                            tokenizer_two, prompts, max_sequence_length=args.max_sequence_length\n                        )\n                        prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n                            text_encoders=[text_encoder_one, text_encoder_two],\n                            tokenizers=[None, None],\n                            text_input_ids_list=[tokens_one, tokens_two],\n                            max_sequence_length=args.max_sequence_length,\n                            prompt=prompts,\n                        )\n                else:\n                    if args.train_text_encoder:\n                        prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n                            text_encoders=[text_encoder_one, text_encoder_two],\n                            tokenizers=[None, None],\n                            text_input_ids_list=[tokens_one, tokens_two],\n                            max_sequence_length=args.max_sequence_length,\n                            prompt=args.instance_prompt,\n                        )\n\n                # Convert images to latent space\n                model_input = vae.encode(pixel_values).latent_dist.sample()\n                model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor\n                model_input = model_input.to(dtype=weight_dtype)\n\n                vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)\n\n                latent_image_ids = FluxPipeline._prepare_latent_image_ids(\n                    model_input.shape[0],\n                    model_input.shape[2] // 2,\n                    model_input.shape[3] // 2,\n                    accelerator.device,\n                    weight_dtype,\n                )\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                packed_noisy_model_input = FluxPipeline._pack_latents(\n                    noisy_model_input,\n                    batch_size=model_input.shape[0],\n                    num_channels_latents=model_input.shape[1],\n                    height=model_input.shape[2],\n                    width=model_input.shape[3],\n                )\n\n                # handle guidance\n                if unwrap_model(transformer).config.guidance_embeds:\n                    guidance = torch.tensor([args.guidance_scale], device=accelerator.device)\n                    guidance = guidance.expand(model_input.shape[0])\n                else:\n                    guidance = None\n\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=packed_noisy_model_input,\n                    # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)\n                    timestep=timesteps / 1000,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    return_dict=False,\n                )[0]\n                # upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042\n                model_pred = FluxPipeline._unpack_latents(\n                    model_pred,\n                    height=model_input.shape[2] * vae_scale_factor,\n                    width=model_input.shape[3] * vae_scale_factor,\n                    vae_scale_factor=vae_scale_factor,\n                )\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = noise - model_input\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(transformer.parameters(), text_encoder_one.parameters())\n                        if args.train_text_encoder\n                        else transformer.parameters()\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                if not args.train_text_encoder:\n                    text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)\n                    text_encoder_one.to(weight_dtype)\n                    text_encoder_two.to(weight_dtype)\n                else:  # even when training the text encoder we're only training text encoder one\n                    text_encoder_two = text_encoder_cls_two.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        subfolder=\"text_encoder_2\",\n                        revision=args.revision,\n                        variant=args.variant,\n                    )\n                pipeline = FluxPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    text_encoder=unwrap_model(text_encoder_one, keep_fp32_wrapper=False),\n                    text_encoder_2=unwrap_model(text_encoder_two, keep_fp32_wrapper=False),\n                    transformer=unwrap_model(transformer, keep_fp32_wrapper=False),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                    torch_dtype=weight_dtype,\n                )\n                if not args.train_text_encoder:\n                    del text_encoder_one, text_encoder_two\n                    if torch.cuda.is_available():\n                        torch.cuda.empty_cache()\n                    elif is_torch_npu_available():\n                        torch_npu.npu.empty_cache()\n                    gc.collect()\n\n                images = None\n                del pipeline\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        transformer = unwrap_model(transformer)\n\n        if args.train_text_encoder:\n            text_encoder_one = unwrap_model(text_encoder_one)\n            pipeline = FluxPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                transformer=transformer,\n                text_encoder=text_encoder_one,\n            )\n        else:\n            pipeline = FluxPipeline.from_pretrained(args.pretrained_model_name_or_path, transformer=transformer)\n\n        # save the pipeline\n        pipeline.save_pretrained(args.output_dir)\n\n        # Final inference\n        # Load previous pipeline\n        pipeline = FluxPipeline.from_pretrained(\n            args.output_dir,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt}\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=pipeline_args,\n                epoch=epoch,\n                is_final_validation=True,\n                torch_dtype=weight_dtype,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                train_text_encoder=args.train_text_encoder,\n                instance_prompt=args.instance_prompt,\n                validation_prompt=args.validation_prompt,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        images = None\n        del pipeline\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport gc\nimport logging\nimport math\nimport os\nimport shutil\nimport warnings\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom packaging import version\nfrom peft import LoraConfig\nfrom peft.utils import get_peft_model_state_dict, set_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    DPMSolverMultistepScheduler,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    _set_state_dict_into_text_encoder,\n    cast_training_params,\n    free_memory,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_state_dict_to_diffusers,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model=str,\n    train_text_encoder=False,\n    prompt=str,\n    repo_folder=None,\n    pipeline: DiffusionPipeline = None,\n):\n    img_str = \"\"\n    for i, image in enumerate(images):\n        image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n        img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# LoRA DreamBooth - {repo_id}\n\nThese are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \\n\n{img_str}\n\nLoRA for the text encoder was enabled: {train_text_encoder}.\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        prompt=prompt,\n        model_description=model_description,\n        inference=True,\n    )\n    tags = [\"text-to-image\", \"diffusers\", \"lora\", \"diffusers-training\"]\n    if isinstance(pipeline, StableDiffusionPipeline):\n        tags.extend([\"stable-diffusion\", \"stable-diffusion-diffusers\"])\n    else:\n        tags.extend([\"if\", \"if-diffusers\"])\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n    scheduler_args = {}\n\n    if \"variance_type\" in pipeline.scheduler.config:\n        variance_type = pipeline.scheduler.config.variance_type\n\n        if variance_type in [\"learned\", \"learned_range\"]:\n            variance_type = \"fixed_small\"\n\n        scheduler_args[\"variance_type\"] = variance_type\n\n    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n\n    pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n\n    if args.validation_images is None:\n        images = []\n        for _ in range(args.num_validation_images):\n            with torch.amp.autocast(accelerator.device.type):\n                image = pipeline(**pipeline_args, generator=generator).images[0]\n                images.append(image)\n    else:\n        images = []\n        for image in args.validation_images:\n            image = Image.open(image)\n            with torch.amp.autocast(accelerator.device.type):\n                image = pipeline(**pipeline_args, image=image, generator=generator).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A folder containing the training data of instance images.\",\n    )\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"lora-dreambooth-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--pre_compute_text_embeddings\",\n        action=\"store_true\",\n        help=\"Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_max_length\",\n        type=int,\n        default=None,\n        required=False,\n        help=\"The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.\",\n    )\n    parser.add_argument(\n        \"--text_encoder_use_attention_mask\",\n        action=\"store_true\",\n        required=False,\n        help=\"Whether to use attention mask for the text encoder\",\n    )\n    parser.add_argument(\n        \"--validation_images\",\n        required=False,\n        default=None,\n        nargs=\"+\",\n        help=\"Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.\",\n    )\n    parser.add_argument(\n        \"--class_labels_conditioning\",\n        required=False,\n        default=None,\n        help=\"The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.\",\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    if args.train_text_encoder and args.pre_compute_text_embeddings:\n        raise ValueError(\"`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images and the tokenizes prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        tokenizer,\n        class_data_root=None,\n        class_prompt=None,\n        class_num=None,\n        size=512,\n        center_crop=False,\n        encoder_hidden_states=None,\n        class_prompt_encoder_hidden_states=None,\n        tokenizer_max_length=None,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n        self.tokenizer = tokenizer\n        self.encoder_hidden_states = encoder_hidden_states\n        self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states\n        self.tokenizer_max_length = tokenizer_max_length\n\n        self.instance_data_root = Path(instance_data_root)\n        if not self.instance_data_root.exists():\n            raise ValueError(\"Instance images root doesn't exists.\")\n\n        self.instance_images_path = list(Path(instance_data_root).iterdir())\n        self.num_instance_images = len(self.instance_images_path)\n        self.instance_prompt = instance_prompt\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n            self.class_prompt = class_prompt\n        else:\n            self.class_data_root = None\n\n        interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)\n        if interpolation is None:\n            raise ValueError(f\"Unsupported interpolation mode {interpolation=}.\")\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=interpolation),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])\n        instance_image = exif_transpose(instance_image)\n\n        if not instance_image.mode == \"RGB\":\n            instance_image = instance_image.convert(\"RGB\")\n        example[\"instance_images\"] = self.image_transforms(instance_image)\n\n        if self.encoder_hidden_states is not None:\n            example[\"instance_prompt_ids\"] = self.encoder_hidden_states\n        else:\n            text_inputs = tokenize_prompt(\n                self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length\n            )\n            example[\"instance_prompt_ids\"] = text_inputs.input_ids\n            example[\"instance_attention_mask\"] = text_inputs.attention_mask\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n\n            if self.class_prompt_encoder_hidden_states is not None:\n                example[\"class_prompt_ids\"] = self.class_prompt_encoder_hidden_states\n            else:\n                class_text_inputs = tokenize_prompt(\n                    self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length\n                )\n                example[\"class_prompt_ids\"] = class_text_inputs.input_ids\n                example[\"class_attention_mask\"] = class_text_inputs.attention_mask\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    has_attention_mask = \"instance_attention_mask\" in examples[0]\n\n    input_ids = [example[\"instance_prompt_ids\"] for example in examples]\n    pixel_values = [example[\"instance_images\"] for example in examples]\n\n    if has_attention_mask:\n        attention_mask = [example[\"instance_attention_mask\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        input_ids += [example[\"class_prompt_ids\"] for example in examples]\n        pixel_values += [example[\"class_images\"] for example in examples]\n        if has_attention_mask:\n            attention_mask += [example[\"class_attention_mask\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    input_ids = torch.cat(input_ids, dim=0)\n\n    batch = {\n        \"input_ids\": input_ids,\n        \"pixel_values\": pixel_values,\n    }\n\n    if has_attention_mask:\n        batch[\"attention_mask\"] = attention_mask\n\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):\n    if tokenizer_max_length is not None:\n        max_length = tokenizer_max_length\n    else:\n        max_length = tokenizer.model_max_length\n\n    text_inputs = tokenizer(\n        prompt,\n        truncation=True,\n        padding=\"max_length\",\n        max_length=max_length,\n        return_tensors=\"pt\",\n    )\n\n    return text_inputs\n\n\ndef encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):\n    text_input_ids = input_ids.to(text_encoder.device)\n\n    if text_encoder_use_attention_mask:\n        attention_mask = attention_mask.to(text_encoder.device)\n    else:\n        attention_mask = None\n\n    prompt_embeds = text_encoder(\n        text_input_ids,\n        attention_mask=attention_mask,\n        return_dict=False,\n    )\n    prompt_embeds = prompt_embeds[0]\n\n    return prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate\n    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.\n    # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.\n    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:\n        raise ValueError(\n            \"Gradient accumulation is not supported when training the text encoder in distributed training. \"\n            \"Please set gradient_accumulation_steps to 1. This feature will be supported in the future.\"\n        )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            torch_dtype = torch.float16 if accelerator.device.type in (\"cuda\", \"xpu\") else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n            pipeline = DiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                safety_checker=None,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    if args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.pretrained_model_name_or_path,\n            subfolder=\"tokenizer\",\n            revision=args.revision,\n            use_fast=False,\n        )\n\n    # import correct text encoder class\n    text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = text_encoder_cls.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    try:\n        vae = AutoencoderKL.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n        )\n    except OSError:\n        # IF does not have a VAE so let's just set it to None\n        # We don't have to error out here\n        vae = None\n\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    if vae is not None:\n        vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    if vae is not None:\n        vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder.gradient_checkpointing_enable()\n\n    # now we will add new LoRA weights to the attention layers\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\", \"add_k_proj\", \"add_v_proj\"],\n    )\n    unet.add_adapter(unet_lora_config)\n\n    # The text encoder comes from 🤗 transformers, we will also attach adapters to it.\n    if args.train_text_encoder:\n        text_lora_config = LoraConfig(\n            r=args.rank,\n            lora_alpha=args.rank,\n            lora_dropout=args.lora_dropout,\n            init_lora_weights=\"gaussian\",\n            target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n        )\n        text_encoder.add_adapter(text_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            # there are only two options here. Either are just the unet attn processor layers\n            # or there are the unet and text encoder atten layers\n            unet_lora_layers_to_save = None\n            text_encoder_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(unwrap_model(unet))):\n                    unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))\n                elif isinstance(model, type(unwrap_model(text_encoder))):\n                    text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(\n                        get_peft_model_state_dict(model)\n                    )\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            StableDiffusionLoraLoaderMixin.save_lora_weights(\n                output_dir,\n                unet_lora_layers=unet_lora_layers_to_save,\n                text_encoder_lora_layers=text_encoder_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n        text_encoder_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(unet))):\n                unet_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder))):\n                text_encoder_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)\n\n        unet_state_dict = {f\"{k.replace('unet.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"unet.\")}\n        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)\n        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name=\"default\")\n\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        if args.train_text_encoder:\n            _set_state_dict_into_text_encoder(lora_state_dict, prefix=\"text_encoder.\", text_encoder=text_encoder_)\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [unet_]\n            if args.train_text_encoder:\n                models.append(text_encoder_)\n\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models, dtype=torch.float32)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [unet]\n        if args.train_text_encoder:\n            models.append(text_encoder)\n\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))\n    if args.train_text_encoder:\n        params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters()))\n\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    if args.pre_compute_text_embeddings:\n\n        def compute_text_embeddings(prompt):\n            with torch.no_grad():\n                text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)\n                prompt_embeds = encode_prompt(\n                    text_encoder,\n                    text_inputs.input_ids,\n                    text_inputs.attention_mask,\n                    text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,\n                )\n\n            return prompt_embeds\n\n        pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)\n        validation_prompt_negative_prompt_embeds = compute_text_embeddings(\"\")\n\n        if args.validation_prompt is not None:\n            validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)\n        else:\n            validation_prompt_encoder_hidden_states = None\n\n        if args.class_prompt is not None:\n            pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)\n        else:\n            pre_computed_class_prompt_encoder_hidden_states = None\n\n        text_encoder = None\n        tokenizer = None\n\n        gc.collect()\n        free_memory()\n    else:\n        pre_computed_encoder_hidden_states = None\n        validation_prompt_encoder_hidden_states = None\n        validation_prompt_negative_prompt_embeds = None\n        pre_computed_class_prompt_encoder_hidden_states = None\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_prompt=args.class_prompt,\n        class_num=args.num_class_images,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        center_crop=args.center_crop,\n        encoder_hidden_states=pre_computed_encoder_hidden_states,\n        class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,\n        tokenizer_max_length=args.tokenizer_max_length,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = vars(copy.deepcopy(args))\n        tracker_config.pop(\"validation_images\")\n        accelerator.init_trackers(\"dreambooth-lora\", config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.train_text_encoder:\n            text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                pixel_values = batch[\"pixel_values\"].to(dtype=weight_dtype)\n\n                if vae is not None:\n                    # Convert images to latent space\n                    model_input = vae.encode(pixel_values).latent_dist.sample()\n                    model_input = model_input * vae.config.scaling_factor\n                else:\n                    model_input = pixel_values\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz, channels, height, width = model_input.shape\n                # Sample a random timestep for each image\n                timesteps = torch.randint(\n                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                )\n                timesteps = timesteps.long()\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                if args.pre_compute_text_embeddings:\n                    encoder_hidden_states = batch[\"input_ids\"]\n                else:\n                    encoder_hidden_states = encode_prompt(\n                        text_encoder,\n                        batch[\"input_ids\"],\n                        batch[\"attention_mask\"],\n                        text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,\n                    )\n\n                if unwrap_model(unet).config.in_channels == channels * 2:\n                    noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)\n\n                if args.class_labels_conditioning == \"timesteps\":\n                    class_labels = timesteps\n                else:\n                    class_labels = None\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_model_input,\n                    timesteps,\n                    encoder_hidden_states,\n                    class_labels=class_labels,\n                    return_dict=False,\n                )[0]\n\n                # if model predicts variance, throw away the prediction. we will only train on the\n                # simplified training objective. This means that all schedulers using the fine tuned\n                # model must be configured to use one of the fixed variance variance types.\n                if model_pred.shape[1] == 6:\n                    model_pred, _ = torch.chunk(model_pred, 2, dim=1)\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute instance loss\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                    # Compute prior loss\n                    prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction=\"mean\")\n\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n                else:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = DiffusionPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    unet=unwrap_model(unet),\n                    text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n\n                if args.pre_compute_text_embeddings:\n                    pipeline_args = {\n                        \"prompt_embeds\": validation_prompt_encoder_hidden_states,\n                        \"negative_prompt_embeds\": validation_prompt_negative_prompt_embeds,\n                    }\n                else:\n                    pipeline_args = {\"prompt\": args.validation_prompt}\n\n                images = log_validation(\n                    pipeline,\n                    args,\n                    accelerator,\n                    pipeline_args,\n                    epoch,\n                    torch_dtype=weight_dtype,\n                )\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        unet = unet.to(torch.float32)\n\n        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n\n        if args.train_text_encoder:\n            text_encoder = unwrap_model(text_encoder)\n            text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))\n        else:\n            text_encoder_state_dict = None\n\n        StableDiffusionLoraLoaderMixin.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_state_dict,\n            text_encoder_lora_layers=text_encoder_state_dict,\n        )\n\n        # Final inference\n        # Load previous pipeline\n        pipeline = DiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype\n        )\n\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir, weight_name=\"pytorch_lora_weights.safetensors\")\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt, \"num_inference_steps\": 25}\n            images = log_validation(\n                pipeline,\n                args,\n                accelerator,\n                pipeline_args,\n                epoch,\n                is_final_validation=True,\n                torch_dtype=weight_dtype,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                train_text_encoder=args.train_text_encoder,\n                prompt=args.instance_prompt,\n                repo_folder=args.output_dir,\n                pipeline=pipeline,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_flux.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=0.31.0\",\n#     \"transformers>=4.41.2\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.11.1\",\n#     \"sentencepiece\",\n#     \"torchvision\",\n#     \"datasets\",\n#     \"bitsandbytes\",\n#     \"prodigyopt\",\n# ]\n# ///\n\nimport argparse\nimport copy\nimport itertools\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    FlowMatchEulerDiscreteScheduler,\n    FluxPipeline,\n    FluxTransformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    _collate_lora_metadata,\n    _set_state_dict_into_text_encoder,\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    free_memory,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    train_text_encoder=False,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# Flux DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).\n\nWas LoRA for the text encoder enabled? {train_text_encoder}.\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\npipeline = AutoPipelineForText2Image.from_pretrained(\"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16).to('cuda')\npipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"flux\",\n        \"flux-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef load_text_encoders(class_one, class_two):\n    text_encoder_one = class_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = class_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    return text_encoder_one, text_encoder_two\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n\n    # pre-calculate  prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast\n    with torch.no_grad():\n        prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(\n            pipeline_args[\"prompt\"], prompt_2=pipeline_args[\"prompt\"]\n        )\n    images = []\n    for _ in range(args.num_validation_images):\n        with autocast_ctx:\n            image = pipeline(\n                prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator\n            ).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=512,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=4,\n        help=\"LoRA alpha to be used for additional scaling.\",\n    )\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the FLUX.1 dev variant is a guidance distilled model\",\n    )\n\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v,to_out.0\" will result in lora training of attention layers only'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\"--enable_npu_flash_attention\", action=\"store_true\", help=\"Enabla Flash Attention for NPU\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        self.pixel_values = []\n        train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in self.instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            image = train_transforms(image)\n            self.pixel_values.append(image)\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef tokenize_prompt(tokenizer, prompt, max_sequence_length):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=max_sequence_length,\n        truncation=True,\n        return_length=False,\n        return_overflowing_tokens=False,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\ndef _encode_prompt_with_t5(\n    text_encoder,\n    tokenizer,\n    max_sequence_length=512,\n    prompt=None,\n    num_images_per_prompt=1,\n    device=None,\n    text_input_ids=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device))[0]\n\n    if hasattr(text_encoder, \"module\"):\n        dtype = text_encoder.module.dtype\n    else:\n        dtype = text_encoder.dtype\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    _, seq_len, _ = prompt_embeds.shape\n\n    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds\n\n\ndef _encode_prompt_with_clip(\n    text_encoder,\n    tokenizer,\n    prompt: str,\n    device=None,\n    text_input_ids=None,\n    num_images_per_prompt: int = 1,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=77,\n            truncation=True,\n            return_overflowing_tokens=False,\n            return_length=False,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n    if hasattr(text_encoder, \"module\"):\n        dtype = text_encoder.module.dtype\n    else:\n        dtype = text_encoder.dtype\n    # Use pooled output of CLIPTextModel\n    prompt_embeds = prompt_embeds.pooler_output\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    # duplicate text embeddings for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n    return prompt_embeds\n\n\ndef encode_prompt(\n    text_encoders,\n    tokenizers,\n    prompt: str,\n    max_sequence_length,\n    device=None,\n    num_images_per_prompt: int = 1,\n    text_input_ids_list=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n\n    if hasattr(text_encoders[0], \"module\"):\n        dtype = text_encoders[0].module.dtype\n    else:\n        dtype = text_encoders[0].dtype\n\n    pooled_prompt_embeds = _encode_prompt_with_clip(\n        text_encoder=text_encoders[0],\n        tokenizer=tokenizers[0],\n        prompt=prompt,\n        device=device if device is not None else text_encoders[0].device,\n        num_images_per_prompt=num_images_per_prompt,\n        text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,\n    )\n\n    prompt_embeds = _encode_prompt_with_t5(\n        text_encoder=text_encoders[1],\n        tokenizer=tokenizers[1],\n        max_sequence_length=max_sequence_length,\n        prompt=prompt,\n        num_images_per_prompt=num_images_per_prompt,\n        device=device if device is not None else text_encoders[1].device,\n        text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,\n    )\n\n    text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\n\n    return prompt_embeds, pooled_prompt_embeds, text_ids\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()\n            torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n\n            pipeline = FluxPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):\n                    images = pipeline(prompt=example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n    tokenizer_two = T5TokenizerFast.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    transformer = FluxTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            transformer.set_attention_backend(\"_native_npu\")\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu device \")\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    vae.to(accelerator.device, dtype=weight_dtype)\n    transformer.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder_one.gradient_checkpointing_enable()\n\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\n            \"attn.to_k\",\n            \"attn.to_q\",\n            \"attn.to_v\",\n            \"attn.to_out.0\",\n            \"attn.add_k_proj\",\n            \"attn.add_q_proj\",\n            \"attn.add_v_proj\",\n            \"attn.to_add_out\",\n            \"ff.net.0.proj\",\n            \"ff.net.2\",\n            \"ff_context.net.0.proj\",\n            \"ff_context.net.2\",\n        ]\n\n    # now we will add new LoRA weights the transformer layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n    if args.train_text_encoder:\n        text_lora_config = LoraConfig(\n            r=args.rank,\n            lora_alpha=args.lora_alpha,\n            lora_dropout=args.lora_dropout,\n            init_lora_weights=\"gaussian\",\n            target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n        )\n        text_encoder_one.add_adapter(text_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n            text_encoder_one_lora_layers_to_save = None\n            modules_to_save = {}\n            for model in models:\n                if isinstance(model, type(unwrap_model(transformer))):\n                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                    modules_to_save[\"transformer\"] = model\n                elif isinstance(model, type(unwrap_model(text_encoder_one))):\n                    text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)\n                    modules_to_save[\"text_encoder\"] = model\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            FluxPipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n                text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,\n                **_collate_lora_metadata(modules_to_save),\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n        text_encoder_one_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(transformer))):\n                transformer_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder_one))):\n                text_encoder_one_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict = FluxPipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n        if args.train_text_encoder:\n            # Do we need to call `scale_lora_layers()` here?\n            _set_state_dict_into_text_encoder(lora_state_dict, prefix=\"text_encoder.\", text_encoder=text_encoder_one_)\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            if args.train_text_encoder:\n                models.extend([text_encoder_one_])\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        if args.train_text_encoder:\n            models.extend([text_encoder_one])\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n    if args.train_text_encoder:\n        text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    if args.train_text_encoder:\n        # different learning rate for text encoder and unet\n        text_parameters_one_with_lr = {\n            \"params\": text_lora_parameters_one,\n            \"weight_decay\": args.adam_weight_decay_text_encoder,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]\n    else:\n        params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n        if args.train_text_encoder and args.text_encoder_lr:\n            logger.warning(\n                f\"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:\"\n                f\" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. \"\n                f\"When using prodigy only learning_rate is used as the initial learning rate.\"\n            )\n            # changes the learning rate of text_encoder_parameters_one to be\n            # --learning_rate\n            params_to_optimize[1][\"lr\"] = args.learning_rate\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    if not args.train_text_encoder:\n        tokenizers = [tokenizer_one, tokenizer_two]\n        text_encoders = [text_encoder_one, text_encoder_two]\n\n        def compute_text_embeddings(prompt, text_encoders, tokenizers):\n            with torch.no_grad():\n                prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n                    text_encoders, tokenizers, prompt, args.max_sequence_length\n                )\n                prompt_embeds = prompt_embeds.to(accelerator.device)\n                pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)\n                text_ids = text_ids.to(accelerator.device)\n            return prompt_embeds, pooled_prompt_embeds, text_ids\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings(\n            args.instance_prompt, text_encoders, tokenizers\n        )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        if not args.train_text_encoder:\n            class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings(\n                args.class_prompt, text_encoders, tokenizers\n            )\n\n    # Clear the memory here\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two\n        free_memory()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n\n    if not train_dataset.custom_instance_prompts:\n        if not args.train_text_encoder:\n            prompt_embeds = instance_prompt_hidden_states\n            pooled_prompt_embeds = instance_pooled_prompt_embeds\n            text_ids = instance_text_ids\n            if args.with_prior_preservation:\n                prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n                pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)\n                text_ids = torch.cat([text_ids, class_text_ids], dim=0)\n        # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts)\n        # we need to tokenize and encode the batch prompts on all training steps\n        else:\n            tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77)\n            tokens_two = tokenize_prompt(\n                tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length\n            )\n            if args.with_prior_preservation:\n                class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77)\n                class_tokens_two = tokenize_prompt(\n                    tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length\n                )\n                tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)\n                tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)\n\n    vae_config_shift_factor = vae.config.shift_factor\n    vae_config_scaling_factor = vae.config.scaling_factor\n    vae_config_block_out_channels = vae.config.block_out_channels\n    if args.cache_latents:\n        latents_cache = []\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                    accelerator.device, non_blocking=True, dtype=weight_dtype\n                )\n                latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n\n        if args.validation_prompt is None:\n            del vae\n            free_memory()\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        (\n            transformer,\n            text_encoder_one,\n            optimizer,\n            train_dataloader,\n            lr_scheduler,\n        ) = accelerator.prepare(\n            transformer,\n            text_encoder_one,\n            optimizer,\n            train_dataloader,\n            lr_scheduler,\n        )\n    else:\n        transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            transformer, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-flux-dev-lora\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n        if args.train_text_encoder:\n            text_encoder_one.train()\n            # set top parameter requires_grad = True for gradient checkpointing works\n            unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            if args.train_text_encoder:\n                models_to_accumulate.extend([text_encoder_one])\n            with accelerator.accumulate(models_to_accumulate):\n                prompts = batch[\"prompts\"]\n\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    if not args.train_text_encoder:\n                        prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(\n                            prompts, text_encoders, tokenizers\n                        )\n                    else:\n                        tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)\n                        tokens_two = tokenize_prompt(\n                            tokenizer_two, prompts, max_sequence_length=args.max_sequence_length\n                        )\n                        prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n                            text_encoders=[text_encoder_one, text_encoder_two],\n                            tokenizers=[None, None],\n                            text_input_ids_list=[tokens_one, tokens_two],\n                            max_sequence_length=args.max_sequence_length,\n                            device=accelerator.device,\n                            prompt=prompts,\n                        )\n                else:\n                    elems_to_repeat = len(prompts)\n                    if args.train_text_encoder:\n                        prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n                            text_encoders=[text_encoder_one, text_encoder_two],\n                            tokenizers=[None, None],\n                            text_input_ids_list=[\n                                tokens_one.repeat(elems_to_repeat, 1),\n                                tokens_two.repeat(elems_to_repeat, 1),\n                            ],\n                            max_sequence_length=args.max_sequence_length,\n                            device=accelerator.device,\n                            prompt=args.instance_prompt,\n                        )\n                    else:\n                        prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(\n                            prompts, text_encoders, tokenizers\n                        )\n\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step].sample()\n                else:\n                    pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent_dist.sample()\n                model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor\n                model_input = model_input.to(dtype=weight_dtype)\n\n                vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)\n\n                latent_image_ids = FluxPipeline._prepare_latent_image_ids(\n                    model_input.shape[0],\n                    model_input.shape[2] // 2,\n                    model_input.shape[3] // 2,\n                    accelerator.device,\n                    weight_dtype,\n                )\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                packed_noisy_model_input = FluxPipeline._pack_latents(\n                    noisy_model_input,\n                    batch_size=model_input.shape[0],\n                    num_channels_latents=model_input.shape[1],\n                    height=model_input.shape[2],\n                    width=model_input.shape[3],\n                )\n\n                # handle guidance\n                if unwrap_model(transformer).config.guidance_embeds:\n                    guidance = torch.tensor([args.guidance_scale], device=accelerator.device)\n                    guidance = guidance.expand(model_input.shape[0])\n                else:\n                    guidance = None\n\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=packed_noisy_model_input,\n                    # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)\n                    timestep=timesteps / 1000,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    return_dict=False,\n                )[0]\n                model_pred = FluxPipeline._unpack_latents(\n                    model_pred,\n                    height=model_input.shape[2] * vae_scale_factor,\n                    width=model_input.shape[3] * vae_scale_factor,\n                    vae_scale_factor=vae_scale_factor,\n                )\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = noise - model_input\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(transformer.parameters(), text_encoder_one.parameters())\n                        if args.train_text_encoder\n                        else transformer.parameters()\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                if not args.train_text_encoder:\n                    text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)\n                    text_encoder_one.to(weight_dtype)\n                    text_encoder_two.to(weight_dtype)\n                pipeline = FluxPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    text_encoder=unwrap_model(text_encoder_one),\n                    text_encoder_2=unwrap_model(text_encoder_two),\n                    transformer=unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                    torch_dtype=weight_dtype,\n                )\n                if not args.train_text_encoder:\n                    del text_encoder_one, text_encoder_two\n                    free_memory()\n\n                images = None\n                del pipeline\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        modules_to_save = {}\n        transformer = unwrap_model(transformer)\n        if args.upcast_before_saving:\n            transformer.to(torch.float32)\n        else:\n            transformer = transformer.to(weight_dtype)\n        transformer_lora_layers = get_peft_model_state_dict(transformer)\n        modules_to_save[\"transformer\"] = transformer\n\n        if args.train_text_encoder:\n            text_encoder_one = unwrap_model(text_encoder_one)\n            text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))\n            modules_to_save[\"text_encoder\"] = text_encoder_one\n        else:\n            text_encoder_lora_layers = None\n\n        FluxPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n            text_encoder_lora_layers=text_encoder_lora_layers,\n            **_collate_lora_metadata(modules_to_save),\n        )\n\n        # Final inference\n        # Load previous pipeline\n        pipeline = FluxPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt}\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=pipeline_args,\n                epoch=epoch,\n                is_final_validation=True,\n                torch_dtype=weight_dtype,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                train_text_encoder=args.train_text_encoder,\n                instance_prompt=args.instance_prompt,\n                validation_prompt=args.validation_prompt,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        images = None\n        del pipeline\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_flux2.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=0.31.0\",\n#     \"transformers>=4.41.2\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.11.1\",\n#     \"sentencepiece\",\n#     \"torchvision\",\n#     \"datasets\",\n#     \"bitsandbytes\",\n#     \"prodigyopt\",\n# ]\n# ///\n\nimport argparse\nimport copy\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import Any\n\nimport numpy as np\nimport torch\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torch.utils.data.sampler import BatchSampler\nfrom torchvision import transforms\nfrom torchvision.transforms import functional as TF\nfrom tqdm.auto import tqdm\nfrom transformers import Mistral3ForConditionalGeneration, PixtralProcessor\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKLFlux2,\n    BitsAndBytesConfig,\n    FlowMatchEulerDiscreteScheduler,\n    Flux2Pipeline,\n    Flux2Transformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    _collate_lora_metadata,\n    _to_cpu_contiguous,\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    find_nearest_bucket,\n    free_memory,\n    get_fsdp_kwargs_from_accelerator,\n    offload_models,\n    parse_buckets_string,\n    wrap_with_fsdp,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif getattr(torch, \"distributed\", None) is not None:\n    import torch.distributed as dist\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n    quant_training=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# Flux2 DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md).\n\nQuant training? {quant_training}\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\npipeline = AutoPipelineForText2Image.from_pretrained(\"black-forest-labs/FLUX.2\", torch_dtype=torch.bfloat16).to('cuda')\npipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"flux2\",\n        \"flux2-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    args.num_validation_images = args.num_validation_images if args.num_validation_images else 1\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(dtype=torch_dtype)\n    pipeline.enable_model_cpu_offload()\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n\n    images = []\n    for _ in range(args.num_validation_images):\n        with autocast_ctx:\n            image = pipeline(\n                prompt_embeds=pipeline_args[\"prompt_embeds\"],\n                generator=generator,\n            ).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef module_filter_fn(mod: torch.nn.Module, fqn: str):\n    # don't convert the output module\n    if fqn == \"proj_out\":\n        return False\n    # don't convert linear modules with weight dimensions not divisible by 16\n    if isinstance(mod, torch.nn.Linear):\n        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:\n            return False\n    return True\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--bnb_quantization_config_path\",\n        type=str,\n        default=None,\n        help=\"Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.\",\n    )\n    parser.add_argument(\n        \"--do_fp8_training\",\n        action=\"store_true\",\n        help=\"if we are doing FP8 training.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=512,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--text_encoder_out_layers\",\n        type=int,\n        nargs=\"+\",\n        default=[10, 20, 30],\n        help=\"Text encoder hidden layers to compute the final text embeddings.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--skip_final_inference\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.\",\n    )\n    parser.add_argument(\n        \"--final_validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=4,\n        help=\"LoRA alpha to be used for additional scaling.\",\n    )\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--aspect_ratio_buckets\",\n        type=str,\n        default=None,\n        help=(\n            \"Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. \"\n            \"e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'\"\n            \"Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored.\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the FLUX.1 dev variant is a guidance distilled model\",\n    )\n\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v,to_out.0\" will result in lora training of attention layers only'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoder to CPU when they are not used.\",\n    )\n    parser.add_argument(\n        \"--remote_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to use a remote text encoder. This means the text encoder will not be loaded locally and instead, the prompt embeddings will be computed remotely using the HuggingFace Inference API.\",\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\"--enable_npu_flash_attention\", action=\"store_true\", help=\"Enabla Flash Attention for NPU\")\n    parser.add_argument(\"--fsdp_text_encoder\", action=\"store_true\", help=\"Use FSDP for text encoder\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n    if args.do_fp8_training and args.bnb_quantization_config_path:\n        raise ValueError(\"Both `do_fp8_training` and `bnb_quantization_config_path` cannot be passed.\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n        buckets=None,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        self.buckets = buckets\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        self.pixel_values = []\n        for i, image in enumerate(self.instance_images):\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n\n            width, height = image.size\n\n            # Find the closest bucket\n            bucket_idx = find_nearest_bucket(height, width, self.buckets)\n            target_height, target_width = self.buckets[bucket_idx]\n            self.size = (target_height, target_width)\n\n            # based on the bucket assignment, define the transformations\n            image = self.train_transform(\n                image,\n                size=self.size,\n                center_crop=args.center_crop,\n                random_flip=args.random_flip,\n            )\n            self.pixel_values.append((image, bucket_idx))\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n        example[\"bucket_idx\"] = bucket_idx\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n    def train_transform(self, image, size=(224, 224), center_crop=False, random_flip=False):\n        # 1. Resize (deterministic)\n        resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        image = resize(image)\n\n        # 2. Crop: either center or SAME random crop\n        if center_crop:\n            crop = transforms.CenterCrop(size)\n            image = crop(image)\n        else:\n            # get_params returns (i, j, h, w)\n            i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)\n            image = TF.crop(image, i, j, h, w)\n\n        # 3. Random horizontal flip with the SAME coin flip\n        if random_flip:\n            do_flip = random.random() < 0.5\n            if do_flip:\n                image = TF.hflip(image)\n\n        # 4. ToTensor + Normalize (deterministic)\n        to_tensor = transforms.ToTensor()\n        normalize = transforms.Normalize([0.5], [0.5])\n        image = normalize(to_tensor(image))\n\n        return image\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass BucketBatchSampler(BatchSampler):\n    def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):\n        if not isinstance(batch_size, int) or batch_size <= 0:\n            raise ValueError(\"batch_size should be a positive integer value, but got batch_size={}\".format(batch_size))\n        if not isinstance(drop_last, bool):\n            raise ValueError(\"drop_last should be a boolean value, but got drop_last={}\".format(drop_last))\n\n        self.dataset = dataset\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n\n        # Group indices by bucket\n        self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]\n        for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):\n            self.bucket_indices[bucket_idx].append(idx)\n\n        self.sampler_len = 0\n        self.batches = []\n\n        # Pre-generate batches for each bucket\n        for indices_in_bucket in self.bucket_indices:\n            # Shuffle indices within the bucket\n            random.shuffle(indices_in_bucket)\n            # Create batches\n            for i in range(0, len(indices_in_bucket), self.batch_size):\n                batch = indices_in_bucket[i : i + self.batch_size]\n                if len(batch) < self.batch_size and self.drop_last:\n                    continue  # Skip partial batch if drop_last is True\n                self.batches.append(batch)\n                self.sampler_len += 1  # Count the number of batches\n\n    def __iter__(self):\n        # Shuffle the order of the batches each epoch\n        random.shuffle(self.batches)\n        for batch in self.batches:\n            yield batch\n\n    def __len__(self):\n        return self.sampler_len\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n    if args.do_fp8_training:\n        from torchao.float8 import Float8LinearConfig, convert_to_float8_training\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()\n            torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n\n            pipeline = Flux2Pipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):\n                    images = pipeline(prompt=example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer = PixtralProcessor.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"scheduler\",\n        revision=args.revision,\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    vae = AutoencoderKLFlux2.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device)\n    latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(\n        accelerator.device\n    )\n\n    quantization_config = None\n    if args.bnb_quantization_config_path is not None:\n        with open(args.bnb_quantization_config_path, \"r\") as f:\n            config_kwargs = json.load(f)\n            if \"load_in_4bit\" in config_kwargs and config_kwargs[\"load_in_4bit\"]:\n                config_kwargs[\"bnb_4bit_compute_dtype\"] = weight_dtype\n        quantization_config = BitsAndBytesConfig(**config_kwargs)\n\n    transformer = Flux2Transformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n        quantization_config=quantization_config,\n        torch_dtype=weight_dtype,\n    )\n    if args.bnb_quantization_config_path is not None:\n        transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)\n\n    if not args.remote_text_encoder:\n        text_encoder = Mistral3ForConditionalGeneration.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n        )\n        text_encoder.requires_grad_(False)\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            transformer.set_attention_backend(\"_native_npu\")\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu device \")\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    to_kwargs = {\"dtype\": weight_dtype, \"device\": accelerator.device} if not args.offload else {\"dtype\": weight_dtype}\n    # flux vae is stable in bf16 so load it in weight_dtype to reduce memory\n    vae.to(**to_kwargs)\n    # we never offload the transformer to CPU, so we can just use the accelerator device\n    transformer_to_kwargs = (\n        {\"device\": accelerator.device}\n        if args.bnb_quantization_config_path is not None\n        else {\"device\": accelerator.device, \"dtype\": weight_dtype}\n    )\n\n    is_fsdp = getattr(accelerator.state, \"fsdp_plugin\", None) is not None\n    if not is_fsdp:\n        transformer.to(**transformer_to_kwargs)\n\n    if args.do_fp8_training:\n        convert_to_float8_training(\n            transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)\n        )\n\n    if not args.remote_text_encoder:\n        text_encoder.to(**to_kwargs)\n        # Initialize a text encoding pipeline and keep it to CPU for now.\n        text_encoding_pipeline = Flux2Pipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            vae=None,\n            transformer=None,\n            tokenizer=tokenizer,\n            text_encoder=text_encoder,\n            scheduler=None,\n            revision=args.revision,\n        )\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\n\n    # now we will add new LoRA weights the transformer layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        transformer_cls = type(unwrap_model(transformer))\n\n        # 1) Validate and pick the transformer model\n        modules_to_save: dict[str, Any] = {}\n        transformer_model = None\n\n        for model in models:\n            if isinstance(unwrap_model(model), transformer_cls):\n                transformer_model = model\n                modules_to_save[\"transformer\"] = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        if transformer_model is None:\n            raise ValueError(\"No transformer model found in 'models'\")\n\n        # 2) Optionally gather FSDP state dict once\n        state_dict = accelerator.get_state_dict(model) if is_fsdp else None\n\n        # 3) Only main process materializes the LoRA state dict\n        transformer_lora_layers_to_save = None\n        if accelerator.is_main_process:\n            peft_kwargs = {}\n            if is_fsdp:\n                peft_kwargs[\"state_dict\"] = state_dict\n\n            transformer_lora_layers_to_save = get_peft_model_state_dict(\n                unwrap_model(transformer_model) if is_fsdp else transformer_model,\n                **peft_kwargs,\n            )\n\n            if is_fsdp:\n                transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)\n\n            # make sure to pop weight so that corresponding model is not saved again\n            if weights:\n                weights.pop()\n\n            Flux2Pipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n                **_collate_lora_metadata(modules_to_save),\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        if not is_fsdp:\n            while len(models) > 0:\n                model = models.pop()\n\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    transformer_ = unwrap_model(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n        else:\n            transformer_ = Flux2Transformer2DModel.from_pretrained(\n                args.pretrained_model_name_or_path,\n                subfolder=\"transformer\",\n            )\n            transformer_.add_adapter(transformer_lora_config)\n\n        lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    if args.aspect_ratio_buckets is not None:\n        buckets = parse_buckets_string(args.aspect_ratio_buckets)\n    else:\n        buckets = [(args.resolution, args.resolution)]\n    logger.info(f\"Using parsed aspect ratio buckets: {buckets}\")\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n        buckets=buckets,\n    )\n    batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_sampler=batch_sampler,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    def compute_text_embeddings(prompt, text_encoding_pipeline):\n        with torch.no_grad():\n            prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(\n                prompt=prompt,\n                max_sequence_length=args.max_sequence_length,\n                text_encoder_out_layers=args.text_encoder_out_layers,\n            )\n        return prompt_embeds, text_ids\n\n    def compute_remote_text_embeddings(prompts):\n        import io\n\n        import requests\n\n        if args.hub_token is not None:\n            hf_token = args.hub_token\n        else:\n            from huggingface_hub import get_token\n\n            hf_token = get_token()\n            if hf_token is None:\n                raise ValueError(\n                    \"No HuggingFace token found. To use the remote text encoder please login using `hf auth login` or provide a token using --hub_token\"\n                )\n\n        def _encode_single(prompt: str):\n            response = requests.post(\n                \"https://remote-text-encoder-flux-2.huggingface.co/predict\",\n                json={\"prompt\": prompt},\n                headers={\"Authorization\": f\"Bearer {hf_token}\", \"Content-Type\": \"application/json\"},\n            )\n            assert response.status_code == 200, f\"{response.status_code=}\"\n            return torch.load(io.BytesIO(response.content))\n\n        try:\n            if isinstance(prompts, (list, tuple)):\n                embeds = [_encode_single(p) for p in prompts]\n                prompt_embeds = torch.cat(embeds, dim=0)\n            else:\n                prompt_embeds = _encode_single(prompts)\n\n            text_ids = Flux2Pipeline._prepare_text_ids(prompt_embeds).to(accelerator.device)\n            prompt_embeds = prompt_embeds.to(accelerator.device)\n            return prompt_embeds, text_ids\n\n        except Exception as e:\n            raise RuntimeError(\"Remote text encoder inference failed.\") from e\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not train_dataset.custom_instance_prompts:\n        if args.remote_text_encoder:\n            instance_prompt_hidden_states, instance_text_ids = compute_remote_text_embeddings(args.instance_prompt)\n        else:\n            with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n                instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings(\n                    args.instance_prompt, text_encoding_pipeline\n                )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        if args.remote_text_encoder:\n            class_prompt_hidden_states, class_text_ids = compute_remote_text_embeddings(args.class_prompt)\n        else:\n            with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n                class_prompt_hidden_states, class_text_ids = compute_text_embeddings(\n                    args.class_prompt, text_encoding_pipeline\n                )\n    validation_embeddings = {}\n    if args.validation_prompt is not None:\n        if args.remote_text_encoder:\n            (validation_embeddings[\"prompt_embeds\"], validation_embeddings[\"text_ids\"]) = (\n                compute_remote_text_embeddings(args.validation_prompt)\n            )\n        else:\n            with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n                (validation_embeddings[\"prompt_embeds\"], validation_embeddings[\"text_ids\"]) = compute_text_embeddings(\n                    args.validation_prompt, text_encoding_pipeline\n                )\n\n    # Init FSDP for text encoder\n    if args.fsdp_text_encoder:\n        fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)\n        text_encoder_fsdp = wrap_with_fsdp(\n            model=text_encoding_pipeline.text_encoder,\n            device=accelerator.device,\n            offload=args.offload,\n            limit_all_gathers=True,\n            use_orig_params=True,\n            fsdp_kwargs=fsdp_kwargs,\n        )\n\n        text_encoding_pipeline.text_encoder = text_encoder_fsdp\n        dist.barrier()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n    if not train_dataset.custom_instance_prompts:\n        prompt_embeds = instance_prompt_hidden_states\n        text_ids = instance_text_ids\n        if args.with_prior_preservation:\n            prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n            text_ids = torch.cat([text_ids, class_text_ids], dim=0)\n\n    # if cache_latents is set to True, we encode images to latents and store them.\n    # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided\n    # we encode them in advance as well.\n    precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts\n    if precompute_latents:\n        prompt_embeds_cache = []\n        text_ids_cache = []\n        latents_cache = []\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                if args.cache_latents:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                            accelerator.device, non_blocking=True, dtype=vae.dtype\n                        )\n                        latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n                if train_dataset.custom_instance_prompts:\n                    if args.remote_text_encoder:\n                        prompt_embeds, text_ids = compute_remote_text_embeddings(batch[\"prompts\"])\n                    elif args.fsdp_text_encoder:\n                        prompt_embeds, text_ids = compute_text_embeddings(batch[\"prompts\"], text_encoding_pipeline)\n                    else:\n                        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n                            prompt_embeds, text_ids = compute_text_embeddings(batch[\"prompts\"], text_encoding_pipeline)\n                    prompt_embeds_cache.append(prompt_embeds)\n                    text_ids_cache.append(text_ids)\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    if args.cache_latents:\n        vae = vae.to(\"cpu\")\n        del vae\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    if not args.remote_text_encoder:\n        text_encoding_pipeline = text_encoding_pipeline.to(\"cpu\")\n        del text_encoder, tokenizer\n    free_memory()\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-flux2-lora\"\n        args_cp = vars(args).copy()\n        args_cp[\"text_encoder_out_layers\"] = str(args_cp[\"text_encoder_out_layers\"])\n        accelerator.init_trackers(tracker_name, config=args_cp)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            prompts = batch[\"prompts\"]\n\n            with accelerator.accumulate(models_to_accumulate):\n                if train_dataset.custom_instance_prompts:\n                    prompt_embeds = prompt_embeds_cache[step]\n                    text_ids = text_ids_cache[step]\n                else:\n                    num_repeat_elements = len(prompts)\n                    prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)\n                    text_ids = text_ids.repeat(num_repeat_elements, 1, 1)\n\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step].mode()\n                else:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent_dist.mode()\n\n                model_input = Flux2Pipeline._patchify_latents(model_input)\n                model_input = (model_input - latents_bn_mean) / latents_bn_std\n\n                model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                # [B, C, H, W] -> [B, H*W, C]\n                packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)\n\n                # handle guidance\n                guidance = torch.full([1], args.guidance_scale, device=accelerator.device)\n                guidance = guidance.expand(model_input.shape[0])\n\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=packed_noisy_model_input,  # (B, image_seq_len, C)\n                    timestep=timesteps / 1000,\n                    guidance=guidance,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,  # B, text_seq_len, 4\n                    img_ids=model_input_ids,  # B, image_seq_len, 4\n                    return_dict=False,\n                )[0]\n                model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]\n\n                model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = noise - model_input\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process or is_fsdp:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = Flux2Pipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    text_encoder=None,\n                    tokenizer=None,\n                    transformer=unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=validation_embeddings,\n                    epoch=epoch,\n                    torch_dtype=weight_dtype,\n                )\n\n                del pipeline\n                free_memory()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n\n    if is_fsdp:\n        transformer = unwrap_model(transformer)\n        state_dict = accelerator.get_state_dict(transformer)\n    if accelerator.is_main_process:\n        modules_to_save = {}\n        if is_fsdp:\n            if args.bnb_quantization_config_path is None:\n                if args.upcast_before_saving:\n                    state_dict = {\n                        k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()\n                    }\n                else:\n                    state_dict = {\n                        k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()\n                    }\n\n            transformer_lora_layers = get_peft_model_state_dict(\n                transformer,\n                state_dict=state_dict,\n            )\n            transformer_lora_layers = {\n                k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v\n                for k, v in transformer_lora_layers.items()\n            }\n\n        else:\n            transformer = unwrap_model(transformer)\n            if args.bnb_quantization_config_path is None:\n                if args.upcast_before_saving:\n                    transformer.to(torch.float32)\n                else:\n                    transformer = transformer.to(weight_dtype)\n            transformer_lora_layers = get_peft_model_state_dict(transformer)\n\n        modules_to_save[\"transformer\"] = transformer\n\n        Flux2Pipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n            **_collate_lora_metadata(modules_to_save),\n        )\n\n        images = []\n        run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)\n        should_run_final_inference = not args.skip_final_inference and run_validation\n        if should_run_final_inference:\n            pipeline = Flux2Pipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                revision=args.revision,\n                variant=args.variant,\n                torch_dtype=weight_dtype,\n            )\n            # load attention processors\n            pipeline.load_lora_weights(args.output_dir)\n\n            # run inference\n            images = []\n            if args.validation_prompt and args.num_validation_images > 0:\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=validation_embeddings,\n                    epoch=epoch,\n                    is_final_validation=True,\n                    torch_dtype=weight_dtype,\n                )\n            images = None\n            del pipeline\n            free_memory()\n\n        validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt\n        quant_training = None\n        if args.do_fp8_training:\n            quant_training = \"FP8 TorchAO\"\n        elif args.bnb_quantization_config_path:\n            quant_training = \"BitsandBytes\"\n        save_model_card(\n            (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,\n            images=images,\n            base_model=args.pretrained_model_name_or_path,\n            instance_prompt=args.instance_prompt,\n            validation_prompt=validation_prompt,\n            repo_folder=args.output_dir,\n            quant_training=quant_training,\n        )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_flux2_img2img.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=0.31.0\",\n#     \"transformers>=4.41.2\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.11.1\",\n#     \"sentencepiece\",\n#     \"torchvision\",\n#     \"datasets\",\n#     \"bitsandbytes\",\n#     \"prodigyopt\",\n# ]\n# ///\n\nimport argparse\nimport copy\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import Any\n\nimport numpy as np\nimport torch\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torch.utils.data.sampler import BatchSampler\nfrom torchvision import transforms\nfrom torchvision.transforms import functional as TF\nfrom tqdm.auto import tqdm\nfrom transformers import Mistral3ForConditionalGeneration, PixtralProcessor\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKLFlux2,\n    BitsAndBytesConfig,\n    FlowMatchEulerDiscreteScheduler,\n    Flux2Pipeline,\n    Flux2Transformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor\nfrom diffusers.training_utils import (\n    _collate_lora_metadata,\n    _to_cpu_contiguous,\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    find_nearest_bucket,\n    free_memory,\n    get_fsdp_kwargs_from_accelerator,\n    offload_models,\n    parse_buckets_string,\n    wrap_with_fsdp,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n    load_image,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif getattr(torch, \"distributed\", None) is not None:\n    import torch.distributed as dist\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n    fp8_training=False,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# Flux.2 DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md).\n\nFP8 training? {fp8_training}\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\npipeline = AutoPipelineForText2Image.from_pretrained(\"black-forest-labs/FLUX.2\", torch_dtype=torch.bfloat16).to('cuda')\npipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"flux2\",\n        \"flux2-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    args.num_validation_images = args.num_validation_images if args.num_validation_images else 1\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(dtype=torch_dtype)\n    pipeline.enable_model_cpu_offload()\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n\n    images = []\n    for _ in range(args.num_validation_images):\n        with autocast_ctx:\n            image = pipeline(\n                image=pipeline_args[\"image\"],\n                prompt_embeds=pipeline_args[\"prompt_embeds\"],\n                generator=generator,\n            ).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef module_filter_fn(mod: torch.nn.Module, fqn: str):\n    # don't convert the output module\n    if fqn == \"proj_out\":\n        return False\n    # don't convert linear modules with weight dimensions not divisible by 16\n    if isinstance(mod, torch.nn.Linear):\n        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:\n            return False\n    return True\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--bnb_quantization_config_path\",\n        type=str,\n        default=None,\n        help=\"Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.\",\n    )\n    parser.add_argument(\n        \"--do_fp8_training\",\n        action=\"store_true\",\n        help=\"if we are doing FP8 training.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--cond_image_column\",\n        type=str,\n        default=None,\n        help=\"Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=512,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        help=\"path to an image that is used during validation as the condition image to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--skip_final_inference\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.\",\n    )\n    parser.add_argument(\n        \"--final_validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=4,\n        help=\"LoRA alpha to be used for additional scaling.\",\n    )\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--aspect_ratio_buckets\",\n        type=str,\n        default=None,\n        help=(\n            \"Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. \"\n            \"e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'\"\n            \"Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored.\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the FLUX.1 dev variant is a guidance distilled model\",\n    )\n\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v,to_out.0\" will result in lora training of attention layers only'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoder to CPU when they are not used.\",\n    )\n    parser.add_argument(\n        \"--remote_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to use a remote text encoder. This means the text encoder will not be loaded locally and instead, the prompt embeddings will be computed remotely using the HuggingFace Inference API.\",\n    )\n\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\"--enable_npu_flash_attention\", action=\"store_true\", help=\"Enabla Flash Attention for NPU\")\n    parser.add_argument(\"--fsdp_text_encoder\", action=\"store_true\", help=\"Use FSDP for text encoder\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.cond_image_column is None:\n        raise ValueError(\n            \"you must provide --cond_image_column for image-to-image training. Otherwise please see Flux2 text-to-image training example.\"\n        )\n    else:\n        assert args.image_column is not None\n        assert args.caption_column is not None\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n        buckets=None,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n\n        self.buckets = buckets\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.cond_image_column is not None and args.cond_image_column not in column_names:\n                raise ValueError(\n                    f\"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                )\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n            cond_images = None\n            cond_image_column = args.cond_image_column\n            if cond_image_column is not None:\n                cond_images = [dataset[\"train\"][i][cond_image_column] for i in range(len(dataset[\"train\"]))]\n                assert len(instance_images) == len(cond_images)\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        self.cond_images = []\n        for i, img in enumerate(instance_images):\n            self.instance_images.extend(itertools.repeat(img, repeats))\n            if args.dataset_name is not None and cond_images is not None:\n                self.cond_images.extend(itertools.repeat(cond_images[i], repeats))\n\n        self.pixel_values = []\n        self.cond_pixel_values = []\n        for i, image in enumerate(self.instance_images):\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            dest_image = None\n            if self.cond_images:  # todo: take care of max area for buckets\n                dest_image = self.cond_images[i]\n                image_width, image_height = dest_image.size\n                if image_width * image_height > 1024 * 1024:\n                    dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024)\n                    image_width, image_height = dest_image.size\n\n                multiple_of = 2 ** (4 - 1)  # 2 ** (len(vae.config.block_out_channels) - 1), temp!\n                image_width = (image_width // multiple_of) * multiple_of\n                image_height = (image_height // multiple_of) * multiple_of\n                image_processor = Flux2ImageProcessor()\n                dest_image = image_processor.preprocess(\n                    dest_image, height=image_height, width=image_width, resize_mode=\"crop\"\n                )\n                # Convert back to PIL\n                dest_image = dest_image.squeeze(0)\n                if dest_image.min() < 0:\n                    dest_image = (dest_image + 1) / 2\n                dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu()\n\n                if dest_image.shape[0] == 1:\n                    # Gray scale image\n                    dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode=\"L\")\n                else:\n                    # RGB scale image: (C, H, W) -> (H, W, C)\n                    dest_image = TF.to_pil_image(dest_image)\n\n                dest_image = exif_transpose(dest_image)\n                if not dest_image.mode == \"RGB\":\n                    dest_image = dest_image.convert(\"RGB\")\n\n            width, height = image.size\n\n            # Find the closest bucket\n            bucket_idx = find_nearest_bucket(height, width, self.buckets)\n            target_height, target_width = self.buckets[bucket_idx]\n            self.size = (target_height, target_width)\n\n            # based on the bucket assignment, define the transformations\n            image, dest_image = self.paired_transform(\n                image,\n                dest_image=dest_image,\n                size=self.size,\n                center_crop=args.center_crop,\n                random_flip=args.random_flip,\n            )\n            self.pixel_values.append((image, bucket_idx))\n            if dest_image is not None:\n                self.cond_pixel_values.append((dest_image, bucket_idx))\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n        example[\"bucket_idx\"] = bucket_idx\n        if self.cond_pixel_values:\n            dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]\n            example[\"cond_images\"] = dest_image\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        return example\n\n    def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):\n        # 1. Resize (deterministic)\n        resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        image = resize(image)\n        if dest_image is not None:\n            dest_image = resize(dest_image)\n\n        # 2. Crop: either center or SAME random crop\n        if center_crop:\n            crop = transforms.CenterCrop(size)\n            image = crop(image)\n            if dest_image is not None:\n                dest_image = crop(dest_image)\n        else:\n            # get_params returns (i, j, h, w)\n            i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)\n            image = TF.crop(image, i, j, h, w)\n            if dest_image is not None:\n                dest_image = TF.crop(dest_image, i, j, h, w)\n\n        # 3. Random horizontal flip with the SAME coin flip\n        if random_flip:\n            do_flip = random.random() < 0.5\n            if do_flip:\n                image = TF.hflip(image)\n                if dest_image is not None:\n                    dest_image = TF.hflip(dest_image)\n\n        # 4. ToTensor + Normalize (deterministic)\n        to_tensor = transforms.ToTensor()\n        normalize = transforms.Normalize([0.5], [0.5])\n        image = normalize(to_tensor(image))\n        if dest_image is not None:\n            dest_image = normalize(to_tensor(dest_image))\n\n        return (image, dest_image) if dest_image is not None else (image, None)\n\n\ndef collate_fn(examples):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    if any(\"cond_images\" in example for example in examples):\n        cond_pixel_values = [example[\"cond_images\"] for example in examples]\n        cond_pixel_values = torch.stack(cond_pixel_values)\n        cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()\n        batch.update({\"cond_pixel_values\": cond_pixel_values})\n    return batch\n\n\nclass BucketBatchSampler(BatchSampler):\n    def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):\n        if not isinstance(batch_size, int) or batch_size <= 0:\n            raise ValueError(\"batch_size should be a positive integer value, but got batch_size={}\".format(batch_size))\n        if not isinstance(drop_last, bool):\n            raise ValueError(\"drop_last should be a boolean value, but got drop_last={}\".format(drop_last))\n\n        self.dataset = dataset\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n\n        # Group indices by bucket\n        self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]\n        for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):\n            self.bucket_indices[bucket_idx].append(idx)\n\n        self.sampler_len = 0\n        self.batches = []\n\n        # Pre-generate batches for each bucket\n        for indices_in_bucket in self.bucket_indices:\n            # Shuffle indices within the bucket\n            random.shuffle(indices_in_bucket)\n            # Create batches\n            for i in range(0, len(indices_in_bucket), self.batch_size):\n                batch = indices_in_bucket[i : i + self.batch_size]\n                if len(batch) < self.batch_size and self.drop_last:\n                    continue  # Skip partial batch if drop_last is True\n                self.batches.append(batch)\n                self.sampler_len += 1  # Count the number of batches\n\n    def __iter__(self):\n        # Shuffle the order of the batches each epoch\n        random.shuffle(self.batches)\n        for batch in self.batches:\n            yield batch\n\n    def __len__(self):\n        return self.sampler_len\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n    if args.do_fp8_training:\n        from torchao.float8 import Float8LinearConfig, convert_to_float8_training\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer = PixtralProcessor.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"scheduler\",\n        revision=args.revision,\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    vae = AutoencoderKLFlux2.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device)\n    latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(\n        accelerator.device\n    )\n\n    quantization_config = None\n    if args.bnb_quantization_config_path is not None:\n        with open(args.bnb_quantization_config_path, \"r\") as f:\n            config_kwargs = json.load(f)\n            if \"load_in_4bit\" in config_kwargs and config_kwargs[\"load_in_4bit\"]:\n                config_kwargs[\"bnb_4bit_compute_dtype\"] = weight_dtype\n        quantization_config = BitsAndBytesConfig(**config_kwargs)\n\n    transformer = Flux2Transformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n        quantization_config=quantization_config,\n        torch_dtype=weight_dtype,\n    )\n    if args.bnb_quantization_config_path is not None:\n        transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)\n\n    if not args.remote_text_encoder:\n        text_encoder = Mistral3ForConditionalGeneration.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n        )\n        text_encoder.requires_grad_(False)\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            transformer.set_attention_backend(\"_native_npu\")\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu device \")\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    to_kwargs = {\"dtype\": weight_dtype, \"device\": accelerator.device} if not args.offload else {\"dtype\": weight_dtype}\n    # flux vae is stable in bf16 so load it in weight_dtype to reduce memory\n    vae.to(**to_kwargs)\n    # we never offload the transformer to CPU, so we can just use the accelerator device\n    transformer_to_kwargs = (\n        {\"device\": accelerator.device}\n        if args.bnb_quantization_config_path is not None\n        else {\"device\": accelerator.device, \"dtype\": weight_dtype}\n    )\n\n    is_fsdp = getattr(accelerator.state, \"fsdp_plugin\", None) is not None\n    if not is_fsdp:\n        transformer.to(**transformer_to_kwargs)\n\n    if args.do_fp8_training:\n        convert_to_float8_training(\n            transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)\n        )\n\n    if not args.remote_text_encoder:\n        text_encoder.to(**to_kwargs)\n        # Initialize a text encoding pipeline and keep it to CPU for now.\n        text_encoding_pipeline = Flux2Pipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            vae=None,\n            transformer=None,\n            tokenizer=tokenizer,\n            text_encoder=text_encoder,\n            scheduler=None,\n            revision=args.revision,\n        )\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\n\n    # now we will add new LoRA weights the transformer layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        transformer_cls = type(unwrap_model(transformer))\n\n        # 1) Validate and pick the transformer model\n        modules_to_save: dict[str, Any] = {}\n        transformer_model = None\n\n        for model in models:\n            if isinstance(unwrap_model(model), transformer_cls):\n                transformer_model = model\n                modules_to_save[\"transformer\"] = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        if transformer_model is None:\n            raise ValueError(\"No transformer model found in 'models'\")\n\n        # 2) Optionally gather FSDP state dict once\n        state_dict = accelerator.get_state_dict(model) if is_fsdp else None\n\n        # 3) Only main process materializes the LoRA state dict\n        transformer_lora_layers_to_save = None\n        if accelerator.is_main_process:\n            peft_kwargs = {}\n            if is_fsdp:\n                peft_kwargs[\"state_dict\"] = state_dict\n\n            transformer_lora_layers_to_save = get_peft_model_state_dict(\n                unwrap_model(transformer_model) if is_fsdp else transformer_model,\n                **peft_kwargs,\n            )\n\n            if is_fsdp:\n                transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)\n\n            # make sure to pop weight so that corresponding model is not saved again\n            if weights:\n                weights.pop()\n\n            Flux2Pipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n                **_collate_lora_metadata(modules_to_save),\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        if not is_fsdp:\n            while len(models) > 0:\n                model = models.pop()\n\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    transformer_ = unwrap_model(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n        else:\n            transformer_ = Flux2Transformer2DModel.from_pretrained(\n                args.pretrained_model_name_or_path,\n                subfolder=\"transformer\",\n            )\n            transformer_.add_adapter(transformer_lora_config)\n\n        lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    if args.aspect_ratio_buckets is not None:\n        buckets = parse_buckets_string(args.aspect_ratio_buckets)\n    else:\n        buckets = [(args.resolution, args.resolution)]\n    logger.info(f\"Using parsed aspect ratio buckets: {buckets}\")\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n        buckets=buckets,\n    )\n    batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_sampler=batch_sampler,\n        collate_fn=lambda examples: collate_fn(examples),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    def compute_text_embeddings(prompt, text_encoding_pipeline):\n        with torch.no_grad():\n            prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(\n                prompt=prompt, max_sequence_length=args.max_sequence_length\n            )\n            # prompt_embeds = prompt_embeds.to(accelerator.device)\n            # text_ids = text_ids.to(accelerator.device)\n        return prompt_embeds, text_ids\n\n    def compute_remote_text_embeddings(prompts: str | list[str]):\n        import io\n\n        import requests\n\n        if args.hub_token is not None:\n            hf_token = args.hub_token\n        else:\n            from huggingface_hub import get_token\n\n            hf_token = get_token()\n            if hf_token is None:\n                raise ValueError(\n                    \"No HuggingFace token found. To use the remote text encoder please login using `hf auth login` or provide a token using --hub_token\"\n                )\n\n        def _encode_single(prompt: str):\n            response = requests.post(\n                \"https://remote-text-encoder-flux-2.huggingface.co/predict\",\n                json={\"prompt\": prompt},\n                headers={\"Authorization\": f\"Bearer {hf_token}\", \"Content-Type\": \"application/json\"},\n            )\n            assert response.status_code == 200, f\"{response.status_code=}\"\n            return torch.load(io.BytesIO(response.content))\n\n        try:\n            if isinstance(prompts, (list, tuple)):\n                embeds = [_encode_single(p) for p in prompts]\n                prompt_embeds = torch.cat(embeds, dim=0).to(accelerator.device)\n            else:\n                prompt_embeds = _encode_single(prompts).to(accelerator.device)\n\n            text_ids = Flux2Pipeline._prepare_text_ids(prompt_embeds).to(accelerator.device)\n            return prompt_embeds, text_ids\n\n        except Exception as e:\n            raise RuntimeError(\"Remote text encoder inference failed.\") from e\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not train_dataset.custom_instance_prompts:\n        if args.remote_text_encoder:\n            instance_prompt_hidden_states, instance_text_ids = compute_remote_text_embeddings(args.instance_prompt)\n        else:\n            with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n                instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings(\n                    args.instance_prompt, text_encoding_pipeline\n                )\n\n    if args.validation_prompt is not None:\n        validation_image = load_image(args.validation_image_path).convert(\"RGB\")\n        validation_kwargs = {\"image\": validation_image}\n        if args.remote_text_encoder:\n            validation_kwargs[\"prompt_embeds\"] = compute_remote_text_embeddings(args.validation_prompt)\n        else:\n            with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n                validation_kwargs[\"prompt_embeds\"] = compute_text_embeddings(\n                    args.validation_prompt, text_encoding_pipeline\n                )\n\n    # Init FSDP for text encoder\n    if args.fsdp_text_encoder:\n        fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)\n        text_encoder_fsdp = wrap_with_fsdp(\n            model=text_encoding_pipeline.text_encoder,\n            device=accelerator.device,\n            offload=args.offload,\n            limit_all_gathers=True,\n            use_orig_params=True,\n            fsdp_kwargs=fsdp_kwargs,\n        )\n\n        text_encoding_pipeline.text_encoder = text_encoder_fsdp\n        dist.barrier()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n    if not train_dataset.custom_instance_prompts:\n        prompt_embeds = instance_prompt_hidden_states\n        text_ids = instance_text_ids\n\n    # if cache_latents is set to True, we encode images to latents and store them.\n    # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided\n    # we encode them in advance as well.\n    precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts\n    if precompute_latents:\n        prompt_embeds_cache = []\n        text_ids_cache = []\n        latents_cache = []\n        cond_latents_cache = []\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                if args.cache_latents:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                            accelerator.device, non_blocking=True, dtype=vae.dtype\n                        )\n                        latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n                        batch[\"cond_pixel_values\"] = batch[\"cond_pixel_values\"].to(\n                            accelerator.device, non_blocking=True, dtype=vae.dtype\n                        )\n                        cond_latents_cache.append(vae.encode(batch[\"cond_pixel_values\"]).latent_dist)\n                if train_dataset.custom_instance_prompts:\n                    if args.remote_text_encoder:\n                        prompt_embeds, text_ids = compute_remote_text_embeddings(batch[\"prompts\"])\n                    elif args.fsdp_text_encoder:\n                        prompt_embeds, text_ids = compute_text_embeddings(batch[\"prompts\"], text_encoding_pipeline)\n                    else:\n                        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n                            prompt_embeds, text_ids = compute_text_embeddings(batch[\"prompts\"], text_encoding_pipeline)\n                    prompt_embeds_cache.append(prompt_embeds)\n                    text_ids_cache.append(text_ids)\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    if args.cache_latents:\n        vae = vae.to(\"cpu\")\n        del vae\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    if not args.remote_text_encoder:\n        text_encoding_pipeline = text_encoding_pipeline.to(\"cpu\")\n        del text_encoder, tokenizer\n    free_memory()\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-flux2-image2img-lora\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            prompts = batch[\"prompts\"]\n\n            with accelerator.accumulate(models_to_accumulate):\n                if train_dataset.custom_instance_prompts:\n                    prompt_embeds = prompt_embeds_cache[step]\n                    text_ids = text_ids_cache[step]\n                else:\n                    num_repeat_elements = len(prompts)\n                    prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)\n                    text_ids = text_ids.repeat(num_repeat_elements, 1, 1)\n\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step].mode()\n                    cond_model_input = cond_latents_cache[step].mode()\n                else:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                        cond_pixel_values = batch[\"cond_pixel_values\"].to(dtype=vae.dtype)\n\n                    model_input = vae.encode(pixel_values).latent_dist.mode()\n                    cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()\n\n                    # model_input = Flux2Pipeline._encode_vae_image(pixel_values)\n\n                model_input = Flux2Pipeline._patchify_latents(model_input)\n                model_input = (model_input - latents_bn_mean) / latents_bn_std\n\n                cond_model_input = Flux2Pipeline._patchify_latents(cond_model_input)\n                cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std\n\n                model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)\n                cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]\n                cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to(\n                    device=cond_model_input.device\n                )\n                cond_model_input_ids = cond_model_input_ids.view(\n                    cond_model_input.shape[0], -1, model_input_ids.shape[-1]\n                )\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                # [B, C, H, W] -> [B, H*W, C]\n                packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)\n                packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)\n\n                orig_input_shape = packed_noisy_model_input.shape\n                orig_input_ids_shape = model_input_ids.shape\n\n                # concatenate the model inputs with the cond inputs\n                packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)\n                model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)\n\n                # handle guidance\n                guidance = torch.full([1], args.guidance_scale, device=accelerator.device)\n                guidance = guidance.expand(model_input.shape[0])\n\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=packed_noisy_model_input,  # (B, image_seq_len, C)\n                    timestep=timesteps / 1000,\n                    guidance=guidance,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,  # B, text_seq_len, 4\n                    img_ids=model_input_ids,  # B, image_seq_len, 4\n                    return_dict=False,\n                )[0]\n                model_pred = model_pred[:, : orig_input_shape[1], :]\n                model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]\n\n                model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = noise - model_input\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process or is_fsdp:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = Flux2Pipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    text_encoder=None,\n                    tokenizer=None,\n                    transformer=unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=validation_kwargs,\n                    epoch=epoch,\n                    torch_dtype=weight_dtype,\n                )\n\n                del pipeline\n                free_memory()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n\n    if is_fsdp:\n        transformer = unwrap_model(transformer)\n        state_dict = accelerator.get_state_dict(transformer)\n    if accelerator.is_main_process:\n        modules_to_save = {}\n        if is_fsdp:\n            if args.bnb_quantization_config_path is None:\n                if args.upcast_before_saving:\n                    state_dict = {\n                        k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()\n                    }\n                else:\n                    state_dict = {\n                        k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()\n                    }\n\n            transformer_lora_layers = get_peft_model_state_dict(\n                transformer,\n                state_dict=state_dict,\n            )\n            transformer_lora_layers = {\n                k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v\n                for k, v in transformer_lora_layers.items()\n            }\n\n        else:\n            transformer = unwrap_model(transformer)\n            if args.bnb_quantization_config_path is None:\n                if args.upcast_before_saving:\n                    transformer.to(torch.float32)\n                else:\n                    transformer = transformer.to(weight_dtype)\n            transformer_lora_layers = get_peft_model_state_dict(transformer)\n\n        modules_to_save[\"transformer\"] = transformer\n\n        Flux2Pipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n            **_collate_lora_metadata(modules_to_save),\n        )\n\n        images = []\n        run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)\n        should_run_final_inference = not args.skip_final_inference and run_validation\n        if should_run_final_inference:\n            pipeline = Flux2Pipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                revision=args.revision,\n                variant=args.variant,\n                torch_dtype=weight_dtype,\n            )\n            # load attention processors\n            pipeline.load_lora_weights(args.output_dir)\n\n            # run inference\n            images = []\n            if args.validation_prompt and args.num_validation_images > 0:\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=validation_kwargs,\n                    epoch=epoch,\n                    is_final_validation=True,\n                    torch_dtype=weight_dtype,\n                )\n            del pipeline\n            free_memory()\n\n        validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt\n        save_model_card(\n            (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,\n            images=images,\n            base_model=args.pretrained_model_name_or_path,\n            instance_prompt=args.instance_prompt,\n            validation_prompt=validation_prompt,\n            repo_folder=args.output_dir,\n            fp8_training=args.do_fp8_training,\n        )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_flux2_klein.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=0.31.0\",\n#     \"transformers>=4.41.2\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.11.1\",\n#     \"sentencepiece\",\n#     \"torchvision\",\n#     \"datasets\",\n#     \"bitsandbytes\",\n#     \"prodigyopt\",\n# ]\n# ///\n\nimport argparse\nimport copy\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import Any\n\nimport numpy as np\nimport torch\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torch.utils.data.sampler import BatchSampler\nfrom torchvision import transforms\nfrom torchvision.transforms import functional as TF\nfrom tqdm.auto import tqdm\nfrom transformers import Qwen2TokenizerFast, Qwen3ForCausalLM\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKLFlux2,\n    BitsAndBytesConfig,\n    FlowMatchEulerDiscreteScheduler,\n    Flux2KleinPipeline,\n    Flux2Transformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    _collate_lora_metadata,\n    _to_cpu_contiguous,\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    find_nearest_bucket,\n    free_memory,\n    get_fsdp_kwargs_from_accelerator,\n    offload_models,\n    parse_buckets_string,\n    wrap_with_fsdp,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif getattr(torch, \"distributed\", None) is not None:\n    import torch.distributed as dist\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n    quant_training=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# Flux.2 [Klein] DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md).\n\nQuant training? {quant_training}\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\npipeline = AutoPipelineForText2Image.from_pretrained(\"black-forest-labs/FLUX.2\", torch_dtype=torch.bfloat16).to('cuda')\npipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"flux2-klein\",\n        \"flux2-klein-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    args.num_validation_images = args.num_validation_images if args.num_validation_images else 1\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(dtype=torch_dtype)\n    pipeline.enable_model_cpu_offload()\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n\n    images = []\n    for _ in range(args.num_validation_images):\n        with autocast_ctx:\n            image = pipeline(\n                prompt_embeds=pipeline_args[\"prompt_embeds\"],\n                generator=generator,\n            ).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef module_filter_fn(mod: torch.nn.Module, fqn: str):\n    # don't convert the output module\n    if fqn == \"proj_out\":\n        return False\n    # don't convert linear modules with weight dimensions not divisible by 16\n    if isinstance(mod, torch.nn.Linear):\n        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:\n            return False\n    return True\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--bnb_quantization_config_path\",\n        type=str,\n        default=None,\n        help=\"Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.\",\n    )\n    parser.add_argument(\n        \"--do_fp8_training\",\n        action=\"store_true\",\n        help=\"if we are doing FP8 training.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=512,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--text_encoder_out_layers\",\n        type=int,\n        nargs=\"+\",\n        default=[10, 20, 30],\n        help=\"Text encoder hidden layers to compute the final text embeddings.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--skip_final_inference\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.\",\n    )\n    parser.add_argument(\n        \"--final_validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=4,\n        help=\"LoRA alpha to be used for additional scaling.\",\n    )\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--aspect_ratio_buckets\",\n        type=str,\n        default=None,\n        help=(\n            \"Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. \"\n            \"e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'\"\n            \"Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored.\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the FLUX.1 dev variant is a guidance distilled model\",\n    )\n\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v,to_out.0\" will result in lora training of attention layers only'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoder to CPU when they are not used.\",\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\"--enable_npu_flash_attention\", action=\"store_true\", help=\"Enabla Flash Attention for NPU\")\n    parser.add_argument(\"--fsdp_text_encoder\", action=\"store_true\", help=\"Use FSDP for text encoder\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n    if args.do_fp8_training and args.bnb_quantization_config_path:\n        raise ValueError(\"Both `do_fp8_training` and `bnb_quantization_config_path` cannot be passed.\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n        buckets=None,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        self.buckets = buckets\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        self.pixel_values = []\n        for i, image in enumerate(self.instance_images):\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n\n            width, height = image.size\n\n            # Find the closest bucket\n            bucket_idx = find_nearest_bucket(height, width, self.buckets)\n            target_height, target_width = self.buckets[bucket_idx]\n            self.size = (target_height, target_width)\n\n            # based on the bucket assignment, define the transformations\n            image = self.train_transform(\n                image,\n                size=self.size,\n                center_crop=args.center_crop,\n                random_flip=args.random_flip,\n            )\n            self.pixel_values.append((image, bucket_idx))\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n        example[\"bucket_idx\"] = bucket_idx\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n    def train_transform(self, image, size=(224, 224), center_crop=False, random_flip=False):\n        # 1. Resize (deterministic)\n        resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        image = resize(image)\n\n        # 2. Crop: either center or SAME random crop\n        if center_crop:\n            crop = transforms.CenterCrop(size)\n            image = crop(image)\n        else:\n            # get_params returns (i, j, h, w)\n            i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)\n            image = TF.crop(image, i, j, h, w)\n\n        # 3. Random horizontal flip with the SAME coin flip\n        if random_flip:\n            do_flip = random.random() < 0.5\n            if do_flip:\n                image = TF.hflip(image)\n\n        # 4. ToTensor + Normalize (deterministic)\n        to_tensor = transforms.ToTensor()\n        normalize = transforms.Normalize([0.5], [0.5])\n        image = normalize(to_tensor(image))\n\n        return image\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass BucketBatchSampler(BatchSampler):\n    def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):\n        if not isinstance(batch_size, int) or batch_size <= 0:\n            raise ValueError(\"batch_size should be a positive integer value, but got batch_size={}\".format(batch_size))\n        if not isinstance(drop_last, bool):\n            raise ValueError(\"drop_last should be a boolean value, but got drop_last={}\".format(drop_last))\n\n        self.dataset = dataset\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n\n        # Group indices by bucket\n        self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]\n        for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):\n            self.bucket_indices[bucket_idx].append(idx)\n\n        self.sampler_len = 0\n        self.batches = []\n\n        # Pre-generate batches for each bucket\n        for indices_in_bucket in self.bucket_indices:\n            # Shuffle indices within the bucket\n            random.shuffle(indices_in_bucket)\n            # Create batches\n            for i in range(0, len(indices_in_bucket), self.batch_size):\n                batch = indices_in_bucket[i : i + self.batch_size]\n                if len(batch) < self.batch_size and self.drop_last:\n                    continue  # Skip partial batch if drop_last is True\n                self.batches.append(batch)\n                self.sampler_len += 1  # Count the number of batches\n\n    def __iter__(self):\n        # Shuffle the order of the batches each epoch\n        random.shuffle(self.batches)\n        for batch in self.batches:\n            yield batch\n\n    def __len__(self):\n        return self.sampler_len\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n    if args.do_fp8_training:\n        from torchao.float8 import Float8LinearConfig, convert_to_float8_training\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()\n            torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n\n            pipeline = Flux2KleinPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):\n                    images = pipeline(prompt=example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer = Qwen2TokenizerFast.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"scheduler\",\n        revision=args.revision,\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    vae = AutoencoderKLFlux2.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device)\n    latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(\n        accelerator.device\n    )\n\n    quantization_config = None\n    if args.bnb_quantization_config_path is not None:\n        with open(args.bnb_quantization_config_path, \"r\") as f:\n            config_kwargs = json.load(f)\n            if \"load_in_4bit\" in config_kwargs and config_kwargs[\"load_in_4bit\"]:\n                config_kwargs[\"bnb_4bit_compute_dtype\"] = weight_dtype\n        quantization_config = BitsAndBytesConfig(**config_kwargs)\n\n    transformer = Flux2Transformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n        quantization_config=quantization_config,\n        torch_dtype=weight_dtype,\n    )\n    if args.bnb_quantization_config_path is not None:\n        transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)\n\n    text_encoder = Qwen3ForCausalLM.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder.requires_grad_(False)\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            transformer.set_attention_backend(\"_native_npu\")\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu device \")\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    to_kwargs = {\"dtype\": weight_dtype, \"device\": accelerator.device} if not args.offload else {\"dtype\": weight_dtype}\n    # flux vae is stable in bf16 so load it in weight_dtype to reduce memory\n    vae.to(**to_kwargs)\n    # we never offload the transformer to CPU, so we can just use the accelerator device\n    transformer_to_kwargs = (\n        {\"device\": accelerator.device}\n        if args.bnb_quantization_config_path is not None\n        else {\"device\": accelerator.device, \"dtype\": weight_dtype}\n    )\n\n    is_fsdp = getattr(accelerator.state, \"fsdp_plugin\", None) is not None\n    if not is_fsdp:\n        transformer.to(**transformer_to_kwargs)\n\n    if args.do_fp8_training:\n        convert_to_float8_training(\n            transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)\n        )\n\n    text_encoder.to(**to_kwargs)\n    # Initialize a text encoding pipeline and keep it to CPU for now.\n    text_encoding_pipeline = Flux2KleinPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=None,\n        transformer=None,\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        scheduler=None,\n        revision=args.revision,\n    )\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\n\n    # now we will add new LoRA weights the transformer layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        transformer_cls = type(unwrap_model(transformer))\n\n        # 1) Validate and pick the transformer model\n        modules_to_save: dict[str, Any] = {}\n        transformer_model = None\n\n        for model in models:\n            if isinstance(unwrap_model(model), transformer_cls):\n                transformer_model = model\n                modules_to_save[\"transformer\"] = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        if transformer_model is None:\n            raise ValueError(\"No transformer model found in 'models'\")\n\n        # 2) Optionally gather FSDP state dict once\n        state_dict = accelerator.get_state_dict(model) if is_fsdp else None\n\n        # 3) Only main process materializes the LoRA state dict\n        transformer_lora_layers_to_save = None\n        if accelerator.is_main_process:\n            peft_kwargs = {}\n            if is_fsdp:\n                peft_kwargs[\"state_dict\"] = state_dict\n\n            transformer_lora_layers_to_save = get_peft_model_state_dict(\n                unwrap_model(transformer_model) if is_fsdp else transformer_model,\n                **peft_kwargs,\n            )\n\n            if is_fsdp:\n                transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)\n\n            # make sure to pop weight so that corresponding model is not saved again\n            if weights:\n                weights.pop()\n\n            Flux2KleinPipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n                **_collate_lora_metadata(modules_to_save),\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        if not is_fsdp:\n            while len(models) > 0:\n                model = models.pop()\n\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    transformer_ = unwrap_model(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n        else:\n            transformer_ = Flux2Transformer2DModel.from_pretrained(\n                args.pretrained_model_name_or_path,\n                subfolder=\"transformer\",\n            )\n            transformer_.add_adapter(transformer_lora_config)\n\n        lora_state_dict = Flux2KleinPipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    if args.aspect_ratio_buckets is not None:\n        buckets = parse_buckets_string(args.aspect_ratio_buckets)\n    else:\n        buckets = [(args.resolution, args.resolution)]\n    logger.info(f\"Using parsed aspect ratio buckets: {buckets}\")\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n        buckets=buckets,\n    )\n    batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_sampler=batch_sampler,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    def compute_text_embeddings(prompt, text_encoding_pipeline):\n        with torch.no_grad():\n            prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(\n                prompt=prompt,\n                max_sequence_length=args.max_sequence_length,\n                text_encoder_out_layers=args.text_encoder_out_layers,\n            )\n        return prompt_embeds, text_ids\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not train_dataset.custom_instance_prompts:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings(\n                args.instance_prompt, text_encoding_pipeline\n            )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            class_prompt_hidden_states, class_text_ids = compute_text_embeddings(\n                args.class_prompt, text_encoding_pipeline\n            )\n    validation_embeddings = {}\n    if args.validation_prompt is not None:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            (validation_embeddings[\"prompt_embeds\"], validation_embeddings[\"text_ids\"]) = compute_text_embeddings(\n                args.validation_prompt, text_encoding_pipeline\n            )\n\n    # Init FSDP for text encoder\n    if args.fsdp_text_encoder:\n        fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)\n        text_encoder_fsdp = wrap_with_fsdp(\n            model=text_encoding_pipeline.text_encoder,\n            device=accelerator.device,\n            offload=args.offload,\n            limit_all_gathers=True,\n            use_orig_params=True,\n            fsdp_kwargs=fsdp_kwargs,\n        )\n\n        text_encoding_pipeline.text_encoder = text_encoder_fsdp\n        dist.barrier()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n    if not train_dataset.custom_instance_prompts:\n        prompt_embeds = instance_prompt_hidden_states\n        text_ids = instance_text_ids\n        if args.with_prior_preservation:\n            prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n            text_ids = torch.cat([text_ids, class_text_ids], dim=0)\n\n    # if cache_latents is set to True, we encode images to latents and store them.\n    # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided\n    # we encode them in advance as well.\n    precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts\n    if precompute_latents:\n        prompt_embeds_cache = []\n        text_ids_cache = []\n        latents_cache = []\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                if args.cache_latents:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                            accelerator.device, non_blocking=True, dtype=vae.dtype\n                        )\n                        latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n                if train_dataset.custom_instance_prompts:\n                    if args.fsdp_text_encoder:\n                        prompt_embeds, text_ids = compute_text_embeddings(batch[\"prompts\"], text_encoding_pipeline)\n                    else:\n                        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n                            prompt_embeds, text_ids = compute_text_embeddings(batch[\"prompts\"], text_encoding_pipeline)\n                    prompt_embeds_cache.append(prompt_embeds)\n                    text_ids_cache.append(text_ids)\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    if args.cache_latents:\n        vae = vae.to(\"cpu\")\n        del vae\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    text_encoding_pipeline = text_encoding_pipeline.to(\"cpu\")\n    del text_encoder, tokenizer\n    free_memory()\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-flux2-klein-lora\"\n        args_cp = vars(args).copy()\n        args_cp[\"text_encoder_out_layers\"] = str(args_cp[\"text_encoder_out_layers\"])\n        accelerator.init_trackers(tracker_name, config=args_cp)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            prompts = batch[\"prompts\"]\n\n            with accelerator.accumulate(models_to_accumulate):\n                if train_dataset.custom_instance_prompts:\n                    prompt_embeds = prompt_embeds_cache[step]\n                    text_ids = text_ids_cache[step]\n                else:\n                    num_repeat_elements = len(prompts)\n                    prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)\n                    text_ids = text_ids.repeat(num_repeat_elements, 1, 1)\n\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step].mode()\n                else:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent_dist.mode()\n\n                model_input = Flux2KleinPipeline._patchify_latents(model_input)\n                model_input = (model_input - latents_bn_mean) / latents_bn_std\n\n                model_input_ids = Flux2KleinPipeline._prepare_latent_ids(model_input).to(device=model_input.device)\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                # [B, C, H, W] -> [B, H*W, C]\n                packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input)\n\n                # handle guidance\n                if unwrap_model(transformer).config.guidance_embeds:\n                    guidance = torch.full([1], args.guidance_scale, device=accelerator.device)\n                    guidance = guidance.expand(model_input.shape[0])\n                else:\n                    guidance = None\n\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=packed_noisy_model_input,  # (B, image_seq_len, C)\n                    timestep=timesteps / 1000,\n                    guidance=guidance,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,  # B, text_seq_len, 4\n                    img_ids=model_input_ids,  # B, image_seq_len, 4\n                    return_dict=False,\n                )[0]\n                model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]\n\n                model_pred = Flux2KleinPipeline._unpack_latents_with_ids(model_pred, model_input_ids)\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = noise - model_input\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process or is_fsdp:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = Flux2KleinPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    transformer=unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=validation_embeddings,\n                    epoch=epoch,\n                    torch_dtype=weight_dtype,\n                )\n\n                del pipeline\n                free_memory()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n\n    if is_fsdp:\n        transformer = unwrap_model(transformer)\n        state_dict = accelerator.get_state_dict(transformer)\n    if accelerator.is_main_process:\n        modules_to_save = {}\n        if is_fsdp:\n            if args.bnb_quantization_config_path is None:\n                if args.upcast_before_saving:\n                    state_dict = {\n                        k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()\n                    }\n                else:\n                    state_dict = {\n                        k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()\n                    }\n\n            transformer_lora_layers = get_peft_model_state_dict(\n                transformer,\n                state_dict=state_dict,\n            )\n            transformer_lora_layers = {\n                k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v\n                for k, v in transformer_lora_layers.items()\n            }\n\n        else:\n            transformer = unwrap_model(transformer)\n            if args.bnb_quantization_config_path is None:\n                if args.upcast_before_saving:\n                    transformer.to(torch.float32)\n                else:\n                    transformer = transformer.to(weight_dtype)\n            transformer_lora_layers = get_peft_model_state_dict(transformer)\n\n        modules_to_save[\"transformer\"] = transformer\n\n        Flux2KleinPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n            **_collate_lora_metadata(modules_to_save),\n        )\n\n        images = []\n        run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)\n        should_run_final_inference = not args.skip_final_inference and run_validation\n        if should_run_final_inference:\n            pipeline = Flux2KleinPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                revision=args.revision,\n                variant=args.variant,\n                torch_dtype=weight_dtype,\n            )\n            # load attention processors\n            pipeline.load_lora_weights(args.output_dir)\n\n            # run inference\n            images = []\n            if args.validation_prompt and args.num_validation_images > 0:\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=validation_embeddings,\n                    epoch=epoch,\n                    is_final_validation=True,\n                    torch_dtype=weight_dtype,\n                )\n            images = None\n            del pipeline\n            free_memory()\n\n        validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt\n        quant_training = None\n        if args.do_fp8_training:\n            quant_training = \"FP8 TorchAO\"\n        elif args.bnb_quantization_config_path:\n            quant_training = \"BitsandBytes\"\n        save_model_card(\n            (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,\n            images=images,\n            base_model=args.pretrained_model_name_or_path,\n            instance_prompt=args.instance_prompt,\n            validation_prompt=validation_prompt,\n            repo_folder=args.output_dir,\n            quant_training=quant_training,\n        )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=0.31.0\",\n#     \"transformers>=4.41.2\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.11.1\",\n#     \"sentencepiece\",\n#     \"torchvision\",\n#     \"datasets\",\n#     \"bitsandbytes\",\n#     \"prodigyopt\",\n# ]\n# ///\n\nimport argparse\nimport copy\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import Any\n\nimport numpy as np\nimport torch\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torch.utils.data.sampler import BatchSampler\nfrom torchvision import transforms\nfrom torchvision.transforms import functional as TF\nfrom tqdm.auto import tqdm\nfrom transformers import Qwen2TokenizerFast, Qwen3ForCausalLM\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKLFlux2,\n    BitsAndBytesConfig,\n    FlowMatchEulerDiscreteScheduler,\n    Flux2KleinPipeline,\n    Flux2Transformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor\nfrom diffusers.training_utils import (\n    _collate_lora_metadata,\n    _to_cpu_contiguous,\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    find_nearest_bucket,\n    free_memory,\n    get_fsdp_kwargs_from_accelerator,\n    offload_models,\n    parse_buckets_string,\n    wrap_with_fsdp,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n    load_image,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif getattr(torch, \"distributed\", None) is not None:\n    import torch.distributed as dist\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n    fp8_training=False,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# Flux.2 [Klein] DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md).\n\nFP8 training? {fp8_training}\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\npipeline = AutoPipelineForText2Image.from_pretrained(\"black-forest-labs/FLUX.2\", torch_dtype=torch.bfloat16).to('cuda')\npipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"flux2\",\n        \"flux2-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    args.num_validation_images = args.num_validation_images if args.num_validation_images else 1\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(dtype=torch_dtype)\n    pipeline.enable_model_cpu_offload()\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n\n    images = []\n    for _ in range(args.num_validation_images):\n        with autocast_ctx:\n            image = pipeline(\n                image=pipeline_args[\"image\"],\n                prompt_embeds=pipeline_args[\"prompt_embeds\"],\n                negative_prompt_embeds=pipeline_args[\"negative_prompt_embeds\"],\n                generator=generator,\n            ).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef module_filter_fn(mod: torch.nn.Module, fqn: str):\n    # don't convert the output module\n    if fqn == \"proj_out\":\n        return False\n    # don't convert linear modules with weight dimensions not divisible by 16\n    if isinstance(mod, torch.nn.Linear):\n        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:\n            return False\n    return True\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--bnb_quantization_config_path\",\n        type=str,\n        default=None,\n        help=\"Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.\",\n    )\n    parser.add_argument(\n        \"--do_fp8_training\",\n        action=\"store_true\",\n        help=\"if we are doing FP8 training.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--cond_image_column\",\n        type=str,\n        default=None,\n        help=\"Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=512,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        help=\"path to an image that is used during validation as the condition image to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--skip_final_inference\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.\",\n    )\n    parser.add_argument(\n        \"--final_validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=4,\n        help=\"LoRA alpha to be used for additional scaling.\",\n    )\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--aspect_ratio_buckets\",\n        type=str,\n        default=None,\n        help=(\n            \"Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. \"\n            \"e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'\"\n            \"Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored.\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the FLUX.1 dev variant is a guidance distilled model\",\n    )\n\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v,to_out.0\" will result in lora training of attention layers only'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoder to CPU when they are not used.\",\n    )\n\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\"--enable_npu_flash_attention\", action=\"store_true\", help=\"Enabla Flash Attention for NPU\")\n    parser.add_argument(\"--fsdp_text_encoder\", action=\"store_true\", help=\"Use FSDP for text encoder\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.cond_image_column is None:\n        raise ValueError(\n            \"you must provide --cond_image_column for image-to-image training. Otherwise please see Flux2 text-to-image training example.\"\n        )\n    else:\n        assert args.image_column is not None\n        assert args.caption_column is not None\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n        buckets=None,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n\n        self.buckets = buckets\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.cond_image_column is not None and args.cond_image_column not in column_names:\n                raise ValueError(\n                    f\"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                )\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n            cond_images = None\n            cond_image_column = args.cond_image_column\n            if cond_image_column is not None:\n                cond_images = [dataset[\"train\"][i][cond_image_column] for i in range(len(dataset[\"train\"]))]\n                assert len(instance_images) == len(cond_images)\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        self.cond_images = []\n        for i, img in enumerate(instance_images):\n            self.instance_images.extend(itertools.repeat(img, repeats))\n            if args.dataset_name is not None and cond_images is not None:\n                self.cond_images.extend(itertools.repeat(cond_images[i], repeats))\n\n        self.pixel_values = []\n        self.cond_pixel_values = []\n        for i, image in enumerate(self.instance_images):\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            dest_image = None\n            if self.cond_images:  # todo: take care of max area for buckets\n                dest_image = self.cond_images[i]\n                image_width, image_height = dest_image.size\n                if image_width * image_height > 1024 * 1024:\n                    dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024)\n                    image_width, image_height = dest_image.size\n\n                multiple_of = 2 ** (4 - 1)  # 2 ** (len(vae.config.block_out_channels) - 1), temp!\n                image_width = (image_width // multiple_of) * multiple_of\n                image_height = (image_height // multiple_of) * multiple_of\n                image_processor = Flux2ImageProcessor()\n                dest_image = image_processor.preprocess(\n                    dest_image, height=image_height, width=image_width, resize_mode=\"crop\"\n                )\n                # Convert back to PIL\n                dest_image = dest_image.squeeze(0)\n                if dest_image.min() < 0:\n                    dest_image = (dest_image + 1) / 2\n                dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu()\n\n                if dest_image.shape[0] == 1:\n                    # Gray scale image\n                    dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode=\"L\")\n                else:\n                    # RGB scale image: (C, H, W) -> (H, W, C)\n                    dest_image = TF.to_pil_image(dest_image)\n\n                dest_image = exif_transpose(dest_image)\n                if not dest_image.mode == \"RGB\":\n                    dest_image = dest_image.convert(\"RGB\")\n\n            width, height = image.size\n\n            # Find the closest bucket\n            bucket_idx = find_nearest_bucket(height, width, self.buckets)\n            target_height, target_width = self.buckets[bucket_idx]\n            self.size = (target_height, target_width)\n\n            # based on the bucket assignment, define the transformations\n            image, dest_image = self.paired_transform(\n                image,\n                dest_image=dest_image,\n                size=self.size,\n                center_crop=args.center_crop,\n                random_flip=args.random_flip,\n            )\n            self.pixel_values.append((image, bucket_idx))\n            if dest_image is not None:\n                self.cond_pixel_values.append((dest_image, bucket_idx))\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n        example[\"bucket_idx\"] = bucket_idx\n        if self.cond_pixel_values:\n            dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]\n            example[\"cond_images\"] = dest_image\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        return example\n\n    def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):\n        # 1. Resize (deterministic)\n        resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        image = resize(image)\n        if dest_image is not None:\n            dest_image = resize(dest_image)\n\n        # 2. Crop: either center or SAME random crop\n        if center_crop:\n            crop = transforms.CenterCrop(size)\n            image = crop(image)\n            if dest_image is not None:\n                dest_image = crop(dest_image)\n        else:\n            # get_params returns (i, j, h, w)\n            i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)\n            image = TF.crop(image, i, j, h, w)\n            if dest_image is not None:\n                dest_image = TF.crop(dest_image, i, j, h, w)\n\n        # 3. Random horizontal flip with the SAME coin flip\n        if random_flip:\n            do_flip = random.random() < 0.5\n            if do_flip:\n                image = TF.hflip(image)\n                if dest_image is not None:\n                    dest_image = TF.hflip(dest_image)\n\n        # 4. ToTensor + Normalize (deterministic)\n        to_tensor = transforms.ToTensor()\n        normalize = transforms.Normalize([0.5], [0.5])\n        image = normalize(to_tensor(image))\n        if dest_image is not None:\n            dest_image = normalize(to_tensor(dest_image))\n\n        return (image, dest_image) if dest_image is not None else (image, None)\n\n\ndef collate_fn(examples):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    if any(\"cond_images\" in example for example in examples):\n        cond_pixel_values = [example[\"cond_images\"] for example in examples]\n        cond_pixel_values = torch.stack(cond_pixel_values)\n        cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()\n        batch.update({\"cond_pixel_values\": cond_pixel_values})\n    return batch\n\n\nclass BucketBatchSampler(BatchSampler):\n    def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):\n        if not isinstance(batch_size, int) or batch_size <= 0:\n            raise ValueError(\"batch_size should be a positive integer value, but got batch_size={}\".format(batch_size))\n        if not isinstance(drop_last, bool):\n            raise ValueError(\"drop_last should be a boolean value, but got drop_last={}\".format(drop_last))\n\n        self.dataset = dataset\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n\n        # Group indices by bucket\n        self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]\n        for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):\n            self.bucket_indices[bucket_idx].append(idx)\n\n        self.sampler_len = 0\n        self.batches = []\n\n        # Pre-generate batches for each bucket\n        for indices_in_bucket in self.bucket_indices:\n            # Shuffle indices within the bucket\n            random.shuffle(indices_in_bucket)\n            # Create batches\n            for i in range(0, len(indices_in_bucket), self.batch_size):\n                batch = indices_in_bucket[i : i + self.batch_size]\n                if len(batch) < self.batch_size and self.drop_last:\n                    continue  # Skip partial batch if drop_last is True\n                self.batches.append(batch)\n                self.sampler_len += 1  # Count the number of batches\n\n    def __iter__(self):\n        # Shuffle the order of the batches each epoch\n        random.shuffle(self.batches)\n        for batch in self.batches:\n            yield batch\n\n    def __len__(self):\n        return self.sampler_len\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n    if args.do_fp8_training:\n        from torchao.float8 import Float8LinearConfig, convert_to_float8_training\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer = Qwen2TokenizerFast.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"scheduler\",\n        revision=args.revision,\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    vae = AutoencoderKLFlux2.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device)\n    latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(\n        accelerator.device\n    )\n\n    quantization_config = None\n    if args.bnb_quantization_config_path is not None:\n        with open(args.bnb_quantization_config_path, \"r\") as f:\n            config_kwargs = json.load(f)\n            if \"load_in_4bit\" in config_kwargs and config_kwargs[\"load_in_4bit\"]:\n                config_kwargs[\"bnb_4bit_compute_dtype\"] = weight_dtype\n        quantization_config = BitsAndBytesConfig(**config_kwargs)\n\n    transformer = Flux2Transformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n        quantization_config=quantization_config,\n        torch_dtype=weight_dtype,\n    )\n    if args.bnb_quantization_config_path is not None:\n        transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)\n\n    text_encoder = Qwen3ForCausalLM.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder.requires_grad_(False)\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            transformer.set_attention_backend(\"_native_npu\")\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu device \")\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    to_kwargs = {\"dtype\": weight_dtype, \"device\": accelerator.device} if not args.offload else {\"dtype\": weight_dtype}\n    # flux vae is stable in bf16 so load it in weight_dtype to reduce memory\n    vae.to(**to_kwargs)\n    # we never offload the transformer to CPU, so we can just use the accelerator device\n    transformer_to_kwargs = (\n        {\"device\": accelerator.device}\n        if args.bnb_quantization_config_path is not None\n        else {\"device\": accelerator.device, \"dtype\": weight_dtype}\n    )\n\n    is_fsdp = getattr(accelerator.state, \"fsdp_plugin\", None) is not None\n    if not is_fsdp:\n        transformer.to(**transformer_to_kwargs)\n\n    if args.do_fp8_training:\n        convert_to_float8_training(\n            transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)\n        )\n\n    text_encoder.to(**to_kwargs)\n    # Initialize a text encoding pipeline and keep it to CPU for now.\n    text_encoding_pipeline = Flux2KleinPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=None,\n        transformer=None,\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        scheduler=None,\n        revision=args.revision,\n    )\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\n\n    # now we will add new LoRA weights the transformer layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        transformer_cls = type(unwrap_model(transformer))\n\n        # 1) Validate and pick the transformer model\n        modules_to_save: dict[str, Any] = {}\n        transformer_model = None\n\n        for model in models:\n            if isinstance(unwrap_model(model), transformer_cls):\n                transformer_model = model\n                modules_to_save[\"transformer\"] = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        if transformer_model is None:\n            raise ValueError(\"No transformer model found in 'models'\")\n\n        # 2) Optionally gather FSDP state dict once\n        state_dict = accelerator.get_state_dict(model) if is_fsdp else None\n\n        # 3) Only main process materializes the LoRA state dict\n        transformer_lora_layers_to_save = None\n        if accelerator.is_main_process:\n            peft_kwargs = {}\n            if is_fsdp:\n                peft_kwargs[\"state_dict\"] = state_dict\n\n            transformer_lora_layers_to_save = get_peft_model_state_dict(\n                unwrap_model(transformer_model) if is_fsdp else transformer_model,\n                **peft_kwargs,\n            )\n\n            if is_fsdp:\n                transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)\n\n            # make sure to pop weight so that corresponding model is not saved again\n            if weights:\n                weights.pop()\n\n            Flux2KleinPipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n                **_collate_lora_metadata(modules_to_save),\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        if not is_fsdp:\n            while len(models) > 0:\n                model = models.pop()\n\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    transformer_ = unwrap_model(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n        else:\n            transformer_ = Flux2Transformer2DModel.from_pretrained(\n                args.pretrained_model_name_or_path,\n                subfolder=\"transformer\",\n            )\n            transformer_.add_adapter(transformer_lora_config)\n\n        lora_state_dict = Flux2KleinPipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    if args.aspect_ratio_buckets is not None:\n        buckets = parse_buckets_string(args.aspect_ratio_buckets)\n    else:\n        buckets = [(args.resolution, args.resolution)]\n    logger.info(f\"Using parsed aspect ratio buckets: {buckets}\")\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n        buckets=buckets,\n    )\n    batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_sampler=batch_sampler,\n        collate_fn=lambda examples: collate_fn(examples),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    def compute_text_embeddings(prompt, text_encoding_pipeline):\n        with torch.no_grad():\n            prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(\n                prompt=prompt, max_sequence_length=args.max_sequence_length\n            )\n        return prompt_embeds, text_ids\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not train_dataset.custom_instance_prompts:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings(\n                args.instance_prompt, text_encoding_pipeline\n            )\n\n    if args.validation_prompt is not None:\n        validation_image = load_image(args.validation_image).convert(\"RGB\")\n        validation_kwargs = {\"image\": validation_image}\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            validation_kwargs[\"prompt_embeds\"], _text_ids = compute_text_embeddings(\n                args.validation_prompt, text_encoding_pipeline\n            )\n            validation_kwargs[\"negative_prompt_embeds\"], _text_ids = compute_text_embeddings(\n                \"\", text_encoding_pipeline\n            )\n\n    # Init FSDP for text encoder\n    if args.fsdp_text_encoder:\n        fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)\n        text_encoder_fsdp = wrap_with_fsdp(\n            model=text_encoding_pipeline.text_encoder,\n            device=accelerator.device,\n            offload=args.offload,\n            limit_all_gathers=True,\n            use_orig_params=True,\n            fsdp_kwargs=fsdp_kwargs,\n        )\n\n        text_encoding_pipeline.text_encoder = text_encoder_fsdp\n        dist.barrier()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n    if not train_dataset.custom_instance_prompts:\n        prompt_embeds = instance_prompt_hidden_states\n        text_ids = instance_text_ids\n\n    # if cache_latents is set to True, we encode images to latents and store them.\n    # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided\n    # we encode them in advance as well.\n    precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts\n    if precompute_latents:\n        prompt_embeds_cache = []\n        text_ids_cache = []\n        latents_cache = []\n        cond_latents_cache = []\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                if args.cache_latents:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                            accelerator.device, non_blocking=True, dtype=vae.dtype\n                        )\n                        latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n                        batch[\"cond_pixel_values\"] = batch[\"cond_pixel_values\"].to(\n                            accelerator.device, non_blocking=True, dtype=vae.dtype\n                        )\n                        cond_latents_cache.append(vae.encode(batch[\"cond_pixel_values\"]).latent_dist)\n                if train_dataset.custom_instance_prompts:\n                    if args.fsdp_text_encoder:\n                        prompt_embeds, text_ids = compute_text_embeddings(batch[\"prompts\"], text_encoding_pipeline)\n                    else:\n                        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n                            prompt_embeds, text_ids = compute_text_embeddings(batch[\"prompts\"], text_encoding_pipeline)\n                    prompt_embeds_cache.append(prompt_embeds)\n                    text_ids_cache.append(text_ids)\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    if args.cache_latents:\n        vae = vae.to(\"cpu\")\n        del vae\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    text_encoding_pipeline = text_encoding_pipeline.to(\"cpu\")\n    del text_encoder, tokenizer\n    free_memory()\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-flux2-image2img-lora\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            prompts = batch[\"prompts\"]\n\n            with accelerator.accumulate(models_to_accumulate):\n                if train_dataset.custom_instance_prompts:\n                    prompt_embeds = prompt_embeds_cache[step]\n                    text_ids = text_ids_cache[step]\n                else:\n                    num_repeat_elements = len(prompts)\n                    prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)\n                    text_ids = text_ids.repeat(num_repeat_elements, 1, 1)\n\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step].mode()\n                    cond_model_input = cond_latents_cache[step].mode()\n                else:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                        cond_pixel_values = batch[\"cond_pixel_values\"].to(dtype=vae.dtype)\n\n                    model_input = vae.encode(pixel_values).latent_dist.mode()\n                    cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()\n\n                model_input = Flux2KleinPipeline._patchify_latents(model_input)\n                model_input = (model_input - latents_bn_mean) / latents_bn_std\n\n                cond_model_input = Flux2KleinPipeline._patchify_latents(cond_model_input)\n                cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std\n\n                model_input_ids = Flux2KleinPipeline._prepare_latent_ids(model_input).to(device=model_input.device)\n                cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]\n                cond_model_input_ids = Flux2KleinPipeline._prepare_image_ids(cond_model_input_list).to(\n                    device=cond_model_input.device\n                )\n                cond_model_input_ids = cond_model_input_ids.view(\n                    cond_model_input.shape[0], -1, model_input_ids.shape[-1]\n                )\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                # [B, C, H, W] -> [B, H*W, C]\n                # concatenate the model inputs with the cond inputs\n                packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input)\n                packed_cond_model_input = Flux2KleinPipeline._pack_latents(cond_model_input)\n                orig_input_shape = packed_noisy_model_input.shape\n                orig_input_ids_shape = model_input_ids.shape\n\n                # concatenate the model inputs with the cond inputs\n                packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)\n                model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)\n\n                # handle guidance\n                if unwrap_model(transformer).config.guidance_embeds:\n                    guidance = torch.full([1], args.guidance_scale, device=accelerator.device)\n                    guidance = guidance.expand(model_input.shape[0])\n                else:\n                    guidance = None\n\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=packed_noisy_model_input,  # (B, image_seq_len, C)\n                    timestep=timesteps / 1000,\n                    guidance=guidance,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,  # B, text_seq_len, 4\n                    img_ids=model_input_ids,  # B, image_seq_len, 4\n                    return_dict=False,\n                )[0]\n                # pruning the condition information\n                model_pred = model_pred[:, : orig_input_shape[1], :]\n                model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]\n\n                model_pred = Flux2KleinPipeline._unpack_latents_with_ids(model_pred, model_input_ids)\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = noise - model_input\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process or is_fsdp:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = Flux2KleinPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    text_encoder=None,\n                    tokenizer=None,\n                    transformer=unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=validation_kwargs,\n                    epoch=epoch,\n                    torch_dtype=weight_dtype,\n                )\n\n                del pipeline\n                free_memory()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n\n    if is_fsdp:\n        transformer = unwrap_model(transformer)\n        state_dict = accelerator.get_state_dict(transformer)\n    if accelerator.is_main_process:\n        modules_to_save = {}\n        if is_fsdp:\n            if args.bnb_quantization_config_path is None:\n                if args.upcast_before_saving:\n                    state_dict = {\n                        k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()\n                    }\n                else:\n                    state_dict = {\n                        k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()\n                    }\n\n            transformer_lora_layers = get_peft_model_state_dict(\n                transformer,\n                state_dict=state_dict,\n            )\n            transformer_lora_layers = {\n                k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v\n                for k, v in transformer_lora_layers.items()\n            }\n\n        else:\n            transformer = unwrap_model(transformer)\n            if args.bnb_quantization_config_path is None:\n                if args.upcast_before_saving:\n                    transformer.to(torch.float32)\n                else:\n                    transformer = transformer.to(weight_dtype)\n            transformer_lora_layers = get_peft_model_state_dict(transformer)\n\n        modules_to_save[\"transformer\"] = transformer\n\n        Flux2KleinPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n            **_collate_lora_metadata(modules_to_save),\n        )\n\n        images = []\n        run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)\n        should_run_final_inference = not args.skip_final_inference and run_validation\n        if should_run_final_inference:\n            pipeline = Flux2KleinPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                revision=args.revision,\n                variant=args.variant,\n                torch_dtype=weight_dtype,\n            )\n            # load attention processors\n            pipeline.load_lora_weights(args.output_dir)\n\n            # run inference\n            images = []\n            if args.validation_prompt and args.num_validation_images > 0:\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=validation_kwargs,\n                    epoch=epoch,\n                    is_final_validation=True,\n                    torch_dtype=weight_dtype,\n                )\n            del pipeline\n            free_memory()\n\n        validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt\n        save_model_card(\n            (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,\n            images=images,\n            base_model=args.pretrained_model_name_or_path,\n            instance_prompt=args.instance_prompt,\n            validation_prompt=validation_prompt,\n            repo_folder=args.output_dir,\n            fp8_training=args.do_fp8_training,\n        )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_flux_kontext.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=0.31.0\",\n#     \"transformers>=4.41.2\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.11.1\",\n#     \"sentencepiece\",\n#     \"torchvision\",\n#     \"datasets\",\n#     \"bitsandbytes\",\n#     \"prodigyopt\",\n# ]\n# ///\n\nimport argparse\nimport copy\nimport itertools\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport transformers\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.logging import get_logger\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torch.utils.data.sampler import BatchSampler\nfrom torchvision import transforms\nfrom torchvision.transforms import functional as TF\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    FlowMatchEulerDiscreteScheduler,\n    FluxKontextPipeline,\n    FluxTransformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    _collate_lora_metadata,\n    _set_state_dict_into_text_encoder,\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    find_nearest_bucket,\n    free_memory,\n    parse_buckets_string,\n)\nfrom diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, load_image\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\nif is_torch_npu_available():\n    torch.npu.config.allow_internal_format = False\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    train_text_encoder=False,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# Flux Kontext DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).\n\nWas LoRA for the text encoder enabled? {train_text_encoder}.\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import FluxKontextPipeline\nimport torch\npipeline = FluxKontextPipeline.from_pretrained(\"black-forest-labs/FLUX.1-Kontext-dev\", torch_dtype=torch.bfloat16).to('cuda')\npipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"flux\",\n        \"flux-kontextflux-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef load_text_encoders(class_one, class_two):\n    text_encoder_one = class_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = class_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    return text_encoder_one, text_encoder_two\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)\n    pipeline.set_progress_bar_config(disable=True)\n    pipeline_args_cp = pipeline_args.copy()\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n\n    # pre-calculate  prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast\n    with torch.no_grad():\n        prompt = pipeline_args_cp.pop(\"prompt\")\n        prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None)\n    images = []\n    for _ in range(args.num_validation_images):\n        with autocast_ctx:\n            image = pipeline(\n                **pipeline_args_cp,\n                prompt_embeds=prompt_embeds,\n                pooled_prompt_embeds=pooled_prompt_embeds,\n                generator=generator,\n            ).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--vae_encode_mode\",\n        type=str,\n        default=\"mode\",\n        choices=[\"sample\", \"mode\"],\n        help=\"VAE encoding mode.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--cond_image_column\",\n        type=str,\n        default=None,\n        help=\"Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=512,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        help=\"Validation image to use (during I2I fine-tuning) to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=4,\n        help=\"LoRA alpha to be used for additional scaling.\",\n    )\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-kontext-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--aspect_ratio_buckets\",\n        type=str,\n        default=None,\n        help=(\n            \"Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. \"\n            \"e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'\"\n            \"Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored.\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the FLUX.1 dev variant is a guidance distilled model\",\n    )\n\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v,to_out.0\" will result in lora training of attention layers only'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\"--enable_npu_flash_attention\", action=\"store_true\", help=\"Enabla Flash Attention for NPU\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n        if args.cond_image_column is not None:\n            raise ValueError(\"Prior preservation isn't supported with I2I training.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    if args.cond_image_column is not None:\n        assert args.image_column is not None\n        assert args.caption_column is not None\n        assert args.dataset_name is not None\n        assert not args.train_text_encoder\n        if args.validation_prompt is not None:\n            assert args.validation_image is None and os.path.exists(args.validation_image)\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        repeats=1,\n        center_crop=False,\n        buckets=None,\n        args=None,\n    ):\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        self.buckets = buckets\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.cond_image_column is not None and args.cond_image_column not in column_names:\n                raise ValueError(\n                    f\"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                )\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = [dataset[\"train\"][i][image_column] for i in range(len(dataset[\"train\"]))]\n            cond_images = None\n            cond_image_column = args.cond_image_column\n            if cond_image_column is not None:\n                cond_images = [dataset[\"train\"][i][cond_image_column] for i in range(len(dataset[\"train\"]))]\n                assert len(instance_images) == len(cond_images)\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        self.cond_images = []\n        for i, img in enumerate(instance_images):\n            self.instance_images.extend(itertools.repeat(img, repeats))\n            if args.dataset_name is not None and cond_images is not None:\n                self.cond_images.extend(itertools.repeat(cond_images[i], repeats))\n\n        self.pixel_values = []\n        self.cond_pixel_values = []\n        for i, image in enumerate(self.instance_images):\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            dest_image = None\n            if self.cond_images:\n                dest_image = exif_transpose(self.cond_images[i])\n                if not dest_image.mode == \"RGB\":\n                    dest_image = dest_image.convert(\"RGB\")\n\n            width, height = image.size\n\n            # Find the closest bucket\n            bucket_idx = find_nearest_bucket(height, width, self.buckets)\n            target_height, target_width = self.buckets[bucket_idx]\n            self.size = (target_height, target_width)\n\n            # based on the bucket assignment, define the transformations\n            image, dest_image = self.paired_transform(\n                image,\n                dest_image=dest_image,\n                size=self.size,\n                center_crop=args.center_crop,\n                random_flip=args.random_flip,\n            )\n            self.pixel_values.append((image, bucket_idx))\n            if dest_image is not None:\n                self.cond_pixel_values.append((dest_image, bucket_idx))\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n        example[\"bucket_idx\"] = bucket_idx\n        if self.cond_pixel_values:\n            dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]\n            example[\"cond_images\"] = dest_image\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n    def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):\n        # 1. Resize (deterministic)\n        resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        image = resize(image)\n        if dest_image is not None:\n            dest_image = resize(dest_image)\n\n        # 2. Crop: either center or SAME random crop\n        if center_crop:\n            crop = transforms.CenterCrop(size)\n            image = crop(image)\n            if dest_image is not None:\n                dest_image = crop(dest_image)\n        else:\n            # get_params returns (i, j, h, w)\n            i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)\n            image = TF.crop(image, i, j, h, w)\n            if dest_image is not None:\n                dest_image = TF.crop(dest_image, i, j, h, w)\n\n        # 3. Random horizontal flip with the SAME coin flip\n        if random_flip:\n            do_flip = random.random() < 0.5\n            if do_flip:\n                image = TF.hflip(image)\n                if dest_image is not None:\n                    dest_image = TF.hflip(dest_image)\n\n        # 4. ToTensor + Normalize (deterministic)\n        to_tensor = transforms.ToTensor()\n        normalize = transforms.Normalize([0.5], [0.5])\n        image = normalize(to_tensor(image))\n        if dest_image is not None:\n            dest_image = normalize(to_tensor(dest_image))\n\n        return (image, dest_image) if dest_image is not None else (image, None)\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    if any(\"cond_images\" in example for example in examples):\n        cond_pixel_values = [example[\"cond_images\"] for example in examples]\n        cond_pixel_values = torch.stack(cond_pixel_values)\n        cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()\n        batch.update({\"cond_pixel_values\": cond_pixel_values})\n    return batch\n\n\nclass BucketBatchSampler(BatchSampler):\n    def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):\n        if not isinstance(batch_size, int) or batch_size <= 0:\n            raise ValueError(\"batch_size should be a positive integer value, but got batch_size={}\".format(batch_size))\n        if not isinstance(drop_last, bool):\n            raise ValueError(\"drop_last should be a boolean value, but got drop_last={}\".format(drop_last))\n\n        self.dataset = dataset\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n\n        # Group indices by bucket\n        self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]\n        for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):\n            self.bucket_indices[bucket_idx].append(idx)\n\n        self.sampler_len = 0\n        self.batches = []\n\n        # Pre-generate batches for each bucket\n        for indices_in_bucket in self.bucket_indices:\n            # Shuffle indices within the bucket\n            random.shuffle(indices_in_bucket)\n            # Create batches\n            for i in range(0, len(indices_in_bucket), self.batch_size):\n                batch = indices_in_bucket[i : i + self.batch_size]\n                if len(batch) < self.batch_size and self.drop_last:\n                    continue  # Skip partial batch if drop_last is True\n                self.batches.append(batch)\n                self.sampler_len += 1  # Count the number of batches\n\n    def __iter__(self):\n        # Shuffle the order of the batches each epoch\n        random.shuffle(self.batches)\n        for batch in self.batches:\n            yield batch\n\n    def __len__(self):\n        return self.sampler_len\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef tokenize_prompt(tokenizer, prompt, max_sequence_length):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=max_sequence_length,\n        truncation=True,\n        return_length=False,\n        return_overflowing_tokens=False,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\ndef _encode_prompt_with_t5(\n    text_encoder,\n    tokenizer,\n    max_sequence_length=512,\n    prompt=None,\n    num_images_per_prompt=1,\n    device=None,\n    text_input_ids=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device))[0]\n\n    if hasattr(text_encoder, \"module\"):\n        dtype = text_encoder.module.dtype\n    else:\n        dtype = text_encoder.dtype\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    _, seq_len, _ = prompt_embeds.shape\n\n    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds\n\n\ndef _encode_prompt_with_clip(\n    text_encoder,\n    tokenizer,\n    prompt: str,\n    device=None,\n    text_input_ids=None,\n    num_images_per_prompt: int = 1,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=77,\n            truncation=True,\n            return_overflowing_tokens=False,\n            return_length=False,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n    if hasattr(text_encoder, \"module\"):\n        dtype = text_encoder.module.dtype\n    else:\n        dtype = text_encoder.dtype\n    # Use pooled output of CLIPTextModel\n    prompt_embeds = prompt_embeds.pooler_output\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    # duplicate text embeddings for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n    return prompt_embeds\n\n\ndef encode_prompt(\n    text_encoders,\n    tokenizers,\n    prompt: str,\n    max_sequence_length,\n    device=None,\n    num_images_per_prompt: int = 1,\n    text_input_ids_list=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n\n    if hasattr(text_encoders[0], \"module\"):\n        dtype = text_encoders[0].module.dtype\n    else:\n        dtype = text_encoders[0].dtype\n\n    pooled_prompt_embeds = _encode_prompt_with_clip(\n        text_encoder=text_encoders[0],\n        tokenizer=tokenizers[0],\n        prompt=prompt,\n        device=device if device is not None else text_encoders[0].device,\n        num_images_per_prompt=num_images_per_prompt,\n        text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,\n    )\n\n    prompt_embeds = _encode_prompt_with_t5(\n        text_encoder=text_encoders[1],\n        tokenizer=tokenizers[1],\n        max_sequence_length=max_sequence_length,\n        prompt=prompt,\n        num_images_per_prompt=num_images_per_prompt,\n        device=device if device is not None else text_encoders[1].device,\n        text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,\n    )\n\n    text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\n\n    return prompt_embeds, pooled_prompt_embeds, text_ids\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    if accelerator.distributed_type == DistributedType.DEEPSPEED:\n        AcceleratorState().deepspeed_plugin.deepspeed_config[\"train_micro_batch_size_per_gpu\"] = args.train_batch_size\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()\n            torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n\n            transformer = FluxTransformer2DModel.from_pretrained(\n                args.pretrained_model_name_or_path,\n                subfolder=\"transformer\",\n                revision=args.revision,\n                variant=args.variant,\n                torch_dtype=torch_dtype,\n            )\n            pipeline = FluxKontextPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                transformer=transformer,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):\n                    images = pipeline(prompt=example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n    tokenizer_two = T5TokenizerFast.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    transformer = FluxTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            transformer.set_attention_backend(\"_native_npu\")\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu device \")\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    vae.to(accelerator.device, dtype=weight_dtype)\n    transformer.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder_one.gradient_checkpointing_enable()\n\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\n            \"attn.to_k\",\n            \"attn.to_q\",\n            \"attn.to_v\",\n            \"attn.to_out.0\",\n            \"attn.add_k_proj\",\n            \"attn.add_q_proj\",\n            \"attn.add_v_proj\",\n            \"attn.to_add_out\",\n            \"ff.net.0.proj\",\n            \"ff.net.2\",\n            \"ff_context.net.0.proj\",\n            \"ff_context.net.2\",\n            \"proj_mlp\",\n        ]\n\n    # now we will add new LoRA weights the transformer layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n    if args.train_text_encoder:\n        text_lora_config = LoraConfig(\n            r=args.rank,\n            lora_alpha=args.lora_alpha,\n            lora_dropout=args.lora_dropout,\n            init_lora_weights=\"gaussian\",\n            target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n        )\n        text_encoder_one.add_adapter(text_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n            text_encoder_one_lora_layers_to_save = None\n            modules_to_save = {}\n            for model in models:\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    model = unwrap_model(model)\n                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                    modules_to_save[\"transformer\"] = model\n                elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):\n                    model = unwrap_model(model)\n                    text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)\n                    modules_to_save[\"text_encoder\"] = model\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                if weights:\n                    weights.pop()\n\n            FluxKontextPipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n                text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,\n                **_collate_lora_metadata(modules_to_save),\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n        text_encoder_one_ = None\n\n        if not accelerator.distributed_type == DistributedType.DEEPSPEED:\n            while len(models) > 0:\n                model = models.pop()\n\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    transformer_ = unwrap_model(model)\n                elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):\n                    text_encoder_one_ = unwrap_model(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        else:\n            transformer_ = FluxTransformer2DModel.from_pretrained(\n                args.pretrained_model_name_or_path, subfolder=\"transformer\"\n            )\n            transformer_.add_adapter(transformer_lora_config)\n            text_encoder_one_ = text_encoder_cls_one.from_pretrained(\n                args.pretrained_model_name_or_path, subfolder=\"text_encoder\"\n            )\n\n        lora_state_dict = FluxKontextPipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n        if args.train_text_encoder:\n            # Do we need to call `scale_lora_layers()` here?\n            _set_state_dict_into_text_encoder(lora_state_dict, prefix=\"text_encoder.\", text_encoder=text_encoder_one_)\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            if args.train_text_encoder:\n                models.extend([text_encoder_one_])\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        if args.train_text_encoder:\n            models.extend([text_encoder_one])\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n    if args.train_text_encoder:\n        text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    if args.train_text_encoder:\n        # different learning rate for text encoder and unet\n        text_parameters_one_with_lr = {\n            \"params\": text_lora_parameters_one,\n            \"weight_decay\": args.adam_weight_decay_text_encoder,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]\n    else:\n        params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n        if args.train_text_encoder and args.text_encoder_lr:\n            logger.warning(\n                f\"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:\"\n                f\" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. \"\n                f\"When using prodigy only learning_rate is used as the initial learning rate.\"\n            )\n            # changes the learning rate of text_encoder_parameters_one to be\n            # --learning_rate\n            params_to_optimize[1][\"lr\"] = args.learning_rate\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    if args.aspect_ratio_buckets is not None:\n        buckets = parse_buckets_string(args.aspect_ratio_buckets)\n    else:\n        buckets = [(args.resolution, args.resolution)]\n    logger.info(f\"Using parsed aspect ratio buckets: {buckets}\")\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        buckets=buckets,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n        args=args,\n    )\n    if args.cond_image_column is not None:\n        logger.info(\"I2I fine-tuning enabled.\")\n    batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_sampler=batch_sampler,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    if not args.train_text_encoder:\n        tokenizers = [tokenizer_one, tokenizer_two]\n        text_encoders = [text_encoder_one, text_encoder_two]\n\n        def compute_text_embeddings(prompt, text_encoders, tokenizers):\n            with torch.no_grad():\n                prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n                    text_encoders, tokenizers, prompt, args.max_sequence_length\n                )\n                prompt_embeds = prompt_embeds.to(accelerator.device)\n                pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)\n                text_ids = text_ids.to(accelerator.device)\n            return prompt_embeds, pooled_prompt_embeds, text_ids\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings(\n            args.instance_prompt, text_encoders, tokenizers\n        )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        if not args.train_text_encoder:\n            class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings(\n                args.class_prompt, text_encoders, tokenizers\n            )\n\n    # Clear the memory here\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        text_encoder_one.cpu(), text_encoder_two.cpu()\n        del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two\n        free_memory()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n\n    if not train_dataset.custom_instance_prompts:\n        if not args.train_text_encoder:\n            prompt_embeds = instance_prompt_hidden_states\n            pooled_prompt_embeds = instance_pooled_prompt_embeds\n            text_ids = instance_text_ids\n            if args.with_prior_preservation:\n                prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n                pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)\n                text_ids = torch.cat([text_ids, class_text_ids], dim=0)\n        # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts)\n        # we need to tokenize and encode the batch prompts on all training steps\n        else:\n            tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77)\n            tokens_two = tokenize_prompt(\n                tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length\n            )\n            if args.with_prior_preservation:\n                class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77)\n                class_tokens_two = tokenize_prompt(\n                    tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length\n                )\n                tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)\n                tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)\n\n    elif train_dataset.custom_instance_prompts and not args.train_text_encoder:\n        cached_text_embeddings = []\n        for batch in tqdm(train_dataloader, desc=\"Embedding prompts\"):\n            batch_prompts = batch[\"prompts\"]\n            prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(\n                batch_prompts, text_encoders, tokenizers\n            )\n            cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids))\n\n        if args.validation_prompt is None:\n            text_encoder_one.cpu(), text_encoder_two.cpu()\n            del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two\n            free_memory()\n\n    vae_config_shift_factor = vae.config.shift_factor\n    vae_config_scaling_factor = vae.config.scaling_factor\n    vae_config_block_out_channels = vae.config.block_out_channels\n    has_image_input = args.cond_image_column is not None\n    if args.cache_latents:\n        latents_cache = []\n        cond_latents_cache = []\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                    accelerator.device, non_blocking=True, dtype=weight_dtype\n                )\n                latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n                if has_image_input:\n                    batch[\"cond_pixel_values\"] = batch[\"cond_pixel_values\"].to(\n                        accelerator.device, non_blocking=True, dtype=weight_dtype\n                    )\n                    cond_latents_cache.append(vae.encode(batch[\"cond_pixel_values\"]).latent_dist)\n\n        if args.validation_prompt is None:\n            vae.cpu()\n            del vae\n            free_memory()\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        (\n            transformer,\n            text_encoder_one,\n            optimizer,\n            train_dataloader,\n            lr_scheduler,\n        ) = accelerator.prepare(\n            transformer,\n            text_encoder_one,\n            optimizer,\n            train_dataloader,\n            lr_scheduler,\n        )\n    else:\n        transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            transformer, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-flux-kontext-lora\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    has_guidance = unwrap_model(transformer).config.guidance_embeds\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n        if args.train_text_encoder:\n            text_encoder_one.train()\n            # set top parameter requires_grad = True for gradient checkpointing works\n            unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            if args.train_text_encoder:\n                models_to_accumulate.extend([text_encoder_one])\n            with accelerator.accumulate(models_to_accumulate):\n                prompts = batch[\"prompts\"]\n\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    if not args.train_text_encoder:\n                        prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step]\n                    else:\n                        tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)\n                        tokens_two = tokenize_prompt(\n                            tokenizer_two, prompts, max_sequence_length=args.max_sequence_length\n                        )\n                        prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n                            text_encoders=[text_encoder_one, text_encoder_two],\n                            tokenizers=[None, None],\n                            text_input_ids_list=[tokens_one, tokens_two],\n                            max_sequence_length=args.max_sequence_length,\n                            device=accelerator.device,\n                            prompt=prompts,\n                        )\n                else:\n                    elems_to_repeat = len(prompts)\n                    if args.train_text_encoder:\n                        prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n                            text_encoders=[text_encoder_one, text_encoder_two],\n                            tokenizers=[None, None],\n                            text_input_ids_list=[\n                                tokens_one.repeat(elems_to_repeat, 1),\n                                tokens_two.repeat(elems_to_repeat, 1),\n                            ],\n                            max_sequence_length=args.max_sequence_length,\n                            device=accelerator.device,\n                            prompt=args.instance_prompt,\n                        )\n                    else:\n                        prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(\n                            prompts, text_encoders, tokenizers\n                        )\n\n                # Convert images to latent space\n                if args.cache_latents:\n                    if args.vae_encode_mode == \"sample\":\n                        model_input = latents_cache[step].sample()\n                        if has_image_input:\n                            cond_model_input = cond_latents_cache[step].sample()\n                    else:\n                        model_input = latents_cache[step].mode()\n                        if has_image_input:\n                            cond_model_input = cond_latents_cache[step].mode()\n                else:\n                    pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    if has_image_input:\n                        cond_pixel_values = batch[\"cond_pixel_values\"].to(dtype=vae.dtype)\n                    if args.vae_encode_mode == \"sample\":\n                        model_input = vae.encode(pixel_values).latent_dist.sample()\n                        if has_image_input:\n                            cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample()\n                    else:\n                        model_input = vae.encode(pixel_values).latent_dist.mode()\n                        if has_image_input:\n                            cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()\n                model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor\n                model_input = model_input.to(dtype=weight_dtype)\n                if has_image_input:\n                    cond_model_input = (cond_model_input - vae_config_shift_factor) * vae_config_scaling_factor\n                    cond_model_input = cond_model_input.to(dtype=weight_dtype)\n\n                vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)\n\n                latent_image_ids = FluxKontextPipeline._prepare_latent_image_ids(\n                    model_input.shape[0],\n                    model_input.shape[2] // 2,\n                    model_input.shape[3] // 2,\n                    accelerator.device,\n                    weight_dtype,\n                )\n                if has_image_input:\n                    cond_latents_ids = FluxKontextPipeline._prepare_latent_image_ids(\n                        cond_model_input.shape[0],\n                        cond_model_input.shape[2] // 2,\n                        cond_model_input.shape[3] // 2,\n                        accelerator.device,\n                        weight_dtype,\n                    )\n                    cond_latents_ids[..., 0] = 1\n                    latent_image_ids = torch.cat([latent_image_ids, cond_latents_ids], dim=0)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n                packed_noisy_model_input = FluxKontextPipeline._pack_latents(\n                    noisy_model_input,\n                    batch_size=model_input.shape[0],\n                    num_channels_latents=model_input.shape[1],\n                    height=model_input.shape[2],\n                    width=model_input.shape[3],\n                )\n                orig_inp_shape = packed_noisy_model_input.shape\n                if has_image_input:\n                    packed_cond_input = FluxKontextPipeline._pack_latents(\n                        cond_model_input,\n                        batch_size=cond_model_input.shape[0],\n                        num_channels_latents=cond_model_input.shape[1],\n                        height=cond_model_input.shape[2],\n                        width=cond_model_input.shape[3],\n                    )\n                    packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1)\n\n                # Kontext always has guidance\n                guidance = None\n                if has_guidance:\n                    guidance = torch.tensor([args.guidance_scale], device=accelerator.device)\n                    guidance = guidance.expand(model_input.shape[0])\n\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=packed_noisy_model_input,\n                    # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)\n                    timestep=timesteps / 1000,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    return_dict=False,\n                )[0]\n                if has_image_input:\n                    model_pred = model_pred[:, : orig_inp_shape[1]]\n                model_pred = FluxKontextPipeline._unpack_latents(\n                    model_pred,\n                    height=model_input.shape[2] * vae_scale_factor,\n                    width=model_input.shape[3] * vae_scale_factor,\n                    vae_scale_factor=vae_scale_factor,\n                )\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = noise - model_input\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(transformer.parameters(), text_encoder_one.parameters())\n                        if args.train_text_encoder\n                        else transformer.parameters()\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                if not args.train_text_encoder:\n                    text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)\n                    text_encoder_one.to(weight_dtype)\n                    text_encoder_two.to(weight_dtype)\n                pipeline = FluxKontextPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    text_encoder=unwrap_model(text_encoder_one),\n                    text_encoder_2=unwrap_model(text_encoder_two),\n                    transformer=unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n                if has_image_input and args.validation_image:\n                    pipeline_args.update({\"image\": load_image(args.validation_image)})\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                    torch_dtype=weight_dtype,\n                )\n                if not args.train_text_encoder:\n                    del text_encoder_one, text_encoder_two\n                    free_memory()\n\n                images = None\n                free_memory()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        modules_to_save = {}\n        transformer = unwrap_model(transformer)\n        if args.upcast_before_saving:\n            transformer.to(torch.float32)\n        else:\n            transformer = transformer.to(weight_dtype)\n        transformer_lora_layers = get_peft_model_state_dict(transformer)\n        modules_to_save[\"transformer\"] = transformer\n\n        if args.train_text_encoder:\n            text_encoder_one = unwrap_model(text_encoder_one)\n            text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))\n            modules_to_save[\"text_encoder\"] = text_encoder_one\n        else:\n            text_encoder_lora_layers = None\n\n        FluxKontextPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n            text_encoder_lora_layers=text_encoder_lora_layers,\n            **_collate_lora_metadata(modules_to_save),\n        )\n\n        # Final inference\n        # Load previous pipeline\n        transformer = FluxTransformer2DModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n        )\n        pipeline = FluxKontextPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            transformer=transformer,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt}\n            if has_image_input and args.validation_image:\n                pipeline_args.update({\"image\": load_image(args.validation_image)})\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=pipeline_args,\n                epoch=epoch,\n                is_final_validation=True,\n                torch_dtype=weight_dtype,\n            )\n            del pipeline\n            free_memory()\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                train_text_encoder=args.train_text_encoder,\n                instance_prompt=args.instance_prompt,\n                validation_prompt=args.validation_prompt,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        images = None\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_hidream.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport transformers\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, CLIPTokenizer, LlamaForCausalLM, PretrainedConfig, T5Tokenizer\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    BitsAndBytesConfig,\n    FlowMatchEulerDiscreteScheduler,\n    HiDreamImagePipeline,\n    HiDreamImageTransformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    free_memory,\n    offload_models,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\nif is_torch_npu_available():\n    torch.npu.config.allow_internal_format = False\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# HiDream Image DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [HiDream Image diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_hidream.md).\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\n    >>> import torch\n    >>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM\n    >>> from diffusers import HiDreamImagePipeline\n\n    >>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n    >>> text_encoder_4 = LlamaForCausalLM.from_pretrained(\n    ...     \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    ...     output_hidden_states=True,\n    ...     output_attentions=True,\n    ...     torch_dtype=torch.bfloat16,\n    ... )\n\n    >>> pipe = HiDreamImagePipeline.from_pretrained(\n    ...     \"HiDream-ai/HiDream-I1-Full\",\n    ...     tokenizer_4=tokenizer_4,\n    ...     text_encoder_4=text_encoder_4,\n    ...     torch_dtype=torch.bfloat16,\n    ... )\n    >>> pipe.enable_model_cpu_offload()\n    >>> pipe.load_lora_weights(f\"{repo_id}\")\n    >>> image = pipe(f\"{instance_prompt}\").images[0]\n\n\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"mit\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"hidream\",\n        \"hidream-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef load_text_encoders(class_one, class_two, class_three):\n    text_encoder_one = class_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = class_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_three = class_three.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_3\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_four = LlamaForCausalLM.from_pretrained(\n        args.pretrained_text_encoder_4_name_or_path,\n        output_hidden_states=True,\n        output_attentions=True,\n        torch_dtype=torch.bfloat16,\n    )\n    return text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    args.num_validation_images = args.num_validation_images if args.num_validation_images else 1\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n\n    images = []\n    for _ in range(args.num_validation_images):\n        with autocast_ctx:\n            image = pipeline(\n                prompt_embeds_t5=pipeline_args[\"prompt_embeds_t5\"],\n                prompt_embeds_llama3=pipeline_args[\"prompt_embeds_llama3\"],\n                negative_prompt_embeds_t5=pipeline_args[\"negative_prompt_embeds_t5\"],\n                negative_prompt_embeds_llama3=pipeline_args[\"negative_prompt_embeds_llama3\"],\n                pooled_prompt_embeds=pipeline_args[\"pooled_prompt_embeds\"],\n                negative_pooled_prompt_embeds=pipeline_args[\"negative_pooled_prompt_embeds\"],\n                generator=generator,\n            ).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n    if model_class == \"CLIPTextModelWithProjection\" or model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_tokenizer_4_name_or_path\",\n        type=str,\n        default=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_text_encoder_4_name_or_path\",\n        type=str,\n        default=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--bnb_quantization_config_path\",\n        type=str,\n        default=None,\n        help=\"Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=128,\n        help=\"Maximum sequence length to use with t5 and llama encoders\",\n    )\n\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n\n    parser.add_argument(\n        \"--skip_final_inference\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.\",\n    )\n\n    parser.add_argument(\n        \"--final_validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"hidream-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v\" will result in lora training of attention layers only'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoder to CPU when they are not used.\",\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        self.pixel_values = []\n        train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in self.instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            image = train_transforms(image)\n            self.pixel_values.append(image)\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            pipeline = HiDreamImagePipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch.bfloat16 if args.mixed_precision == \"bf16\" else torch.float16,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            pipeline.to(\"cpu\")\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n    tokenizer_two = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n    )\n    tokenizer_three = T5Tokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_3\",\n        revision=args.revision,\n    )\n\n    tokenizer_four = AutoTokenizer.from_pretrained(\n        args.pretrained_tokenizer_4_name_or_path,\n        revision=args.revision,\n    )\n    tokenizer_four.pad_token = tokenizer_four.eos_token\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n    text_encoder_cls_three = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_3\"\n    )\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\", revision=args.revision, shift=3.0\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(\n        text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    quantization_config = None\n    if args.bnb_quantization_config_path is not None:\n        with open(args.bnb_quantization_config_path, \"r\") as f:\n            config_kwargs = json.load(f)\n            if \"load_in_4bit\" in config_kwargs and config_kwargs[\"load_in_4bit\"]:\n                config_kwargs[\"bnb_4bit_compute_dtype\"] = weight_dtype\n        quantization_config = BitsAndBytesConfig(**config_kwargs)\n\n    transformer = HiDreamImageTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n        quantization_config=quantization_config,\n        torch_dtype=weight_dtype,\n        force_inference_output=True,\n    )\n    if args.bnb_quantization_config_path is not None:\n        transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    text_encoder_three.requires_grad_(False)\n    text_encoder_four.requires_grad_(False)\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    to_kwargs = {\"dtype\": weight_dtype, \"device\": accelerator.device} if not args.offload else {\"dtype\": weight_dtype}\n    # flux vae is stable in bf16 so load it in weight_dtype to reduce memory\n    vae.to(**to_kwargs)\n    text_encoder_one.to(**to_kwargs)\n    text_encoder_two.to(**to_kwargs)\n    text_encoder_three.to(**to_kwargs)\n    text_encoder_four.to(**to_kwargs)\n    # we never offload the transformer to CPU, so we can just use the accelerator device\n    transformer_to_kwargs = (\n        {\"device\": accelerator.device}\n        if args.bnb_quantization_config_path is not None\n        else {\"device\": accelerator.device, \"dtype\": weight_dtype}\n    )\n    transformer.to(**transformer_to_kwargs)\n\n    # Initialize a text encoding pipeline and keep it to CPU for now.\n    text_encoding_pipeline = HiDreamImagePipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=None,\n        transformer=None,\n        text_encoder=text_encoder_one,\n        tokenizer=tokenizer_one,\n        text_encoder_2=text_encoder_two,\n        tokenizer_2=tokenizer_two,\n        text_encoder_3=text_encoder_three,\n        tokenizer_3=tokenizer_three,\n        text_encoder_4=text_encoder_four,\n        tokenizer_4=tokenizer_four,\n    )\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out\"]\n\n    # now we will add new LoRA weights the transformer layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    model = unwrap_model(model)\n                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                if weights:\n                    weights.pop()\n\n            HiDreamImagePipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        if not accelerator.distributed_type == DistributedType.DEEPSPEED:\n            while len(models) > 0:\n                model = models.pop()\n\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    model = unwrap_model(model)\n                    transformer_ = model\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n        else:\n            transformer_ = HiDreamImageTransformer2DModel.from_pretrained(\n                args.pretrained_model_name_or_path, subfolder=\"transformer\"\n            )\n            transformer_.add_adapter(transformer_lora_config)\n\n        lora_state_dict = HiDreamImagePipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    def compute_text_embeddings(prompt, text_encoding_pipeline):\n        with torch.no_grad():\n            (\n                t5_prompt_embeds,\n                negative_prompt_embeds_t5,\n                llama3_prompt_embeds,\n                negative_prompt_embeds_llama3,\n                pooled_prompt_embeds,\n                negative_pooled_prompt_embeds,\n            ) = text_encoding_pipeline.encode_prompt(prompt=prompt, max_sequence_length=args.max_sequence_length)\n        return (\n            t5_prompt_embeds,\n            llama3_prompt_embeds,\n            pooled_prompt_embeds,\n            negative_prompt_embeds_t5,\n            negative_prompt_embeds_llama3,\n            negative_pooled_prompt_embeds,\n        )\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not train_dataset.custom_instance_prompts:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            (\n                instance_prompt_hidden_states_t5,\n                instance_prompt_hidden_states_llama3,\n                instance_pooled_prompt_embeds,\n                _,\n                _,\n                _,\n            ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (\n                compute_text_embeddings(args.class_prompt, text_encoding_pipeline)\n            )\n\n    validation_embeddings = {}\n    if args.validation_prompt is not None:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            (\n                validation_embeddings[\"prompt_embeds_t5\"],\n                validation_embeddings[\"prompt_embeds_llama3\"],\n                validation_embeddings[\"pooled_prompt_embeds\"],\n                validation_embeddings[\"negative_prompt_embeds_t5\"],\n                validation_embeddings[\"negative_prompt_embeds_llama3\"],\n                validation_embeddings[\"negative_pooled_prompt_embeds\"],\n            ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n    if not train_dataset.custom_instance_prompts:\n        t5_prompt_embeds = instance_prompt_hidden_states_t5\n        llama3_prompt_embeds = instance_prompt_hidden_states_llama3\n        pooled_prompt_embeds = instance_pooled_prompt_embeds\n        if args.with_prior_preservation:\n            t5_prompt_embeds = torch.cat([instance_prompt_hidden_states_t5, class_prompt_hidden_states_t5], dim=0)\n            llama3_prompt_embeds = torch.cat(\n                [instance_prompt_hidden_states_llama3, class_prompt_hidden_states_llama3], dim=0\n            )\n            pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)\n\n    vae_config_scaling_factor = vae.config.scaling_factor\n    vae_config_shift_factor = vae.config.shift_factor\n\n    # if cache_latents is set to True, we encode images to latents and store them.\n    # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided\n    # we encode them in advance as well.\n    precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts\n    if precompute_latents:\n        t5_prompt_cache = []\n        llama3_prompt_cache = []\n        pooled_prompt_cache = []\n        latents_cache = []\n        if args.offload:\n            vae = vae.to(accelerator.device)\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                if args.cache_latents:\n                    batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                        accelerator.device, non_blocking=True, dtype=vae.dtype\n                    )\n                    latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n                if train_dataset.custom_instance_prompts:\n                    text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)\n                    t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds, _, _, _ = compute_text_embeddings(\n                        batch[\"prompts\"], text_encoding_pipeline\n                    )\n                    t5_prompt_cache.append(t5_prompt_embeds)\n                    llama3_prompt_cache.append(llama3_prompt_embeds)\n                    pooled_prompt_cache.append(pooled_prompt_embeds)\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    if args.offload or args.cache_latents:\n        vae = vae.to(\"cpu\")\n        if args.cache_latents:\n            del vae\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    text_encoding_pipeline = text_encoding_pipeline.to(\"cpu\")\n    del (\n        text_encoder_one,\n        text_encoder_two,\n        text_encoder_three,\n        text_encoder_four,\n        tokenizer_two,\n        tokenizer_three,\n        tokenizer_four,\n        text_encoding_pipeline,\n    )\n    free_memory()\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-hidream-lora\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            prompts = batch[\"prompts\"]\n\n            with accelerator.accumulate(models_to_accumulate):\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    t5_prompt_embeds = t5_prompt_cache[step]\n                    llama3_prompt_embeds = llama3_prompt_cache[step]\n                    pooled_prompt_embeds = pooled_prompt_cache[step]\n                else:\n                    t5_prompt_embeds = t5_prompt_embeds.repeat(len(prompts), 1, 1)\n                    llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, len(prompts), 1, 1)\n                    pooled_prompt_embeds = pooled_prompt_embeds.repeat(len(prompts), 1)\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step].sample()\n                else:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent_dist.sample()\n\n                model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor\n                model_input = model_input.to(dtype=weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=noisy_model_input,\n                    encoder_hidden_states_t5=t5_prompt_embeds,\n                    encoder_hidden_states_llama3=llama3_prompt_embeds,\n                    pooled_embeds=pooled_prompt_embeds,\n                    timesteps=timesteps,\n                    return_dict=False,\n                )[0]\n                model_pred = model_pred * -1\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                target = noise - model_input\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = HiDreamImagePipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    tokenizer=None,\n                    text_encoder=None,\n                    tokenizer_2=None,\n                    text_encoder_2=None,\n                    tokenizer_3=None,\n                    text_encoder_3=None,\n                    tokenizer_4=None,\n                    text_encoder_4=None,\n                    transformer=accelerator.unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=validation_embeddings,\n                    torch_dtype=weight_dtype,\n                    epoch=epoch,\n                )\n                del pipeline\n                images = None\n                free_memory()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        transformer = unwrap_model(transformer)\n        if args.bnb_quantization_config_path is None:\n            if args.upcast_before_saving:\n                transformer.to(torch.float32)\n            else:\n                transformer = transformer.to(weight_dtype)\n        transformer_lora_layers = get_peft_model_state_dict(transformer)\n\n        HiDreamImagePipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n        )\n\n        images = []\n        run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)\n        should_run_final_inference = not args.skip_final_inference and run_validation\n        if should_run_final_inference:\n            # Final inference\n            # Load previous pipeline\n            pipeline = HiDreamImagePipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                tokenizer=None,\n                text_encoder=None,\n                tokenizer_2=None,\n                text_encoder_2=None,\n                tokenizer_3=None,\n                text_encoder_3=None,\n                tokenizer_4=None,\n                text_encoder_4=None,\n                revision=args.revision,\n                variant=args.variant,\n                torch_dtype=weight_dtype,\n            )\n            # load attention processors\n            pipeline.load_lora_weights(args.output_dir)\n\n            # run inference\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=validation_embeddings,\n                epoch=epoch,\n                is_final_validation=True,\n                torch_dtype=weight_dtype,\n            )\n            del pipeline\n            free_memory()\n\n        validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt\n        save_model_card(\n            (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,\n            images=images,\n            base_model=args.pretrained_model_name_or_path,\n            instance_prompt=args.instance_prompt,\n            validation_prompt=validation_prompt,\n            repo_folder=args.output_dir,\n        )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        images = None\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_lumina2.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport itertools\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, Gemma2Model\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    FlowMatchEulerDiscreteScheduler,\n    Lumina2Pipeline,\n    Lumina2Transformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    free_memory,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\nif is_torch_npu_available():\n    torch.npu.config.allow_internal_format = False\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    instance_prompt=None,\n    system_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# Lumina2 DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Lumina2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_lumina2.md).\n\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\nThe following `system_prompt` was also used used during training (ignore if `None`): {system_prompt}.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nTODO\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"apache-2.0\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"lumina2\",\n        \"lumina2-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n\n    with autocast_ctx:\n        images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {pipeline_args['prompt']}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n\n    return images\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=256,\n        help=\"Maximum sequence length to use with with the Gemma2 model\",\n    )\n    parser.add_argument(\n        \"--system_prompt\",\n        type=str,\n        default=None,\n        help=\"System prompt to use during inference to give the Gemma2 model certain characteristics.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--final_validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"lumina2-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v\" will result in lora training of attention layers only'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoder to CPU when they are not used.\",\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        self.pixel_values = []\n        interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)\n        if interpolation is None:\n            raise ValueError(f\"Unsupported interpolation mode: {args.image_interpolation_mode}\")\n\n        train_resize = transforms.Resize(size, interpolation=interpolation)\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in self.instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            image = train_transforms(image)\n            self.pixel_values.append(image)\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=interpolation),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            pipeline = Lumina2Pipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch.bfloat16 if args.mixed_precision == \"bf16\" else torch.float16,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizer\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\", revision=args.revision\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    text_encoder = Gemma2Model.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    transformer = Lumina2Transformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    # keep VAE in FP32 to ensure numerical stability.\n    vae.to(dtype=torch.float32)\n    transformer.to(accelerator.device, dtype=weight_dtype)\n    # because Gemma2 is particularly suited for bfloat16.\n    text_encoder.to(dtype=torch.bfloat16)\n\n    # Initialize a text encoding pipeline and keep it to CPU for now.\n    text_encoding_pipeline = Lumina2Pipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=None,\n        transformer=None,\n        text_encoder=text_encoder,\n        tokenizer=tokenizer,\n    )\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\n\n    # now we will add new LoRA weights the transformer layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(unwrap_model(transformer))):\n                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            Lumina2Pipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(transformer))):\n                transformer_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict = Lumina2Pipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    def compute_text_embeddings(prompt, text_encoding_pipeline):\n        text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)\n        with torch.no_grad():\n            prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(\n                prompt,\n                max_sequence_length=args.max_sequence_length,\n                system_prompt=args.system_prompt,\n            )\n        if args.offload:\n            text_encoding_pipeline = text_encoding_pipeline.to(\"cpu\")\n        prompt_embeds = prompt_embeds.to(transformer.dtype)\n        return prompt_embeds, prompt_attention_mask\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not train_dataset.custom_instance_prompts:\n        instance_prompt_hidden_states, instance_prompt_attention_mask = compute_text_embeddings(\n            args.instance_prompt, text_encoding_pipeline\n        )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        class_prompt_hidden_states, class_prompt_attention_mask = compute_text_embeddings(\n            args.class_prompt, text_encoding_pipeline\n        )\n\n    # Clear the memory here\n    if not train_dataset.custom_instance_prompts:\n        del text_encoder, tokenizer\n        free_memory()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n    if not train_dataset.custom_instance_prompts:\n        prompt_embeds = instance_prompt_hidden_states\n        prompt_attention_mask = instance_prompt_attention_mask\n        if args.with_prior_preservation:\n            prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n            prompt_attention_mask = torch.cat([prompt_attention_mask, class_prompt_attention_mask], dim=0)\n\n    vae_config_scaling_factor = vae.config.scaling_factor\n    vae_config_shift_factor = vae.config.shift_factor\n    if args.cache_latents:\n        latents_cache = []\n        vae = vae.to(accelerator.device)\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                    accelerator.device, non_blocking=True, dtype=vae.dtype\n                )\n                latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n\n        if args.validation_prompt is None:\n            del vae\n            free_memory()\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-lumina2-lora\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            prompts = batch[\"prompts\"]\n\n            with accelerator.accumulate(models_to_accumulate):\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    prompt_embeds, prompt_attention_mask = compute_text_embeddings(prompts, text_encoding_pipeline)\n\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step].sample()\n                else:\n                    vae = vae.to(accelerator.device)\n                    pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent_dist.sample()\n                    if args.offload:\n                        vae = vae.to(\"cpu\")\n                model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor\n                model_input = model_input.to(dtype=weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                # Lumina2 reverses the lerp i.e., sigma of 1.0 should mean `model_input`\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * noise + sigmas * model_input\n\n                # Predict the noise residual\n                # scale the timesteps (reversal not needed as we used a reverse lerp above already)\n                timesteps = timesteps / noise_scheduler.config.num_train_timesteps\n                model_pred = transformer(\n                    hidden_states=noisy_model_input,\n                    encoder_hidden_states=prompt_embeds.repeat(len(prompts), 1, 1)\n                    if not train_dataset.custom_instance_prompts\n                    else prompt_embeds,\n                    encoder_attention_mask=prompt_attention_mask.repeat(len(prompts), 1)\n                    if not train_dataset.custom_instance_prompts\n                    else prompt_attention_mask,\n                    timestep=timesteps,\n                    return_dict=False,\n                )[0]\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss (reversed)\n                target = model_input - noise\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = Lumina2Pipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    transformer=accelerator.unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt, \"system_prompt\": args.system_prompt}\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                )\n                free_memory()\n\n                images = None\n                del pipeline\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        transformer = unwrap_model(transformer)\n        if args.upcast_before_saving:\n            transformer.to(torch.float32)\n        else:\n            transformer = transformer.to(weight_dtype)\n        transformer_lora_layers = get_peft_model_state_dict(transformer)\n\n        Lumina2Pipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n        )\n\n        # Final inference\n        # Load previous pipeline\n        pipeline = Lumina2Pipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir)\n\n        # run inference\n        images = []\n        if (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt):\n            prompt_to_use = args.validation_prompt if args.validation_prompt else args.final_validation_prompt\n            args.num_validation_images = args.num_validation_images if args.num_validation_images else 1\n            pipeline_args = {\"prompt\": prompt_to_use, \"system_prompt\": args.system_prompt}\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=pipeline_args,\n                epoch=epoch,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            validation_prpmpt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                instance_prompt=args.instance_prompt,\n                system_prompt=args.system_prompt,\n                validation_prompt=validation_prpmpt,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        images = None\n        del pipeline\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_qwen_image.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=0.31.0\",\n#     \"transformers>=4.41.2\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.11.1\",\n#     \"sentencepiece\",\n#     \"torchvision\",\n#     \"datasets\",\n#     \"bitsandbytes\",\n#     \"prodigyopt\",\n# ]\n# ///\n\nimport argparse\nimport copy\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport transformers\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKLQwenImage,\n    BitsAndBytesConfig,\n    FlowMatchEulerDiscreteScheduler,\n    QwenImagePipeline,\n    QwenImageTransformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    _collate_lora_metadata,\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    free_memory,\n    offload_models,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\nif is_torch_npu_available():\n    torch.npu.config.allow_internal_format = False\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# HiDream Image DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Qwen Image diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_qwen.md).\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\n    >>> import torch\n    >>> from diffusers import QwenImagePipeline\n\n    >>> pipe = QwenImagePipeline.from_pretrained(\n    ...     \"Qwen/Qwen-Image\",\n    ...     torch_dtype=torch.bfloat16,\n    ... )\n    >>> pipe.enable_model_cpu_offload()\n    >>> pipe.load_lora_weights(f\"{repo_id}\")\n    >>> image = pipe(f\"{instance_prompt}\").images[0]\n\n\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"apache-2.0\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"qwen-image\",\n        \"qwen-image-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    args.num_validation_images = args.num_validation_images if args.num_validation_images else 1\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n\n    images = []\n    for _ in range(args.num_validation_images):\n        with autocast_ctx:\n            image = pipeline(\n                prompt_embeds=pipeline_args[\"prompt_embeds\"],\n                prompt_embeds_mask=pipeline_args[\"prompt_embeds_mask\"],\n                generator=generator,\n            ).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_tokenizer_4_name_or_path\",\n        type=str,\n        default=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_text_encoder_4_name_or_path\",\n        type=str,\n        default=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--bnb_quantization_config_path\",\n        type=str,\n        default=None,\n        help=\"Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=512,\n        help=\"Maximum sequence length to use with the Qwen2.5 VL as text encoder.\",\n    )\n\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n\n    parser.add_argument(\n        \"--skip_final_inference\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.\",\n    )\n\n    parser.add_argument(\n        \"--final_validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=4,\n        help=\"LoRA alpha to be used for additional scaling.\",\n    )\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"hidream-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v\" will result in lora training of attention layers only'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoder to CPU when they are not used.\",\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        self.pixel_values = []\n        train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in self.instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            image = train_transforms(image)\n            self.pixel_values.append(image)\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    # Qwen expects a `num_frames` dimension too.\n    if pixel_values.ndim == 4:\n        pixel_values = pixel_values.unsqueeze(2)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            pipeline = QwenImagePipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch.bfloat16 if args.mixed_precision == \"bf16\" else torch.float16,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            pipeline.to(\"cpu\")\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer = Qwen2Tokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\", revision=args.revision, shift=3.0\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    vae = AutoencoderKLQwenImage.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    vae_scale_factor = 2 ** len(vae.temperal_downsample)\n    latents_mean = (torch.tensor(vae.config.latents_mean).view(1, vae.config.z_dim, 1, 1, 1)).to(accelerator.device)\n    latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(accelerator.device)\n    text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, torch_dtype=weight_dtype\n    )\n    quantization_config = None\n    if args.bnb_quantization_config_path is not None:\n        with open(args.bnb_quantization_config_path, \"r\") as f:\n            config_kwargs = json.load(f)\n            if \"load_in_4bit\" in config_kwargs and config_kwargs[\"load_in_4bit\"]:\n                config_kwargs[\"bnb_4bit_compute_dtype\"] = weight_dtype\n        quantization_config = BitsAndBytesConfig(**config_kwargs)\n\n    transformer = QwenImageTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n        quantization_config=quantization_config,\n        torch_dtype=weight_dtype,\n    )\n    if args.bnb_quantization_config_path is not None:\n        transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    to_kwargs = {\"dtype\": weight_dtype, \"device\": accelerator.device} if not args.offload else {\"dtype\": weight_dtype}\n    # flux vae is stable in bf16 so load it in weight_dtype to reduce memory\n    vae.to(**to_kwargs)\n    text_encoder.to(**to_kwargs)\n    # we never offload the transformer to CPU, so we can just use the accelerator device\n    transformer_to_kwargs = (\n        {\"device\": accelerator.device}\n        if args.bnb_quantization_config_path is not None\n        else {\"device\": accelerator.device, \"dtype\": weight_dtype}\n    )\n    transformer.to(**transformer_to_kwargs)\n\n    # Initialize a text encoding pipeline and keep it to CPU for now.\n    text_encoding_pipeline = QwenImagePipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=None,\n        transformer=None,\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        scheduler=None,\n    )\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\n\n    # now we will add new LoRA weights the transformer layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n            modules_to_save = {}\n\n            for model in models:\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    model = unwrap_model(model)\n                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                    modules_to_save[\"transformer\"] = model\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                if weights:\n                    weights.pop()\n\n            QwenImagePipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n                **_collate_lora_metadata(modules_to_save),\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        if not accelerator.distributed_type == DistributedType.DEEPSPEED:\n            while len(models) > 0:\n                model = models.pop()\n\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    model = unwrap_model(model)\n                    transformer_ = model\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n        else:\n            transformer_ = QwenImageTransformer2DModel.from_pretrained(\n                args.pretrained_model_name_or_path, subfolder=\"transformer\"\n            )\n            transformer_.add_adapter(transformer_lora_config)\n\n        lora_state_dict = QwenImagePipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    def compute_text_embeddings(prompt, text_encoding_pipeline):\n        with torch.no_grad():\n            prompt_embeds, prompt_embeds_mask = text_encoding_pipeline.encode_prompt(\n                prompt=prompt, max_sequence_length=args.max_sequence_length\n            )\n        return prompt_embeds, prompt_embeds_mask\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not train_dataset.custom_instance_prompts:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            instance_prompt_embeds, instance_prompt_embeds_mask = compute_text_embeddings(\n                args.instance_prompt, text_encoding_pipeline\n            )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            class_prompt_embeds, class_prompt_embeds_mask = compute_text_embeddings(\n                args.class_prompt, text_encoding_pipeline\n            )\n\n    validation_embeddings = {}\n    if args.validation_prompt is not None:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            (validation_embeddings[\"prompt_embeds\"], validation_embeddings[\"prompt_embeds_mask\"]) = (\n                compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)\n            )\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n    if not train_dataset.custom_instance_prompts:\n        prompt_embeds = instance_prompt_embeds\n        prompt_embeds_mask = instance_prompt_embeds_mask\n        if args.with_prior_preservation:\n            prompt_embeds = torch.cat([prompt_embeds, class_prompt_embeds], dim=0)\n            prompt_embeds_mask = torch.cat([prompt_embeds_mask, class_prompt_embeds_mask], dim=0)\n\n    # if cache_latents is set to True, we encode images to latents and store them.\n    # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided\n    # we encode them in advance as well.\n    precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts\n    if precompute_latents:\n        prompt_embeds_cache = []\n        prompt_embeds_mask_cache = []\n        latents_cache = []\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                if args.cache_latents:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                            accelerator.device, non_blocking=True, dtype=vae.dtype\n                        )\n                        latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n                if train_dataset.custom_instance_prompts:\n                    with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n                        prompt_embeds, prompt_embeds_mask = compute_text_embeddings(\n                            batch[\"prompts\"], text_encoding_pipeline\n                        )\n                    prompt_embeds_cache.append(prompt_embeds)\n                    prompt_embeds_mask_cache.append(prompt_embeds_mask)\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    if args.cache_latents:\n        vae = vae.to(\"cpu\")\n        del vae\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    text_encoding_pipeline = text_encoding_pipeline.to(\"cpu\")\n    del text_encoder, tokenizer\n    free_memory()\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-qwen-image-lora\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            prompts = batch[\"prompts\"]\n\n            with accelerator.accumulate(models_to_accumulate):\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    prompt_embeds = prompt_embeds_cache[step]\n                    prompt_embeds_mask = prompt_embeds_mask_cache[step]\n                else:\n                    num_repeat_elements = len(prompts)\n                    prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)\n                    if prompt_embeds_mask is not None:\n                        prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step].sample()\n                else:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent_dist.sample()\n\n                model_input = (model_input - latents_mean) * latents_std\n                model_input = model_input.to(dtype=weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                # Predict the noise residual\n                img_shapes = [\n                    (1, args.resolution // vae_scale_factor // 2, args.resolution // vae_scale_factor // 2)\n                ] * bsz\n                # transpose the dimensions\n                noisy_model_input = noisy_model_input.permute(0, 2, 1, 3, 4)\n                packed_noisy_model_input = QwenImagePipeline._pack_latents(\n                    noisy_model_input,\n                    batch_size=model_input.shape[0],\n                    num_channels_latents=model_input.shape[1],\n                    height=model_input.shape[3],\n                    width=model_input.shape[4],\n                )\n                model_pred = transformer(\n                    hidden_states=packed_noisy_model_input,\n                    encoder_hidden_states=prompt_embeds,\n                    encoder_hidden_states_mask=prompt_embeds_mask,\n                    timestep=timesteps / 1000,\n                    img_shapes=img_shapes,\n                    return_dict=False,\n                )[0]\n                model_pred = QwenImagePipeline._unpack_latents(\n                    model_pred, args.resolution, args.resolution, vae_scale_factor\n                )\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                target = noise - model_input\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = QwenImagePipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    tokenizer=None,\n                    text_encoder=None,\n                    transformer=accelerator.unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=validation_embeddings,\n                    torch_dtype=weight_dtype,\n                    epoch=epoch,\n                )\n                del pipeline\n                images = None\n                free_memory()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        modules_to_save = {}\n        transformer = unwrap_model(transformer)\n        if args.bnb_quantization_config_path is None:\n            if args.upcast_before_saving:\n                transformer.to(torch.float32)\n            else:\n                transformer = transformer.to(weight_dtype)\n        transformer_lora_layers = get_peft_model_state_dict(transformer)\n        modules_to_save[\"transformer\"] = transformer\n\n        QwenImagePipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n            **_collate_lora_metadata(modules_to_save),\n        )\n\n        images = []\n        run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)\n        should_run_final_inference = not args.skip_final_inference and run_validation\n        if should_run_final_inference:\n            # Final inference\n            # Load previous pipeline\n            pipeline = QwenImagePipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                tokenizer=None,\n                text_encoder=None,\n                revision=args.revision,\n                variant=args.variant,\n                torch_dtype=weight_dtype,\n            )\n            # load attention processors\n            pipeline.load_lora_weights(args.output_dir)\n\n            # run inference\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=validation_embeddings,\n                epoch=epoch,\n                is_final_validation=True,\n                torch_dtype=weight_dtype,\n            )\n            del pipeline\n            free_memory()\n\n        validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt\n        save_model_card(\n            (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,\n            images=images,\n            base_model=args.pretrained_model_name_or_path,\n            instance_prompt=args.instance_prompt,\n            validation_prompt=validation_prompt,\n            repo_folder=args.output_dir,\n        )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        images = None\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_sana.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=1.0.0\",\n#     \"transformers>=4.47.0\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.14.0\",\n#     \"sentencepiece\",\n#     \"torchvision\",\n#     \"datasets\",\n#     \"bitsandbytes\",\n#     \"prodigyopt\",\n# ]\n# ///\n\nimport argparse\nimport copy\nimport itertools\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, Gemma2Model\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderDC,\n    FlowMatchEulerDiscreteScheduler,\n    SanaPipeline,\n    SanaTransformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    _collate_lora_metadata,\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    free_memory,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\nif is_torch_npu_available():\n    torch.npu.config.allow_internal_format = False\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# Sana DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Sana diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sana.md).\n\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nTODO\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## License\n\nTODO\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"sana\",\n        \"sana-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    if args.enable_vae_tiling:\n        pipeline.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_stride_width=1024)\n\n    pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n\n    images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n\n    return images\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=300,\n        help=\"Maximum sequence length to use with with the Gemma model\",\n    )\n    parser.add_argument(\n        \"--complex_human_instruction\",\n        type=str,\n        default=None,\n        help=\"Instructions for complex human attention: https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=4,\n        help=\"LoRA alpha to be used for additional scaling.\",\n    )\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sana-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v\" will result in lora training of attention layers only'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoder to CPU when they are not used.\",\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\"--enable_vae_tiling\", action=\"store_true\", help=\"Enabla vae tiling in log validation\")\n    parser.add_argument(\"--enable_npu_flash_attention\", action=\"store_true\", help=\"Enabla Flash Attention for NPU\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        self.pixel_values = []\n        train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in self.instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            image = train_transforms(image)\n            self.pixel_values.append(image)\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            pipeline = SanaPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch.float32,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)\n            pipeline.transformer = pipeline.transformer.to(torch.float16)\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizer\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\", revision=args.revision\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    text_encoder = Gemma2Model.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    vae = AutoencoderDC.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    transformer = SanaTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    # VAE should always be kept in fp32 for SANA (?)\n    vae.to(dtype=torch.float32)\n    transformer.to(accelerator.device, dtype=weight_dtype)\n    # because Gemma2 is particularly suited for bfloat16.\n    text_encoder.to(dtype=torch.bfloat16)\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            for block in transformer.transformer_blocks:\n                block.attn2.set_use_npu_flash_attention(True)\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu device \")\n\n    # Initialize a text encoding pipeline and keep it to CPU for now.\n    text_encoding_pipeline = SanaPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=None,\n        transformer=None,\n        text_encoder=text_encoder,\n        tokenizer=tokenizer,\n    )\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\"to_k\", \"to_q\", \"to_v\"]\n\n    # now we will add new LoRA weights the transformer layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n            modules_to_save = {}\n            for model in models:\n                if isinstance(model, type(unwrap_model(transformer))):\n                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                    modules_to_save[\"transformer\"] = model\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            SanaPipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n                **_collate_lora_metadata(modules_to_save),\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(transformer))):\n                transformer_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict = SanaPipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    def compute_text_embeddings(prompt, text_encoding_pipeline):\n        text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)\n        with torch.no_grad():\n            prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(\n                prompt,\n                max_sequence_length=args.max_sequence_length,\n                complex_human_instruction=args.complex_human_instruction,\n            )\n        if args.offload:\n            text_encoding_pipeline = text_encoding_pipeline.to(\"cpu\")\n        prompt_embeds = prompt_embeds.to(transformer.dtype)\n        return prompt_embeds, prompt_attention_mask\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not train_dataset.custom_instance_prompts:\n        instance_prompt_hidden_states, instance_prompt_attention_mask = compute_text_embeddings(\n            args.instance_prompt, text_encoding_pipeline\n        )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        class_prompt_hidden_states, class_prompt_attention_mask = compute_text_embeddings(\n            args.class_prompt, text_encoding_pipeline\n        )\n\n    # Clear the memory here\n    if not train_dataset.custom_instance_prompts:\n        del text_encoder, tokenizer\n        free_memory()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n    if not train_dataset.custom_instance_prompts:\n        prompt_embeds = instance_prompt_hidden_states\n        prompt_attention_mask = instance_prompt_attention_mask\n        if args.with_prior_preservation:\n            prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n            prompt_attention_mask = torch.cat([prompt_attention_mask, class_prompt_attention_mask], dim=0)\n\n    vae_config_scaling_factor = vae.config.scaling_factor\n    if args.cache_latents:\n        latents_cache = []\n        vae = vae.to(accelerator.device)\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                    accelerator.device, non_blocking=True, dtype=vae.dtype\n                )\n                latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent)\n\n        if args.validation_prompt is None:\n            del vae\n            free_memory()\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-sana-lora\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            with accelerator.accumulate(models_to_accumulate):\n                prompts = batch[\"prompts\"]\n\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    prompt_embeds, prompt_attention_mask = compute_text_embeddings(prompts, text_encoding_pipeline)\n\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step]\n                else:\n                    vae = vae.to(accelerator.device)\n                    pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent\n                    if args.offload:\n                        vae = vae.to(\"cpu\")\n                model_input = model_input * vae_config_scaling_factor\n                model_input = model_input.to(dtype=weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=noisy_model_input,\n                    encoder_hidden_states=prompt_embeds,\n                    encoder_attention_mask=prompt_attention_mask,\n                    timestep=timesteps,\n                    return_dict=False,\n                )[0]\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = noise - model_input\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = SanaPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    transformer=accelerator.unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=torch.float32,\n                )\n                pipeline_args = {\n                    \"prompt\": args.validation_prompt,\n                    \"complex_human_instruction\": args.complex_human_instruction,\n                }\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                )\n                free_memory()\n\n                images = None\n                del pipeline\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        transformer = unwrap_model(transformer)\n        modules_to_save = {}\n        if args.upcast_before_saving:\n            transformer.to(torch.float32)\n        else:\n            transformer = transformer.to(weight_dtype)\n        transformer_lora_layers = get_peft_model_state_dict(transformer)\n        modules_to_save[\"transformer\"] = transformer\n\n        SanaPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n            **_collate_lora_metadata(modules_to_save),\n        )\n\n        # Final inference\n        # Load previous pipeline\n        pipeline = SanaPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=torch.float32,\n        )\n        pipeline.transformer = pipeline.transformer.to(torch.float16)\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\n                \"prompt\": args.validation_prompt,\n                \"complex_human_instruction\": args.complex_human_instruction,\n            }\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=pipeline_args,\n                epoch=epoch,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                instance_prompt=args.instance_prompt,\n                validation_prompt=args.validation_prompt,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        images = None\n        del pipeline\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_sd3.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport itertools\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    FlowMatchEulerDiscreteScheduler,\n    SD3Transformer2DModel,\n    StableDiffusion3Pipeline,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    _set_state_dict_into_text_encoder,\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    free_memory,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    train_text_encoder=False,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n):\n    if \"large\" in base_model:\n        model_variant = \"SD3.5-Large\"\n        license_url = \"https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md\"\n        variant_tags = [\"sd3.5-large\", \"sd3.5\", \"sd3.5-diffusers\"]\n    else:\n        model_variant = \"SD3\"\n        license_url = \"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md\"\n        variant_tags = [\"sd3\", \"sd3-diffusers\"]\n\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# {model_variant} DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md).\n\nWas LoRA for the text encoder enabled? {train_text_encoder}.\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\npipeline = AutoPipelineForText2Image.from_pretrained({base_model}, torch_dtype=torch.float16).to('cuda')\npipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\n### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke\n\n- **LoRA**: download **[`diffusers_lora_weights.safetensors` here 💾](/{repo_id}/blob/main/diffusers_lora_weights.safetensors)**.\n    - Rename it and place it on your `models/Lora` folder.\n    - On AUTOMATIC1111, load the LoRA by adding `<lora:your_new_name:1>` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## License\n\nPlease adhere to the licensing terms as described [here]({license_url}).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"template:sd-lora\",\n    ]\n\n    tags += variant_tags\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef load_text_encoders(class_one, class_two, class_three):\n    text_encoder_one = class_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = class_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_three = class_three.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_3\", revision=args.revision, variant=args.variant\n    )\n    return text_encoder_one, text_encoder_two, text_encoder_three\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n    autocast_ctx = nullcontext()\n\n    with autocast_ctx:\n        images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n    if model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=77,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd3-dreambooth\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder (clip text encoders only). If set, the text encoder should be float32 precision.\",\n    )\n\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"logit_normal\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\"],\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--precondition_outputs\",\n        type=int,\n        default=1,\n        help=\"Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how \"\n        \"model `target` is calculated.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            \"The transformer block layers to apply LoRA training on. Please specify the layers in a comma separated string.\"\n            \"For examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md\"\n        ),\n    )\n    parser.add_argument(\n        \"--lora_blocks\",\n        type=str,\n        default=None,\n        help=(\n            \"The transformer blocks to apply LoRA training on. Please specify the block numbers in a comma separated manner.\"\n            'E.g. - \"--lora_blocks 12,30\" will result in lora training of transformer blocks 12 and 30. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        self.pixel_values = []\n        train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in self.instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            image = train_transforms(image)\n            self.pixel_values.append(image)\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef tokenize_prompt(tokenizer, prompt):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=77,\n        truncation=True,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\ndef _encode_prompt_with_t5(\n    text_encoder,\n    tokenizer,\n    max_sequence_length,\n    prompt=None,\n    num_images_per_prompt=1,\n    device=None,\n    text_input_ids=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            add_special_tokens=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device))[0]\n\n    dtype = text_encoder.dtype\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    _, seq_len, _ = prompt_embeds.shape\n\n    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds\n\n\ndef _encode_prompt_with_clip(\n    text_encoder,\n    tokenizer,\n    prompt: str,\n    device=None,\n    text_input_ids=None,\n    num_images_per_prompt: int = 1,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=77,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n\n    pooled_prompt_embeds = prompt_embeds[0]\n    prompt_embeds = prompt_embeds.hidden_states[-2]\n    prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)\n\n    _, seq_len, _ = prompt_embeds.shape\n    # duplicate text embeddings for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef encode_prompt(\n    text_encoders,\n    tokenizers,\n    prompt: str,\n    max_sequence_length,\n    device=None,\n    num_images_per_prompt: int = 1,\n    text_input_ids_list=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n\n    clip_tokenizers = tokenizers[:2]\n    clip_text_encoders = text_encoders[:2]\n\n    clip_prompt_embeds_list = []\n    clip_pooled_prompt_embeds_list = []\n    for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):\n        prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            prompt=prompt,\n            device=device if device is not None else text_encoder.device,\n            num_images_per_prompt=num_images_per_prompt,\n            text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,\n        )\n        clip_prompt_embeds_list.append(prompt_embeds)\n        clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)\n\n    clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)\n\n    t5_prompt_embed = _encode_prompt_with_t5(\n        text_encoders[-1],\n        tokenizers[-1],\n        max_sequence_length,\n        prompt=prompt,\n        num_images_per_prompt=num_images_per_prompt,\n        text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,\n        device=device if device is not None else text_encoders[-1].device,\n    )\n\n    clip_prompt_embeds = torch.nn.functional.pad(\n        clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])\n    )\n    prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)\n\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()\n            torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n            pipeline = StableDiffusion3Pipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n    tokenizer_two = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n    )\n    tokenizer_three = T5TokenizerFast.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_3\",\n        revision=args.revision,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n    text_encoder_cls_three = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_3\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(\n        text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    transformer = SD3Transformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n    )\n\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    text_encoder_three.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    vae.to(accelerator.device, dtype=torch.float32)\n    transformer.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_three.to(accelerator.device, dtype=weight_dtype)\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder_one.gradient_checkpointing_enable()\n            text_encoder_two.gradient_checkpointing_enable()\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\n            \"attn.add_k_proj\",\n            \"attn.add_q_proj\",\n            \"attn.add_v_proj\",\n            \"attn.to_add_out\",\n            \"attn.to_k\",\n            \"attn.to_out.0\",\n            \"attn.to_q\",\n            \"attn.to_v\",\n        ]\n    if args.lora_blocks is not None:\n        target_blocks = [int(block.strip()) for block in args.lora_blocks.split(\",\")]\n        target_modules = [\n            f\"transformer_blocks.{block}.{module}\" for block in target_blocks for module in target_modules\n        ]\n\n    # now we will add new LoRA weights to the attention layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    if args.train_text_encoder:\n        text_lora_config = LoraConfig(\n            r=args.rank,\n            lora_alpha=args.rank,\n            lora_dropout=args.lora_dropout,\n            init_lora_weights=\"gaussian\",\n            target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n        )\n        text_encoder_one.add_adapter(text_lora_config)\n        text_encoder_two.add_adapter(text_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n            text_encoder_one_lora_layers_to_save = None\n            text_encoder_two_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    model = unwrap_model(model)\n                    if args.upcast_before_saving:\n                        model = model.to(torch.float32)\n                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                elif args.train_text_encoder and isinstance(\n                    unwrap_model(model), type(unwrap_model(text_encoder_one))\n                ):  # or text_encoder_two\n                    # both text encoders are of the same class, so we check hidden size to distinguish between the two\n                    model = unwrap_model(model)\n                    hidden_size = model.config.hidden_size\n                    if hidden_size == 768:\n                        text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)\n                    elif hidden_size == 1280:\n                        text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                if weights:\n                    weights.pop()\n\n            StableDiffusion3Pipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n                text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,\n                text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n        text_encoder_one_ = None\n        text_encoder_two_ = None\n\n        if not accelerator.distributed_type == DistributedType.DEEPSPEED:\n            while len(models) > 0:\n                model = models.pop()\n\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    transformer_ = unwrap_model(model)\n                elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):\n                    text_encoder_one_ = unwrap_model(model)\n                elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):\n                    text_encoder_two_ = unwrap_model(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        else:\n            transformer_ = SD3Transformer2DModel.from_pretrained(\n                args.pretrained_model_name_or_path, subfolder=\"transformer\"\n            )\n            transformer_.add_adapter(transformer_lora_config)\n            if args.train_text_encoder:\n                text_encoder_one_ = text_encoder_cls_one.from_pretrained(\n                    args.pretrained_model_name_or_path, subfolder=\"text_encoder\"\n                )\n                text_encoder_two_ = text_encoder_cls_two.from_pretrained(\n                    args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\"\n                )\n\n        lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n        if args.train_text_encoder:\n            # Do we need to call `scale_lora_layers()` here?\n            _set_state_dict_into_text_encoder(lora_state_dict, prefix=\"text_encoder.\", text_encoder=text_encoder_one_)\n\n            _set_state_dict_into_text_encoder(\n                lora_state_dict, prefix=\"text_encoder_2.\", text_encoder=text_encoder_two_\n            )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            if args.train_text_encoder:\n                models.extend([text_encoder_one_, text_encoder_two_])\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        if args.train_text_encoder:\n            models.extend([text_encoder_one, text_encoder_two])\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n    if args.train_text_encoder:\n        text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))\n        text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    if args.train_text_encoder:\n        # different learning rate for text encoder and unet\n        text_lora_parameters_one_with_lr = {\n            \"params\": text_lora_parameters_one,\n            \"weight_decay\": args.adam_weight_decay_text_encoder,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        text_lora_parameters_two_with_lr = {\n            \"params\": text_lora_parameters_two,\n            \"weight_decay\": args.adam_weight_decay_text_encoder,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        params_to_optimize = [\n            transformer_parameters_with_lr,\n            text_lora_parameters_one_with_lr,\n            text_lora_parameters_two_with_lr,\n        ]\n    else:\n        params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n        if args.train_text_encoder and args.text_encoder_lr:\n            logger.warning(\n                f\"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:\"\n                f\" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. \"\n                f\"When using prodigy only learning_rate is used as the initial learning rate.\"\n            )\n            # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be\n            # --learning_rate\n            params_to_optimize[1][\"lr\"] = args.learning_rate\n            params_to_optimize[2][\"lr\"] = args.learning_rate\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    if not args.train_text_encoder:\n        tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]\n        text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]\n\n        def compute_text_embeddings(prompt, text_encoders, tokenizers):\n            with torch.no_grad():\n                prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                    text_encoders, tokenizers, prompt, args.max_sequence_length\n                )\n                prompt_embeds = prompt_embeds.to(accelerator.device)\n                pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)\n            return prompt_embeds, pooled_prompt_embeds\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(\n            args.instance_prompt, text_encoders, tokenizers\n        )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        if not args.train_text_encoder:\n            class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(\n                args.class_prompt, text_encoders, tokenizers\n            )\n\n    # Clear the memory here\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection\n        del tokenizers, text_encoders\n        del text_encoder_one, text_encoder_two, text_encoder_three\n        free_memory()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n\n    if not train_dataset.custom_instance_prompts:\n        if not args.train_text_encoder:\n            prompt_embeds = instance_prompt_hidden_states\n            pooled_prompt_embeds = instance_pooled_prompt_embeds\n            if args.with_prior_preservation:\n                prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n                pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)\n            # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the\n        # batch prompts on all training steps\n        else:\n            tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)\n            tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)\n            tokens_three = tokenize_prompt(tokenizer_three, args.instance_prompt)\n            if args.with_prior_preservation:\n                class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)\n                class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)\n                class_tokens_three = tokenize_prompt(tokenizer_three, args.class_prompt)\n                tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)\n                tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)\n                tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)\n\n    vae_config_shift_factor = vae.config.shift_factor\n    vae_config_scaling_factor = vae.config.scaling_factor\n    if args.cache_latents:\n        latents_cache = []\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                    accelerator.device, non_blocking=True, dtype=weight_dtype\n                )\n                latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n\n        if args.validation_prompt is None:\n            del vae\n            free_memory()\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        (\n            transformer,\n            text_encoder_one,\n            text_encoder_two,\n            optimizer,\n            train_dataloader,\n            lr_scheduler,\n        ) = accelerator.prepare(\n            transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler\n        )\n        assert text_encoder_one is not None\n        assert text_encoder_two is not None\n        assert text_encoder_three is not None\n    else:\n        transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            transformer, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-sd3-lora\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n        if args.train_text_encoder:\n            text_encoder_one.train()\n            text_encoder_two.train()\n\n            # set top parameter requires_grad = True for gradient checkpointing works\n            accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)\n            accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            if args.train_text_encoder:\n                models_to_accumulate.extend([text_encoder_one, text_encoder_two])\n            with accelerator.accumulate(models_to_accumulate):\n                prompts = batch[\"prompts\"]\n\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    if not args.train_text_encoder:\n                        prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(\n                            prompts, text_encoders, tokenizers\n                        )\n                    else:\n                        tokens_one = tokenize_prompt(tokenizer_one, prompts)\n                        tokens_two = tokenize_prompt(tokenizer_two, prompts)\n                        tokens_three = tokenize_prompt(tokenizer_three, prompts)\n                        prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                            text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],\n                            tokenizers=[None, None, None],\n                            prompt=prompts,\n                            max_sequence_length=args.max_sequence_length,\n                            text_input_ids_list=[tokens_one, tokens_two, tokens_three],\n                        )\n                else:\n                    if args.train_text_encoder:\n                        prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                            text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],\n                            tokenizers=[None, None, tokenizer_three],\n                            prompt=args.instance_prompt,\n                            max_sequence_length=args.max_sequence_length,\n                            text_input_ids_list=[tokens_one, tokens_two, tokens_three],\n                        )\n\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step].sample()\n                else:\n                    pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent_dist.sample()\n\n                model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor\n                model_input = model_input.to(dtype=weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=noisy_model_input,\n                    timestep=timesteps,\n                    encoder_hidden_states=prompt_embeds,\n                    pooled_projections=pooled_prompt_embeds,\n                    return_dict=False,\n                )[0]\n\n                # Follow: Section 5 of https://huggingface.co/papers/2206.00364.\n                # Preconditioning of the model outputs.\n                if args.precondition_outputs:\n                    model_pred = model_pred * (-sigmas) + noisy_model_input\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                if args.precondition_outputs:\n                    target = model_input\n                else:\n                    target = noise - model_input\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(\n                            transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two\n                        )\n                        if args.train_text_encoder\n                        else transformer_lora_parameters\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                if not args.train_text_encoder:\n                    # create pipeline\n                    text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(\n                        text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three\n                    )\n                    text_encoder_one.to(weight_dtype)\n                    text_encoder_two.to(weight_dtype)\n                pipeline = StableDiffusion3Pipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    text_encoder=accelerator.unwrap_model(text_encoder_one),\n                    text_encoder_2=accelerator.unwrap_model(text_encoder_two),\n                    text_encoder_3=accelerator.unwrap_model(text_encoder_three),\n                    transformer=accelerator.unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                    torch_dtype=weight_dtype,\n                )\n                if not args.train_text_encoder:\n                    del text_encoder_one, text_encoder_two, text_encoder_three\n                    free_memory()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        transformer = unwrap_model(transformer)\n        if args.upcast_before_saving:\n            transformer.to(torch.float32)\n        else:\n            transformer = transformer.to(weight_dtype)\n        transformer_lora_layers = get_peft_model_state_dict(transformer)\n\n        if args.train_text_encoder:\n            text_encoder_one = unwrap_model(text_encoder_one)\n            text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))\n            text_encoder_two = unwrap_model(text_encoder_two)\n            text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))\n        else:\n            text_encoder_lora_layers = None\n            text_encoder_2_lora_layers = None\n\n        StableDiffusion3Pipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n            text_encoder_lora_layers=text_encoder_lora_layers,\n            text_encoder_2_lora_layers=text_encoder_2_lora_layers,\n        )\n\n        # Final inference\n        # Load previous pipeline\n        pipeline = StableDiffusion3Pipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt}\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=pipeline_args,\n                epoch=epoch,\n                is_final_validation=True,\n                torch_dtype=weight_dtype,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                instance_prompt=args.instance_prompt,\n                validation_prompt=args.validation_prompt,\n                train_text_encoder=args.train_text_encoder,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_sdxl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport gc\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, hf_hub_download, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom packaging import version\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom safetensors.torch import load_file, save_file\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DPMSolverMultistepScheduler,\n    EDMEulerScheduler,\n    EulerDiscreteScheduler,\n    StableDiffusionXLPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr\nfrom diffusers.utils import (\n    check_min_version,\n    convert_all_state_dict_to_peft,\n    convert_state_dict_to_diffusers,\n    convert_state_dict_to_kohya,\n    convert_unet_state_dict_to_peft,\n    is_peft_version,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef determine_scheduler_type(pretrained_model_name_or_path, revision):\n    model_index_filename = \"model_index.json\"\n    if os.path.isdir(pretrained_model_name_or_path):\n        model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)\n    else:\n        model_index = hf_hub_download(\n            repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision\n        )\n\n    with open(model_index, \"r\") as f:\n        scheduler_type = json.load(f)[\"scheduler\"][1]\n    return scheduler_type\n\n\ndef save_model_card(\n    repo_id: str,\n    use_dora: bool,\n    images=None,\n    base_model: str = None,\n    train_text_encoder=False,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n    vae_path=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# {\"SDXL\" if \"playground\" not in base_model else \"Playground\"} LoRA DreamBooth - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} LoRA adaption weights for {base_model}.\n\nThe weights were trained  using [DreamBooth](https://dreambooth.github.io/).\n\nLoRA for the text encoder was enabled: {train_text_encoder}.\n\nSpecial VAE used for training: {vae_path}.\n\n## Trigger words\n\nYou should use {instance_prompt} to trigger the image generation.\n\n## Download model\n\nWeights for this model are available in Safetensors format.\n\n[Download]({repo_id}/tree/main) them in the Files & versions tab.\n\n\"\"\"\n    if \"playground\" in base_model:\n        model_description += \"\"\"\\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"openrail++\" if \"playground\" not in base_model else \"playground-v2dot5-community\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\" if not use_dora else \"dora\",\n        \"template:sd-lora\",\n    ]\n    if \"playground\" in base_model:\n        tags.extend([\"playground\", \"playground-diffusers\"])\n    else:\n        tags.extend([\"stable-diffusion-xl\", \"stable-diffusion-xl-diffusers\"])\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n\n    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n    scheduler_args = {}\n\n    if not args.do_edm_style_training:\n        if \"variance_type\" in pipeline.scheduler.config:\n            variance_type = pipeline.scheduler.config.variance_type\n\n            if variance_type in [\"learned\", \"learned_range\"]:\n                variance_type = \"fixed_small\"\n\n            scheduler_args[\"variance_type\"] = variance_type\n\n        pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better\n    # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051\n    if torch.backends.mps.is_available() or \"playground\" in args.pretrained_model_name_or_path:\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n\n    with autocast_ctx:\n        images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--do_edm_style_training\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to conduct training using the EDM formulation as introduced in https://huggingface.co/papers/2206.00364.\",\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"lora-dreambooth-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--output_kohya_format\",\n        action=\"store_true\",\n        help=\"Flag to additionally generate final state dict in the Kohya format so that it becomes compatible with A111, Comfy, Kohya, etc.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--use_dora\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://huggingface.co/papers/2402.09353. \"\n            \"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`\"\n        ),\n    )\n\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        # image processing to prepare for using SD-XL micro-conditioning\n        self.original_sizes = []\n        self.crop_top_lefts = []\n        self.pixel_values = []\n\n        interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)\n        if interpolation is None:\n            raise ValueError(f\"Unsupported interpolation mode {interpolation=}.\")\n        train_resize = transforms.Resize(size, interpolation=interpolation)\n\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in self.instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            self.original_sizes.append((image.height, image.width))\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            crop_top_left = (y1, x1)\n            self.crop_top_lefts.append(crop_top_left)\n            image = train_transforms(image)\n            self.pixel_values.append(image)\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=interpolation),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        original_size = self.original_sizes[index % self.num_instance_images]\n        crop_top_left = self.crop_top_lefts[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n        example[\"original_size\"] = original_size\n        example[\"crop_top_left\"] = crop_top_left\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n    original_sizes = [example[\"original_size\"] for example in examples]\n    crop_top_lefts = [example[\"crop_top_left\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n        original_sizes += [example[\"original_size\"] for example in examples]\n        crop_top_lefts += [example[\"crop_top_left\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\n        \"pixel_values\": pixel_values,\n        \"prompts\": prompts,\n        \"original_sizes\": original_sizes,\n        \"crop_top_lefts\": crop_top_lefts,\n    }\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef tokenize_prompt(tokenizer, prompt):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=tokenizer.model_max_length,\n        truncation=True,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):\n    prompt_embeds_list = []\n\n    for i, text_encoder in enumerate(text_encoders):\n        if tokenizers is not None:\n            tokenizer = tokenizers[i]\n            text_input_ids = tokenize_prompt(tokenizer, prompt)\n        else:\n            assert text_input_ids_list is not None\n            text_input_ids = text_input_ids_list[i]\n\n        prompt_embeds = text_encoder(\n            text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False\n        )\n\n        # We are only ALWAYS interested in the pooled output of the final text encoder\n        pooled_prompt_embeds = prompt_embeds[0]\n        prompt_embeds = prompt_embeds[-1][-2]\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n        prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if args.do_edm_style_training and args.snr_gamma is not None:\n        raise ValueError(\"Min-SNR formulation is not supported when conducting EDM-style training.\")\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()\n            torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n            pipeline = StableDiffusionXLPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)\n    if \"EDM\" in scheduler_type:\n        args.do_edm_style_training = True\n        noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n        logger.info(\"Performing EDM-style training!\")\n    elif args.do_edm_style_training:\n        noise_scheduler = EulerDiscreteScheduler.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n        )\n        logger.info(\"Performing EDM-style training!\")\n    else:\n        noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    latents_mean = latents_std = None\n    if hasattr(vae.config, \"latents_mean\") and vae.config.latents_mean is not None:\n        latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)\n    if hasattr(vae.config, \"latents_std\") and vae.config.latents_std is not None:\n        latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)\n\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n\n    # The VAE is always in float32 to avoid NaN losses.\n    vae.to(accelerator.device, dtype=torch.float32)\n\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, \"\n                    \"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder_one.gradient_checkpointing_enable()\n            text_encoder_two.gradient_checkpointing_enable()\n\n    def get_lora_config(rank, dropout, use_dora, target_modules):\n        base_config = {\n            \"r\": rank,\n            \"lora_alpha\": rank,\n            \"lora_dropout\": dropout,\n            \"init_lora_weights\": \"gaussian\",\n            \"target_modules\": target_modules,\n        }\n        if use_dora:\n            if is_peft_version(\"<\", \"0.9.0\"):\n                raise ValueError(\n                    \"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`.\"\n                )\n            else:\n                base_config[\"use_dora\"] = True\n\n        return LoraConfig(**base_config)\n\n    # now we will add new LoRA weights to the attention layers\n    unet_target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\n    unet_lora_config = get_lora_config(\n        rank=args.rank,\n        dropout=args.lora_dropout,\n        use_dora=args.use_dora,\n        target_modules=unet_target_modules,\n    )\n    unet.add_adapter(unet_lora_config)\n\n    # The text encoder comes from 🤗 transformers, so we cannot directly modify it.\n    # So, instead, we monkey-patch the forward calls of its attention-blocks.\n    if args.train_text_encoder:\n        text_target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"]\n        text_lora_config = get_lora_config(\n            rank=args.rank,\n            dropout=args.lora_dropout,\n            use_dora=args.use_dora,\n            target_modules=text_target_modules,\n        )\n        text_encoder_one.add_adapter(text_lora_config)\n        text_encoder_two.add_adapter(text_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            # there are only two options here. Either are just the unet attn processor layers\n            # or there are the unet and text encoder atten layers\n            unet_lora_layers_to_save = None\n            text_encoder_one_lora_layers_to_save = None\n            text_encoder_two_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(unwrap_model(unet))):\n                    unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))\n                elif isinstance(model, type(unwrap_model(text_encoder_one))):\n                    text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(\n                        get_peft_model_state_dict(model)\n                    )\n                elif isinstance(model, type(unwrap_model(text_encoder_two))):\n                    text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(\n                        get_peft_model_state_dict(model)\n                    )\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            StableDiffusionXLPipeline.save_lora_weights(\n                output_dir,\n                unet_lora_layers=unet_lora_layers_to_save,\n                text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,\n                text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n        text_encoder_one_ = None\n        text_encoder_two_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(unet))):\n                unet_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder_one))):\n                text_encoder_one_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder_two))):\n                text_encoder_two_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)\n\n        unet_state_dict = {f\"{k.replace('unet.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"unet.\")}\n        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)\n        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        if args.train_text_encoder:\n            # Do we need to call `scale_lora_layers()` here?\n            _set_state_dict_into_text_encoder(lora_state_dict, prefix=\"text_encoder.\", text_encoder=text_encoder_one_)\n\n            _set_state_dict_into_text_encoder(\n                lora_state_dict, prefix=\"text_encoder_2.\", text_encoder=text_encoder_two_\n            )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [unet_]\n            if args.train_text_encoder:\n                models.extend([text_encoder_one_, text_encoder_two_])\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [unet]\n        if args.train_text_encoder:\n            models.extend([text_encoder_one, text_encoder_two])\n\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))\n\n    if args.train_text_encoder:\n        text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))\n        text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))\n\n    # Optimization parameters\n    unet_lora_parameters_with_lr = {\"params\": unet_lora_parameters, \"lr\": args.learning_rate}\n    if args.train_text_encoder:\n        # different learning rate for text encoder and unet\n        text_lora_parameters_one_with_lr = {\n            \"params\": text_lora_parameters_one,\n            \"weight_decay\": args.adam_weight_decay_text_encoder,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        text_lora_parameters_two_with_lr = {\n            \"params\": text_lora_parameters_two,\n            \"weight_decay\": args.adam_weight_decay_text_encoder,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        params_to_optimize = [\n            unet_lora_parameters_with_lr,\n            text_lora_parameters_one_with_lr,\n            text_lora_parameters_two_with_lr,\n        ]\n    else:\n        params_to_optimize = [unet_lora_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n        if args.train_text_encoder and args.text_encoder_lr:\n            logger.warning(\n                f\"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:\"\n                f\" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. \"\n                f\"When using prodigy only learning_rate is used as the initial learning rate.\"\n            )\n            # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be\n            # --learning_rate\n            params_to_optimize[1][\"lr\"] = args.learning_rate\n            params_to_optimize[2][\"lr\"] = args.learning_rate\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Computes additional embeddings/ids required by the SDXL UNet.\n    # regular text embeddings (when `train_text_encoder` is not True)\n    # pooled text embeddings\n    # time ids\n\n    def compute_time_ids(original_size, crops_coords_top_left):\n        # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n        target_size = (args.resolution, args.resolution)\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n        add_time_ids = torch.tensor([add_time_ids])\n        add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)\n        return add_time_ids\n\n    if not args.train_text_encoder:\n        tokenizers = [tokenizer_one, tokenizer_two]\n        text_encoders = [text_encoder_one, text_encoder_two]\n\n        def compute_text_embeddings(prompt, text_encoders, tokenizers):\n            with torch.no_grad():\n                prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)\n                prompt_embeds = prompt_embeds.to(accelerator.device)\n                pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)\n            return prompt_embeds, pooled_prompt_embeds\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(\n            args.instance_prompt, text_encoders, tokenizers\n        )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        if not args.train_text_encoder:\n            class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(\n                args.class_prompt, text_encoders, tokenizers\n            )\n\n    # Clear the memory here\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        del tokenizers, text_encoders\n        gc.collect()\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n\n    if not train_dataset.custom_instance_prompts:\n        if not args.train_text_encoder:\n            prompt_embeds = instance_prompt_hidden_states\n            unet_add_text_embeds = instance_pooled_prompt_embeds\n            if args.with_prior_preservation:\n                prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n                unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)\n        # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the\n        # batch prompts on all training steps\n        else:\n            tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)\n            tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)\n            if args.with_prior_preservation:\n                class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)\n                class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)\n                tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)\n                tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = (\n            \"dreambooth-lora-sd-xl\"\n            if \"playground\" not in args.pretrained_model_name_or_path\n            else \"dreambooth-lora-playground\"\n        )\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.train_text_encoder:\n            text_encoder_one.train()\n            text_encoder_two.train()\n\n            # set top parameter requires_grad = True for gradient checkpointing works\n            accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)\n            accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)\n\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                prompts = batch[\"prompts\"]\n\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    if not args.train_text_encoder:\n                        prompt_embeds, unet_add_text_embeds = compute_text_embeddings(\n                            prompts, text_encoders, tokenizers\n                        )\n                    else:\n                        tokens_one = tokenize_prompt(tokenizer_one, prompts)\n                        tokens_two = tokenize_prompt(tokenizer_two, prompts)\n\n                # Convert images to latent space\n                model_input = vae.encode(pixel_values).latent_dist.sample()\n\n                if latents_mean is None and latents_std is None:\n                    model_input = model_input * vae.config.scaling_factor\n                    if args.pretrained_vae_model_name_or_path is None:\n                        model_input = model_input.to(weight_dtype)\n                else:\n                    latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)\n                    latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)\n                    model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std\n                    model_input = model_input.to(dtype=weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                if not args.do_edm_style_training:\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                    )\n                    timesteps = timesteps.long()\n                else:\n                    # in EDM formulation, the model is conditioned on the pre-conditioned noise levels\n                    # instead of discrete timesteps, so here we sample indices to get the noise levels\n                    # from `scheduler.timesteps`\n                    indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))\n                    timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n                # For EDM-style training, we first obtain the sigmas based on the continuous timesteps.\n                # We then precondition the final model inputs based on these sigmas instead of the timesteps.\n                # Follow: Section 5 of https://huggingface.co/papers/2206.00364.\n                if args.do_edm_style_training:\n                    sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)\n                    if \"EDM\" in scheduler_type:\n                        inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)\n                    else:\n                        inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)\n\n                # time ids\n                add_time_ids = torch.cat(\n                    [\n                        compute_time_ids(original_size=s, crops_coords_top_left=c)\n                        for s, c in zip(batch[\"original_sizes\"], batch[\"crop_top_lefts\"])\n                    ]\n                )\n\n                # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.\n                if not train_dataset.custom_instance_prompts:\n                    elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz\n                else:\n                    elems_to_repeat_text_embeds = 1\n\n                # Predict the noise residual\n                if not args.train_text_encoder:\n                    unet_added_conditions = {\n                        \"time_ids\": add_time_ids,\n                        \"text_embeds\": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),\n                    }\n                    prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)\n                    model_pred = unet(\n                        inp_noisy_latents if args.do_edm_style_training else noisy_model_input,\n                        timesteps,\n                        prompt_embeds_input,\n                        added_cond_kwargs=unet_added_conditions,\n                        return_dict=False,\n                    )[0]\n                else:\n                    unet_added_conditions = {\"time_ids\": add_time_ids}\n                    prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                        text_encoders=[text_encoder_one, text_encoder_two],\n                        tokenizers=None,\n                        prompt=None,\n                        text_input_ids_list=[tokens_one, tokens_two],\n                    )\n                    unet_added_conditions.update(\n                        {\"text_embeds\": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}\n                    )\n                    prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)\n                    model_pred = unet(\n                        inp_noisy_latents if args.do_edm_style_training else noisy_model_input,\n                        timesteps,\n                        prompt_embeds_input,\n                        added_cond_kwargs=unet_added_conditions,\n                        return_dict=False,\n                    )[0]\n\n                weighting = None\n                if args.do_edm_style_training:\n                    # Similar to the input preconditioning, the model predictions are also preconditioned\n                    # on noised model inputs (before preconditioning) and the sigmas.\n                    # Follow: Section 5 of https://huggingface.co/papers/2206.00364.\n                    if \"EDM\" in scheduler_type:\n                        model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)\n                    else:\n                        if noise_scheduler.config.prediction_type == \"epsilon\":\n                            model_pred = model_pred * (-sigmas) + noisy_model_input\n                        elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                            model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (\n                                noisy_model_input / (sigmas**2 + 1)\n                            )\n                    # We are not doing weighting here because it tends result in numerical problems.\n                    # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051\n                    # There might be other alternatives for weighting as well:\n                    # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686\n                    if \"EDM\" not in scheduler_type:\n                        weighting = (sigmas**-2.0).float()\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = model_input if args.do_edm_style_training else noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = (\n                        model_input\n                        if args.do_edm_style_training\n                        else noise_scheduler.get_velocity(model_input, noise, timesteps)\n                    )\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    if weighting is not None:\n                        prior_loss = torch.mean(\n                            (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                                target_prior.shape[0], -1\n                            ),\n                            1,\n                        )\n                        prior_loss = prior_loss.mean()\n                    else:\n                        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction=\"mean\")\n\n                if args.snr_gamma is None:\n                    if weighting is not None:\n                        loss = torch.mean(\n                            (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(\n                                target.shape[0], -1\n                            ),\n                            1,\n                        )\n                        loss = loss.mean()\n                    else:\n                        loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    base_weight = (\n                        torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr\n                    )\n\n                    if noise_scheduler.config.prediction_type == \"v_prediction\":\n                        # Velocity objective needs to be floored to an SNR weight of one.\n                        mse_loss_weights = base_weight + 1\n                    else:\n                        # Epsilon and sample both use the same loss weights.\n                        mse_loss_weights = base_weight\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)\n                        if args.train_text_encoder\n                        else unet_lora_parameters\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                if not args.train_text_encoder:\n                    text_encoder_one = text_encoder_cls_one.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        subfolder=\"text_encoder\",\n                        revision=args.revision,\n                        variant=args.variant,\n                    )\n                    text_encoder_two = text_encoder_cls_two.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        subfolder=\"text_encoder_2\",\n                        revision=args.revision,\n                        variant=args.variant,\n                    )\n                pipeline = StableDiffusionXLPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    text_encoder=accelerator.unwrap_model(text_encoder_one),\n                    text_encoder_2=accelerator.unwrap_model(text_encoder_two),\n                    unet=accelerator.unwrap_model(unet),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n\n                images = log_validation(\n                    pipeline,\n                    args,\n                    accelerator,\n                    pipeline_args,\n                    epoch,\n                    torch_dtype=weight_dtype,\n                )\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        unet = unet.to(torch.float32)\n        unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n\n        if args.train_text_encoder:\n            text_encoder_one = unwrap_model(text_encoder_one)\n            text_encoder_lora_layers = convert_state_dict_to_diffusers(\n                get_peft_model_state_dict(text_encoder_one.to(torch.float32))\n            )\n            text_encoder_two = unwrap_model(text_encoder_two)\n            text_encoder_2_lora_layers = convert_state_dict_to_diffusers(\n                get_peft_model_state_dict(text_encoder_two.to(torch.float32))\n            )\n        else:\n            text_encoder_lora_layers = None\n            text_encoder_2_lora_layers = None\n\n        StableDiffusionXLPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_layers,\n            text_encoder_lora_layers=text_encoder_lora_layers,\n            text_encoder_2_lora_layers=text_encoder_2_lora_layers,\n        )\n        if args.output_kohya_format:\n            lora_state_dict = load_file(f\"{args.output_dir}/pytorch_lora_weights.safetensors\")\n            peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)\n            kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)\n            save_file(kohya_state_dict, f\"{args.output_dir}/pytorch_lora_weights_kohya.safetensors\")\n\n        # Final inference\n        # Load previous pipeline\n        vae = AutoencoderKL.from_pretrained(\n            vae_path,\n            subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        pipeline = StableDiffusionXLPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            vae=vae,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt, \"num_inference_steps\": 25}\n            images = log_validation(\n                pipeline,\n                args,\n                accelerator,\n                pipeline_args,\n                epoch,\n                is_final_validation=True,\n                torch_dtype=weight_dtype,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                use_dora=args.use_dora,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                train_text_encoder=args.train_text_encoder,\n                instance_prompt=args.instance_prompt,\n                validation_prompt=args.validation_prompt,\n                repo_folder=args.output_dir,\n                vae_path=args.pretrained_vae_model_name_or_path,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_lora_z_image.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# /// script\n# dependencies = [\n#     \"diffusers @ git+https://github.com/huggingface/diffusers.git\",\n#     \"torch>=2.0.0\",\n#     \"accelerate>=0.31.0\",\n#     \"transformers>=4.41.2\",\n#     \"ftfy\",\n#     \"tensorboard\",\n#     \"Jinja2\",\n#     \"peft>=0.11.1\",\n#     \"sentencepiece\",\n#     \"torchvision\",\n#     \"datasets\",\n#     \"bitsandbytes\",\n#     \"prodigyopt\",\n# ]\n# ///\n\nimport argparse\nimport copy\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import Any\n\nimport numpy as np\nimport torch\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torch.utils.data.sampler import BatchSampler\nfrom torchvision import transforms\nfrom torchvision.transforms import functional as TF\nfrom tqdm.auto import tqdm\nfrom transformers import Qwen2Tokenizer, Qwen3Model\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    BitsAndBytesConfig,\n    FlowMatchEulerDiscreteScheduler,\n    ZImagePipeline,\n    ZImageTransformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    _collate_lora_metadata,\n    _to_cpu_contiguous,\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    find_nearest_bucket,\n    free_memory,\n    get_fsdp_kwargs_from_accelerator,\n    offload_models,\n    parse_buckets_string,\n    wrap_with_fsdp,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif getattr(torch, \"distributed\", None) is not None:\n    import torch.distributed as dist\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n    quant_training=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# Z Image DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Z Image diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_z_image.md).\n\nQuant training? {quant_training}\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\npipeline = AutoPipelineForText2Image.from_pretrained(\"Tongyi-MAI/Z-Image\", torch_dtype=torch.bfloat16).to('cuda')\npipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## License\n\nApace License 2.0\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"apache-2.0\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"z-image\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    args.num_validation_images = args.num_validation_images if args.num_validation_images else 1\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(dtype=torch_dtype)\n    pipeline.enable_model_cpu_offload()\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n\n    images = []\n    for _ in range(args.num_validation_images):\n        with autocast_ctx:\n            image = pipeline(\n                prompt=args.validation_prompt,\n                prompt_embeds=pipeline_args[\"prompt_embeds\"],\n                generator=generator,\n            ).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef module_filter_fn(mod: torch.nn.Module, fqn: str):\n    # don't convert the output module\n    if fqn == \"proj_out\":\n        return False\n    # don't convert linear modules with weight dimensions not divisible by 16\n    if isinstance(mod, torch.nn.Linear):\n        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:\n            return False\n    return True\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--bnb_quantization_config_path\",\n        type=str,\n        default=None,\n        help=\"Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.\",\n    )\n    parser.add_argument(\n        \"--do_fp8_training\",\n        action=\"store_true\",\n        help=\"if we are doing FP8 training.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=512,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--skip_final_inference\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.\",\n    )\n    parser.add_argument(\n        \"--final_validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=4,\n        help=\"LoRA alpha to be used for additional scaling.\",\n    )\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Dropout probability for LoRA layers\")\n\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"z-image-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--aspect_ratio_buckets\",\n        type=str,\n        default=None,\n        help=(\n            \"Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. \"\n            \"e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'\"\n            \"Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored.\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the FLUX.1 dev variant is a guidance distilled model\",\n    )\n\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v,to_out.0\" will result in lora training of attention layers only'\n        ),\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoder to CPU when they are not used.\",\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\"--enable_npu_flash_attention\", action=\"store_true\", help=\"Enabla Flash Attention for NPU\")\n    parser.add_argument(\"--fsdp_text_encoder\", action=\"store_true\", help=\"Use FSDP for text encoder\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n    if args.do_fp8_training and args.bnb_quantization_config_path:\n        raise ValueError(\"Both `do_fp8_training` and `bnb_quantization_config_path` cannot be passed.\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n        buckets=None,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        self.buckets = buckets\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        self.pixel_values = []\n        for i, image in enumerate(self.instance_images):\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n\n            width, height = image.size\n\n            # Find the closest bucket\n            bucket_idx = find_nearest_bucket(height, width, self.buckets)\n            target_height, target_width = self.buckets[bucket_idx]\n            self.size = (target_height, target_width)\n\n            # based on the bucket assignment, define the transformations\n            image = self.train_transform(\n                image,\n                size=self.size,\n                center_crop=args.center_crop,\n                random_flip=args.random_flip,\n            )\n            self.pixel_values.append((image, bucket_idx))\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n        example[\"bucket_idx\"] = bucket_idx\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n    def train_transform(self, image, size=(224, 224), center_crop=False, random_flip=False):\n        # 1. Resize (deterministic)\n        resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        image = resize(image)\n\n        # 2. Crop: either center or SAME random crop\n        if center_crop:\n            crop = transforms.CenterCrop(size)\n            image = crop(image)\n        else:\n            # get_params returns (i, j, h, w)\n            i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)\n            image = TF.crop(image, i, j, h, w)\n\n        # 3. Random horizontal flip with the SAME coin flip\n        if random_flip:\n            do_flip = random.random() < 0.5\n            if do_flip:\n                image = TF.hflip(image)\n\n        # 4. ToTensor + Normalize (deterministic)\n        to_tensor = transforms.ToTensor()\n        normalize = transforms.Normalize([0.5], [0.5])\n        image = normalize(to_tensor(image))\n\n        return image\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass BucketBatchSampler(BatchSampler):\n    def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):\n        if not isinstance(batch_size, int) or batch_size <= 0:\n            raise ValueError(\"batch_size should be a positive integer value, but got batch_size={}\".format(batch_size))\n        if not isinstance(drop_last, bool):\n            raise ValueError(\"drop_last should be a boolean value, but got drop_last={}\".format(drop_last))\n\n        self.dataset = dataset\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n\n        # Group indices by bucket\n        self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]\n        for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):\n            self.bucket_indices[bucket_idx].append(idx)\n\n        self.sampler_len = 0\n        self.batches = []\n\n        # Pre-generate batches for each bucket\n        for indices_in_bucket in self.bucket_indices:\n            # Shuffle indices within the bucket\n            random.shuffle(indices_in_bucket)\n            # Create batches\n            for i in range(0, len(indices_in_bucket), self.batch_size):\n                batch = indices_in_bucket[i : i + self.batch_size]\n                if len(batch) < self.batch_size and self.drop_last:\n                    continue  # Skip partial batch if drop_last is True\n                self.batches.append(batch)\n                self.sampler_len += 1  # Count the number of batches\n\n    def __iter__(self):\n        # Shuffle the order of the batches each epoch\n        random.shuffle(self.batches)\n        for batch in self.batches:\n            yield batch\n\n    def __len__(self):\n        return self.sampler_len\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n    if args.do_fp8_training:\n        from torchao.float8 import Float8LinearConfig, convert_to_float8_training\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()\n            torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n\n            pipeline = ZImagePipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):\n                    images = pipeline(prompt=example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer = Qwen2Tokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"scheduler\",\n        revision=args.revision,\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    vae_config_shift_factor = vae.config.shift_factor\n    vae_config_scaling_factor = vae.config.scaling_factor\n\n    quantization_config = None\n    if args.bnb_quantization_config_path is not None:\n        with open(args.bnb_quantization_config_path, \"r\") as f:\n            config_kwargs = json.load(f)\n            if \"load_in_4bit\" in config_kwargs and config_kwargs[\"load_in_4bit\"]:\n                config_kwargs[\"bnb_4bit_compute_dtype\"] = weight_dtype\n        quantization_config = BitsAndBytesConfig(**config_kwargs)\n\n    transformer = ZImageTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n        quantization_config=quantization_config,\n        torch_dtype=weight_dtype,\n    )\n    if args.bnb_quantization_config_path is not None:\n        transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)\n\n    text_encoder = Qwen3Model.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    text_encoder.requires_grad_(False)\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            transformer.set_attention_backend(\"_native_npu\")\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu device \")\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    to_kwargs = {\"dtype\": weight_dtype, \"device\": accelerator.device} if not args.offload else {\"dtype\": weight_dtype}\n    vae.to(**to_kwargs)\n    # we never offload the transformer to CPU, so we can just use the accelerator device\n    transformer_to_kwargs = (\n        {\"device\": accelerator.device}\n        if args.bnb_quantization_config_path is not None\n        else {\"device\": accelerator.device, \"dtype\": weight_dtype}\n    )\n\n    is_fsdp = getattr(accelerator.state, \"fsdp_plugin\", None) is not None\n    if not is_fsdp:\n        transformer.to(**transformer_to_kwargs)\n\n    if args.do_fp8_training:\n        convert_to_float8_training(\n            transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)\n        )\n\n    text_encoder.to(**to_kwargs)\n    # Initialize a text encoding pipeline and keep it to CPU for now.\n    text_encoding_pipeline = ZImagePipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=None,\n        transformer=None,\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        scheduler=None,\n        revision=args.revision,\n    )\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    if args.lora_layers is not None:\n        target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n    else:\n        target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\n\n    # now we will add new LoRA weights the transformer layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.lora_alpha,\n        lora_dropout=args.lora_dropout,\n        init_lora_weights=\"gaussian\",\n        target_modules=target_modules,\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        transformer_cls = type(unwrap_model(transformer))\n\n        # 1) Validate and pick the transformer model\n        modules_to_save: dict[str, Any] = {}\n        transformer_model = None\n\n        for model in models:\n            if isinstance(unwrap_model(model), transformer_cls):\n                transformer_model = model\n                modules_to_save[\"transformer\"] = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        if transformer_model is None:\n            raise ValueError(\"No transformer model found in 'models'\")\n\n        # 2) Optionally gather FSDP state dict once\n        state_dict = accelerator.get_state_dict(model) if is_fsdp else None\n\n        # 3) Only main process materializes the LoRA state dict\n        transformer_lora_layers_to_save = None\n        if accelerator.is_main_process:\n            peft_kwargs = {}\n            if is_fsdp:\n                peft_kwargs[\"state_dict\"] = state_dict\n\n            transformer_lora_layers_to_save = get_peft_model_state_dict(\n                unwrap_model(transformer_model) if is_fsdp else transformer_model,\n                **peft_kwargs,\n            )\n\n            if is_fsdp:\n                transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)\n\n            # make sure to pop weight so that corresponding model is not saved again\n            if weights:\n                weights.pop()\n\n            ZImagePipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n                **_collate_lora_metadata(modules_to_save),\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        if not is_fsdp:\n            while len(models) > 0:\n                model = models.pop()\n\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    transformer_ = unwrap_model(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n        else:\n            transformer_ = ZImageTransformer2DModel.from_pretrained(\n                args.pretrained_model_name_or_path,\n                subfolder=\"transformer\",\n            )\n            transformer_.add_adapter(transformer_lora_config)\n\n        lora_state_dict = ZImagePipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    if args.aspect_ratio_buckets is not None:\n        buckets = parse_buckets_string(args.aspect_ratio_buckets)\n    else:\n        buckets = [(args.resolution, args.resolution)]\n    logger.info(f\"Using parsed aspect ratio buckets: {buckets}\")\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n        buckets=buckets,\n    )\n    batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_sampler=batch_sampler,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    def compute_text_embeddings(prompt, text_encoding_pipeline):\n        with torch.no_grad():\n            prompt_embeds, _ = text_encoding_pipeline.encode_prompt(\n                prompt=prompt,\n                max_sequence_length=args.max_sequence_length,\n            )\n        return prompt_embeds\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not train_dataset.custom_instance_prompts:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            instance_prompt_hidden_states = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            class_prompt_hidden_states = compute_text_embeddings(args.class_prompt, text_encoding_pipeline)\n    validation_embeddings = {}\n    if args.validation_prompt is not None:\n        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n            validation_embeddings[\"prompt_embeds\"] = compute_text_embeddings(\n                args.validation_prompt, text_encoding_pipeline\n            )\n\n    # Init FSDP for text encoder\n    if args.fsdp_text_encoder:\n        fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)\n        text_encoder_fsdp = wrap_with_fsdp(\n            model=text_encoding_pipeline.text_encoder,\n            device=accelerator.device,\n            offload=args.offload,\n            limit_all_gathers=True,\n            use_orig_params=True,\n            fsdp_kwargs=fsdp_kwargs,\n        )\n\n        text_encoding_pipeline.text_encoder = text_encoder_fsdp\n        dist.barrier()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n    if not train_dataset.custom_instance_prompts:\n        prompt_embeds = instance_prompt_hidden_states\n        if args.with_prior_preservation:\n            prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n\n    # if cache_latents is set to True, we encode images to latents and store them.\n    # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided\n    # we encode them in advance as well.\n    precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts\n    if precompute_latents:\n        prompt_embeds_cache = []\n        latents_cache = []\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                if args.cache_latents:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                            accelerator.device, non_blocking=True, dtype=vae.dtype\n                        )\n                        latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n                if train_dataset.custom_instance_prompts:\n                    if args.fsdp_text_encoder:\n                        prompt_embeds = compute_text_embeddings(batch[\"prompts\"], text_encoding_pipeline)\n                    else:\n                        with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):\n                            prompt_embeds = compute_text_embeddings(batch[\"prompts\"], text_encoding_pipeline)\n                    prompt_embeds_cache.append(prompt_embeds)\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    if args.cache_latents:\n        vae = vae.to(\"cpu\")\n        del vae\n\n    # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624\n    text_encoding_pipeline = text_encoding_pipeline.to(\"cpu\")\n    del text_encoder, tokenizer\n    free_memory()\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-z-image-lora\"\n        args_cp = vars(args).copy()\n        accelerator.init_trackers(tracker_name, config=args_cp)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            prompts = batch[\"prompts\"]\n\n            with accelerator.accumulate(models_to_accumulate):\n                if train_dataset.custom_instance_prompts:\n                    prompt_embeds = prompt_embeds_cache[step]\n                else:\n                    num_repeat_elements = len(prompts)\n                    prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_repeat_elements)]\n\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step].mode()\n                else:\n                    with offload_models(vae, device=accelerator.device, offload=args.offload):\n                        pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent_dist.mode()\n\n                model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                timestep_normalized = (1000 - timesteps) / 1000\n\n                noisy_model_input_5d = noisy_model_input.unsqueeze(2)  # (B, C, H, W) -> (B, C, 1, H, W)\n                noisy_model_input_list = list(noisy_model_input_5d.unbind(dim=0))  # List of (C, 1, H, W)\n\n                model_pred_list = transformer(\n                    noisy_model_input_list,\n                    timestep_normalized,\n                    prompt_embeds,  # This is a List[torch.Tensor] for Z-Image\n                    return_dict=False,\n                )[0]\n                model_pred = torch.stack(model_pred_list, dim=0)  # (B, C, 1, H, W)\n                model_pred = model_pred.squeeze(2)  # (B, C, H, W)\n                model_pred = -model_pred  # z-Image negates the prediction\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = noise - model_input\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process or is_fsdp:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = ZImagePipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    transformer=unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=validation_embeddings,\n                    epoch=epoch,\n                    torch_dtype=weight_dtype,\n                )\n\n                del pipeline\n                free_memory()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n\n    if is_fsdp:\n        transformer = unwrap_model(transformer)\n        state_dict = accelerator.get_state_dict(transformer)\n    if accelerator.is_main_process:\n        modules_to_save = {}\n        if is_fsdp:\n            if args.bnb_quantization_config_path is None:\n                if args.upcast_before_saving:\n                    state_dict = {\n                        k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()\n                    }\n                else:\n                    state_dict = {\n                        k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()\n                    }\n\n            transformer_lora_layers = get_peft_model_state_dict(\n                transformer,\n                state_dict=state_dict,\n            )\n            transformer_lora_layers = {\n                k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v\n                for k, v in transformer_lora_layers.items()\n            }\n\n        else:\n            transformer = unwrap_model(transformer)\n            if args.bnb_quantization_config_path is None:\n                if args.upcast_before_saving:\n                    transformer.to(torch.float32)\n                else:\n                    transformer = transformer.to(weight_dtype)\n            transformer_lora_layers = get_peft_model_state_dict(transformer)\n\n        modules_to_save[\"transformer\"] = transformer\n\n        ZImagePipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n            **_collate_lora_metadata(modules_to_save),\n        )\n\n        images = []\n        run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt)\n        should_run_final_inference = not args.skip_final_inference and run_validation\n        if should_run_final_inference:\n            pipeline = ZImagePipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                revision=args.revision,\n                variant=args.variant,\n                torch_dtype=weight_dtype,\n            )\n            # load attention processors\n            pipeline.load_lora_weights(args.output_dir)\n\n            # run inference\n            images = []\n            if args.validation_prompt and args.num_validation_images > 0:\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=validation_embeddings,\n                    epoch=epoch,\n                    is_final_validation=True,\n                    torch_dtype=weight_dtype,\n                )\n            images = None\n            del pipeline\n            free_memory()\n\n        validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt\n        quant_training = None\n        if args.do_fp8_training:\n            quant_training = \"FP8 TorchAO\"\n        elif args.bnb_quantization_config_path:\n            quant_training = \"BitsandBytes\"\n        save_model_card(\n            (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id,\n            images=images,\n            base_model=args.pretrained_model_name_or_path,\n            instance_prompt=args.instance_prompt,\n            validation_prompt=validation_prompt,\n            repo_folder=args.output_dir,\n            quant_training=quant_training,\n        )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/dreambooth/train_dreambooth_sd3.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport itertools\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig, T5EncoderModel, T5TokenizerFast\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    FlowMatchEulerDiscreteScheduler,\n    SD3Transformer2DModel,\n    StableDiffusion3Pipeline,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory\nfrom diffusers.utils import (\n    check_min_version,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    train_text_encoder=False,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n):\n    if \"large\" in base_model:\n        model_variant = \"SD3.5-Large\"\n        license_url = \"https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md\"\n        variant_tags = [\"sd3.5-large\", \"sd3.5\", \"sd3.5-diffusers\"]\n    else:\n        model_variant = \"SD3\"\n        license_url = \"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md\"\n        variant_tags = [\"sd3\", \"sd3-diffusers\"]\n\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# {model_variant} DreamBooth - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md).\n\nWas the text encoder fine-tuned? {train_text_encoder}.\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)\n\n```py\nfrom diffusers import AutoPipelineForText2Image\nimport torch\npipeline = AutoPipelineForText2Image.from_pretrained('{repo_id}', torch_dtype=torch.float16).to('cuda')\nimage = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]\n```\n\n## License\n\nPlease adhere to the licensing terms as described `[here]({license_url})`.\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"template:sd-lora\",\n    ]\n    tags += variant_tags\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef load_text_encoders(class_one, class_two, class_three):\n    text_encoder_one = class_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = class_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_three = class_three.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_3\", revision=args.revision, variant=args.variant\n    )\n    return text_encoder_one, text_encoder_two, text_encoder_three\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n    autocast_ctx = nullcontext()\n\n    with autocast_ctx:\n        images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    free_memory()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n    if model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=77,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd3-dreambooth\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"logit_normal\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\"],\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--precondition_outputs\",\n        type=int,\n        default=1,\n        help=\"Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how \"\n        \"model `target` is calculated.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        self.pixel_values = []\n        train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in self.instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            image = train_transforms(image)\n            self.pixel_values.append(image)\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\"pixel_values\": pixel_values, \"prompts\": prompts}\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef tokenize_prompt(tokenizer, prompt):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=77,\n        truncation=True,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\ndef _encode_prompt_with_t5(\n    text_encoder,\n    tokenizer,\n    max_sequence_length,\n    prompt=None,\n    num_images_per_prompt=1,\n    device=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=max_sequence_length,\n        truncation=True,\n        add_special_tokens=True,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    prompt_embeds = text_encoder(text_input_ids.to(device))[0]\n\n    dtype = text_encoder.dtype\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    _, seq_len, _ = prompt_embeds.shape\n\n    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds\n\n\ndef _encode_prompt_with_clip(\n    text_encoder,\n    tokenizer,\n    prompt: str,\n    device=None,\n    text_input_ids=None,\n    num_images_per_prompt: int = 1,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=77,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)\n\n    pooled_prompt_embeds = prompt_embeds[0]\n    prompt_embeds = prompt_embeds.hidden_states[-2]\n    prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)\n\n    _, seq_len, _ = prompt_embeds.shape\n    # duplicate text embeddings for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef encode_prompt(\n    text_encoders,\n    tokenizers,\n    prompt: str,\n    max_sequence_length,\n    device=None,\n    num_images_per_prompt: int = 1,\n    text_input_ids_list=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n\n    clip_tokenizers = tokenizers[:2]\n    clip_text_encoders = text_encoders[:2]\n\n    clip_prompt_embeds_list = []\n    clip_pooled_prompt_embeds_list = []\n    for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):\n        prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            prompt=prompt,\n            device=device if device is not None else text_encoder.device,\n            num_images_per_prompt=num_images_per_prompt,\n            text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,\n        )\n        clip_prompt_embeds_list.append(prompt_embeds)\n        clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)\n\n    clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)\n\n    t5_prompt_embed = _encode_prompt_with_t5(\n        text_encoders[-1],\n        tokenizers[-1],\n        max_sequence_length,\n        prompt=prompt,\n        num_images_per_prompt=num_images_per_prompt,\n        device=device if device is not None else text_encoders[-1].device,\n    )\n\n    clip_prompt_embeds = torch.nn.functional.pad(\n        clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])\n    )\n    prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)\n\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()\n            torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n            pipeline = StableDiffusion3Pipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            free_memory()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n    tokenizer_two = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n    )\n    tokenizer_three = T5TokenizerFast.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_3\",\n        revision=args.revision,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n    text_encoder_cls_three = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_3\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(\n        text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    transformer = SD3Transformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n    )\n\n    transformer.requires_grad_(True)\n    vae.requires_grad_(False)\n    if args.train_text_encoder:\n        text_encoder_one.requires_grad_(True)\n        text_encoder_two.requires_grad_(True)\n        text_encoder_three.requires_grad_(True)\n    else:\n        text_encoder_one.requires_grad_(False)\n        text_encoder_two.requires_grad_(False)\n        text_encoder_three.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    vae.to(accelerator.device, dtype=torch.float32)\n    if not args.train_text_encoder:\n        text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n        text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n        text_encoder_three.to(accelerator.device, dtype=weight_dtype)\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder_one.gradient_checkpointing_enable()\n            text_encoder_two.gradient_checkpointing_enable()\n            text_encoder_three.gradient_checkpointing_enable()\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            for i, model in enumerate(models):\n                if isinstance(unwrap_model(model), SD3Transformer2DModel):\n                    unwrap_model(model).save_pretrained(os.path.join(output_dir, \"transformer\"))\n                elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):\n                    if isinstance(unwrap_model(model), CLIPTextModelWithProjection):\n                        hidden_size = unwrap_model(model).config.hidden_size\n                        if hidden_size == 768:\n                            unwrap_model(model).save_pretrained(os.path.join(output_dir, \"text_encoder\"))\n                        elif hidden_size == 1280:\n                            unwrap_model(model).save_pretrained(os.path.join(output_dir, \"text_encoder_2\"))\n                    else:\n                        unwrap_model(model).save_pretrained(os.path.join(output_dir, \"text_encoder_3\"))\n                else:\n                    raise ValueError(f\"Wrong model supplied: {type(model)=}.\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n    def load_model_hook(models, input_dir):\n        for _ in range(len(models)):\n            # pop models so that they are not loaded again\n            model = models.pop()\n\n            # load diffusers style into model\n            if isinstance(unwrap_model(model), SD3Transformer2DModel):\n                load_model = SD3Transformer2DModel.from_pretrained(input_dir, subfolder=\"transformer\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n            elif isinstance(unwrap_model(model), (CLIPTextModelWithProjection, T5EncoderModel)):\n                try:\n                    load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder=\"text_encoder\")\n                    model(**load_model.config)\n                    model.load_state_dict(load_model.state_dict())\n                except Exception:\n                    try:\n                        load_model = CLIPTextModelWithProjection.from_pretrained(input_dir, subfolder=\"text_encoder_2\")\n                        model(**load_model.config)\n                        model.load_state_dict(load_model.state_dict())\n                    except Exception:\n                        try:\n                            load_model = T5EncoderModel.from_pretrained(input_dir, subfolder=\"text_encoder_3\")\n                            model(**load_model.config)\n                            model.load_state_dict(load_model.state_dict())\n                        except Exception:\n                            raise ValueError(f\"Couldn't load the model of type: ({type(model)}).\")\n            else:\n                raise ValueError(f\"Unsupported model found: {type(model)=}\")\n\n            del load_model\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer.parameters(), \"lr\": args.learning_rate}\n    if args.train_text_encoder:\n        # different learning rate for text encoder and unet\n        text_parameters_one_with_lr = {\n            \"params\": text_encoder_one.parameters(),\n            \"weight_decay\": args.adam_weight_decay_text_encoder,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        text_parameters_two_with_lr = {\n            \"params\": text_encoder_two.parameters(),\n            \"weight_decay\": args.adam_weight_decay_text_encoder,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        text_parameters_three_with_lr = {\n            \"params\": text_encoder_three.parameters(),\n            \"weight_decay\": args.adam_weight_decay_text_encoder,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        params_to_optimize = [\n            transformer_parameters_with_lr,\n            text_parameters_one_with_lr,\n            text_parameters_two_with_lr,\n            text_parameters_three_with_lr,\n        ]\n    else:\n        params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n        if args.train_text_encoder and args.text_encoder_lr:\n            logger.warning(\n                f\"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:\"\n                f\" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. \"\n                f\"When using prodigy only learning_rate is used as the initial learning rate.\"\n            )\n            # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be\n            # --learning_rate\n            params_to_optimize[1][\"lr\"] = args.learning_rate\n            params_to_optimize[2][\"lr\"] = args.learning_rate\n            params_to_optimize[3][\"lr\"] = args.learning_rate\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    if not args.train_text_encoder:\n        tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]\n        text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]\n\n        def compute_text_embeddings(prompt, text_encoders, tokenizers):\n            with torch.no_grad():\n                prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                    text_encoders, tokenizers, prompt, args.max_sequence_length\n                )\n                prompt_embeds = prompt_embeds.to(accelerator.device)\n                pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)\n            return prompt_embeds, pooled_prompt_embeds\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(\n            args.instance_prompt, text_encoders, tokenizers\n        )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        if not args.train_text_encoder:\n            class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(\n                args.class_prompt, text_encoders, tokenizers\n            )\n\n    # Clear the memory here\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        del tokenizers, text_encoders\n        # Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection\n        del text_encoder_one, text_encoder_two, text_encoder_three\n        free_memory()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n\n    if not train_dataset.custom_instance_prompts:\n        if not args.train_text_encoder:\n            prompt_embeds = instance_prompt_hidden_states\n            pooled_prompt_embeds = instance_pooled_prompt_embeds\n            if args.with_prior_preservation:\n                prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n                pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)\n        # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the\n        # batch prompts on all training steps\n        else:\n            tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)\n            tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)\n            tokens_three = tokenize_prompt(tokenizer_three, args.instance_prompt)\n            if args.with_prior_preservation:\n                class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)\n                class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)\n                class_tokens_three = tokenize_prompt(tokenizer_three, args.class_prompt)\n                tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)\n                tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)\n                tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        (\n            transformer,\n            text_encoder_one,\n            text_encoder_two,\n            text_encoder_three,\n            optimizer,\n            train_dataloader,\n            lr_scheduler,\n        ) = accelerator.prepare(\n            transformer,\n            text_encoder_one,\n            text_encoder_two,\n            text_encoder_three,\n            optimizer,\n            train_dataloader,\n            lr_scheduler,\n        )\n    else:\n        transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            transformer, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-sd3\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n        if args.train_text_encoder:\n            text_encoder_one.train()\n            text_encoder_two.train()\n            text_encoder_three.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            if args.train_text_encoder:\n                models_to_accumulate.extend([text_encoder_one, text_encoder_two, text_encoder_three])\n            with accelerator.accumulate(models_to_accumulate):\n                pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                prompts = batch[\"prompts\"]\n\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    if not args.train_text_encoder:\n                        prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(\n                            prompts, text_encoders, tokenizers\n                        )\n                    else:\n                        tokens_one = tokenize_prompt(tokenizer_one, prompts)\n                        tokens_two = tokenize_prompt(tokenizer_two, prompts)\n                        tokens_three = tokenize_prompt(tokenizer_three, prompts)\n\n                # Convert images to latent space\n                model_input = vae.encode(pixel_values).latent_dist.sample()\n                model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor\n                model_input = model_input.to(dtype=weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                # Predict the noise residual\n                if not args.train_text_encoder:\n                    model_pred = transformer(\n                        hidden_states=noisy_model_input,\n                        timestep=timesteps,\n                        encoder_hidden_states=prompt_embeds,\n                        pooled_projections=pooled_prompt_embeds,\n                        return_dict=False,\n                    )[0]\n                else:\n                    prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                        text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],\n                        tokenizers=None,\n                        prompt=None,\n                        text_input_ids_list=[tokens_one, tokens_two, tokens_three],\n                    )\n                    model_pred = transformer(\n                        hidden_states=noisy_model_input,\n                        timestep=timesteps,\n                        encoder_hidden_states=prompt_embeds,\n                        pooled_projections=pooled_prompt_embeds,\n                        return_dict=False,\n                    )[0]\n\n                # Follow: Section 5 of https://huggingface.co/papers/2206.00364.\n                # Preconditioning of the model outputs.\n                if args.precondition_outputs:\n                    model_pred = model_pred * (-sigmas) + noisy_model_input\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                if args.precondition_outputs:\n                    target = model_input\n                else:\n                    target = noise - model_input\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = torch.mean(\n                        (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(\n                            target_prior.shape[0], -1\n                        ),\n                        1,\n                    )\n                    prior_loss = prior_loss.mean()\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(\n                            transformer.parameters(),\n                            text_encoder_one.parameters(),\n                            text_encoder_two.parameters(),\n                            text_encoder_three.parameters(),\n                        )\n                        if args.train_text_encoder\n                        else transformer.parameters()\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                if not args.train_text_encoder:\n                    text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(\n                        text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three\n                    )\n                    text_encoder_one.to(weight_dtype)\n                    text_encoder_two.to(weight_dtype)\n                    text_encoder_three.to(weight_dtype)\n                pipeline = StableDiffusion3Pipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    text_encoder=accelerator.unwrap_model(text_encoder_one),\n                    text_encoder_2=accelerator.unwrap_model(text_encoder_two),\n                    text_encoder_3=accelerator.unwrap_model(text_encoder_three),\n                    transformer=accelerator.unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                    torch_dtype=weight_dtype,\n                )\n                if not args.train_text_encoder:\n                    del text_encoder_one, text_encoder_two, text_encoder_three\n                    free_memory()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        transformer = unwrap_model(transformer)\n\n        if args.train_text_encoder:\n            text_encoder_one = unwrap_model(text_encoder_one)\n            text_encoder_two = unwrap_model(text_encoder_two)\n            text_encoder_three = unwrap_model(text_encoder_three)\n            pipeline = StableDiffusion3Pipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                transformer=transformer,\n                text_encoder=text_encoder_one,\n                text_encoder_2=text_encoder_two,\n                text_encoder_3=text_encoder_three,\n            )\n        else:\n            pipeline = StableDiffusion3Pipeline.from_pretrained(\n                args.pretrained_model_name_or_path, transformer=transformer\n            )\n\n        # save the pipeline\n        pipeline.save_pretrained(args.output_dir)\n\n        # Final inference\n        # Load previous pipeline\n        pipeline = StableDiffusion3Pipeline.from_pretrained(\n            args.output_dir,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt}\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=pipeline_args,\n                epoch=epoch,\n                is_final_validation=True,\n                torch_dtype=weight_dtype,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                train_text_encoder=args.train_text_encoder,\n                instance_prompt=args.instance_prompt,\n                validation_prompt=args.validation_prompt,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/flux-control/README.md",
    "content": "# Training Flux Control\n\nThis (experimental) example shows how to train Control LoRAs with [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about Flux Control family, refer to the following resources:\n\n* [Docs](https://github.com/black-forest-labs/flux/blob/main/docs/structural-conditioning.md) by Black Forest Labs\n* Diffusers docs ([1](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#canny-control), [2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#depth-control))\n\nTo incorporate additional condition latents, we expand the input features of Flux.1-Dev from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `x_embedder` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `FluxControlPipeline`.\n\n> [!NOTE]\n> **Gated model**\n>\n> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:\n\n```bash\nhf auth login\n```\n\nThe example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.\n\n```bash\naccelerate launch train_control_lora_flux.py \\\n  --pretrained_model_name_or_path=\"black-forest-labs/FLUX.1-dev\" \\\n  --dataset_name=\"raulc0399/open_pose_controlnet\" \\\n  --output_dir=\"pose-control-lora\" \\\n  --mixed_precision=\"bf16\" \\\n  --train_batch_size=1 \\\n  --rank=64 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --learning_rate=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=5000 \\\n  --validation_image=\"openpose.png\" \\\n  --validation_prompt=\"A couple, 4k photo, highly detailed\" \\\n  --offload \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\n`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).\n\nYou need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`.\n\nThe training script exposes additional CLI args that might be useful to experiment with:\n\n* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer. \n* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading.\n* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify \"all-linear\", all the linear layers will be LoRA-attached.\n\n### Training with DeepSpeed\n\nIt's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed):\n\n```yaml\ncompute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  gradient_clipping: 1.0\n  offload_optimizer_device: cpu\n  offload_param_device: cpu\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 1\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n```\n\nAnd then while launching training, pass the config file:\n\n```bash\naccelerate launch --config_file=CONFIG_FILE.yaml ...\n```\n\n### Inference\n\nThe pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first:\n\n```bash\npip install controlnet_aux\n```\n\nAnd then we are ready:\n\n```py\nfrom controlnet_aux import OpenposeDetector\nfrom diffusers import FluxControlPipeline\nfrom diffusers.utils import load_image\nfrom PIL import Image\nimport numpy as np\nimport torch \n\npipe = FluxControlPipeline.from_pretrained(\"black-forest-labs/FLUX.1-dev\", torch_dtype=torch.bfloat16).to(\"cuda\")\npipe.load_lora_weights(\"...\") # change this.\n\nopen_pose = OpenposeDetector.from_pretrained(\"lllyasviel/Annotators\")\n\n# prepare pose condition.\nurl = \"https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg\"\nimage = load_image(url)\nimage = open_pose(image, detect_resolution=512, image_resolution=1024)\nimage = np.array(image)[:, :, ::-1]           \nimage = Image.fromarray(np.uint8(image))\n\nprompt = \"A couple, 4k photo, highly detailed\"\n\ngen_images = pipe(\n  prompt=prompt,\n  control_image=image,\n  num_inference_steps=50,\n  joint_attention_kwargs={\"scale\": 0.9},\n  guidance_scale=25., \n).images[0]\ngen_images.save(\"output.png\")\n```\n\n## Full fine-tuning\n\nWe provide a non-LoRA version of the training script `train_control_flux.py`. Here is an example command:\n\n```bash\naccelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \\\n  --pretrained_model_name_or_path=\"black-forest-labs/FLUX.1-dev\" \\\n  --dataset_name=\"raulc0399/open_pose_controlnet\" \\\n  --output_dir=\"pose-control\" \\\n  --mixed_precision=\"bf16\" \\\n  --train_batch_size=2 \\\n  --dataloader_num_workers=4 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --proportion_empty_prompts=0.2 \\\n  --learning_rate=5e-5 \\\n  --adam_weight_decay=1e-4 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"cosine\" \\\n  --lr_warmup_steps=1000 \\\n  --checkpointing_steps=1000 \\\n  --max_train_steps=10000 \\\n  --validation_steps=200 \\\n  --validation_image \"2_pose_1024.jpg\" \"3_pose_1024.jpg\" \\\n  --validation_prompt \"two friends sitting by each other enjoying a day at the park, full hd, cinematic\" \"person enjoying a day at the park, full hd, cinematic\" \\\n  --offload \\\n  --seed=\"0\" \\\n  --push_to_hub\n```\n\nChange the `validation_image` and `validation_prompt` as needed.\n\nFor inference, this time, we will run:\n\n```py\nfrom controlnet_aux import OpenposeDetector\nfrom diffusers import FluxControlPipeline, FluxTransformer2DModel\nfrom diffusers.utils import load_image\nfrom PIL import Image\nimport numpy as np\nimport torch \n\ntransformer = FluxTransformer2DModel.from_pretrained(\"...\") # change this.\npipe = FluxControlPipeline.from_pretrained(\n  \"black-forest-labs/FLUX.1-dev\",  transformer=transformer, torch_dtype=torch.bfloat16\n).to(\"cuda\")\n\nopen_pose = OpenposeDetector.from_pretrained(\"lllyasviel/Annotators\")\n\n# prepare pose condition.\nurl = \"https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg\"\nimage = load_image(url)\nimage = open_pose(image, detect_resolution=512, image_resolution=1024)\nimage = np.array(image)[:, :, ::-1]           \nimage = Image.fromarray(np.uint8(image))\n\nprompt = \"A couple, 4k photo, highly detailed\"\n\ngen_images = pipe(\n  prompt=prompt,\n  control_image=image,\n  num_inference_steps=50,\n  guidance_scale=25., \n).images[0]\ngen_images.save(\"output.png\")\n```\n\n## Things to note\n\n* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗\n* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified. \n* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality. "
  },
  {
    "path": "examples/flux-control/requirements.txt",
    "content": "transformers==4.47.0\nwandb\ntorch\ntorchvision\naccelerate==1.2.0\npeft>=0.14.0\n"
  },
  {
    "path": "examples/flux-control/train_control_flux.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport accelerate\nimport numpy as np\nimport torch\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedType, ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\n\nimport diffusers\nfrom diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    free_memory,\n)\nfrom diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\nNORM_LAYER_PREFIXES = [\"norm_q\", \"norm_k\", \"norm_added_q\", \"norm_added_k\"]\n\n\ndef encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype):\n    pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample()\n    pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor\n    return pixel_latents.to(weight_dtype)\n\n\ndef log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_final_validation=False):\n    logger.info(\"Running validation... \")\n\n    if not is_final_validation:\n        flux_transformer = accelerator.unwrap_model(flux_transformer)\n        pipeline = FluxControlPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            transformer=flux_transformer,\n            torch_dtype=weight_dtype,\n        )\n    else:\n        transformer = FluxTransformer2DModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)\n        pipeline = FluxControlPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            transformer=transformer,\n            torch_dtype=weight_dtype,\n        )\n\n    pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n    if is_final_validation or torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype)\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        validation_image = load_image(validation_image)\n        # maybe need to inference on 1024 to get a good image\n        validation_image = validation_image.resize((args.resolution, args.resolution))\n\n        images = []\n\n        for _ in range(args.num_validation_images):\n            with autocast_ctx:\n                image = pipeline(\n                    prompt=validation_prompt,\n                    control_image=validation_image,\n                    num_inference_steps=50,\n                    guidance_scale=args.guidance_scale,\n                    generator=generator,\n                    max_sequence_length=512,\n                    height=args.resolution,\n                    width=args.resolution,\n                ).images[0]\n            image = image.resize((args.resolution, args.resolution))\n            images.append(image)\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n                formatted_images = []\n                formatted_images.append(np.asarray(validation_image))\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n                formatted_images = np.stack(formatted_images)\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n                formatted_images.append(wandb.Image(validation_image, caption=\"Conditioning\"))\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({tracker_key: formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n        del pipeline\n        free_memory()\n        return image_logs\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# flux-control-{repo_id}\n\nThese are Control weights trained on {base_model} with new type of conditioning.\n{img_str}\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"flux\",\n        \"flux-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"control\",\n        \"diffusers-training\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a Flux Control training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-control\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the control conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\"--log_dataset_samples\", action=\"store_true\", help=\"Whether to log somple dataset samples.\")\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the control conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=1,\n        help=\"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"flux_train_control\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n    parser.add_argument(\n        \"--jsonl_for_train\",\n        type=str,\n        default=None,\n        help=\"Path to the jsonl file containing the training data.\",\n    )\n    parser.add_argument(\n        \"--only_target_transformer_blocks\",\n        action=\"store_true\",\n        help=\"If we should only target the transformer blocks to train along with the input layer (`x_embedder`).\",\n    )\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=30.0,\n        help=\"the guidance scale used for transformer.\",\n    )\n\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoders to CPU when they are not used.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.jsonl_for_train is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--jsonl_for_train`\")\n\n    if args.dataset_name is not None and args.jsonl_for_train is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--jsonl_for_train`\")\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    if args.validation_prompt is not None and args.validation_image is None:\n        raise ValueError(\"`--validation_image` must be set if `--validation_prompt` is set\")\n\n    if args.validation_prompt is None and args.validation_image is not None:\n        raise ValueError(\"`--validation_prompt` must be set if `--validation_image` is set\")\n\n    if (\n        args.validation_image is not None\n        and args.validation_prompt is not None\n        and len(args.validation_image) != 1\n        and len(args.validation_prompt) != 1\n        and len(args.validation_image) != len(args.validation_prompt)\n    ):\n        raise ValueError(\n            \"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,\"\n            \" or the same number of `--validation_prompt`s and `--validation_image`s\"\n        )\n\n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer.\"\n        )\n\n    return args\n\n\ndef get_train_dataset(args, accelerator):\n    dataset = None\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    if args.jsonl_for_train is not None:\n        # load from json\n        dataset = load_dataset(\"json\", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)\n        dataset = dataset.flatten_indices()\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    if args.image_column is None:\n        image_column = column_names[0]\n        logger.info(f\"image column defaulting to {image_column}\")\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.caption_column is None:\n        caption_column = column_names[1]\n        logger.info(f\"caption column defaulting to {caption_column}\")\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.conditioning_image_column is None:\n        conditioning_image_column = column_names[2]\n        logger.info(f\"conditioning image column defaulting to {conditioning_image_column}\")\n    else:\n        conditioning_image_column = args.conditioning_image_column\n        if conditioning_image_column not in column_names:\n            raise ValueError(\n                f\"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    with accelerator.main_process_first():\n        train_dataset = dataset[\"train\"].shuffle(seed=args.seed)\n        if args.max_train_samples is not None:\n            train_dataset = train_dataset.select(range(args.max_train_samples))\n    return train_dataset\n\n\ndef prepare_train_dataset(dataset, accelerator):\n    image_transforms = transforms.Compose(\n        [\n            transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [\n            (image.convert(\"RGB\") if not isinstance(image, str) else Image.open(image).convert(\"RGB\"))\n            for image in examples[args.image_column]\n        ]\n        images = [image_transforms(image) for image in images]\n\n        conditioning_images = [\n            (image.convert(\"RGB\") if not isinstance(image, str) else Image.open(image).convert(\"RGB\"))\n            for image in examples[args.conditioning_image_column]\n        ]\n        conditioning_images = [image_transforms(image) for image in conditioning_images]\n        examples[\"pixel_values\"] = images\n        examples[\"conditioning_pixel_values\"] = conditioning_images\n\n        is_caption_list = isinstance(examples[args.caption_column][0], list)\n        if is_caption_list:\n            examples[\"captions\"] = [max(example, key=len) for example in examples[args.caption_column]]\n        else:\n            examples[\"captions\"] = list(examples[args.caption_column])\n\n        return examples\n\n    with accelerator.main_process_first():\n        dataset = dataset.with_transform(preprocess_train)\n\n    return dataset\n\n\ndef collate_fn(examples):\n    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n    conditioning_pixel_values = torch.stack([example[\"conditioning_pixel_values\"] for example in examples])\n    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()\n    captions = [example[\"captions\"] for example in examples]\n    return {\"pixel_values\": pixel_values, \"conditioning_pixel_values\": conditioning_pixel_values, \"captions\": captions}\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_out_dir = Path(args.output_dir, args.logging_dir)\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir))\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.\n    if torch.backends.mps.is_available():\n        logger.info(\"MPS is enabled. Disabling AMP.\")\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        # DEBUG, INFO, WARNING, ERROR, CRITICAL\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load models. We will load the text encoders later in a pipeline to compute\n    # embeddings.\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)\n    flux_transformer = FluxTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    logger.info(\"All models loaded successfully\")\n\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"scheduler\",\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    if not args.only_target_transformer_blocks:\n        flux_transformer.requires_grad_(True)\n    vae.requires_grad_(False)\n\n    # cast down and move to the CPU\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # let's not move the VAE to the GPU yet.\n    vae.to(dtype=torch.float32)  # keep the VAE in float32.\n\n    # enable image inputs\n    with torch.no_grad():\n        initial_input_channels = flux_transformer.config.in_channels\n        new_linear = torch.nn.Linear(\n            flux_transformer.x_embedder.in_features * 2,\n            flux_transformer.x_embedder.out_features,\n            bias=flux_transformer.x_embedder.bias is not None,\n            dtype=flux_transformer.dtype,\n            device=flux_transformer.device,\n        )\n        new_linear.weight.zero_()\n        new_linear.weight[:, :initial_input_channels].copy_(flux_transformer.x_embedder.weight)\n        if flux_transformer.x_embedder.bias is not None:\n            new_linear.bias.copy_(flux_transformer.x_embedder.bias)\n        flux_transformer.x_embedder = new_linear\n\n    assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)\n    flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)\n\n    if args.only_target_transformer_blocks:\n        flux_transformer.x_embedder.requires_grad_(True)\n        for name, module in flux_transformer.named_modules():\n            if \"transformer_blocks\" in name:\n                module.requires_grad_(True)\n            else:\n                module.requirs_grad_(False)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                for model in models:\n                    if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))):\n                        model = unwrap_model(model)\n                        model.save_pretrained(os.path.join(output_dir, \"transformer\"))\n                    else:\n                        raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    if weights:\n                        weights.pop()\n\n        def load_model_hook(models, input_dir):\n            transformer_ = None\n\n            if not accelerator.distributed_type == DistributedType.DEEPSPEED:\n                while len(models) > 0:\n                    model = models.pop()\n\n                    if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))):\n                        transformer_ = model  # noqa: F841\n                    else:\n                        raise ValueError(f\"unexpected save model: {unwrap_model(model).__class__}\")\n\n            else:\n                transformer_ = FluxTransformer2DModel.from_pretrained(input_dir, subfolder=\"transformer\")  # noqa: F841\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        flux_transformer.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimization parameters\n    optimizer = optimizer_class(\n        flux_transformer.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Prepare dataset and dataloader.\n    train_dataset = get_train_dataset(args, accelerator)\n    train_dataset = prepare_train_dataset(train_dataset, accelerator)\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n    # Prepare everything with our `accelerator`.\n    flux_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        flux_transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n\n        # tensorboard cannot handle list types for config\n        tracker_config.pop(\"validation_prompt\")\n        tracker_config.pop(\"validation_image\")\n\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed.\n    text_encoding_pipeline = FluxControlPipeline.from_pretrained(\n        args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype\n    )\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            logger.info(f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\")\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            logger.info(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    if accelerator.is_main_process and args.report_to == \"wandb\" and args.log_dataset_samples:\n        logger.info(\"Logging some dataset samples.\")\n        formatted_images = []\n        formatted_control_images = []\n        all_prompts = []\n        for i, batch in enumerate(train_dataloader):\n            images = (batch[\"pixel_values\"] + 1) / 2\n            control_images = (batch[\"conditioning_pixel_values\"] + 1) / 2\n            prompts = batch[\"captions\"]\n\n            if len(formatted_images) > 10:\n                break\n\n            for img, control_img, prompt in zip(images, control_images, prompts):\n                formatted_images.append(img)\n                formatted_control_images.append(control_img)\n                all_prompts.append(prompt)\n\n        logged_artifacts = []\n        for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):\n            logged_artifacts.append(wandb.Image(control_img, caption=\"Conditioning\"))\n            logged_artifacts.append(wandb.Image(img, caption=prompt))\n\n        wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == \"wandb\"]\n        wandb_tracker[0].log({\"dataset_samples\": logged_artifacts})\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    image_logs = None\n    for epoch in range(first_epoch, args.num_train_epochs):\n        flux_transformer.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(flux_transformer):\n                # Convert images to latent space\n                # vae encode\n                pixel_latents = encode_images(batch[\"pixel_values\"], vae.to(accelerator.device), weight_dtype)\n                control_latents = encode_images(\n                    batch[\"conditioning_pixel_values\"], vae.to(accelerator.device), weight_dtype\n                )\n                if args.offload:\n                    # offload vae to CPU.\n                    vae.cpu()\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                bsz = pixel_latents.shape[0]\n                noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype)\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)\n\n                # Add noise according to flow matching.\n                sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)\n                noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise\n                # Concatenate across channels.\n                # Question: Should we concatenate before adding noise?\n                concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1)\n\n                # pack the latents.\n                packed_noisy_model_input = FluxControlPipeline._pack_latents(\n                    concatenated_noisy_model_input,\n                    batch_size=bsz,\n                    num_channels_latents=concatenated_noisy_model_input.shape[1],\n                    height=concatenated_noisy_model_input.shape[2],\n                    width=concatenated_noisy_model_input.shape[3],\n                )\n\n                # latent image ids for RoPE.\n                latent_image_ids = FluxControlPipeline._prepare_latent_image_ids(\n                    bsz,\n                    concatenated_noisy_model_input.shape[2] // 2,\n                    concatenated_noisy_model_input.shape[3] // 2,\n                    accelerator.device,\n                    weight_dtype,\n                )\n\n                # handle guidance\n                if unwrap_model(flux_transformer).config.guidance_embeds:\n                    guidance_vec = torch.full(\n                        (bsz,),\n                        args.guidance_scale,\n                        device=noisy_model_input.device,\n                        dtype=weight_dtype,\n                    )\n                else:\n                    guidance_vec = None\n\n                # text encoding.\n                captions = batch[\"captions\"]\n                text_encoding_pipeline = text_encoding_pipeline.to(\"cuda\")\n                with torch.no_grad():\n                    prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(\n                        captions, prompt_2=None\n                    )\n                # this could be optimized by not having to do any text encoding and just\n                # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`\n                if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:\n                    prompt_embeds.zero_()\n                    pooled_prompt_embeds.zero_()\n                if args.offload:\n                    text_encoding_pipeline = text_encoding_pipeline.to(\"cpu\")\n\n                # Predict.\n                model_pred = flux_transformer(\n                    hidden_states=packed_noisy_model_input,\n                    timestep=timesteps / 1000,\n                    guidance=guidance_vec,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    return_dict=False,\n                )[0]\n                model_pred = FluxControlPipeline._unpack_latents(\n                    model_pred,\n                    height=noisy_model_input.shape[2] * vae_scale_factor,\n                    width=noisy_model_input.shape[3] * vae_scale_factor,\n                    vae_scale_factor=vae_scale_factor,\n                )\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow-matching loss\n                target = noise - pixel_latents\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n                accelerator.backward(loss)\n\n                if accelerator.sync_gradients:\n                    params_to_clip = flux_transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.\n                if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        image_logs = log_validation(\n                            flux_transformer=flux_transformer,\n                            args=args,\n                            accelerator=accelerator,\n                            weight_dtype=weight_dtype,\n                            step=global_step,\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        flux_transformer = unwrap_model(flux_transformer)\n        if args.upcast_before_saving:\n            flux_transformer.to(torch.float32)\n        flux_transformer.save_pretrained(args.output_dir)\n\n        del flux_transformer\n        del text_encoding_pipeline\n        del vae\n        free_memory()\n\n        # Run a final round of validation.\n        image_logs = None\n        if args.validation_prompt is not None:\n            image_logs = log_validation(\n                flux_transformer=None,\n                args=args,\n                accelerator=accelerator,\n                weight_dtype=weight_dtype,\n                step=global_step,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                image_logs=image_logs,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\", \"checkpoint-*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/flux-control/train_control_lora_flux.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport accelerate\nimport numpy as np\nimport torch\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedType, ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\n\nimport diffusers\nfrom diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    free_memory,\n)\nfrom diffusers.utils import check_min_version, is_wandb_available, load_image, make_image_grid\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\nNORM_LAYER_PREFIXES = [\"norm_q\", \"norm_k\", \"norm_added_q\", \"norm_added_k\"]\n\n\ndef encode_images(pixels: torch.Tensor, vae: torch.nn.Module, weight_dtype):\n    pixel_latents = vae.encode(pixels.to(vae.dtype)).latent_dist.sample()\n    pixel_latents = (pixel_latents - vae.config.shift_factor) * vae.config.scaling_factor\n    return pixel_latents.to(weight_dtype)\n\n\ndef log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_final_validation=False):\n    logger.info(\"Running validation... \")\n\n    if not is_final_validation:\n        flux_transformer = accelerator.unwrap_model(flux_transformer)\n        pipeline = FluxControlPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            transformer=flux_transformer,\n            torch_dtype=weight_dtype,\n        )\n    else:\n        transformer = FluxTransformer2DModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"transformer\", torch_dtype=weight_dtype\n        )\n        initial_channels = transformer.config.in_channels\n        pipeline = FluxControlPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            transformer=transformer,\n            torch_dtype=weight_dtype,\n        )\n        pipeline.load_lora_weights(args.output_dir)\n        assert pipeline.transformer.config.in_channels == initial_channels * 2, (\n            f\"{pipeline.transformer.config.in_channels=}\"\n        )\n\n    pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n    if is_final_validation or torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type, weight_dtype)\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        validation_image = load_image(validation_image)\n        # maybe need to inference on 1024 to get a good image\n        validation_image = validation_image.resize((args.resolution, args.resolution))\n\n        images = []\n\n        for _ in range(args.num_validation_images):\n            with autocast_ctx:\n                image = pipeline(\n                    prompt=validation_prompt,\n                    control_image=validation_image,\n                    num_inference_steps=50,\n                    guidance_scale=args.guidance_scale,\n                    generator=generator,\n                    max_sequence_length=512,\n                    height=args.resolution,\n                    width=args.resolution,\n                ).images[0]\n            image = image.resize((args.resolution, args.resolution))\n            images.append(image)\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n                formatted_images = []\n                formatted_images.append(np.asarray(validation_image))\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n                formatted_images = np.stack(formatted_images)\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n                formatted_images.append(wandb.Image(validation_image, caption=\"Conditioning\"))\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({tracker_key: formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n        del pipeline\n        free_memory()\n        return image_logs\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# control-lora-{repo_id}\n\nThese are Control LoRA weights trained on {base_model} with new type of conditioning.\n{img_str}\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"flux\",\n        \"flux-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"control-lora\",\n        \"diffusers-training\",\n        \"lora\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a Control LoRA training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"control-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\"--use_lora_bias\", action=\"store_true\", help=\"If training the bias of lora_B layers.\")\n    parser.add_argument(\n        \"--lora_layers\",\n        type=str,\n        default=None,\n        help=(\n            'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - \"to_k,to_q,to_v,to_out.0\" will result in lora training of attention layers only'\n        ),\n    )\n    parser.add_argument(\n        \"--gaussian_init_lora\",\n        action=\"store_true\",\n        help=\"If using the Gaussian init strategy. When False, we follow the original LoRA init strategy.\",\n    )\n    parser.add_argument(\"--train_norm_layers\", action=\"store_true\", help=\"Whether to train the norm scales.\")\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the control conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\"--log_dataset_samples\", action=\"store_true\", help=\"Whether to log somple dataset samples.\")\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the control conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=1,\n        help=\"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"flux_train_control_lora\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n    parser.add_argument(\n        \"--jsonl_for_train\",\n        type=str,\n        default=None,\n        help=\"Path to the jsonl file containing the training data.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=30.0,\n        help=\"the guidance scale used for transformer.\",\n    )\n\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoders to CPU when they are not used.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.jsonl_for_train is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--jsonl_for_train`\")\n\n    if args.dataset_name is not None and args.jsonl_for_train is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--jsonl_for_train`\")\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    if args.validation_prompt is not None and args.validation_image is None:\n        raise ValueError(\"`--validation_image` must be set if `--validation_prompt` is set\")\n\n    if args.validation_prompt is None and args.validation_image is not None:\n        raise ValueError(\"`--validation_prompt` must be set if `--validation_image` is set\")\n\n    if (\n        args.validation_image is not None\n        and args.validation_prompt is not None\n        and len(args.validation_image) != 1\n        and len(args.validation_prompt) != 1\n        and len(args.validation_image) != len(args.validation_prompt)\n    ):\n        raise ValueError(\n            \"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,\"\n            \" or the same number of `--validation_prompt`s and `--validation_image`s\"\n        )\n\n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer.\"\n        )\n\n    return args\n\n\ndef get_train_dataset(args, accelerator):\n    dataset = None\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    if args.jsonl_for_train is not None:\n        # load from json\n        dataset = load_dataset(\"json\", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)\n        dataset = dataset.flatten_indices()\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    if args.image_column is None:\n        image_column = column_names[0]\n        logger.info(f\"image column defaulting to {image_column}\")\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.caption_column is None:\n        caption_column = column_names[1]\n        logger.info(f\"caption column defaulting to {caption_column}\")\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.conditioning_image_column is None:\n        conditioning_image_column = column_names[2]\n        logger.info(f\"conditioning image column defaulting to {conditioning_image_column}\")\n    else:\n        conditioning_image_column = args.conditioning_image_column\n        if conditioning_image_column not in column_names:\n            raise ValueError(\n                f\"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    with accelerator.main_process_first():\n        train_dataset = dataset[\"train\"].shuffle(seed=args.seed)\n        if args.max_train_samples is not None:\n            train_dataset = train_dataset.select(range(args.max_train_samples))\n    return train_dataset\n\n\ndef prepare_train_dataset(dataset, accelerator):\n    image_transforms = transforms.Compose(\n        [\n            transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [\n            (image.convert(\"RGB\") if not isinstance(image, str) else Image.open(image).convert(\"RGB\"))\n            for image in examples[args.image_column]\n        ]\n        images = [image_transforms(image) for image in images]\n\n        conditioning_images = [\n            (image.convert(\"RGB\") if not isinstance(image, str) else Image.open(image).convert(\"RGB\"))\n            for image in examples[args.conditioning_image_column]\n        ]\n        conditioning_images = [image_transforms(image) for image in conditioning_images]\n        examples[\"pixel_values\"] = images\n        examples[\"conditioning_pixel_values\"] = conditioning_images\n\n        is_caption_list = isinstance(examples[args.caption_column][0], list)\n        if is_caption_list:\n            examples[\"captions\"] = [max(example, key=len) for example in examples[args.caption_column]]\n        else:\n            examples[\"captions\"] = list(examples[args.caption_column])\n\n        return examples\n\n    with accelerator.main_process_first():\n        dataset = dataset.with_transform(preprocess_train)\n\n    return dataset\n\n\ndef collate_fn(examples):\n    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n    conditioning_pixel_values = torch.stack([example[\"conditioning_pixel_values\"] for example in examples])\n    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()\n    captions = [example[\"captions\"] for example in examples]\n    return {\"pixel_values\": pixel_values, \"conditioning_pixel_values\": conditioning_pixel_values, \"captions\": captions}\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n    if args.use_lora_bias and args.gaussian_init_lora:\n        raise ValueError(\"`gaussian` LoRA init scheme isn't supported when `use_lora_bias` is True.\")\n\n    logging_out_dir = Path(args.output_dir, args.logging_dir)\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=str(logging_out_dir))\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS. A technique for accelerating machine learning computations on iOS and macOS devices.\n    if torch.backends.mps.is_available():\n        logger.info(\"MPS is enabled. Disabling AMP.\")\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        # DEBUG, INFO, WARNING, ERROR, CRITICAL\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load models. We will load the text encoders later in a pipeline to compute\n    # embeddings.\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)\n    flux_transformer = FluxTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    logger.info(\"All models loaded successfully\")\n\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"scheduler\",\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    vae.requires_grad_(False)\n    flux_transformer.requires_grad_(False)\n\n    # cast down and move to the CPU\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # let's not move the VAE to the GPU yet.\n    vae.to(dtype=torch.float32)  # keep the VAE in float32.\n    flux_transformer.to(dtype=weight_dtype, device=accelerator.device)\n\n    # enable image inputs\n    with torch.no_grad():\n        initial_input_channels = flux_transformer.config.in_channels\n        new_linear = torch.nn.Linear(\n            flux_transformer.x_embedder.in_features * 2,\n            flux_transformer.x_embedder.out_features,\n            bias=flux_transformer.x_embedder.bias is not None,\n            dtype=flux_transformer.dtype,\n            device=flux_transformer.device,\n        )\n        new_linear.weight.zero_()\n        new_linear.weight[:, :initial_input_channels].copy_(flux_transformer.x_embedder.weight)\n        if flux_transformer.x_embedder.bias is not None:\n            new_linear.bias.copy_(flux_transformer.x_embedder.bias)\n        flux_transformer.x_embedder = new_linear\n\n    assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)\n    flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)\n\n    if args.lora_layers is not None:\n        if args.lora_layers != \"all-linear\":\n            target_modules = [layer.strip() for layer in args.lora_layers.split(\",\")]\n            # add the input layer to the mix.\n            if \"x_embedder\" not in target_modules:\n                target_modules.append(\"x_embedder\")\n        elif args.lora_layers == \"all-linear\":\n            target_modules = set()\n            for name, module in flux_transformer.named_modules():\n                if isinstance(module, torch.nn.Linear):\n                    target_modules.add(name)\n            target_modules = list(target_modules)\n    else:\n        target_modules = [\n            \"x_embedder\",\n            \"attn.to_k\",\n            \"attn.to_q\",\n            \"attn.to_v\",\n            \"attn.to_out.0\",\n            \"attn.add_k_proj\",\n            \"attn.add_q_proj\",\n            \"attn.add_v_proj\",\n            \"attn.to_add_out\",\n            \"ff.net.0.proj\",\n            \"ff.net.2\",\n            \"ff_context.net.0.proj\",\n            \"ff_context.net.2\",\n        ]\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\" if args.gaussian_init_lora else True,\n        target_modules=target_modules,\n        lora_bias=args.use_lora_bias,\n    )\n    flux_transformer.add_adapter(transformer_lora_config)\n\n    if args.train_norm_layers:\n        for name, param in flux_transformer.named_parameters():\n            if any(k in name for k in NORM_LAYER_PREFIXES):\n                param.requires_grad = True\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                transformer_lora_layers_to_save = None\n\n                for model in models:\n                    if isinstance(unwrap_model(model), type(unwrap_model(flux_transformer))):\n                        model = unwrap_model(model)\n                        transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                        if args.train_norm_layers:\n                            transformer_norm_layers_to_save = {\n                                f\"transformer.{name}\": param\n                                for name, param in model.named_parameters()\n                                if any(k in name for k in NORM_LAYER_PREFIXES)\n                            }\n                            transformer_lora_layers_to_save = {\n                                **transformer_lora_layers_to_save,\n                                **transformer_norm_layers_to_save,\n                            }\n                    else:\n                        raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    if weights:\n                        weights.pop()\n\n                FluxControlPipeline.save_lora_weights(\n                    output_dir,\n                    transformer_lora_layers=transformer_lora_layers_to_save,\n                )\n\n        def load_model_hook(models, input_dir):\n            transformer_ = None\n\n            if not accelerator.distributed_type == DistributedType.DEEPSPEED:\n                while len(models) > 0:\n                    model = models.pop()\n\n                    if isinstance(model, type(unwrap_model(flux_transformer))):\n                        transformer_ = model\n                    else:\n                        raise ValueError(f\"unexpected save model: {model.__class__}\")\n            else:\n                transformer_ = FluxTransformer2DModel.from_pretrained(\n                    args.pretrained_model_name_or_path, subfolder=\"transformer\"\n                ).to(accelerator.device, weight_dtype)\n\n                # Handle input dimension doubling before adding adapter\n                with torch.no_grad():\n                    initial_input_channels = transformer_.config.in_channels\n                    new_linear = torch.nn.Linear(\n                        transformer_.x_embedder.in_features * 2,\n                        transformer_.x_embedder.out_features,\n                        bias=transformer_.x_embedder.bias is not None,\n                        dtype=transformer_.dtype,\n                        device=transformer_.device,\n                    )\n                    new_linear.weight.zero_()\n                    new_linear.weight[:, :initial_input_channels].copy_(transformer_.x_embedder.weight)\n                    if transformer_.x_embedder.bias is not None:\n                        new_linear.bias.copy_(transformer_.x_embedder.bias)\n                    transformer_.x_embedder = new_linear\n                    transformer_.register_to_config(in_channels=initial_input_channels * 2)\n\n                transformer_.add_adapter(transformer_lora_config)\n\n            lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)\n            transformer_lora_state_dict = {\n                f\"{k.replace('transformer.', '')}\": v\n                for k, v in lora_state_dict.items()\n                if k.startswith(\"transformer.\") and \"lora\" in k\n            }\n            incompatible_keys = set_peft_model_state_dict(\n                transformer_, transformer_lora_state_dict, adapter_name=\"default\"\n            )\n            if incompatible_keys is not None:\n                # check only for unexpected keys\n                unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n                if unexpected_keys:\n                    logger.warning(\n                        f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                        f\" {unexpected_keys}. \"\n                    )\n            if args.train_norm_layers:\n                transformer_norm_state_dict = {\n                    k: v\n                    for k, v in lora_state_dict.items()\n                    if k.startswith(\"transformer.\") and any(norm_k in k for norm_k in NORM_LAYER_PREFIXES)\n                }\n                transformer_._transformer_norm_layers = FluxControlPipeline._load_norm_into_transformer(\n                    transformer_norm_state_dict,\n                    transformer=transformer_,\n                    discard_original_layers=False,\n                )\n\n            # Make sure the trainable params are in float32. This is again needed since the base models\n            # are in `weight_dtype`. More details:\n            # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n            if args.mixed_precision == \"fp16\":\n                models = [transformer_]\n                # only upcast trainable parameters (LoRA) into fp32\n                cast_training_params(models)\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [flux_transformer]\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    if args.gradient_checkpointing:\n        flux_transformer.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimization parameters\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, flux_transformer.parameters()))\n    optimizer = optimizer_class(\n        transformer_lora_parameters,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Prepare dataset and dataloader.\n    train_dataset = get_train_dataset(args, accelerator)\n    train_dataset = prepare_train_dataset(train_dataset, accelerator)\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n    # Prepare everything with our `accelerator`.\n    flux_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        flux_transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n\n        # tensorboard cannot handle list types for config\n        tracker_config.pop(\"validation_prompt\")\n        tracker_config.pop(\"validation_image\")\n\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Create a pipeline for text encoding. We will move this pipeline to GPU/CPU as needed.\n    text_encoding_pipeline = FluxControlPipeline.from_pretrained(\n        args.pretrained_model_name_or_path, transformer=None, vae=None, torch_dtype=weight_dtype\n    )\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            logger.info(f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\")\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            logger.info(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    if accelerator.is_main_process and args.report_to == \"wandb\" and args.log_dataset_samples:\n        logger.info(\"Logging some dataset samples.\")\n        formatted_images = []\n        formatted_control_images = []\n        all_prompts = []\n        for i, batch in enumerate(train_dataloader):\n            images = (batch[\"pixel_values\"] + 1) / 2\n            control_images = (batch[\"conditioning_pixel_values\"] + 1) / 2\n            prompts = batch[\"captions\"]\n\n            if len(formatted_images) > 10:\n                break\n\n            for img, control_img, prompt in zip(images, control_images, prompts):\n                formatted_images.append(img)\n                formatted_control_images.append(control_img)\n                all_prompts.append(prompt)\n\n        logged_artifacts = []\n        for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):\n            logged_artifacts.append(wandb.Image(control_img, caption=\"Conditioning\"))\n            logged_artifacts.append(wandb.Image(img, caption=prompt))\n\n        wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == \"wandb\"]\n        wandb_tracker[0].log({\"dataset_samples\": logged_artifacts})\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    image_logs = None\n    for epoch in range(first_epoch, args.num_train_epochs):\n        flux_transformer.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(flux_transformer):\n                # Convert images to latent space\n                # vae encode\n                pixel_latents = encode_images(batch[\"pixel_values\"], vae.to(accelerator.device), weight_dtype)\n                control_latents = encode_images(\n                    batch[\"conditioning_pixel_values\"], vae.to(accelerator.device), weight_dtype\n                )\n\n                if args.offload:\n                    # offload vae to CPU.\n                    vae.cpu()\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                bsz = pixel_latents.shape[0]\n                noise = torch.randn_like(pixel_latents, device=accelerator.device, dtype=weight_dtype)\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=pixel_latents.device)\n\n                # Add noise according to flow matching.\n                sigmas = get_sigmas(timesteps, n_dim=pixel_latents.ndim, dtype=pixel_latents.dtype)\n                noisy_model_input = (1.0 - sigmas) * pixel_latents + sigmas * noise\n                # Concatenate across channels.\n                # Question: Should we concatenate before adding noise?\n                concatenated_noisy_model_input = torch.cat([noisy_model_input, control_latents], dim=1)\n\n                # pack the latents.\n                packed_noisy_model_input = FluxControlPipeline._pack_latents(\n                    concatenated_noisy_model_input,\n                    batch_size=bsz,\n                    num_channels_latents=concatenated_noisy_model_input.shape[1],\n                    height=concatenated_noisy_model_input.shape[2],\n                    width=concatenated_noisy_model_input.shape[3],\n                )\n\n                # latent image ids for RoPE.\n                latent_image_ids = FluxControlPipeline._prepare_latent_image_ids(\n                    bsz,\n                    concatenated_noisy_model_input.shape[2] // 2,\n                    concatenated_noisy_model_input.shape[3] // 2,\n                    accelerator.device,\n                    weight_dtype,\n                )\n\n                # handle guidance\n                if unwrap_model(flux_transformer).config.guidance_embeds:\n                    guidance_vec = torch.full(\n                        (bsz,),\n                        args.guidance_scale,\n                        device=noisy_model_input.device,\n                        dtype=weight_dtype,\n                    )\n                else:\n                    guidance_vec = None\n\n                # text encoding.\n                captions = batch[\"captions\"]\n                text_encoding_pipeline = text_encoding_pipeline.to(\"cuda\")\n                with torch.no_grad():\n                    prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(\n                        captions, prompt_2=None\n                    )\n                # this could be optimized by not having to do any text encoding and just\n                # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`\n                if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:\n                    prompt_embeds.zero_()\n                    pooled_prompt_embeds.zero_()\n                if args.offload:\n                    text_encoding_pipeline = text_encoding_pipeline.to(\"cpu\")\n\n                # Predict.\n                model_pred = flux_transformer(\n                    hidden_states=packed_noisy_model_input,\n                    timestep=timesteps / 1000,\n                    guidance=guidance_vec,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    return_dict=False,\n                )[0]\n                model_pred = FluxControlPipeline._unpack_latents(\n                    model_pred,\n                    height=noisy_model_input.shape[2] * vae_scale_factor,\n                    width=noisy_model_input.shape[3] * vae_scale_factor,\n                    vae_scale_factor=vae_scale_factor,\n                )\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow-matching loss\n                target = noise - pixel_latents\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n                accelerator.backward(loss)\n\n                if accelerator.sync_gradients:\n                    params_to_clip = flux_transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.\n                if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        image_logs = log_validation(\n                            flux_transformer=flux_transformer,\n                            args=args,\n                            accelerator=accelerator,\n                            weight_dtype=weight_dtype,\n                            step=global_step,\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        flux_transformer = unwrap_model(flux_transformer)\n        if args.upcast_before_saving:\n            flux_transformer.to(torch.float32)\n        transformer_lora_layers = get_peft_model_state_dict(flux_transformer)\n        if args.train_norm_layers:\n            transformer_norm_layers = {\n                f\"transformer.{name}\": param\n                for name, param in flux_transformer.named_parameters()\n                if any(k in name for k in NORM_LAYER_PREFIXES)\n            }\n            transformer_lora_layers = {**transformer_lora_layers, **transformer_norm_layers}\n        FluxControlPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n        )\n\n        del flux_transformer\n        del text_encoding_pipeline\n        del vae\n        free_memory()\n\n        # Run a final round of validation.\n        image_logs = None\n        if args.validation_prompt is not None:\n            image_logs = log_validation(\n                flux_transformer=None,\n                args=args,\n                accelerator=accelerator,\n                weight_dtype=weight_dtype,\n                step=global_step,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                image_logs=image_logs,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\", \"*.pt\", \"*.bin\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/inference/README.md",
    "content": "# Inference Examples\n\n**The inference examples folder is deprecated and will be removed in a future version**.\n**Officially supported inference examples can be found in the [Pipelines folder](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines)**.\n\n- For `Image-to-Image text-guided generation with Stable Diffusion`, please have a look at the official [Pipeline examples](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines#examples)\n- For `In-painting using Stable Diffusion`, please have a look at the official [Pipeline examples](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines#examples)\n- For `Tweak prompts reusing seeds and latents`, please have a look at the official [Pipeline examples](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines#examples)\n"
  },
  {
    "path": "examples/inference/image_to_image.py",
    "content": "import warnings\n\nfrom diffusers import StableDiffusionImg2ImgPipeline  # noqa F401\n\n\nwarnings.warn(\n    \"The `image_to_image.py` script is outdated. Please use directly `from diffusers import\"\n    \" StableDiffusionImg2ImgPipeline` instead.\"\n)\n"
  },
  {
    "path": "examples/inference/inpainting.py",
    "content": "import warnings\n\nfrom diffusers import StableDiffusionInpaintPipeline as StableDiffusionInpaintPipeline  # noqa F401\n\n\nwarnings.warn(\n    \"The `inpainting.py` script is outdated. Please use directly `from diffusers import\"\n    \" StableDiffusionInpaintPipeline` instead.\"\n)\n"
  },
  {
    "path": "examples/instruct_pix2pix/README.md",
    "content": "# InstructPix2Pix training example\n\n[InstructPix2Pix](https://huggingface.co/papers/2211.09800) is a method to fine-tune text-conditioned diffusion models such that they can follow an edit instruction for an input image. Models fine-tuned using this method take the following as inputs:\n\n<p align=\"center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png\" alt=\"instructpix2pix-inputs\" width=600/>\n</p>\n\nThe output is an \"edited\" image that reflects the edit instruction applied on the input image:\n\n<p align=\"center\">\n    <img src=\"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/output-gs%407-igs%401-steps%4050.png\" alt=\"instructpix2pix-output\" width=600/>\n</p>\n\nThe `train_instruct_pix2pix.py` script shows how to implement the training procedure and adapt it for Stable Diffusion.\n\n***Disclaimer: Even though `train_instruct_pix2pix.py` implements the InstructPix2Pix\ntraining procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.***\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the example folder and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell e.g. a notebook\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\n### Toy example\n\nAs mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset\nis a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper.\n\nConfigure environment variables such as the dataset identifier and the Stable Diffusion\ncheckpoint:\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport DATASET_ID=\"fusing/instructpix2pix-1000-samples\"\n```\n\nNow, we can launch training:\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" train_instruct_pix2pix.py \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --dataset_name=$DATASET_ID \\\n    --enable_xformers_memory_efficient_attention \\\n    --resolution=256 --random_flip \\\n    --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \\\n    --max_train_steps=15000 \\\n    --checkpointing_steps=5000 --checkpoints_total_limit=1 \\\n    --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \\\n    --conditioning_dropout_prob=0.05 \\\n    --mixed_precision=fp16 \\\n    --seed=42 \\\n    --push_to_hub\n```\n\nAdditionally, we support performing validation inference to monitor training progress\nwith Weights and Biases. You can enable this feature with `report_to=\"wandb\"`:\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" train_instruct_pix2pix.py \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --dataset_name=$DATASET_ID \\\n    --enable_xformers_memory_efficient_attention \\\n    --resolution=256 --random_flip \\\n    --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \\\n    --max_train_steps=15000 \\\n    --checkpointing_steps=5000 --checkpoints_total_limit=1 \\\n    --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \\\n    --conditioning_dropout_prob=0.05 \\\n    --mixed_precision=fp16 \\\n    --val_image_url=\"https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png\" \\\n    --validation_prompt=\"make the mountains snowy\" \\\n    --seed=42 \\\n    --report_to=wandb \\\n    --push_to_hub\n ```\n\n We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`.\n\n [Here](https://wandb.ai/sayakpaul/instruct-pix2pix/runs/ctr3kovq), you can find an example training run that includes some validation samples and the training hyperparameters.\n\n ***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***\n\n ## Training with multiple GPUs\n\n`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)\nfor running distributed training with `accelerate`. Here is an example command:\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" --multi_gpu train_instruct_pix2pix.py \\\n --pretrained_model_name_or_path=stable-diffusion-v1-5/stable-diffusion-v1-5 \\\n --dataset_name=sayakpaul/instructpix2pix-1000-samples \\\n --use_ema \\\n --enable_xformers_memory_efficient_attention \\\n --resolution=512 --random_flip \\\n --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \\\n --max_train_steps=15000 \\\n --checkpointing_steps=5000 --checkpoints_total_limit=1 \\\n --learning_rate=5e-05 --lr_warmup_steps=0 \\\n --conditioning_dropout_prob=0.05 \\\n --mixed_precision=fp16 \\\n --seed=42 \\\n --push_to_hub\n```\n\n ## Inference\n\n Once training is complete, we can perform inference:\n\n ```python\nimport PIL\nimport requests\nimport torch\nfrom diffusers import StableDiffusionInstructPix2PixPipeline\n\nmodel_id = \"your_model_id\" # <- replace this\npipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\n\nurl = \"https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png\"\n\n\ndef download_image(url):\n    image = PIL.Image.open(requests.get(url, stream=True).raw)\n    image = PIL.ImageOps.exif_transpose(image)\n    image = image.convert(\"RGB\")\n    return image\n\nimage = download_image(url)\nprompt = \"wipe out the lake\"\nnum_inference_steps = 20\nimage_guidance_scale = 1.5\nguidance_scale = 10\n\nedited_image = pipe(prompt,\n    image=image,\n    num_inference_steps=num_inference_steps,\n    image_guidance_scale=image_guidance_scale,\n    guidance_scale=guidance_scale,\n    generator=generator,\n).images[0]\nedited_image.save(\"edited_image.png\")\n```\n\nAn example model repo obtained using this training script can be found\nhere - [sayakpaul/instruct-pix2pix](https://huggingface.co/sayakpaul/instruct-pix2pix).\n\nWe encourage you to play with the following three parameters to control\nspeed and quality during performance:\n\n* `num_inference_steps`\n* `image_guidance_scale`\n* `guidance_scale`\n\nParticularly, `image_guidance_scale` and `guidance_scale` can have a profound impact\non the generated (\"edited\") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example).\n\nIf you're looking for some interesting ways to use the InstructPix2Pix training methodology, we welcome you to check out this blog post: [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd).\n\n## Stable Diffusion XL\n\nThere's an equivalent `train_instruct_pix2pix_sdxl.py` script for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to the docs [here](./README_sdxl.md) to learn more.\n"
  },
  {
    "path": "examples/instruct_pix2pix/README_sdxl.md",
    "content": "# InstructPix2Pix SDXL training example\n\n***This is based on the original InstructPix2Pix training example.***\n\n[Stable Diffusion XL](https://huggingface.co/papers/2307.01952) (or SDXL) is the latest image generation model that is tailored towards more photorealistic outputs with more detailed imagery and composition compared to previous SD models. It leverages a three times larger UNet backbone. The increase of model parameters is mainly due to more attention blocks and a larger cross-attention context as SDXL uses a second text encoder.\n\nThe `train_instruct_pix2pix_sdxl.py` script shows how to implement the training procedure and adapt it for Stable Diffusion XL.\n\n***Disclaimer: Even though `train_instruct_pix2pix_sdxl.py` implements the InstructPix2Pix\ntraining procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.***\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nRefer to the original InstructPix2Pix training example for installing the dependencies.\n\nYou will also need to get access of SDXL by filling the [form](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).\n\n### Toy example\n\nAs mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset\nis a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper.\n\nConfigure environment variables such as the dataset identifier and the Stable Diffusion\ncheckpoint:\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport DATASET_ID=\"fusing/instructpix2pix-1000-samples\"\n```\n\nNow, we can launch training:\n\n```bash\naccelerate launch train_instruct_pix2pix_sdxl.py \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --dataset_name=$DATASET_ID \\\n    --enable_xformers_memory_efficient_attention \\\n    --resolution=256 --random_flip \\\n    --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \\\n    --max_train_steps=15000 \\\n    --checkpointing_steps=5000 --checkpoints_total_limit=1 \\\n    --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \\\n    --conditioning_dropout_prob=0.05 \\\n    --seed=42 \\\n    --push_to_hub\n```\n\nAdditionally, we support performing validation inference to monitor training progress\nwith Weights and Biases. You can enable this feature with `report_to=\"wandb\"`:\n\n```bash\naccelerate launch train_instruct_pix2pix_sdxl.py \\\n    --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \\\n    --dataset_name=$DATASET_ID \\\n    --use_ema \\\n    --enable_xformers_memory_efficient_attention \\\n    --resolution=512 --random_flip \\\n    --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \\\n    --max_train_steps=15000 \\\n    --checkpointing_steps=5000 --checkpoints_total_limit=1 \\\n    --learning_rate=5e-05 --lr_warmup_steps=0 \\\n    --conditioning_dropout_prob=0.05 \\\n    --seed=42 \\\n    --val_image_url_or_path=\"https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg\" \\\n    --validation_prompt=\"make it in japan\" \\\n    --report_to=wandb \\\n    --push_to_hub\n ```\n\n We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`.\n\n [Here](https://wandb.ai/sayakpaul/instruct-pix2pix-sdxl-new/runs/sw53gxmc), you can find an example training run that includes some validation samples and the training hyperparameters.\n\n ***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***\n\n ## Training with multiple GPUs\n\n`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)\nfor running distributed training with `accelerate`. Here is an example command:\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" --multi_gpu train_instruct_pix2pix_sdxl.py \\\n    --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \\\n    --dataset_name=$DATASET_ID \\\n    --use_ema \\\n    --enable_xformers_memory_efficient_attention \\\n    --resolution=512 --random_flip \\\n    --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \\\n    --max_train_steps=15000 \\\n    --checkpointing_steps=5000 --checkpoints_total_limit=1 \\\n    --learning_rate=5e-05 --lr_warmup_steps=0 \\\n    --conditioning_dropout_prob=0.05 \\\n    --seed=42 \\\n    --val_image_url_or_path=\"https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg\" \\\n    --validation_prompt=\"make it in japan\" \\\n    --report_to=wandb \\\n    --push_to_hub\n```\n\n ## Inference\n\n Once training is complete, we can perform inference:\n\n ```python\nimport PIL\nimport requests\nimport torch\nfrom diffusers import StableDiffusionXLInstructPix2PixPipeline\n\nmodel_id = \"your_model_id\" # <- replace this\npipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\ngenerator = torch.Generator(\"cuda\").manual_seed(0)\n\nurl = \"https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg\"\n\n\ndef download_image(url):\n    image = PIL.Image.open(requests.get(url, stream=True).raw)\n    image = PIL.ImageOps.exif_transpose(image)\n    image = image.convert(\"RGB\")\n    return image\n\nimage = download_image(url)\nprompt = \"make it Japan\"\nnum_inference_steps = 20\nimage_guidance_scale = 1.5\nguidance_scale = 10\n\nedited_image = pipe(prompt,\n    image=image,\n    num_inference_steps=num_inference_steps,\n    image_guidance_scale=image_guidance_scale,\n    guidance_scale=guidance_scale,\n    generator=generator,\n).images[0]\nedited_image.save(\"edited_image.png\")\n```\n\nWe encourage you to play with the following three parameters to control\nspeed and quality during performance:\n\n* `num_inference_steps`\n* `image_guidance_scale`\n* `guidance_scale`\n\nParticularly, `image_guidance_scale` and `guidance_scale` can have a profound impact\non the generated (\"edited\") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example).\n\nIf you're looking for some interesting ways to use the InstructPix2Pix training methodology, we welcome you to check out this blog post: [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd).\n\n## Compare between SD and SDXL\n\nWe aim to understand the differences resulting from the use of SD-1.5 and SDXL-0.9 as pretrained models. To achieve this, we trained on the [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) using both of these pretrained models. The training script is as follows:\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\" or \"stabilityai/stable-diffusion-xl-base-0.9\"\nexport DATASET_ID=\"fusing/instructpix2pix-1000-samples\"\n\naccelerate launch train_instruct_pix2pix.py \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --dataset_name=$DATASET_ID \\\n    --use_ema \\\n    --enable_xformers_memory_efficient_attention \\\n    --resolution=512 --random_flip \\\n    --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \\\n    --max_train_steps=15000 \\\n    --checkpointing_steps=5000 --checkpoints_total_limit=1 \\\n    --learning_rate=5e-05 --lr_warmup_steps=0 \\\n    --conditioning_dropout_prob=0.05 \\\n    --seed=42 \\\n    --val_image_url=\"https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg\" \\\n    --validation_prompt=\"make it in Japan\" \\\n    --report_to=wandb \\\n    --push_to_hub\n```\n\nWe discovered that compared to training with SD-1.5 as the pretrained model, SDXL-0.9 results in a lower training loss value (SD-1.5 yields 0.0599, SDXL scores 0.0254). Moreover, from a visual perspective, the results obtained using SDXL demonstrated fewer artifacts and a richer detail. Notably, SDXL starts to preserve the structure of the original image earlier on.\n\nThe following two GIFs provide intuitive visual results. We observed, for each step, what kind of results could be achieved using the image\n<p align=\"center\">\n    <img src=\"https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg\" alt=\"input for make it Japan\" width=600/>\n</p>\nwith \"make it in Japan” as the prompt. It can be seen that SDXL starts preserving the details of the original image earlier, resulting in higher fidelity outcomes sooner.\n\n* SD-1.5: https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd_ip2p_training_val_img_progress.gif\n\n<p align=\"center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd_ip2p_training_val_img_progress.gif\" alt=\"input for make it Japan\" width=600/>\n</p>\n\n* SDXL: https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_ip2p_training_val_img_progress.gif\n\n<p align=\"center\">\n    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_ip2p_training_val_img_progress.gif\" alt=\"input for make it Japan\" width=600/>\n</p>\n"
  },
  {
    "path": "examples/instruct_pix2pix/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\ndatasets\nftfy\ntensorboard"
  },
  {
    "path": "examples/instruct_pix2pix/test_instruct_pix2pix.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass InstructPix2Pix(ExamplesTestsAccelerate):\n    def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/instruct_pix2pix/train_instruct_pix2pix.py\n                --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n                --dataset_name=hf-internal-testing/instructpix2pix-10-samples\n                --resolution=64\n                --random_flip\n                --train_batch_size=1\n                --max_train_steps=6\n                --checkpointing_steps=2\n                --checkpoints_total_limit=2\n                --output_dir {tmpdir}\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/instruct_pix2pix/train_instruct_pix2pix.py\n                --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n                --dataset_name=hf-internal-testing/instructpix2pix-10-samples\n                --resolution=64\n                --random_flip\n                --train_batch_size=1\n                --max_train_steps=4\n                --checkpointing_steps=2\n                --output_dir {tmpdir}\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\"},\n            )\n\n            resume_run_args = f\"\"\"\n                examples/instruct_pix2pix/train_instruct_pix2pix.py\n                --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe\n                --dataset_name=hf-internal-testing/instructpix2pix-10-samples\n                --resolution=64\n                --random_flip\n                --train_batch_size=1\n                --max_train_steps=8\n                --checkpointing_steps=2\n                --output_dir {tmpdir}\n                --seed=0\n                --resume_from_checkpoint=checkpoint-4\n                --checkpoints_total_limit=2\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-6\", \"checkpoint-8\"},\n            )\n"
  },
  {
    "path": "examples/instruct_pix2pix/train_instruct_pix2pix.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Script to fine-tune Stable Diffusion for InstructPix2Pix.\"\"\"\n\nimport argparse\nimport logging\nimport math\nimport os\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport PIL\nimport requests\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel\nfrom diffusers.utils import check_min_version, deprecate, is_wandb_available\nfrom diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\nDATASET_NAME_MAPPING = {\n    \"fusing/instructpix2pix-1000-samples\": (\"input_image\", \"edit_prompt\", \"edited_image\"),\n}\nWANDB_TABLE_COL_NAMES = [\"original_image\", \"edited_image\", \"edit_prompt\"]\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    generator,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    original_image = download_image(args.val_image_url)\n    edited_images = []\n    if torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type)\n\n    with autocast_ctx:\n        for _ in range(args.num_validation_images):\n            edited_images.append(\n                pipeline(\n                    args.validation_prompt,\n                    image=original_image,\n                    num_inference_steps=20,\n                    image_guidance_scale=1.5,\n                    guidance_scale=7,\n                    generator=generator,\n                ).images[0]\n            )\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"wandb\":\n            wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)\n            for edited_image in edited_images:\n                wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)\n            tracker.log({\"validation\": wandb_table})\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script for InstructPix2Pix.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--original_image_column\",\n        type=str,\n        default=\"input_image\",\n        help=\"The column of the dataset containing the original image on which edits where made.\",\n    )\n    parser.add_argument(\n        \"--edited_image_column\",\n        type=str,\n        default=\"edited_image\",\n        help=\"The column of the dataset containing the edited image.\",\n    )\n    parser.add_argument(\n        \"--edit_prompt_column\",\n        type=str,\n        default=\"edit_prompt\",\n        help=\"The column of the dataset containing the edit instruction.\",\n    )\n    parser.add_argument(\n        \"--val_image_url\",\n        type=str,\n        default=None,\n        help=\"URL to the original image that you would like to edit (used during inference for debugging purposes).\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\", type=str, default=None, help=\"A prompt that is sampled during training for inference.\"\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"instruct-pix2pix-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=256,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--conditioning_dropout_prob\",\n        type=float,\n        default=None,\n        help=\"Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://huggingface.co/papers/2211.09800.\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--non_ema_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or\"\n            \" remote repository specified with --pretrained_model_name_or_path.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    # default to using the same revision for the non-ema model if not specified\n    if args.non_ema_revision is None:\n        args.non_ema_revision = args.revision\n\n    return args\n\n\ndef convert_to_np(image, resolution):\n    image = image.convert(\"RGB\").resize((resolution, resolution))\n    return np.array(image).transpose(2, 0, 1)\n\n\ndef download_image(url):\n    image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)\n    image = PIL.ImageOps.exif_transpose(image)\n    image = image.convert(\"RGB\")\n    return image\n\n\ndef main():\n    args = parse_args()\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if args.non_ema_revision is not None:\n        deprecate(\n            \"non_ema_revision!=None\",\n            \"0.15.0\",\n            message=(\n                \"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to\"\n                \" use `--variant=non_ema` instead.\"\n            ),\n        )\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n    )\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.non_ema_revision\n    )\n\n    # InstructPix2Pix uses an additional image for conditioning. To accommodate that,\n    # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is\n    # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized\n    # from the pre-trained checkpoints. For the extra channels added to the first layer, they are\n    # initialized to zero.\n    logger.info(\"Initializing the InstructPix2Pix UNet from the pretrained UNet.\")\n    in_channels = 8\n    out_channels = unet.conv_in.out_channels\n    unet.register_to_config(in_channels=in_channels)\n\n    with torch.no_grad():\n        new_conv_in = nn.Conv2d(\n            in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding\n        )\n        new_conv_in.weight.zero_()\n        new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)\n        unet.conv_in = new_conv_in\n\n    # Freeze vae and text_encoder\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n\n    # Create EMA for the unet.\n    if args.use_ema:\n        ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_unet.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    if weights:\n                        weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"unet_ema\"), UNet2DConditionModel)\n                ema_unet.load_state_dict(load_model.state_dict())\n                ema_unet.to(accelerator.device)\n                del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    optimizer = optimizer_cls(\n        unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.original_image_column is None:\n        original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        original_image_column = args.original_image_column\n        if original_image_column not in column_names:\n            raise ValueError(\n                f\"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.edit_prompt_column is None:\n        edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        edit_prompt_column = args.edit_prompt_column\n        if edit_prompt_column not in column_names:\n            raise ValueError(\n                f\"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.edited_image_column is None:\n        edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]\n    else:\n        edited_image_column = args.edited_image_column\n        if edited_image_column not in column_names:\n            raise ValueError(\n                f\"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(captions):\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        return inputs.input_ids\n\n    # Preprocessing the datasets.\n    train_transforms = transforms.Compose(\n        [\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n        ]\n    )\n\n    def preprocess_images(examples):\n        original_images = np.concatenate(\n            [convert_to_np(image, args.resolution) for image in examples[original_image_column]]\n        )\n        edited_images = np.concatenate(\n            [convert_to_np(image, args.resolution) for image in examples[edited_image_column]]\n        )\n        # We need to ensure that the original and the edited images undergo the same\n        # augmentation transforms.\n        images = np.stack([original_images, edited_images])\n        images = torch.tensor(images)\n        images = 2 * (images / 255) - 1\n        return train_transforms(images)\n\n    def preprocess_train(examples):\n        # Preprocess images.\n        preprocessed_images = preprocess_images(examples)\n        # Since the original and edited images were concatenated before\n        # applying the transformations, we need to separate them and reshape\n        # them accordingly.\n        original_images, edited_images = preprocessed_images\n        original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)\n        edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)\n\n        # Collate the preprocessed images into the `examples`.\n        examples[\"original_pixel_values\"] = original_images\n        examples[\"edited_pixel_values\"] = edited_images\n\n        # Preprocess the captions.\n        captions = list(examples[edit_prompt_column])\n        examples[\"input_ids\"] = tokenize_captions(captions)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        original_pixel_values = torch.stack([example[\"original_pixel_values\"] for example in examples])\n        original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()\n        edited_pixel_values = torch.stack([example[\"edited_pixel_values\"] for example in examples])\n        edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()\n        input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n        return {\n            \"original_pixel_values\": original_pixel_values,\n            \"edited_pixel_values\": edited_pixel_values,\n            \"input_ids\": input_ids,\n        }\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    if args.use_ema:\n        ema_unet.to(accelerator.device)\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move text_encode and vae to gpu and cast to weight_dtype\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"instruct-pix2pix\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(unet):\n                # We want to learn the denoising process w.r.t the edited images which\n                # are conditioned on the original image (which was edited) and the edit instruction.\n                # So, first, convert images to latent space.\n                latents = vae.encode(batch[\"edited_pixel_values\"].to(weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning.\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n                # Get the additional image embedding for conditioning.\n                # Instead of getting a diagonal Gaussian here, we simply take the mode.\n                original_image_embeds = vae.encode(batch[\"original_pixel_values\"].to(weight_dtype)).latent_dist.mode()\n\n                # Conditioning dropout to support classifier-free guidance during inference. For more details\n                # check out the section 3.2.1 of the original paper https://huggingface.co/papers/2211.09800.\n                if args.conditioning_dropout_prob is not None:\n                    random_p = torch.rand(bsz, device=latents.device, generator=generator)\n                    # Sample masks for the edit prompts.\n                    prompt_mask = random_p < 2 * args.conditioning_dropout_prob\n                    prompt_mask = prompt_mask.reshape(bsz, 1, 1)\n                    # Final text conditioning.\n                    null_conditioning = text_encoder(tokenize_captions([\"\"]).to(accelerator.device))[0]\n                    encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)\n\n                    # Sample masks for the original images.\n                    image_mask_dtype = original_image_embeds.dtype\n                    image_mask = 1 - (\n                        (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)\n                        * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)\n                    )\n                    image_mask = image_mask.reshape(bsz, 1, 1, 1)\n                    # Final image conditioning.\n                    original_image_embeds = image_mask * original_image_embeds\n\n                # Concatenate the `original_image_embeds` with the `noisy_latents`.\n                concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # Predict the noise residual and compute loss\n                model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_unet.step(unet.parameters())\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if (\n                (args.val_image_url is not None)\n                and (args.validation_prompt is not None)\n                and (epoch % args.validation_epochs == 0)\n            ):\n                if args.use_ema:\n                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                    ema_unet.store(unet.parameters())\n                    ema_unet.copy_to(unet.parameters())\n                # The models need unwrapping because for compatibility in distributed training mode.\n                pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    unet=unwrap_model(unet),\n                    text_encoder=unwrap_model(text_encoder),\n                    vae=unwrap_model(vae),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n\n                log_validation(\n                    pipeline,\n                    args,\n                    accelerator,\n                    generator,\n                )\n\n                if args.use_ema:\n                    # Switch back to the original UNet parameters.\n                    ema_unet.restore(unet.parameters())\n\n                del pipeline\n                torch.cuda.empty_cache()\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        if args.use_ema:\n            ema_unet.copy_to(unet.parameters())\n\n        pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            text_encoder=unwrap_model(text_encoder),\n            vae=unwrap_model(vae),\n            unet=unwrap_model(unet),\n            revision=args.revision,\n            variant=args.variant,\n        )\n        pipeline.save_pretrained(args.output_dir)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        if (args.val_image_url is not None) and (args.validation_prompt is not None):\n            log_validation(\n                pipeline,\n                args,\n                accelerator,\n                generator,\n            )\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 Harutatsu Akiyama and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom urllib.parse import urlparse\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport PIL\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import (\n    StableDiffusionXLInstructPix2PixPipeline,\n)\nfrom diffusers.training_utils import EMAModel\nfrom diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\nDATASET_NAME_MAPPING = {\n    \"fusing/instructpix2pix-1000-samples\": (\"file_name\", \"edited_image\", \"edit_prompt\"),\n}\nWANDB_TABLE_COL_NAMES = [\"file_name\", \"edited_image\", \"edit_prompt\"]\nTORCH_DTYPE_MAPPING = {\"fp32\": torch.float32, \"fp16\": torch.float16, \"bf16\": torch.bfloat16}\n\n\ndef log_validation(pipeline, args, accelerator, generator, global_step, is_final_validation=False):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    val_save_dir = os.path.join(args.output_dir, \"validation_images\")\n    if not os.path.exists(val_save_dir):\n        os.makedirs(val_save_dir)\n\n    original_image = (\n        lambda image_url_or_path: load_image(image_url_or_path)\n        if urlparse(image_url_or_path).scheme\n        else Image.open(image_url_or_path).convert(\"RGB\")\n    )(args.val_image_url_or_path)\n\n    if torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type)\n\n    with autocast_ctx:\n        edited_images = []\n        # Run inference\n        for val_img_idx in range(args.num_validation_images):\n            a_val_img = pipeline(\n                args.validation_prompt,\n                image=original_image,\n                num_inference_steps=20,\n                image_guidance_scale=1.5,\n                guidance_scale=7,\n                generator=generator,\n            ).images[0]\n            edited_images.append(a_val_img)\n            # Save validation images\n            a_val_img.save(os.path.join(val_save_dir, f\"step_{global_step}_val_img_{val_img_idx}.png\"))\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"wandb\":\n            wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)\n            for edited_image in edited_images:\n                wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)\n            logger_name = \"test\" if is_final_validation else \"validation\"\n            tracker.log({logger_name: wandb_table})\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Script to train Stable Diffusion XL for InstructPix2Pix.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--vae_precision\",\n        type=str,\n        choices=[\"fp32\", \"fp16\", \"bf16\"],\n        default=\"fp32\",\n        help=(\n            \"The vanilla SDXL 1.0 VAE can cause NaNs due to large activation values. Some custom models might already have a solution\"\n            \" to this problem, and this flag allows you to use mixed precision to stabilize training.\"\n        ),\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--original_image_column\",\n        type=str,\n        default=\"input_image\",\n        help=\"The column of the dataset containing the original image on which edits where made.\",\n    )\n    parser.add_argument(\n        \"--edited_image_column\",\n        type=str,\n        default=\"edited_image\",\n        help=\"The column of the dataset containing the edited image.\",\n    )\n    parser.add_argument(\n        \"--edit_prompt_column\",\n        type=str,\n        default=\"edit_prompt\",\n        help=\"The column of the dataset containing the edit instruction.\",\n    )\n    parser.add_argument(\n        \"--val_image_url_or_path\",\n        type=str,\n        default=None,\n        help=\"URL to the original image that you would like to edit (used during inference for debugging purposes).\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\", type=str, default=None, help=\"A prompt that is sampled during training for inference.\"\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run fine-tuning validation every X steps. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"instruct-pix2pix-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=256,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this resolution.\"\n        ),\n    )\n    parser.add_argument(\n        \"--crops_coords_top_left_h\",\n        type=int,\n        default=0,\n        help=(\"Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet.\"),\n    )\n    parser.add_argument(\n        \"--crops_coords_top_left_w\",\n        type=int,\n        default=0,\n        help=(\"Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet.\"),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--conditioning_dropout_prob\",\n        type=float,\n        default=None,\n        help=\"Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://huggingface.co/papers/2211.09800.\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--non_ema_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or\"\n            \" remote repository specified with --pretrained_model_name_or_path.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    # default to using the same revision for the non-ema model if not specified\n    if args.non_ema_revision is None:\n        args.non_ema_revision = args.revision\n\n    return args\n\n\ndef convert_to_np(image, resolution):\n    if isinstance(image, str):\n        image = PIL.Image.open(image)\n    image = image.convert(\"RGB\").resize((resolution, resolution))\n    return np.array(image).transpose(2, 0, 1)\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if args.non_ema_revision is not None:\n        deprecate(\n            \"non_ema_revision!=None\",\n            \"0.15.0\",\n            message=(\n                \"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to\"\n                \" use `--variant=non_ema` instead.\"\n            ),\n        )\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # InstructPix2Pix uses an additional image for conditioning. To accommodate that,\n    # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is\n    # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized\n    # from the pre-trained checkpoints. For the extra channels added to the first layer, they are\n    # initialized to zero.\n    logger.info(\"Initializing the XL InstructPix2Pix UNet from the pretrained UNet.\")\n    in_channels = 8\n    out_channels = unet.conv_in.out_channels\n    unet.register_to_config(in_channels=in_channels)\n\n    with torch.no_grad():\n        new_conv_in = nn.Conv2d(\n            in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding\n        )\n        new_conv_in.weight.zero_()\n        new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)\n        unet.conv_in = new_conv_in\n\n    # Create EMA for the unet.\n    if args.use_ema:\n        ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_unet.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"unet_ema\"), UNet2DConditionModel)\n                ema_unet.load_state_dict(load_model.state_dict())\n                ema_unet.to(accelerator.device)\n                del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    optimizer = optimizer_cls(\n        unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.original_image_column is None:\n        original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        original_image_column = args.original_image_column\n        if original_image_column not in column_names:\n            raise ValueError(\n                f\"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.edit_prompt_column is None:\n        edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        edit_prompt_column = args.edit_prompt_column\n        if edit_prompt_column not in column_names:\n            raise ValueError(\n                f\"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.edited_image_column is None:\n        edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]\n    else:\n        edited_image_column = args.edited_image_column\n        if edited_image_column not in column_names:\n            raise ValueError(\n                f\"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n        warnings.warn(f\"weight_dtype {weight_dtype} may cause nan during vae encoding\", UserWarning)\n\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n        warnings.warn(f\"weight_dtype {weight_dtype} may cause nan during vae encoding\", UserWarning)\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(captions, tokenizer):\n        inputs = tokenizer(\n            captions,\n            max_length=tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        return inputs.input_ids\n\n    # Preprocessing the datasets.\n    train_transforms = transforms.Compose(\n        [\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n        ]\n    )\n\n    def preprocess_images(examples):\n        original_images = np.concatenate(\n            [convert_to_np(image, args.resolution) for image in examples[original_image_column]]\n        )\n        edited_images = np.concatenate(\n            [convert_to_np(image, args.resolution) for image in examples[edited_image_column]]\n        )\n        # We need to ensure that the original and the edited images undergo the same\n        # augmentation transforms.\n        images = np.stack([original_images, edited_images])\n        images = torch.tensor(images)\n        images = 2 * (images / 255) - 1\n        return train_transforms(images)\n\n    # Load scheduler, tokenizer and models.\n    tokenizer_1 = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_2 = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n    text_encoder_cls_2 = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder_1 = text_encoder_cls_1.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_2 = text_encoder_cls_2.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n\n    # We ALWAYS pre-compute the additional condition embeddings needed for SDXL\n    # UNet as the model is already big and it uses two text encoders.\n    text_encoder_1.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_2.to(accelerator.device, dtype=weight_dtype)\n    tokenizers = [tokenizer_1, tokenizer_2]\n    text_encoders = [text_encoder_1, text_encoder_2]\n\n    # Freeze vae and text_encoders\n    vae.requires_grad_(False)\n    text_encoder_1.requires_grad_(False)\n    text_encoder_2.requires_grad_(False)\n\n    # Set UNet to trainable.\n    unet.train()\n\n    # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\n    def encode_prompt(text_encoders, tokenizers, prompt):\n        prompt_embeds_list = []\n\n        for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n            text_inputs = tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            prompt_embeds = text_encoder(\n                text_input_ids.to(text_encoder.device),\n                output_hidden_states=True,\n            )\n\n            # We are only ALWAYS interested in the pooled output of the final text encoder\n            pooled_prompt_embeds = prompt_embeds[0]\n            prompt_embeds = prompt_embeds.hidden_states[-2]\n            bs_embed, seq_len, _ = prompt_embeds.shape\n            prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n            prompt_embeds_list.append(prompt_embeds)\n\n        prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n        pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n        return prompt_embeds, pooled_prompt_embeds\n\n    # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\n    def encode_prompts(text_encoders, tokenizers, prompts):\n        prompt_embeds_all = []\n        pooled_prompt_embeds_all = []\n\n        for prompt in prompts:\n            prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)\n            prompt_embeds_all.append(prompt_embeds)\n            pooled_prompt_embeds_all.append(pooled_prompt_embeds)\n\n        return torch.stack(prompt_embeds_all), torch.stack(pooled_prompt_embeds_all)\n\n    # Adapted from examples.dreambooth.train_dreambooth_lora_sdxl\n    # Here, we compute not just the text embeddings but also the additional embeddings\n    # needed for the SD XL UNet to operate.\n    def compute_embeddings_for_prompts(prompts, text_encoders, tokenizers):\n        with torch.no_grad():\n            prompt_embeds_all, pooled_prompt_embeds_all = encode_prompts(text_encoders, tokenizers, prompts)\n            add_text_embeds_all = pooled_prompt_embeds_all\n\n            prompt_embeds_all = prompt_embeds_all.to(accelerator.device)\n            add_text_embeds_all = add_text_embeds_all.to(accelerator.device)\n        return prompt_embeds_all, add_text_embeds_all\n\n    # Get null conditioning\n    def compute_null_conditioning():\n        null_conditioning_list = []\n        for a_tokenizer, a_text_encoder in zip(tokenizers, text_encoders):\n            null_conditioning_list.append(\n                a_text_encoder(\n                    tokenize_captions([\"\"], tokenizer=a_tokenizer).to(accelerator.device),\n                    output_hidden_states=True,\n                ).hidden_states[-2]\n            )\n        return torch.concat(null_conditioning_list, dim=-1)\n\n    null_conditioning = compute_null_conditioning()\n\n    def compute_time_ids():\n        crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)\n        original_size = target_size = (args.resolution, args.resolution)\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n        add_time_ids = torch.tensor([add_time_ids], dtype=weight_dtype)\n        return add_time_ids.to(accelerator.device).repeat(args.train_batch_size, 1)\n\n    add_time_ids = compute_time_ids()\n\n    def preprocess_train(examples):\n        # Preprocess images.\n        preprocessed_images = preprocess_images(examples)\n        # Since the original and edited images were concatenated before\n        # applying the transformations, we need to separate them and reshape\n        # them accordingly.\n        original_images, edited_images = preprocessed_images\n        original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)\n        edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)\n\n        # Collate the preprocessed images into the `examples`.\n        examples[\"original_pixel_values\"] = original_images\n        examples[\"edited_pixel_values\"] = edited_images\n\n        # Preprocess the captions.\n        captions = list(examples[edit_prompt_column])\n        prompt_embeds_all, add_text_embeds_all = compute_embeddings_for_prompts(captions, text_encoders, tokenizers)\n        examples[\"prompt_embeds\"] = prompt_embeds_all\n        examples[\"add_text_embeds\"] = add_text_embeds_all\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        original_pixel_values = torch.stack([example[\"original_pixel_values\"] for example in examples])\n        original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()\n        edited_pixel_values = torch.stack([example[\"edited_pixel_values\"] for example in examples])\n        edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()\n        prompt_embeds = torch.concat([example[\"prompt_embeds\"] for example in examples], dim=0)\n        add_text_embeds = torch.concat([example[\"add_text_embeds\"] for example in examples], dim=0)\n        return {\n            \"original_pixel_values\": original_pixel_values,\n            \"edited_pixel_values\": edited_pixel_values,\n            \"prompt_embeds\": prompt_embeds,\n            \"add_text_embeds\": add_text_embeds,\n        }\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    if args.use_ema:\n        ema_unet.to(accelerator.device)\n\n    # Move vae, unet and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    if args.pretrained_vae_model_name_or_path is not None:\n        vae.to(accelerator.device, dtype=weight_dtype)\n    else:\n        vae.to(accelerator.device, dtype=TORCH_DTYPE_MAPPING[args.vae_precision])\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"instruct-pix2pix-xl\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # We want to learn the denoising process w.r.t the edited images which\n                # are conditioned on the original image (which was edited) and the edit instruction.\n                # So, first, convert images to latent space.\n                if args.pretrained_vae_model_name_or_path is not None:\n                    edited_pixel_values = batch[\"edited_pixel_values\"].to(dtype=weight_dtype)\n                else:\n                    edited_pixel_values = batch[\"edited_pixel_values\"]\n                latents = vae.encode(edited_pixel_values).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    latents = latents.to(weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # SDXL additional inputs\n                encoder_hidden_states = batch[\"prompt_embeds\"]\n                add_text_embeds = batch[\"add_text_embeds\"]\n\n                # Get the additional image embedding for conditioning.\n                # Instead of getting a diagonal Gaussian here, we simply take the mode.\n                if args.pretrained_vae_model_name_or_path is not None:\n                    original_pixel_values = batch[\"original_pixel_values\"].to(dtype=weight_dtype)\n                else:\n                    original_pixel_values = batch[\"original_pixel_values\"]\n                original_image_embeds = vae.encode(original_pixel_values).latent_dist.sample()\n                if args.pretrained_vae_model_name_or_path is None:\n                    original_image_embeds = original_image_embeds.to(weight_dtype)\n\n                # Conditioning dropout to support classifier-free guidance during inference. For more details\n                # check out the section 3.2.1 of the original paper https://huggingface.co/papers/2211.09800.\n                if args.conditioning_dropout_prob is not None:\n                    random_p = torch.rand(bsz, device=latents.device, generator=generator)\n                    # Sample masks for the edit prompts.\n                    prompt_mask = random_p < 2 * args.conditioning_dropout_prob\n                    prompt_mask = prompt_mask.reshape(bsz, 1, 1)\n                    # Final text conditioning.\n                    encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)\n\n                    # Sample masks for the original images.\n                    image_mask_dtype = original_image_embeds.dtype\n                    image_mask = 1 - (\n                        (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)\n                        * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)\n                    )\n                    image_mask = image_mask.reshape(bsz, 1, 1, 1)\n                    # Final image conditioning.\n                    original_image_embeds = image_mask * original_image_embeds\n\n                # Concatenate the `original_image_embeds` with the `noisy_latents`.\n                concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # Predict the noise residual and compute loss\n                added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n                model_pred = unet(\n                    concatenated_noisy_latents,\n                    timesteps,\n                    encoder_hidden_states,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_unet.step(unet.parameters())\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            ### BEGIN: Perform validation every `validation_epochs` steps\n            if global_step % args.validation_steps == 0:\n                if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):\n                    # create pipeline\n                    if args.use_ema:\n                        # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                        ema_unet.store(unet.parameters())\n                        ema_unet.copy_to(unet.parameters())\n\n                    # The models need unwrapping because for compatibility in distributed training mode.\n                    pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        unet=unwrap_model(unet),\n                        text_encoder=text_encoder_1,\n                        text_encoder_2=text_encoder_2,\n                        tokenizer=tokenizer_1,\n                        tokenizer_2=tokenizer_2,\n                        vae=vae,\n                        revision=args.revision,\n                        variant=args.variant,\n                        torch_dtype=weight_dtype,\n                    )\n\n                    log_validation(\n                        pipeline,\n                        args,\n                        accelerator,\n                        generator,\n                        global_step,\n                        is_final_validation=False,\n                    )\n\n                    if args.use_ema:\n                        # Switch back to the original UNet parameters.\n                        ema_unet.restore(unet.parameters())\n\n                    del pipeline\n                    torch.cuda.empty_cache()\n            ### END: Perform validation every `validation_epochs` steps\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        if args.use_ema:\n            ema_unet.copy_to(unet.parameters())\n\n        pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            text_encoder=text_encoder_1,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer_1,\n            tokenizer_2=tokenizer_2,\n            vae=vae,\n            unet=unwrap_model(unet),\n            revision=args.revision,\n            variant=args.variant,\n        )\n\n        pipeline.save_pretrained(args.output_dir)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):\n            log_validation(\n                pipeline,\n                args,\n                accelerator,\n                generator,\n                global_step,\n                is_final_validation=True,\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/kandinsky2_2/text_to_image/README.md",
    "content": "# Kandinsky2.2 text-to-image fine-tuning\n\nKandinsky 2.2 includes a prior pipeline that generates image embeddings from text prompts, and a decoder pipeline that generates the output image based on the image embeddings. We provide `train_text_to_image_prior.py` and `train_text_to_image_decoder.py` scripts to show you how to fine-tune the Kandinsky prior and decoder models separately based on your own dataset. To achieve the best results, you should fine-tune **_both_** your prior and decoder models.\n\n___Note___:\n\n___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___\n\n\n## Running locally with PyTorch\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder  and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\nFor this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the --push_to_hub flag.\n\n___\n\n### Naruto example\n\nFor all our examples, we will directly store the trained weights on the Hub, so we need to be logged in and add the `--push_to_hub` flag. In order to do that, you have to be a registered user on the 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to the [User Access Tokens](https://huggingface.co/docs/hub/security-tokens) guide.\n\nRun the following command to authenticate your token\n\n```bash\nhf auth login\n```\n\nWe also use [Weights and Biases](https://docs.wandb.ai/quickstart) logging by default, because it is really useful to monitor the training progress by regularly generating sample images during training. To install wandb, run\n\n```bash\npip install wandb\n```\n\nTo disable wandb logging, remove the `--report_to==\"wandb\"` and `--validation_prompts=\"A robot naruto, 4k photo\"` flags from below examples\n\n#### Fine-tune decoder\n<br>\n\n<!-- accelerate_snippet_start -->\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image_decoder.py \\\n  --dataset_name=$DATASET_NAME \\\n  --resolution=768 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --checkpoints_total_limit=3 \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --validation_prompts=\"A robot naruto, 4k photo\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub \\\n  --output_dir=\"kandi2-decoder-naruto-model\"\n```\n<!-- accelerate_snippet_end -->\n\n\nTo train on your own training files, prepare the dataset according to the format required by `datasets`. You can find the instructions for how to do that in the [ImageFolder with metadata](https://huggingface.co/docs/datasets/en/image_load#imagefolder-with-metadata) guide.\nIf you wish to use custom loading logic, you should modify the script and we have left pointers for that in the training script.\n\n```bash\nexport TRAIN_DIR=\"path_to_your_dataset\"\n\naccelerate launch --mixed_precision=\"fp16\" train_text_to_image_decoder.py \\\n  --train_data_dir=$TRAIN_DIR \\\n  --resolution=768 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --checkpoints_total_limit=3 \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --validation_prompts=\"A robot naruto, 4k photo\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub \\\n  --output_dir=\"kandi22-decoder-naruto-model\"\n```\n\n\nOnce the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `kandi22-decoder-naruto-model`. To load the fine-tuned model for inference just pass that path to `AutoPipelineForText2Image`\n\n```python\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipe = AutoPipelineForText2Image.from_pretrained(output_dir, torch_dtype=torch.float16)\npipe.enable_model_cpu_offload()\n\nprompt='A robot naruto, 4k photo'\nimages = pipe(prompt=prompt).images\nimages[0].save(\"robot-naruto.png\")\n```\n\nCheckpoints only save the unet, so to run inference from a checkpoint, just load the unet\n```python\nfrom diffusers import AutoPipelineForText2Image, UNet2DConditionModel\n\nmodel_path = \"path_to_saved_model\"\n\nunet = UNet2DConditionModel.from_pretrained(model_path + \"/checkpoint-<N>/unet\")\n\npipe = AutoPipelineForText2Image.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", unet=unet, torch_dtype=torch.float16)\npipe.enable_model_cpu_offload()\n\nimage = pipe(prompt=\"A robot naruto, 4k photo\").images[0]\nimage.save(\"robot-naruto.png\")\n```\n\n#### Fine-tune prior\n\nYou can fine-tune the Kandinsky prior model with `train_text_to_image_prior.py` script. Note that we currently do not support `--gradient_checkpointing` for prior model fine-tuning.\n\n<br>\n\n<!-- accelerate_snippet_start -->\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image_prior.py \\\n  --dataset_name=$DATASET_NAME \\\n  --resolution=768 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --checkpoints_total_limit=3 \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --validation_prompts=\"A robot naruto, 4k photo\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub \\\n  --output_dir=\"kandi2-prior-naruto-model\"\n```\n<!-- accelerate_snippet_end -->\n\n\nTo perform inference with the fine-tuned prior model, you will need to first create a prior pipeline by passing the `output_dir` to `DiffusionPipeline`. Then create a `KandinskyV22CombinedPipeline` from a pretrained or fine-tuned decoder checkpoint along with all the modules of the prior pipeline you just created.\n\n```python\nfrom diffusers import AutoPipelineForText2Image, DiffusionPipeline\nimport torch\n\npipe_prior = DiffusionPipeline.from_pretrained(output_dir, torch_dtype=torch.float16)\nprior_components = {\"prior_\" + k: v for k,v in pipe_prior.components.items()}\npipe = AutoPipelineForText2Image.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", **prior_components, torch_dtype=torch.float16)\n\npipe.enable_model_cpu_offload()\nprompt='A robot naruto, 4k photo'\nimages = pipe(prompt=prompt, negative_prompt=negative_prompt).images\nimages[0]\n```\n\nIf you want to use a fine-tuned decoder checkpoint along with your fine-tuned prior checkpoint, you can simply replace the \"kandinsky-community/kandinsky-2-2-decoder\" in above code with your custom model repo name. Note that in order to be able to create a `KandinskyV22CombinedPipeline`, your model repository need to have a prior tag. If you have created your model repo using our training script, the prior tag is automatically included.\n\n#### Training with multiple GPUs\n\n`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)\nfor running distributed training with `accelerate`. Here is an example command:\n\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\" --multi_gpu  train_text_to_image_decoder.py \\\n  --dataset_name=$DATASET_NAME \\\n  --resolution=768 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --checkpoints_total_limit=3 \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --validation_prompts=\"A robot naruto, 4k photo\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub \\\n  --output_dir=\"kandi2-decoder-naruto-model\"\n```\n\n\n#### Training with Min-SNR weighting\n\nWe support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://huggingface.co/papers/2303.09556) which helps achieve faster convergence\nby rebalancing the loss. Enable the `--snr_gamma` argument and set it to the recommended\nvalue of 5.0.\n\n\n## Training with LoRA\n\nLow-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.\n\nIn a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:\n\n- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).\n- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.\n- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.\n\n[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.\n\nWith LoRA, it's possible to fine-tune Kandinsky 2.2 on a custom image-caption pair dataset\non consumer GPUs like Tesla T4, Tesla V100.\n\n### Training\n\nFirst, you need to set up your development environment as explained in the [installation](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder) and the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).\n\n\n#### Train decoder\n\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\" train_text_to_image_decoder_lora.py \\\n  --dataset_name=$DATASET_NAME --caption_column=\"text\" \\\n  --resolution=768 \\\n  --train_batch_size=1 \\\n  --num_train_epochs=100 --checkpointing_steps=5000 \\\n  --learning_rate=1e-04 --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --seed=42 \\\n  --rank=4 \\\n  --gradient_checkpointing \\\n  --output_dir=\"kandi22-decoder-naruto-lora\" \\\n  --validation_prompt=\"cute dragon creature\" --report_to=\"wandb\" \\\n  --push_to_hub \\\n```\n\n#### Train prior\n\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\" train_text_to_image_prior_lora.py \\\n  --dataset_name=$DATASET_NAME --caption_column=\"text\" \\\n  --resolution=768 \\\n  --train_batch_size=1 \\\n  --num_train_epochs=100 --checkpointing_steps=5000 \\\n  --learning_rate=1e-04 --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --seed=42 \\\n  --rank=4 \\\n  --output_dir=\"kandi22-prior-naruto-lora\" \\\n  --validation_prompt=\"cute dragon creature\" --report_to=\"wandb\" \\\n  --push_to_hub \\\n```\n\n**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run above scripts in consumer GPUs like T4 or V100.___**\n\n\n### Inference\n\n#### Inference using fine-tuned LoRA checkpoint for decoder\n\nOnce you have trained a Kandinsky decoder model using the above command, inference can be done with the `AutoPipelineForText2Image` after loading the trained LoRA weights.  You need to pass the `output_dir` for loading the LoRA weights, which in this case is `kandi22-decoder-naruto-lora`.\n\n\n```python\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipe = AutoPipelineForText2Image.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16)\npipe.unet.load_attn_procs(output_dir)\npipe.enable_model_cpu_offload()\n\nprompt='A robot naruto, 4k photo'\nimage = pipe(prompt=prompt).images[0]\nimage.save(\"robot_naruto.png\")\n```\n\n#### Inference using fine-tuned LoRA checkpoint for prior\n\n```python\nfrom diffusers import AutoPipelineForText2Image\nimport torch\n\npipe = AutoPipelineForText2Image.from_pretrained(\"kandinsky-community/kandinsky-2-2-decoder\", torch_dtype=torch.float16)\npipe.prior_prior.load_attn_procs(output_dir)\npipe.enable_model_cpu_offload()\n\nprompt='A robot naruto, 4k photo'\nimage = pipe(prompt=prompt).images[0]\nimage.save(\"robot_naruto.png\")\nimage\n```\n\n### Training with xFormers:\n\nYou can enable memory efficient attention by [installing xFormers](https://huggingface.co/docs/diffusers/main/en/optimization/xformers) and passing the `--enable_xformers_memory_efficient_attention` argument to the script.\n\nxFormers training is not available for fine-tuning the prior model.\n\n**Note**:\n\nAccording to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment."
  },
  {
    "path": "examples/kandinsky2_2/text_to_image/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\ndatasets\nftfy\ntensorboard\nJinja2\n"
  },
  {
    "path": "examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport shutil\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\nfrom transformers.utils import ContextManagers\n\nimport diffusers\nfrom diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel, compute_snr\nfrom diffusers.utils import check_min_version, is_wandb_available, make_image_grid\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nif is_wandb_available():\n    import wandb\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef save_model_card(\n    args,\n    repo_id: str,\n    images=None,\n    repo_folder=None,\n):\n    img_str = \"\"\n    if len(images) > 0:\n        image_grid = make_image_grid(images, 1, len(args.validation_prompts))\n        image_grid.save(os.path.join(repo_folder, \"val_imgs_grid.png\"))\n        img_str += \"![val_imgs_grid](./val_imgs_grid.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {args.pretrained_decoder_model_name_or_path}\ndatasets:\n- {args.dataset_name}\nprior:\n- {args.pretrained_prior_model_name_or_path}\ntags:\n- kandinsky\n- text-to-image\n- diffusers\n- diffusers-training\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# Finetuning - {repo_id}\n\nThis pipeline was finetuned from **{args.pretrained_decoder_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \\n\n{img_str}\n\n## Pipeline usage\n\nYou can use the pipeline like so:\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\"{repo_id}\", torch_dtype=torch.float16)\nprompt = \"{args.validation_prompts[0]}\"\nimage = pipeline(prompt).images[0]\nimage.save(\"my_image.png\")\n```\n\n## Training info\n\nThese are the key hyperparameters used during training:\n\n* Epochs: {args.num_train_epochs}\n* Learning rate: {args.learning_rate}\n* Batch size: {args.train_batch_size}\n* Gradient accumulation steps: {args.gradient_accumulation_steps}\n* Image resolution: {args.resolution}\n* Mixed-precision: {args.mixed_precision}\n\n\"\"\"\n    wandb_info = \"\"\n    if is_wandb_available():\n        wandb_run_url = None\n        if wandb.run is not None:\n            wandb_run_url = wandb.run.url\n\n    if wandb_run_url is not None:\n        wandb_info = f\"\"\"\nMore information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).\n\"\"\"\n\n    model_card += wandb_info\n\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef log_validation(vae, image_encoder, image_processor, unet, args, accelerator, weight_dtype, epoch):\n    logger.info(\"Running validation... \")\n\n    pipeline = AutoPipelineForText2Image.from_pretrained(\n        args.pretrained_decoder_model_name_or_path,\n        vae=accelerator.unwrap_model(vae),\n        prior_image_encoder=accelerator.unwrap_model(image_encoder),\n        prior_image_processor=image_processor,\n        unet=accelerator.unwrap_model(unet),\n        torch_dtype=weight_dtype,\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    images = []\n    for i in range(len(args.validation_prompts)):\n        with torch.autocast(\"cuda\"):\n            image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]\n\n        images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompts[i]}\")\n                        for i, image in enumerate(images)\n                    ]\n                }\n            )\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    return images\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of finetuning Kandinsky 2.2.\")\n    parser.add_argument(\n        \"--pretrained_decoder_model_name_or_path\",\n        type=str,\n        default=\"kandinsky-community/kandinsky-2-2-decoder\",\n        required=False,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_prior_model_name_or_path\",\n        type=str,\n        default=\"kandinsky-community/kandinsky-2-2-prior\",\n        required=False,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompts\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\"A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`.\"),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"kandi_2_2-model-finetuned\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=1, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"learning rate\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\n        \"--adam_weight_decay\",\n        type=float,\n        default=0.0,\n        required=False,\n        help=\"weight decay_to_use\",\n    )\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=5,\n        help=\"Run validation every X epochs.\",\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"text2image-fine-tune\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder=\"scheduler\")\n    image_processor = CLIPImageProcessor.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"image_processor\"\n    )\n\n    def deepspeed_zero_init_disabled_context_manager():\n        \"\"\"\n        returns either a context list that includes one that will disable zero.Init or an empty context list\n        \"\"\"\n        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None\n        if deepspeed_plugin is None:\n            return []\n\n        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]\n\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n        vae = VQModel.from_pretrained(\n            args.pretrained_decoder_model_name_or_path, subfolder=\"movq\", torch_dtype=weight_dtype\n        ).eval()\n        image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n            args.pretrained_prior_model_name_or_path, subfolder=\"image_encoder\", torch_dtype=weight_dtype\n        ).eval()\n    unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder=\"unet\")\n\n    # Freeze vae and image_encoder\n    vae.requires_grad_(False)\n    image_encoder.requires_grad_(False)\n\n    # Set unet to trainable.\n    unet.train()\n\n    # Create EMA for the unet.\n    if args.use_ema:\n        ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder=\"unet\")\n        ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)\n        ema_unet.to(accelerator.device)\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if args.use_ema:\n                ema_unet.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n            for i, model in enumerate(models):\n                model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"unet_ema\"), UNet2DConditionModel)\n                ema_unet.load_state_dict(load_model.state_dict())\n                ema_unet.to(accelerator.device)\n                del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    optimizer = optimizer_cls(\n        unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    image_column = args.image_column\n    if image_column not in column_names:\n        raise ValueError(f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\")\n\n    def center_crop(image):\n        width, height = image.size\n        new_size = min(width, height)\n        left = (width - new_size) / 2\n        top = (height - new_size) / 2\n        right = (width + new_size) / 2\n        bottom = (height + new_size) / 2\n        return image.crop((left, top, right, bottom))\n\n    def train_transforms(img):\n        img = center_crop(img)\n        img = img.resize((args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1)\n        img = np.array(img).astype(np.float32) / 127.5 - 1\n        img = torch.from_numpy(np.transpose(img, [2, 0, 1]))\n        return img\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n        examples[\"clip_pixel_values\"] = image_processor(images, return_tensors=\"pt\").pixel_values\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        clip_pixel_values = torch.stack([example[\"clip_pixel_values\"] for example in examples])\n        clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()\n        return {\"pixel_values\": pixel_values, \"clip_pixel_values\": clip_pixel_values}\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n    # Move image_encode and vae to gpu and cast to weight_dtype\n    image_encoder.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        tracker_config.pop(\"validation_prompts\")\n        accelerator.init_trackers(args.tracker_project_name, tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                images = batch[\"pixel_values\"].to(weight_dtype)\n                clip_images = batch[\"clip_pixel_values\"].to(weight_dtype)\n                latents = vae.encode(images).latents\n                image_embeds = image_encoder(clip_images).image_embeds\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                target = noise\n\n                # Predict the noise residual and compute loss\n                added_cond_kwargs = {\"image_embeds\": image_embeds}\n\n                model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]\n\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_unet.step(unet.parameters())\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompts is not None and epoch % args.validation_epochs == 0:\n                if args.use_ema:\n                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                    ema_unet.store(unet.parameters())\n                    ema_unet.copy_to(unet.parameters())\n                log_validation(\n                    vae,\n                    image_encoder,\n                    image_processor,\n                    unet,\n                    args,\n                    accelerator,\n                    weight_dtype,\n                    global_step,\n                )\n                if args.use_ema:\n                    # Switch back to the original UNet parameters.\n                    ema_unet.restore(unet.parameters())\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = accelerator.unwrap_model(unet)\n        if args.use_ema:\n            ema_unet.copy_to(unet.parameters())\n\n        pipeline = AutoPipelineForText2Image.from_pretrained(\n            args.pretrained_decoder_model_name_or_path,\n            vae=vae,\n            unet=unet,\n        )\n        pipeline.decoder_pipe.save_pretrained(args.output_dir)\n\n        # Run a final round of inference.\n        images = []\n        if args.validation_prompts is not None:\n            logger.info(\"Running inference for collecting generated images...\")\n            pipeline.torch_dtype = weight_dtype\n            pipeline.set_progress_bar_config(disable=True)\n            pipeline.enable_model_cpu_offload()\n\n            if args.enable_xformers_memory_efficient_attention:\n                pipeline.enable_xformers_memory_efficient_attention()\n\n            if args.seed is None:\n                generator = None\n            else:\n                generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n            for i in range(len(args.validation_prompts)):\n                with torch.autocast(\"cuda\"):\n                    image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]\n                images.append(image)\n\n        if args.push_to_hub:\n            save_model_card(args, repo_id, images, repo_folder=args.output_dir)\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Fine-tuning script for Kandinsky with support for LoRA.\"\"\"\n\nimport argparse\nimport logging\nimport math\nimport os\nimport shutil\nfrom pathlib import Path\n\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\n\nimport diffusers\nfrom diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel\nfrom diffusers.loaders import AttnProcsLayers\nfrom diffusers.models.attention_processor import LoRAAttnAddedKVProcessor\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import compute_snr\nfrom diffusers.utils import check_min_version, is_wandb_available\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):\n    img_str = \"\"\n    for i, image in enumerate(images):\n        image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n        img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {base_model}\ntags:\n- kandinsky\n- text-to-image\n- diffusers\n- diffusers-training\n- lora\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# LoRA text2image fine-tuning - {repo_id}\nThese are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \\n\n{img_str}\n\"\"\"\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of finetuning Kandinsky 2.2 with LoRA.\")\n    parser.add_argument(\n        \"--pretrained_decoder_model_name_or_path\",\n        type=str,\n        default=\"kandinsky-community/kandinsky-2-2-decoder\",\n        required=False,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_prior_model_name_or_path\",\n        type=str,\n        default=\"kandinsky-community/kandinsky-2-2-prior\",\n        required=False,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--validation_prompt\", type=str, default=None, help=\"A prompt that is sampled during training for inference.\"\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"kandi_2_2-model-finetuned-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=1, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=0.0, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder=\"scheduler\")\n    image_processor = CLIPImageProcessor.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"image_processor\"\n    )\n    image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"image_encoder\"\n    )\n\n    vae = VQModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder=\"movq\")\n\n    unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder=\"unet\")\n    # freeze parameters of models to save more memory\n    unet.requires_grad_(False)\n    vae.requires_grad_(False)\n\n    image_encoder.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n    image_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    lora_attn_procs = {}\n    for name in unet.attn_processors.keys():\n        cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n        if name.startswith(\"mid_block\"):\n            hidden_size = unet.config.block_out_channels[-1]\n        elif name.startswith(\"up_blocks\"):\n            block_id = int(name[len(\"up_blocks.\")])\n            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n        elif name.startswith(\"down_blocks\"):\n            block_id = int(name[len(\"down_blocks.\")])\n            hidden_size = unet.config.block_out_channels[block_id]\n\n        lora_attn_procs[name] = LoRAAttnAddedKVProcessor(\n            hidden_size=hidden_size,\n            cross_attention_dim=cross_attention_dim,\n            rank=args.rank,\n        )\n\n    unet.set_attn_processor(lora_attn_procs)\n\n    lora_layers = AttnProcsLayers(unet.attn_processors)\n\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    optimizer = optimizer_cls(\n        lora_layers.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    image_column = args.image_column\n    if image_column not in column_names:\n        raise ValueError(f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\")\n\n    def center_crop(image):\n        width, height = image.size\n        new_size = min(width, height)\n        left = (width - new_size) / 2\n        top = (height - new_size) / 2\n        right = (width + new_size) / 2\n        bottom = (height + new_size) / 2\n        return image.crop((left, top, right, bottom))\n\n    def train_transforms(img):\n        img = center_crop(img)\n        img = img.resize((args.resolution, args.resolution), resample=Image.BICUBIC, reducing_gap=1)\n        img = np.array(img).astype(np.float32) / 127.5 - 1\n        img = torch.from_numpy(np.transpose(img, [2, 0, 1]))\n        return img\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n        examples[\"clip_pixel_values\"] = image_processor(images, return_tensors=\"pt\").pixel_values\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        clip_pixel_values = torch.stack([example[\"clip_pixel_values\"] for example in examples])\n        clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()\n        return {\"pixel_values\": pixel_values, \"clip_pixel_values\": clip_pixel_values}\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n    # Prepare everything with our `accelerator`.\n    lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        lora_layers, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"text2image-fine-tune\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                images = batch[\"pixel_values\"].to(weight_dtype)\n                clip_images = batch[\"clip_pixel_values\"].to(weight_dtype)\n                latents = vae.encode(images).latents\n                image_embeds = image_encoder(clip_images).image_embeds\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                target = noise\n\n                # Predict the noise residual and compute loss\n                added_cond_kwargs = {\"image_embeds\": image_embeds}\n\n                model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]\n\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = lora_layers.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                logger.info(\n                    f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n                    f\" {args.validation_prompt}.\"\n                )\n                # create pipeline\n                pipeline = AutoPipelineForText2Image.from_pretrained(\n                    args.pretrained_decoder_model_name_or_path,\n                    unet=accelerator.unwrap_model(unet),\n                    torch_dtype=weight_dtype,\n                )\n                pipeline = pipeline.to(accelerator.device)\n                pipeline.set_progress_bar_config(disable=True)\n\n                # run inference\n                generator = torch.Generator(device=accelerator.device)\n                if args.seed is not None:\n                    generator = generator.manual_seed(args.seed)\n                images = []\n                for _ in range(args.num_validation_images):\n                    images.append(\n                        pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]\n                    )\n\n                for tracker in accelerator.trackers:\n                    if tracker.name == \"tensorboard\":\n                        np_images = np.stack([np.asarray(img) for img in images])\n                        tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n                    if tracker.name == \"wandb\":\n                        tracker.log(\n                            {\n                                \"validation\": [\n                                    wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                    for i, image in enumerate(images)\n                                ]\n                            }\n                        )\n\n                del pipeline\n                torch.cuda.empty_cache()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unet.to(torch.float32)\n        unet.save_attn_procs(args.output_dir)\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_decoder_model_name_or_path,\n                dataset_name=args.dataset_name,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    # Final inference\n    # Load previous pipeline\n    pipeline = AutoPipelineForText2Image.from_pretrained(\n        args.pretrained_decoder_model_name_or_path, torch_dtype=weight_dtype\n    )\n    pipeline = pipeline.to(accelerator.device)\n\n    # load attention processors\n    pipeline.unet.load_attn_procs(args.output_dir)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device)\n    if args.seed is not None:\n        generator = generator.manual_seed(args.seed)\n    images = []\n    for _ in range(args.num_validation_images):\n        images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])\n\n    if accelerator.is_main_process:\n        for tracker in accelerator.trackers:\n            if len(images) != 0:\n                if tracker.name == \"tensorboard\":\n                    np_images = np.stack([np.asarray(img) for img in images])\n                    tracker.writer.add_images(\"test\", np_images, epoch, dataformats=\"NHWC\")\n                if tracker.name == \"wandb\":\n                    tracker.log(\n                        {\n                            \"test\": [\n                                wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                for i, image in enumerate(images)\n                            ]\n                        }\n                    )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Fine-tuning script for Stable Diffusion for text2image with support for LoRA.\"\"\"\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom tqdm import tqdm\nfrom transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection\n\nimport diffusers\nfrom diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer\nfrom diffusers.loaders import AttnProcsLayers\nfrom diffusers.models.attention_processor import LoRAAttnProcessor\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import compute_snr\nfrom diffusers.utils import check_min_version, is_wandb_available\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):\n    img_str = \"\"\n    for i, image in enumerate(images):\n        image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n        img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {base_model}\ntags:\n- kandinsky\n- text-to-image\n- diffusers\n- diffusers-training\n- lora\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# LoRA text2image fine-tuning - {repo_id}\nThese are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \\n\n{img_str}\n\"\"\"\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of finetuning Kandinsky 2.2.\")\n    parser.add_argument(\n        \"--pretrained_decoder_model_name_or_path\",\n        type=str,\n        default=\"kandinsky-community/kandinsky-2-2-decoder\",\n        required=False,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_prior_model_name_or_path\",\n        type=str,\n        default=\"kandinsky-community/kandinsky-2-2-prior\",\n        required=False,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\", type=str, default=None, help=\"A prompt that is sampled during training for inference.\"\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"kandi_2_2-model-finetuned-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=1, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"learning rate\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\n        \"--adam_weight_decay\",\n        type=float,\n        default=0.0,\n        required=False,\n        help=\"weight decay_to_use\",\n    )\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    return args\n\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n    # Load scheduler, image_processor, tokenizer and models.\n    noise_scheduler = DDPMScheduler(beta_schedule=\"squaredcos_cap_v2\", prediction_type=\"sample\")\n    image_processor = CLIPImageProcessor.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"image_processor\"\n    )\n    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"tokenizer\")\n    image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"image_encoder\"\n    )\n    text_encoder = CLIPTextModelWithProjection.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"text_encoder\"\n    )\n    prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"prior\")\n    # freeze parameters of models to save more memory\n    image_encoder.requires_grad_(False)\n    prior.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move image_encoder, text_encoder and prior to device and cast to weight_dtype\n    prior.to(accelerator.device, dtype=weight_dtype)\n    image_encoder.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    lora_attn_procs = {}\n    for name in prior.attn_processors.keys():\n        lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=2048, rank=args.rank)\n\n    prior.set_attn_processor(lora_attn_procs)\n    lora_layers = AttnProcsLayers(prior.attn_processors)\n\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    optimizer = optimizer_cls(\n        lora_layers.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        text_input_ids = inputs.input_ids\n        text_mask = inputs.attention_mask.bool()\n        return text_input_ids, text_mask\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"clip_pixel_values\"] = image_processor(images, return_tensors=\"pt\").pixel_values\n        examples[\"text_input_ids\"], examples[\"text_mask\"] = tokenize_captions(examples)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        clip_pixel_values = torch.stack([example[\"clip_pixel_values\"] for example in examples])\n        clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()\n        text_input_ids = torch.stack([example[\"text_input_ids\"] for example in examples])\n        text_mask = torch.stack([example[\"text_mask\"] for example in examples])\n        return {\"clip_pixel_values\": clip_pixel_values, \"text_input_ids\": text_input_ids, \"text_mask\": text_mask}\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n    clip_mean = prior.clip_mean.clone()\n    clip_std = prior.clip_std.clone()\n    lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        lora_layers, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"text2image-fine-tune\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)\n    clip_std = clip_std.to(weight_dtype).to(accelerator.device)\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        prior.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(prior):\n                # Convert images to latent space\n                text_input_ids, text_mask, clip_images = (\n                    batch[\"text_input_ids\"],\n                    batch[\"text_mask\"],\n                    batch[\"clip_pixel_values\"].to(weight_dtype),\n                )\n                with torch.no_grad():\n                    text_encoder_output = text_encoder(text_input_ids)\n                    prompt_embeds = text_encoder_output.text_embeds\n                    text_encoder_hidden_states = text_encoder_output.last_hidden_state\n\n                    image_embeds = image_encoder(clip_images).image_embeds\n                    # Sample noise that we'll add to the image_embeds\n                    noise = torch.randn_like(image_embeds)\n                    bsz = image_embeds.shape[0]\n                    # Sample a random timestep for each image\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=image_embeds.device\n                    )\n                    timesteps = timesteps.long()\n                    image_embeds = (image_embeds - clip_mean) / clip_std\n                    noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)\n\n                    target = image_embeds\n\n                # Predict the noise residual and compute loss\n                model_pred = prior(\n                    noisy_latents,\n                    timestep=timesteps,\n                    proj_embedding=prompt_embeds,\n                    encoder_hidden_states=text_encoder_hidden_states,\n                    attention_mask=text_mask,\n                ).predicted_image_embedding\n\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                logger.info(\n                    f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n                    f\" {args.validation_prompt}.\"\n                )\n                # create pipeline\n                pipeline = AutoPipelineForText2Image.from_pretrained(\n                    args.pretrained_decoder_model_name_or_path,\n                    prior_prior=accelerator.unwrap_model(prior),\n                    torch_dtype=weight_dtype,\n                )\n                pipeline = pipeline.to(accelerator.device)\n                pipeline.set_progress_bar_config(disable=True)\n\n                # run inference\n                generator = torch.Generator(device=accelerator.device)\n                if args.seed is not None:\n                    generator = generator.manual_seed(args.seed)\n                images = []\n                for _ in range(args.num_validation_images):\n                    images.append(\n                        pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]\n                    )\n\n                for tracker in accelerator.trackers:\n                    if tracker.name == \"tensorboard\":\n                        np_images = np.stack([np.asarray(img) for img in images])\n                        tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n                    if tracker.name == \"wandb\":\n                        tracker.log(\n                            {\n                                \"validation\": [\n                                    wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                    for i, image in enumerate(images)\n                                ]\n                            }\n                        )\n\n                del pipeline\n                torch.cuda.empty_cache()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        prior = prior.to(torch.float32)\n        prior.save_attn_procs(args.output_dir)\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_prior_model_name_or_path,\n                dataset_name=args.dataset_name,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    # Final inference\n    # Load previous pipeline\n    pipeline = AutoPipelineForText2Image.from_pretrained(\n        args.pretrained_decoder_model_name_or_path, torch_dtype=weight_dtype\n    )\n    pipeline = pipeline.to(accelerator.device)\n\n    # load attention processors\n    pipeline.prior_prior.load_attn_procs(args.output_dir)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device)\n    if args.seed is not None:\n        generator = generator.manual_seed(args.seed)\n    images = []\n    for _ in range(args.num_validation_images):\n        images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])\n\n    if accelerator.is_main_process:\n        for tracker in accelerator.trackers:\n            if len(images) != 0:\n                if tracker.name == \"tensorboard\":\n                    np_images = np.stack([np.asarray(img) for img in images])\n                    tracker.writer.add_images(\"test\", np_images, epoch, dataformats=\"NHWC\")\n                if tracker.name == \"wandb\":\n                    tracker.log(\n                        {\n                            \"test\": [\n                                wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                for i, image in enumerate(images)\n                            ]\n                        }\n                    )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom tqdm import tqdm\nfrom transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection\nfrom transformers.utils import ContextManagers\n\nimport diffusers\nfrom diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel, compute_snr\nfrom diffusers.utils import check_min_version, is_wandb_available, make_image_grid\n\n\nif is_wandb_available():\n    import wandb\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef save_model_card(\n    args,\n    repo_id: str,\n    images=None,\n    repo_folder=None,\n):\n    img_str = \"\"\n    if len(images) > 0:\n        image_grid = make_image_grid(images, 1, len(args.validation_prompts))\n        image_grid.save(os.path.join(repo_folder, \"val_imgs_grid.png\"))\n        img_str += \"![val_imgs_grid](./val_imgs_grid.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {args.pretrained_prior_model_name_or_path}\ndatasets:\n- {args.dataset_name}\ntags:\n- kandinsky\n- text-to-image\n- diffusers\n- diffusers-training\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# Finetuning - {repo_id}\n\nThis pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \\n\n{img_str}\n\n## Pipeline usage\n\nYou can use the pipeline like so:\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipe_prior = DiffusionPipeline.from_pretrained(\"{repo_id}\", torch_dtype=torch.float16)\npipe_t2i = DiffusionPipeline.from_pretrained(\"{args.pretrained_decoder_model_name_or_path}\", torch_dtype=torch.float16)\nprompt = \"{args.validation_prompts[0]}\"\nimage_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple()\nimage = pipe_t2i(image_embeds=image_embeds, negative_image_embeds=negative_image_embeds).images[0]\nimage.save(\"my_image.png\")\n```\n\n## Training info\n\nThese are the key hyperparameters used during training:\n\n* Epochs: {args.num_train_epochs}\n* Learning rate: {args.learning_rate}\n* Batch size: {args.train_batch_size}\n* Gradient accumulation steps: {args.gradient_accumulation_steps}\n* Image resolution: {args.resolution}\n* Mixed-precision: {args.mixed_precision}\n\n\"\"\"\n    wandb_info = \"\"\n    if is_wandb_available():\n        wandb_run_url = None\n        if wandb.run is not None:\n            wandb_run_url = wandb.run.url\n\n    if wandb_run_url is not None:\n        wandb_info = f\"\"\"\nMore information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).\n\"\"\"\n\n    model_card += wandb_info\n\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef log_validation(\n    image_encoder, image_processor, text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch\n):\n    logger.info(\"Running validation... \")\n\n    pipeline = AutoPipelineForText2Image.from_pretrained(\n        args.pretrained_decoder_model_name_or_path,\n        prior_image_encoder=accelerator.unwrap_model(image_encoder),\n        prior_image_processor=image_processor,\n        prior_text_encoder=accelerator.unwrap_model(text_encoder),\n        prior_tokenizer=tokenizer,\n        prior_prior=accelerator.unwrap_model(prior),\n        torch_dtype=weight_dtype,\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    images = []\n    for i in range(len(args.validation_prompts)):\n        with torch.autocast(\"cuda\"):\n            image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]\n\n        images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompts[i]}\")\n                        for i, image in enumerate(images)\n                    ]\n                }\n            )\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    return images\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of finetuning Kandinsky 2.2.\")\n    parser.add_argument(\n        \"--pretrained_decoder_model_name_or_path\",\n        type=str,\n        default=\"kandinsky-community/kandinsky-2-2-decoder\",\n        required=False,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_prior_model_name_or_path\",\n        type=str,\n        default=\"kandinsky-community/kandinsky-2-2-prior\",\n        required=False,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompts\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\"A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`.\"),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"kandi_2_2-model-finetuned\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=1, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"learning rate\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\n        \"--adam_weight_decay\",\n        type=float,\n        default=0.0,\n        required=False,\n        help=\"weight decay_to_use\",\n    )\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=5,\n        help=\"Run validation every X epochs.\",\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"text2image-fine-tune\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load scheduler, image_processor, tokenizer and models.\n    noise_scheduler = DDPMScheduler(beta_schedule=\"squaredcos_cap_v2\", prediction_type=\"sample\")\n    image_processor = CLIPImageProcessor.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"image_processor\"\n    )\n    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"tokenizer\")\n\n    def deepspeed_zero_init_disabled_context_manager():\n        \"\"\"\n        returns either a context list that includes one that will disable zero.Init or an empty context list\n        \"\"\"\n        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None\n        if deepspeed_plugin is None:\n            return []\n\n        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]\n\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n        image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n            args.pretrained_prior_model_name_or_path, subfolder=\"image_encoder\", torch_dtype=weight_dtype\n        ).eval()\n        text_encoder = CLIPTextModelWithProjection.from_pretrained(\n            args.pretrained_prior_model_name_or_path, subfolder=\"text_encoder\", torch_dtype=weight_dtype\n        ).eval()\n\n    prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"prior\")\n\n    # Freeze text_encoder and image_encoder\n    text_encoder.requires_grad_(False)\n    image_encoder.requires_grad_(False)\n\n    # Set prior to trainable.\n    prior.train()\n\n    # Create EMA for the prior.\n    if args.use_ema:\n        ema_prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"prior\")\n        ema_prior = EMAModel(ema_prior.parameters(), model_cls=PriorTransformer, model_config=ema_prior.config)\n        ema_prior.to(accelerator.device)\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if args.use_ema:\n                ema_prior.save_pretrained(os.path.join(output_dir, \"prior_ema\"))\n\n            for i, model in enumerate(models):\n                model.save_pretrained(os.path.join(output_dir, \"prior\"))\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"prior_ema\"), PriorTransformer)\n                ema_prior.load_state_dict(load_model.state_dict())\n                ema_prior.to(accelerator.device)\n                del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = PriorTransformer.from_pretrained(input_dir, subfolder=\"prior\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n    optimizer = optimizer_cls(\n        prior.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        text_input_ids = inputs.input_ids\n        text_mask = inputs.attention_mask.bool()\n        return text_input_ids, text_mask\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"clip_pixel_values\"] = image_processor(images, return_tensors=\"pt\").pixel_values\n        examples[\"text_input_ids\"], examples[\"text_mask\"] = tokenize_captions(examples)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        clip_pixel_values = torch.stack([example[\"clip_pixel_values\"] for example in examples])\n        clip_pixel_values = clip_pixel_values.to(memory_format=torch.contiguous_format).float()\n        text_input_ids = torch.stack([example[\"text_input_ids\"] for example in examples])\n        text_mask = torch.stack([example[\"text_mask\"] for example in examples])\n        return {\"clip_pixel_values\": clip_pixel_values, \"text_input_ids\": text_input_ids, \"text_mask\": text_mask}\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n\n    clip_mean = prior.clip_mean.clone()\n    clip_std = prior.clip_std.clone()\n\n    prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        prior, optimizer, train_dataloader, lr_scheduler\n    )\n\n    image_encoder.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        tracker_config.pop(\"validation_prompts\")\n        accelerator.init_trackers(args.tracker_project_name, tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)\n    clip_std = clip_std.to(weight_dtype).to(accelerator.device)\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(prior):\n                # Convert images to latent space\n                text_input_ids, text_mask, clip_images = (\n                    batch[\"text_input_ids\"],\n                    batch[\"text_mask\"],\n                    batch[\"clip_pixel_values\"].to(weight_dtype),\n                )\n                with torch.no_grad():\n                    text_encoder_output = text_encoder(text_input_ids)\n                    prompt_embeds = text_encoder_output.text_embeds\n                    text_encoder_hidden_states = text_encoder_output.last_hidden_state\n\n                    image_embeds = image_encoder(clip_images).image_embeds\n                    # Sample noise that we'll add to the image_embeds\n                    noise = torch.randn_like(image_embeds)\n                    bsz = image_embeds.shape[0]\n                    # Sample a random timestep for each image\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=image_embeds.device\n                    )\n                    timesteps = timesteps.long()\n                    image_embeds = (image_embeds - clip_mean) / clip_std\n                    noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)\n\n                    target = image_embeds\n\n                # Predict the noise residual and compute loss\n                model_pred = prior(\n                    noisy_latents,\n                    timestep=timesteps,\n                    proj_embedding=prompt_embeds,\n                    encoder_hidden_states=text_encoder_hidden_states,\n                    attention_mask=text_mask,\n                ).predicted_image_embedding\n\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_prior.step(prior.parameters())\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompts is not None and epoch % args.validation_epochs == 0:\n                if args.use_ema:\n                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                    ema_prior.store(prior.parameters())\n                    ema_prior.copy_to(prior.parameters())\n                log_validation(\n                    image_encoder,\n                    image_processor,\n                    text_encoder,\n                    tokenizer,\n                    prior,\n                    args,\n                    accelerator,\n                    weight_dtype,\n                    global_step,\n                )\n                if args.use_ema:\n                    # Switch back to the original UNet parameters.\n                    ema_prior.restore(prior.parameters())\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        prior = accelerator.unwrap_model(prior)\n        if args.use_ema:\n            ema_prior.copy_to(prior.parameters())\n\n        pipeline = AutoPipelineForText2Image.from_pretrained(\n            args.pretrained_decoder_model_name_or_path,\n            prior_image_encoder=image_encoder,\n            prior_text_encoder=text_encoder,\n            prior_prior=prior,\n        )\n        pipeline.prior_pipe.save_pretrained(args.output_dir)\n\n        # Run a final round of inference.\n        images = []\n        if args.validation_prompts is not None:\n            logger.info(\"Running inference for collecting generated images...\")\n            pipeline = pipeline.to(accelerator.device)\n            pipeline.torch_dtype = weight_dtype\n            pipeline.set_progress_bar_config(disable=True)\n\n            if args.seed is None:\n                generator = None\n            else:\n                generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n            for i in range(len(args.validation_prompts)):\n                with torch.autocast(\"cuda\"):\n                    image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]\n                images.append(image)\n\n        if args.push_to_hub:\n            save_model_card(args, repo_id, images, repo_folder=args.output_dir)\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/model_search/README.md",
    "content": "# Search models on Civitai and Hugging Face\n\nThe [auto_diffusers](https://github.com/suzukimain/auto_diffusers) library provides additional functionalities to Diffusers such as searching for models on Civitai and the Hugging Face Hub.\nPlease refer to the original library [here](https://pypi.org/project/auto-diffusers/)\n\n## Installation\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n> [!IMPORTANT]\n> To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the installation up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment.\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\nSet up the pipeline. You can also cd to this folder and run it.\n```bash\n!wget https://raw.githubusercontent.com/suzukimain/auto_diffusers/refs/heads/master/src/auto_diffusers/pipeline_easy.py\n```\n\n## Load from Civitai\n```python\nfrom pipeline_easy import (\n    EasyPipelineForText2Image,\n    EasyPipelineForImage2Image,\n    EasyPipelineForInpainting,\n)\n\n# Text-to-Image\npipeline = EasyPipelineForText2Image.from_civitai(\n    \"search_word\",\n    base_model=\"SD 1.5\",\n).to(\"cuda\")\n\n\n# Image-to-Image\npipeline = EasyPipelineForImage2Image.from_civitai(\n    \"search_word\",\n    base_model=\"SD 1.5\",\n).to(\"cuda\")\n\n\n# Inpainting\npipeline = EasyPipelineForInpainting.from_civitai(\n    \"search_word\",\n    base_model=\"SD 1.5\",\n).to(\"cuda\")\n```\n\n## Load from Hugging Face\n```python\nfrom pipeline_easy import (\n    EasyPipelineForText2Image,\n    EasyPipelineForImage2Image,\n    EasyPipelineForInpainting,\n)\n\n# Text-to-Image\npipeline = EasyPipelineForText2Image.from_huggingface(\n    \"search_word\",\n    checkpoint_format=\"diffusers\",\n).to(\"cuda\")\n\n\n# Image-to-Image\npipeline = EasyPipelineForImage2Image.from_huggingface(\n    \"search_word\",\n    checkpoint_format=\"diffusers\",\n).to(\"cuda\")\n\n\n# Inpainting\npipeline = EasyPipelineForInpainting.from_huggingface(\n    \"search_word\",\n    checkpoint_format=\"diffusers\",\n).to(\"cuda\")\n```\n\n\n## Search Civitai and Huggingface\n\n```python\n# Load Lora into the pipeline.\npipeline.auto_load_lora_weights(\"Detail Tweaker\")\n\n# Load TextualInversion into the pipeline.\npipeline.auto_load_textual_inversion(\"EasyNegative\", token=\"EasyNegative\")\n```\n\n### Search Civitai\n\n> [!TIP]\n> **If an error occurs, insert the `token` and run again.**\n\n#### `EasyPipeline.from_civitai` parameters\n\n| Name            | Type                   | Default       | Description                                                                    |\n|:---------------:|:----------------------:|:-------------:|:-----------------------------------------------------------------------------------:|\n| search_word     | string, Path           | ー            | The search query string. Can be a keyword, Civitai URL, local directory or file path. |\n| model_type      | string                 | `Checkpoint`  | The type of model to search for.  <br>(for example `Checkpoint`, `TextualInversion`, `Controlnet`, `LORA`, `Hypernetwork`, `AestheticGradient`, `Poses`)      |\n| base_model      | string                 | None          | Trained model tag (for example  `SD 1.5`, `SD 3.5`, `SDXL 1.0`) |\n| torch_dtype     | string, torch.dtype    | None          | Override the default `torch.dtype` and load the model with another dtype.     |\n| force_download  | bool                   | False         | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. |\n| cache_dir       | string, Path | None    | Path to the folder where cached files are stored. |\n| resume          | bool   | False         | Whether to resume an incomplete download. |\n| token           | string | None          | API token for Civitai authentication. |\n\n\n#### `search_civitai` parameters\n\n| Name            | Type           | Default       | Description                                                                    |\n|:---------------:|:--------------:|:-------------:|:-----------------------------------------------------------------------------------:|\n| search_word     | string, Path   | ー            | The search query string. Can be a keyword, Civitai URL, local directory or file path. |\n| model_type      | string         | `Checkpoint`  | The type of model to search for. <br>(for example `Checkpoint`, `TextualInversion`, `Controlnet`, `LORA`, `Hypernetwork`, `AestheticGradient`, `Poses`)   |\n| base_model      | string         | None          | Trained model tag (for example  `SD 1.5`, `SD 3.5`, `SDXL 1.0`)                        |\n| download        | bool           | False         | Whether to download the model.                                   |\n| force_download  | bool           | False         | Whether to force the download if the model already exists.                          |\n| cache_dir       | string, Path   | None          | Path to the folder where cached files are stored.                              |\n| resume          | bool           | False         | Whether to resume an incomplete download.                                           |\n| token           | string         | None          | API token for Civitai authentication.                                               |\n| include_params  | bool           | False         | Whether to include parameters in the returned data.           |\n| skip_error      | bool           | False         | Whether to skip errors and return None.                                             |\n\n### Search Huggingface\n\n> [!TIP]\n> **If an error occurs, insert the `token` and run again.**\n\n#### `EasyPipeline.from_huggingface` parameters\n\n| Name                  | Type                | Default        | Description                                                      |\n|:---------------------:|:-------------------:|:--------------:|:----------------------------------------------------------------:|\n| search_word           | string, Path        | ー             | The search query string. Can be a keyword, Hugging Face URL, local directory or file path, or a Hugging Face path (`<creator>/<repo>`). |\n| checkpoint_format     | string              | `single_file`  | The format of the model checkpoint.<br>● `single_file` to search for `single file checkpoint` <br>●`diffusers` to search for `multifolder diffusers format checkpoint` |\n| torch_dtype           | string, torch.dtype | None           | Override the default `torch.dtype` and load the model with another dtype. |\n| force_download        | bool                | False          | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. |\n| cache_dir             | string, Path        | None           | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used.   |\n| token                 | string, bool        | None           | The token to use as HTTP bearer authorization for remote files.  |\n\n\n#### `search_huggingface` parameters\n\n| Name                  | Type                | Default        | Description                                                      |\n|:---------------------:|:-------------------:|:--------------:|:----------------------------------------------------------------:|\n| search_word           | string, Path        | ー             | The search query string. Can be a keyword, Hugging Face URL, local directory or file path, or a Hugging Face path (`<creator>/<repo>`). |\n| checkpoint_format     | string              | `single_file`  | The format of the model checkpoint. <br>● `single_file` to search for `single file checkpoint` <br>●`diffusers` to search for `multifolder diffusers format checkpoint` |\n| pipeline_tag          | string              | None           | Tag to filter models by pipeline.                                |\n| download              | bool                | False          | Whether to download the model.                                   |\n| force_download        | bool                | False          | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. |\n| cache_dir             | string, Path        | None           | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used.   |\n| token                 | string, bool        | None           | The token to use as HTTP bearer authorization for remote files.  |\n| include_params        | bool                | False         | Whether to include parameters in the returned data.               |\n| skip_error            | bool                | False         | Whether to skip errors and return None.                           |\n"
  },
  {
    "path": "examples/model_search/pipeline_easy.py",
    "content": "# coding=utf-8\n# Copyright 2025 suzukimain\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport re\nimport types\nfrom collections import OrderedDict\nfrom dataclasses import asdict, dataclass, field\nfrom typing import Dict, List, Optional, Union\n\nimport requests\nimport torch\nfrom huggingface_hub import hf_api, hf_hub_download\nfrom huggingface_hub.file_download import http_get\nfrom huggingface_hub.utils import validate_hf_hub_args\n\nfrom diffusers.loaders.single_file_utils import (\n    VALID_URL_PREFIXES,\n    _extract_repo_id_and_weights_name,\n    infer_diffusers_model_type,\n    load_single_file_checkpoint,\n)\nfrom diffusers.pipelines.animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline\nfrom diffusers.pipelines.auto_pipeline import (\n    AutoPipelineForImage2Image,\n    AutoPipelineForInpainting,\n    AutoPipelineForText2Image,\n)\nfrom diffusers.pipelines.controlnet import (\n    StableDiffusionControlNetImg2ImgPipeline,\n    StableDiffusionControlNetInpaintPipeline,\n    StableDiffusionControlNetPipeline,\n    StableDiffusionXLControlNetImg2ImgPipeline,\n    StableDiffusionXLControlNetPipeline,\n)\nfrom diffusers.pipelines.flux import FluxImg2ImgPipeline, FluxPipeline\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion import (\n    StableDiffusionImg2ImgPipeline,\n    StableDiffusionInpaintPipeline,\n    StableDiffusionPipeline,\n    StableDiffusionUpscalePipeline,\n)\nfrom diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline\nfrom diffusers.pipelines.stable_diffusion_xl import (\n    StableDiffusionXLImg2ImgPipeline,\n    StableDiffusionXLInpaintPipeline,\n    StableDiffusionXLPipeline,\n)\nfrom diffusers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\n\nSINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING = OrderedDict(\n    [\n        (\"animatediff_rgb\", AnimateDiffPipeline),\n        (\"animatediff_scribble\", AnimateDiffPipeline),\n        (\"animatediff_sdxl_beta\", AnimateDiffSDXLPipeline),\n        (\"animatediff_v1\", AnimateDiffPipeline),\n        (\"animatediff_v2\", AnimateDiffPipeline),\n        (\"animatediff_v3\", AnimateDiffPipeline),\n        (\"autoencoder-dc-f128c512\", None),\n        (\"autoencoder-dc-f32c32\", None),\n        (\"autoencoder-dc-f32c32-sana\", None),\n        (\"autoencoder-dc-f64c128\", None),\n        (\"controlnet\", StableDiffusionControlNetPipeline),\n        (\"controlnet_xl\", StableDiffusionXLControlNetPipeline),\n        (\"controlnet_xl_large\", StableDiffusionXLControlNetPipeline),\n        (\"controlnet_xl_mid\", StableDiffusionXLControlNetPipeline),\n        (\"controlnet_xl_small\", StableDiffusionXLControlNetPipeline),\n        (\"flux-depth\", FluxPipeline),\n        (\"flux-dev\", FluxPipeline),\n        (\"flux-fill\", FluxPipeline),\n        (\"flux-schnell\", FluxPipeline),\n        (\"hunyuan-video\", None),\n        (\"inpainting\", None),\n        (\"inpainting_v2\", None),\n        (\"ltx-video\", None),\n        (\"ltx-video-0.9.1\", None),\n        (\"mochi-1-preview\", None),\n        (\"playground-v2-5\", StableDiffusionXLPipeline),\n        (\"sd3\", StableDiffusion3Pipeline),\n        (\"sd35_large\", StableDiffusion3Pipeline),\n        (\"sd35_medium\", StableDiffusion3Pipeline),\n        (\"stable_cascade_stage_b\", None),\n        (\"stable_cascade_stage_b_lite\", None),\n        (\"stable_cascade_stage_c\", None),\n        (\"stable_cascade_stage_c_lite\", None),\n        (\"upscale\", StableDiffusionUpscalePipeline),\n        (\"v1\", StableDiffusionPipeline),\n        (\"v2\", StableDiffusionPipeline),\n        (\"xl_base\", StableDiffusionXLPipeline),\n        (\"xl_inpaint\", None),\n        (\"xl_refiner\", StableDiffusionXLPipeline),\n    ]\n)\n\nSINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING = OrderedDict(\n    [\n        (\"animatediff_rgb\", AnimateDiffPipeline),\n        (\"animatediff_scribble\", AnimateDiffPipeline),\n        (\"animatediff_sdxl_beta\", AnimateDiffSDXLPipeline),\n        (\"animatediff_v1\", AnimateDiffPipeline),\n        (\"animatediff_v2\", AnimateDiffPipeline),\n        (\"animatediff_v3\", AnimateDiffPipeline),\n        (\"autoencoder-dc-f128c512\", None),\n        (\"autoencoder-dc-f32c32\", None),\n        (\"autoencoder-dc-f32c32-sana\", None),\n        (\"autoencoder-dc-f64c128\", None),\n        (\"controlnet\", StableDiffusionControlNetImg2ImgPipeline),\n        (\"controlnet_xl\", StableDiffusionXLControlNetImg2ImgPipeline),\n        (\"controlnet_xl_large\", StableDiffusionXLControlNetImg2ImgPipeline),\n        (\"controlnet_xl_mid\", StableDiffusionXLControlNetImg2ImgPipeline),\n        (\"controlnet_xl_small\", StableDiffusionXLControlNetImg2ImgPipeline),\n        (\"flux-depth\", FluxImg2ImgPipeline),\n        (\"flux-dev\", FluxImg2ImgPipeline),\n        (\"flux-fill\", FluxImg2ImgPipeline),\n        (\"flux-schnell\", FluxImg2ImgPipeline),\n        (\"hunyuan-video\", None),\n        (\"inpainting\", None),\n        (\"inpainting_v2\", None),\n        (\"ltx-video\", None),\n        (\"ltx-video-0.9.1\", None),\n        (\"mochi-1-preview\", None),\n        (\"playground-v2-5\", StableDiffusionXLImg2ImgPipeline),\n        (\"sd3\", StableDiffusion3Img2ImgPipeline),\n        (\"sd35_large\", StableDiffusion3Img2ImgPipeline),\n        (\"sd35_medium\", StableDiffusion3Img2ImgPipeline),\n        (\"stable_cascade_stage_b\", None),\n        (\"stable_cascade_stage_b_lite\", None),\n        (\"stable_cascade_stage_c\", None),\n        (\"stable_cascade_stage_c_lite\", None),\n        (\"upscale\", StableDiffusionUpscalePipeline),\n        (\"v1\", StableDiffusionImg2ImgPipeline),\n        (\"v2\", StableDiffusionImg2ImgPipeline),\n        (\"xl_base\", StableDiffusionXLImg2ImgPipeline),\n        (\"xl_inpaint\", None),\n        (\"xl_refiner\", StableDiffusionXLImg2ImgPipeline),\n    ]\n)\n\nSINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING = OrderedDict(\n    [\n        (\"animatediff_rgb\", None),\n        (\"animatediff_scribble\", None),\n        (\"animatediff_sdxl_beta\", None),\n        (\"animatediff_v1\", None),\n        (\"animatediff_v2\", None),\n        (\"animatediff_v3\", None),\n        (\"autoencoder-dc-f128c512\", None),\n        (\"autoencoder-dc-f32c32\", None),\n        (\"autoencoder-dc-f32c32-sana\", None),\n        (\"autoencoder-dc-f64c128\", None),\n        (\"controlnet\", StableDiffusionControlNetInpaintPipeline),\n        (\"controlnet_xl\", None),\n        (\"controlnet_xl_large\", None),\n        (\"controlnet_xl_mid\", None),\n        (\"controlnet_xl_small\", None),\n        (\"flux-depth\", None),\n        (\"flux-dev\", None),\n        (\"flux-fill\", None),\n        (\"flux-schnell\", None),\n        (\"hunyuan-video\", None),\n        (\"inpainting\", StableDiffusionInpaintPipeline),\n        (\"inpainting_v2\", StableDiffusionInpaintPipeline),\n        (\"ltx-video\", None),\n        (\"ltx-video-0.9.1\", None),\n        (\"mochi-1-preview\", None),\n        (\"playground-v2-5\", None),\n        (\"sd3\", None),\n        (\"sd35_large\", None),\n        (\"sd35_medium\", None),\n        (\"stable_cascade_stage_b\", None),\n        (\"stable_cascade_stage_b_lite\", None),\n        (\"stable_cascade_stage_c\", None),\n        (\"stable_cascade_stage_c_lite\", None),\n        (\"upscale\", StableDiffusionUpscalePipeline),\n        (\"v1\", None),\n        (\"v2\", None),\n        (\"xl_base\", None),\n        (\"xl_inpaint\", StableDiffusionXLInpaintPipeline),\n        (\"xl_refiner\", None),\n    ]\n)\n\n\nCONFIG_FILE_LIST = [\n    \"pytorch_model.bin\",\n    \"pytorch_model.fp16.bin\",\n    \"diffusion_pytorch_model.bin\",\n    \"diffusion_pytorch_model.fp16.bin\",\n    \"diffusion_pytorch_model.safetensors\",\n    \"diffusion_pytorch_model.fp16.safetensors\",\n    \"diffusion_pytorch_model.ckpt\",\n    \"diffusion_pytorch_model.fp16.ckpt\",\n    \"diffusion_pytorch_model.non_ema.bin\",\n    \"diffusion_pytorch_model.non_ema.safetensors\",\n]\n\nDIFFUSERS_CONFIG_DIR = [\n    \"safety_checker\",\n    \"unet\",\n    \"vae\",\n    \"text_encoder\",\n    \"text_encoder_2\",\n]\n\nTOKENIZER_SHAPE_MAP = {\n    768: [\n        \"SD 1.4\",\n        \"SD 1.5\",\n        \"SD 1.5 LCM\",\n        \"SDXL 0.9\",\n        \"SDXL 1.0\",\n        \"SDXL 1.0 LCM\",\n        \"SDXL Distilled\",\n        \"SDXL Turbo\",\n        \"SDXL Lightning\",\n        \"PixArt a\",\n        \"Playground v2\",\n        \"Pony\",\n    ],\n    1024: [\"SD 2.0\", \"SD 2.0 768\", \"SD 2.1\", \"SD 2.1 768\", \"SD 2.1 Unclip\"],\n}\n\n\nEXTENSION = [\".safetensors\", \".ckpt\", \".bin\"]\n\nCACHE_HOME = os.path.expanduser(\"~/.cache\")\n\n\n@dataclass\nclass RepoStatus:\n    r\"\"\"\n    Data class for storing repository status information.\n\n    Attributes:\n        repo_id (`str`):\n            The name of the repository.\n        repo_hash (`str`):\n            The hash of the repository.\n        version (`str`):\n            The version ID of the repository.\n    \"\"\"\n\n    repo_id: str = \"\"\n    repo_hash: str = \"\"\n    version: str = \"\"\n\n\n@dataclass\nclass ModelStatus:\n    r\"\"\"\n    Data class for storing model status information.\n\n    Attributes:\n        search_word (`str`):\n            The search word used to find the model.\n        download_url (`str`):\n            The URL to download the model.\n        file_name (`str`):\n            The name of the model file.\n        local (`bool`):\n            Whether the model exists locally\n        site_url (`str`):\n            The URL of the site where the model is hosted.\n    \"\"\"\n\n    search_word: str = \"\"\n    download_url: str = \"\"\n    file_name: str = \"\"\n    local: bool = False\n    site_url: str = \"\"\n\n\n@dataclass\nclass ExtraStatus:\n    r\"\"\"\n    Data class for storing extra status information.\n\n    Attributes:\n        trained_words (`str`):\n            The words used to trigger the model\n    \"\"\"\n\n    trained_words: Union[List[str], None] = None\n\n\n@dataclass\nclass SearchResult:\n    r\"\"\"\n    Data class for storing model data.\n\n    Attributes:\n        model_path (`str`):\n            The path to the model.\n        loading_method (`str`):\n            The type of loading method used for the model ( None or 'from_single_file' or 'from_pretrained')\n        checkpoint_format (`str`):\n            The format of the model checkpoint (`single_file` or `diffusers`).\n        repo_status (`RepoStatus`):\n            The status of the repository.\n        model_status (`ModelStatus`):\n            The status of the model.\n    \"\"\"\n\n    model_path: str = \"\"\n    loading_method: Union[str, None] = None\n    checkpoint_format: Union[str, None] = None\n    repo_status: RepoStatus = field(default_factory=RepoStatus)\n    model_status: ModelStatus = field(default_factory=ModelStatus)\n    extra_status: ExtraStatus = field(default_factory=ExtraStatus)\n\n\n@validate_hf_hub_args\ndef load_pipeline_from_single_file(pretrained_model_or_path, pipeline_mapping, **kwargs):\n    r\"\"\"\n    Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`\n    format. The pipeline is set in evaluation mode (`model.eval()`) by default.\n\n    Parameters:\n        pretrained_model_or_path (`str` or `os.PathLike`, *optional*):\n            Can be either:\n                - A link to the `.ckpt` file (for example\n                  `\"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt\"`) on the Hub.\n                - A path to a *file* containing all pipeline weights.\n        pipeline_mapping (`dict`):\n            A mapping of model types to their corresponding pipeline classes. This is used to determine\n            which pipeline class to instantiate based on the model type inferred from the checkpoint.\n        torch_dtype (`str` or `torch.dtype`, *optional*):\n            Override the default `torch.dtype` and load the model with another dtype.\n        force_download (`bool`, *optional*, defaults to `False`):\n            Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n            cached versions if they exist.\n        cache_dir (`Union[str, os.PathLike]`, *optional*):\n            Path to a directory where a downloaded pretrained model configuration is cached if the standard cache\n            is not used.\n        proxies (`Dict[str, str]`, *optional*):\n            A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',\n            'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n        local_files_only (`bool`, *optional*, defaults to `False`):\n            Whether to only load local model weights and configuration files or not. If set to `True`, the model\n            won't be downloaded from the Hub.\n        token (`str` or *bool*, *optional*):\n            The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from\n            `diffusers-cli login` (stored in `~/.huggingface`) is used.\n        revision (`str`, *optional*, defaults to `\"main\"`):\n            The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier\n            allowed by Git.\n        original_config_file (`str`, *optional*):\n            The path to the original config file that was used to train the model. If not provided, the config file\n            will be inferred from the checkpoint file.\n        config (`str`, *optional*):\n            Can be either:\n                - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline\n                  hosted on the Hub.\n                - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline\n                  component configs in Diffusers format.\n        checkpoint (`dict`, *optional*):\n            The loaded state dictionary of the model.\n        kwargs (remaining dictionary of keyword arguments, *optional*):\n            Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline\n            class). The overwritten components are passed directly to the pipelines `__init__` method. See example\n            below for more information.\n    \"\"\"\n\n    # Load the checkpoint from the provided link or path\n    checkpoint = load_single_file_checkpoint(pretrained_model_or_path)\n\n    # Infer the model type from the loaded checkpoint\n    model_type = infer_diffusers_model_type(checkpoint)\n\n    # Get the corresponding pipeline class from the pipeline mapping\n    pipeline_class = pipeline_mapping[model_type]\n\n    # For tasks not supported by this pipeline\n    if pipeline_class is None:\n        raise ValueError(\n            f\"{model_type} is not supported in this pipeline.\"\n            \"For `Text2Image`, please use `AutoPipelineForText2Image.from_pretrained`, \"\n            \"for `Image2Image` , please use `AutoPipelineForImage2Image.from_pretrained`, \"\n            \"and `inpaint` is only supported in `AutoPipelineForInpainting.from_pretrained`\"\n        )\n\n    else:\n        # Instantiate and return the pipeline with the loaded checkpoint and any additional kwargs\n        return pipeline_class.from_single_file(pretrained_model_or_path, **kwargs)\n\n\ndef get_keyword_types(keyword):\n    r\"\"\"\n    Determine the type and loading method for a given keyword.\n\n    Parameters:\n        keyword (`str`):\n            The input keyword to classify.\n\n    Returns:\n        `dict`: A dictionary containing the model format, loading method,\n                and various types and extra types flags.\n    \"\"\"\n\n    # Initialize the status dictionary with default values\n    status = {\n        \"checkpoint_format\": None,\n        \"loading_method\": None,\n        \"type\": {\n            \"other\": False,\n            \"hf_url\": False,\n            \"hf_repo\": False,\n            \"civitai_url\": False,\n            \"local\": False,\n        },\n        \"extra_type\": {\n            \"url\": False,\n            \"missing_model_index\": None,\n        },\n    }\n\n    # Check if the keyword is an HTTP or HTTPS URL\n    status[\"extra_type\"][\"url\"] = bool(re.search(r\"^(https?)://\", keyword))\n\n    # Check if the keyword is a file\n    if os.path.isfile(keyword):\n        status[\"type\"][\"local\"] = True\n        status[\"checkpoint_format\"] = \"single_file\"\n        status[\"loading_method\"] = \"from_single_file\"\n\n    # Check if the keyword is a directory\n    elif os.path.isdir(keyword):\n        status[\"type\"][\"local\"] = True\n        status[\"checkpoint_format\"] = \"diffusers\"\n        status[\"loading_method\"] = \"from_pretrained\"\n        if not os.path.exists(os.path.join(keyword, \"model_index.json\")):\n            status[\"extra_type\"][\"missing_model_index\"] = True\n\n    # Check if the keyword is a Civitai URL\n    elif keyword.startswith(\"https://civitai.com/\"):\n        status[\"type\"][\"civitai_url\"] = True\n        status[\"checkpoint_format\"] = \"single_file\"\n        status[\"loading_method\"] = None\n\n    # Check if the keyword starts with any valid URL prefixes\n    elif any(keyword.startswith(prefix) for prefix in VALID_URL_PREFIXES):\n        repo_id, weights_name = _extract_repo_id_and_weights_name(keyword)\n        if weights_name:\n            status[\"type\"][\"hf_url\"] = True\n            status[\"checkpoint_format\"] = \"single_file\"\n            status[\"loading_method\"] = \"from_single_file\"\n        else:\n            status[\"type\"][\"hf_repo\"] = True\n            status[\"checkpoint_format\"] = \"diffusers\"\n            status[\"loading_method\"] = \"from_pretrained\"\n\n    # Check if the keyword matches a Hugging Face repository format\n    elif re.match(r\"^[^/]+/[^/]+$\", keyword):\n        status[\"type\"][\"hf_repo\"] = True\n        status[\"checkpoint_format\"] = \"diffusers\"\n        status[\"loading_method\"] = \"from_pretrained\"\n\n    # If none of the above apply\n    else:\n        status[\"type\"][\"other\"] = True\n        status[\"checkpoint_format\"] = None\n        status[\"loading_method\"] = None\n\n    return status\n\n\ndef file_downloader(\n    url,\n    save_path,\n    **kwargs,\n) -> None:\n    \"\"\"\n    Downloads a file from a given URL and saves it to the specified path.\n\n    parameters:\n        url (`str`):\n            The URL of the file to download.\n        save_path (`str`):\n            The local path where the file will be saved.\n        resume (`bool`, *optional*, defaults to `False`):\n            Whether to resume an incomplete download.\n        headers (`dict`, *optional*, defaults to `None`):\n            Dictionary of HTTP Headers to send with the request.\n        proxies (`dict`, *optional*, defaults to `None`):\n            Dictionary mapping protocol to the URL of the proxy passed to `requests.request`.\n        force_download (`bool`, *optional*, defaults to `False`):\n            Whether to force the download even if the file already exists.\n        displayed_filename (`str`, *optional*):\n            The filename of the file that is being downloaded. Value is used only to display a nice progress bar. If\n            not set, the filename is guessed from the URL or the `Content-Disposition` header.\n\n    returns:\n        None\n    \"\"\"\n\n    # Get optional parameters from kwargs, with their default values\n    resume = kwargs.pop(\"resume\", False)\n    headers = kwargs.pop(\"headers\", None)\n    proxies = kwargs.pop(\"proxies\", None)\n    force_download = kwargs.pop(\"force_download\", False)\n    displayed_filename = kwargs.pop(\"displayed_filename\", None)\n\n    # Default mode for file writing and initial file size\n    mode = \"wb\"\n    file_size = 0\n\n    # Create directory\n    os.makedirs(os.path.dirname(save_path), exist_ok=True)\n\n    # Check if the file already exists at the save path\n    if os.path.exists(save_path):\n        if not force_download:\n            # If the file exists and force_download is False, skip the download\n            logger.info(f\"File already exists: {save_path}, skipping download.\")\n            return None\n        elif resume:\n            # If resuming, set mode to append binary and get current file size\n            mode = \"ab\"\n            file_size = os.path.getsize(save_path)\n\n    # Open the file in the appropriate mode (write or append)\n    with open(save_path, mode) as model_file:\n        # Call the http_get function to perform the file download\n        return http_get(\n            url=url,\n            temp_file=model_file,\n            resume_size=file_size,\n            displayed_filename=displayed_filename,\n            headers=headers,\n            proxies=proxies,\n            **kwargs,\n        )\n\n\ndef search_huggingface(search_word: str, **kwargs) -> Union[str, SearchResult, None]:\n    r\"\"\"\n    Downloads a model from Hugging Face.\n\n    Parameters:\n        search_word (`str`):\n            The search query string.\n        revision (`str`, *optional*):\n            The specific version of the model to download.\n        checkpoint_format (`str`, *optional*, defaults to `\"single_file\"`):\n            The format of the model checkpoint.\n        download (`bool`, *optional*, defaults to `False`):\n            Whether to download the model.\n        force_download (`bool`, *optional*, defaults to `False`):\n            Whether to force the download if the model already exists.\n        include_params (`bool`, *optional*, defaults to `False`):\n            Whether to include parameters in the returned data.\n        pipeline_tag (`str`, *optional*):\n            Tag to filter models by pipeline.\n        token (`str`, *optional*):\n            API token for Hugging Face authentication.\n        gated (`bool`, *optional*, defaults to `False` ):\n            A boolean to filter models on the Hub that are gated or not.\n        skip_error (`bool`, *optional*, defaults to `False`):\n            Whether to skip errors and return None.\n\n    Returns:\n        `Union[str,  SearchResult, None]`: The model path or  SearchResult or None.\n    \"\"\"\n    # Extract additional parameters from kwargs\n    revision = kwargs.pop(\"revision\", None)\n    checkpoint_format = kwargs.pop(\"checkpoint_format\", \"single_file\")\n    download = kwargs.pop(\"download\", False)\n    force_download = kwargs.pop(\"force_download\", False)\n    include_params = kwargs.pop(\"include_params\", False)\n    pipeline_tag = kwargs.pop(\"pipeline_tag\", None)\n    token = kwargs.pop(\"token\", None)\n    gated = kwargs.pop(\"gated\", False)\n    skip_error = kwargs.pop(\"skip_error\", False)\n\n    file_list = []\n    hf_repo_info = {}\n    hf_security_info = {}\n    model_path = \"\"\n    repo_id, file_name = \"\", \"\"\n    diffusers_model_exists = False\n\n    # Get the type and loading method for the keyword\n    search_word_status = get_keyword_types(search_word)\n\n    if search_word_status[\"type\"][\"hf_repo\"]:\n        hf_repo_info = hf_api.model_info(repo_id=search_word, securityStatus=True)\n        if download:\n            model_path = DiffusionPipeline.download(\n                search_word,\n                revision=revision,\n                token=token,\n                force_download=force_download,\n                **kwargs,\n            )\n        else:\n            model_path = search_word\n    elif search_word_status[\"type\"][\"hf_url\"]:\n        repo_id, weights_name = _extract_repo_id_and_weights_name(search_word)\n        if download:\n            model_path = hf_hub_download(\n                repo_id=repo_id,\n                filename=weights_name,\n                force_download=force_download,\n                token=token,\n            )\n        else:\n            model_path = search_word\n    elif search_word_status[\"type\"][\"local\"]:\n        model_path = search_word\n    elif search_word_status[\"type\"][\"civitai_url\"]:\n        if skip_error:\n            return None\n        else:\n            raise ValueError(\"The URL for Civitai is invalid with `for_hf`. Please use `for_civitai` instead.\")\n    else:\n        # Get model data from HF API\n        hf_models = hf_api.list_models(\n            search=search_word,\n            direction=-1,\n            limit=100,\n            fetch_config=True,\n            pipeline_tag=pipeline_tag,\n            full=True,\n            gated=gated,\n            token=token,\n        )\n        model_dicts = [asdict(value) for value in list(hf_models)]\n\n        # Loop through models to find a suitable candidate\n        for repo_info in model_dicts:\n            repo_id = repo_info[\"id\"]\n            file_list = []\n            hf_repo_info = hf_api.model_info(repo_id=repo_id, securityStatus=True)\n            # Lists files with security issues.\n            hf_security_info = hf_repo_info.security_repo_status\n            exclusion = [issue[\"path\"] for issue in hf_security_info[\"filesWithIssues\"]]\n\n            # Checks for multi-folder diffusers model or valid files (models with security issues are excluded).\n            if hf_security_info[\"scansDone\"]:\n                for info in repo_info[\"siblings\"]:\n                    file_path = info[\"rfilename\"]\n                    if \"model_index.json\" == file_path and checkpoint_format in [\n                        \"diffusers\",\n                        \"all\",\n                    ]:\n                        diffusers_model_exists = True\n                        break\n\n                    elif (\n                        any(file_path.endswith(ext) for ext in EXTENSION)\n                        and not any(config in file_path for config in CONFIG_FILE_LIST)\n                        and not any(exc in file_path for exc in exclusion)\n                        and os.path.basename(os.path.dirname(file_path)) not in DIFFUSERS_CONFIG_DIR\n                    ):\n                        file_list.append(file_path)\n\n            # Exit from the loop if a multi-folder diffusers model or valid file is found\n            if diffusers_model_exists or file_list:\n                break\n        else:\n            # Handle case where no models match the criteria\n            if skip_error:\n                return None\n            else:\n                raise ValueError(\"No models matching your criteria were found on huggingface.\")\n\n        if diffusers_model_exists:\n            if download:\n                model_path = DiffusionPipeline.download(\n                    repo_id,\n                    token=token,\n                    **kwargs,\n                )\n            else:\n                model_path = repo_id\n\n        elif file_list:\n            # Sort and find the safest model\n            file_name = next(\n                (model for model in sorted(file_list, reverse=True) if re.search(r\"(?i)[-_](safe|sfw)\", model)),\n                file_list[0],\n            )\n\n            if download:\n                model_path = hf_hub_download(\n                    repo_id=repo_id,\n                    filename=file_name,\n                    revision=revision,\n                    token=token,\n                    force_download=force_download,\n                )\n\n    # `pathlib.PosixPath` may be returned\n    if model_path:\n        model_path = str(model_path)\n\n    if file_name:\n        download_url = f\"https://huggingface.co/{repo_id}/blob/main/{file_name}\"\n    else:\n        download_url = f\"https://huggingface.co/{repo_id}\"\n\n    output_info = get_keyword_types(model_path)\n\n    if include_params:\n        return SearchResult(\n            model_path=model_path or download_url,\n            loading_method=output_info[\"loading_method\"],\n            checkpoint_format=output_info[\"checkpoint_format\"],\n            repo_status=RepoStatus(repo_id=repo_id, repo_hash=hf_repo_info.sha, version=revision),\n            model_status=ModelStatus(\n                search_word=search_word,\n                site_url=download_url,\n                download_url=download_url,\n                file_name=file_name,\n                local=download,\n            ),\n            extra_status=ExtraStatus(trained_words=None),\n        )\n\n    else:\n        return model_path\n\n\ndef search_civitai(search_word: str, **kwargs) -> Union[str, SearchResult, None]:\n    r\"\"\"\n    Downloads a model from Civitai.\n\n    Parameters:\n        search_word (`str`):\n            The search query string.\n        model_type (`str`, *optional*, defaults to `Checkpoint`):\n            The type of model to search for.\n        sort (`str`, *optional*):\n            The order in which you wish to sort the results(for example, `Highest Rated`, `Most Downloaded`, `Newest`).\n        base_model (`str`, *optional*):\n            The base model to filter by.\n        download (`bool`, *optional*, defaults to `False`):\n            Whether to download the model.\n        force_download (`bool`, *optional*, defaults to `False`):\n            Whether to force the download if the model already exists.\n        token (`str`, *optional*):\n            API token for Civitai authentication.\n        include_params (`bool`, *optional*, defaults to `False`):\n            Whether to include parameters in the returned data.\n        cache_dir (`str`, `Path`, *optional*):\n            Path to the folder where cached files are stored.\n        resume (`bool`, *optional*, defaults to `False`):\n            Whether to resume an incomplete download.\n        skip_error (`bool`, *optional*, defaults to `False`):\n            Whether to skip errors and return None.\n\n    Returns:\n        `Union[str,  SearchResult, None]`: The model path or ` SearchResult` or None.\n    \"\"\"\n\n    # Extract additional parameters from kwargs\n    model_type = kwargs.pop(\"model_type\", \"Checkpoint\")\n    sort = kwargs.pop(\"sort\", None)\n    download = kwargs.pop(\"download\", False)\n    base_model = kwargs.pop(\"base_model\", None)\n    force_download = kwargs.pop(\"force_download\", False)\n    token = kwargs.pop(\"token\", None)\n    include_params = kwargs.pop(\"include_params\", False)\n    resume = kwargs.pop(\"resume\", False)\n    cache_dir = kwargs.pop(\"cache_dir\", None)\n    skip_error = kwargs.pop(\"skip_error\", False)\n\n    # Initialize additional variables with default values\n    model_path = \"\"\n    repo_name = \"\"\n    repo_id = \"\"\n    version_id = \"\"\n    trainedWords = \"\"\n    models_list = []\n    selected_repo = {}\n    selected_model = {}\n    selected_version = {}\n    civitai_cache_dir = cache_dir or os.path.join(CACHE_HOME, \"Civitai\")\n\n    # Set up parameters and headers for the CivitAI API request\n    params = {\n        \"query\": search_word,\n        \"types\": model_type,\n        \"limit\": 20,\n    }\n    if base_model is not None:\n        if not isinstance(base_model, list):\n            base_model = [base_model]\n        params[\"baseModel\"] = base_model\n\n    if sort is not None:\n        params[\"sort\"] = sort\n\n    headers = {}\n    if token:\n        headers[\"Authorization\"] = f\"Bearer {token}\"\n\n    try:\n        # Make the request to the CivitAI API\n        response = requests.get(\"https://civitai.com/api/v1/models\", params=params, headers=headers)\n        response.raise_for_status()\n    except requests.exceptions.HTTPError as err:\n        raise requests.HTTPError(f\"Could not get elements from the URL: {err}\")\n    else:\n        try:\n            data = response.json()\n        except AttributeError:\n            if skip_error:\n                return None\n            else:\n                raise ValueError(\"Invalid JSON response\")\n\n    # Sort repositories by download count in descending order\n    sorted_repos = sorted(data[\"items\"], key=lambda x: x[\"stats\"][\"downloadCount\"], reverse=True)\n\n    for selected_repo in sorted_repos:\n        repo_name = selected_repo[\"name\"]\n        repo_id = selected_repo[\"id\"]\n\n        # Sort versions within the selected repo by download count\n        sorted_versions = sorted(\n            selected_repo[\"modelVersions\"],\n            key=lambda x: x[\"stats\"][\"downloadCount\"],\n            reverse=True,\n        )\n        for selected_version in sorted_versions:\n            version_id = selected_version[\"id\"]\n            trainedWords = selected_version[\"trainedWords\"]\n            models_list = []\n            # When searching for textual inversion, results other than the values entered for the base model may come up, so check again.\n            if base_model is None or selected_version[\"baseModel\"] in base_model:\n                for model_data in selected_version[\"files\"]:\n                    # Check if the file passes security scans and has a valid extension\n                    file_name = model_data[\"name\"]\n                    if (\n                        model_data[\"pickleScanResult\"] == \"Success\"\n                        and model_data[\"virusScanResult\"] == \"Success\"\n                        and any(file_name.endswith(ext) for ext in EXTENSION)\n                        and os.path.basename(os.path.dirname(file_name)) not in DIFFUSERS_CONFIG_DIR\n                    ):\n                        file_status = {\n                            \"filename\": file_name,\n                            \"download_url\": model_data[\"downloadUrl\"],\n                        }\n                        models_list.append(file_status)\n\n            if models_list:\n                # Sort the models list by filename and find the safest model\n                sorted_models = sorted(models_list, key=lambda x: x[\"filename\"], reverse=True)\n                selected_model = next(\n                    (\n                        model_data\n                        for model_data in sorted_models\n                        if bool(re.search(r\"(?i)[-_](safe|sfw)\", model_data[\"filename\"]))\n                    ),\n                    sorted_models[0],\n                )\n\n                break\n        else:\n            continue\n        break\n\n    # Exception handling when search candidates are not found\n    if not selected_model:\n        if skip_error:\n            return None\n        else:\n            raise ValueError(\"No model found. Please try changing the word you are searching for.\")\n\n    # Define model file status\n    file_name = selected_model[\"filename\"]\n    download_url = selected_model[\"download_url\"]\n\n    # Handle file download and setting model information\n    if download:\n        # The path where the model is to be saved.\n        model_path = os.path.join(str(civitai_cache_dir), str(repo_id), str(version_id), str(file_name))\n        # Download Model File\n        file_downloader(\n            url=download_url,\n            save_path=model_path,\n            resume=resume,\n            force_download=force_download,\n            displayed_filename=file_name,\n            headers=headers,\n            **kwargs,\n        )\n\n    else:\n        model_path = download_url\n\n    output_info = get_keyword_types(model_path)\n\n    if not include_params:\n        return model_path\n    else:\n        return SearchResult(\n            model_path=model_path,\n            loading_method=output_info[\"loading_method\"],\n            checkpoint_format=output_info[\"checkpoint_format\"],\n            repo_status=RepoStatus(repo_id=repo_name, repo_hash=repo_id, version=version_id),\n            model_status=ModelStatus(\n                search_word=search_word,\n                site_url=f\"https://civitai.com/models/{repo_id}?modelVersionId={version_id}\",\n                download_url=download_url,\n                file_name=file_name,\n                local=output_info[\"type\"][\"local\"],\n            ),\n            extra_status=ExtraStatus(trained_words=trainedWords or None),\n        )\n\n\ndef add_methods(pipeline):\n    r\"\"\"\n    Add methods from `AutoConfig` to the pipeline.\n\n    Parameters:\n        pipeline (`Pipeline`):\n            The pipeline to which the methods will be added.\n    \"\"\"\n    for attr_name in dir(AutoConfig):\n        attr_value = getattr(AutoConfig, attr_name)\n        if callable(attr_value) and not attr_name.startswith(\"__\"):\n            setattr(pipeline, attr_name, types.MethodType(attr_value, pipeline))\n    return pipeline\n\n\nclass AutoConfig:\n    def auto_load_textual_inversion(\n        self,\n        pretrained_model_name_or_path: Union[str, List[str]],\n        token: Optional[Union[str, List[str]]] = None,\n        base_model: Optional[Union[str, List[str]]] = None,\n        tokenizer=None,\n        text_encoder=None,\n        **kwargs,\n    ):\n        r\"\"\"\n        Load Textual Inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and\n        Automatic1111 formats are supported).\n\n        Parameters:\n            pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):\n                Can be either one of the following or a list of them:\n\n                    - Search keywords for pretrained model (for example `EasyNegative`).\n                    - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a\n                      pretrained model hosted on the Hub.\n                    - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual\n                      inversion weights.\n                    - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.\n                    - A [torch state\n                      dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).\n\n            token (`str` or `List[str]`, *optional*):\n                Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a\n                list, then `token` must also be a list of equal length.\n            text_encoder ([`~transformers.CLIPTextModel`], *optional*):\n                Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n                If not specified, function will take self.tokenizer.\n            tokenizer ([`~transformers.CLIPTokenizer`], *optional*):\n                A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.\n            weight_name (`str`, *optional*):\n                Name of a custom weight file. This should be used when:\n\n                    - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight\n                      name such as `text_inv.bin`.\n                    - The saved textual inversion file is in the Automatic1111 format.\n            cache_dir (`Union[str, os.PathLike]`, *optional*):\n                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache\n                is not used.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n\n            proxies (`Dict[str, str]`, *optional*):\n                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',\n                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n            local_files_only (`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            token (`str` or *bool*, *optional*):\n                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from\n                `diffusers-cli login` (stored in `~/.huggingface`) is used.\n            revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier\n                allowed by Git.\n            subfolder (`str`, *optional*, defaults to `\"\"`):\n                The subfolder location of a model file within a larger model repository on the Hub or locally.\n            mirror (`str`, *optional*):\n                Mirror source to resolve accessibility issues if you're downloading a model in China. We do not\n                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more\n                information.\n\n        Examples:\n\n        ```py\n        >>> from auto_diffusers import EasyPipelineForText2Image\n\n        >>> pipeline = EasyPipelineForText2Image.from_huggingface(\"stable-diffusion-v1-5\")\n\n        >>> pipeline.auto_load_textual_inversion(\"EasyNegative\", token=\"EasyNegative\")\n\n        >>> image = pipeline(prompt).images[0]\n        ```\n\n        \"\"\"\n        # 1. Set tokenizer and text encoder\n        tokenizer = tokenizer or getattr(self, \"tokenizer\", None)\n        text_encoder = text_encoder or getattr(self, \"text_encoder\", None)\n\n        # Check if tokenizer and text encoder are provided\n        if tokenizer is None or text_encoder is None:\n            raise ValueError(\"Tokenizer and text encoder must be provided.\")\n\n        # 2. Normalize inputs\n        pretrained_model_name_or_paths = (\n            [pretrained_model_name_or_path]\n            if not isinstance(pretrained_model_name_or_path, list)\n            else pretrained_model_name_or_path\n        )\n\n        # 2.1 Normalize tokens\n        tokens = [token] if not isinstance(token, list) else token\n        if tokens[0] is None:\n            tokens = tokens * len(pretrained_model_name_or_paths)\n\n        for check_token in tokens:\n            # Check if token is already in tokenizer vocabulary\n            if check_token in tokenizer.get_vocab():\n                raise ValueError(\n                    f\"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder.\"\n                )\n\n        expected_shape = text_encoder.get_input_embeddings().weight.shape[-1]  # Expected shape of tokenizer\n\n        for search_word in pretrained_model_name_or_paths:\n            if isinstance(search_word, str):\n                # Update kwargs to ensure the model is downloaded and parameters are included\n                _status = {\n                    \"download\": True,\n                    \"include_params\": True,\n                    \"skip_error\": False,\n                    \"model_type\": \"TextualInversion\",\n                }\n                # Get tags for the base model of textual inversion compatible with tokenizer.\n                # If the tokenizer is 768-dimensional, set tags for SD 1.x and SDXL.\n                # If the tokenizer is 1024-dimensional, set tags for SD 2.x.\n                if expected_shape in TOKENIZER_SHAPE_MAP:\n                    # Retrieve the appropriate tags from the TOKENIZER_SHAPE_MAP based on the expected shape\n                    tags = TOKENIZER_SHAPE_MAP[expected_shape]\n                    if base_model is not None:\n                        if isinstance(base_model, list):\n                            tags.extend(base_model)\n                        else:\n                            tags.append(base_model)\n                    _status[\"base_model\"] = tags\n\n                kwargs.update(_status)\n                # Search for the model on Civitai and get the model status\n                textual_inversion_path = search_civitai(search_word, **kwargs)\n                logger.warning(\n                    f\"textual_inversion_path: {search_word} -> {textual_inversion_path.model_status.site_url}\"\n                )\n\n                pretrained_model_name_or_paths[pretrained_model_name_or_paths.index(search_word)] = (\n                    textual_inversion_path.model_path\n                )\n\n        self.load_textual_inversion(\n            pretrained_model_name_or_paths, token=tokens, tokenizer=tokenizer, text_encoder=text_encoder, **kwargs\n        )\n\n    def auto_load_lora_weights(\n        self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs\n    ):\n        r\"\"\"\n        Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and\n        `self.text_encoder`.\n\n        All kwargs are forwarded to `self.lora_state_dict`.\n\n        See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is\n        loaded.\n\n        See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is\n        loaded into `self.unet`.\n\n        See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state\n        dict is loaded into `self.text_encoder`.\n\n        Parameters:\n            pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):\n                See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].\n            adapter_name (`str`, *optional*):\n                Adapter name to be used for referencing the loaded adapter model. If not specified, it will use\n                `default_{i}` where i is the total number of adapters being loaded.\n            low_cpu_mem_usage (`bool`, *optional*):\n                Speed up model loading by only loading the pretrained LoRA weights and not initializing the random\n                weights.\n            kwargs (`dict`, *optional*):\n                See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].\n        \"\"\"\n        if isinstance(pretrained_model_name_or_path_or_dict, str):\n            # Update kwargs to ensure the model is downloaded and parameters are included\n            _status = {\n                \"download\": True,\n                \"include_params\": True,\n                \"skip_error\": False,\n                \"model_type\": \"LORA\",\n            }\n            kwargs.update(_status)\n            # Search for the model on Civitai and get the model status\n            lora_path = search_civitai(pretrained_model_name_or_path_or_dict, **kwargs)\n            logger.warning(f\"lora_path: {lora_path.model_status.site_url}\")\n            logger.warning(f\"trained_words: {lora_path.extra_status.trained_words}\")\n            pretrained_model_name_or_path_or_dict = lora_path.model_path\n\n        self.load_lora_weights(pretrained_model_name_or_path_or_dict, adapter_name=adapter_name, **kwargs)\n\n\nclass EasyPipelineForText2Image(AutoPipelineForText2Image):\n    r\"\"\"\n    [`EasyPipelineForText2Image`] is a generic pipeline class that instantiates a text-to-image pipeline class. The\n    specific underlying pipeline class is automatically selected from either the\n    [`~EasyPipelineForText2Image.from_pretrained`], [`~EasyPipelineForText2Image.from_pipe`], [`~EasyPipelineForText2Image.from_huggingface`] or [`~EasyPipelineForText2Image.from_civitai`] methods.\n\n    This class cannot be instantiated using `__init__()` (throws an error).\n\n    Class attributes:\n\n        - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the\n          diffusion pipeline's components.\n\n    \"\"\"\n\n    config_name = \"model_index.json\"\n\n    def __init__(self, *args, **kwargs):\n        # EnvironmentError is returned\n        super().__init__()\n\n    @classmethod\n    @validate_hf_hub_args\n    def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):\n        r\"\"\"\n        Parameters:\n            pretrained_model_or_path (`str` or `os.PathLike`, *optional*):\n                Can be either:\n\n                    - A keyword to search for Hugging Face (for example `Stable Diffusion`)\n                    - Link to `.ckpt` or `.safetensors` file (for example\n                      `\"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors\"`) on the Hub.\n                    - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline\n                      hosted on the Hub.\n                    - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights\n                      saved using\n                    [`~DiffusionPipeline.save_pretrained`].\n            checkpoint_format (`str`, *optional*, defaults to `\"single_file\"`):\n                The format of the model checkpoint.\n            pipeline_tag (`str`, *optional*):\n                Tag to filter models by pipeline.\n            torch_dtype (`str` or `torch.dtype`, *optional*):\n                Override the default `torch.dtype` and load the model with another dtype. If \"auto\" is passed, the\n                dtype is automatically derived from the model's weights.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            cache_dir (`Union[str, os.PathLike]`, *optional*):\n                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache\n                is not used.\n\n            proxies (`Dict[str, str]`, *optional*):\n                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',\n                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n            output_loading_info(`bool`, *optional*, defaults to `False`):\n                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.\n            local_files_only (`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            token (`str` or *bool*, *optional*):\n                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from\n                `diffusers-cli login` (stored in `~/.huggingface`) is used.\n            revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier\n                allowed by Git.\n            custom_revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, or a commit id similar to\n                `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a\n                custom pipeline from GitHub, otherwise it defaults to `\"main\"` when loading from the Hub.\n            mirror (`str`, *optional*):\n                Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not\n                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more\n                information.\n            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):\n                A map that specifies where each submodule should go. It doesn’t need to be defined for each\n                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the\n                same device.\n\n                Set `device_map=\"auto\"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For\n                more information about each option see [designing a device\n                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).\n            max_memory (`Dict`, *optional*):\n                A dictionary device identifier for the maximum memory. Will default to the maximum memory available for\n                each GPU and the available CPU RAM if unset.\n            offload_folder (`str` or `os.PathLike`, *optional*):\n                The path to offload weights if device_map contains the value `\"disk\"`.\n            offload_state_dict (`bool`, *optional*):\n                If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if\n                the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`\n                when there is some disk offload.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n            use_safetensors (`bool`, *optional*, defaults to `None`):\n                If set to `None`, the safetensors weights are downloaded if they're available **and** if the\n                safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors\n                weights. If set to `False`, safetensors weights are not loaded.\n            gated (`bool`, *optional*, defaults to `False` ):\n                A boolean to filter models on the Hub that are gated or not.\n            kwargs (remaining dictionary of keyword arguments, *optional*):\n                Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline\n                class). The overwritten components are passed directly to the pipelines `__init__` method. See example\n                below for more information.\n            variant (`str`, *optional*):\n                Load weights from a specified variant filename such as `\"fp16\"` or `\"ema\"`. This is ignored when\n                loading `from_flax`.\n\n        > [!TIP]\n        > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with\n        > `hf auth login`.\n\n        Examples:\n\n        ```py\n        >>> from auto_diffusers import EasyPipelineForText2Image\n\n        >>> pipeline = EasyPipelineForText2Image.from_huggingface(\"stable-diffusion-v1-5\")\n        >>> image = pipeline(prompt).images[0]\n        ```\n        \"\"\"\n        # Update kwargs to ensure the model is downloaded and parameters are included\n        _status = {\n            \"download\": True,\n            \"include_params\": True,\n            \"skip_error\": False,\n            \"pipeline_tag\": \"text-to-image\",\n        }\n        kwargs.update(_status)\n\n        # Search for the model on Hugging Face and get the model status\n        hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)\n        logger.warning(f\"checkpoint_path: {hf_checkpoint_status.model_status.download_url}\")\n        checkpoint_path = hf_checkpoint_status.model_path\n\n        # Check the format of the model checkpoint\n        if hf_checkpoint_status.loading_method == \"from_single_file\":\n            # Load the pipeline from a single file checkpoint\n            pipeline = load_pipeline_from_single_file(\n                pretrained_model_or_path=checkpoint_path,\n                pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING,\n                **kwargs,\n            )\n        else:\n            pipeline = cls.from_pretrained(checkpoint_path, **kwargs)\n        return add_methods(pipeline)\n\n    @classmethod\n    def from_civitai(cls, pretrained_model_link_or_path, **kwargs):\n        r\"\"\"\n        Parameters:\n            pretrained_model_or_path (`str` or `os.PathLike`, *optional*):\n                Can be either:\n\n                    - A keyword to search for Hugging Face (for example `Stable Diffusion`)\n                    - Link to `.ckpt` or `.safetensors` file (for example\n                      `\"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors\"`) on the Hub.\n                    - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline\n                      hosted on the Hub.\n                    - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights\n                      saved using\n                    [`~DiffusionPipeline.save_pretrained`].\n            model_type (`str`, *optional*, defaults to `Checkpoint`):\n                The type of model to search for. (for example `Checkpoint`, `TextualInversion`, `LORA`, `Controlnet`)\n            base_model (`str`, *optional*):\n                The base model to filter by.\n            cache_dir (`str`, `Path`, *optional*):\n                Path to the folder where cached files are stored.\n            resume (`bool`, *optional*, defaults to `False`):\n                Whether to resume an incomplete download.\n            torch_dtype (`str` or `torch.dtype`, *optional*):\n                Override the default `torch.dtype` and load the model with another dtype. If \"auto\" is passed, the\n                dtype is automatically derived from the model's weights.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            output_loading_info(`bool`, *optional*, defaults to `False`):\n                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.\n            local_files_only (`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            token (`str`, *optional*):\n                The token to use as HTTP bearer authorization for remote files.\n            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):\n                A map that specifies where each submodule should go. It doesn’t need to be defined for each\n                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the\n                same device.\n\n                Set `device_map=\"auto\"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For\n                more information about each option see [designing a device\n                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).\n            max_memory (`Dict`, *optional*):\n                A dictionary device identifier for the maximum memory. Will default to the maximum memory available for\n                each GPU and the available CPU RAM if unset.\n            offload_folder (`str` or `os.PathLike`, *optional*):\n                The path to offload weights if device_map contains the value `\"disk\"`.\n            offload_state_dict (`bool`, *optional*):\n                If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if\n                the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`\n                when there is some disk offload.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n            use_safetensors (`bool`, *optional*, defaults to `None`):\n                If set to `None`, the safetensors weights are downloaded if they're available **and** if the\n                safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors\n                weights. If set to `False`, safetensors weights are not loaded.\n            kwargs (remaining dictionary of keyword arguments, *optional*):\n                Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline\n                class). The overwritten components are passed directly to the pipelines `__init__` method. See example\n                below for more information.\n\n        > [!TIP]\n        > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with\n        > `hf auth login`.\n\n        Examples:\n\n        ```py\n        >>> from auto_diffusers import EasyPipelineForText2Image\n\n        >>> pipeline = EasyPipelineForText2Image.from_huggingface(\"stable-diffusion-v1-5\")\n        >>> image = pipeline(prompt).images[0]\n        ```\n        \"\"\"\n        # Update kwargs to ensure the model is downloaded and parameters are included\n        _status = {\n            \"download\": True,\n            \"include_params\": True,\n            \"skip_error\": False,\n            \"model_type\": \"Checkpoint\",\n        }\n        kwargs.update(_status)\n\n        # Search for the model on Civitai and get the model status\n        checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)\n        logger.warning(f\"checkpoint_path: {checkpoint_status.model_status.site_url}\")\n        checkpoint_path = checkpoint_status.model_path\n\n        # Load the pipeline from a single file checkpoint\n        pipeline = load_pipeline_from_single_file(\n            pretrained_model_or_path=checkpoint_path,\n            pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING,\n            **kwargs,\n        )\n        return add_methods(pipeline)\n\n\nclass EasyPipelineForImage2Image(AutoPipelineForImage2Image):\n    r\"\"\"\n\n    [`EasyPipelineForImage2Image`] is a generic pipeline class that instantiates an image-to-image pipeline class. The\n    specific underlying pipeline class is automatically selected from either the\n    [`~EasyPipelineForImage2Image.from_pretrained`], [`~EasyPipelineForImage2Image.from_pipe`], [`~EasyPipelineForImage2Image.from_huggingface`] or [`~EasyPipelineForImage2Image.from_civitai`] methods.\n\n    This class cannot be instantiated using `__init__()` (throws an error).\n\n    Class attributes:\n\n        - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the\n          diffusion pipeline's components.\n\n    \"\"\"\n\n    config_name = \"model_index.json\"\n\n    def __init__(self, *args, **kwargs):\n        # EnvironmentError is returned\n        super().__init__()\n\n    @classmethod\n    @validate_hf_hub_args\n    def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):\n        r\"\"\"\n        Parameters:\n            pretrained_model_or_path (`str` or `os.PathLike`, *optional*):\n                Can be either:\n\n                    - A keyword to search for Hugging Face (for example `Stable Diffusion`)\n                    - Link to `.ckpt` or `.safetensors` file (for example\n                      `\"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors\"`) on the Hub.\n                    - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline\n                      hosted on the Hub.\n                    - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights\n                      saved using\n                    [`~DiffusionPipeline.save_pretrained`].\n            checkpoint_format (`str`, *optional*, defaults to `\"single_file\"`):\n                The format of the model checkpoint.\n            pipeline_tag (`str`, *optional*):\n                Tag to filter models by pipeline.\n            torch_dtype (`str` or `torch.dtype`, *optional*):\n                Override the default `torch.dtype` and load the model with another dtype. If \"auto\" is passed, the\n                dtype is automatically derived from the model's weights.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            cache_dir (`Union[str, os.PathLike]`, *optional*):\n                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache\n                is not used.\n\n            proxies (`Dict[str, str]`, *optional*):\n                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',\n                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n            output_loading_info(`bool`, *optional*, defaults to `False`):\n                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.\n            local_files_only (`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            token (`str` or *bool*, *optional*):\n                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from\n                `diffusers-cli login` (stored in `~/.huggingface`) is used.\n            revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier\n                allowed by Git.\n            custom_revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, or a commit id similar to\n                `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a\n                custom pipeline from GitHub, otherwise it defaults to `\"main\"` when loading from the Hub.\n            mirror (`str`, *optional*):\n                Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not\n                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more\n                information.\n            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):\n                A map that specifies where each submodule should go. It doesn’t need to be defined for each\n                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the\n                same device.\n\n                Set `device_map=\"auto\"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For\n                more information about each option see [designing a device\n                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).\n            max_memory (`Dict`, *optional*):\n                A dictionary device identifier for the maximum memory. Will default to the maximum memory available for\n                each GPU and the available CPU RAM if unset.\n            offload_folder (`str` or `os.PathLike`, *optional*):\n                The path to offload weights if device_map contains the value `\"disk\"`.\n            offload_state_dict (`bool`, *optional*):\n                If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if\n                the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`\n                when there is some disk offload.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n            use_safetensors (`bool`, *optional*, defaults to `None`):\n                If set to `None`, the safetensors weights are downloaded if they're available **and** if the\n                safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors\n                weights. If set to `False`, safetensors weights are not loaded.\n            gated (`bool`, *optional*, defaults to `False` ):\n                A boolean to filter models on the Hub that are gated or not.\n            kwargs (remaining dictionary of keyword arguments, *optional*):\n                Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline\n                class). The overwritten components are passed directly to the pipelines `__init__` method. See example\n                below for more information.\n            variant (`str`, *optional*):\n                Load weights from a specified variant filename such as `\"fp16\"` or `\"ema\"`. This is ignored when\n                loading `from_flax`.\n\n        > [!TIP]\n        > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with\n        > `hf auth login`.\n\n        Examples:\n\n        ```py\n        >>> from auto_diffusers import EasyPipelineForImage2Image\n\n        >>> pipeline = EasyPipelineForImage2Image.from_huggingface(\"stable-diffusion-v1-5\")\n        >>> image = pipeline(prompt, image).images[0]\n        ```\n        \"\"\"\n        # Update kwargs to ensure the model is downloaded and parameters are included\n        _parmas = {\n            \"download\": True,\n            \"include_params\": True,\n            \"skip_error\": False,\n            \"pipeline_tag\": \"image-to-image\",\n        }\n        kwargs.update(_parmas)\n\n        # Search for the model on Hugging Face and get the model status\n        hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)\n        logger.warning(f\"checkpoint_path: {hf_checkpoint_status.model_status.download_url}\")\n        checkpoint_path = hf_checkpoint_status.model_path\n\n        # Check the format of the model checkpoint\n        if hf_checkpoint_status.loading_method == \"from_single_file\":\n            # Load the pipeline from a single file checkpoint\n            pipeline = load_pipeline_from_single_file(\n                pretrained_model_or_path=checkpoint_path,\n                pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING,\n                **kwargs,\n            )\n        else:\n            pipeline = cls.from_pretrained(checkpoint_path, **kwargs)\n\n        return add_methods(pipeline)\n\n    @classmethod\n    def from_civitai(cls, pretrained_model_link_or_path, **kwargs):\n        r\"\"\"\n        Parameters:\n            pretrained_model_or_path (`str` or `os.PathLike`, *optional*):\n                Can be either:\n\n                    - A keyword to search for Hugging Face (for example `Stable Diffusion`)\n                    - Link to `.ckpt` or `.safetensors` file (for example\n                      `\"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors\"`) on the Hub.\n                    - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline\n                      hosted on the Hub.\n                    - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights\n                      saved using\n                    [`~DiffusionPipeline.save_pretrained`].\n            model_type (`str`, *optional*, defaults to `Checkpoint`):\n                The type of model to search for. (for example `Checkpoint`, `TextualInversion`, `LORA`, `Controlnet`)\n            base_model (`str`, *optional*):\n                The base model to filter by.\n            cache_dir (`str`, `Path`, *optional*):\n                Path to the folder where cached files are stored.\n            resume (`bool`, *optional*, defaults to `False`):\n                Whether to resume an incomplete download.\n            torch_dtype (`str` or `torch.dtype`, *optional*):\n                Override the default `torch.dtype` and load the model with another dtype. If \"auto\" is passed, the\n                dtype is automatically derived from the model's weights.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            output_loading_info(`bool`, *optional*, defaults to `False`):\n                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.\n            local_files_only (`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            token (`str`, *optional*):\n                The token to use as HTTP bearer authorization for remote files.\n            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):\n                A map that specifies where each submodule should go. It doesn’t need to be defined for each\n                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the\n                same device.\n\n                Set `device_map=\"auto\"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For\n                more information about each option see [designing a device\n                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).\n            max_memory (`Dict`, *optional*):\n                A dictionary device identifier for the maximum memory. Will default to the maximum memory available for\n                each GPU and the available CPU RAM if unset.\n            offload_folder (`str` or `os.PathLike`, *optional*):\n                The path to offload weights if device_map contains the value `\"disk\"`.\n            offload_state_dict (`bool`, *optional*):\n                If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if\n                the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`\n                when there is some disk offload.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n            use_safetensors (`bool`, *optional*, defaults to `None`):\n                If set to `None`, the safetensors weights are downloaded if they're available **and** if the\n                safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors\n                weights. If set to `False`, safetensors weights are not loaded.\n            kwargs (remaining dictionary of keyword arguments, *optional*):\n                Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline\n                class). The overwritten components are passed directly to the pipelines `__init__` method. See example\n                below for more information.\n\n        > [!TIP]\n        > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with\n        > `hf auth login`.\n\n        Examples:\n\n        ```py\n        >>> from auto_diffusers import EasyPipelineForImage2Image\n\n        >>> pipeline = EasyPipelineForImage2Image.from_huggingface(\"stable-diffusion-v1-5\")\n        >>> image = pipeline(prompt, image).images[0]\n        ```\n        \"\"\"\n        # Update kwargs to ensure the model is downloaded and parameters are included\n        _status = {\n            \"download\": True,\n            \"include_params\": True,\n            \"skip_error\": False,\n            \"model_type\": \"Checkpoint\",\n        }\n        kwargs.update(_status)\n\n        # Search for the model on Civitai and get the model status\n        checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)\n        logger.warning(f\"checkpoint_path: {checkpoint_status.model_status.site_url}\")\n        checkpoint_path = checkpoint_status.model_path\n\n        # Load the pipeline from a single file checkpoint\n        pipeline = load_pipeline_from_single_file(\n            pretrained_model_or_path=checkpoint_path,\n            pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING,\n            **kwargs,\n        )\n        return add_methods(pipeline)\n\n\nclass EasyPipelineForInpainting(AutoPipelineForInpainting):\n    r\"\"\"\n\n    [`EasyPipelineForInpainting`] is a generic pipeline class that instantiates an inpainting pipeline class. The\n    specific underlying pipeline class is automatically selected from either the\n    [`~EasyPipelineForInpainting.from_pretrained`], [`~EasyPipelineForInpainting.from_pipe`], [`~EasyPipelineForInpainting.from_huggingface`] or [`~EasyPipelineForInpainting.from_civitai`] methods.\n\n    This class cannot be instantiated using `__init__()` (throws an error).\n\n    Class attributes:\n\n        - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the\n          diffusion pipeline's components.\n\n    \"\"\"\n\n    config_name = \"model_index.json\"\n\n    def __init__(self, *args, **kwargs):\n        # EnvironmentError is returned\n        super().__init__()\n\n    @classmethod\n    @validate_hf_hub_args\n    def from_huggingface(cls, pretrained_model_link_or_path, **kwargs):\n        r\"\"\"\n        Parameters:\n            pretrained_model_or_path (`str` or `os.PathLike`, *optional*):\n                Can be either:\n\n                    - A keyword to search for Hugging Face (for example `Stable Diffusion`)\n                    - Link to `.ckpt` or `.safetensors` file (for example\n                      `\"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors\"`) on the Hub.\n                    - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline\n                      hosted on the Hub.\n                    - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights\n                      saved using\n                    [`~DiffusionPipeline.save_pretrained`].\n            checkpoint_format (`str`, *optional*, defaults to `\"single_file\"`):\n                The format of the model checkpoint.\n            pipeline_tag (`str`, *optional*):\n                Tag to filter models by pipeline.\n            torch_dtype (`str` or `torch.dtype`, *optional*):\n                Override the default `torch.dtype` and load the model with another dtype. If \"auto\" is passed, the\n                dtype is automatically derived from the model's weights.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            cache_dir (`Union[str, os.PathLike]`, *optional*):\n                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache\n                is not used.\n\n            proxies (`Dict[str, str]`, *optional*):\n                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',\n                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.\n            output_loading_info(`bool`, *optional*, defaults to `False`):\n                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.\n            local_files_only (`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            token (`str` or *bool*, *optional*):\n                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from\n                `diffusers-cli login` (stored in `~/.huggingface`) is used.\n            revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier\n                allowed by Git.\n            custom_revision (`str`, *optional*, defaults to `\"main\"`):\n                The specific model version to use. It can be a branch name, a tag name, or a commit id similar to\n                `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a\n                custom pipeline from GitHub, otherwise it defaults to `\"main\"` when loading from the Hub.\n            mirror (`str`, *optional*):\n                Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not\n                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more\n                information.\n            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):\n                A map that specifies where each submodule should go. It doesn’t need to be defined for each\n                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the\n                same device.\n\n                Set `device_map=\"auto\"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For\n                more information about each option see [designing a device\n                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).\n            max_memory (`Dict`, *optional*):\n                A dictionary device identifier for the maximum memory. Will default to the maximum memory available for\n                each GPU and the available CPU RAM if unset.\n            offload_folder (`str` or `os.PathLike`, *optional*):\n                The path to offload weights if device_map contains the value `\"disk\"`.\n            offload_state_dict (`bool`, *optional*):\n                If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if\n                the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`\n                when there is some disk offload.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n            use_safetensors (`bool`, *optional*, defaults to `None`):\n                If set to `None`, the safetensors weights are downloaded if they're available **and** if the\n                safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors\n                weights. If set to `False`, safetensors weights are not loaded.\n            gated (`bool`, *optional*, defaults to `False` ):\n                A boolean to filter models on the Hub that are gated or not.\n            kwargs (remaining dictionary of keyword arguments, *optional*):\n                Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline\n                class). The overwritten components are passed directly to the pipelines `__init__` method. See example\n                below for more information.\n            variant (`str`, *optional*):\n                Load weights from a specified variant filename such as `\"fp16\"` or `\"ema\"`. This is ignored when\n                loading `from_flax`.\n\n        > [!TIP]\n        > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with\n        > `hf auth login\n\n        Examples:\n\n        ```py\n        >>> from auto_diffusers import EasyPipelineForInpainting\n\n        >>> pipeline = EasyPipelineForInpainting.from_huggingface(\"stable-diffusion-2-inpainting\")\n        >>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]\n        ```\n        \"\"\"\n        # Update kwargs to ensure the model is downloaded and parameters are included\n        _status = {\n            \"download\": True,\n            \"include_params\": True,\n            \"skip_error\": False,\n            \"pipeline_tag\": \"image-to-image\",\n        }\n        kwargs.update(_status)\n\n        # Search for the model on Hugging Face and get the model status\n        hf_checkpoint_status = search_huggingface(pretrained_model_link_or_path, **kwargs)\n        logger.warning(f\"checkpoint_path: {hf_checkpoint_status.model_status.download_url}\")\n        checkpoint_path = hf_checkpoint_status.model_path\n\n        # Check the format of the model checkpoint\n        if hf_checkpoint_status.loading_method == \"from_single_file\":\n            # Load the pipeline from a single file checkpoint\n            pipeline = load_pipeline_from_single_file(\n                pretrained_model_or_path=checkpoint_path,\n                pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING,\n                **kwargs,\n            )\n        else:\n            pipeline = cls.from_pretrained(checkpoint_path, **kwargs)\n        return add_methods(pipeline)\n\n    @classmethod\n    def from_civitai(cls, pretrained_model_link_or_path, **kwargs):\n        r\"\"\"\n        Parameters:\n            pretrained_model_or_path (`str` or `os.PathLike`, *optional*):\n                Can be either:\n\n                    - A keyword to search for Hugging Face (for example `Stable Diffusion`)\n                    - Link to `.ckpt` or `.safetensors` file (for example\n                      `\"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors\"`) on the Hub.\n                    - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline\n                      hosted on the Hub.\n                    - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights\n                      saved using\n                    [`~DiffusionPipeline.save_pretrained`].\n            model_type (`str`, *optional*, defaults to `Checkpoint`):\n                The type of model to search for. (for example `Checkpoint`, `TextualInversion`, `LORA`, `Controlnet`)\n            base_model (`str`, *optional*):\n                The base model to filter by.\n            cache_dir (`str`, `Path`, *optional*):\n                Path to the folder where cached files are stored.\n            resume (`bool`, *optional*, defaults to `False`):\n                Whether to resume an incomplete download.\n            torch_dtype (`str` or `torch.dtype`, *optional*):\n                Override the default `torch.dtype` and load the model with another dtype. If \"auto\" is passed, the\n                dtype is automatically derived from the model's weights.\n            force_download (`bool`, *optional*, defaults to `False`):\n                Whether or not to force the (re-)download of the model weights and configuration files, overriding the\n                cached versions if they exist.\n            output_loading_info(`bool`, *optional*, defaults to `False`):\n                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.\n            local_files_only (`bool`, *optional*, defaults to `False`):\n                Whether to only load local model weights and configuration files or not. If set to `True`, the model\n                won't be downloaded from the Hub.\n            token (`str`, *optional*):\n                The token to use as HTTP bearer authorization for remote files.\n            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):\n                A map that specifies where each submodule should go. It doesn’t need to be defined for each\n                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the\n                same device.\n\n                Set `device_map=\"auto\"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For\n                more information about each option see [designing a device\n                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).\n            max_memory (`Dict`, *optional*):\n                A dictionary device identifier for the maximum memory. Will default to the maximum memory available for\n                each GPU and the available CPU RAM if unset.\n            offload_folder (`str` or `os.PathLike`, *optional*):\n                The path to offload weights if device_map contains the value `\"disk\"`.\n            offload_state_dict (`bool`, *optional*):\n                If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if\n                the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`\n                when there is some disk offload.\n            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):\n                Speed up model loading only loading the pretrained weights and not initializing the weights. This also\n                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.\n                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this\n                argument to `True` will raise an error.\n            use_safetensors (`bool`, *optional*, defaults to `None`):\n                If set to `None`, the safetensors weights are downloaded if they're available **and** if the\n                safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors\n                weights. If set to `False`, safetensors weights are not loaded.\n            kwargs (remaining dictionary of keyword arguments, *optional*):\n                Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline\n                class). The overwritten components are passed directly to the pipelines `__init__` method. See example\n                below for more information.\n\n        > [!TIP]\n        > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with\n        > `hf auth login\n\n        Examples:\n\n        ```py\n        >>> from auto_diffusers import EasyPipelineForInpainting\n\n        >>> pipeline = EasyPipelineForInpainting.from_huggingface(\"stable-diffusion-2-inpainting\")\n        >>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]\n        ```\n        \"\"\"\n        # Update kwargs to ensure the model is downloaded and parameters are included\n        _status = {\n            \"download\": True,\n            \"include_params\": True,\n            \"skip_error\": False,\n            \"model_type\": \"Checkpoint\",\n        }\n        kwargs.update(_status)\n\n        # Search for the model on Civitai and get the model status\n        checkpoint_status = search_civitai(pretrained_model_link_or_path, **kwargs)\n        logger.warning(f\"checkpoint_path: {checkpoint_status.model_status.site_url}\")\n        checkpoint_path = checkpoint_status.model_path\n\n        # Load the pipeline from a single file checkpoint\n        pipeline = load_pipeline_from_single_file(\n            pretrained_model_or_path=checkpoint_path,\n            pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING,\n            **kwargs,\n        )\n        return add_methods(pipeline)\n"
  },
  {
    "path": "examples/model_search/requirements.txt",
    "content": "huggingface-hub>=0.26.2\n"
  },
  {
    "path": "examples/reinforcement_learning/README.md",
    "content": "\n## Diffusion-based Policy Learning for RL\n\n`diffusion_policy` implements [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/), a diffusion model that predicts robot action sequences in reinforcement learning tasks.\n\nThis example implements a robot control model for pushing a T-shaped block into a target area. The model takes in current state observations as input, and outputs a trajectory of subsequent steps to follow.\n\nTo execute the script, run `diffusion_policy.py`\n\n## Diffuser Locomotion\n\nThese examples show how to run [Diffuser](https://huggingface.co/papers/2205.09991) in Diffusers.\nThere are two ways to use the script, `run_diffuser_locomotion.py`.\n\nThe key option is a change of the variable `n_guide_steps`.\nWhen `n_guide_steps=0`, the trajectories are sampled from the diffusion model, but not fine-tuned to maximize reward in the environment.\nBy default, `n_guide_steps=2` to match the original implementation.\n\n\nYou will need some RL specific requirements to run the examples:\n\n```sh\npip install -f https://download.pytorch.org/whl/torch_stable.html \\\n                free-mujoco-py \\\n                einops \\\n                gym==0.24.1 \\\n                protobuf==3.20.1 \\\n                git+https://github.com/rail-berkeley/d4rl.git \\\n                mediapy \\\n                Pillow==9.0.0\n```\n"
  },
  {
    "path": "examples/reinforcement_learning/diffusion_policy.py",
    "content": "import numpy as np\nimport numpy.core.multiarray as multiarray\nimport torch\nimport torch.nn as nn\nfrom huggingface_hub import hf_hub_download\nfrom torch.serialization import add_safe_globals\n\nfrom diffusers import DDPMScheduler, UNet1DModel\n\n\nadd_safe_globals(\n    [\n        multiarray._reconstruct,\n        np.ndarray,\n        np.dtype,\n        np.dtype(np.float32).type,\n        np.dtype(np.float64).type,\n        np.dtype(np.int32).type,\n        np.dtype(np.int64).type,\n        type(np.dtype(np.float32)),\n        type(np.dtype(np.float64)),\n        type(np.dtype(np.int32)),\n        type(np.dtype(np.int64)),\n    ]\n)\n\n\"\"\"\nAn example of using HuggingFace's diffusers library for diffusion policy,\ngenerating smooth movement trajectories.\n\nThis implements a robot control model for pushing a T-shaped block into a target area.\nThe model takes in the robot arm position, block position, and block angle,\nthen outputs a sequence of 16 (x,y) positions for the robot arm to follow.\n\"\"\"\n\n\nclass ObservationEncoder(nn.Module):\n    \"\"\"\n    Converts raw robot observations (positions/angles) into a more compact representation\n\n    state_dim (int): Dimension of the input state vector (default: 5)\n        [robot_x, robot_y, block_x, block_y, block_angle]\n\n    - Input shape: (batch_size, state_dim)\n    - Output shape: (batch_size, 256)\n    \"\"\"\n\n    def __init__(self, state_dim):\n        super().__init__()\n        self.net = nn.Sequential(nn.Linear(state_dim, 512), nn.ReLU(), nn.Linear(512, 256))\n\n    def forward(self, x):\n        return self.net(x)\n\n\nclass ObservationProjection(nn.Module):\n    \"\"\"\n    Takes the encoded observation and transforms it into 32 values that represent the current robot/block situation.\n    These values are used as additional contextual information during the diffusion model's trajectory generation.\n\n    - Input: 256-dim vector (padded to 512)\n            Shape: (batch_size, 256)\n    - Output: 32 contextual information values for the diffusion model\n            Shape: (batch_size, 32)\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.weight = nn.Parameter(torch.randn(32, 512))\n        self.bias = nn.Parameter(torch.zeros(32))\n\n    def forward(self, x):  # pad 256-dim input to 512-dim with zeros\n        if x.size(-1) == 256:\n            x = torch.cat([x, torch.zeros(*x.shape[:-1], 256, device=x.device)], dim=-1)\n        return nn.functional.linear(x, self.weight, self.bias)\n\n\nclass DiffusionPolicy:\n    \"\"\"\n    Implements diffusion policy for generating robot arm trajectories.\n    Uses diffusion to generate sequences of positions for a robot arm, conditioned on\n    the current state of the robot and the block it needs to push.\n\n    The model expects observations in pixel coordinates (0-512 range) and block angle in radians.\n    It generates trajectories as sequences of (x,y) coordinates also in the 0-512 range.\n    \"\"\"\n\n    def __init__(self, state_dim=5, device=\"cpu\"):\n        self.device = device\n\n        # define valid ranges for inputs/outputs\n        self.stats = {\n            \"obs\": {\"min\": torch.zeros(5), \"max\": torch.tensor([512, 512, 512, 512, 2 * np.pi])},\n            \"action\": {\"min\": torch.zeros(2), \"max\": torch.full((2,), 512)},\n        }\n\n        self.obs_encoder = ObservationEncoder(state_dim).to(device)\n        self.obs_projection = ObservationProjection().to(device)\n\n        # UNet model that performs the denoising process\n        # takes in concatenated action (2 channels) and context (32 channels) = 34 channels\n        # outputs predicted action (2 channels for x,y coordinates)\n        self.model = UNet1DModel(\n            sample_size=16,  # length of trajectory sequence\n            in_channels=34,\n            out_channels=2,\n            layers_per_block=2,  # number of layers per each UNet block\n            block_out_channels=(128,),  # number of output neurons per layer in each block\n            down_block_types=(\"DownBlock1D\",),  # reduce the resolution of data\n            up_block_types=(\"UpBlock1D\",),  # increase the resolution of data\n        ).to(device)\n\n        # noise scheduler that controls the denoising process\n        self.noise_scheduler = DDPMScheduler(\n            num_train_timesteps=100,  # number of denoising steps\n            beta_schedule=\"squaredcos_cap_v2\",  # type of noise schedule\n        )\n\n        # load pre-trained weights from HuggingFace\n        checkpoint = torch.load(\n            hf_hub_download(\"dorsar/diffusion_policy\", \"push_tblock.pt\"), weights_only=True, map_location=device\n        )\n        self.model.load_state_dict(checkpoint[\"model_state_dict\"])\n        self.obs_encoder.load_state_dict(checkpoint[\"encoder_state_dict\"])\n        self.obs_projection.load_state_dict(checkpoint[\"projection_state_dict\"])\n\n    # scales data to [-1, 1] range for neural network processing\n    def normalize_data(self, data, stats):\n        return ((data - stats[\"min\"]) / (stats[\"max\"] - stats[\"min\"])) * 2 - 1\n\n    # converts normalized data back to original range\n    def unnormalize_data(self, ndata, stats):\n        return ((ndata + 1) / 2) * (stats[\"max\"] - stats[\"min\"]) + stats[\"min\"]\n\n    @torch.no_grad()\n    def predict(self, observation):\n        \"\"\"\n        Generates a trajectory of robot arm positions given the current state.\n\n        Args:\n            observation (torch.Tensor): Current state [robot_x, robot_y, block_x, block_y, block_angle]\n                                    Shape: (batch_size, 5)\n\n        Returns:\n            torch.Tensor: Sequence of (x,y) positions for the robot arm to follow\n                        Shape: (batch_size, 16, 2) where:\n                        - 16 is the number of steps in the trajectory\n                        - 2 is the (x,y) coordinates in pixel space (0-512)\n\n        The function first encodes the observation, then uses it to condition a diffusion\n        process that gradually denoises random trajectories into smooth, purposeful movements.\n        \"\"\"\n        observation = observation.to(self.device)\n        normalized_obs = self.normalize_data(observation, self.stats[\"obs\"])\n\n        # encode the observation into context values for the diffusion model\n        cond = self.obs_projection(self.obs_encoder(normalized_obs))\n        # keeps first & second dimension sizes unchanged, and multiplies last dimension by 16\n        cond = cond.view(normalized_obs.shape[0], -1, 1).expand(-1, -1, 16)\n\n        # initialize action with noise - random noise that will be refined into a trajectory\n        action = torch.randn((observation.shape[0], 2, 16), device=self.device)\n\n        # denoise\n        # at each step `t`, the current noisy trajectory (`action`) & conditioning info (context) are\n        # fed into the model to predict a denoised trajectory, then uses self.noise_scheduler.step to\n        # apply this prediction & slightly reduce the noise in `action` more\n\n        self.noise_scheduler.set_timesteps(100)\n        for t in self.noise_scheduler.timesteps:\n            model_output = self.model(torch.cat([action, cond], dim=1), t)\n            action = self.noise_scheduler.step(model_output.sample, t, action).prev_sample\n\n        action = action.transpose(1, 2)  # reshape to [batch, 16, 2]\n        action = self.unnormalize_data(action, self.stats[\"action\"])  # scale back to coordinates\n        return action\n\n\nif __name__ == \"__main__\":\n    policy = DiffusionPolicy()\n\n    # sample of a single observation\n    # robot arm starts in center, block is slightly left and up, rotated 90 degrees\n    obs = torch.tensor(\n        [\n            [\n                256.0,  # robot arm x position (middle of screen)\n                256.0,  # robot arm y position (middle of screen)\n                200.0,  # block x position\n                300.0,  # block y position\n                np.pi / 2,  # block angle (90 degrees)\n            ]\n        ]\n    )\n\n    action = policy.predict(obs)\n\n    print(\"Action shape:\", action.shape)  # should be [1, 16, 2] - one trajectory of 16 x,y positions\n    print(\"\\nPredicted trajectory:\")\n    for i, (x, y) in enumerate(action[0]):\n        print(f\"Step {i:2d}: x={x:6.1f}, y={y:6.1f}\")\n"
  },
  {
    "path": "examples/reinforcement_learning/run_diffuser_locomotion.py",
    "content": "import d4rl  # noqa\nimport gym\nimport tqdm\nfrom diffusers.experimental import ValueGuidedRLPipeline\n\n\nconfig = {\n    \"n_samples\": 64,\n    \"horizon\": 32,\n    \"num_inference_steps\": 20,\n    \"n_guide_steps\": 2,  # can set to 0 for faster sampling, does not use value network\n    \"scale_grad_by_std\": True,\n    \"scale\": 0.1,\n    \"eta\": 0.0,\n    \"t_grad_cutoff\": 2,\n    \"device\": \"cpu\",\n}\n\n\nif __name__ == \"__main__\":\n    env_name = \"hopper-medium-v2\"\n    env = gym.make(env_name)\n\n    pipeline = ValueGuidedRLPipeline.from_pretrained(\n        \"bglick13/hopper-medium-v2-value-function-hor32\",\n        env=env,\n    )\n\n    env.seed(0)\n    obs = env.reset()\n    total_reward = 0\n    total_score = 0\n    T = 1000\n    rollout = [obs.copy()]\n    try:\n        for t in tqdm.tqdm(range(T)):\n            # call the policy\n            denorm_actions = pipeline(obs, planning_horizon=32)\n\n            # execute action in environment\n            next_observation, reward, terminal, _ = env.step(denorm_actions)\n            score = env.get_normalized_score(total_reward)\n\n            # update return\n            total_reward += reward\n            total_score += score\n            print(\n                f\"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:\"\n                f\" {total_score}\"\n            )\n\n            # save observations for rendering\n            rollout.append(next_observation.copy())\n\n            obs = next_observation\n    except KeyboardInterrupt:\n        pass\n\n    print(f\"Total reward: {total_reward}\")\n"
  },
  {
    "path": "examples/research_projects/README.md",
    "content": "# Research projects\n\nThis folder contains various research projects using 🧨 Diffusers.\nThey are not really maintained by the core maintainers of this library and often require a specific version of Diffusers that is indicated in the requirements file of each folder.\nUpdating them to the most recent version of the library will require some work.\n\nTo use any of them, just run the command\n\n```sh\npip install -r requirements.txt\n```\ninside the folder of your choice.\n\nIf you need help with any of those, please open an issue where you directly ping the author(s), as indicated at the top of the README of each folder.\n"
  },
  {
    "path": "examples/research_projects/anytext/README.md",
    "content": "# AnyTextPipeline\n\nProject page: https://aigcdesigngroup.github.io/homepage_anytext\n\n\"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy.\"\n\n> **Note:** Each text line that needs to be generated should be enclosed in double quotes.\n\nFor any usage questions, please refer to the [paper](https://huggingface.co/papers/2311.03054).\n\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/b87ec9d2f265b448dd947c9d4a0da389/anytext.ipynb)\n\n```py\n# This example requires the `anytext_controlnet.py` file:\n# !git clone --depth 1 https://github.com/huggingface/diffusers.git\n# %cd diffusers/examples/research_projects/anytext\n# Let's choose a font file shared by an HF staff:\n# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf\n\nimport torch\nfrom diffusers import DiffusionPipeline\nfrom anytext_controlnet import AnyTextControlNetModel\nfrom diffusers.utils import load_image\n\n\nanytext_controlnet = AnyTextControlNetModel.from_pretrained(\"tolgacangoz/anytext-controlnet\", torch_dtype=torch.float16,\n                                                            variant=\"fp16\",)\npipe = DiffusionPipeline.from_pretrained(\"tolgacangoz/anytext\", font_path=\"arial-unicode-ms.ttf\",\n                                          controlnet=anytext_controlnet, torch_dtype=torch.float16,\n                                          trust_remote_code=False,  # One needs to give permission to run this pipeline's code\n                                          ).to(\"cuda\")\n\n# generate image\nprompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with \"Any\" \"Text\" written on it using cream'\ndraw_pos = load_image(\"https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png\")\n# There are two modes: \"generate\" and \"edit\". \"edit\" mode requires `ori_image` parameter for the image to be edited.\nimage = pipe(prompt, num_inference_steps=20, mode=\"generate\", draw_pos=draw_pos,\n             ).images[0]\nimage\n```\n"
  },
  {
    "path": "examples/research_projects/anytext/anytext.py",
    "content": "# Copyright 2025 The HuggingFace Team. All rights reserved.\n# Copyright (c) Alibaba, Inc. and its affiliates.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054).\n# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie\n# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license\n#\n# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).\n\n\nimport inspect\nimport math\nimport os\nimport re\nimport sys\nimport unicodedata\nfrom functools import partial\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport cv2\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nfrom huggingface_hub import hf_hub_download\nfrom ocr_recog.RecModel import RecModel\nfrom PIL import Image, ImageDraw, ImageFont\nfrom safetensors.torch import load_file\nfrom skimage.transform._geometric import _umeyama as get_sym_mat\nfrom torch import nn\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\nfrom transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import (\n    FromSingleFileMixin,\n    IPAdapterMixin,\n    StableDiffusionLoraLoaderMixin,\n    TextualInversionLoaderMixin,\n)\nfrom diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.constants import HF_MODULES_CACHE\nfrom diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor\n\n\nclass Checker:\n    def __init__(self):\n        pass\n\n    def _is_chinese_char(self, cp):\n        \"\"\"Checks whether CP is the codepoint of a CJK character.\"\"\"\n        # This defines a \"chinese character\" as anything in the CJK Unicode block:\n        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)\n        #\n        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,\n        # despite its name. The modern Korean Hangul alphabet is a different block,\n        # as is Japanese Hiragana and Katakana. Those alphabets are used to write\n        # space-separated words, so they are not treated specially and handled\n        # like the all of the other languages.\n        if (\n            (cp >= 0x4E00 and cp <= 0x9FFF)\n            or (cp >= 0x3400 and cp <= 0x4DBF)\n            or (cp >= 0x20000 and cp <= 0x2A6DF)\n            or (cp >= 0x2A700 and cp <= 0x2B73F)\n            or (cp >= 0x2B740 and cp <= 0x2B81F)\n            or (cp >= 0x2B820 and cp <= 0x2CEAF)\n            or (cp >= 0xF900 and cp <= 0xFAFF)\n            or (cp >= 0x2F800 and cp <= 0x2FA1F)\n        ):\n            return True\n\n        return False\n\n    def _clean_text(self, text):\n        \"\"\"Performs invalid character removal and whitespace cleanup on text.\"\"\"\n        output = []\n        for char in text:\n            cp = ord(char)\n            if cp == 0 or cp == 0xFFFD or self._is_control(char):\n                continue\n            if self._is_whitespace(char):\n                output.append(\" \")\n            else:\n                output.append(char)\n        return \"\".join(output)\n\n    def _is_control(self, char):\n        \"\"\"Checks whether `chars` is a control character.\"\"\"\n        # These are technically control characters but we count them as whitespace\n        # characters.\n        if char == \"\\t\" or char == \"\\n\" or char == \"\\r\":\n            return False\n        cat = unicodedata.category(char)\n        if cat in (\"Cc\", \"Cf\"):\n            return True\n        return False\n\n    def _is_whitespace(self, char):\n        \"\"\"Checks whether `chars` is a whitespace character.\"\"\"\n        # \\t, \\n, and \\r are technically control characters but we treat them\n        # as whitespace since they are generally considered as such.\n        if char == \" \" or char == \"\\t\" or char == \"\\n\" or char == \"\\r\":\n            return True\n        cat = unicodedata.category(char)\n        if cat == \"Zs\":\n            return True\n        return False\n\n\nchecker = Checker()\n\n\nPLACE_HOLDER = \"*\"\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> # This example requires the `anytext_controlnet.py` file:\n        >>> # !git clone --depth 1 https://github.com/huggingface/diffusers.git\n        >>> # %cd diffusers/examples/research_projects/anytext\n        >>> # Let's choose a font file shared by an HF staff:\n        >>> # !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf\n\n        >>> import torch\n        >>> from diffusers import DiffusionPipeline\n        >>> from anytext_controlnet import AnyTextControlNetModel\n        >>> from diffusers.utils import load_image\n\n        >>> anytext_controlnet = AnyTextControlNetModel.from_pretrained(\"tolgacangoz/anytext-controlnet\", torch_dtype=torch.float16,\n        ...                                                             variant=\"fp16\",)\n        >>> pipe = DiffusionPipeline.from_pretrained(\"tolgacangoz/anytext\", font_path=\"arial-unicode-ms.ttf\",\n        ...                                           controlnet=anytext_controlnet, torch_dtype=torch.float16,\n        ...                                           trust_remote_code=False,  # One needs to give permission to run this pipeline's code\n        ...                                           ).to(\"cuda\")\n\n\n        >>> # generate image\n        >>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with \"Any\" \"Text\" written on it using cream'\n        >>> draw_pos = load_image(\"https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png\")\n        >>> # There are two modes: \"generate\" and \"edit\". \"edit\" mode requires `ori_image` parameter for the image to be edited.\n        >>> image = pipe(prompt, num_inference_steps=20, mode=\"generate\", draw_pos=draw_pos,\n        ...              ).images[0]\n        >>> image\n        ```\n\"\"\"\n\n\ndef get_clip_token_for_string(tokenizer, string):\n    batch_encoding = tokenizer(\n        string,\n        truncation=True,\n        max_length=77,\n        return_length=True,\n        return_overflowing_tokens=False,\n        padding=\"max_length\",\n        return_tensors=\"pt\",\n    )\n    tokens = batch_encoding[\"input_ids\"]\n    assert torch.count_nonzero(tokens - 49407) == 2, (\n        f\"String '{string}' maps to more than a single token. Please use another string\"\n    )\n    return tokens[0, 1]\n\n\ndef get_recog_emb(encoder, img_list):\n    _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list]\n    encoder.predictor.eval()\n    _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)\n    return preds_neck\n\n\nclass EmbeddingManager(ModelMixin, ConfigMixin):\n    @register_to_config\n    def __init__(\n        self,\n        embedder,\n        placeholder_string=\"*\",\n        use_fp16=False,\n        token_dim=768,\n        get_recog_emb=None,\n    ):\n        super().__init__()\n        get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)\n\n        self.proj = nn.Linear(40 * 64, token_dim)\n        proj_dir = hf_hub_download(\n            repo_id=\"tolgacangoz/anytext\",\n            filename=\"text_embedding_module/proj.safetensors\",\n            cache_dir=HF_MODULES_CACHE,\n        )\n        self.proj.load_state_dict(load_file(proj_dir, device=str(embedder.device)))\n        if use_fp16:\n            self.proj = self.proj.to(dtype=torch.float16)\n\n        self.placeholder_token = get_token_for_string(placeholder_string)\n\n    @torch.no_grad()\n    def encode_text(self, text_info):\n        if self.config.get_recog_emb is None:\n            self.config.get_recog_emb = partial(get_recog_emb, self.recog)\n\n        gline_list = []\n        for i in range(len(text_info[\"n_lines\"])):  # sample index in a batch\n            n_lines = text_info[\"n_lines\"][i]\n            for j in range(n_lines):  # line\n                gline_list += [text_info[\"gly_line\"][j][i : i + 1]]\n\n        if len(gline_list) > 0:\n            recog_emb = self.config.get_recog_emb(gline_list)\n            enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.dtype))\n\n        self.text_embs_all = []\n        n_idx = 0\n        for i in range(len(text_info[\"n_lines\"])):  # sample index in a batch\n            n_lines = text_info[\"n_lines\"][i]\n            text_embs = []\n            for j in range(n_lines):  # line\n                text_embs += [enc_glyph[n_idx : n_idx + 1]]\n                n_idx += 1\n            self.text_embs_all += [text_embs]\n\n    @torch.no_grad()\n    def forward(\n        self,\n        tokenized_text,\n        embedded_text,\n    ):\n        b, device = tokenized_text.shape[0], tokenized_text.device\n        for i in range(b):\n            idx = tokenized_text[i] == self.placeholder_token.to(device)\n            if sum(idx) > 0:\n                if i >= len(self.text_embs_all):\n                    logger.warning(\"truncation for log images...\")\n                    break\n                text_emb = torch.cat(self.text_embs_all[i], dim=0)\n                if sum(idx) != len(text_emb):\n                    logger.warning(\"truncation for long caption...\")\n                text_emb = text_emb.to(embedded_text.device)\n                embedded_text[i][idx] = text_emb[: sum(idx)]\n        return embedded_text\n\n    def embedding_parameters(self):\n        return self.parameters()\n\n\nsys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), \"..\")))\n\n\ndef min_bounding_rect(img):\n    ret, thresh = cv2.threshold(img, 127, 255, 0)\n    contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n    if len(contours) == 0:\n        print(\"Bad contours, using fake bbox...\")\n        return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])\n    max_contour = max(contours, key=cv2.contourArea)\n    rect = cv2.minAreaRect(max_contour)\n    box = cv2.boxPoints(rect)\n    box = np.int0(box)\n    # sort\n    x_sorted = sorted(box, key=lambda x: x[0])\n    left = x_sorted[:2]\n    right = x_sorted[2:]\n    left = sorted(left, key=lambda x: x[1])\n    (tl, bl) = left\n    right = sorted(right, key=lambda x: x[1])\n    (tr, br) = right\n    if tl[1] > bl[1]:\n        (tl, bl) = (bl, tl)\n    if tr[1] > br[1]:\n        (tr, br) = (br, tr)\n    return np.array([tl, tr, br, bl])\n\n\ndef adjust_image(box, img):\n    pts1 = np.float32([box[0], box[1], box[2], box[3]])\n    width = max(np.linalg.norm(pts1[0] - pts1[1]), np.linalg.norm(pts1[2] - pts1[3]))\n    height = max(np.linalg.norm(pts1[0] - pts1[3]), np.linalg.norm(pts1[1] - pts1[2]))\n    pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]])\n    # get transform matrix\n    M = get_sym_mat(pts1, pts2, estimate_scale=True)\n    C, H, W = img.shape\n    T = np.array([[2 / W, 0, -1], [0, 2 / H, -1], [0, 0, 1]])\n    theta = np.linalg.inv(T @ M @ np.linalg.inv(T))\n    theta = torch.from_numpy(theta[:2, :]).unsqueeze(0).type(torch.float32).to(img.device)\n    grid = F.affine_grid(theta, torch.Size([1, C, H, W]), align_corners=True)\n    result = F.grid_sample(img.unsqueeze(0), grid, align_corners=True)\n    result = torch.clamp(result.squeeze(0), 0, 255)\n    # crop\n    result = result[:, : int(height), : int(width)]\n    return result\n\n\ndef crop_image(src_img, mask):\n    box = min_bounding_rect(mask)\n    result = adjust_image(box, src_img)\n    if len(result.shape) == 2:\n        result = torch.stack([result] * 3, axis=-1)\n    return result\n\n\ndef create_predictor(model_lang=\"ch\", device=\"cpu\", use_fp16=False):\n    model_dir = hf_hub_download(\n        repo_id=\"tolgacangoz/anytext\",\n        filename=\"text_embedding_module/OCR/ppv3_rec.pth\",\n        cache_dir=HF_MODULES_CACHE,\n    )\n    if not os.path.exists(model_dir):\n        raise ValueError(\"not find model file path {}\".format(model_dir))\n\n    if model_lang == \"ch\":\n        n_class = 6625\n    elif model_lang == \"en\":\n        n_class = 97\n    else:\n        raise ValueError(f\"Unsupported OCR recog model_lang: {model_lang}\")\n    rec_config = {\n        \"in_channels\": 3,\n        \"backbone\": {\"type\": \"MobileNetV1Enhance\", \"scale\": 0.5, \"last_conv_stride\": [1, 2], \"last_pool_type\": \"avg\"},\n        \"neck\": {\n            \"type\": \"SequenceEncoder\",\n            \"encoder_type\": \"svtr\",\n            \"dims\": 64,\n            \"depth\": 2,\n            \"hidden_dims\": 120,\n            \"use_guide\": True,\n        },\n        \"head\": {\"type\": \"CTCHead\", \"fc_decay\": 0.00001, \"out_channels\": n_class, \"return_feats\": True},\n    }\n\n    rec_model = RecModel(rec_config)\n    state_dict = torch.load(model_dir, map_location=device)\n    rec_model.load_state_dict(state_dict)\n    return rec_model\n\n\ndef _check_image_file(path):\n    img_end = (\"tiff\", \"tif\", \"bmp\", \"rgb\", \"jpg\", \"png\", \"jpeg\")\n    return path.lower().endswith(tuple(img_end))\n\n\ndef get_image_file_list(img_file):\n    imgs_lists = []\n    if img_file is None or not os.path.exists(img_file):\n        raise Exception(\"not found any img file in {}\".format(img_file))\n    if os.path.isfile(img_file) and _check_image_file(img_file):\n        imgs_lists.append(img_file)\n    elif os.path.isdir(img_file):\n        for single_file in os.listdir(img_file):\n            file_path = os.path.join(img_file, single_file)\n            if os.path.isfile(file_path) and _check_image_file(file_path):\n                imgs_lists.append(file_path)\n    if len(imgs_lists) == 0:\n        raise Exception(\"not found any img file in {}\".format(img_file))\n    imgs_lists = sorted(imgs_lists)\n    return imgs_lists\n\n\nclass TextRecognizer(object):\n    def __init__(self, args, predictor):\n        self.rec_image_shape = [int(v) for v in args[\"rec_image_shape\"].split(\",\")]\n        self.rec_batch_num = args[\"rec_batch_num\"]\n        self.predictor = predictor\n        self.chars = self.get_char_dict(args[\"rec_char_dict_path\"])\n        self.char2id = {x: i for i, x in enumerate(self.chars)}\n        self.is_onnx = not isinstance(self.predictor, torch.nn.Module)\n        self.use_fp16 = args[\"use_fp16\"]\n\n    # img: CHW\n    def resize_norm_img(self, img, max_wh_ratio):\n        imgC, imgH, imgW = self.rec_image_shape\n        assert imgC == img.shape[0]\n        imgW = int((imgH * max_wh_ratio))\n\n        h, w = img.shape[1:]\n        ratio = w / float(h)\n        if math.ceil(imgH * ratio) > imgW:\n            resized_w = imgW\n        else:\n            resized_w = int(math.ceil(imgH * ratio))\n        resized_image = torch.nn.functional.interpolate(\n            img.unsqueeze(0),\n            size=(imgH, resized_w),\n            mode=\"bilinear\",\n            align_corners=True,\n        )\n        resized_image /= 255.0\n        resized_image -= 0.5\n        resized_image /= 0.5\n        padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)\n        padding_im[:, :, 0:resized_w] = resized_image[0]\n        return padding_im\n\n    # img_list: list of tensors with shape chw 0-255\n    def pred_imglist(self, img_list, show_debug=False):\n        img_num = len(img_list)\n        assert img_num > 0\n        # Calculate the aspect ratio of all text bars\n        width_list = []\n        for img in img_list:\n            width_list.append(img.shape[2] / float(img.shape[1]))\n        # Sorting can speed up the recognition process\n        indices = torch.from_numpy(np.argsort(np.array(width_list)))\n        batch_num = self.rec_batch_num\n        preds_all = [None] * img_num\n        preds_neck_all = [None] * img_num\n        for beg_img_no in range(0, img_num, batch_num):\n            end_img_no = min(img_num, beg_img_no + batch_num)\n            norm_img_batch = []\n\n            imgC, imgH, imgW = self.rec_image_shape[:3]\n            max_wh_ratio = imgW / imgH\n            for ino in range(beg_img_no, end_img_no):\n                h, w = img_list[indices[ino]].shape[1:]\n                if h > w * 1.2:\n                    img = img_list[indices[ino]]\n                    img = torch.transpose(img, 1, 2).flip(dims=[1])\n                    img_list[indices[ino]] = img\n                    h, w = img.shape[1:]\n                # wh_ratio = w * 1.0 / h\n                # max_wh_ratio = max(max_wh_ratio, wh_ratio)  # comment to not use different ratio\n            for ino in range(beg_img_no, end_img_no):\n                norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)\n                if self.use_fp16:\n                    norm_img = norm_img.half()\n                norm_img = norm_img.unsqueeze(0)\n                norm_img_batch.append(norm_img)\n            norm_img_batch = torch.cat(norm_img_batch, dim=0)\n            if show_debug:\n                for i in range(len(norm_img_batch)):\n                    _img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()\n                    _img = (_img + 0.5) * 255\n                    _img = _img[:, :, ::-1]\n                    file_name = f\"{indices[beg_img_no + i]}\"\n                    if os.path.exists(file_name + \".jpg\"):\n                        file_name += \"_2\"  # ori image\n                    cv2.imwrite(file_name + \".jpg\", _img)\n            if self.is_onnx:\n                input_dict = {}\n                input_dict[self.predictor.get_inputs()[0].name] = norm_img_batch.detach().cpu().numpy()\n                outputs = self.predictor.run(None, input_dict)\n                preds = {}\n                preds[\"ctc\"] = torch.from_numpy(outputs[0])\n                preds[\"ctc_neck\"] = [torch.zeros(1)] * img_num\n            else:\n                preds = self.predictor(norm_img_batch.to(next(self.predictor.parameters()).device))\n            for rno in range(preds[\"ctc\"].shape[0]):\n                preds_all[indices[beg_img_no + rno]] = preds[\"ctc\"][rno]\n                preds_neck_all[indices[beg_img_no + rno]] = preds[\"ctc_neck\"][rno]\n\n        return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)\n\n    def get_char_dict(self, character_dict_path):\n        character_str = []\n        with open(character_dict_path, \"rb\") as fin:\n            lines = fin.readlines()\n            for line in lines:\n                line = line.decode(\"utf-8\").strip(\"\\n\").strip(\"\\r\\n\")\n                character_str.append(line)\n        dict_character = list(character_str)\n        dict_character = [\"sos\"] + dict_character + [\" \"]  # eos is space\n        return dict_character\n\n    def get_text(self, order):\n        char_list = [self.chars[text_id] for text_id in order]\n        return \"\".join(char_list)\n\n    def decode(self, mat):\n        text_index = mat.detach().cpu().numpy().argmax(axis=1)\n        ignored_tokens = [0]\n        selection = np.ones(len(text_index), dtype=bool)\n        selection[1:] = text_index[1:] != text_index[:-1]\n        for ignored_token in ignored_tokens:\n            selection &= text_index != ignored_token\n        return text_index[selection], np.where(selection)[0]\n\n    def get_ctcloss(self, preds, gt_text, weight):\n        if not isinstance(weight, torch.Tensor):\n            weight = torch.tensor(weight).to(preds.device)\n        ctc_loss = torch.nn.CTCLoss(reduction=\"none\")\n        log_probs = preds.log_softmax(dim=2).permute(1, 0, 2)  # NTC-->TNC\n        targets = []\n        target_lengths = []\n        for t in gt_text:\n            targets += [self.char2id.get(i, len(self.chars) - 1) for i in t]\n            target_lengths += [len(t)]\n        targets = torch.tensor(targets).to(preds.device)\n        target_lengths = torch.tensor(target_lengths).to(preds.device)\n        input_lengths = torch.tensor([log_probs.shape[0]] * (log_probs.shape[1])).to(preds.device)\n        loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)\n        loss = loss / input_lengths * weight\n        return loss\n\n\nclass AbstractEncoder(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def encode(self, *args, **kwargs):\n        raise NotImplementedError\n\n\nclass FrozenCLIPEmbedderT3(AbstractEncoder, ModelMixin, ConfigMixin):\n    \"\"\"Uses the CLIP transformer encoder for text (from Hugging Face)\"\"\"\n\n    @register_to_config\n    def __init__(\n        self,\n        device=\"cpu\",\n        max_length=77,\n        freeze=True,\n        use_fp16=False,\n        variant: str | None = None,\n    ):\n        super().__init__()\n        self.tokenizer = CLIPTokenizer.from_pretrained(\"tolgacangoz/anytext\", subfolder=\"tokenizer\")\n        self.transformer = CLIPTextModel.from_pretrained(\n            \"tolgacangoz/anytext\",\n            subfolder=\"text_encoder\",\n            torch_dtype=torch.float16 if use_fp16 else torch.float32,\n            variant=\"fp16\" if use_fp16 else None,\n        )\n\n        if freeze:\n            self.freeze()\n\n        def embedding_forward(\n            self,\n            input_ids=None,\n            position_ids=None,\n            inputs_embeds=None,\n            embedding_manager=None,\n        ):\n            seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]\n            if position_ids is None:\n                position_ids = self.position_ids[:, :seq_length]\n            if inputs_embeds is None:\n                inputs_embeds = self.token_embedding(input_ids)\n            if embedding_manager is not None:\n                inputs_embeds = embedding_manager(input_ids, inputs_embeds)\n            position_embeddings = self.position_embedding(position_ids)\n            embeddings = inputs_embeds + position_embeddings\n            return embeddings\n\n        self.transformer.text_model.embeddings.forward = embedding_forward.__get__(\n            self.transformer.text_model.embeddings\n        )\n\n        def encoder_forward(\n            self,\n            inputs_embeds,\n            attention_mask=None,\n            causal_attention_mask=None,\n            output_attentions=None,\n            output_hidden_states=None,\n            return_dict=None,\n        ):\n            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n            output_hidden_states = (\n                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n            )\n            return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n            encoder_states = () if output_hidden_states else None\n            all_attentions = () if output_attentions else None\n            hidden_states = inputs_embeds\n            for idx, encoder_layer in enumerate(self.layers):\n                if output_hidden_states:\n                    encoder_states = encoder_states + (hidden_states,)\n                layer_outputs = encoder_layer(\n                    hidden_states,\n                    attention_mask,\n                    causal_attention_mask,\n                    output_attentions=output_attentions,\n                )\n                hidden_states = layer_outputs[0]\n                if output_attentions:\n                    all_attentions = all_attentions + (layer_outputs[1],)\n            if output_hidden_states:\n                encoder_states = encoder_states + (hidden_states,)\n            return hidden_states\n\n        self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)\n\n        def text_encoder_forward(\n            self,\n            input_ids=None,\n            attention_mask=None,\n            position_ids=None,\n            output_attentions=None,\n            output_hidden_states=None,\n            return_dict=None,\n            embedding_manager=None,\n        ):\n            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n            output_hidden_states = (\n                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n            )\n            return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n            if input_ids is None:\n                raise ValueError(\"You have to specify either input_ids\")\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n            hidden_states = self.embeddings(\n                input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager\n            )\n            # CLIP's text model uses causal mask, prepare it here.\n            # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324\n            causal_attention_mask = _create_4d_causal_attention_mask(\n                input_shape, hidden_states.dtype, device=hidden_states.device\n            )\n            # expand attention_mask\n            if attention_mask is not None:\n                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n                attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)\n            last_hidden_state = self.encoder(\n                inputs_embeds=hidden_states,\n                attention_mask=attention_mask,\n                causal_attention_mask=causal_attention_mask,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n            last_hidden_state = self.final_layer_norm(last_hidden_state)\n            return last_hidden_state\n\n        self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)\n\n        def transformer_forward(\n            self,\n            input_ids=None,\n            attention_mask=None,\n            position_ids=None,\n            output_attentions=None,\n            output_hidden_states=None,\n            return_dict=None,\n            embedding_manager=None,\n        ):\n            return self.text_model(\n                input_ids=input_ids,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n                embedding_manager=embedding_manager,\n            )\n\n        self.transformer.forward = transformer_forward.__get__(self.transformer)\n\n    def freeze(self):\n        self.transformer = self.transformer.eval()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, text, **kwargs):\n        batch_encoding = self.tokenizer(\n            text,\n            truncation=False,\n            max_length=self.config.max_length,\n            return_length=True,\n            return_overflowing_tokens=False,\n            padding=\"longest\",\n            return_tensors=\"pt\",\n        )\n        input_ids = batch_encoding[\"input_ids\"]\n        tokens_list = self.split_chunks(input_ids)\n        z_list = []\n        for tokens in tokens_list:\n            tokens = tokens.to(self.device)\n            _z = self.transformer(input_ids=tokens, **kwargs)\n            z_list += [_z]\n        return torch.cat(z_list, dim=1)\n\n    def encode(self, text, **kwargs):\n        return self(text, **kwargs)\n\n    def split_chunks(self, input_ids, chunk_size=75):\n        tokens_list = []\n        bs, n = input_ids.shape\n        id_start = input_ids[:, 0].unsqueeze(1)  # dim --> [bs, 1]\n        id_end = input_ids[:, -1].unsqueeze(1)\n        if n == 2:  # empty caption\n            tokens_list.append(torch.cat((id_start,) + (id_end,) * (chunk_size + 1), dim=1))\n\n        trimmed_encoding = input_ids[:, 1:-1]\n        num_full_groups = (n - 2) // chunk_size\n\n        for i in range(num_full_groups):\n            group = trimmed_encoding[:, i * chunk_size : (i + 1) * chunk_size]\n            group_pad = torch.cat((id_start, group, id_end), dim=1)\n            tokens_list.append(group_pad)\n\n        remaining_columns = (n - 2) % chunk_size\n        if remaining_columns > 0:\n            remaining_group = trimmed_encoding[:, -remaining_columns:]\n            padding_columns = chunk_size - remaining_group.shape[1]\n            padding = id_end.expand(bs, padding_columns)\n            remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1)\n            tokens_list.append(remaining_group_pad)\n        return tokens_list\n\n\nclass TextEmbeddingModule(ModelMixin, ConfigMixin):\n    @register_to_config\n    def __init__(self, font_path, use_fp16=False, device=\"cpu\"):\n        super().__init__()\n        font = ImageFont.truetype(font_path, 60)\n\n        self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)\n        self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)\n        self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()\n        args = {\n            \"rec_image_shape\": \"3, 48, 320\",\n            \"rec_batch_num\": 6,\n            \"rec_char_dict_path\": hf_hub_download(\n                repo_id=\"tolgacangoz/anytext\",\n                filename=\"text_embedding_module/OCR/ppocr_keys_v1.txt\",\n                cache_dir=HF_MODULES_CACHE,\n            ),\n            \"use_fp16\": use_fp16,\n        }\n        self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)\n\n        self.register_to_config(font=font)\n\n    @torch.no_grad()\n    def forward(\n        self,\n        prompt,\n        texts,\n        negative_prompt,\n        num_images_per_prompt,\n        mode,\n        draw_pos,\n        sort_priority=\"↕\",\n        max_chars=77,\n        revise_pos=False,\n        h=512,\n        w=512,\n    ):\n        if prompt is None and texts is None:\n            raise ValueError(\"Prompt or texts must be provided!\")\n        # preprocess pos_imgs(if numpy, make sure it's white pos in black bg)\n        if draw_pos is None:\n            pos_imgs = np.zeros((w, h, 1))\n        if isinstance(draw_pos, PIL.Image.Image):\n            pos_imgs = np.array(draw_pos)[..., ::-1]\n            pos_imgs = 255 - pos_imgs\n        elif isinstance(draw_pos, str):\n            draw_pos = cv2.imread(draw_pos)[..., ::-1]\n            if draw_pos is None:\n                raise ValueError(f\"Can't read draw_pos image from {draw_pos}!\")\n            pos_imgs = 255 - draw_pos\n        elif isinstance(draw_pos, torch.Tensor):\n            pos_imgs = draw_pos.cpu().numpy()\n        else:\n            if not isinstance(draw_pos, np.ndarray):\n                raise ValueError(f\"Unknown format of draw_pos: {type(draw_pos)}\")\n        if mode == \"edit\":\n            pos_imgs = cv2.resize(pos_imgs, (w, h))\n        pos_imgs = pos_imgs[..., 0:1]\n        pos_imgs = cv2.convertScaleAbs(pos_imgs)\n        _, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY)\n        # separate pos_imgs\n        pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority)\n        if len(pos_imgs) == 0:\n            pos_imgs = [np.zeros((h, w, 1))]\n        n_lines = len(texts)\n        if len(pos_imgs) < n_lines:\n            if n_lines == 1 and texts[0] == \" \":\n                pass  # text-to-image without text\n            else:\n                raise ValueError(\n                    f\"Found {len(pos_imgs)} positions that < needed {n_lines} from prompt, check and try again!\"\n                )\n        elif len(pos_imgs) > n_lines:\n            str_warning = f\"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt.\"\n            logger.warning(str_warning)\n        # get pre_pos, poly_list, hint that needed for anytext\n        pre_pos = []\n        poly_list = []\n        for input_pos in pos_imgs:\n            if input_pos.mean() != 0:\n                input_pos = input_pos[..., np.newaxis] if len(input_pos.shape) == 2 else input_pos\n                poly, pos_img = self.find_polygon(input_pos)\n                pre_pos += [pos_img / 255.0]\n                poly_list += [poly]\n            else:\n                pre_pos += [np.zeros((h, w, 1))]\n                poly_list += [None]\n        np_hint = np.sum(pre_pos, axis=0).clip(0, 1)\n        # prepare info dict\n        text_info = {}\n        text_info[\"glyphs\"] = []\n        text_info[\"gly_line\"] = []\n        text_info[\"positions\"] = []\n        text_info[\"n_lines\"] = [len(texts)] * num_images_per_prompt\n        for i in range(len(texts)):\n            text = texts[i]\n            if len(text) > max_chars:\n                str_warning = f'\"{text}\" length > max_chars: {max_chars}, will be cut off...'\n                logger.warning(str_warning)\n                text = text[:max_chars]\n            gly_scale = 2\n            if pre_pos[i].mean() != 0:\n                gly_line = self.draw_glyph(self.config.font, text)\n                glyphs = self.draw_glyph2(\n                    self.config.font, text, poly_list[i], scale=gly_scale, width=w, height=h, add_space=False\n                )\n                if revise_pos:\n                    resize_gly = cv2.resize(glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]))\n                    new_pos = cv2.morphologyEx(\n                        (resize_gly * 255).astype(np.uint8),\n                        cv2.MORPH_CLOSE,\n                        kernel=np.ones((resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), dtype=np.uint8),\n                        iterations=1,\n                    )\n                    new_pos = new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos\n                    contours, _ = cv2.findContours(new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)\n                    if len(contours) != 1:\n                        str_warning = f\"Fail to revise position {i} to bounding rect, remain position unchanged...\"\n                        logger.warning(str_warning)\n                    else:\n                        rect = cv2.minAreaRect(contours[0])\n                        poly = np.int0(cv2.boxPoints(rect))\n                        pre_pos[i] = cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0\n            else:\n                glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))\n                gly_line = np.zeros((80, 512, 1))\n            pos = pre_pos[i]\n            text_info[\"glyphs\"] += [self.arr2tensor(glyphs, num_images_per_prompt)]\n            text_info[\"gly_line\"] += [self.arr2tensor(gly_line, num_images_per_prompt)]\n            text_info[\"positions\"] += [self.arr2tensor(pos, num_images_per_prompt)]\n\n        self.embedding_manager.encode_text(text_info)\n        prompt_embeds = self.frozen_CLIP_embedder_t3.encode([prompt], embedding_manager=self.embedding_manager)\n\n        self.embedding_manager.encode_text(text_info)\n        negative_prompt_embeds = self.frozen_CLIP_embedder_t3.encode(\n            [negative_prompt or \"\"], embedding_manager=self.embedding_manager\n        )\n\n        return prompt_embeds, negative_prompt_embeds, text_info, np_hint\n\n    def arr2tensor(self, arr, bs):\n        arr = np.transpose(arr, (2, 0, 1))\n        _arr = torch.from_numpy(arr.copy()).float().cpu()\n        if self.config.use_fp16:\n            _arr = _arr.half()\n        _arr = torch.stack([_arr for _ in range(bs)], dim=0)\n        return _arr\n\n    def separate_pos_imgs(self, img, sort_priority, gap=102):\n        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img)\n        components = []\n        for label in range(1, num_labels):\n            component = np.zeros_like(img)\n            component[labels == label] = 255\n            components.append((component, centroids[label]))\n        if sort_priority == \"↕\":\n            fir, sec = 1, 0  # top-down first\n        elif sort_priority == \"↔\":\n            fir, sec = 0, 1  # left-right first\n        else:\n            raise ValueError(f\"Unknown sort_priority: {sort_priority}\")\n        components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap))\n        sorted_components = [c[0] for c in components]\n        return sorted_components\n\n    def find_polygon(self, image, min_rect=False):\n        contours, hierarchy = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)\n        max_contour = max(contours, key=cv2.contourArea)  # get contour with max area\n        if min_rect:\n            # get minimum enclosing rectangle\n            rect = cv2.minAreaRect(max_contour)\n            poly = np.int0(cv2.boxPoints(rect))\n        else:\n            # get approximate polygon\n            epsilon = 0.01 * cv2.arcLength(max_contour, True)\n            poly = cv2.approxPolyDP(max_contour, epsilon, True)\n            n, _, xy = poly.shape\n            poly = poly.reshape(n, xy)\n        cv2.drawContours(image, [poly], -1, 255, -1)\n        return poly, image\n\n    def draw_glyph(self, font, text):\n        g_size = 50\n        W, H = (512, 80)\n        new_font = font.font_variant(size=g_size)\n        img = Image.new(mode=\"1\", size=(W, H), color=0)\n        draw = ImageDraw.Draw(img)\n        left, top, right, bottom = new_font.getbbox(text)\n        text_width = max(right - left, 5)\n        text_height = max(bottom - top, 5)\n        ratio = min(W * 0.9 / text_width, H * 0.9 / text_height)\n        new_font = font.font_variant(size=int(g_size * ratio))\n\n        left, top, right, bottom = new_font.getbbox(text)\n        text_width = right - left\n        text_height = bottom - top\n        x = (img.width - text_width) // 2\n        y = (img.height - text_height) // 2 - top // 2\n        draw.text((x, y), text, font=new_font, fill=\"white\")\n        img = np.expand_dims(np.array(img), axis=2).astype(np.float64)\n        return img\n\n    def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, height=512, add_space=True):\n        enlarge_polygon = polygon * scale\n        rect = cv2.minAreaRect(enlarge_polygon)\n        box = cv2.boxPoints(rect)\n        box = np.int0(box)\n        w, h = rect[1]\n        angle = rect[2]\n        if angle < -45:\n            angle += 90\n        angle = -angle\n        if w < h:\n            angle += 90\n\n        vert = False\n        if abs(angle) % 90 < vertAng or abs(90 - abs(angle) % 90) % 90 < vertAng:\n            _w = max(box[:, 0]) - min(box[:, 0])\n            _h = max(box[:, 1]) - min(box[:, 1])\n            if _h >= _w:\n                vert = True\n                angle = 0\n\n        img = np.zeros((height * scale, width * scale, 3), np.uint8)\n        img = Image.fromarray(img)\n\n        # infer font size\n        image4ratio = Image.new(\"RGB\", img.size, \"white\")\n        draw = ImageDraw.Draw(image4ratio)\n        _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)\n        text_w = min(w, h) * (_tw / _th)\n        if text_w <= max(w, h):\n            # add space\n            if len(text) > 1 and not vert and add_space:\n                for i in range(1, 100):\n                    text_space = self.insert_spaces(text, i)\n                    _, _, _tw2, _th2 = draw.textbbox(xy=(0, 0), text=text_space, font=font)\n                    if min(w, h) * (_tw2 / _th2) > max(w, h):\n                        break\n                text = self.insert_spaces(text, i - 1)\n            font_size = min(w, h) * 0.80\n        else:\n            shrink = 0.75 if vert else 0.85\n            font_size = min(w, h) / (text_w / max(w, h)) * shrink\n        new_font = font.font_variant(size=int(font_size))\n\n        left, top, right, bottom = new_font.getbbox(text)\n        text_width = right - left\n        text_height = bottom - top\n\n        layer = Image.new(\"RGBA\", img.size, (0, 0, 0, 0))\n        draw = ImageDraw.Draw(layer)\n        if not vert:\n            draw.text(\n                (rect[0][0] - text_width // 2, rect[0][1] - text_height // 2 - top),\n                text,\n                font=new_font,\n                fill=(255, 255, 255, 255),\n            )\n        else:\n            x_s = min(box[:, 0]) + _w // 2 - text_height // 2\n            y_s = min(box[:, 1])\n            for c in text:\n                draw.text((x_s, y_s), c, font=new_font, fill=(255, 255, 255, 255))\n                _, _t, _, _b = new_font.getbbox(c)\n                y_s += _b\n\n        rotated_layer = layer.rotate(angle, expand=1, center=(rect[0][0], rect[0][1]))\n\n        x_offset = int((img.width - rotated_layer.width) / 2)\n        y_offset = int((img.height - rotated_layer.height) / 2)\n        img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)\n        img = np.expand_dims(np.array(img.convert(\"1\")), axis=2).astype(np.float64)\n        return img\n\n    def insert_spaces(self, string, nSpace):\n        if nSpace == 0:\n            return string\n        new_string = \"\"\n        for char in string:\n            new_string += char + \" \" * nSpace\n        return new_string[:-nSpace]\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\nclass AuxiliaryLatentModule(ModelMixin, ConfigMixin):\n    @register_to_config\n    def __init__(\n        self,\n        vae,\n        device=\"cpu\",\n    ):\n        super().__init__()\n\n    @torch.no_grad()\n    def forward(\n        self,\n        text_info,\n        mode,\n        draw_pos,\n        ori_image,\n        num_images_per_prompt,\n        np_hint,\n        h=512,\n        w=512,\n    ):\n        if mode == \"generate\":\n            edit_image = np.ones((h, w, 3)) * 127.5  # empty mask image\n        elif mode == \"edit\":\n            if draw_pos is None or ori_image is None:\n                raise ValueError(\"Reference image and position image are needed for text editing!\")\n            if isinstance(ori_image, str):\n                ori_image = cv2.imread(ori_image)[..., ::-1]\n                if ori_image is None:\n                    raise ValueError(f\"Can't read ori_image image from {ori_image}!\")\n            elif isinstance(ori_image, torch.Tensor):\n                ori_image = ori_image.cpu().numpy()\n            elif isinstance(ori_image, PIL.Image.Image):\n                ori_image = np.array(ori_image.convert(\"RGB\"))\n            else:\n                if not isinstance(ori_image, np.ndarray):\n                    raise ValueError(f\"Unknown format of ori_image: {type(ori_image)}\")\n            edit_image = ori_image.clip(1, 255)  # for mask reason\n            edit_image = self.check_channels(edit_image)\n            edit_image = self.resize_image(\n                edit_image, max_length=768\n            )  # make w h multiple of 64, resize if w or h > max_length\n\n        # get masked_x\n        masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint)\n        masked_img = np.transpose(masked_img, (2, 0, 1))\n        device = next(self.config.vae.parameters()).device\n        dtype = next(self.config.vae.parameters()).dtype\n        masked_img = torch.from_numpy(masked_img.copy()).float().to(device)\n        if dtype == torch.float16:\n            masked_img = masked_img.half()\n        masked_x = (\n            retrieve_latents(self.config.vae.encode(masked_img[None, ...])) * self.config.vae.config.scaling_factor\n        ).detach()\n        if dtype == torch.float16:\n            masked_x = masked_x.half()\n        text_info[\"masked_x\"] = torch.cat([masked_x for _ in range(num_images_per_prompt)], dim=0)\n\n        glyphs = torch.cat(text_info[\"glyphs\"], dim=1).sum(dim=1, keepdim=True)\n        positions = torch.cat(text_info[\"positions\"], dim=1).sum(dim=1, keepdim=True)\n\n        return glyphs, positions, text_info\n\n    def check_channels(self, image):\n        channels = image.shape[2] if len(image.shape) == 3 else 1\n        if channels == 1:\n            image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)\n        elif channels > 3:\n            image = image[:, :, :3]\n        return image\n\n    def resize_image(self, img, max_length=768):\n        height, width = img.shape[:2]\n        max_dimension = max(height, width)\n\n        if max_dimension > max_length:\n            scale_factor = max_length / max_dimension\n            new_width = int(round(width * scale_factor))\n            new_height = int(round(height * scale_factor))\n            new_size = (new_width, new_height)\n            img = cv2.resize(img, new_size)\n        height, width = img.shape[:2]\n        img = cv2.resize(img, (width - (width % 64), height - (height % 64)))\n        return img\n\n    def insert_spaces(self, string, nSpace):\n        if nSpace == 0:\n            return string\n        new_string = \"\"\n        for char in string:\n            new_string += char + \" \" * nSpace\n        return new_string[:-nSpace]\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass AnyTextPipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    StableDiffusionLoraLoaderMixin,\n    IPAdapterMixin,\n    FromSingleFileMixin,\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):\n            Provides additional conditioning to the `unet` during the denoising process. If you set multiple\n            ControlNets as a list, the outputs from each ControlNet are added together to create one combined\n            additional conditioning.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        font_path: str = None,\n        text_embedding_module: Optional[TextEmbeddingModule] = None,\n        auxiliary_latent_module: Optional[AuxiliaryLatentModule] = None,\n        trust_remote_code: bool = False,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n        if font_path is None:\n            raise ValueError(\"font_path is required!\")\n\n        text_embedding_module = TextEmbeddingModule(font_path=font_path, use_fp16=unet.dtype == torch.float16)\n        auxiliary_latent_module = AuxiliaryLatentModule(vae=vae)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n            text_embedding_module=text_embedding_module,\n            auxiliary_latent_module=auxiliary_latent_module,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def modify_prompt(self, prompt):\n        prompt = prompt.replace(\"“\", '\"')\n        prompt = prompt.replace(\"”\", '\"')\n        p = '\"(.*?)\"'\n        strs = re.findall(p, prompt)\n        if len(strs) == 0:\n            strs = [\" \"]\n        else:\n            for s in strs:\n                prompt = prompt.replace(f'\"{s}\"', f\" {PLACE_HOLDER} \", 1)\n        if self.is_chinese(prompt):\n            if self.trans_pipe is None:\n                return None, None\n            old_prompt = prompt\n            prompt = self.trans_pipe(input=prompt + \" .\")[\"translation\"][:-1]\n            print(f\"Translate: {old_prompt} --> {prompt}\")\n        return prompt, strs\n\n    def is_chinese(self, text):\n        text = checker._clean_text(text)\n        for char in text:\n            cp = ord(char)\n            if checker._is_chinese_char(cp):\n                return True\n        return False\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ):\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if self.text_encoder is not None:\n            if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        image_embeds = []\n        if do_classifier_free_guidance:\n            negative_image_embeds = []\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n\n                image_embeds.append(single_image_embeds[None, :])\n                if do_classifier_free_guidance:\n                    negative_image_embeds.append(single_negative_image_embeds[None, :])\n        else:\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    negative_image_embeds.append(single_negative_image_embeds)\n                image_embeds.append(single_image_embeds)\n\n        ip_adapter_image_embeds = []\n        for i, single_image_embeds in enumerate(image_embeds):\n            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)\n            if do_classifier_free_guidance:\n                single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)\n                single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)\n\n            single_image_embeds = single_image_embeds.to(device=device)\n            ip_adapter_image_embeds.append(single_image_embeds)\n\n        return ip_adapter_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents\n    def decode_latents(self, latents):\n        deprecation_message = \"The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead\"\n        deprecate(\"decode_latents\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        # image,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        controlnet_conditioning_scale=1.0,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        # Check `image`\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                print(controlnet_conditioning_scale)\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\n                        \"A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. \"\n                        \"The conditioning scale must be fixed across the batch.\"\n                    )\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if not isinstance(control_guidance_start, (tuple, list)):\n            control_guidance_start = [control_guidance_start]\n\n        if not isinstance(control_guidance_end, (tuple, list)):\n            control_guidance_end = [control_guidance_end]\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if len(control_guidance_start) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n        if ip_adapter_image_embeds is not None:\n            if not isinstance(ip_adapter_image_embeds, list):\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\n                )\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\n                )\n\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    def prepare_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(\n        self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32\n    ) -> torch.Tensor:\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            w (`torch.Tensor`):\n                Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.\n            embedding_dim (`int`, *optional*, defaults to 512):\n                Dimension of the embeddings to generate.\n            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):\n                Data type of the generated embeddings.\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        mode: str | None = \"generate\",\n        draw_pos: Optional[Union[str, torch.Tensor]] = None,\n        ori_image: Optional[Union[str, torch.Tensor]] = None,\n        timesteps: List[int] = None,\n        sigmas: List[float] = None,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is\n                specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted\n                as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or\n                width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,\n                images must be passed as a list such that each element of the list can be correctly batched for input\n                to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single\n                ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple\n                ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet.\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that calls every `callback_steps` steps during inference. The function is called with the\n                following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function is called. If not specified, the callback is called at\n                every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set\n                the corresponding scale as a list.\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                The ControlNet encoder tries to recognize the content of the input image even if you remove all\n                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the ControlNet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the ControlNet stops applying.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            # image,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        global_pool_conditions = (\n            controlnet.config.global_pool_conditions\n            if isinstance(controlnet, ControlNetModel)\n            else controlnet.nets[0].config.global_pool_conditions\n        )\n        guess_mode = guess_mode or global_pool_conditions\n\n        prompt, texts = self.modify_prompt(prompt)\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos\n        prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module(\n            prompt,\n            texts,\n            negative_prompt,\n            num_images_per_prompt,\n            mode,\n            draw_pos,\n        )\n\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 3.5 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 4. Prepare image\n        if isinstance(controlnet, ControlNetModel):\n            guided_hint = self.auxiliary_latent_module(\n                text_info=text_info,\n                mode=mode,\n                draw_pos=draw_pos,\n                ori_image=ori_image,\n                num_images_per_prompt=num_images_per_prompt,\n                np_hint=np_hint,\n            )\n            height, width = 512, 512\n        else:\n            assert False\n\n        # 5. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler, num_inference_steps, device, timesteps, sigmas\n        )\n        self._num_timesteps = len(timesteps)\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = (\n            {\"image_embeds\": image_embeds}\n            if ip_adapter_image is not None or ip_adapter_image_embeds is not None\n            else None\n        )\n\n        # 7.2 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        is_unet_compiled = is_compiled_module(self.unet)\n        is_controlnet_compiled = is_compiled_module(self.controlnet)\n        is_torch_higher_equal_2_1 = is_torch_version(\">=\", \"2.1\")\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # Relevant thread:\n                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428\n                if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:\n                    torch._inductor.cudagraph_mark_step_begin()\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # controlnet(s) inference\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                else:\n                    control_model_input = latent_model_input\n                    controlnet_prompt_embeds = prompt_embeds\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input.to(self.controlnet.dtype),\n                    t,\n                    encoder_hidden_states=controlnet_prompt_embeds,\n                    controlnet_cond=guided_hint,\n                    conditioning_scale=cond_scale,\n                    guess_mode=guess_mode,\n                    return_dict=False,\n                )\n\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Inferred ControlNet only for the conditional batch.\n                    # To apply the output of ControlNet to both the unconditional and conditional batches,\n                    # add 0 to the unconditional batch to keep it unchanged.\n                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\n                0\n            ]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n\n    def to(self, *args, **kwargs):\n        super().to(*args, **kwargs)\n        self.text_embedding_module.to(*args, **kwargs)\n        self.auxiliary_latent_module.to(*args, **kwargs)\n        return self\n"
  },
  {
    "path": "examples/research_projects/anytext/anytext_controlnet.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054).\n# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie\n# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license\n#\n# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).\n\n\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\n\nfrom diffusers.configuration_utils import register_to_config\nfrom diffusers.models.controlnets.controlnet import (\n    ControlNetModel,\n    ControlNetOutput,\n)\nfrom diffusers.utils import logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass AnyTextControlNetConditioningEmbedding(nn.Module):\n    \"\"\"\n    Quoting from https://huggingface.co/papers/2302.05543: \"Stable Diffusion uses a pre-processing method similar to VQ-GAN\n    [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized\n    training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the\n    convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides\n    (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full\n    model) to encode image-space conditions ... into feature maps ...\"\n    \"\"\"\n\n    def __init__(\n        self,\n        conditioning_embedding_channels: int,\n        glyph_channels=1,\n        position_channels=1,\n    ):\n        super().__init__()\n\n        self.glyph_block = nn.Sequential(\n            nn.Conv2d(glyph_channels, 8, 3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(8, 8, 3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(8, 16, 3, padding=1, stride=2),\n            nn.SiLU(),\n            nn.Conv2d(16, 16, 3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(16, 32, 3, padding=1, stride=2),\n            nn.SiLU(),\n            nn.Conv2d(32, 32, 3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(32, 96, 3, padding=1, stride=2),\n            nn.SiLU(),\n            nn.Conv2d(96, 96, 3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(96, 256, 3, padding=1, stride=2),\n            nn.SiLU(),\n        )\n\n        self.position_block = nn.Sequential(\n            nn.Conv2d(position_channels, 8, 3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(8, 8, 3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(8, 16, 3, padding=1, stride=2),\n            nn.SiLU(),\n            nn.Conv2d(16, 16, 3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(16, 32, 3, padding=1, stride=2),\n            nn.SiLU(),\n            nn.Conv2d(32, 32, 3, padding=1),\n            nn.SiLU(),\n            nn.Conv2d(32, 64, 3, padding=1, stride=2),\n            nn.SiLU(),\n        )\n\n        self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1)\n\n    def forward(self, glyphs, positions, text_info):\n        glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device))\n        position_embedding = self.position_block(positions.to(self.position_block[0].weight.device))\n        guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info[\"masked_x\"]], dim=1))\n\n        return guided_hint\n\n\nclass AnyTextControlNetModel(ControlNetModel):\n    \"\"\"\n    A AnyTextControlNetModel model.\n\n    Args:\n        in_channels (`int`, defaults to 4):\n            The number of channels in the input sample.\n        flip_sin_to_cos (`bool`, defaults to `True`):\n            Whether to flip the sin to cos in the time embedding.\n        freq_shift (`int`, defaults to 0):\n            The frequency shift to apply to the time embedding.\n        down_block_types (`tuple[str]`, defaults to `(\"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"DownBlock2D\")`):\n            The tuple of downsample blocks to use.\n        only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):\n        block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        layers_per_block (`int`, defaults to 2):\n            The number of layers per block.\n        downsample_padding (`int`, defaults to 1):\n            The padding to use for the downsampling convolution.\n        mid_block_scale_factor (`float`, defaults to 1):\n            The scale factor to use for the mid block.\n        act_fn (`str`, defaults to \"silu\"):\n            The activation function to use.\n        norm_num_groups (`int`, *optional*, defaults to 32):\n            The number of groups to use for the normalization. If None, normalization and activation layers is skipped\n            in post-processing.\n        norm_eps (`float`, defaults to 1e-5):\n            The epsilon to use for the normalization.\n        cross_attention_dim (`int`, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n        encoder_hid_dim (`int`, *optional*, defaults to None):\n            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`\n            dimension to `cross_attention_dim`.\n        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):\n            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text\n            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.\n        attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):\n            The dimension of the attention heads.\n        use_linear_projection (`bool`, defaults to `False`):\n        class_embed_type (`str`, *optional*, defaults to `None`):\n            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,\n            `\"timestep\"`, `\"identity\"`, `\"projection\"`, or `\"simple_projection\"`.\n        addition_embed_type (`str`, *optional*, defaults to `None`):\n            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or\n            \"text\". \"text\" will use the `TextTimeEmbedding` layer.\n        num_class_embeds (`int`, *optional*, defaults to 0):\n            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing\n            class conditioning with `class_embed_type` equal to `None`.\n        upcast_attention (`bool`, defaults to `False`):\n        resnet_time_scale_shift (`str`, defaults to `\"default\"`):\n            Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.\n        projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):\n            The dimension of the `class_labels` input when `class_embed_type=\"projection\"`. Required when\n            `class_embed_type=\"projection\"`.\n        controlnet_conditioning_channel_order (`str`, defaults to `\"rgb\"`):\n            The channel order of conditional image. Will convert to `rgb` if it's `bgr`.\n        conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):\n            The tuple of output channel for each block in the `conditioning_embedding` layer.\n        global_pool_conditions (`bool`, defaults to `False`):\n            TODO(Patrick) - unused parameter.\n        addition_embed_type_num_heads (`int`, defaults to 64):\n            The number of heads to use for the `TextTimeEmbedding` layer.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        in_channels: int = 4,\n        conditioning_channels: int = 1,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str, ...] = (\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"DownBlock2D\",\n        ),\n        mid_block_type: str | None = \"UNetMidBlock2DCrossAttn\",\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),\n        layers_per_block: int = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: int = 1280,\n        transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,\n        encoder_hid_dim: Optional[int] = None,\n        encoder_hid_dim_type: str | None = None,\n        attention_head_dim: Union[int, Tuple[int, ...]] = 8,\n        num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,\n        use_linear_projection: bool = False,\n        class_embed_type: str | None = None,\n        addition_embed_type: str | None = None,\n        addition_time_embed_dim: Optional[int] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        controlnet_conditioning_channel_order: str = \"rgb\",\n        conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),\n        global_pool_conditions: bool = False,\n        addition_embed_type_num_heads: int = 64,\n    ):\n        super().__init__(\n            in_channels,\n            conditioning_channels,\n            flip_sin_to_cos,\n            freq_shift,\n            down_block_types,\n            mid_block_type,\n            only_cross_attention,\n            block_out_channels,\n            layers_per_block,\n            downsample_padding,\n            mid_block_scale_factor,\n            act_fn,\n            norm_num_groups,\n            norm_eps,\n            cross_attention_dim,\n            transformer_layers_per_block,\n            encoder_hid_dim,\n            encoder_hid_dim_type,\n            attention_head_dim,\n            num_attention_heads,\n            use_linear_projection,\n            class_embed_type,\n            addition_embed_type,\n            addition_time_embed_dim,\n            num_class_embeds,\n            upcast_attention,\n            resnet_time_scale_shift,\n            projection_class_embeddings_input_dim,\n            controlnet_conditioning_channel_order,\n            conditioning_embedding_out_channels,\n            global_pool_conditions,\n            addition_embed_type_num_heads,\n        )\n\n        # control net conditioning embedding\n        self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding(\n            conditioning_embedding_channels=block_out_channels[0],\n            glyph_channels=conditioning_channels,\n            position_channels=conditioning_channels,\n        )\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        controlnet_cond: torch.Tensor,\n        conditioning_scale: float = 1.0,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guess_mode: bool = False,\n        return_dict: bool = True,\n    ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:\n        \"\"\"\n        The [`~PromptDiffusionControlNetModel`] forward method.\n\n        Args:\n            sample (`torch.Tensor`):\n                The noisy input tensor.\n            timestep (`Union[torch.Tensor, float, int]`):\n                The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.Tensor`):\n                The encoder hidden states.\n            #controlnet_cond (`torch.Tensor`):\n            #    The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.\n            conditioning_scale (`float`, defaults to `1.0`):\n                The scale factor for ControlNet outputs.\n            class_labels (`torch.Tensor`, *optional*, defaults to `None`):\n                Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.\n            timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):\n                Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the\n                timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep\n                embeddings.\n            attention_mask (`torch.Tensor`, *optional*, defaults to `None`):\n                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask\n                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large\n                negative values to the attention scores corresponding to \"discard\" tokens.\n            added_cond_kwargs (`dict`):\n                Additional conditions for the Stable Diffusion XL UNet.\n            cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):\n                A kwargs dictionary that if specified is passed along to the `AttnProcessor`.\n            guess_mode (`bool`, defaults to `False`):\n                In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if\n                you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.\n            return_dict (`bool`, defaults to `True`):\n                Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.controlnet.ControlNetOutput`] **or** `tuple`:\n                If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is\n                returned where the first element is the sample tensor.\n        \"\"\"\n        # check channel order\n        channel_order = self.config.controlnet_conditioning_channel_order\n\n        if channel_order == \"rgb\":\n            # in rgb order by default\n            ...\n        # elif channel_order == \"bgr\":\n        #    controlnet_cond = torch.flip(controlnet_cond, dims=[1])\n        else:\n            raise ValueError(f\"unknown `controlnet_conditioning_channel_order`: {channel_order}\")\n\n        # prepare attention_mask\n        if attention_mask is not None:\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)\n            emb = emb + class_emb\n\n        if self.config.addition_embed_type is not None:\n            if self.config.addition_embed_type == \"text\":\n                aug_emb = self.add_embedding(encoder_hidden_states)\n\n            elif self.config.addition_embed_type == \"text_time\":\n                if \"text_embeds\" not in added_cond_kwargs:\n                    raise ValueError(\n                        f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`\"\n                    )\n                text_embeds = added_cond_kwargs.get(\"text_embeds\")\n                if \"time_ids\" not in added_cond_kwargs:\n                    raise ValueError(\n                        f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`\"\n                    )\n                time_ids = added_cond_kwargs.get(\"time_ids\")\n                time_embeds = self.add_time_proj(time_ids.flatten())\n                time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))\n\n                add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)\n                add_embeds = add_embeds.to(emb.dtype)\n                aug_emb = self.add_embedding(add_embeds)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        controlnet_cond = self.controlnet_cond_embedding(*controlnet_cond)\n        sample = sample + controlnet_cond\n\n        # 3. down\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n\n            down_block_res_samples += res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            if hasattr(self.mid_block, \"has_cross_attention\") and self.mid_block.has_cross_attention:\n                sample = self.mid_block(\n                    sample,\n                    emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                )\n            else:\n                sample = self.mid_block(sample, emb)\n\n        # 5. Control net blocks\n        controlnet_down_block_res_samples = ()\n\n        for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):\n            down_block_res_sample = controlnet_block(down_block_res_sample)\n            controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)\n\n        down_block_res_samples = controlnet_down_block_res_samples\n\n        mid_block_res_sample = self.controlnet_mid_block(sample)\n\n        # 6. scaling\n        if guess_mode and not self.config.global_pool_conditions:\n            scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device)  # 0.1 to 1.0\n            scales = scales * conditioning_scale\n            down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]\n            mid_block_res_sample = mid_block_res_sample * scales[-1]  # last one\n        else:\n            down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]\n            mid_block_res_sample = mid_block_res_sample * conditioning_scale\n\n        if self.config.global_pool_conditions:\n            down_block_res_samples = [\n                torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples\n            ]\n            mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)\n\n        if not return_dict:\n            return (down_block_res_samples, mid_block_res_sample)\n\n        return ControlNetOutput(\n            down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample\n        )\n\n\n# Copied from diffusers.models.controlnet.zero_module\ndef zero_module(module):\n    for p in module.parameters():\n        nn.init.zeros_(p)\n    return module\n"
  },
  {
    "path": "examples/research_projects/anytext/ocr_recog/RNN.py",
    "content": "import torch\nfrom torch import nn\n\nfrom .RecSVTR import Block\n\n\nclass Swish(nn.Module):\n    def __int__(self):\n        super(Swish, self).__int__()\n\n    def forward(self, x):\n        return x * torch.sigmoid(x)\n\n\nclass Im2Im(nn.Module):\n    def __init__(self, in_channels, **kwargs):\n        super().__init__()\n        self.out_channels = in_channels\n\n    def forward(self, x):\n        return x\n\n\nclass Im2Seq(nn.Module):\n    def __init__(self, in_channels, **kwargs):\n        super().__init__()\n        self.out_channels = in_channels\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # assert H == 1\n        x = x.reshape(B, C, H * W)\n        x = x.permute((0, 2, 1))\n        return x\n\n\nclass EncoderWithRNN(nn.Module):\n    def __init__(self, in_channels, **kwargs):\n        super(EncoderWithRNN, self).__init__()\n        hidden_size = kwargs.get(\"hidden_size\", 256)\n        self.out_channels = hidden_size * 2\n        self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2, batch_first=True)\n\n    def forward(self, x):\n        self.lstm.flatten_parameters()\n        x, _ = self.lstm(x)\n        return x\n\n\nclass SequenceEncoder(nn.Module):\n    def __init__(self, in_channels, encoder_type=\"rnn\", **kwargs):\n        super(SequenceEncoder, self).__init__()\n        self.encoder_reshape = Im2Seq(in_channels)\n        self.out_channels = self.encoder_reshape.out_channels\n        self.encoder_type = encoder_type\n        if encoder_type == \"reshape\":\n            self.only_reshape = True\n        else:\n            support_encoder_dict = {\"reshape\": Im2Seq, \"rnn\": EncoderWithRNN, \"svtr\": EncoderWithSVTR}\n            assert encoder_type in support_encoder_dict, \"{} must in {}\".format(\n                encoder_type, support_encoder_dict.keys()\n            )\n\n            self.encoder = support_encoder_dict[encoder_type](self.encoder_reshape.out_channels, **kwargs)\n            self.out_channels = self.encoder.out_channels\n            self.only_reshape = False\n\n    def forward(self, x):\n        if self.encoder_type != \"svtr\":\n            x = self.encoder_reshape(x)\n            if not self.only_reshape:\n                x = self.encoder(x)\n            return x\n        else:\n            x = self.encoder(x)\n            x = self.encoder_reshape(x)\n            return x\n\n\nclass ConvBNLayer(nn.Module):\n    def __init__(\n        self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU\n    ):\n        super().__init__()\n        self.conv = nn.Conv2d(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n            groups=groups,\n            # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),\n            bias=bias_attr,\n        )\n        self.norm = nn.BatchNorm2d(out_channels)\n        self.act = Swish()\n\n    def forward(self, inputs):\n        out = self.conv(inputs)\n        out = self.norm(out)\n        out = self.act(out)\n        return out\n\n\nclass EncoderWithSVTR(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        dims=64,  # XS\n        depth=2,\n        hidden_dims=120,\n        use_guide=False,\n        num_heads=8,\n        qkv_bias=True,\n        mlp_ratio=2.0,\n        drop_rate=0.1,\n        attn_drop_rate=0.1,\n        drop_path=0.0,\n        qk_scale=None,\n    ):\n        super(EncoderWithSVTR, self).__init__()\n        self.depth = depth\n        self.use_guide = use_guide\n        self.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1, act=\"swish\")\n        self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1, act=\"swish\")\n\n        self.svtr_block = nn.ModuleList(\n            [\n                Block(\n                    dim=hidden_dims,\n                    num_heads=num_heads,\n                    mixer=\"Global\",\n                    HW=None,\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    act_layer=\"swish\",\n                    attn_drop=attn_drop_rate,\n                    drop_path=drop_path,\n                    norm_layer=\"nn.LayerNorm\",\n                    epsilon=1e-05,\n                    prenorm=False,\n                )\n                for i in range(depth)\n            ]\n        )\n        self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)\n        self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act=\"swish\")\n        # last conv-nxn, the input is concat of input tensor and conv3 output tensor\n        self.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1, act=\"swish\")\n\n        self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act=\"swish\")\n        self.out_channels = dims\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        # weight initialization\n        if isinstance(m, nn.Conv2d):\n            nn.init.kaiming_normal_(m.weight, mode=\"fan_out\")\n            if m.bias is not None:\n                nn.init.zeros_(m.bias)\n        elif isinstance(m, nn.BatchNorm2d):\n            nn.init.ones_(m.weight)\n            nn.init.zeros_(m.bias)\n        elif isinstance(m, nn.Linear):\n            nn.init.normal_(m.weight, 0, 0.01)\n            if m.bias is not None:\n                nn.init.zeros_(m.bias)\n        elif isinstance(m, nn.ConvTranspose2d):\n            nn.init.kaiming_normal_(m.weight, mode=\"fan_out\")\n            if m.bias is not None:\n                nn.init.zeros_(m.bias)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.ones_(m.weight)\n            nn.init.zeros_(m.bias)\n\n    def forward(self, x):\n        # for use guide\n        if self.use_guide:\n            z = x.clone()\n            z.stop_gradient = True\n        else:\n            z = x\n        # for short cut\n        h = z\n        # reduce dim\n        z = self.conv1(z)\n        z = self.conv2(z)\n        # SVTR global block\n        B, C, H, W = z.shape\n        z = z.flatten(2).permute(0, 2, 1)\n\n        for blk in self.svtr_block:\n            z = blk(z)\n\n        z = self.norm(z)\n        # last stage\n        z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)\n        z = self.conv3(z)\n        z = torch.cat((h, z), dim=1)\n        z = self.conv1x1(self.conv4(z))\n\n        return z\n\n\nif __name__ == \"__main__\":\n    svtrRNN = EncoderWithSVTR(56)\n    print(svtrRNN)\n"
  },
  {
    "path": "examples/research_projects/anytext/ocr_recog/RecCTCHead.py",
    "content": "from torch import nn\n\n\nclass CTCHead(nn.Module):\n    def __init__(\n        self, in_channels, out_channels=6625, fc_decay=0.0004, mid_channels=None, return_feats=False, **kwargs\n    ):\n        super(CTCHead, self).__init__()\n        if mid_channels is None:\n            self.fc = nn.Linear(\n                in_channels,\n                out_channels,\n                bias=True,\n            )\n        else:\n            self.fc1 = nn.Linear(\n                in_channels,\n                mid_channels,\n                bias=True,\n            )\n            self.fc2 = nn.Linear(\n                mid_channels,\n                out_channels,\n                bias=True,\n            )\n\n        self.out_channels = out_channels\n        self.mid_channels = mid_channels\n        self.return_feats = return_feats\n\n    def forward(self, x, labels=None):\n        if self.mid_channels is None:\n            predicts = self.fc(x)\n        else:\n            x = self.fc1(x)\n            predicts = self.fc2(x)\n\n        if self.return_feats:\n            result = {}\n            result[\"ctc\"] = predicts\n            result[\"ctc_neck\"] = x\n        else:\n            result = predicts\n\n        return result\n"
  },
  {
    "path": "examples/research_projects/anytext/ocr_recog/RecModel.py",
    "content": "from torch import nn\n\nfrom .RecCTCHead import CTCHead\nfrom .RecMv1_enhance import MobileNetV1Enhance\nfrom .RNN import Im2Im, Im2Seq, SequenceEncoder\n\n\nbackbone_dict = {\"MobileNetV1Enhance\": MobileNetV1Enhance}\nneck_dict = {\"SequenceEncoder\": SequenceEncoder, \"Im2Seq\": Im2Seq, \"None\": Im2Im}\nhead_dict = {\"CTCHead\": CTCHead}\n\n\nclass RecModel(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        assert \"in_channels\" in config, \"in_channels must in model config\"\n        backbone_type = config[\"backbone\"].pop(\"type\")\n        assert backbone_type in backbone_dict, f\"backbone.type must in {backbone_dict}\"\n        self.backbone = backbone_dict[backbone_type](config[\"in_channels\"], **config[\"backbone\"])\n\n        neck_type = config[\"neck\"].pop(\"type\")\n        assert neck_type in neck_dict, f\"neck.type must in {neck_dict}\"\n        self.neck = neck_dict[neck_type](self.backbone.out_channels, **config[\"neck\"])\n\n        head_type = config[\"head\"].pop(\"type\")\n        assert head_type in head_dict, f\"head.type must in {head_dict}\"\n        self.head = head_dict[head_type](self.neck.out_channels, **config[\"head\"])\n\n        self.name = f\"RecModel_{backbone_type}_{neck_type}_{head_type}\"\n\n    def load_3rd_state_dict(self, _3rd_name, _state):\n        self.backbone.load_3rd_state_dict(_3rd_name, _state)\n        self.neck.load_3rd_state_dict(_3rd_name, _state)\n        self.head.load_3rd_state_dict(_3rd_name, _state)\n\n    def forward(self, x):\n        import torch\n\n        x = x.to(torch.float32)\n        x = self.backbone(x)\n        x = self.neck(x)\n        x = self.head(x)\n        return x\n\n    def encode(self, x):\n        x = self.backbone(x)\n        x = self.neck(x)\n        x = self.head.ctc_encoder(x)\n        return x\n"
  },
  {
    "path": "examples/research_projects/anytext/ocr_recog/RecMv1_enhance.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .common import Activation\n\n\nclass ConvBNLayer(nn.Module):\n    def __init__(\n        self, num_channels, filter_size, num_filters, stride, padding, channels=None, num_groups=1, act=\"hard_swish\"\n    ):\n        super(ConvBNLayer, self).__init__()\n        self.act = act\n        self._conv = nn.Conv2d(\n            in_channels=num_channels,\n            out_channels=num_filters,\n            kernel_size=filter_size,\n            stride=stride,\n            padding=padding,\n            groups=num_groups,\n            bias=False,\n        )\n\n        self._batch_norm = nn.BatchNorm2d(\n            num_filters,\n        )\n        if self.act is not None:\n            self._act = Activation(act_type=act, inplace=True)\n\n    def forward(self, inputs):\n        y = self._conv(inputs)\n        y = self._batch_norm(y)\n        if self.act is not None:\n            y = self._act(y)\n        return y\n\n\nclass DepthwiseSeparable(nn.Module):\n    def __init__(\n        self, num_channels, num_filters1, num_filters2, num_groups, stride, scale, dw_size=3, padding=1, use_se=False\n    ):\n        super(DepthwiseSeparable, self).__init__()\n        self.use_se = use_se\n        self._depthwise_conv = ConvBNLayer(\n            num_channels=num_channels,\n            num_filters=int(num_filters1 * scale),\n            filter_size=dw_size,\n            stride=stride,\n            padding=padding,\n            num_groups=int(num_groups * scale),\n        )\n        if use_se:\n            self._se = SEModule(int(num_filters1 * scale))\n        self._pointwise_conv = ConvBNLayer(\n            num_channels=int(num_filters1 * scale),\n            filter_size=1,\n            num_filters=int(num_filters2 * scale),\n            stride=1,\n            padding=0,\n        )\n\n    def forward(self, inputs):\n        y = self._depthwise_conv(inputs)\n        if self.use_se:\n            y = self._se(y)\n        y = self._pointwise_conv(y)\n        return y\n\n\nclass MobileNetV1Enhance(nn.Module):\n    def __init__(self, in_channels=3, scale=0.5, last_conv_stride=1, last_pool_type=\"max\", **kwargs):\n        super().__init__()\n        self.scale = scale\n        self.block_list = []\n\n        self.conv1 = ConvBNLayer(\n            num_channels=in_channels, filter_size=3, channels=3, num_filters=int(32 * scale), stride=2, padding=1\n        )\n\n        conv2_1 = DepthwiseSeparable(\n            num_channels=int(32 * scale), num_filters1=32, num_filters2=64, num_groups=32, stride=1, scale=scale\n        )\n        self.block_list.append(conv2_1)\n\n        conv2_2 = DepthwiseSeparable(\n            num_channels=int(64 * scale), num_filters1=64, num_filters2=128, num_groups=64, stride=1, scale=scale\n        )\n        self.block_list.append(conv2_2)\n\n        conv3_1 = DepthwiseSeparable(\n            num_channels=int(128 * scale), num_filters1=128, num_filters2=128, num_groups=128, stride=1, scale=scale\n        )\n        self.block_list.append(conv3_1)\n\n        conv3_2 = DepthwiseSeparable(\n            num_channels=int(128 * scale),\n            num_filters1=128,\n            num_filters2=256,\n            num_groups=128,\n            stride=(2, 1),\n            scale=scale,\n        )\n        self.block_list.append(conv3_2)\n\n        conv4_1 = DepthwiseSeparable(\n            num_channels=int(256 * scale), num_filters1=256, num_filters2=256, num_groups=256, stride=1, scale=scale\n        )\n        self.block_list.append(conv4_1)\n\n        conv4_2 = DepthwiseSeparable(\n            num_channels=int(256 * scale),\n            num_filters1=256,\n            num_filters2=512,\n            num_groups=256,\n            stride=(2, 1),\n            scale=scale,\n        )\n        self.block_list.append(conv4_2)\n\n        for _ in range(5):\n            conv5 = DepthwiseSeparable(\n                num_channels=int(512 * scale),\n                num_filters1=512,\n                num_filters2=512,\n                num_groups=512,\n                stride=1,\n                dw_size=5,\n                padding=2,\n                scale=scale,\n                use_se=False,\n            )\n            self.block_list.append(conv5)\n\n        conv5_6 = DepthwiseSeparable(\n            num_channels=int(512 * scale),\n            num_filters1=512,\n            num_filters2=1024,\n            num_groups=512,\n            stride=(2, 1),\n            dw_size=5,\n            padding=2,\n            scale=scale,\n            use_se=True,\n        )\n        self.block_list.append(conv5_6)\n\n        conv6 = DepthwiseSeparable(\n            num_channels=int(1024 * scale),\n            num_filters1=1024,\n            num_filters2=1024,\n            num_groups=1024,\n            stride=last_conv_stride,\n            dw_size=5,\n            padding=2,\n            use_se=True,\n            scale=scale,\n        )\n        self.block_list.append(conv6)\n\n        self.block_list = nn.Sequential(*self.block_list)\n        if last_pool_type == \"avg\":\n            self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)\n        else:\n            self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)\n        self.out_channels = int(1024 * scale)\n\n    def forward(self, inputs):\n        y = self.conv1(inputs)\n        y = self.block_list(y)\n        y = self.pool(y)\n        return y\n\n\ndef hardsigmoid(x):\n    return F.relu6(x + 3.0, inplace=True) / 6.0\n\n\nclass SEModule(nn.Module):\n    def __init__(self, channel, reduction=4):\n        super(SEModule, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.conv1 = nn.Conv2d(\n            in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0, bias=True\n        )\n        self.conv2 = nn.Conv2d(\n            in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0, bias=True\n        )\n\n    def forward(self, inputs):\n        outputs = self.avg_pool(inputs)\n        outputs = self.conv1(outputs)\n        outputs = F.relu(outputs)\n        outputs = self.conv2(outputs)\n        outputs = hardsigmoid(outputs)\n        x = torch.mul(inputs, outputs)\n\n        return x\n"
  },
  {
    "path": "examples/research_projects/anytext/ocr_recog/RecSVTR.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional\nfrom torch.nn.init import ones_, trunc_normal_, zeros_\n\n\ndef drop_path(x, drop_prob=0.0, training=False):\n    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...\n    \"\"\"\n    if drop_prob == 0.0 or not training:\n        return x\n    keep_prob = torch.tensor(1 - drop_prob)\n    shape = (x.size()[0],) + (1,) * (x.ndim - 1)\n    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)\n    random_tensor = torch.floor(random_tensor)  # binarize\n    output = x.divide(keep_prob) * random_tensor\n    return output\n\n\nclass Swish(nn.Module):\n    def __int__(self):\n        super(Swish, self).__int__()\n\n    def forward(self, x):\n        return x * torch.sigmoid(x)\n\n\nclass ConvBNLayer(nn.Module):\n    def __init__(\n        self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU\n    ):\n        super().__init__()\n        self.conv = nn.Conv2d(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n            padding=padding,\n            groups=groups,\n            # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),\n            bias=bias_attr,\n        )\n        self.norm = nn.BatchNorm2d(out_channels)\n        self.act = act()\n\n    def forward(self, inputs):\n        out = self.conv(inputs)\n        out = self.norm(out)\n        out = self.act(out)\n        return out\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n\n\nclass Identity(nn.Module):\n    def __init__(self):\n        super(Identity, self).__init__()\n\n    def forward(self, input):\n        return input\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        if isinstance(act_layer, str):\n            self.act = Swish()\n        else:\n            self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass ConvMixer(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_heads=8,\n        HW=(8, 25),\n        local_k=(3, 3),\n    ):\n        super().__init__()\n        self.HW = HW\n        self.dim = dim\n        self.local_mixer = nn.Conv2d(\n            dim,\n            dim,\n            local_k,\n            1,\n            (local_k[0] // 2, local_k[1] // 2),\n            groups=num_heads,\n            # weight_attr=ParamAttr(initializer=KaimingNormal())\n        )\n\n    def forward(self, x):\n        h = self.HW[0]\n        w = self.HW[1]\n        x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])\n        x = self.local_mixer(x)\n        x = x.flatten(2).transpose([0, 2, 1])\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_heads=8,\n        mixer=\"Global\",\n        HW=(8, 25),\n        local_k=(7, 11),\n        qkv_bias=False,\n        qk_scale=None,\n        attn_drop=0.0,\n        proj_drop=0.0,\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim**-0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n        self.HW = HW\n        if HW is not None:\n            H = HW[0]\n            W = HW[1]\n            self.N = H * W\n            self.C = dim\n        if mixer == \"Local\" and HW is not None:\n            hk = local_k[0]\n            wk = local_k[1]\n            mask = torch.ones([H * W, H + hk - 1, W + wk - 1])\n            for h in range(0, H):\n                for w in range(0, W):\n                    mask[h * W + w, h : h + hk, w : w + wk] = 0.0\n            mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(1)\n            mask_inf = torch.full([H * W, H * W], fill_value=float(\"-inf\"))\n            mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)\n            self.mask = mask[None, None, :]\n            # self.mask = mask.unsqueeze([0, 1])\n        self.mixer = mixer\n\n    def forward(self, x):\n        if self.HW is not None:\n            N = self.N\n            C = self.C\n        else:\n            _, N, C = x.shape\n        qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))\n        q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]\n\n        attn = q.matmul(k.permute((0, 1, 3, 2)))\n        if self.mixer == \"Local\":\n            attn += self.mask\n        attn = functional.softmax(attn, dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(\n        self,\n        dim,\n        num_heads,\n        mixer=\"Global\",\n        local_mixer=(7, 11),\n        HW=(8, 25),\n        mlp_ratio=4.0,\n        qkv_bias=False,\n        qk_scale=None,\n        drop=0.0,\n        attn_drop=0.0,\n        drop_path=0.0,\n        act_layer=nn.GELU,\n        norm_layer=\"nn.LayerNorm\",\n        epsilon=1e-6,\n        prenorm=True,\n    ):\n        super().__init__()\n        if isinstance(norm_layer, str):\n            self.norm1 = eval(norm_layer)(dim, eps=epsilon)\n        else:\n            self.norm1 = norm_layer(dim)\n        if mixer == \"Global\" or mixer == \"Local\":\n            self.mixer = Attention(\n                dim,\n                num_heads=num_heads,\n                mixer=mixer,\n                HW=HW,\n                local_k=local_mixer,\n                qkv_bias=qkv_bias,\n                qk_scale=qk_scale,\n                attn_drop=attn_drop,\n                proj_drop=drop,\n            )\n        elif mixer == \"Conv\":\n            self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)\n        else:\n            raise TypeError(\"The mixer must be one of [Global, Local, Conv]\")\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()\n        if isinstance(norm_layer, str):\n            self.norm2 = eval(norm_layer)(dim, eps=epsilon)\n        else:\n            self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp_ratio = mlp_ratio\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n        self.prenorm = prenorm\n\n    def forward(self, x):\n        if self.prenorm:\n            x = self.norm1(x + self.drop_path(self.mixer(x)))\n            x = self.norm2(x + self.drop_path(self.mlp(x)))\n        else:\n            x = x + self.drop_path(self.mixer(self.norm1(x)))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"Image to Patch Embedding\"\"\"\n\n    def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2):\n        super().__init__()\n        num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))\n        self.img_size = img_size\n        self.num_patches = num_patches\n        self.embed_dim = embed_dim\n        self.norm = None\n        if sub_num == 2:\n            self.proj = nn.Sequential(\n                ConvBNLayer(\n                    in_channels=in_channels,\n                    out_channels=embed_dim // 2,\n                    kernel_size=3,\n                    stride=2,\n                    padding=1,\n                    act=nn.GELU,\n                    bias_attr=False,\n                ),\n                ConvBNLayer(\n                    in_channels=embed_dim // 2,\n                    out_channels=embed_dim,\n                    kernel_size=3,\n                    stride=2,\n                    padding=1,\n                    act=nn.GELU,\n                    bias_attr=False,\n                ),\n            )\n        if sub_num == 3:\n            self.proj = nn.Sequential(\n                ConvBNLayer(\n                    in_channels=in_channels,\n                    out_channels=embed_dim // 4,\n                    kernel_size=3,\n                    stride=2,\n                    padding=1,\n                    act=nn.GELU,\n                    bias_attr=False,\n                ),\n                ConvBNLayer(\n                    in_channels=embed_dim // 4,\n                    out_channels=embed_dim // 2,\n                    kernel_size=3,\n                    stride=2,\n                    padding=1,\n                    act=nn.GELU,\n                    bias_attr=False,\n                ),\n                ConvBNLayer(\n                    in_channels=embed_dim // 2,\n                    out_channels=embed_dim,\n                    kernel_size=3,\n                    stride=2,\n                    padding=1,\n                    act=nn.GELU,\n                    bias_attr=False,\n                ),\n            )\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        assert H == self.img_size[0] and W == self.img_size[1], (\n            f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        )\n        x = self.proj(x).flatten(2).permute(0, 2, 1)\n        return x\n\n\nclass SubSample(nn.Module):\n    def __init__(self, in_channels, out_channels, types=\"Pool\", stride=(2, 1), sub_norm=\"nn.LayerNorm\", act=None):\n        super().__init__()\n        self.types = types\n        if types == \"Pool\":\n            self.avgpool = nn.AvgPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))\n            self.maxpool = nn.MaxPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))\n            self.proj = nn.Linear(in_channels, out_channels)\n        else:\n            self.conv = nn.Conv2d(\n                in_channels,\n                out_channels,\n                kernel_size=3,\n                stride=stride,\n                padding=1,\n                # weight_attr=ParamAttr(initializer=KaimingNormal())\n            )\n        self.norm = eval(sub_norm)(out_channels)\n        if act is not None:\n            self.act = act()\n        else:\n            self.act = None\n\n    def forward(self, x):\n        if self.types == \"Pool\":\n            x1 = self.avgpool(x)\n            x2 = self.maxpool(x)\n            x = (x1 + x2) * 0.5\n            out = self.proj(x.flatten(2).permute((0, 2, 1)))\n        else:\n            x = self.conv(x)\n            out = x.flatten(2).permute((0, 2, 1))\n        out = self.norm(out)\n        if self.act is not None:\n            out = self.act(out)\n\n        return out\n\n\nclass SVTRNet(nn.Module):\n    def __init__(\n        self,\n        img_size=[48, 100],\n        in_channels=3,\n        embed_dim=[64, 128, 256],\n        depth=[3, 6, 3],\n        num_heads=[2, 4, 8],\n        mixer=[\"Local\"] * 6 + [\"Global\"] * 6,  # Local atten, Global atten, Conv\n        local_mixer=[[7, 11], [7, 11], [7, 11]],\n        patch_merging=\"Conv\",  # Conv, Pool, None\n        mlp_ratio=4,\n        qkv_bias=True,\n        qk_scale=None,\n        drop_rate=0.0,\n        last_drop=0.1,\n        attn_drop_rate=0.0,\n        drop_path_rate=0.1,\n        norm_layer=\"nn.LayerNorm\",\n        sub_norm=\"nn.LayerNorm\",\n        epsilon=1e-6,\n        out_channels=192,\n        out_char_num=25,\n        block_unit=\"Block\",\n        act=\"nn.GELU\",\n        last_stage=True,\n        sub_num=2,\n        prenorm=True,\n        use_lenhead=False,\n        **kwargs,\n    ):\n        super().__init__()\n        self.img_size = img_size\n        self.embed_dim = embed_dim\n        self.out_channels = out_channels\n        self.prenorm = prenorm\n        patch_merging = None if patch_merging != \"Conv\" and patch_merging != \"Pool\" else patch_merging\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, in_channels=in_channels, embed_dim=embed_dim[0], sub_num=sub_num\n        )\n        num_patches = self.patch_embed.num_patches\n        self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))\n        # self.pos_embed = self.create_parameter(\n        #     shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)\n\n        # self.add_parameter(\"pos_embed\", self.pos_embed)\n\n        self.pos_drop = nn.Dropout(p=drop_rate)\n        Block_unit = eval(block_unit)\n\n        dpr = np.linspace(0, drop_path_rate, sum(depth))\n        self.blocks1 = nn.ModuleList(\n            [\n                Block_unit(\n                    dim=embed_dim[0],\n                    num_heads=num_heads[0],\n                    mixer=mixer[0 : depth[0]][i],\n                    HW=self.HW,\n                    local_mixer=local_mixer[0],\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    act_layer=eval(act),\n                    attn_drop=attn_drop_rate,\n                    drop_path=dpr[0 : depth[0]][i],\n                    norm_layer=norm_layer,\n                    epsilon=epsilon,\n                    prenorm=prenorm,\n                )\n                for i in range(depth[0])\n            ]\n        )\n        if patch_merging is not None:\n            self.sub_sample1 = SubSample(\n                embed_dim[0], embed_dim[1], sub_norm=sub_norm, stride=[2, 1], types=patch_merging\n            )\n            HW = [self.HW[0] // 2, self.HW[1]]\n        else:\n            HW = self.HW\n        self.patch_merging = patch_merging\n        self.blocks2 = nn.ModuleList(\n            [\n                Block_unit(\n                    dim=embed_dim[1],\n                    num_heads=num_heads[1],\n                    mixer=mixer[depth[0] : depth[0] + depth[1]][i],\n                    HW=HW,\n                    local_mixer=local_mixer[1],\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    act_layer=eval(act),\n                    attn_drop=attn_drop_rate,\n                    drop_path=dpr[depth[0] : depth[0] + depth[1]][i],\n                    norm_layer=norm_layer,\n                    epsilon=epsilon,\n                    prenorm=prenorm,\n                )\n                for i in range(depth[1])\n            ]\n        )\n        if patch_merging is not None:\n            self.sub_sample2 = SubSample(\n                embed_dim[1], embed_dim[2], sub_norm=sub_norm, stride=[2, 1], types=patch_merging\n            )\n            HW = [self.HW[0] // 4, self.HW[1]]\n        else:\n            HW = self.HW\n        self.blocks3 = nn.ModuleList(\n            [\n                Block_unit(\n                    dim=embed_dim[2],\n                    num_heads=num_heads[2],\n                    mixer=mixer[depth[0] + depth[1] :][i],\n                    HW=HW,\n                    local_mixer=local_mixer[2],\n                    mlp_ratio=mlp_ratio,\n                    qkv_bias=qkv_bias,\n                    qk_scale=qk_scale,\n                    drop=drop_rate,\n                    act_layer=eval(act),\n                    attn_drop=attn_drop_rate,\n                    drop_path=dpr[depth[0] + depth[1] :][i],\n                    norm_layer=norm_layer,\n                    epsilon=epsilon,\n                    prenorm=prenorm,\n                )\n                for i in range(depth[2])\n            ]\n        )\n        self.last_stage = last_stage\n        if last_stage:\n            self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))\n            self.last_conv = nn.Conv2d(\n                in_channels=embed_dim[2],\n                out_channels=self.out_channels,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n                bias=False,\n            )\n            self.hardswish = nn.Hardswish()\n            self.dropout = nn.Dropout(p=last_drop)\n        if not prenorm:\n            self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)\n        self.use_lenhead = use_lenhead\n        if use_lenhead:\n            self.len_conv = nn.Linear(embed_dim[2], self.out_channels)\n            self.hardswish_len = nn.Hardswish()\n            self.dropout_len = nn.Dropout(p=last_drop)\n\n        trunc_normal_(self.pos_embed, std=0.02)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=0.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                zeros_(m.bias)\n        elif isinstance(m, nn.LayerNorm):\n            zeros_(m.bias)\n            ones_(m.weight)\n\n    def forward_features(self, x):\n        x = self.patch_embed(x)\n        x = x + self.pos_embed\n        x = self.pos_drop(x)\n        for blk in self.blocks1:\n            x = blk(x)\n        if self.patch_merging is not None:\n            x = self.sub_sample1(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[0], self.HW[0], self.HW[1]]))\n        for blk in self.blocks2:\n            x = blk(x)\n        if self.patch_merging is not None:\n            x = self.sub_sample2(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))\n        for blk in self.blocks3:\n            x = blk(x)\n        if not self.prenorm:\n            x = self.norm(x)\n        return x\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        if self.use_lenhead:\n            len_x = self.len_conv(x.mean(1))\n            len_x = self.dropout_len(self.hardswish_len(len_x))\n        if self.last_stage:\n            if self.patch_merging is not None:\n                h = self.HW[0] // 4\n            else:\n                h = self.HW[0]\n            x = self.avg_pool(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]]))\n            x = self.last_conv(x)\n            x = self.hardswish(x)\n            x = self.dropout(x)\n        if self.use_lenhead:\n            return x, len_x\n        return x\n\n\nif __name__ == \"__main__\":\n    a = torch.rand(1, 3, 48, 100)\n    svtr = SVTRNet()\n\n    out = svtr(a)\n    print(svtr)\n    print(out.size())\n"
  },
  {
    "path": "examples/research_projects/anytext/ocr_recog/common.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Hswish(nn.Module):\n    def __init__(self, inplace=True):\n        super(Hswish, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0\n\n\n# out = max(0, min(1, slop*x+offset))\n# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)\nclass Hsigmoid(nn.Module):\n    def __init__(self, inplace=True):\n        super(Hsigmoid, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        # torch: F.relu6(x + 3., inplace=self.inplace) / 6.\n        # paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.\n        return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0\n\n\nclass GELU(nn.Module):\n    def __init__(self, inplace=True):\n        super(GELU, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return torch.nn.functional.gelu(x)\n\n\nclass Swish(nn.Module):\n    def __init__(self, inplace=True):\n        super(Swish, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        if self.inplace:\n            x.mul_(torch.sigmoid(x))\n            return x\n        else:\n            return x * torch.sigmoid(x)\n\n\nclass Activation(nn.Module):\n    def __init__(self, act_type, inplace=True):\n        super(Activation, self).__init__()\n        act_type = act_type.lower()\n        if act_type == \"relu\":\n            self.act = nn.ReLU(inplace=inplace)\n        elif act_type == \"relu6\":\n            self.act = nn.ReLU6(inplace=inplace)\n        elif act_type == \"sigmoid\":\n            raise NotImplementedError\n        elif act_type == \"hard_sigmoid\":\n            self.act = Hsigmoid(inplace)\n        elif act_type == \"hard_swish\":\n            self.act = Hswish(inplace=inplace)\n        elif act_type == \"leakyrelu\":\n            self.act = nn.LeakyReLU(inplace=inplace)\n        elif act_type == \"gelu\":\n            self.act = GELU(inplace=inplace)\n        elif act_type == \"swish\":\n            self.act = Swish(inplace=inplace)\n        else:\n            raise NotImplementedError\n\n    def forward(self, inputs):\n        return self.act(inputs)\n"
  },
  {
    "path": "examples/research_projects/anytext/ocr_recog/en_dict.txt",
    "content": "0\n1\n2\n3\n4\n5\n6\n7\n8\n9\n:\n;\n<\n=\n>\n?\n@\nA\nB\nC\nD\nE\nF\nG\nH\nI\nJ\nK\nL\nM\nN\nO\nP\nQ\nR\nS\nT\nU\nV\nW\nX\nY\nZ\n[\n\\\n]\n^\n_\n`\na\nb\nc\nd\ne\nf\ng\nh\ni\nj\nk\nl\nm\nn\no\np\nq\nr\ns\nt\nu\nv\nw\nx\ny\nz\n{\n|\n}\n~\n!\n\"\n#\n$\n%\n&\n'\n(\n)\n*\n+\n,\n-\n.\n/\n \n"
  },
  {
    "path": "examples/research_projects/autoencoder_rae/README.md",
    "content": "# Training AutoencoderRAE\n\nThis example trains the decoder of `AutoencoderRAE` (stage-1 style), while keeping the representation encoder frozen.\n\nIt follows the same high-level training recipe as the official RAE stage-1 setup:\n- frozen encoder\n- train decoder\n- pixel reconstruction loss\n- optional encoder feature consistency loss\n\n## Quickstart\n\n### Resume or finetune from pretrained weights\n\n```bash\naccelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \\\n  --pretrained_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \\\n  --train_data_dir /path/to/imagenet_like_folder \\\n  --output_dir /tmp/autoencoder-rae \\\n  --resolution 256 \\\n  --train_batch_size 8 \\\n  --learning_rate 1e-4 \\\n  --num_train_epochs 10 \\\n  --report_to wandb \\\n  --reconstruction_loss_type l1 \\\n  --use_encoder_loss \\\n  --encoder_loss_weight 0.1\n```\n\n### Train from scratch with a pretrained encoder\nThe following command launches RAE training with \"facebook/dinov2-with-registers-base\" as the base.\n\n```bash\naccelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \\\n  --train_data_dir /path/to/imagenet_like_folder \\\n  --output_dir /tmp/autoencoder-rae \\\n  --resolution 256 \\\n  --encoder_type dinov2 \\\n  --encoder_name_or_path facebook/dinov2-with-registers-base \\\n  --encoder_input_size 224 \\\n  --patch_size 16 \\\n  --image_size 256 \\\n  --decoder_hidden_size 1152 \\\n  --decoder_num_hidden_layers 28 \\\n  --decoder_num_attention_heads 16 \\\n  --decoder_intermediate_size 4096 \\\n  --train_batch_size 8 \\\n  --learning_rate 1e-4 \\\n  --num_train_epochs 10 \\\n  --report_to wandb \\\n  --reconstruction_loss_type l1 \\\n  --use_encoder_loss \\\n  --encoder_loss_weight 0.1\n```\n\nNote: stage-1 reconstruction loss assumes matching target/output spatial size, so `--resolution` must equal `--image_size`.\n\nDataset format is expected to be `ImageFolder`-compatible:\n\n```text\ntrain_data_dir/\n  class_a/\n    img_0001.jpg\n  class_b/\n    img_0002.jpg\n```\n"
  },
  {
    "path": "examples/research_projects/autoencoder_rae/train_autoencoder_rae.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nfrom pathlib import Path\n\nimport torch\nimport torch.nn.functional as F\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms\nfrom torchvision.datasets import ImageFolder\nfrom tqdm.auto import tqdm\n\nfrom diffusers import AutoencoderRAE\nfrom diffusers.optimization import get_scheduler\n\n\nlogger = get_logger(__name__)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Train a stage-1 Representation Autoencoder (RAE) decoder.\")\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        required=True,\n        help=\"Path to an ImageFolder-style dataset root.\",\n    )\n    parser.add_argument(\n        \"--output_dir\", type=str, default=\"autoencoder-rae\", help=\"Directory to save checkpoints/model.\"\n    )\n    parser.add_argument(\"--logging_dir\", type=str, default=\"logs\", help=\"Accelerate logging directory.\")\n    parser.add_argument(\"--seed\", type=int, default=42)\n\n    parser.add_argument(\"--resolution\", type=int, default=256)\n    parser.add_argument(\"--center_crop\", action=\"store_true\")\n    parser.add_argument(\"--random_flip\", action=\"store_true\")\n\n    parser.add_argument(\"--train_batch_size\", type=int, default=8)\n    parser.add_argument(\"--dataloader_num_workers\", type=int, default=4)\n    parser.add_argument(\"--num_train_epochs\", type=int, default=10)\n    parser.add_argument(\"--max_train_steps\", type=int, default=None)\n    parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=1)\n    parser.add_argument(\"--max_grad_norm\", type=float, default=1.0)\n\n    parser.add_argument(\"--learning_rate\", type=float, default=1e-4)\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9)\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999)\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2)\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-8)\n    parser.add_argument(\"--lr_scheduler\", type=str, default=\"cosine\")\n    parser.add_argument(\"--lr_warmup_steps\", type=int, default=500)\n\n    parser.add_argument(\"--checkpointing_steps\", type=int, default=1000)\n    parser.add_argument(\"--validation_steps\", type=int, default=500)\n\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to a pretrained AutoencoderRAE model (or HF Hub id) to resume training from.\",\n    )\n    parser.add_argument(\n        \"--encoder_name_or_path\",\n        type=str,\n        default=None,\n        help=(\n            \"HF Hub id or local path of the pretrained encoder (e.g. 'facebook/dinov2-with-registers-base'). \"\n            \"When --pretrained_model_name_or_path is not set, the encoder weights are loaded from this path \"\n            \"into a freshly constructed AutoencoderRAE. Ignored when --pretrained_model_name_or_path is set.\"\n        ),\n    )\n\n    parser.add_argument(\"--encoder_type\", type=str, choices=[\"dinov2\", \"siglip2\", \"mae\"], default=\"dinov2\")\n    parser.add_argument(\"--encoder_hidden_size\", type=int, default=768)\n    parser.add_argument(\"--encoder_patch_size\", type=int, default=14)\n    parser.add_argument(\"--encoder_num_hidden_layers\", type=int, default=12)\n    parser.add_argument(\"--encoder_input_size\", type=int, default=224)\n    parser.add_argument(\"--patch_size\", type=int, default=16)\n    parser.add_argument(\"--image_size\", type=int, default=256)\n    parser.add_argument(\"--num_channels\", type=int, default=3)\n\n    parser.add_argument(\"--decoder_hidden_size\", type=int, default=1152)\n    parser.add_argument(\"--decoder_num_hidden_layers\", type=int, default=28)\n    parser.add_argument(\"--decoder_num_attention_heads\", type=int, default=16)\n    parser.add_argument(\"--decoder_intermediate_size\", type=int, default=4096)\n\n    parser.add_argument(\"--noise_tau\", type=float, default=0.0)\n    parser.add_argument(\"--scaling_factor\", type=float, default=1.0)\n    parser.add_argument(\"--reshape_to_2d\", action=argparse.BooleanOptionalAction, default=True)\n\n    parser.add_argument(\n        \"--reconstruction_loss_type\",\n        type=str,\n        choices=[\"l1\", \"mse\"],\n        default=\"l1\",\n        help=\"Pixel reconstruction loss.\",\n    )\n    parser.add_argument(\n        \"--encoder_loss_weight\",\n        type=float,\n        default=0.0,\n        help=\"Weight for encoder feature consistency loss in the training loop.\",\n    )\n    parser.add_argument(\n        \"--use_encoder_loss\",\n        action=\"store_true\",\n        help=\"Enable encoder feature consistency loss term in the training loop.\",\n    )\n    parser.add_argument(\"--report_to\", type=str, default=\"tensorboard\")\n\n    return parser.parse_args()\n\n\ndef build_transforms(args):\n    image_transforms = [\n        transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC),\n        transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n    ]\n    if args.random_flip:\n        image_transforms.append(transforms.RandomHorizontalFlip())\n    image_transforms.append(transforms.ToTensor())\n    return transforms.Compose(image_transforms)\n\n\ndef compute_losses(\n    model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float\n):\n    decoded = model(pixel_values).sample\n\n    if decoded.shape[-2:] != pixel_values.shape[-2:]:\n        raise ValueError(\n            \"Training requires matching reconstruction and target sizes, got \"\n            f\"decoded={tuple(decoded.shape[-2:])}, target={tuple(pixel_values.shape[-2:])}.\"\n        )\n\n    if reconstruction_loss_type == \"l1\":\n        reconstruction_loss = F.l1_loss(decoded.float(), pixel_values.float())\n    else:\n        reconstruction_loss = F.mse_loss(decoded.float(), pixel_values.float())\n\n    encoder_loss = torch.zeros_like(reconstruction_loss)\n    if use_encoder_loss and encoder_loss_weight > 0:\n        base_model = model.module if hasattr(model, \"module\") else model\n        target_encoder_input = base_model._resize_and_normalize(pixel_values)\n        reconstructed_encoder_input = base_model._resize_and_normalize(decoded)\n\n        encoder_forward_kwargs = {\"model\": base_model.encoder}\n        if base_model.config.encoder_type == \"mae\":\n            encoder_forward_kwargs[\"patch_size\"] = base_model.config.encoder_patch_size\n        with torch.no_grad():\n            target_tokens = base_model._encoder_forward_fn(images=target_encoder_input, **encoder_forward_kwargs)\n        reconstructed_tokens = base_model._encoder_forward_fn(\n            images=reconstructed_encoder_input, **encoder_forward_kwargs\n        )\n        encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float())\n\n    loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss\n    return decoded, loss, reconstruction_loss, encoder_loss\n\n\ndef _strip_final_layernorm_affine(state_dict, prefix=\"\"):\n    \"\"\"Remove final layernorm weight/bias so the model keeps its default init (identity).\"\"\"\n    keys_to_strip = {f\"{prefix}weight\", f\"{prefix}bias\"}\n    return {k: v for k, v in state_dict.items() if k not in keys_to_strip}\n\n\ndef _load_pretrained_encoder_weights(model, encoder_type, encoder_name_or_path):\n    \"\"\"Load pretrained HF transformers encoder weights into the model's encoder.\"\"\"\n    if encoder_type == \"dinov2\":\n        from transformers import Dinov2WithRegistersModel\n\n        hf_encoder = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)\n        state_dict = hf_encoder.state_dict()\n        state_dict = _strip_final_layernorm_affine(state_dict, prefix=\"layernorm.\")\n    elif encoder_type == \"siglip2\":\n        from transformers import SiglipModel\n\n        hf_encoder = SiglipModel.from_pretrained(encoder_name_or_path).vision_model\n        state_dict = {f\"vision_model.{k}\": v for k, v in hf_encoder.state_dict().items()}\n        state_dict = _strip_final_layernorm_affine(state_dict, prefix=\"vision_model.post_layernorm.\")\n    elif encoder_type == \"mae\":\n        from transformers import ViTMAEForPreTraining\n\n        hf_encoder = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit\n        state_dict = hf_encoder.state_dict()\n        state_dict = _strip_final_layernorm_affine(state_dict, prefix=\"layernorm.\")\n    else:\n        raise ValueError(f\"Unknown encoder_type: {encoder_type}\")\n\n    model.encoder.load_state_dict(state_dict, strict=False)\n\n\ndef main():\n    args = parse_args()\n    if args.resolution != args.image_size:\n        raise ValueError(\n            f\"`--resolution` ({args.resolution}) must match `--image_size` ({args.image_size}) \"\n            \"for stage-1 reconstruction loss.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        project_config=accelerator_project_config,\n        log_with=args.report_to,\n    )\n\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    if accelerator.is_main_process:\n        os.makedirs(args.output_dir, exist_ok=True)\n    accelerator.wait_for_everyone()\n\n    dataset = ImageFolder(args.train_data_dir, transform=build_transforms(args))\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[0] for example in examples]).float()\n        return {\"pixel_values\": pixel_values}\n\n    train_dataloader = DataLoader(\n        dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n        pin_memory=True,\n        drop_last=True,\n    )\n\n    if args.pretrained_model_name_or_path is not None:\n        model = AutoencoderRAE.from_pretrained(args.pretrained_model_name_or_path)\n        logger.info(f\"Loaded pretrained AutoencoderRAE from {args.pretrained_model_name_or_path}\")\n    else:\n        model = AutoencoderRAE(\n            encoder_type=args.encoder_type,\n            encoder_hidden_size=args.encoder_hidden_size,\n            encoder_patch_size=args.encoder_patch_size,\n            encoder_num_hidden_layers=args.encoder_num_hidden_layers,\n            decoder_hidden_size=args.decoder_hidden_size,\n            decoder_num_hidden_layers=args.decoder_num_hidden_layers,\n            decoder_num_attention_heads=args.decoder_num_attention_heads,\n            decoder_intermediate_size=args.decoder_intermediate_size,\n            patch_size=args.patch_size,\n            encoder_input_size=args.encoder_input_size,\n            image_size=args.image_size,\n            num_channels=args.num_channels,\n            noise_tau=args.noise_tau,\n            reshape_to_2d=args.reshape_to_2d,\n            use_encoder_loss=args.use_encoder_loss,\n            scaling_factor=args.scaling_factor,\n        )\n        if args.encoder_name_or_path is not None:\n            _load_pretrained_encoder_weights(model, args.encoder_type, args.encoder_name_or_path)\n            logger.info(f\"Loaded pretrained encoder weights from {args.encoder_name_or_path}\")\n    model.encoder.requires_grad_(False)\n    model.decoder.requires_grad_(True)\n    model.train()\n\n    optimizer = torch.optim.AdamW(\n        (p for p in model.parameters() if p.requires_grad),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, lr_scheduler\n    )\n\n    if overrode_max_train_steps:\n        num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"train_autoencoder_rae\", config=vars(args))\n\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n\n    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n    global_step = 0\n\n    for epoch in range(args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(model):\n                pixel_values = batch[\"pixel_values\"]\n\n                _, loss, reconstruction_loss, encoder_loss = compute_losses(\n                    model,\n                    pixel_values,\n                    reconstruction_loss_type=args.reconstruction_loss_type,\n                    use_encoder_loss=args.use_encoder_loss,\n                    encoder_loss_weight=args.encoder_loss_weight,\n                )\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                logs = {\n                    \"loss\": loss.detach().item(),\n                    \"reconstruction_loss\": reconstruction_loss.detach().item(),\n                    \"encoder_loss\": encoder_loss.detach().item(),\n                    \"lr\": lr_scheduler.get_last_lr()[0],\n                }\n                progress_bar.set_postfix(**logs)\n                accelerator.log(logs, step=global_step)\n\n                if global_step % args.validation_steps == 0:\n                    with torch.no_grad():\n                        _, val_loss, val_reconstruction_loss, val_encoder_loss = compute_losses(\n                            model,\n                            pixel_values,\n                            reconstruction_loss_type=args.reconstruction_loss_type,\n                            use_encoder_loss=args.use_encoder_loss,\n                            encoder_loss_weight=args.encoder_loss_weight,\n                        )\n                    accelerator.log(\n                        {\n                            \"val/loss\": val_loss.detach().item(),\n                            \"val/reconstruction_loss\": val_reconstruction_loss.detach().item(),\n                            \"val/encoder_loss\": val_encoder_loss.detach().item(),\n                        },\n                        step=global_step,\n                    )\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        unwrapped_model = accelerator.unwrap_model(model)\n                        unwrapped_model.save_pretrained(save_path)\n                        logger.info(f\"Saved checkpoint to {save_path}\")\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if global_step >= args.max_train_steps:\n            break\n\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unwrapped_model = accelerator.unwrap_model(model)\n        unwrapped_model.save_pretrained(args.output_dir)\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/autoencoderkl/README.md",
    "content": "# AutoencoderKL training example\n\n## Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder  and run\n```bash\npip install -r requirements.txt\n```\n\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\n## Training on CIFAR10\n\nPlease replace the validation image with your own image.\n\n```bash\naccelerate launch train_autoencoderkl.py \\\n    --pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \\\n    --dataset_name=cifar10 \\\n    --image_column=img \\\n    --validation_image images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \\\n    --num_train_epochs 100 \\\n    --gradient_accumulation_steps 2 \\\n    --learning_rate 4.5e-6 \\\n    --lr_scheduler cosine \\\n    --report_to wandb \\\n```\n\n## Training on ImageNet\n\n```bash\naccelerate launch train_autoencoderkl.py \\\n    --pretrained_model_name_or_path stabilityai/sd-vae-ft-mse \\\n    --num_train_epochs 100 \\\n    --gradient_accumulation_steps 2 \\\n    --learning_rate 4.5e-6 \\\n    --lr_scheduler cosine \\\n    --report_to wandb \\\n    --mixed_precision bf16 \\\n    --train_data_dir /path/to/ImageNet/train \\\n    --validation_image ./image.png \\\n    --decoder_only\n```\n"
  },
  {
    "path": "examples/research_projects/autoencoderkl/requirements.txt",
    "content": "accelerate>=0.16.0\nbitsandbytes\ndatasets\nhuggingface_hub\nlpips\nnumpy\npackaging\nPillow\ntaming_transformers\ntorch\ntorchvision\ntqdm\ntransformers\nwandb\nxformers\n"
  },
  {
    "path": "examples/research_projects/autoencoderkl/train_autoencoderkl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport contextlib\nimport gc\nimport logging\nimport math\nimport os\nimport shutil\nfrom pathlib import Path\n\nimport accelerate\nimport lpips\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport torchvision\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom taming.modules.losses.vqperceptual import NLayerDiscriminator, hinge_d_loss, vanilla_d_loss, weights_init\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\n\nimport diffusers\nfrom diffusers import AutoencoderKL\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel\nfrom diffusers.utils import check_min_version, is_wandb_available, make_image_grid\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.33.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\n@torch.no_grad()\ndef log_validation(vae, args, accelerator, weight_dtype, step, is_final_validation=False):\n    logger.info(\"Running validation... \")\n\n    if not is_final_validation:\n        vae = accelerator.unwrap_model(vae)\n    else:\n        vae = AutoencoderKL.from_pretrained(args.output_dir, torch_dtype=weight_dtype)\n\n    images = []\n    inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(\"cuda\")\n\n    image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    for i, validation_image in enumerate(args.validation_image):\n        validation_image = Image.open(validation_image).convert(\"RGB\")\n        targets = image_transforms(validation_image).to(accelerator.device, weight_dtype)\n        targets = targets.unsqueeze(0)\n\n        with inference_ctx:\n            reconstructions = vae(targets).sample\n\n        images.append(torch.cat([targets.cpu(), reconstructions.cpu()], axis=0))\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(f\"{tracker_key}: Original (left), Reconstruction (right)\", np_images, step)\n        elif tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    f\"{tracker_key}: Original (left), Reconstruction (right)\": [\n                        wandb.Image(torchvision.utils.make_grid(image)) for _, image in enumerate(images)\n                    ]\n                }\n            )\n        else:\n            logger.warn(f\"image logging not implemented for {tracker.name}\")\n\n        gc.collect()\n        torch.cuda.empty_cache()\n\n    return images\n\n\ndef save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if images is not None:\n        img_str = \"You can find some example images below.\\n\\n\"\n        make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, \"images.png\"))\n        img_str += \"![images](./images.png)\\n\"\n\n    model_description = f\"\"\"\n# autoencoderkl-{repo_id}\n\nThese are autoencoderkl weights trained on {base_model} with new type of conditioning.\n{img_str}\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion\",\n        \"stable-diffusion-diffusers\",\n        \"image-to-image\",\n        \"diffusers\",\n        \"autoencoderkl\",\n        \"diffusers-training\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a AutoencoderKL training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--model_config_name_or_path\",\n        type=str,\n        default=None,\n        help=\"The config of the VAE model to train, leave as None to use standard VAE model configuration.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"autoencoderkl-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=4.5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--disc_learning_rate\",\n        type=float,\n        default=4.5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--disc_lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=\"A set of paths to the image be evaluated every `--validation_steps` and logged to `--report_to`.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"train_autoencoderkl\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n    parser.add_argument(\n        \"--rec_loss\",\n        type=str,\n        default=\"l2\",\n        help=\"The loss function for VAE reconstruction loss.\",\n    )\n    parser.add_argument(\n        \"--kl_scale\",\n        type=float,\n        default=1e-6,\n        help=\"Scaling factor for the Kullback-Leibler divergence penalty term.\",\n    )\n    parser.add_argument(\n        \"--perceptual_scale\",\n        type=float,\n        default=0.5,\n        help=\"Scaling factor for the LPIPS metric\",\n    )\n    parser.add_argument(\n        \"--disc_start\",\n        type=int,\n        default=50001,\n        help=\"Start for the discriminator\",\n    )\n    parser.add_argument(\n        \"--disc_factor\",\n        type=float,\n        default=1.0,\n        help=\"Scaling factor for the discriminator\",\n    )\n    parser.add_argument(\n        \"--disc_scale\",\n        type=float,\n        default=1.0,\n        help=\"Scaling factor for the discriminator\",\n    )\n    parser.add_argument(\n        \"--disc_loss\",\n        type=str,\n        default=\"hinge\",\n        help=\"Loss function for the discriminator\",\n    )\n    parser.add_argument(\n        \"--decoder_only\",\n        action=\"store_true\",\n        help=\"Only train the VAE decoder.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.pretrained_model_name_or_path is not None and args.model_config_name_or_path is not None:\n        raise ValueError(\"Cannot specify both `--pretrained_model_name_or_path` and `--model_config_name_or_path`\")\n\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--train_data_dir`\")\n\n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the diffusion model.\"\n        )\n\n    return args\n\n\ndef make_train_dataset(args, accelerator):\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            data_dir=args.train_data_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    if args.image_column is None:\n        image_column = column_names[0]\n        logger.info(f\"image column defaulting to {image_column}\")\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        images = [image_transforms(image) for image in images]\n\n        examples[\"pixel_values\"] = images\n\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    return train_dataset\n\n\ndef collate_fn(examples):\n    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    return {\"pixel_values\": pixel_values}\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load AutoencoderKL\n    if args.pretrained_model_name_or_path is None and args.model_config_name_or_path is None:\n        config = AutoencoderKL.load_config(\"stabilityai/sd-vae-ft-mse\")\n        vae = AutoencoderKL.from_config(config)\n    elif args.pretrained_model_name_or_path is not None:\n        vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, revision=args.revision)\n    else:\n        config = AutoencoderKL.load_config(args.model_config_name_or_path)\n        vae = AutoencoderKL.from_config(config)\n    if args.use_ema:\n        ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)\n    perceptual_loss = lpips.LPIPS(net=\"vgg\").eval()\n    discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)\n    discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)\n\n    # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    sub_dir = \"autoencoderkl_ema\"\n                    ema_vae.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                i = len(weights) - 1\n\n                while len(weights) > 0:\n                    weights.pop()\n                    model = models[i]\n\n                    if isinstance(model, AutoencoderKL):\n                        sub_dir = \"autoencoderkl\"\n                        model.save_pretrained(os.path.join(output_dir, sub_dir))\n                    else:\n                        sub_dir = \"discriminator\"\n                        os.makedirs(os.path.join(output_dir, sub_dir), exist_ok=True)\n                        torch.save(model.state_dict(), os.path.join(output_dir, sub_dir, \"pytorch_model.bin\"))\n\n                    i -= 1\n\n        def load_model_hook(models, input_dir):\n            while len(models) > 0:\n                if args.use_ema:\n                    sub_dir = \"autoencoderkl_ema\"\n                    load_model = EMAModel.from_pretrained(os.path.join(input_dir, sub_dir), AutoencoderKL)\n                    ema_vae.load_state_dict(load_model.state_dict())\n                    ema_vae.to(accelerator.device)\n                    del load_model\n\n                # pop models so that they are not loaded again\n                model = models.pop()\n                load_model = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).load_state_dict(\n                    os.path.join(input_dir, \"discriminator\", \"pytorch_model.bin\")\n                )\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n                model = models.pop()\n                load_model = AutoencoderKL.from_pretrained(input_dir, subfolder=\"autoencoderkl\")\n                model.register_to_config(**load_model.config)\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    vae.requires_grad_(True)\n    if args.decoder_only:\n        vae.encoder.requires_grad_(False)\n        if getattr(vae, \"quant_conv\", None):\n            vae.quant_conv.requires_grad_(False)\n    vae.train()\n    discriminator.requires_grad_(True)\n    discriminator.train()\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            vae.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        vae.enable_gradient_checkpointing()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if unwrap_model(vae).dtype != torch.float32:\n        raise ValueError(f\"VAE loaded as datatype {unwrap_model(vae).dtype}. {low_precision_error_string}\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    params_to_optimize = filter(lambda p: p.requires_grad, vae.parameters())\n    disc_params_to_optimize = filter(lambda p: p.requires_grad, discriminator.parameters())\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n    disc_optimizer = optimizer_class(\n        disc_params_to_optimize,\n        lr=args.disc_learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    train_dataset = make_train_dataset(args, accelerator)\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n    disc_lr_scheduler = get_scheduler(\n        args.disc_lr_scheduler,\n        optimizer=disc_optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    (\n        vae,\n        discriminator,\n        optimizer,\n        disc_optimizer,\n        train_dataloader,\n        lr_scheduler,\n        disc_lr_scheduler,\n    ) = accelerator.prepare(\n        vae, discriminator, optimizer, disc_optimizer, train_dataloader, lr_scheduler, disc_lr_scheduler\n    )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move VAE, perceptual loss and discriminator to device and cast to weight_dtype\n    vae.to(accelerator.device, dtype=weight_dtype)\n    perceptual_loss.to(accelerator.device, dtype=weight_dtype)\n    discriminator.to(accelerator.device, dtype=weight_dtype)\n    if args.use_ema:\n        ema_vae.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    image_logs = None\n    for epoch in range(first_epoch, args.num_train_epochs):\n        vae.train()\n        discriminator.train()\n        for step, batch in enumerate(train_dataloader):\n            # Convert images to latent space and reconstruct from them\n            targets = batch[\"pixel_values\"].to(dtype=weight_dtype)\n            posterior = accelerator.unwrap_model(vae).encode(targets).latent_dist\n            latents = posterior.sample()\n            reconstructions = accelerator.unwrap_model(vae).decode(latents).sample\n\n            if (step // args.gradient_accumulation_steps) % 2 == 0 or global_step < args.disc_start:\n                with accelerator.accumulate(vae):\n                    # reconstruction loss. Pixel level differences between input vs output\n                    if args.rec_loss == \"l2\":\n                        rec_loss = F.mse_loss(reconstructions.float(), targets.float(), reduction=\"none\")\n                    elif args.rec_loss == \"l1\":\n                        rec_loss = F.l1_loss(reconstructions.float(), targets.float(), reduction=\"none\")\n                    else:\n                        raise ValueError(f\"Invalid reconstruction loss type: {args.rec_loss}\")\n                    # perceptual loss. The high level feature mean squared error loss\n                    with torch.no_grad():\n                        p_loss = perceptual_loss(reconstructions, targets)\n\n                    rec_loss = rec_loss + args.perceptual_scale * p_loss\n                    nll_loss = rec_loss\n                    nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]\n\n                    kl_loss = posterior.kl()\n                    kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]\n\n                    logits_fake = discriminator(reconstructions)\n                    g_loss = -torch.mean(logits_fake)\n                    last_layer = accelerator.unwrap_model(vae).decoder.conv_out.weight\n                    nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]\n                    g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]\n                    disc_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n                    disc_weight = torch.clamp(disc_weight, 0.0, 1e4).detach()\n                    disc_weight = disc_weight * args.disc_scale\n                    disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0\n\n                    loss = nll_loss + args.kl_scale * kl_loss + disc_weight * disc_factor * g_loss\n\n                    logs = {\n                        \"loss\": loss.detach().mean().item(),\n                        \"nll_loss\": nll_loss.detach().mean().item(),\n                        \"rec_loss\": rec_loss.detach().mean().item(),\n                        \"p_loss\": p_loss.detach().mean().item(),\n                        \"kl_loss\": kl_loss.detach().mean().item(),\n                        \"disc_weight\": disc_weight.detach().mean().item(),\n                        \"disc_factor\": disc_factor,\n                        \"g_loss\": g_loss.detach().mean().item(),\n                        \"lr\": lr_scheduler.get_last_lr()[0],\n                    }\n\n                    accelerator.backward(loss)\n                    if accelerator.sync_gradients:\n                        params_to_clip = vae.parameters()\n                        accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                    optimizer.step()\n                    lr_scheduler.step()\n                    optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n            else:\n                with accelerator.accumulate(discriminator):\n                    logits_real = discriminator(targets)\n                    logits_fake = discriminator(reconstructions)\n                    disc_loss = hinge_d_loss if args.disc_loss == \"hinge\" else vanilla_d_loss\n                    disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0\n                    d_loss = disc_factor * disc_loss(logits_real, logits_fake)\n                    logs = {\n                        \"disc_loss\": d_loss.detach().mean().item(),\n                        \"logits_real\": logits_real.detach().mean().item(),\n                        \"logits_fake\": logits_fake.detach().mean().item(),\n                        \"disc_lr\": disc_lr_scheduler.get_last_lr()[0],\n                    }\n                    accelerator.backward(d_loss)\n                    if accelerator.sync_gradients:\n                        params_to_clip = discriminator.parameters()\n                        accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                    disc_optimizer.step()\n                    disc_lr_scheduler.step()\n                    disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                if args.use_ema:\n                    ema_vae.step(vae.parameters())\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if global_step == 1 or global_step % args.validation_steps == 0:\n                        if args.use_ema:\n                            ema_vae.store(vae.parameters())\n                            ema_vae.copy_to(vae.parameters())\n                        image_logs = log_validation(\n                            vae,\n                            args,\n                            accelerator,\n                            weight_dtype,\n                            global_step,\n                        )\n                        if args.use_ema:\n                            ema_vae.restore(vae.parameters())\n\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        vae = accelerator.unwrap_model(vae)\n        discriminator = accelerator.unwrap_model(discriminator)\n        if args.use_ema:\n            ema_vae.copy_to(vae.parameters())\n        vae.save_pretrained(args.output_dir)\n        torch.save(discriminator.state_dict(), os.path.join(args.output_dir, \"pytorch_model.bin\"))\n        # Run a final round of validation.\n        image_logs = None\n        image_logs = log_validation(\n            vae=vae,\n            args=args,\n            accelerator=accelerator,\n            weight_dtype=weight_dtype,\n            step=global_step,\n            is_final_validation=True,\n        )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                image_logs=image_logs,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/colossalai/README.md",
    "content": "# [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) by [colossalai](https://github.com/hpcaitech/ColossalAI.git)\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.\nThe `train_dreambooth_colossalai.py` script shows how to implement the training procedure and adapt it for stable diffusion.\n\nBy accommodating model data in CPU and GPU and moving the data to the computing device when necessary, [Gemini](https://www.colossalai.org/docs/advanced_tutorials/meet_gemini), the Heterogeneous Memory Manager of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) can breakthrough the GPU memory wall by using GPU and CPU memory (composed of CPU DRAM or nvme SSD memory) together at the same time. Moreover, the model scale can be further improved by combining heterogeneous training with the other parallel approaches, such as data parallel, tensor parallel and pipeline parallel.\n\n## Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n```bash\npip install -r requirements.txt\n```\n\n## Install [ColossalAI](https://github.com/hpcaitech/ColossalAI.git)\n\n**From PyPI**\n```bash\npip install colossalai\n```\n\n**From source**\n\n```bash\ngit clone https://github.com/hpcaitech/ColossalAI.git\ncd ColossalAI\n\n# install colossalai\npip install .\n```\n\n## Dataset for Teyvat BLIP captions\nDataset used to train [Teyvat characters text to image model](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion).\n\nBLIP generated captions for characters images from [genshin-impact fandom wiki](https://genshin-impact.fandom.com/wiki/Character#Playable_Characters)and [biligame wiki for genshin impact](https://wiki.biligame.com/ys/%E8%A7%92%E8%89%B2).\n\nFor each row the dataset contains `image` and `text` keys. `image` is a varying size PIL png, and `text` is the accompanying text caption. Only a train split is provided.\n\nThe `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Model type`, and `Description`, the `Description` is captioned with the [pre-trained BLIP model](https://github.com/salesforce/BLIP).\n\n## Training\n\nThe argument `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training， with `auto` a more balanced solution for speed and memory can be obtained。\n\n**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\ntorchrun --nproc_per_node 2 train_dreambooth_colossalai.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=400 \\\n  --placement=\"cuda\"\n```\n\n\n### Training with prior-preservation loss\n\nPrior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.\nAccording to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\ntorchrun --nproc_per_node 2 train_dreambooth_colossalai.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=800 \\\n  --placement=\"cuda\"\n```\n\n## Inference\n\nOnce you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.\n\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\nmodel_id = \"path-to-save-model\"\npipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A photo of sks dog in a bucket\"\nimage = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]\n\nimage.save(\"dog-bucket.png\")\n```\n"
  },
  {
    "path": "examples/research_projects/colossalai/inference.py",
    "content": "import torch\n\nfrom diffusers import StableDiffusionPipeline\n\n\nmodel_id = \"path-to-your-trained-model\"\npipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A photo of sks dog in a bucket\"\nimage = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]\n\nimage.save(\"dog-bucket.png\")\n"
  },
  {
    "path": "examples/research_projects/colossalai/requirement.txt",
    "content": "diffusers\ntorch\ntorchvision\nftfy\ntensorboard\nJinja2\ntransformers"
  },
  {
    "path": "examples/research_projects/colossalai/train_dreambooth_colossalai.py",
    "content": "import argparse\nimport math\nimport os\nfrom pathlib import Path\n\nimport colossalai\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom colossalai.context.parallel_mode import ParallelMode\nfrom colossalai.core import global_context as gpc\nfrom colossalai.logging import disable_existing_loggers, get_dist_logger\nfrom colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer\nfrom colossalai.nn.parallel.utils import get_static_torch_model\nfrom colossalai.utils import get_current_device\nfrom colossalai.utils.model.colo_init_context import ColoInitContext\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nfrom diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\n\n\ndisable_existing_loggers()\nlogger = get_dist_logger()\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=args.revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A folder containing the training data of instance images.\",\n    )\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=\"a photo of sks dog\",\n        required=False,\n        help=\"The prompt with identifier specifying the instance\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--placement\",\n        type=str,\n        default=\"cpu\",\n        help=\"Placement Policy for Gemini. Valid when using colossalai as dist plan.\",\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\"--save_steps\", type=int, default=500, help=\"Save checkpoint every X updates steps.\")\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        if args.class_data_dir is not None:\n            logger.warning(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            logger.warning(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images and the tokenizes prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        tokenizer,\n        class_data_root=None,\n        class_prompt=None,\n        size=512,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n        self.tokenizer = tokenizer\n\n        self.instance_data_root = Path(instance_data_root)\n        if not self.instance_data_root.exists():\n            raise ValueError(\"Instance images root doesn't exists.\")\n\n        self.instance_images_path = list(Path(instance_data_root).iterdir())\n        self.num_instance_images = len(self.instance_images_path)\n        self.instance_prompt = instance_prompt\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n            self.class_prompt = class_prompt\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])\n        if not instance_image.mode == \"RGB\":\n            instance_image = instance_image.convert(\"RGB\")\n        example[\"instance_images\"] = self.image_transforms(instance_image)\n        example[\"instance_prompt_ids\"] = self.tokenizer(\n            self.instance_prompt,\n            padding=\"do_not_pad\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n        ).input_ids\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt_ids\"] = self.tokenizer(\n                self.class_prompt,\n                padding=\"do_not_pad\",\n                truncation=True,\n                max_length=self.tokenizer.model_max_length,\n            ).input_ids\n\n        return example\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\n# Gemini + ZeRO DDP\ndef gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = \"auto\"):\n    from colossalai.nn.parallel import GeminiDDP\n\n    model = GeminiDDP(\n        model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=64\n    )\n    return model\n\n\ndef main(args):\n    if args.seed is None:\n        colossalai.launch_from_torch(config={})\n    else:\n        colossalai.launch_from_torch(config={}, seed=args.seed)\n\n    local_rank = gpc.get_local_rank(ParallelMode.DATA)\n    world_size = gpc.get_world_size(ParallelMode.DATA)\n\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            torch_dtype = torch.float16 if get_current_device() == \"cuda\" else torch.float32\n            pipeline = DiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                safety_checker=None,\n                revision=args.revision,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            pipeline.to(get_current_device())\n\n            for example in tqdm(\n                sample_dataloader,\n                desc=\"Generating class images\",\n                disable=not local_rank == 0,\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n\n    # Handle the repository creation\n    if local_rank == 0:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    if args.tokenizer_name:\n        logger.info(f\"Loading tokenizer from {args.tokenizer_name}\", ranks=[0])\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.tokenizer_name,\n            revision=args.revision,\n            use_fast=False,\n        )\n    elif args.pretrained_model_name_or_path:\n        logger.info(\"Loading tokenizer from pretrained model\", ranks=[0])\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.pretrained_model_name_or_path,\n            subfolder=\"tokenizer\",\n            revision=args.revision,\n            use_fast=False,\n        )\n        # import correct text encoder class\n    text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)\n\n    # Load models and create wrapper for stable diffusion\n\n    logger.info(f\"Loading text_encoder from {args.pretrained_model_name_or_path}\", ranks=[0])\n\n    text_encoder = text_encoder_cls.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=args.revision,\n    )\n\n    logger.info(f\"Loading AutoencoderKL from {args.pretrained_model_name_or_path}\", ranks=[0])\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n    )\n\n    logger.info(f\"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}\", ranks=[0])\n    with ColoInitContext(device=get_current_device()):\n        unet = UNet2DConditionModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, low_cpu_mem_usage=False\n        )\n\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    if args.scale_lr:\n        args.learning_rate = args.learning_rate * args.train_batch_size * world_size\n\n    unet = gemini_zero_dpp(unet, args.placement)\n\n    # config optimizer for colossalai zero\n    optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)\n\n    # load noise_scheduler\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    # prepare dataset\n    logger.info(f\"Prepare dataset from {args.instance_data_dir}\", ranks=[0])\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_prompt=args.class_prompt,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        center_crop=args.center_crop,\n    )\n\n    def collate_fn(examples):\n        input_ids = [example[\"instance_prompt_ids\"] for example in examples]\n        pixel_values = [example[\"instance_images\"] for example in examples]\n\n        # Concat class and instance examples for prior preservation.\n        # We do this to avoid doing two forward passes.\n        if args.with_prior_preservation:\n            input_ids += [example[\"class_prompt_ids\"] for example in examples]\n            pixel_values += [example[\"class_images\"] for example in examples]\n\n        pixel_values = torch.stack(pixel_values)\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n        input_ids = tokenizer.pad(\n            {\"input_ids\": input_ids},\n            padding=\"max_length\",\n            max_length=tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids\n\n        batch = {\n            \"input_ids\": input_ids,\n            \"pixel_values\": pixel_values,\n        }\n        return batch\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader))\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps,\n        num_training_steps=args.max_train_steps,\n    )\n    weight_dtype = torch.float32\n    if args.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif args.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move text_encode and vae to gpu.\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    vae.to(get_current_device(), dtype=weight_dtype)\n    text_encoder.to(get_current_device(), dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader))\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # Train!\n    total_batch_size = args.train_batch_size * world_size\n\n    logger.info(\"***** Running training *****\", ranks=[0])\n    logger.info(f\"  Num examples = {len(train_dataset)}\", ranks=[0])\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\", ranks=[0])\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\", ranks=[0])\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\", ranks=[0])\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\", ranks=[0])\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\", ranks=[0])\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)\n    progress_bar.set_description(\"Steps\")\n    global_step = 0\n\n    torch.cuda.synchronize()\n    for epoch in range(args.num_train_epochs):\n        unet.train()\n        for step, batch in enumerate(train_dataloader):\n            torch.cuda.reset_peak_memory_stats()\n            # Move batch to gpu\n            for key, value in batch.items():\n                batch[key] = value.to(get_current_device(), non_blocking=True)\n\n            # Convert images to latent space\n            optimizer.zero_grad()\n\n            latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n            latents = latents * 0.18215\n\n            # Sample noise that we'll add to the latents\n            noise = torch.randn_like(latents)\n            bsz = latents.shape[0]\n            # Sample a random timestep for each image\n            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n            timesteps = timesteps.long()\n\n            # Add noise to the latents according to the noise magnitude at each timestep\n            # (this is the forward diffusion process)\n            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n            # Get the text embedding for conditioning\n            encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n            # Predict the noise residual\n            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n\n            # Get the target for loss depending on the prediction type\n            if noise_scheduler.config.prediction_type == \"epsilon\":\n                target = noise\n            elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                target = noise_scheduler.get_velocity(latents, noise, timesteps)\n            else:\n                raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n            if args.with_prior_preservation:\n                # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                target, target_prior = torch.chunk(target, 2, dim=0)\n\n                # Compute instance loss\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\").mean([1, 2, 3]).mean()\n\n                # Compute prior loss\n                prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction=\"mean\")\n\n                # Add the prior loss to the instance loss.\n                loss = loss + args.prior_loss_weight * prior_loss\n            else:\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n            optimizer.backward(loss)\n\n            optimizer.step()\n            lr_scheduler.step()\n            logger.info(f\"max GPU_mem cost is {torch.cuda.max_memory_allocated() / 2**20} MB\", ranks=[0])\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            progress_bar.update(1)\n            global_step += 1\n            logs = {\n                \"loss\": loss.detach().item(),\n                \"lr\": optimizer.param_groups[0][\"lr\"],\n            }  # lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step % args.save_steps == 0:\n                torch.cuda.synchronize()\n                torch_unet = get_static_torch_model(unet)\n                if local_rank == 0:\n                    pipeline = DiffusionPipeline.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        unet=torch_unet,\n                        revision=args.revision,\n                    )\n                    save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                    pipeline.save_pretrained(save_path)\n                    logger.info(f\"Saving model checkpoint to {save_path}\", ranks=[0])\n            if global_step >= args.max_train_steps:\n                break\n\n    torch.cuda.synchronize()\n    unet = get_static_torch_model(unet)\n\n    if local_rank == 0:\n        pipeline = DiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            unet=unet,\n            revision=args.revision,\n        )\n\n        pipeline.save_pretrained(args.output_dir)\n        logger.info(f\"Saving model checkpoint to {args.output_dir}\", ranks=[0])\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/consistency_training/README.md",
    "content": "# Consistency Training\n\n`train_cm_ct_unconditional.py` trains a consistency model (CM) from scratch following the consistency training (CT) algorithm introduced in [Consistency Models](https://huggingface.co/papers/2303.01469) and refined in [Improved Techniques for Training Consistency Models](https://huggingface.co/papers/2310.14189). Both unconditional and class-conditional training are supported.\n\nA usage example is as follows:\n\n```bash\naccelerate launch examples/research_projects/consistency_training/train_cm_ct_unconditional.py \\\n    --dataset_name=\"cifar10\" \\\n    --dataset_image_column_name=\"img\" \\\n    --output_dir=\"/path/to/output/dir\" \\\n    --mixed_precision=fp16 \\\n    --resolution=32 \\\n    --max_train_steps=1000 --max_train_samples=10000 \\\n    --dataloader_num_workers=8 \\\n    --noise_precond_type=\"cm\" --input_precond_type=\"cm\" \\\n    --train_batch_size=4 \\\n    --learning_rate=1e-04 --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n    --use_8bit_adam \\\n    --use_ema \\\n    --validation_steps=100 --eval_batch_size=4 \\\n    --checkpointing_steps=100 --checkpoints_total_limit=10 \\\n    --class_conditional --num_classes=10 \\\n```"
  },
  {
    "path": "examples/research_projects/consistency_training/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2"
  },
  {
    "path": "examples/research_projects/consistency_training/train_cm_ct_unconditional.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n\"\"\"Script to train a consistency model from scratch via (improved) consistency training.\"\"\"\n\nimport argparse\nimport gc\nimport logging\nimport math\nimport os\nimport shutil\nfrom datetime import timedelta\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport torch\nfrom accelerate import Accelerator, InitProcessGroupKwargs\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\n\nimport diffusers\nfrom diffusers import (\n    CMStochasticIterativeScheduler,\n    ConsistencyModelPipeline,\n    UNet2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel, resolve_interpolation_mode\nfrom diffusers.utils import is_tensorboard_available, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef _extract_into_tensor(arr, timesteps, broadcast_shape):\n    \"\"\"\n    Extract values from a 1-D numpy array for a batch of indices.\n\n    :param arr: the 1-D numpy array.\n    :param timesteps: a tensor of indices into the array to extract.\n    :param broadcast_shape: a larger shape of K dimensions with the batch\n                            dimension equal to the length of timesteps.\n    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.\n    \"\"\"\n    if not isinstance(arr, torch.Tensor):\n        arr = torch.from_numpy(arr)\n    res = arr[timesteps].float().to(timesteps.device)\n    while len(res.shape) < len(broadcast_shape):\n        res = res[..., None]\n    return res.expand(broadcast_shape)\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\")\n    return x[(...,) + (None,) * dims_to_append]\n\n\ndef extract_into_tensor(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef get_discretization_steps(global_step: int, max_train_steps: int, s_0: int = 10, s_1: int = 1280, constant=False):\n    \"\"\"\n    Calculates the current discretization steps at global step k using the discretization curriculum N(k).\n    \"\"\"\n    if constant:\n        return s_0 + 1\n\n    k_prime = math.floor(max_train_steps / (math.log2(math.floor(s_1 / s_0)) + 1))\n    num_discretization_steps = min(s_0 * 2 ** math.floor(global_step / k_prime), s_1) + 1\n\n    return num_discretization_steps\n\n\ndef get_skip_steps(global_step, initial_skip: int = 1):\n    # Currently only support constant skip curriculum.\n    return initial_skip\n\n\ndef get_karras_sigmas(\n    num_discretization_steps: int,\n    sigma_min: float = 0.002,\n    sigma_max: float = 80.0,\n    rho: float = 7.0,\n    dtype=torch.float32,\n):\n    \"\"\"\n    Calculates the Karras sigmas timestep discretization of [sigma_min, sigma_max].\n    \"\"\"\n    ramp = np.linspace(0, 1, num_discretization_steps)\n    min_inv_rho = sigma_min ** (1 / rho)\n    max_inv_rho = sigma_max ** (1 / rho)\n    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho\n    # Make sure sigmas are in increasing rather than decreasing order (see section 2 of the iCT paper)\n    sigmas = sigmas[::-1].copy()\n    sigmas = torch.from_numpy(sigmas).to(dtype=dtype)\n    return sigmas\n\n\ndef get_discretized_lognormal_weights(noise_levels: torch.Tensor, p_mean: float = -1.1, p_std: float = 2.0):\n    \"\"\"\n    Calculates the unnormalized weights for a 1D array of noise level sigma_i based on the discretized lognormal\"\n    \" distribution used in the iCT paper (given in Equation 10).\n    \"\"\"\n    upper_prob = torch.special.erf((torch.log(noise_levels[1:]) - p_mean) / (math.sqrt(2) * p_std))\n    lower_prob = torch.special.erf((torch.log(noise_levels[:-1]) - p_mean) / (math.sqrt(2) * p_std))\n    weights = upper_prob - lower_prob\n    return weights\n\n\ndef get_loss_weighting_schedule(noise_levels: torch.Tensor):\n    \"\"\"\n    Calculates the loss weighting schedule lambda given a set of noise levels.\n    \"\"\"\n    return 1.0 / (noise_levels[1:] - noise_levels[:-1])\n\n\ndef add_noise(original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):\n    # Make sure timesteps (Karras sigmas) have the same device and dtype as original_samples\n    sigmas = timesteps.to(device=original_samples.device, dtype=original_samples.dtype)\n    while len(sigmas.shape) < len(original_samples.shape):\n        sigmas = sigmas.unsqueeze(-1)\n\n    noisy_samples = original_samples + noise * sigmas\n\n    return noisy_samples\n\n\ndef get_noise_preconditioning(sigmas, noise_precond_type: str = \"cm\"):\n    \"\"\"\n    Calculates the noise preconditioning function c_noise, which is used to transform the raw Karras sigmas into the\n    timestep input for the U-Net.\n    \"\"\"\n    if noise_precond_type == \"none\":\n        return sigmas\n    elif noise_precond_type == \"edm\":\n        return 0.25 * torch.log(sigmas)\n    elif noise_precond_type == \"cm\":\n        return 1000 * 0.25 * torch.log(sigmas + 1e-44)\n    else:\n        raise ValueError(\n            f\"Noise preconditioning type {noise_precond_type} is not current supported. Currently supported noise\"\n            f\" preconditioning types are `none` (which uses the sigmas as is), `edm`, and `cm`.\"\n        )\n\n\ndef get_input_preconditioning(sigmas, sigma_data=0.5, input_precond_type: str = \"cm\"):\n    \"\"\"\n    Calculates the input preconditioning factor c_in, which is used to scale the U-Net image input.\n    \"\"\"\n    if input_precond_type == \"none\":\n        return 1\n    elif input_precond_type == \"cm\":\n        return 1.0 / (sigmas**2 + sigma_data**2)\n    else:\n        raise ValueError(\n            f\"Input preconditioning type {input_precond_type} is not current supported. Currently supported input\"\n            f\" preconditioning types are `none` (which uses a scaling factor of 1.0) and `cm`.\"\n        )\n\n\ndef scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=1.0):\n    scaled_timestep = timestep_scaling * timestep\n    c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)\n    c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5\n    return c_skip, c_out\n\n\ndef log_validation(unet, scheduler, args, accelerator, weight_dtype, step, name=\"teacher\"):\n    logger.info(\"Running validation... \")\n\n    unet = accelerator.unwrap_model(unet)\n    pipeline = ConsistencyModelPipeline(\n        unet=unet,\n        scheduler=scheduler,\n    )\n    pipeline = pipeline.to(device=accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    class_labels = [None]\n    if args.class_conditional:\n        if args.num_classes is not None:\n            class_labels = list(range(args.num_classes))\n        else:\n            logger.warning(\n                \"The model is class-conditional but the number of classes is not set. The generated images will be\"\n                \" unconditional rather than class-conditional.\"\n            )\n\n    image_logs = []\n\n    for class_label in class_labels:\n        images = []\n        with torch.autocast(\"cuda\"):\n            images = pipeline(\n                num_inference_steps=1,\n                batch_size=args.eval_batch_size,\n                class_labels=[class_label] * args.eval_batch_size,\n                generator=generator,\n            ).images\n        log = {\"images\": images}\n        if args.class_conditional and class_label is not None:\n            log[\"class_label\"] = str(class_label)\n        else:\n            log[\"class_label\"] = \"images\"\n        image_logs.append(log)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                class_label = log[\"class_label\"]\n                formatted_images = []\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(class_label, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                class_label = log[\"class_label\"]\n                for image in images:\n                    image = wandb.Image(image, caption=class_label)\n                    formatted_images.append(image)\n\n            tracker.log({f\"validation/{name}\": formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n    del pipeline\n    gc.collect()\n    torch.cuda.empty_cache()\n\n    return image_logs\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    # ------------Model Arguments-----------\n    parser.add_argument(\n        \"--model_config_name_or_path\",\n        type=str,\n        default=None,\n        help=\"The config of the UNet model to train, leave as None to use standard DDPM configuration.\",\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        help=(\n            \"If initializing the weights from a pretrained model, the path to the pretrained model or model identifier\"\n            \" from huggingface.co/models.\"\n        ),\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=(\n            \"Variant of the model files of the pretrained model identifier from huggingface.co/models, e.g. `fp16`,\"\n            \" `non_ema`, etc.\",\n        ),\n    )\n    # ------------Dataset Arguments-----------\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that HF Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--dataset_image_column_name\",\n        type=str,\n        default=\"image\",\n        help=\"The name of the image column in the dataset to use for training.\",\n    )\n    parser.add_argument(\n        \"--dataset_class_label_column_name\",\n        type=str,\n        default=\"label\",\n        help=\"If doing class-conditional training, the name of the class label column in the dataset to use.\",\n    )\n    # ------------Image Processing Arguments-----------\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=64,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--interpolation_type\",\n        type=str,\n        default=\"bilinear\",\n        help=(\n            \"The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,\"\n            \" `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--class_conditional\",\n        action=\"store_true\",\n        help=(\n            \"Whether to train a class-conditional model. If set, the class labels will be taken from the `label`\"\n            \" column of the provided dataset.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_classes\",\n        type=int,\n        default=None,\n        help=\"The number of classes in the training data, if training a class-conditional model.\",\n    )\n    parser.add_argument(\n        \"--class_embed_type\",\n        type=str,\n        default=None,\n        help=(\n            \"The class embedding type to use. Choose from `None`, `identity`, and `timestep`. If `class_conditional`\"\n            \" and `num_classes` and set, but `class_embed_type` is `None`, a embedding matrix will be used.\"\n        ),\n    )\n    # ------------Dataloader Arguments-----------\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main\"\n            \" process.\"\n        ),\n    )\n    # ------------Training Arguments-----------\n    # ----General Training Arguments----\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"ddpm-model-64\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--overwrite_output_dir\", action=\"store_true\")\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    # ----Batch Size and Training Length----\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    # ----Learning Rate----\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"cosine\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    # ----Optimizer (Adam) Arguments----\n    parser.add_argument(\n        \"--optimizer_type\",\n        type=str,\n        default=\"adamw\",\n        help=(\n            \"The optimizer algorithm to use for training. Choose between `radam` and `adamw`. The iCT paper uses\"\n            \" RAdam.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.95, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\n        \"--adam_weight_decay\", type=float, default=1e-6, help=\"Weight decay magnitude for the Adam optimizer.\"\n    )\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer.\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    # ----Consistency Training (CT) Specific Arguments----\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=\"sample\",\n        choices=[\"sample\"],\n        help=\"Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.\",\n    )\n    parser.add_argument(\"--ddpm_num_steps\", type=int, default=1000)\n    parser.add_argument(\"--ddpm_num_inference_steps\", type=int, default=1000)\n    parser.add_argument(\"--ddpm_beta_schedule\", type=str, default=\"linear\")\n    parser.add_argument(\n        \"--sigma_min\",\n        type=float,\n        default=0.002,\n        help=(\n            \"The lower boundary for the timestep discretization, which should be set to a small positive value close\"\n            \" to zero to avoid numerical issues when solving the PF-ODE backwards in time.\"\n        ),\n    )\n    parser.add_argument(\n        \"--sigma_max\",\n        type=float,\n        default=80.0,\n        help=(\n            \"The upper boundary for the timestep discretization, which also determines the variance of the Gaussian\"\n            \" prior.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rho\",\n        type=float,\n        default=7.0,\n        help=\"The rho parameter for the Karras sigmas timestep dicretization.\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=None,\n        help=(\n            \"The Pseudo-Huber loss parameter c. If not set, this will default to the value recommended in the Improved\"\n            \" Consistency Training (iCT) paper of 0.00054 * sqrt(d), where d is the data dimensionality.\"\n        ),\n    )\n    parser.add_argument(\n        \"--discretization_s_0\",\n        type=int,\n        default=10,\n        help=(\n            \"The s_0 parameter in the discretization curriculum N(k). This controls the number of training steps after\"\n            \" which the number of discretization steps N will be doubled.\"\n        ),\n    )\n    parser.add_argument(\n        \"--discretization_s_1\",\n        type=int,\n        default=1280,\n        help=(\n            \"The s_1 parameter in the discretization curriculum N(k). This controls the upper limit to the number of\"\n            \" discretization steps used. Increasing this value will reduce the bias at the cost of higher variance.\"\n        ),\n    )\n    parser.add_argument(\n        \"--constant_discretization_steps\",\n        action=\"store_true\",\n        help=(\n            \"Whether to set the discretization curriculum N(k) to be the constant value `discretization_s_0 + 1`. This\"\n            \" is useful for testing when `max_number_steps` is small, when `k_prime` would otherwise be 0, causing\"\n            \" a divide-by-zero error.\"\n        ),\n    )\n    parser.add_argument(\n        \"--p_mean\",\n        type=float,\n        default=-1.1,\n        help=(\n            \"The mean parameter P_mean for the (discretized) lognormal noise schedule, which controls the probability\"\n            \" of sampling a (discrete) noise level sigma_i.\"\n        ),\n    )\n    parser.add_argument(\n        \"--p_std\",\n        type=float,\n        default=2.0,\n        help=(\n            \"The standard deviation parameter P_std for the (discretized) noise schedule, which controls the\"\n            \" probability of sampling a (discrete) noise level sigma_i.\"\n        ),\n    )\n    parser.add_argument(\n        \"--noise_precond_type\",\n        type=str,\n        default=\"cm\",\n        help=(\n            \"The noise preconditioning function to use for transforming the raw Karras sigmas into the timestep\"\n            \" argument of the U-Net. Choose between `none` (the identity function), `edm`, and `cm`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--input_precond_type\",\n        type=str,\n        default=\"cm\",\n        help=(\n            \"The input preconditioning function to use for scaling the image input of the U-Net. Choose between `none`\"\n            \" (a scaling factor of 1) and `cm`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--skip_steps\",\n        type=int,\n        default=1,\n        help=(\n            \"The gap in indices between the student and teacher noise levels. In the iCT paper this is always set to\"\n            \" 1, but theoretically this could be greater than 1 and/or altered according to a curriculum throughout\"\n            \" training, much like the number of discretization steps is.\"\n        ),\n    )\n    parser.add_argument(\n        \"--cast_teacher\",\n        action=\"store_true\",\n        help=\"Whether to cast the teacher U-Net model to `weight_dtype` or leave it in full precision.\",\n    )\n    # ----Exponential Moving Average (EMA) Arguments----\n    parser.add_argument(\n        \"--use_ema\",\n        action=\"store_true\",\n        help=\"Whether to use Exponential Moving Average for the final model weights.\",\n    )\n    parser.add_argument(\n        \"--ema_min_decay\",\n        type=float,\n        default=None,\n        help=(\n            \"The minimum decay magnitude for EMA. If not set, this will default to the value of `ema_max_decay`,\"\n            \" resulting in a constant EMA decay rate.\"\n        ),\n    )\n    parser.add_argument(\n        \"--ema_max_decay\",\n        type=float,\n        default=0.99993,\n        help=(\n            \"The maximum decay magnitude for EMA. Setting `ema_min_decay` equal to this value will result in a\"\n            \" constant decay rate.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_ema_warmup\",\n        action=\"store_true\",\n        help=\"Whether to use EMA warmup.\",\n    )\n    parser.add_argument(\"--ema_inv_gamma\", type=float, default=1.0, help=\"The inverse gamma value for the EMA decay.\")\n    parser.add_argument(\"--ema_power\", type=float, default=3 / 4, help=\"The power value for the EMA decay.\")\n    # ----Training Optimization Arguments----\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    # ----Distributed Training Arguments----\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    # ------------Validation Arguments-----------\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=200,\n        help=\"Run validation every X steps.\",\n    )\n    parser.add_argument(\n        \"--eval_batch_size\",\n        type=int,\n        default=16,\n        help=(\n            \"The number of images to generate for evaluation. Note that if `class_conditional` and `num_classes` is\"\n            \" set the effective number of images generated per evaluation step is `eval_batch_size * num_classes`.\"\n        ),\n    )\n    parser.add_argument(\"--save_images_epochs\", type=int, default=10, help=\"How often to save images during training.\")\n    # ------------Validation Arguments-----------\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--save_model_epochs\", type=int, default=10, help=\"How often to save the model during training.\"\n    )\n    # ------------Logging Arguments-----------\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    # ------------HuggingFace Hub Arguments-----------\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--hub_private_repo\", action=\"store_true\", help=\"Whether or not to create a private repository.\"\n    )\n    # ------------Accelerate Arguments-----------\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"consistency-training\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"You must specify either a dataset name from the hub or a train data directory.\")\n\n    return args\n\n\ndef main(args):\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))  # a big number for high resolution or big dataset\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    if args.report_to == \"tensorboard\":\n        if not is_tensorboard_available():\n            raise ImportError(\"Make sure to install tensorboard if you want to use it for logging during training.\")\n\n    elif args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # 1. Initialize the noise scheduler.\n    initial_discretization_steps = get_discretization_steps(\n        0,\n        args.max_train_steps,\n        s_0=args.discretization_s_0,\n        s_1=args.discretization_s_1,\n        constant=args.constant_discretization_steps,\n    )\n    noise_scheduler = CMStochasticIterativeScheduler(\n        num_train_timesteps=initial_discretization_steps,\n        sigma_min=args.sigma_min,\n        sigma_max=args.sigma_max,\n        rho=args.rho,\n    )\n\n    # 2. Initialize the student U-Net model.\n    if args.pretrained_model_name_or_path is not None:\n        logger.info(f\"Loading pretrained U-Net weights from {args.pretrained_model_name_or_path}... \")\n        unet = UNet2DModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n        )\n    elif args.model_config_name_or_path is None:\n        # TODO: use default architectures from iCT paper\n        if not args.class_conditional and (args.num_classes is not None or args.class_embed_type is not None):\n            logger.warning(\n                f\"`--class_conditional` is set to `False` but `--num_classes` is set to {args.num_classes} and\"\n                f\" `--class_embed_type` is set to {args.class_embed_type}. These values will be overridden to `None`.\"\n            )\n            args.num_classes = None\n            args.class_embed_type = None\n        elif args.class_conditional and args.num_classes is None and args.class_embed_type is None:\n            logger.warning(\n                \"`--class_conditional` is set to `True` but neither `--num_classes` nor `--class_embed_type` is set.\"\n                \"`class_conditional` will be overridden to `False`.\"\n            )\n            args.class_conditional = False\n        unet = UNet2DModel(\n            sample_size=args.resolution,\n            in_channels=3,\n            out_channels=3,\n            layers_per_block=2,\n            block_out_channels=(128, 128, 256, 256, 512, 512),\n            down_block_types=(\n                \"DownBlock2D\",\n                \"DownBlock2D\",\n                \"DownBlock2D\",\n                \"DownBlock2D\",\n                \"AttnDownBlock2D\",\n                \"DownBlock2D\",\n            ),\n            up_block_types=(\n                \"UpBlock2D\",\n                \"AttnUpBlock2D\",\n                \"UpBlock2D\",\n                \"UpBlock2D\",\n                \"UpBlock2D\",\n                \"UpBlock2D\",\n            ),\n            class_embed_type=args.class_embed_type,\n            num_class_embeds=args.num_classes,\n        )\n    else:\n        config = UNet2DModel.load_config(args.model_config_name_or_path)\n        unet = UNet2DModel.from_config(config)\n    unet.train()\n\n    # Create EMA for the student U-Net model.\n    if args.use_ema:\n        if args.ema_min_decay is None:\n            args.ema_min_decay = args.ema_max_decay\n        ema_unet = EMAModel(\n            unet.parameters(),\n            decay=args.ema_max_decay,\n            min_decay=args.ema_min_decay,\n            use_ema_warmup=args.use_ema_warmup,\n            inv_gamma=args.ema_inv_gamma,\n            power=args.ema_power,\n            model_cls=UNet2DModel,\n            model_config=unet.config,\n        )\n\n    # 3. Initialize the teacher U-Net model from the student U-Net model.\n    # Note that following the improved Consistency Training paper, the teacher U-Net is not updated via EMA (e.g. the\n    # EMA decay rate is 0.)\n    teacher_unet = UNet2DModel.from_config(unet.config)\n    teacher_unet.load_state_dict(unet.state_dict())\n    teacher_unet.train()\n    teacher_unet.requires_grad_(False)\n\n    # 4. Handle mixed precision and device placement\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n        args.mixed_precision = accelerator.mixed_precision\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n        args.mixed_precision = accelerator.mixed_precision\n\n    # Cast teacher_unet to weight_dtype if cast_teacher is set.\n    if args.cast_teacher:\n        teacher_dtype = weight_dtype\n    else:\n        teacher_dtype = torch.float32\n\n    teacher_unet.to(accelerator.device)\n    if args.use_ema:\n        ema_unet.to(accelerator.device)\n\n    # 5. Handle saving and loading of checkpoints.\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                teacher_unet.save_pretrained(os.path.join(output_dir, \"unet_teacher\"))\n                if args.use_ema:\n                    ema_unet.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            load_model = UNet2DModel.from_pretrained(os.path.join(input_dir, \"unet_teacher\"))\n            teacher_unet.load_state_dict(load_model.state_dict())\n            teacher_unet.to(accelerator.device)\n            del load_model\n\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"unet_ema\"), UNet2DModel)\n                ema_unet.load_state_dict(load_model.state_dict())\n                ema_unet.to(accelerator.device)\n                del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # 6. Enable optimizations\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n            teacher_unet.enable_xformers_memory_efficient_attention()\n            if args.use_ema:\n                ema_unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    if args.optimizer_type == \"radam\":\n        optimizer_class = torch.optim.RAdam\n    elif args.optimizer_type == \"adamw\":\n        # Use 8-bit Adam for lower memory usage or to fine-tune the model for 16GB GPUs\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n    else:\n        raise ValueError(\n            f\"Optimizer type {args.optimizer_type} is not supported. Currently supported optimizer types are `radam`\"\n            f\" and `adamw`.\"\n        )\n\n    # 7. Initialize the optimizer\n    optimizer = optimizer_class(\n        unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # 8. Dataset creation and data preprocessing\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            split=\"train\",\n        )\n    else:\n        dataset = load_dataset(\"imagefolder\", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split=\"train\")\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets and DataLoaders creation.\n    interpolation_mode = resolve_interpolation_mode(args.interpolation_type)\n    augmentations = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=interpolation_mode),\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def transform_images(examples):\n        images = [augmentations(image.convert(\"RGB\")) for image in examples[args.dataset_image_column_name]]\n        batch_dict = {\"images\": images}\n        if args.class_conditional:\n            batch_dict[\"class_labels\"] = examples[args.dataset_class_label_column_name]\n        return batch_dict\n\n    logger.info(f\"Dataset size: {len(dataset)}\")\n\n    dataset.set_transform(transform_images)\n    train_dataloader = torch.utils.data.DataLoader(\n        dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers\n    )\n\n    # 9. Initialize the learning rate scheduler\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps,\n        num_training_steps=args.max_train_steps,\n    )\n\n    # 10. Prepare for training\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    def recalculate_num_discretization_step_values(discretization_steps, skip_steps):\n        \"\"\"\n        Recalculates all quantities depending on the number of discretization steps N.\n        \"\"\"\n        noise_scheduler = CMStochasticIterativeScheduler(\n            num_train_timesteps=discretization_steps,\n            sigma_min=args.sigma_min,\n            sigma_max=args.sigma_max,\n            rho=args.rho,\n        )\n        current_timesteps = get_karras_sigmas(discretization_steps, args.sigma_min, args.sigma_max, args.rho)\n        valid_teacher_timesteps_plus_one = current_timesteps[: len(current_timesteps) - skip_steps + 1]\n        # timestep_weights are the unnormalized probabilities of sampling the timestep/noise level at each index\n        timestep_weights = get_discretized_lognormal_weights(\n            valid_teacher_timesteps_plus_one, p_mean=args.p_mean, p_std=args.p_std\n        )\n        # timestep_loss_weights is the timestep-dependent loss weighting schedule lambda(sigma_i)\n        timestep_loss_weights = get_loss_weighting_schedule(valid_teacher_timesteps_plus_one)\n\n        current_timesteps = current_timesteps.to(accelerator.device)\n        timestep_weights = timestep_weights.to(accelerator.device)\n        timestep_loss_weights = timestep_loss_weights.to(accelerator.device)\n\n        return noise_scheduler, current_timesteps, timestep_weights, timestep_loss_weights\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Function for unwrapping if torch.compile() was used in accelerate.\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    # Resolve the c parameter for the Pseudo-Huber loss\n    if args.huber_c is None:\n        args.huber_c = 0.00054 * args.resolution * math.sqrt(unwrap_model(unet).config.in_channels)\n\n    # Get current number of discretization steps N according to our discretization curriculum\n    current_discretization_steps = get_discretization_steps(\n        initial_global_step,\n        args.max_train_steps,\n        s_0=args.discretization_s_0,\n        s_1=args.discretization_s_1,\n        constant=args.constant_discretization_steps,\n    )\n    current_skip_steps = get_skip_steps(initial_global_step, initial_skip=args.skip_steps)\n    if current_skip_steps >= current_discretization_steps:\n        raise ValueError(\n            f\"The current skip steps is {current_skip_steps}, but should be smaller than the current number of\"\n            f\" discretization steps {current_discretization_steps}\"\n        )\n    # Recalculate all quantities depending on the number of discretization steps N\n    (\n        noise_scheduler,\n        current_timesteps,\n        timestep_weights,\n        timestep_loss_weights,\n    ) = recalculate_num_discretization_step_values(current_discretization_steps, current_skip_steps)\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    # 11. Train!\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        for step, batch in enumerate(train_dataloader):\n            # 1. Get batch of images from dataloader (sample x ~ p_data(x))\n            clean_images = batch[\"images\"].to(weight_dtype)\n            if args.class_conditional:\n                class_labels = batch[\"class_labels\"]\n            else:\n                class_labels = None\n            bsz = clean_images.shape[0]\n\n            # 2. Sample a random timestep for each image according to the noise schedule.\n            # Sample random indices i ~ p(i), where p(i) is the dicretized lognormal distribution in the iCT paper\n            # NOTE: timestep_indices should be in the range [0, len(current_timesteps) - k - 1] inclusive\n            timestep_indices = torch.multinomial(timestep_weights, bsz, replacement=True).long()\n            teacher_timesteps = current_timesteps[timestep_indices]\n            student_timesteps = current_timesteps[timestep_indices + current_skip_steps]\n\n            # 3. Sample noise and add it to the clean images for both teacher and student unets\n            # Sample noise z ~ N(0, I) that we'll add to the images\n            noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device)\n            # Add noise to the clean images according to the noise magnitude at each timestep\n            # (this is the forward diffusion process)\n            teacher_noisy_images = add_noise(clean_images, noise, teacher_timesteps)\n            student_noisy_images = add_noise(clean_images, noise, student_timesteps)\n\n            # 4. Calculate preconditioning and scalings for boundary conditions for the consistency model.\n            teacher_rescaled_timesteps = get_noise_preconditioning(teacher_timesteps, args.noise_precond_type)\n            student_rescaled_timesteps = get_noise_preconditioning(student_timesteps, args.noise_precond_type)\n\n            c_in_teacher = get_input_preconditioning(teacher_timesteps, input_precond_type=args.input_precond_type)\n            c_in_student = get_input_preconditioning(student_timesteps, input_precond_type=args.input_precond_type)\n\n            c_skip_teacher, c_out_teacher = scalings_for_boundary_conditions(teacher_timesteps)\n            c_skip_student, c_out_student = scalings_for_boundary_conditions(student_timesteps)\n\n            c_skip_teacher, c_out_teacher, c_in_teacher = [\n                append_dims(x, clean_images.ndim) for x in [c_skip_teacher, c_out_teacher, c_in_teacher]\n            ]\n            c_skip_student, c_out_student, c_in_student = [\n                append_dims(x, clean_images.ndim) for x in [c_skip_student, c_out_student, c_in_student]\n            ]\n\n            with accelerator.accumulate(unet):\n                # 5. Get the student unet denoising prediction on the student timesteps\n                # Get rng state now to ensure that dropout is synced between the student and teacher models.\n                dropout_state = torch.get_rng_state()\n                student_model_output = unet(\n                    c_in_student * student_noisy_images, student_rescaled_timesteps, class_labels=class_labels\n                ).sample\n                # NOTE: currently only support prediction_type == sample, so no need to convert model_output\n                student_denoise_output = c_skip_student * student_noisy_images + c_out_student * student_model_output\n\n                # 6. Get the teacher unet denoising prediction on the teacher timesteps\n                with torch.no_grad(), torch.autocast(\"cuda\", dtype=teacher_dtype):\n                    torch.set_rng_state(dropout_state)\n                    teacher_model_output = teacher_unet(\n                        c_in_teacher * teacher_noisy_images, teacher_rescaled_timesteps, class_labels=class_labels\n                    ).sample\n                    # NOTE: currently only support prediction_type == sample, so no need to convert model_output\n                    teacher_denoise_output = (\n                        c_skip_teacher * teacher_noisy_images + c_out_teacher * teacher_model_output\n                    )\n\n                # 7. Calculate the weighted Pseudo-Huber loss\n                if args.prediction_type == \"sample\":\n                    # Note that the loss weights should be those at the (teacher) timestep indices.\n                    lambda_t = _extract_into_tensor(\n                        timestep_loss_weights, timestep_indices, (bsz,) + (1,) * (clean_images.ndim - 1)\n                    )\n                    loss = lambda_t * (\n                        torch.sqrt(\n                            (student_denoise_output.float() - teacher_denoise_output.float()) ** 2 + args.huber_c**2\n                        )\n                        - args.huber_c\n                    )\n                    loss = loss.mean()\n                else:\n                    raise ValueError(\n                        f\"Unsupported prediction type: {args.prediction_type}. Currently, only `sample` is supported.\"\n                    )\n\n                # 8. Backpropagate on the consistency training loss\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                # 9. Update teacher_unet and ema_unet parameters using unet's parameters.\n                teacher_unet.load_state_dict(unet.state_dict())\n                if args.use_ema:\n                    ema_unet.step(unet.parameters())\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    # 10. Recalculate quantities depending on the global step, if necessary.\n                    new_discretization_steps = get_discretization_steps(\n                        global_step,\n                        args.max_train_steps,\n                        s_0=args.discretization_s_0,\n                        s_1=args.discretization_s_1,\n                        constant=args.constant_discretization_steps,\n                    )\n                    current_skip_steps = get_skip_steps(global_step, initial_skip=args.skip_steps)\n                    if current_skip_steps >= new_discretization_steps:\n                        raise ValueError(\n                            f\"The current skip steps is {current_skip_steps}, but should be smaller than the current\"\n                            f\" number of discretization steps {new_discretization_steps}.\"\n                        )\n                    if new_discretization_steps != current_discretization_steps:\n                        (\n                            noise_scheduler,\n                            current_timesteps,\n                            timestep_weights,\n                            timestep_loss_weights,\n                        ) = recalculate_num_discretization_step_values(new_discretization_steps, current_skip_steps)\n                        current_discretization_steps = new_discretization_steps\n\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if global_step % args.validation_steps == 0:\n                        # NOTE: since we do not use EMA for the teacher model, the teacher parameters and student\n                        # parameters are the same at this point in time\n                        log_validation(unet, noise_scheduler, args, accelerator, weight_dtype, global_step, \"teacher\")\n                        # teacher_unet.to(dtype=teacher_dtype)\n\n                        if args.use_ema:\n                            # Store the student unet weights and load the EMA weights.\n                            ema_unet.store(unet.parameters())\n                            ema_unet.copy_to(unet.parameters())\n\n                            log_validation(\n                                unet,\n                                noise_scheduler,\n                                args,\n                                accelerator,\n                                weight_dtype,\n                                global_step,\n                                \"ema_student\",\n                            )\n\n                            # Restore student unet weights\n                            ema_unet.restore(unet.parameters())\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0], \"step\": global_step}\n            if args.use_ema:\n                logs[\"ema_decay\"] = ema_unet.cur_decay_value\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n        # progress_bar.close()\n\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        pipeline = ConsistencyModelPipeline(unet=unet, scheduler=noise_scheduler)\n        pipeline.save_pretrained(args.output_dir)\n\n        # If using EMA, save EMA weights as well.\n        if args.use_ema:\n            ema_unet.copy_to(unet.parameters())\n\n            unet.save_pretrained(os.path.join(args.output_dir, \"ema_unet\"))\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/control_lora/README.md",
    "content": "# Control-LoRA inference example\n\nControl-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs.\n\n## Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder  and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\n## Inference on SDXL\n\n[stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) provides a set of Control-LoRA weights for SDXL. Here we use the `canny` condition to generate an image from a text prompt and a reference image.\n\n```bash\npython control_lora.py\n```\n\n## Acknowledgements\n\n- [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora)\n- [comfyanonymous/ControlNet-v1-1_fp16_safetensors](https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors)\n- [HighCWu/control-lora-v2](https://github.com/HighCWu/control-lora-v2)"
  },
  {
    "path": "examples/research_projects/control_lora/control_lora.py",
    "content": "import cv2\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom diffusers import (\n    AutoencoderKL,\n    ControlNetModel,\n    StableDiffusionXLControlNetPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.utils import load_image, make_image_grid\n\n\npipe_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\nlora_id = \"stabilityai/control-lora\"\nlora_filename = \"control-LoRAs-rank128/control-lora-canny-rank128.safetensors\"\n\nunet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder=\"unet\", torch_dtype=torch.bfloat16).to(\"cuda\")\ncontrolnet = ControlNetModel.from_unet(unet).to(device=\"cuda\", dtype=torch.bfloat16)\ncontrolnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config)\n\nprompt = \"aerial view, a futuristic research complex in a bright foggy jungle, hard lighting\"\nnegative_prompt = \"low quality, bad quality, sketches\"\n\nimage = load_image(\n    \"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png\"\n)\n\ncontrolnet_conditioning_scale = 1.0  # recommended for good generalization\n\nvae = AutoencoderKL.from_pretrained(\"stabilityai/sdxl-vae\", torch_dtype=torch.bfloat16)\npipe = StableDiffusionXLControlNetPipeline.from_pretrained(\n    pipe_id,\n    unet=unet,\n    controlnet=controlnet,\n    vae=vae,\n    torch_dtype=torch.bfloat16,\n    safety_checker=None,\n).to(\"cuda\")\n\nimage = np.array(image)\nimage = cv2.Canny(image, 100, 200)\nimage = image[:, :, None]\nimage = np.concatenate([image, image, image], axis=2)\nimage = Image.fromarray(image)\n\nimages = pipe(\n    prompt,\n    negative_prompt=negative_prompt,\n    image=image,\n    controlnet_conditioning_scale=controlnet_conditioning_scale,\n    num_images_per_prompt=4,\n).images\n\nfinal_image = [image] + images\ngrid = make_image_grid(final_image, 1, 5)\ngrid.save(\"hf-logo_canny.png\")\n"
  },
  {
    "path": "examples/research_projects/controlnet/train_controlnet_webdataset.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport functools\nimport gc\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\nfrom typing import List, Optional, Union\n\nimport accelerate\nimport cv2\nimport numpy as np\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nimport webdataset as wds\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom braceexpand import braceexpand\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torch.utils.data import default_collate\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, DPTForDepthEstimation, DPTImageProcessor, PretrainedConfig\nfrom webdataset.tariterators import (\n    base_plus_ext,\n    tar_file_expander,\n    url_opener,\n    valid_sample,\n)\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    ControlNetModel,\n    EulerDiscreteScheduler,\n    StableDiffusionXLControlNetPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nMAX_SEQ_LENGTH = 77\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.18.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef filter_keys(key_set):\n    def _f(dictionary):\n        return {k: v for k, v in dictionary.items() if k in key_set}\n\n    return _f\n\n\ndef group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):\n    \"\"\"Return function over iterator that groups key, value pairs into samples.\n\n    :param keys: function that splits the key into key and extension (base_plus_ext)\n    :param lcase: convert suffixes to lower case (Default value = True)\n    \"\"\"\n    current_sample = None\n    for filesample in data:\n        assert isinstance(filesample, dict)\n        fname, value = filesample[\"fname\"], filesample[\"data\"]\n        prefix, suffix = keys(fname)\n        if prefix is None:\n            continue\n        if lcase:\n            suffix = suffix.lower()\n        # FIXME webdataset version throws if suffix in current_sample, but we have a potential for\n        #  this happening in the current LAION400m dataset if a tar ends with same prefix as the next\n        #  begins, rare, but can happen since prefix aren't unique across tar files in that dataset\n        if current_sample is None or prefix != current_sample[\"__key__\"] or suffix in current_sample:\n            if valid_sample(current_sample):\n                yield current_sample\n            current_sample = {\"__key__\": prefix, \"__url__\": filesample[\"__url__\"]}\n        if suffixes is None or suffix in suffixes:\n            current_sample[suffix] = value\n    if valid_sample(current_sample):\n        yield current_sample\n\n\ndef tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):\n    # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw\n    streams = url_opener(src, handler=handler)\n    files = tar_file_expander(streams, handler=handler)\n    samples = group_by_keys_nothrow(files, handler=handler)\n    return samples\n\n\ndef control_transform(image):\n    image = np.array(image)\n\n    low_threshold = 100\n    high_threshold = 200\n\n    image = cv2.Canny(image, low_threshold, high_threshold)\n    image = image[:, :, None]\n    image = np.concatenate([image, image, image], axis=2)\n    control_image = Image.fromarray(image)\n    return control_image\n\n\ndef canny_image_transform(example, resolution=1024):\n    image = example[\"image\"]\n    image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)\n    # get crop coordinates\n    c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))\n    image = transforms.functional.crop(image, c_top, c_left, resolution, resolution)\n    control_image = control_transform(image)\n\n    image = transforms.ToTensor()(image)\n    image = transforms.Normalize([0.5], [0.5])(image)\n    control_image = transforms.ToTensor()(control_image)\n\n    example[\"image\"] = image\n    example[\"control_image\"] = control_image\n    example[\"crop_coords\"] = (c_top, c_left)\n\n    return example\n\n\ndef depth_image_transform(example, feature_extractor, resolution=1024):\n    image = example[\"image\"]\n    image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR)(image)\n    # get crop coordinates\n    c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))\n    image = transforms.functional.crop(image, c_top, c_left, resolution, resolution)\n\n    control_image = feature_extractor(images=image, return_tensors=\"pt\").pixel_values.squeeze(0)\n\n    image = transforms.ToTensor()(image)\n    image = transforms.Normalize([0.5], [0.5])(image)\n\n    example[\"image\"] = image\n    example[\"control_image\"] = control_image\n    example[\"crop_coords\"] = (c_top, c_left)\n\n    return example\n\n\nclass WebdatasetFilter:\n    def __init__(self, min_size=1024, max_pwatermark=0.5):\n        self.min_size = min_size\n        self.max_pwatermark = max_pwatermark\n\n    def __call__(self, x):\n        try:\n            if \"json\" in x:\n                x_json = json.loads(x[\"json\"])\n                filter_size = (x_json.get(\"original_width\", 0.0) or 0.0) >= self.min_size and x_json.get(\n                    \"original_height\", 0\n                ) >= self.min_size\n                filter_watermark = (x_json.get(\"pwatermark\", 1.0) or 1.0) <= self.max_pwatermark\n                return filter_size and filter_watermark\n            else:\n                return False\n        except Exception:\n            return False\n\n\nclass Text2ImageDataset:\n    def __init__(\n        self,\n        train_shards_path_or_url: Union[str, List[str]],\n        eval_shards_path_or_url: Union[str, List[str]],\n        num_train_examples: int,\n        per_gpu_batch_size: int,\n        global_batch_size: int,\n        num_workers: int,\n        resolution: int = 256,\n        center_crop: bool = True,\n        random_flip: bool = False,\n        shuffle_buffer_size: int = 1000,\n        pin_memory: bool = False,\n        persistent_workers: bool = False,\n        control_type: str = \"canny\",\n        feature_extractor: Optional[DPTImageProcessor] = None,\n    ):\n        if not isinstance(train_shards_path_or_url, str):\n            train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]\n            # flatten list using itertools\n            train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))\n\n        if not isinstance(eval_shards_path_or_url, str):\n            eval_shards_path_or_url = [list(braceexpand(urls)) for urls in eval_shards_path_or_url]\n            # flatten list using itertools\n            eval_shards_path_or_url = list(itertools.chain.from_iterable(eval_shards_path_or_url))\n\n        def get_orig_size(json):\n            return (int(json.get(\"original_width\", 0.0)), int(json.get(\"original_height\", 0.0)))\n\n        if control_type == \"canny\":\n            image_transform = functools.partial(canny_image_transform, resolution=resolution)\n        elif control_type == \"depth\":\n            image_transform = functools.partial(\n                depth_image_transform, feature_extractor=feature_extractor, resolution=resolution\n            )\n\n        processing_pipeline = [\n            wds.decode(\"pil\", handler=wds.ignore_and_continue),\n            wds.rename(\n                image=\"jpg;png;jpeg;webp\",\n                control_image=\"jpg;png;jpeg;webp\",\n                text=\"text;txt;caption\",\n                orig_size=\"json\",\n                handler=wds.warn_and_continue,\n            ),\n            wds.map(filter_keys({\"image\", \"control_image\", \"text\", \"orig_size\"})),\n            wds.map_dict(orig_size=get_orig_size),\n            wds.map(image_transform),\n            wds.to_tuple(\"image\", \"control_image\", \"text\", \"orig_size\", \"crop_coords\"),\n        ]\n\n        # Create train dataset and loader\n        pipeline = [\n            wds.ResampledShards(train_shards_path_or_url),\n            tarfile_to_samples_nothrow,\n            wds.select(WebdatasetFilter(min_size=512)),\n            wds.shuffle(shuffle_buffer_size),\n            *processing_pipeline,\n            wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),\n        ]\n\n        num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers))  # per dataloader worker\n        num_batches = num_worker_batches * num_workers\n        num_samples = num_batches * global_batch_size\n\n        # each worker is iterating over this\n        self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)\n        self._train_dataloader = wds.WebLoader(\n            self._train_dataset,\n            batch_size=None,\n            shuffle=False,\n            num_workers=num_workers,\n            pin_memory=pin_memory,\n            persistent_workers=persistent_workers,\n        )\n        # add meta-data to dataloader instance for convenience\n        self._train_dataloader.num_batches = num_batches\n        self._train_dataloader.num_samples = num_samples\n\n        # Create eval dataset and loader\n        pipeline = [\n            wds.SimpleShardList(eval_shards_path_or_url),\n            wds.split_by_worker,\n            wds.tarfile_to_samples(handler=wds.ignore_and_continue),\n            *processing_pipeline,\n            wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),\n        ]\n        self._eval_dataset = wds.DataPipeline(*pipeline)\n        self._eval_dataloader = wds.WebLoader(\n            self._eval_dataset,\n            batch_size=None,\n            shuffle=False,\n            num_workers=num_workers,\n            pin_memory=pin_memory,\n            persistent_workers=persistent_workers,\n        )\n\n    @property\n    def train_dataset(self):\n        return self._train_dataset\n\n    @property\n    def train_dataloader(self):\n        return self._train_dataloader\n\n    @property\n    def eval_dataset(self):\n        return self._eval_dataset\n\n    @property\n    def eval_dataloader(self):\n        return self._eval_dataloader\n\n\ndef image_grid(imgs, rows, cols):\n    assert len(imgs) == rows * cols\n\n    w, h = imgs[0].size\n    grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i % cols * w, i // cols * h))\n    return grid\n\n\ndef log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step):\n    logger.info(\"Running validation... \")\n\n    controlnet = accelerator.unwrap_model(controlnet)\n\n    pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=vae,\n        unet=unet,\n        controlnet=controlnet,\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n    )\n    # pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        validation_image = Image.open(validation_image).convert(\"RGB\")\n        validation_image = validation_image.resize((args.resolution, args.resolution))\n\n        images = []\n\n        for _ in range(args.num_validation_images):\n            with torch.autocast(\"cuda\"):\n                image = pipeline(\n                    validation_prompt, image=validation_image, num_inference_steps=20, generator=generator\n                ).images[0]\n            images.append(image)\n\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images = [np.asarray(validation_image)]\n\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images.append(wandb.Image(validation_image, caption=\"Controlnet conditioning\"))\n\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({\"validation\": formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n        del pipeline\n        gc.collect()\n        torch.cuda.empty_cache()\n\n        return image_logs\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {base_model}\ntags:\n- stable-diffusion-xl\n- stable-diffusion-xl-diffusers\n- text-to-image\n- diffusers\n- controlnet\n- diffusers-training\n- webdataset\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# controlnet-{repo_id}\n\nThese are controlnet weights trained on {base_model} with new type of conditioning.\n{img_str}\n\"\"\"\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a ControlNet training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--controlnet_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained controlnet model or model identifier from huggingface.co/models.\"\n        \" If not specified controlnet weights are initialized from unet.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be\"\n            \" float32 precision.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"controlnet-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--crops_coords_top_left_h\",\n        type=int,\n        default=0,\n        help=(\"Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet.\"),\n    )\n    parser.add_argument(\n        \"--crops_coords_top_left_w\",\n        type=int,\n        default=0,\n        help=(\"Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet.\"),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=3,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=1,\n        help=(\"Number of subprocesses to use for data loading.\"),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_shards_path_or_url\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--eval_shards_path_or_url\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the controlnet conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"sd_xl_train_controlnet\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n    parser.add_argument(\n        \"--control_type\",\n        type=str,\n        default=\"canny\",\n        help=(\"The type of controlnet conditioning image to use. One of `canny`, `depth` Defaults to `canny`.\"),\n    )\n    parser.add_argument(\n        \"--transformer_layers_per_block\",\n        type=str,\n        default=None,\n        help=(\"The number of layers per block in the transformer. If None, defaults to `args.transformer_layers`.\"),\n    )\n    parser.add_argument(\n        \"--old_style_controlnet\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Use the old style controlnet, which is a single transformer layer with a single head. Defaults to False.\"\n        ),\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    if args.validation_prompt is not None and args.validation_image is None:\n        raise ValueError(\"`--validation_image` must be set if `--validation_prompt` is set\")\n\n    if args.validation_prompt is None and args.validation_image is not None:\n        raise ValueError(\"`--validation_prompt` must be set if `--validation_image` is set\")\n\n    if (\n        args.validation_image is not None\n        and args.validation_prompt is not None\n        and len(args.validation_image) != 1\n        and len(args.validation_prompt) != 1\n        and len(args.validation_image) != len(args.validation_prompt)\n    ):\n        raise ValueError(\n            \"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,\"\n            \" or the same number of `--validation_prompt`s and `--validation_image`s\"\n        )\n\n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder.\"\n        )\n\n    return args\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):\n    prompt_embeds_list = []\n\n    captions = []\n    for caption in prompt_batch:\n        if random.random() < proportion_empty_prompts:\n            captions.append(\"\")\n        elif isinstance(caption, str):\n            captions.append(caption)\n        elif isinstance(caption, (list, np.ndarray)):\n            # take a random caption if there are multiple\n            captions.append(random.choice(caption) if is_train else caption[0])\n\n    with torch.no_grad():\n        for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n            text_inputs = tokenizer(\n                captions,\n                padding=\"max_length\",\n                max_length=tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            prompt_embeds = text_encoder(\n                text_input_ids.to(text_encoder.device),\n                output_hidden_states=True,\n            )\n\n            # We are only ALWAYS interested in the pooled output of the final text encoder\n            pooled_prompt_embeds = prompt_embeds[0]\n            prompt_embeds = prompt_embeds.hidden_states[-2]\n            bs_embed, seq_len, _ = prompt_embeds.shape\n            prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n            prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n                token=args.hub_token,\n                private=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision, use_fast=False\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer_2\", revision=args.revision, use_fast=False\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    # noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n    )\n\n    if args.controlnet_model_name_or_path:\n        logger.info(\"Loading existing controlnet weights\")\n        pre_controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)\n    else:\n        logger.info(\"Initializing controlnet weights from unet\")\n        pre_controlnet = ControlNetModel.from_unet(unet)\n\n    if args.transformer_layers_per_block is not None:\n        transformer_layers_per_block = [int(x) for x in args.transformer_layers_per_block.split(\",\")]\n        down_block_types = [\"DownBlock2D\" if l == 0 else \"CrossAttnDownBlock2D\" for l in transformer_layers_per_block]\n        controlnet = ControlNetModel.from_config(\n            pre_controlnet.config,\n            down_block_types=down_block_types,\n            transformer_layers_per_block=transformer_layers_per_block,\n        )\n        controlnet.load_state_dict(pre_controlnet.state_dict(), strict=False)\n        del pre_controlnet\n    else:\n        controlnet = pre_controlnet\n\n    if args.control_type == \"depth\":\n        feature_extractor = DPTImageProcessor.from_pretrained(\"Intel/dpt-hybrid-midas\")\n        depth_model = DPTForDepthEstimation.from_pretrained(\"Intel/dpt-hybrid-midas\")\n        depth_model.requires_grad_(False)\n    else:\n        feature_extractor = None\n        depth_model = None\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                i = len(weights) - 1\n\n                while len(weights) > 0:\n                    weights.pop()\n                    model = models[i]\n\n                    sub_dir = \"controlnet\"\n                    model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                    i -= 1\n\n        def load_model_hook(models, input_dir):\n            while len(models) > 0:\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = ControlNetModel.from_pretrained(input_dir, subfolder=\"controlnet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    vae.requires_grad_(False)\n    unet.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    controlnet.train()\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n            controlnet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        controlnet.enable_gradient_checkpointing()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if accelerator.unwrap_model(controlnet).dtype != torch.float32:\n        raise ValueError(\n            f\"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}\"\n        )\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = controlnet.parameters()\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae, unet and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    if args.pretrained_vae_model_name_or_path is not None:\n        vae.to(accelerator.device, dtype=weight_dtype)\n    else:\n        vae.to(accelerator.device, dtype=torch.float32)\n    unet.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n    if args.control_type == \"depth\":\n        depth_model.to(accelerator.device, dtype=weight_dtype)\n\n    # Here, we compute not just the text embeddings but also the additional embeddings\n    # needed for the SD XL UNet to operate.\n    def compute_embeddings(\n        prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True\n    ):\n        target_size = (args.resolution, args.resolution)\n        original_sizes = list(map(list, zip(*original_sizes)))\n        crops_coords_top_left = list(map(list, zip(*crop_coords)))\n\n        original_sizes = torch.tensor(original_sizes, dtype=torch.long)\n        crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)\n\n        # crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)\n        prompt_embeds, pooled_prompt_embeds = encode_prompt(\n            prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train\n        )\n        add_text_embeds = pooled_prompt_embeds\n\n        # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n        # add_time_ids = list(crops_coords_top_left + target_size)\n        add_time_ids = list(target_size)\n        add_time_ids = torch.tensor([add_time_ids])\n        add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)\n        # add_time_ids = torch.cat([torch.tensor(original_sizes, dtype=torch.long), add_time_ids], dim=-1)\n        add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)\n        add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)\n\n        prompt_embeds = prompt_embeds.to(accelerator.device)\n        add_text_embeds = add_text_embeds.to(accelerator.device)\n        unet_added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n        return {\"prompt_embeds\": prompt_embeds, **unet_added_cond_kwargs}\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    dataset = Text2ImageDataset(\n        train_shards_path_or_url=args.train_shards_path_or_url,\n        eval_shards_path_or_url=args.eval_shards_path_or_url,\n        num_train_examples=args.max_train_samples,\n        per_gpu_batch_size=args.train_batch_size,\n        global_batch_size=args.train_batch_size * accelerator.num_processes,\n        num_workers=args.dataloader_num_workers,\n        resolution=args.resolution,\n        center_crop=False,\n        random_flip=False,\n        shuffle_buffer_size=1000,\n        pin_memory=True,\n        persistent_workers=True,\n        control_type=args.control_type,\n        feature_extractor=feature_extractor,\n    )\n    train_dataloader = dataset.train_dataloader\n\n    # Let's first compute all the embeddings so that we can free up the text encoders\n    # from memory.\n    text_encoders = [text_encoder_one, text_encoder_two]\n    tokenizers = [tokenizer_one, tokenizer_two]\n\n    compute_embeddings_fn = functools.partial(\n        compute_embeddings,\n        proportion_empty_prompts=args.proportion_empty_prompts,\n        text_encoders=text_encoders,\n        tokenizers=tokenizers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    controlnet, optimizer, lr_scheduler = accelerator.prepare(controlnet, optimizer, lr_scheduler)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n\n        # tensorboard cannot handle list types for config\n        tracker_config.pop(\"validation_prompt\")\n        tracker_config.pop(\"validation_image\")\n\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches each epoch = {train_dataloader.num_batches}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    image_logs = None\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(controlnet):\n                image, control_image, text, orig_size, crop_coords = batch\n\n                encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)\n                image = image.to(accelerator.device, non_blocking=True)\n                control_image = control_image.to(accelerator.device, non_blocking=True)\n\n                if args.pretrained_vae_model_name_or_path is not None:\n                    pixel_values = image.to(dtype=weight_dtype)\n                    if vae.dtype != weight_dtype:\n                        vae.to(dtype=weight_dtype)\n                else:\n                    pixel_values = image\n\n                # latents = vae.encode(pixel_values).latent_dist.sample()\n                # encode pixel values with batch size of at most 8\n                latents = []\n                for i in range(0, pixel_values.shape[0], 8):\n                    latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())\n                latents = torch.cat(latents, dim=0)\n\n                latents = latents * vae.config.scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    latents = latents.to(weight_dtype)\n\n                if args.control_type == \"depth\":\n                    control_image = control_image.to(weight_dtype)\n                    with torch.autocast(\"cuda\"):\n                        depth_map = depth_model(control_image).predicted_depth\n                    depth_map = torch.nn.functional.interpolate(\n                        depth_map.unsqueeze(1),\n                        size=image.shape[2:],\n                        mode=\"bicubic\",\n                        align_corners=False,\n                    )\n                    depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)\n                    depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)\n                    depth_map = (depth_map - depth_min) / (depth_max - depth_min)\n                    control_image = (depth_map * 255.0).to(torch.uint8).float() / 255.0  # hack to match inference\n                    control_image = torch.cat([control_image] * 3, dim=1)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n                sigmas = get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)\n                inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5)\n\n                # ControlNet conditioning.\n                controlnet_image = control_image.to(dtype=weight_dtype)\n                prompt_embeds = encoded_text.pop(\"prompt_embeds\")\n                down_block_res_samples, mid_block_res_sample = controlnet(\n                    inp_noisy_latents,\n                    timesteps,\n                    encoder_hidden_states=prompt_embeds,\n                    added_cond_kwargs=encoded_text,\n                    controlnet_cond=controlnet_image,\n                    return_dict=False,\n                )\n\n                # Predict the noise residual\n                model_pred = unet(\n                    inp_noisy_latents,\n                    timesteps,\n                    encoder_hidden_states=prompt_embeds,\n                    added_cond_kwargs=encoded_text,\n                    down_block_additional_residuals=[\n                        sample.to(dtype=weight_dtype) for sample in down_block_res_samples\n                    ],\n                    mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),\n                ).sample\n\n                model_pred = model_pred * (-sigmas) + noisy_latents\n                weighing = sigmas**-2.0\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = latents  # compute loss against the denoised latents\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n                loss = torch.mean(\n                    (weighing.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1\n                )\n                loss = loss.mean()\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = controlnet.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        image_logs = log_validation(\n                            vae, unet, controlnet, args, accelerator, weight_dtype, global_step\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        controlnet = accelerator.unwrap_model(controlnet)\n        controlnet.save_pretrained(args.output_dir)\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                image_logs=image_logs,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/diffusion_dpo/README.md",
    "content": "# Diffusion Model Alignment Using Direct Preference Optimization\n\nThis directory provides LoRA implementations of Diffusion DPO proposed in [DiffusionModel Alignment Using Direct Preference Optimization](https://huggingface.co/papers/2311.12908) by Bram Wallace, Meihua Dang, Rafael Rafailov, Linqi Zhou, Aaron Lou, Senthil Purushwalkam, Stefano Ermon, Caiming Xiong, Shafiq Joty, and Nikhil Naik.\n\nWe provide implementations for both Stable Diffusion (SD) and Stable Diffusion XL (SDXL). The original checkpoints are available at the URLs below:\n\n* [mhdang/dpo-sd1.5-text2image-v1](https://huggingface.co/mhdang/dpo-sd1.5-text2image-v1)\n* [mhdang/dpo-sdxl-text2image-v1](https://huggingface.co/mhdang/dpo-sdxl-text2image-v1)\n\n> 💡 Note: The scripts are highly experimental and were only tested on low-data regimes. Proceed with caution. Feel free to let us know about your findings via GitHub issues.\n\n## SD training command\n\n```bash\naccelerate launch train_diffusion_dpo.py \\\n  --pretrained_model_name_or_path=stable-diffusion-v1-5/stable-diffusion-v1-5  \\\n  --output_dir=\"diffusion-dpo\" \\\n  --mixed_precision=\"fp16\" \\\n  --dataset_name=kashif/pickascore \\\n  --resolution=512 \\\n  --train_batch_size=16 \\\n  --gradient_accumulation_steps=2 \\\n  --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --rank=8 \\\n  --learning_rate=1e-5 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=10000 \\\n  --checkpointing_steps=2000 \\\n  --run_validation --validation_steps=200 \\\n  --seed=\"0\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub\n```\n\n## SDXL training command\n\n```bash\naccelerate launch train_diffusion_dpo_sdxl.py \\\n  --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0  \\\n  --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \\\n  --output_dir=\"diffusion-sdxl-dpo\" \\\n  --mixed_precision=\"fp16\" \\\n  --dataset_name=kashif/pickascore \\\n  --train_batch_size=8 \\\n  --gradient_accumulation_steps=2 \\\n  --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --rank=8 \\\n  --learning_rate=1e-5 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=2000 \\\n  --checkpointing_steps=500 \\\n  --run_validation --validation_steps=50 \\\n  --seed=\"0\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub\n```\n\n## SDXL Turbo training command\n\n```bash\naccelerate launch train_diffusion_dpo_sdxl.py \\\n  --pretrained_model_name_or_path=stabilityai/sdxl-turbo \\\n  --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \\\n  --output_dir=\"diffusion-sdxl-turbo-dpo\" \\\n  --mixed_precision=\"fp16\" \\\n  --dataset_name=kashif/pickascore \\\n  --train_batch_size=8 \\\n  --gradient_accumulation_steps=2 \\\n  --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --rank=8 \\\n  --learning_rate=1e-5 \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=2000 \\\n  --checkpointing_steps=500 \\\n  --run_validation --validation_steps=50 \\\n  --seed=\"0\" \\\n  --report_to=\"wandb\" \\\n  --is_turbo --resolution 512 \\\n  --push_to_hub\n```\n\n\n## Acknowledgements\n\nThis is based on the amazing work done by [Bram](https://github.com/bram-w) here for Diffusion DPO: https://github.com/bram-w/trl/blob/dpo/.\n"
  },
  {
    "path": "examples/research_projects/diffusion_dpo/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2\npeft\nwandb"
  },
  {
    "path": "examples/research_projects/diffusion_dpo/train_diffusion_dpo.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 bram-w, The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport contextlib\nimport io\nimport logging\nimport math\nimport os\nimport shutil\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nimport wandb\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    DPMSolverMultistepScheduler,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, convert_state_dict_to_diffusers\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.25.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\nVALIDATION_PROMPTS = [\n    \"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography\",\n    \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\",\n    \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\",\n    \"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece\",\n]\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef log_validation(args, unet, accelerator, weight_dtype, epoch, is_final_validation=False):\n    logger.info(f\"Running validation... \\n Generating images with prompts:\\n {VALIDATION_PROMPTS}.\")\n\n    # create pipeline\n    pipeline = DiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    if not is_final_validation:\n        pipeline.unet = accelerator.unwrap_model(unet)\n    else:\n        pipeline.load_lora_weights(args.output_dir, weight_name=\"pytorch_lora_weights.safetensors\")\n\n    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n    images = []\n    context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()\n\n    for prompt in VALIDATION_PROMPTS:\n        with context:\n            image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0]\n            images.append(image)\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(tracker_key, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    tracker_key: [\n                        wandb.Image(image, caption=f\"{i}: {VALIDATION_PROMPTS[i]}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    # Also log images without the LoRA params for comparison.\n    if is_final_validation:\n        pipeline.disable_lora()\n        no_lora_images = [\n            pipeline(prompt, num_inference_steps=25, generator=generator).images[0] for prompt in VALIDATION_PROMPTS\n        ]\n\n        for tracker in accelerator.trackers:\n            if tracker.name == \"tensorboard\":\n                np_images = np.stack([np.asarray(img) for img in no_lora_images])\n                tracker.writer.add_images(\"test_without_lora\", np_images, epoch, dataformats=\"NHWC\")\n            if tracker.name == \"wandb\":\n                tracker.log(\n                    {\n                        \"test_without_lora\": [\n                            wandb.Image(image, caption=f\"{i}: {VALIDATION_PROMPTS[i]}\")\n                            for i, image in enumerate(no_lora_images)\n                        ]\n                    }\n                )\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_split_name\",\n        type=str,\n        default=\"validation\",\n        help=\"Dataset split to be used during training. Helpful to specify for conducting experimental runs.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--run_validation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to run validation inference in between training and also after training. Helps to track progress.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=200,\n        help=\"Run validation every X steps.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"diffusion-dpo-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--vae_encode_batch_size\",\n        type=int,\n        default=8,\n        help=\"Batch size to use for VAE encoding of the images for efficient processing.\",\n    )\n    parser.add_argument(\n        \"--no_hflip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--random_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to random crop the input images to the resolution. If not set, the images will be center-cropped.\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--beta_dpo\",\n        type=int,\n        default=2500,\n        help=\"DPO KL Divergence penalty.\",\n    )\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"sigmoid\",\n        help=\"DPO loss type. Can be one of 'sigmoid' (default), 'ipo', or 'cpo'\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--tracker_name\",\n        type=str,\n        default=\"diffusion-dpo-lora\",\n        help=(\"The name of the tracker to report results to.\"),\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None:\n        raise ValueError(\"Must provide a `dataset_name`.\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\ndef tokenize_captions(tokenizer, examples):\n    max_length = tokenizer.model_max_length\n    captions = []\n    for caption in examples[\"caption\"]:\n        captions.append(caption)\n\n    text_inputs = tokenizer(\n        captions, truncation=True, padding=\"max_length\", max_length=max_length, return_tensors=\"pt\"\n    )\n\n    return text_inputs.input_ids\n\n\n@torch.no_grad()\ndef encode_prompt(text_encoder, input_ids):\n    text_input_ids = input_ids.to(text_encoder.device)\n    attention_mask = None\n\n    prompt_embeds = text_encoder(text_input_ids, attention_mask=attention_mask)\n    prompt_embeds = prompt_embeds[0]\n\n    return prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder class\n    text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = text_encoder_cls.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n    )\n\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # Set up LoRA.\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n    # Add adapter and make sure the trainable params are in float32.\n    unet.add_adapter(unet_lora_config)\n    if args.mixed_precision == \"fp16\":\n        for param in unet.parameters():\n            # only upcast trainable parameters (LoRA) into fp32\n            if param.requires_grad:\n                param.data = param.to(torch.float32)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            # there are only two options here. Either are just the unet attn processor layers\n            # or there are the unet and text encoder atten layers\n            unet_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(accelerator.unwrap_model(unet))):\n                    unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            StableDiffusionLoraLoaderMixin.save_lora_weights(\n                output_dir,\n                unet_lora_layers=unet_lora_layers_to_save,\n                text_encoder_lora_layers=None,\n            )\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(accelerator.unwrap_model(unet))):\n                unet_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)\n        StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = load_dataset(\n        args.dataset_name,\n        cache_dir=args.cache_dir,\n        split=args.dataset_split_name,\n    )\n\n    train_transforms = transforms.Compose(\n        [\n            transforms.Resize(int(args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution),\n            transforms.Lambda(lambda x: x) if args.no_hflip else transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        all_pixel_values = []\n        for col_name in [\"jpg_0\", \"jpg_1\"]:\n            images = [Image.open(io.BytesIO(im_bytes)).convert(\"RGB\") for im_bytes in examples[col_name]]\n            pixel_values = [train_transforms(image) for image in images]\n            all_pixel_values.append(pixel_values)\n\n        # Double on channel dim, jpg_y then jpg_w\n        im_tup_iterator = zip(*all_pixel_values)\n        combined_pixel_values = []\n        for im_tup, label_0 in zip(im_tup_iterator, examples[\"label_0\"]):\n            if label_0 == 0:\n                im_tup = im_tup[::-1]\n            combined_im = torch.cat(im_tup, dim=0)  # no batch dim\n            combined_pixel_values.append(combined_im)\n        examples[\"pixel_values\"] = combined_pixel_values\n\n        examples[\"input_ids\"] = tokenize_captions(tokenizer, examples)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = train_dataset.with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        final_dict = {\"pixel_values\": pixel_values}\n        final_dict[\"input_ids\"] = torch.stack([example[\"input_ids\"] for example in examples])\n        return final_dict\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=collate_fn,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(args.tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    unet.train()\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)\n                pixel_values = batch[\"pixel_values\"].to(dtype=weight_dtype)\n                feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))\n\n                latents = []\n                for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):\n                    latents.append(\n                        vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()\n                    )\n                latents = torch.cat(latents, dim=0)\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)\n\n                # Sample a random timestep for each image\n                bsz = latents.shape[0] // 2\n                timesteps = torch.randint(\n                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long\n                ).repeat(2)\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = encode_prompt(text_encoder, batch[\"input_ids\"]).repeat(2, 1, 1)\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_model_input,\n                    timesteps,\n                    encoder_hidden_states,\n                ).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # Compute losses.\n                model_losses = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))\n                model_losses_w, model_losses_l = model_losses.chunk(2)\n\n                # For logging\n                raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())\n                model_diff = model_losses_w - model_losses_l  # These are both LBS (as is t)\n\n                # Reference model predictions.\n                accelerator.unwrap_model(unet).disable_adapters()\n                with torch.no_grad():\n                    ref_preds = unet(\n                        noisy_model_input,\n                        timesteps,\n                        encoder_hidden_states,\n                    ).sample.detach()\n                    ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction=\"none\")\n                    ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))\n\n                    ref_losses_w, ref_losses_l = ref_loss.chunk(2)\n                    ref_diff = ref_losses_w - ref_losses_l\n                    raw_ref_loss = ref_loss.mean()\n\n                # Re-enable adapters.\n                accelerator.unwrap_model(unet).enable_adapters()\n\n                # Final loss.\n                logits = ref_diff - model_diff\n                if args.loss_type == \"sigmoid\":\n                    loss = -1 * F.logsigmoid(args.beta_dpo * logits).mean()\n                elif args.loss_type == \"hinge\":\n                    loss = torch.relu(1 - args.beta_dpo * logits).mean()\n                elif args.loss_type == \"ipo\":\n                    losses = (logits - 1 / (2 * args.beta)) ** 2\n                    loss = losses.mean()\n                else:\n                    raise ValueError(f\"Unknown loss type {args.loss_type}\")\n\n                implicit_acc = (logits > 0).sum().float() / logits.size(0)\n                implicit_acc += 0.5 * (logits == 0).sum().float() / logits.size(0)\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.run_validation and global_step % args.validation_steps == 0:\n                        log_validation(\n                            args, unet=unet, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch\n                        )\n\n            logs = {\n                \"loss\": loss.detach().item(),\n                \"raw_model_loss\": raw_model_loss.detach().item(),\n                \"ref_loss\": raw_ref_loss.detach().item(),\n                \"implicit_acc\": implicit_acc.detach().item(),\n                \"lr\": lr_scheduler.get_last_lr()[0],\n            }\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = accelerator.unwrap_model(unet)\n        unet = unet.to(torch.float32)\n        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n\n        StableDiffusionLoraLoaderMixin.save_lora_weights(\n            save_directory=args.output_dir, unet_lora_layers=unet_lora_state_dict, text_encoder_lora_layers=None\n        )\n\n        # Final validation?\n        if args.run_validation:\n            log_validation(\n                args,\n                unet=None,\n                accelerator=accelerator,\n                weight_dtype=weight_dtype,\n                epoch=epoch,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 bram-w, The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport contextlib\nimport io\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nimport wandb\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import StableDiffusionXLLoraLoaderMixin\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, convert_state_dict_to_diffusers\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.25.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\nVALIDATION_PROMPTS = [\n    \"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography\",\n    \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\",\n    \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\",\n    \"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece\",\n]\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):\n    logger.info(f\"Running validation... \\n Generating images with prompts:\\n {VALIDATION_PROMPTS}.\")\n\n    if is_final_validation:\n        if args.mixed_precision == \"fp16\":\n            vae.to(weight_dtype)\n\n    # create pipeline\n    pipeline = DiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=vae,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    if not is_final_validation:\n        pipeline.unet = accelerator.unwrap_model(unet)\n    else:\n        pipeline.load_lora_weights(args.output_dir, weight_name=\"pytorch_lora_weights.safetensors\")\n\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n    images = []\n    context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()\n\n    guidance_scale = 5.0\n    num_inference_steps = 25\n    if args.is_turbo:\n        guidance_scale = 0.0\n        num_inference_steps = 4\n    for prompt in VALIDATION_PROMPTS:\n        with context:\n            image = pipeline(\n                prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator\n            ).images[0]\n            images.append(image)\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(tracker_key, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    tracker_key: [\n                        wandb.Image(image, caption=f\"{i}: {VALIDATION_PROMPTS[i]}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    # Also log images without the LoRA params for comparison.\n    if is_final_validation:\n        pipeline.disable_lora()\n        no_lora_images = [\n            pipeline(\n                prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator\n            ).images[0]\n            for prompt in VALIDATION_PROMPTS\n        ]\n\n        for tracker in accelerator.trackers:\n            if tracker.name == \"tensorboard\":\n                np_images = np.stack([np.asarray(img) for img in no_lora_images])\n                tracker.writer.add_images(\"test_without_lora\", np_images, epoch, dataformats=\"NHWC\")\n            if tracker.name == \"wandb\":\n                tracker.log(\n                    {\n                        \"test_without_lora\": [\n                            wandb.Image(image, caption=f\"{i}: {VALIDATION_PROMPTS[i]}\")\n                            for i, image in enumerate(no_lora_images)\n                        ]\n                    }\n                )\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_split_name\",\n        type=str,\n        default=\"validation\",\n        help=\"Dataset split to be used during training. Helpful to specify for conducting experimental runs.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--run_validation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to run validation inference in between training and also after training. Helps to track progress.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=200,\n        help=\"Run validation every X steps.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"diffusion-dpo-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--vae_encode_batch_size\",\n        type=int,\n        default=8,\n        help=\"Batch size to use for VAE encoding of the images for efficient processing.\",\n    )\n    parser.add_argument(\n        \"--no_hflip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--random_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to random crop the input images to the resolution. If not set, the images will be center-cropped.\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--beta_dpo\",\n        type=int,\n        default=5000,\n        help=\"DPO KL Divergence penalty.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--is_turbo\",\n        action=\"store_true\",\n        help=(\"Use if tuning SDXL Turbo instead of SDXL\"),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--tracker_name\",\n        type=str,\n        default=\"diffusion-dpo-lora-sdxl\",\n        help=(\"The name of the tracker to report results to.\"),\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None:\n        raise ValueError(\"Must provide a `dataset_name`.\")\n\n    if args.is_turbo:\n        assert \"turbo\" in args.pretrained_model_name_or_path\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\ndef tokenize_captions(tokenizers, examples):\n    captions = []\n    for caption in examples[\"caption\"]:\n        captions.append(caption)\n\n    tokens_one = tokenizers[0](\n        captions, truncation=True, padding=\"max_length\", max_length=tokenizers[0].model_max_length, return_tensors=\"pt\"\n    ).input_ids\n    tokens_two = tokenizers[1](\n        captions, truncation=True, padding=\"max_length\", max_length=tokenizers[1].model_max_length, return_tensors=\"pt\"\n    ).input_ids\n\n    return tokens_one, tokens_two\n\n\n@torch.no_grad()\ndef encode_prompt(text_encoders, text_input_ids_list):\n    prompt_embeds_list = []\n\n    for i, text_encoder in enumerate(text_encoders):\n        text_input_ids = text_input_ids_list[i]\n\n        prompt_embeds = text_encoder(\n            text_input_ids.to(text_encoder.device),\n            output_hidden_states=True,\n        )\n\n        # We are only ALWAYS interested in the pooled output of the final text encoder\n        pooled_prompt_embeds = prompt_embeds[0]\n        prompt_embeds = prompt_embeds.hidden_states[-2]\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n        prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    def enforce_zero_terminal_snr(scheduler):\n        # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L93\n        # Original implementation https://huggingface.co/papers/2305.08891\n        # Turbo needs zero terminal SNR\n        # Turbo: https://static1.squarespace.com/static/6213c340453c3f502425776e/t/65663480a92fba51d0e1023f/1701197769659/adversarial_diffusion_distillation.pdf\n        # Convert betas to alphas_bar_sqrt\n        alphas = 1 - scheduler.betas\n        alphas_bar = alphas.cumprod(0)\n        alphas_bar_sqrt = alphas_bar.sqrt()\n\n        # Store old values.\n        alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()\n        alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()\n        # Shift so last timestep is zero.\n        alphas_bar_sqrt -= alphas_bar_sqrt_T\n        # Scale so first timestep is back to old value.\n        alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)\n\n        alphas_bar = alphas_bar_sqrt**2\n        alphas = alphas_bar[1:] / alphas_bar[:-1]\n        alphas = torch.cat([alphas_bar[0:1], alphas])\n\n        alphas_cumprod = torch.cumprod(alphas, dim=0)\n        scheduler.alphas_cumprod = alphas_cumprod\n        return\n\n    if args.is_turbo:\n        enforce_zero_terminal_snr(noise_scheduler)\n\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet and text_encoders to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    # The VAE is always in float32 to avoid NaN losses.\n    vae.to(accelerator.device, dtype=torch.float32)\n\n    # Set up LoRA.\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n    # Add adapter and make sure the trainable params are in float32.\n    unet.add_adapter(unet_lora_config)\n    if args.mixed_precision == \"fp16\":\n        for param in unet.parameters():\n            # only upcast trainable parameters (LoRA) into fp32\n            if param.requires_grad:\n                param.data = param.to(torch.float32)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            # there are only two options here. Either are just the unet attn processor layers\n            # or there are the unet and text encoder atten layers\n            unet_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(accelerator.unwrap_model(unet))):\n                    unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            StableDiffusionXLLoraLoaderMixin.save_lora_weights(output_dir, unet_lora_layers=unet_lora_layers_to_save)\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(accelerator.unwrap_model(unet))):\n                unet_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)\n        StableDiffusionXLLoraLoaderMixin.load_lora_into_unet(\n            lora_state_dict, network_alphas=network_alphas, unet=unet_\n        )\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = load_dataset(\n        args.dataset_name,\n        cache_dir=args.cache_dir,\n        split=args.dataset_split_name,\n    )\n\n    # Preprocessing the datasets.\n    train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)\n    train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution)\n    train_flip = transforms.RandomHorizontalFlip(p=1.0)\n    to_tensor = transforms.ToTensor()\n    normalize = transforms.Normalize([0.5], [0.5])\n\n    def preprocess_train(examples):\n        all_pixel_values = []\n        images = [Image.open(io.BytesIO(im_bytes)).convert(\"RGB\") for im_bytes in examples[\"jpg_0\"]]\n        original_sizes = [(image.height, image.width) for image in images]\n        crop_top_lefts = []\n\n        for col_name in [\"jpg_0\", \"jpg_1\"]:\n            images = [Image.open(io.BytesIO(im_bytes)).convert(\"RGB\") for im_bytes in examples[col_name]]\n            if col_name == \"jpg_1\":\n                # Need to bring down the image to the same resolution.\n                # This seems like the simplest reasonable approach.\n                # \"::-1\" because PIL resize takes (width, height).\n                images = [image.resize(original_sizes[i][::-1]) for i, image in enumerate(images)]\n            pixel_values = [to_tensor(image) for image in images]\n            all_pixel_values.append(pixel_values)\n\n        # Double on channel dim, jpg_y then jpg_w\n        im_tup_iterator = zip(*all_pixel_values)\n        combined_pixel_values = []\n        for im_tup, label_0 in zip(im_tup_iterator, examples[\"label_0\"]):\n            if label_0 == 0:\n                im_tup = im_tup[::-1]\n\n            combined_im = torch.cat(im_tup, dim=0)  # no batch dim\n\n            # Resize.\n            combined_im = train_resize(combined_im)\n\n            # Flipping.\n            if not args.no_hflip and random.random() < 0.5:\n                combined_im = train_flip(combined_im)\n\n            # Cropping.\n            if not args.random_crop:\n                y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))\n                x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0)))\n                combined_im = train_crop(combined_im)\n            else:\n                y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))\n                combined_im = crop(combined_im, y1, x1, h, w)\n\n            crop_top_left = (y1, x1)\n            crop_top_lefts.append(crop_top_left)\n            combined_im = normalize(combined_im)\n            combined_pixel_values.append(combined_im)\n\n        examples[\"pixel_values\"] = combined_pixel_values\n        examples[\"original_sizes\"] = original_sizes\n        examples[\"crop_top_lefts\"] = crop_top_lefts\n        tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], examples)\n        examples[\"input_ids_one\"] = tokens_one\n        examples[\"input_ids_two\"] = tokens_two\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = train_dataset.with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        original_sizes = [example[\"original_sizes\"] for example in examples]\n        crop_top_lefts = [example[\"crop_top_lefts\"] for example in examples]\n        input_ids_one = torch.stack([example[\"input_ids_one\"] for example in examples])\n        input_ids_two = torch.stack([example[\"input_ids_two\"] for example in examples])\n\n        return {\n            \"pixel_values\": pixel_values,\n            \"input_ids_one\": input_ids_one,\n            \"input_ids_two\": input_ids_two,\n            \"original_sizes\": original_sizes,\n            \"crop_top_lefts\": crop_top_lefts,\n        }\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=collate_fn,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(args.tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    unet.train()\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)\n                pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))\n\n                latents = []\n                for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):\n                    latents.append(\n                        vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()\n                    )\n                latents = torch.cat(latents, dim=0)\n                latents = latents * vae.config.scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    latents = latents.to(weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)\n\n                # Sample a random timestep for each image\n                bsz = latents.shape[0] // 2\n                timesteps = torch.randint(\n                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long\n                ).repeat(2)\n                if args.is_turbo:\n                    # Learn a 4 timestep schedule\n                    timesteps_0_to_3 = timesteps % 4\n                    timesteps = 250 * timesteps_0_to_3 + 249\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # time ids\n                def compute_time_ids(original_size, crops_coords_top_left):\n                    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n                    target_size = (args.resolution, args.resolution)\n                    add_time_ids = list(original_size + crops_coords_top_left + target_size)\n                    add_time_ids = torch.tensor([add_time_ids])\n                    add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)\n                    return add_time_ids\n\n                add_time_ids = torch.cat(\n                    [compute_time_ids(s, c) for s, c in zip(batch[\"original_sizes\"], batch[\"crop_top_lefts\"])]\n                ).repeat(2, 1)\n\n                # Get the text embedding for conditioning\n                prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                    [text_encoder_one, text_encoder_two], [batch[\"input_ids_one\"], batch[\"input_ids_two\"]]\n                )\n                prompt_embeds = prompt_embeds.repeat(2, 1, 1)\n                pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_model_input,\n                    timesteps,\n                    prompt_embeds,\n                    added_cond_kwargs={\"time_ids\": add_time_ids, \"text_embeds\": pooled_prompt_embeds},\n                ).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # Compute losses.\n                model_losses = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))\n                model_losses_w, model_losses_l = model_losses.chunk(2)\n\n                # For logging\n                raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())\n                model_diff = model_losses_w - model_losses_l  # These are both LBS (as is t)\n\n                # Reference model predictions.\n                accelerator.unwrap_model(unet).disable_adapters()\n                with torch.no_grad():\n                    ref_preds = unet(\n                        noisy_model_input,\n                        timesteps,\n                        prompt_embeds,\n                        added_cond_kwargs={\"time_ids\": add_time_ids, \"text_embeds\": pooled_prompt_embeds},\n                    ).sample\n                    ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction=\"none\")\n                    ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))\n\n                    ref_losses_w, ref_losses_l = ref_loss.chunk(2)\n                    ref_diff = ref_losses_w - ref_losses_l\n                    raw_ref_loss = ref_loss.mean()\n\n                # Re-enable adapters.\n                accelerator.unwrap_model(unet).enable_adapters()\n\n                # Final loss.\n                scale_term = -0.5 * args.beta_dpo\n                inside_term = scale_term * (model_diff - ref_diff)\n                loss = -1 * F.logsigmoid(inside_term).mean()\n\n                implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)\n                implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.run_validation and global_step % args.validation_steps == 0:\n                        log_validation(\n                            args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch\n                        )\n\n            logs = {\n                \"loss\": loss.detach().item(),\n                \"raw_model_loss\": raw_model_loss.detach().item(),\n                \"ref_loss\": raw_ref_loss.detach().item(),\n                \"implicit_acc\": implicit_acc.detach().item(),\n                \"lr\": lr_scheduler.get_last_lr()[0],\n            }\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = accelerator.unwrap_model(unet)\n        unet = unet.to(torch.float32)\n        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n\n        StableDiffusionXLLoraLoaderMixin.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_state_dict,\n            text_encoder_lora_layers=None,\n            text_encoder_2_lora_layers=None,\n        )\n\n        # Final validation?\n        if args.run_validation:\n            log_validation(\n                args,\n                unet=None,\n                vae=vae,\n                accelerator=accelerator,\n                weight_dtype=weight_dtype,\n                epoch=epoch,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/diffusion_orpo/README.md",
    "content": "This project has a new home now: [https://mapo-t2i.github.io/](https://mapo-t2i.github.io/). We formally studied the use of ORPO in the context of diffusion models and open-sourced our codebase, models, and datasets. We released our paper too!\n"
  },
  {
    "path": "examples/research_projects/diffusion_orpo/requirements.txt",
    "content": "datasets\naccelerate\ntransformers\ntorchvision\nwandb\npeft\nwebdataset"
  },
  {
    "path": "examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport contextlib\nimport io\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nimport wandb\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import StableDiffusionXLLoraLoaderMixin\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.25.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\nVALIDATION_PROMPTS = [\n    \"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography\",\n    \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\",\n    \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\",\n    \"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece\",\n]\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):\n    logger.info(f\"Running validation... \\n Generating images with prompts:\\n {VALIDATION_PROMPTS}.\")\n\n    if is_final_validation:\n        if args.mixed_precision == \"fp16\":\n            vae.to(weight_dtype)\n\n    # create pipeline\n    pipeline = DiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=vae,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    if not is_final_validation:\n        pipeline.unet = accelerator.unwrap_model(unet)\n    else:\n        pipeline.load_lora_weights(args.output_dir, weight_name=\"pytorch_lora_weights.safetensors\")\n\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n    images = []\n    context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()\n\n    guidance_scale = 5.0\n    num_inference_steps = 25\n    for prompt in VALIDATION_PROMPTS:\n        with context:\n            image = pipeline(\n                prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator\n            ).images[0]\n            images.append(image)\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(tracker_key, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    tracker_key: [\n                        wandb.Image(image, caption=f\"{i}: {VALIDATION_PROMPTS[i]}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    # Also log images without the LoRA params for comparison.\n    if is_final_validation:\n        pipeline.disable_lora()\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n        no_lora_images = [\n            pipeline(\n                prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator\n            ).images[0]\n            for prompt in VALIDATION_PROMPTS\n        ]\n\n        for tracker in accelerator.trackers:\n            if tracker.name == \"tensorboard\":\n                np_images = np.stack([np.asarray(img) for img in no_lora_images])\n                tracker.writer.add_images(\"test_without_lora\", np_images, epoch, dataformats=\"NHWC\")\n            if tracker.name == \"wandb\":\n                tracker.log(\n                    {\n                        \"test_without_lora\": [\n                            wandb.Image(image, caption=f\"{i}: {VALIDATION_PROMPTS[i]}\")\n                            for i, image in enumerate(no_lora_images)\n                        ]\n                    }\n                )\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_split_name\",\n        type=str,\n        default=\"validation\",\n        help=\"Dataset split to be used during training. Helpful to specify for conducting experimental runs.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--run_validation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to run validation inference in between training and also after training. Helps to track progress.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=200,\n        help=\"Run validation every X steps.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"diffusion-orpo-lora-sdxl\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--vae_encode_batch_size\",\n        type=int,\n        default=8,\n        help=\"Batch size to use for VAE encoding of the images for efficient processing.\",\n    )\n    parser.add_argument(\n        \"--no_hflip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--random_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to random crop the input images to the resolution. If not set, the images will be center-cropped.\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--beta_orpo\",\n        type=float,\n        default=0.1,\n        help=\"ORPO contribution factor.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--tracker_name\",\n        type=str,\n        default=\"diffusion-orpo-lora-sdxl\",\n        help=(\"The name of the tracker to report results to.\"),\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None:\n        raise ValueError(\"Must provide a `dataset_name`.\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\ndef tokenize_captions(tokenizers, examples):\n    captions = []\n    for caption in examples[\"caption\"]:\n        captions.append(caption)\n\n    tokens_one = tokenizers[0](\n        captions, truncation=True, padding=\"max_length\", max_length=tokenizers[0].model_max_length, return_tensors=\"pt\"\n    ).input_ids\n    tokens_two = tokenizers[1](\n        captions, truncation=True, padding=\"max_length\", max_length=tokenizers[1].model_max_length, return_tensors=\"pt\"\n    ).input_ids\n\n    return tokens_one, tokens_two\n\n\n@torch.no_grad()\ndef encode_prompt(text_encoders, text_input_ids_list):\n    prompt_embeds_list = []\n\n    for i, text_encoder in enumerate(text_encoders):\n        text_input_ids = text_input_ids_list[i]\n\n        prompt_embeds = text_encoder(\n            text_input_ids.to(text_encoder.device),\n            output_hidden_states=True,\n        )\n\n        # We are only ALWAYS interested in the pooled output of the final text encoder\n        pooled_prompt_embeds = prompt_embeds[0]\n        prompt_embeds = prompt_embeds.hidden_states[-2]\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n        prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet and text_encoders to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    # The VAE is always in float32 to avoid NaN losses.\n    vae.to(accelerator.device, dtype=torch.float32)\n\n    # Set up LoRA.\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n    # Add adapter and make sure the trainable params are in float32.\n    unet.add_adapter(unet_lora_config)\n    if args.mixed_precision == \"fp16\":\n        for param in unet.parameters():\n            # only upcast trainable parameters (LoRA) into fp32\n            if param.requires_grad:\n                param.data = param.to(torch.float32)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            # there are only two options here. Either are just the unet attn processor layers\n            # or there are the unet and text encoder atten layers\n            unet_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(accelerator.unwrap_model(unet))):\n                    unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            StableDiffusionXLLoraLoaderMixin.save_lora_weights(\n                output_dir,\n                unet_lora_layers=unet_lora_layers_to_save,\n                text_encoder_lora_layers=None,\n                text_encoder_2_lora_layers=None,\n            )\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(accelerator.unwrap_model(unet))):\n                unet_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)\n\n        unet_state_dict = {f\"{k.replace('unet.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"unet.\")}\n        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)\n        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = load_dataset(\n        args.dataset_name,\n        cache_dir=args.cache_dir,\n        split=args.dataset_split_name,\n    )\n\n    # Preprocessing the datasets.\n    train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)\n    train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution)\n    train_flip = transforms.RandomHorizontalFlip(p=1.0)\n    to_tensor = transforms.ToTensor()\n    normalize = transforms.Normalize([0.5], [0.5])\n\n    def preprocess_train(examples):\n        all_pixel_values = []\n        images = [Image.open(io.BytesIO(im_bytes)).convert(\"RGB\") for im_bytes in examples[\"jpg_0\"]]\n        original_sizes = [(image.height, image.width) for image in images]\n        crop_top_lefts = []\n\n        for col_name in [\"jpg_0\", \"jpg_1\"]:\n            images = [Image.open(io.BytesIO(im_bytes)).convert(\"RGB\") for im_bytes in examples[col_name]]\n            if col_name == \"jpg_1\":\n                # Need to bring down the image to the same resolution.\n                # This seems like the simplest reasonable approach.\n                # \"::-1\" because PIL resize takes (width, height).\n                images = [image.resize(original_sizes[i][::-1]) for i, image in enumerate(images)]\n            pixel_values = [to_tensor(image) for image in images]\n            all_pixel_values.append(pixel_values)\n\n        # Double on channel dim, jpg_y then jpg_w\n        im_tup_iterator = zip(*all_pixel_values)\n        combined_pixel_values = []\n        for im_tup, label_0 in zip(im_tup_iterator, examples[\"label_0\"]):\n            # We randomize selection and rejection.\n            if label_0 == 0.5:\n                if random.random() < 0.5:\n                    label_0 = 0\n                else:\n                    label_0 = 1\n\n            if label_0 == 0:\n                im_tup = im_tup[::-1]\n\n            combined_im = torch.cat(im_tup, dim=0)  # no batch dim\n\n            # Resize.\n            combined_im = train_resize(combined_im)\n\n            # Flipping.\n            if not args.no_hflip and random.random() < 0.5:\n                combined_im = train_flip(combined_im)\n\n            # Cropping.\n            if not args.random_crop:\n                y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))\n                x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0)))\n                combined_im = train_crop(combined_im)\n            else:\n                y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))\n                combined_im = crop(combined_im, y1, x1, h, w)\n\n            crop_top_left = (y1, x1)\n            crop_top_lefts.append(crop_top_left)\n            combined_im = normalize(combined_im)\n            combined_pixel_values.append(combined_im)\n\n        examples[\"pixel_values\"] = combined_pixel_values\n        examples[\"original_sizes\"] = original_sizes\n        examples[\"crop_top_lefts\"] = crop_top_lefts\n        tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], examples)\n        examples[\"input_ids_one\"] = tokens_one\n        examples[\"input_ids_two\"] = tokens_two\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = train_dataset.with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        original_sizes = [example[\"original_sizes\"] for example in examples]\n        crop_top_lefts = [example[\"crop_top_lefts\"] for example in examples]\n        input_ids_one = torch.stack([example[\"input_ids_one\"] for example in examples])\n        input_ids_two = torch.stack([example[\"input_ids_two\"] for example in examples])\n\n        return {\n            \"pixel_values\": pixel_values,\n            \"input_ids_one\": input_ids_one,\n            \"input_ids_two\": input_ids_two,\n            \"original_sizes\": original_sizes,\n            \"crop_top_lefts\": crop_top_lefts,\n        }\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=collate_fn,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(args.tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    unet.train()\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)\n                pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))\n\n                latents = []\n                for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):\n                    latents.append(\n                        vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()\n                    )\n                latents = torch.cat(latents, dim=0)\n                latents = latents * vae.config.scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    latents = latents.to(weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)\n\n                # Sample a random timestep for each image\n                bsz = latents.shape[0] // 2\n                timesteps = torch.randint(\n                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long\n                ).repeat(2)\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # time ids\n                def compute_time_ids(original_size, crops_coords_top_left):\n                    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n                    target_size = (args.resolution, args.resolution)\n                    add_time_ids = list(original_size + crops_coords_top_left + target_size)\n                    add_time_ids = torch.tensor([add_time_ids])\n                    add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)\n                    return add_time_ids\n\n                add_time_ids = torch.cat(\n                    [compute_time_ids(s, c) for s, c in zip(batch[\"original_sizes\"], batch[\"crop_top_lefts\"])]\n                ).repeat(2, 1)\n\n                # Get the text embedding for conditioning\n                prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                    [text_encoder_one, text_encoder_two], [batch[\"input_ids_one\"], batch[\"input_ids_two\"]]\n                )\n                prompt_embeds = prompt_embeds.repeat(2, 1, 1)\n                pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_model_input,\n                    timesteps,\n                    prompt_embeds,\n                    added_cond_kwargs={\"time_ids\": add_time_ids, \"text_embeds\": pooled_prompt_embeds},\n                ).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # ODDS ratio loss.\n                # In the diffusion formulation, we're assuming that the MSE loss\n                # approximates the logp.\n                model_losses = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))\n                model_losses_w, model_losses_l = model_losses.chunk(2)\n                log_odds = model_losses_w - model_losses_l\n\n                # Ratio loss.\n                ratio = F.logsigmoid(log_odds)\n                ratio_losses = args.beta_orpo * ratio\n\n                # Full ORPO loss\n                loss = model_losses_w.mean() - ratio_losses.mean()\n\n                # Backprop.\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.run_validation and global_step % args.validation_steps == 0:\n                        log_validation(\n                            args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = accelerator.unwrap_model(unet)\n        unet = unet.to(torch.float32)\n        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n\n        StableDiffusionXLLoraLoaderMixin.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_state_dict,\n            text_encoder_lora_layers=None,\n            text_encoder_2_lora_layers=None,\n        )\n\n        # Final validation?\n        if args.run_validation:\n            log_validation(\n                args,\n                unet=None,\n                vae=vae,\n                accelerator=accelerator,\n                weight_dtype=weight_dtype,\n                epoch=epoch,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport contextlib\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nimport wandb\nimport webdataset as wds\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import StableDiffusionXLLoraLoaderMixin\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.25.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\nVALIDATION_PROMPTS = [\n    \"portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography\",\n    \"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k\",\n    \"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k\",\n    \"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece\",\n]\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef log_validation(args, unet, vae, accelerator, weight_dtype, epoch, is_final_validation=False):\n    logger.info(f\"Running validation... \\n Generating images with prompts:\\n {VALIDATION_PROMPTS}.\")\n\n    if is_final_validation:\n        if args.mixed_precision == \"fp16\":\n            vae.to(weight_dtype)\n\n    # create pipeline\n    pipeline = DiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=vae,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    if not is_final_validation:\n        pipeline.unet = accelerator.unwrap_model(unet)\n    else:\n        pipeline.load_lora_weights(args.output_dir, weight_name=\"pytorch_lora_weights.safetensors\")\n\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n    images = []\n    context = contextlib.nullcontext() if is_final_validation else torch.cuda.amp.autocast()\n\n    guidance_scale = 5.0\n    num_inference_steps = 25\n    for prompt in VALIDATION_PROMPTS:\n        with context:\n            image = pipeline(\n                prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator\n            ).images[0]\n            images.append(image)\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(tracker_key, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    tracker_key: [\n                        wandb.Image(image, caption=f\"{i}: {VALIDATION_PROMPTS[i]}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    # Also log images without the LoRA params for comparison.\n    if is_final_validation:\n        pipeline.disable_lora()\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n        no_lora_images = [\n            pipeline(\n                prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator\n            ).images[0]\n            for prompt in VALIDATION_PROMPTS\n        ]\n\n        for tracker in accelerator.trackers:\n            if tracker.name == \"tensorboard\":\n                np_images = np.stack([np.asarray(img) for img in no_lora_images])\n                tracker.writer.add_images(\"test_without_lora\", np_images, epoch, dataformats=\"NHWC\")\n            if tracker.name == \"wandb\":\n                tracker.log(\n                    {\n                        \"test_without_lora\": [\n                            wandb.Image(image, caption=f\"{i}: {VALIDATION_PROMPTS[i]}\")\n                            for i, image in enumerate(no_lora_images)\n                        ]\n                    }\n                )\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_path\",\n        type=str,\n        default=\"pipe:aws s3 cp s3://diffusion-preference-opt/{00000..00644}.tar -\",\n    )\n    parser.add_argument(\n        \"--num_train_examples\",\n        type=int,\n        default=1001352,\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--run_validation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to run validation inference in between training and also after training. Helps to track progress.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=200,\n        help=\"Run validation every X steps.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"diffusion-orpo-lora-sdxl\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--vae_encode_batch_size\",\n        type=int,\n        default=8,\n        help=\"Batch size to use for VAE encoding of the images for efficient processing.\",\n    )\n    parser.add_argument(\n        \"--no_hflip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--random_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to random crop the input images to the resolution. If not set, the images will be center-cropped.\"\n        ),\n    )\n    parser.add_argument(\"--global_batch_size\", type=int, default=64, help=\"Total batch size.\")\n    parser.add_argument(\n        \"--per_gpu_batch_size\", type=int, default=8, help=\"Number of samples in a batch for a single GPU.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--beta_orpo\",\n        type=float,\n        default=0.1,\n        help=\"ORPO contribution factor.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--tracker_name\",\n        type=str,\n        default=\"diffusion-orpo-lora-sdxl\",\n        help=(\"The name of the tracker to report results to.\"),\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\ndef tokenize_captions(tokenizers, sample):\n    tokens_one = tokenizers[0](\n        sample[\"original_prompt\"],\n        truncation=True,\n        padding=\"max_length\",\n        max_length=tokenizers[0].model_max_length,\n        return_tensors=\"pt\",\n    ).input_ids\n    tokens_two = tokenizers[1](\n        sample[\"original_prompt\"],\n        truncation=True,\n        padding=\"max_length\",\n        max_length=tokenizers[1].model_max_length,\n        return_tensors=\"pt\",\n    ).input_ids\n\n    return tokens_one, tokens_two\n\n\n@torch.no_grad()\ndef encode_prompt(text_encoders, text_input_ids_list):\n    prompt_embeds_list = []\n\n    for i, text_encoder in enumerate(text_encoders):\n        text_input_ids = text_input_ids_list[i]\n\n        prompt_embeds = text_encoder(\n            text_input_ids.to(text_encoder.device),\n            output_hidden_states=True,\n        )\n\n        # We are only ALWAYS interested in the pooled output of the final text encoder\n        pooled_prompt_embeds = prompt_embeds[0]\n        prompt_embeds = prompt_embeds.hidden_states[-2]\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n        prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef get_dataset(args):\n    dataset = (\n        wds.WebDataset(args.dataset_path, resampled=True, handler=wds.warn_and_continue)\n        .shuffle(690, handler=wds.warn_and_continue)\n        .decode(\"pil\", handler=wds.warn_and_continue)\n        .rename(\n            original_prompt=\"original_prompt.txt\",\n            jpg_0=\"jpg_0.jpg\",\n            jpg_1=\"jpg_1.jpg\",\n            label_0=\"label_0.txt\",\n            label_1=\"label_1.txt\",\n            handler=wds.warn_and_continue,\n        )\n    )\n    return dataset\n\n\ndef get_loader(args, tokenizer_one, tokenizer_two):\n    # 1,001,352\n    num_batches = math.ceil(args.num_train_examples / args.global_batch_size)\n    num_worker_batches = math.ceil(\n        args.num_train_examples / (args.global_batch_size * args.dataloader_num_workers)\n    )  # per dataloader worker\n    num_batches = num_worker_batches * args.dataloader_num_workers\n    num_samples = num_batches * args.global_batch_size\n\n    dataset = get_dataset(args)\n\n    train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)\n    train_crop = transforms.RandomCrop(args.resolution) if args.random_crop else transforms.CenterCrop(args.resolution)\n    train_flip = transforms.RandomHorizontalFlip(p=1.0)\n    to_tensor = transforms.ToTensor()\n    normalize = transforms.Normalize([0.5], [0.5])\n\n    def preprocess_images(sample):\n        jpg_0_image = sample[\"jpg_0\"]\n        original_size = (jpg_0_image.height, jpg_0_image.width)\n        crop_top_left = []\n\n        jpg_1_image = sample[\"jpg_1\"]\n        # Need to bring down the image to the same resolution.\n        # This seems like the simplest reasonable approach.\n        # \"::-1\" because PIL resize takes (width, height).\n        jpg_1_image = jpg_1_image.resize(original_size[::-1])\n\n        # We randomize selection and rejection.\n        label_0 = sample[\"label_0\"]\n        if sample[\"label_0\"] == 0.5:\n            if random.random() < 0.5:\n                label_0 = 0\n            else:\n                label_0 = 1\n\n        # Double on channel dim, jpg_y then jpg_w\n        if label_0 == 0:\n            pixel_values = torch.cat([to_tensor(image) for image in [jpg_1_image, jpg_0_image]])\n        else:\n            pixel_values = torch.cat([to_tensor(image) for image in [jpg_0_image, jpg_1_image]])\n\n        # Resize.\n        combined_im = train_resize(pixel_values)\n\n        # Flipping.\n        if not args.no_hflip and random.random() < 0.5:\n            combined_im = train_flip(combined_im)\n\n        # Cropping.\n        if not args.random_crop:\n            y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))\n            x1 = max(0, int(round((combined_im.shape[2] - args.resolution) / 2.0)))\n            combined_im = train_crop(combined_im)\n        else:\n            y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))\n            combined_im = crop(combined_im, y1, x1, h, w)\n\n        crop_top_left = (y1, x1)\n        combined_im = normalize(combined_im)\n        tokens_one, tokens_two = tokenize_captions([tokenizer_one, tokenizer_two], sample)\n\n        return {\n            \"pixel_values\": combined_im,\n            \"original_size\": original_size,\n            \"crop_top_left\": crop_top_left,\n            \"tokens_one\": tokens_one,\n            \"tokens_two\": tokens_two,\n        }\n\n    def collate_fn(samples):\n        pixel_values = torch.stack([sample[\"pixel_values\"] for sample in samples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n        original_sizes = [example[\"original_size\"] for example in samples]\n        crop_top_lefts = [example[\"crop_top_left\"] for example in samples]\n        input_ids_one = torch.stack([example[\"tokens_one\"] for example in samples])\n        input_ids_two = torch.stack([example[\"tokens_two\"] for example in samples])\n\n        return {\n            \"pixel_values\": pixel_values,\n            \"input_ids_one\": input_ids_one,\n            \"input_ids_two\": input_ids_two,\n            \"original_sizes\": original_sizes,\n            \"crop_top_lefts\": crop_top_lefts,\n        }\n\n    dataset = dataset.map(preprocess_images, handler=wds.warn_and_continue)\n    dataset = dataset.batched(args.per_gpu_batch_size, partial=False, collation_fn=collate_fn)\n    dataset = dataset.with_epoch(num_worker_batches)\n\n    dataloader = wds.WebLoader(\n        dataset,\n        batch_size=None,\n        shuffle=False,\n        num_workers=args.dataloader_num_workers,\n        pin_memory=True,\n        persistent_workers=True,\n    )\n    # add meta-data to dataloader instance for convenience\n    dataloader.num_batches = num_batches\n    dataloader.num_samples = num_samples\n    return dataloader\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet and text_encoders to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    # The VAE is always in float32 to avoid NaN losses.\n    vae.to(accelerator.device, dtype=torch.float32)\n\n    # Set up LoRA.\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n    # Add adapter and make sure the trainable params are in float32.\n    unet.add_adapter(unet_lora_config)\n    if args.mixed_precision == \"fp16\":\n        for param in unet.parameters():\n            # only upcast trainable parameters (LoRA) into fp32\n            if param.requires_grad:\n                param.data = param.to(torch.float32)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            # there are only two options here. Either are just the unet attn processor layers\n            # or there are the unet and text encoder atten layers\n            unet_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(accelerator.unwrap_model(unet))):\n                    unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            StableDiffusionXLLoraLoaderMixin.save_lora_weights(\n                output_dir,\n                unet_lora_layers=unet_lora_layers_to_save,\n                text_encoder_lora_layers=None,\n                text_encoder_2_lora_layers=None,\n            )\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(accelerator.unwrap_model(unet))):\n                unet_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, network_alphas = StableDiffusionXLLoraLoaderMixin.lora_state_dict(input_dir)\n\n        unet_state_dict = {f\"{k.replace('unet.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"unet.\")}\n        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)\n        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.per_gpu_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Dataset and DataLoaders creation:\n    args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes\n    train_dataloader = get_loader(args, tokenizer_one=tokenizer_one, tokenizer_two=tokenizer_two)\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(args.tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.per_gpu_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {train_dataloader.num_samples}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.per_gpu_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    unet.train()\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # (batch_size, 2*channels, h, w) -> (2*batch_size, channels, h, w)\n                pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype, device=accelerator.device, non_blocking=True)\n                feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))\n\n                latents = []\n                for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):\n                    latents.append(\n                        vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()\n                    )\n                latents = torch.cat(latents, dim=0)\n                latents = latents * vae.config.scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    latents = latents.to(weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1, 1)\n\n                # Sample a random timestep for each image\n                bsz = latents.shape[0] // 2\n                timesteps = torch.randint(\n                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device, dtype=torch.long\n                ).repeat(2)\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # time ids\n                def compute_time_ids(original_size, crops_coords_top_left):\n                    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n                    target_size = (args.resolution, args.resolution)\n                    add_time_ids = list(tuple(original_size) + tuple(crops_coords_top_left) + target_size)\n                    add_time_ids = torch.tensor([add_time_ids])\n                    add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)\n                    return add_time_ids\n\n                add_time_ids = torch.cat(\n                    [compute_time_ids(s, c) for s, c in zip(batch[\"original_sizes\"], batch[\"crop_top_lefts\"])]\n                ).repeat(2, 1)\n\n                # Get the text embedding for conditioning\n                prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                    [text_encoder_one, text_encoder_two], [batch[\"input_ids_one\"], batch[\"input_ids_two\"]]\n                )\n                prompt_embeds = prompt_embeds.repeat(2, 1, 1)\n                pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1)\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_model_input,\n                    timesteps,\n                    prompt_embeds,\n                    added_cond_kwargs={\"time_ids\": add_time_ids, \"text_embeds\": pooled_prompt_embeds},\n                ).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # ODDS ratio loss.\n                # In the diffusion formulation, we're assuming that the MSE loss\n                # approximates the logp.\n                model_losses = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))\n                model_losses_w, model_losses_l = model_losses.chunk(2)\n                log_odds = model_losses_w - model_losses_l\n\n                # Ratio loss.\n                ratio = F.logsigmoid(log_odds)\n                ratio_losses = args.beta_orpo * ratio\n\n                # Full ORPO loss\n                loss = model_losses_w.mean() - ratio_losses.mean()\n\n                # Backprop.\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.run_validation and global_step % args.validation_steps == 0:\n                        log_validation(\n                            args, unet=unet, vae=vae, accelerator=accelerator, weight_dtype=weight_dtype, epoch=epoch\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = accelerator.unwrap_model(unet)\n        unet = unet.to(torch.float32)\n        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n\n        StableDiffusionXLLoraLoaderMixin.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_state_dict,\n            text_encoder_lora_layers=None,\n            text_encoder_2_lora_layers=None,\n        )\n\n        # Final validation?\n        if args.run_validation:\n            log_validation(\n                args,\n                unet=None,\n                vae=vae,\n                accelerator=accelerator,\n                weight_dtype=weight_dtype,\n                epoch=epoch,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/dreambooth_inpaint/README.md",
    "content": "# Dreambooth for the inpainting model\n\nThis script was added by @thedarkzeno .\n\nPlease note that this script is not actively maintained, you can open an issue and tag @thedarkzeno or @patil-suraj though.\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-inpainting\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth_inpaint.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=400\n```\n\n### Training with prior-preservation loss\n\nPrior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.\nAccording to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases.\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-inpainting\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth_inpaint.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n\n\n### Training with gradient checkpointing and 8-bit optimizer:\n\nWith the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.\n\nTo install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-inpainting\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth_inpaint.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=2 --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n\n### Fine-tune text encoder with the UNet.\n\nThe script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.\nPass the `--train_text_encoder` argument to the script to enable training `text_encoder`.\n\n___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-inpainting\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth_inpaint.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --train_text_encoder \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --use_8bit_adam \\\n  --gradient_checkpointing \\\n  --learning_rate=2e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n"
  },
  {
    "path": "examples/research_projects/dreambooth_inpaint/requirements.txt",
    "content": "diffusers==0.9.0\naccelerate>=0.16.0\ntorchvision\ntransformers>=4.21.0\nftfy\ntensorboard\nJinja2\n"
  },
  {
    "path": "examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py",
    "content": "import argparse\nimport itertools\nimport math\nimport os\nimport random\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom PIL import Image, ImageDraw\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    StableDiffusionInpaintPipeline,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.13.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef prepare_mask_and_masked_image(image, mask):\n    image = np.array(image.convert(\"RGB\"))\n    image = image[None].transpose(0, 3, 1, 2)\n    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n    mask = np.array(mask.convert(\"L\"))\n    mask = mask.astype(np.float32) / 255.0\n    mask = mask[None, None]\n    mask[mask < 0.5] = 0\n    mask[mask >= 0.5] = 1\n    mask = torch.from_numpy(mask)\n\n    masked_image = image * (mask < 0.5)\n\n    return mask, masked_image\n\n\n# generate random masks\ndef random_mask(im_shape, ratio=1, mask_full_image=False):\n    mask = Image.new(\"L\", im_shape, 0)\n    draw = ImageDraw.Draw(mask)\n    size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))\n    # use this to always mask the whole image\n    if mask_full_image:\n        size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))\n    limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)\n    center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))\n    draw_type = random.randint(0, 1)\n    if draw_type == 0 or mask_full_image:\n        draw.rectangle(\n            (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),\n            fill=255,\n        )\n    else:\n        draw.ellipse(\n            (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),\n            fill=255,\n        )\n\n    return mask\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A folder containing the training data of instance images.\",\n    )\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt with identifier specifying the instance\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If not have enough images, additional images will be\"\n            \" sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\"--train_text_encoder\", action=\"store_true\", help=\"Whether to train the text encoder\")\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint and are suitable for resuming training\"\n            \" using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more docs\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.instance_data_dir is None:\n        raise ValueError(\"You must specify a train data directory.\")\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images and the tokenizes prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        tokenizer,\n        class_data_root=None,\n        class_prompt=None,\n        size=512,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n        self.tokenizer = tokenizer\n\n        self.instance_data_root = Path(instance_data_root)\n        if not self.instance_data_root.exists():\n            raise ValueError(\"Instance images root doesn't exists.\")\n\n        self.instance_images_path = list(Path(instance_data_root).iterdir())\n        self.num_instance_images = len(self.instance_images_path)\n        self.instance_prompt = instance_prompt\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n            self.class_prompt = class_prompt\n        else:\n            self.class_data_root = None\n\n        self.image_transforms_resize_and_crop = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n            ]\n        )\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])\n        if not instance_image.mode == \"RGB\":\n            instance_image = instance_image.convert(\"RGB\")\n        instance_image = self.image_transforms_resize_and_crop(instance_image)\n\n        example[\"PIL_images\"] = instance_image\n        example[\"instance_images\"] = self.image_transforms(instance_image)\n\n        example[\"instance_prompt_ids\"] = self.tokenizer(\n            self.instance_prompt,\n            padding=\"do_not_pad\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n        ).input_ids\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            class_image = self.image_transforms_resize_and_crop(class_image)\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_PIL_images\"] = class_image\n            example[\"class_prompt_ids\"] = self.tokenizer(\n                self.class_prompt,\n                padding=\"do_not_pad\",\n                truncation=True,\n                max_length=self.tokenizer.model_max_length,\n            ).input_ids\n\n        return example\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef main():\n    args = parse_args()\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=\"tensorboard\",\n        project_config=project_config,\n    )\n\n    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate\n    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.\n    # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.\n    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:\n        raise ValueError(\n            \"Gradient accumulation is not supported when training the text encoder in distributed training. \"\n            \"Please set gradient_accumulation_steps to 1. This feature will be supported in the future.\"\n        )\n\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            torch_dtype = torch.float16 if accelerator.device.type == \"cuda\" else torch.float32\n            pipeline = StableDiffusionInpaintPipeline.from_pretrained(\n                args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(\n                sample_dataset, batch_size=args.sample_batch_size, num_workers=1\n            )\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n            transform_to_pil = transforms.ToPILImage()\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                bsz = len(example[\"prompt\"])\n                fake_images = torch.rand((3, args.resolution, args.resolution))\n                transform_to_pil = transforms.ToPILImage()\n                fake_pil_images = transform_to_pil(fake_images)\n\n                fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)\n\n                images = pipeline(prompt=example[\"prompt\"], mask_image=fake_mask, image=fake_pil_images).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    if args.tokenizer_name:\n        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n    # Load models and create wrapper for stable diffusion\n    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"text_encoder\")\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\")\n    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"unet\")\n\n    vae.requires_grad_(False)\n    if not args.train_text_encoder:\n        text_encoder.requires_grad_(False)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder.gradient_checkpointing_enable()\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    params_to_optimize = (\n        itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()\n    )\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_prompt=args.class_prompt,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        center_crop=args.center_crop,\n    )\n\n    def collate_fn(examples):\n        input_ids = [example[\"instance_prompt_ids\"] for example in examples]\n        pixel_values = [example[\"instance_images\"] for example in examples]\n\n        # Concat class and instance examples for prior preservation.\n        # We do this to avoid doing two forward passes.\n        if args.with_prior_preservation:\n            input_ids += [example[\"class_prompt_ids\"] for example in examples]\n            pixel_values += [example[\"class_images\"] for example in examples]\n            pior_pil = [example[\"class_PIL_images\"] for example in examples]\n\n        masks = []\n        masked_images = []\n        for example in examples:\n            pil_image = example[\"PIL_images\"]\n            # generate a random mask\n            mask = random_mask(pil_image.size, 1, False)\n            # prepare mask and masked image\n            mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)\n\n            masks.append(mask)\n            masked_images.append(masked_image)\n\n        if args.with_prior_preservation:\n            for pil_image in pior_pil:\n                # generate a random mask\n                mask = random_mask(pil_image.size, 1, False)\n                # prepare mask and masked image\n                mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)\n\n                masks.append(mask)\n                masked_images.append(masked_image)\n\n        pixel_values = torch.stack(pixel_values)\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n        input_ids = tokenizer.pad({\"input_ids\": input_ids}, padding=True, return_tensors=\"pt\").input_ids\n        masks = torch.stack(masks)\n        masked_images = torch.stack(masked_images)\n        batch = {\"input_ids\": input_ids, \"pixel_values\": pixel_values, \"masks\": masks, \"masked_images\": masked_images}\n        return batch\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    if args.train_text_encoder:\n        unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n    accelerator.register_for_checkpointing(lr_scheduler)\n\n    weight_dtype = torch.float32\n    if args.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif args.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move text_encode and vae to gpu.\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    vae.to(accelerator.device, dtype=weight_dtype)\n    if not args.train_text_encoder:\n        text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"dreambooth\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Convert masked images to latent space\n                masked_latents = vae.encode(\n                    batch[\"masked_images\"].reshape(batch[\"pixel_values\"].shape).to(dtype=weight_dtype)\n                ).latent_dist.sample()\n                masked_latents = masked_latents * vae.config.scaling_factor\n\n                masks = batch[\"masks\"]\n                # resize the mask to latents shape as we concatenate the mask to the latents\n                mask = torch.stack(\n                    [\n                        torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))\n                        for mask in masks\n                    ]\n                )\n                mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # concatenate the noised latents with the mask and the masked latents\n                latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n                # Predict the noise residual\n                noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.\n                    noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute instance loss\n                    loss = F.mse_loss(noise_pred.float(), target.float(), reduction=\"none\").mean([1, 2, 3]).mean()\n\n                    # Compute prior loss\n                    prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction=\"mean\")\n\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n                else:\n                    loss = F.mse_loss(noise_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(unet.parameters(), text_encoder.parameters())\n                        if args.train_text_encoder\n                        else unet.parameters()\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        accelerator.wait_for_everyone()\n\n    # Create the pipeline using using the trained modules and save it.\n    if accelerator.is_main_process:\n        pipeline = StableDiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            unet=accelerator.unwrap_model(unet),\n            text_encoder=accelerator.unwrap_model(text_encoder),\n        )\n        pipeline.save_pretrained(args.output_dir)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py",
    "content": "import argparse\nimport math\nimport os\nimport random\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom PIL import Image, ImageDraw\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel\nfrom diffusers.loaders import AttnProcsLayers\nfrom diffusers.models.attention_processor import LoRAAttnProcessor\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.13.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef prepare_mask_and_masked_image(image, mask):\n    image = np.array(image.convert(\"RGB\"))\n    image = image[None].transpose(0, 3, 1, 2)\n    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n    mask = np.array(mask.convert(\"L\"))\n    mask = mask.astype(np.float32) / 255.0\n    mask = mask[None, None]\n    mask[mask < 0.5] = 0\n    mask[mask >= 0.5] = 1\n    mask = torch.from_numpy(mask)\n\n    masked_image = image * (mask < 0.5)\n\n    return mask, masked_image\n\n\n# generate random masks\ndef random_mask(im_shape, ratio=1, mask_full_image=False):\n    mask = Image.new(\"L\", im_shape, 0)\n    draw = ImageDraw.Draw(mask)\n    size = (random.randint(0, int(im_shape[0] * ratio)), random.randint(0, int(im_shape[1] * ratio)))\n    # use this to always mask the whole image\n    if mask_full_image:\n        size = (int(im_shape[0] * ratio), int(im_shape[1] * ratio))\n    limits = (im_shape[0] - size[0] // 2, im_shape[1] - size[1] // 2)\n    center = (random.randint(size[0] // 2, limits[0]), random.randint(size[1] // 2, limits[1]))\n    draw_type = random.randint(0, 1)\n    if draw_type == 0 or mask_full_image:\n        draw.rectangle(\n            (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),\n            fill=255,\n        )\n    else:\n        draw.ellipse(\n            (center[0] - size[0] // 2, center[1] - size[1] // 2, center[0] + size[0] // 2, center[1] + size[1] // 2),\n            fill=255,\n        )\n\n    return mask\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A folder containing the training data of instance images.\",\n    )\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt with identifier specifying the instance\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If not have enough images, additional images will be\"\n            \" sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"dreambooth-inpaint-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\"--train_text_encoder\", action=\"store_true\", help=\"Whether to train the text encoder\")\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint and are suitable for resuming training\"\n            \" using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more docs\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.instance_data_dir is None:\n        raise ValueError(\"You must specify a train data directory.\")\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images and the tokenizes prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        tokenizer,\n        class_data_root=None,\n        class_prompt=None,\n        size=512,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n        self.tokenizer = tokenizer\n\n        self.instance_data_root = Path(instance_data_root)\n        if not self.instance_data_root.exists():\n            raise ValueError(\"Instance images root doesn't exists.\")\n\n        self.instance_images_path = list(Path(instance_data_root).iterdir())\n        self.num_instance_images = len(self.instance_images_path)\n        self.instance_prompt = instance_prompt\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n            self.class_prompt = class_prompt\n        else:\n            self.class_data_root = None\n\n        self.image_transforms_resize_and_crop = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n            ]\n        )\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])\n        if not instance_image.mode == \"RGB\":\n            instance_image = instance_image.convert(\"RGB\")\n        instance_image = self.image_transforms_resize_and_crop(instance_image)\n\n        example[\"PIL_images\"] = instance_image\n        example[\"instance_images\"] = self.image_transforms(instance_image)\n\n        example[\"instance_prompt_ids\"] = self.tokenizer(\n            self.instance_prompt,\n            padding=\"do_not_pad\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n        ).input_ids\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            class_image = self.image_transforms_resize_and_crop(class_image)\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_PIL_images\"] = class_image\n            example[\"class_prompt_ids\"] = self.tokenizer(\n                self.class_prompt,\n                padding=\"do_not_pad\",\n                truncation=True,\n                max_length=self.tokenizer.model_max_length,\n            ).input_ids\n\n        return example\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef main():\n    args = parse_args()\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=\"tensorboard\",\n        project_config=accelerator_project_config,\n    )\n\n    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate\n    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.\n    # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.\n    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:\n        raise ValueError(\n            \"Gradient accumulation is not supported when training the text encoder in distributed training. \"\n            \"Please set gradient_accumulation_steps to 1. This feature will be supported in the future.\"\n        )\n\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            torch_dtype = torch.float16 if accelerator.device.type == \"cuda\" else torch.float32\n            pipeline = StableDiffusionInpaintPipeline.from_pretrained(\n                args.pretrained_model_name_or_path, torch_dtype=torch_dtype, safety_checker=None\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(\n                sample_dataset, batch_size=args.sample_batch_size, num_workers=1\n            )\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n            transform_to_pil = transforms.ToPILImage()\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                bsz = len(example[\"prompt\"])\n                fake_images = torch.rand((3, args.resolution, args.resolution))\n                transform_to_pil = transforms.ToPILImage()\n                fake_pil_images = transform_to_pil(fake_images)\n\n                fake_mask = random_mask((args.resolution, args.resolution), ratio=1, mask_full_image=True)\n\n                images = pipeline(prompt=example[\"prompt\"], mask_image=fake_mask, image=fake_pil_images).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    if args.tokenizer_name:\n        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n    # Load models and create wrapper for stable diffusion\n    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"text_encoder\")\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\")\n    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"unet\")\n\n    # We only train the additional adapter LoRA layers\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    weight_dtype = torch.float32\n    if args.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif args.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move text_encode and vae to gpu.\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # now we will add new LoRA weights to the attention layers\n    # It's important to realize here how many attention weights will be added and of which sizes\n    # The sizes of the attention layers consist only of two different variables:\n    # 1) - the \"hidden_size\", which is increased according to `unet.config.block_out_channels`.\n    # 2) - the \"cross attention size\", which is set to `unet.config.cross_attention_dim`.\n\n    # Let's first see how many attention processors we will have to set.\n    # For Stable Diffusion, it should be equal to:\n    # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12\n    # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2\n    # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18\n    # => 32 layers\n\n    # Set correct lora layers\n    lora_attn_procs = {}\n    for name in unet.attn_processors.keys():\n        cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n        if name.startswith(\"mid_block\"):\n            hidden_size = unet.config.block_out_channels[-1]\n        elif name.startswith(\"up_blocks\"):\n            block_id = int(name[len(\"up_blocks.\")])\n            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n        elif name.startswith(\"down_blocks\"):\n            block_id = int(name[len(\"down_blocks.\")])\n            hidden_size = unet.config.block_out_channels[block_id]\n\n        lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)\n\n    unet.set_attn_processor(lora_attn_procs)\n    lora_layers = AttnProcsLayers(unet.attn_processors)\n\n    accelerator.register_for_checkpointing(lora_layers)\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    optimizer = optimizer_class(\n        lora_layers.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_prompt=args.class_prompt,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        center_crop=args.center_crop,\n    )\n\n    def collate_fn(examples):\n        input_ids = [example[\"instance_prompt_ids\"] for example in examples]\n        pixel_values = [example[\"instance_images\"] for example in examples]\n\n        # Concat class and instance examples for prior preservation.\n        # We do this to avoid doing two forward passes.\n        if args.with_prior_preservation:\n            input_ids += [example[\"class_prompt_ids\"] for example in examples]\n            pixel_values += [example[\"class_images\"] for example in examples]\n            pior_pil = [example[\"class_PIL_images\"] for example in examples]\n\n        masks = []\n        masked_images = []\n        for example in examples:\n            pil_image = example[\"PIL_images\"]\n            # generate a random mask\n            mask = random_mask(pil_image.size, 1, False)\n            # prepare mask and masked image\n            mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)\n\n            masks.append(mask)\n            masked_images.append(masked_image)\n\n        if args.with_prior_preservation:\n            for pil_image in pior_pil:\n                # generate a random mask\n                mask = random_mask(pil_image.size, 1, False)\n                # prepare mask and masked image\n                mask, masked_image = prepare_mask_and_masked_image(pil_image, mask)\n\n                masks.append(mask)\n                masked_images.append(masked_image)\n\n        pixel_values = torch.stack(pixel_values)\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n        input_ids = tokenizer.pad({\"input_ids\": input_ids}, padding=True, return_tensors=\"pt\").input_ids\n        masks = torch.stack(masks)\n        masked_images = torch.stack(masked_images)\n        batch = {\"input_ids\": input_ids, \"pixel_values\": pixel_values, \"masks\": masks, \"masked_images\": masked_images}\n        return batch\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    # Prepare everything with our `accelerator`.\n    lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        lora_layers, optimizer, train_dataloader, lr_scheduler\n    )\n    # accelerator.register_for_checkpointing(lr_scheduler)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"dreambooth-inpaint-lora\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Convert masked images to latent space\n                masked_latents = vae.encode(\n                    batch[\"masked_images\"].reshape(batch[\"pixel_values\"].shape).to(dtype=weight_dtype)\n                ).latent_dist.sample()\n                masked_latents = masked_latents * vae.config.scaling_factor\n\n                masks = batch[\"masks\"]\n                # resize the mask to latents shape as we concatenate the mask to the latents\n                mask = torch.stack(\n                    [\n                        torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))\n                        for mask in masks\n                    ]\n                ).to(dtype=weight_dtype)\n                mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # concatenate the noised latents with the mask and the masked latents\n                latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n                # Predict the noise residual\n                noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.\n                    noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute instance loss\n                    loss = F.mse_loss(noise_pred.float(), target.float(), reduction=\"none\").mean([1, 2, 3]).mean()\n\n                    # Compute prior loss\n                    prior_loss = F.mse_loss(noise_pred_prior.float(), target_prior.float(), reduction=\"mean\")\n\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n                else:\n                    loss = F.mse_loss(noise_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = lora_layers.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        accelerator.wait_for_everyone()\n\n    # Save the lora layers\n    if accelerator.is_main_process:\n        unet = unet.to(torch.float32)\n        unet.save_attn_procs(args.output_dir)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/flux_lora_quantization/README.md",
    "content": "## LoRA fine-tuning Flux.1 Dev with quantization\n\n> [!NOTE]  \n> This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further.\n\nThis example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow:\n\n* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file.\n  * Even though optional, we load the T5-xxl in NF4 to further reduce the memory foot-print. \n* `train_dreambooth_lora_flux_miniature.py` takes care of training:\n  * Since we already precomputed the text embeddings, we don't load the text encoders.\n  * We load the VAE and use it to precompute the image latents and we then delete it. \n  * Load the Flux transformer, quantize it with the [NF4 datatype](https://huggingface.co/papers/2305.14314) through `bitsandbytes`, prepare it for 4bit training. \n  * Add LoRA adapter layers to it and then ensure they are kept in FP32 precision.\n  * Train!\n\nTo run training in a memory-optimized manner, we additionally use:\n\n* 8Bit Adam\n* Gradient checkpointing \n\nWe have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it's extremely slow. \n\n## Training\n\nEnsure you have installed the required libraries:\n\n```bash\npip install -U transformers accelerate bitsandbytes peft datasets \npip install git+https://github.com/huggingface/diffusers -U\n```\n\nNow, compute the text embeddings:\n\n```bash\npython compute_embeddings.py\n```\n\nIt should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model: \n\n```bash\nhf auth login\n```\n\nThen launch:\n\n```bash\naccelerate launch --config_file=accelerate.yaml \\\n  train_dreambooth_lora_flux_miniature.py \\\n  --pretrained_model_name_or_path=\"black-forest-labs/FLUX.1-dev\" \\\n  --data_df_path=\"embeddings.parquet\" \\\n  --output_dir=\"yarn_art_lora_flux_nf4\" \\\n  --mixed_precision=\"fp16\" \\\n  --use_8bit_adam \\\n  --weighting_scheme=\"none\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --repeats=1 \\\n  --learning_rate=1e-4 \\\n  --guidance_scale=1 \\\n  --report_to=\"wandb\" \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --cache_latents \\\n  --rank=4 \\\n  --max_train_steps=700 \\\n  --seed=\"0\"\n```\n\nWe can directly pass a quantized checkpoint path, too:\n\n```diff\n+ --quantized_model_path=\"hf-internal-testing/flux.1-dev-nf4-pkg\"\n```\n\nDepending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using `torch.bfloat16`. \n\nWe support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed:\n\n```bash\npip install -Uq deepspeed\n```\n\nAnd then launch:\n\n```bash\naccelerate launch --config_file=ds2.yaml \\\n  train_dreambooth_lora_flux_miniature.py \\\n  --pretrained_model_name_or_path=\"black-forest-labs/FLUX.1-dev\" \\\n  --data_df_path=\"embeddings.parquet\" \\\n  --output_dir=\"yarn_art_lora_flux_nf4\" \\\n  --mixed_precision=\"no\" \\\n  --use_8bit_adam \\\n  --weighting_scheme=\"none\" \\\n  --resolution=1024 \\\n  --train_batch_size=1 \\\n  --repeats=1 \\\n  --learning_rate=1e-4 \\\n  --guidance_scale=1 \\\n  --report_to=\"wandb\" \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --cache_latents \\\n  --rank=4 \\\n  --max_train_steps=700 \\\n  --seed=\"0\"\n```\n\n## Inference\n\nWhen loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example:\n\n1. First, load the original model and merge the LoRA params into it:\n\n```py\nfrom diffusers import FluxPipeline \nimport torch \n\nckpt_id = \"black-forest-labs/FLUX.1-dev\"\npipeline = FluxPipeline.from_pretrained(\n    ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16\n)\npipeline.load_lora_weights(\"yarn_art_lora_flux_nf4\", weight_name=\"pytorch_lora_weights.safetensors\")\npipeline.fuse_lora()\npipeline.unload_lora_weights()\n\npipeline.transformer.save_pretrained(\"fused_transformer\")\n```\n\n2. Quantize the model and run inference\n\n```py\nfrom diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig\nimport torch\n\nckpt_id = \"black-forest-labs/FLUX.1-dev\"\nbnb_4bit_compute_dtype = torch.float16\nnf4_config = BitsAndBytesConfig(\n    load_in_4bit=True,\n    bnb_4bit_quant_type=\"nf4\",\n    bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,\n)\ntransformer = FluxTransformer2DModel.from_pretrained(\n    \"fused_transformer\",\n    quantization_config=nf4_config,\n    torch_dtype=bnb_4bit_compute_dtype,\n)\npipeline = AutoPipelineForText2Image.from_pretrained(\n    ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype\n)\npipeline.enable_model_cpu_offload()\n\nimage = pipeline(\n    \"a puppy in a pond, yarn art style\", num_inference_steps=28, guidance_scale=3.5, height=768\n).images[0]\nimage.save(\"yarn_merged.png\")\n```\n\n|   Dequantize, merge, quantize   |   Merging directly into quantized model   |\n|-------|-------|\n| ![Image A](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/merged.png) | ![Image B](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/unmerged.png) |\n\nAs we can notice the first column result follows the style more closely.\n"
  },
  {
    "path": "examples/research_projects/flux_lora_quantization/accelerate.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndistributed_type: NO\ndowncast_bf16: 'no'\nenable_cpu_affinity: true\ngpu_ids: all\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 1\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "examples/research_projects/flux_lora_quantization/compute_embeddings.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\n\nimport pandas as pd\nimport torch\nfrom datasets import load_dataset\nfrom huggingface_hub.utils import insecure_hashlib\nfrom tqdm.auto import tqdm\nfrom transformers import T5EncoderModel\n\nfrom diffusers import FluxPipeline\n\n\nMAX_SEQ_LENGTH = 77\nOUTPUT_PATH = \"embeddings.parquet\"\n\n\ndef generate_image_hash(image):\n    return insecure_hashlib.sha256(image.tobytes()).hexdigest()\n\n\ndef load_flux_dev_pipeline():\n    id = \"black-forest-labs/FLUX.1-dev\"\n    text_encoder = T5EncoderModel.from_pretrained(id, subfolder=\"text_encoder_2\", load_in_8bit=True, device_map=\"auto\")\n    pipeline = FluxPipeline.from_pretrained(\n        id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map=\"balanced\"\n    )\n    return pipeline\n\n\n@torch.no_grad()\ndef compute_embeddings(pipeline, prompts, max_sequence_length):\n    all_prompt_embeds = []\n    all_pooled_prompt_embeds = []\n    all_text_ids = []\n    for prompt in tqdm(prompts, desc=\"Encoding prompts.\"):\n        (\n            prompt_embeds,\n            pooled_prompt_embeds,\n            text_ids,\n        ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length)\n        all_prompt_embeds.append(prompt_embeds)\n        all_pooled_prompt_embeds.append(pooled_prompt_embeds)\n        all_text_ids.append(text_ids)\n\n    max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024\n    print(f\"Max memory allocated: {max_memory:.3f} GB\")\n    return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids\n\n\ndef run(args):\n    dataset = load_dataset(\"Norod78/Yarn-art-style\", split=\"train\")\n    image_prompts = {generate_image_hash(sample[\"image\"]): sample[\"text\"] for sample in dataset}\n    all_prompts = list(image_prompts.values())\n    print(f\"{len(all_prompts)=}\")\n\n    pipeline = load_flux_dev_pipeline()\n    all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings(\n        pipeline, all_prompts, args.max_sequence_length\n    )\n\n    data = []\n    for i, (image_hash, _) in enumerate(image_prompts.items()):\n        data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i]))\n    print(f\"{len(data)=}\")\n\n    # Create a DataFrame\n    embedding_cols = [\"prompt_embeds\", \"pooled_prompt_embeds\", \"text_ids\"]\n    df = pd.DataFrame(data, columns=[\"image_hash\"] + embedding_cols)\n    print(f\"{len(df)=}\")\n\n    # Convert embedding lists to arrays (for proper storage in parquet)\n    for col in embedding_cols:\n        df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())\n\n    # Save the dataframe to a parquet file\n    df.to_parquet(args.output_path)\n    print(f\"Data successfully serialized to {args.output_path}\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=MAX_SEQ_LENGTH,\n        help=\"Maximum sequence length to use for computing the embeddings. The more the higher computational costs.\",\n    )\n    parser.add_argument(\"--output_path\", type=str, default=OUTPUT_PATH, help=\"Path to serialize the parquet file.\")\n    args = parser.parse_args()\n\n    run(args)\n"
  },
  {
    "path": "examples/research_projects/flux_lora_quantization/ds2.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  gradient_clipping: 1.0\n  offload_optimizer_device: cpu\n  offload_param_device: cpu\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: 'no'\nnum_machines: 1\nnum_processes: 1\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false"
  },
  {
    "path": "examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator, DistributedType\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    BitsAndBytesConfig,\n    FlowMatchEulerDiscreteScheduler,\n    FluxPipeline,\n    FluxTransformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    free_memory,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    pass\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.31.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    base_model: str = None,\n    instance_prompt=None,\n    repo_folder=None,\n    quantization_config=None,\n):\n    widget_dict = []\n\n    model_description = f\"\"\"\n# Flux DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth LoRA weights for {base_model}.\n\nThe weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).\n\nWas LoRA for the text encoder enabled? False.\n\nQuantization config:\n\n```yaml\n{quantization_config}\n```\n\n## Trigger words\n\nYou should use `{instance_prompt}` to trigger the image generation.\n\n## Download model\n\n[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.\n\nFor more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)\n\n## Usage\n\nTODO\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"flux\",\n        \"flux-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--quantized_model_path\",\n        type=str,\n        default=None,\n        help=\"Path to the quantized model.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--data_df_path\",\n        type=str,\n        default=None,\n        help=(\"Path to the parquet file serialized with compute_embeddings.py.\"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=77,\n        help=\"Used for reading the embeddings. Needs to be the same as used during `compute_embeddings.py`.\",\n    )\n\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-dreambooth-lora-nf4\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the FLUX.1 dev variant is a guidance distilled model\",\n    )\n\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        choices=[\"AdamW\", \"Prodigy\", \"AdEMAMix\"],\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n    parser.add_argument(\n        \"--use_8bit_ademamix\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit AdEMAMix from bitsandbytes.\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    def __init__(\n        self,\n        data_df_path,\n        dataset_name,\n        size=1024,\n        max_sequence_length=77,\n        center_crop=False,\n    ):\n        # Logistics\n        self.size = size\n        self.center_crop = center_crop\n        self.max_sequence_length = max_sequence_length\n\n        self.data_df_path = Path(data_df_path)\n        if not self.data_df_path.exists():\n            raise ValueError(\"`data_df_path` doesn't exists.\")\n\n        # Load images.\n        dataset = load_dataset(dataset_name, split=\"train\")\n        instance_images = [sample[\"image\"] for sample in dataset]\n        image_hashes = [self.generate_image_hash(image) for image in instance_images]\n        self.instance_images = instance_images\n        self.image_hashes = image_hashes\n\n        # Image transformations\n        self.pixel_values = self.apply_image_transformations(\n            instance_images=instance_images, size=size, center_crop=center_crop\n        )\n\n        # Map hashes to embeddings.\n        self.data_dict = self.map_image_hash_embedding(data_df_path=data_df_path)\n\n        self.num_instance_images = len(instance_images)\n        self._length = self.num_instance_images\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        image_hash = self.image_hashes[index % self.num_instance_images]\n        prompt_embeds, pooled_prompt_embeds, text_ids = self.data_dict[image_hash]\n        example[\"instance_images\"] = instance_image\n        example[\"prompt_embeds\"] = prompt_embeds\n        example[\"pooled_prompt_embeds\"] = pooled_prompt_embeds\n        example[\"text_ids\"] = text_ids\n        return example\n\n    def apply_image_transformations(self, instance_images, size, center_crop):\n        pixel_values = []\n\n        train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            image = train_transforms(image)\n            pixel_values.append(image)\n\n        return pixel_values\n\n    def convert_to_torch_tensor(self, embeddings: list):\n        prompt_embeds = embeddings[0]\n        pooled_prompt_embeds = embeddings[1]\n        text_ids = embeddings[2]\n        prompt_embeds = np.array(prompt_embeds).reshape(self.max_sequence_length, 4096)\n        pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(768)\n        text_ids = np.array(text_ids).reshape(77, 3)\n        return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds), torch.from_numpy(text_ids)\n\n    def map_image_hash_embedding(self, data_df_path):\n        hashes_df = pd.read_parquet(data_df_path)\n        data_dict = {}\n        for i, row in hashes_df.iterrows():\n            embeddings = [row[\"prompt_embeds\"], row[\"pooled_prompt_embeds\"], row[\"text_ids\"]]\n            prompt_embeds, pooled_prompt_embeds, text_ids = self.convert_to_torch_tensor(embeddings=embeddings)\n            data_dict.update({row[\"image_hash\"]: (prompt_embeds, pooled_prompt_embeds, text_ids)})\n        return data_dict\n\n    def generate_image_hash(self, image):\n        return insecure_hashlib.sha256(image.tobytes()).hexdigest()\n\n\ndef collate_fn(examples):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompt_embeds = [example[\"prompt_embeds\"] for example in examples]\n    pooled_prompt_embeds = [example[\"pooled_prompt_embeds\"] for example in examples]\n    text_ids = [example[\"text_ids\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n    prompt_embeds = torch.stack(prompt_embeds)\n    pooled_prompt_embeds = torch.stack(pooled_prompt_embeds)\n    text_ids = torch.stack(text_ids)[0]  # just 2D tensor\n\n    batch = {\n        \"pixel_values\": pixel_values,\n        \"prompt_embeds\": prompt_embeds,\n        \"pooled_prompt_embeds\": pooled_prompt_embeds,\n        \"text_ids\": text_ids,\n    }\n    return batch\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    bnb_4bit_compute_dtype = torch.float32\n    if args.mixed_precision == \"fp16\":\n        bnb_4bit_compute_dtype = torch.float16\n    elif args.mixed_precision == \"bf16\":\n        bnb_4bit_compute_dtype = torch.bfloat16\n    if args.quantized_model_path is not None:\n        transformer = FluxTransformer2DModel.from_pretrained(\n            args.quantized_model_path,\n            subfolder=\"transformer\",\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=bnb_4bit_compute_dtype,\n        )\n    else:\n        nf4_config = BitsAndBytesConfig(\n            load_in_4bit=True,\n            bnb_4bit_quant_type=\"nf4\",\n            bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,\n        )\n        transformer = FluxTransformer2DModel.from_pretrained(\n            args.pretrained_model_name_or_path,\n            subfolder=\"transformer\",\n            revision=args.revision,\n            variant=args.variant,\n            quantization_config=nf4_config,\n            torch_dtype=bnb_4bit_compute_dtype,\n        )\n    transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)\n\n    # We only train the additional adapter LoRA layers\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    vae.to(accelerator.device, dtype=weight_dtype)\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    # now we will add new LoRA weights to the attention layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(unwrap_model(model), type(unwrap_model(transformer))):\n                    model = unwrap_model(model)\n                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                if weights:\n                    weights.pop()\n\n            FluxPipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n                text_encoder_lora_layers=None,\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        if not accelerator.distributed_type == DistributedType.DEEPSPEED:\n            while len(models) > 0:\n                model = models.pop()\n\n                if isinstance(model, type(unwrap_model(transformer))):\n                    transformer_ = model\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n        else:\n            if args.quantized_model_path is not None:\n                transformer_ = FluxTransformer2DModel.from_pretrained(\n                    args.quantized_model_path,\n                    subfolder=\"transformer\",\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=bnb_4bit_compute_dtype,\n                )\n            else:\n                nf4_config = BitsAndBytesConfig(\n                    load_in_4bit=True,\n                    bnb_4bit_quant_type=\"nf4\",\n                    bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,\n                )\n                transformer_ = FluxTransformer2DModel.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    subfolder=\"transformer\",\n                    revision=args.revision,\n                    variant=args.variant,\n                    quantization_config=nf4_config,\n                    torch_dtype=bnb_4bit_compute_dtype,\n                )\n            transformer_ = prepare_model_for_kbit_training(transformer_, use_gradient_checkpointing=False)\n            transformer_.add_adapter(transformer_lora_config)\n\n        lora_state_dict = FluxPipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n\n    # Optimization parameters\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.use_8bit_ademamix and not args.optimizer.lower() == \"ademamix\":\n        logger.warning(\n            f\"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    elif args.optimizer.lower() == \"ademamix\":\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`.\"\n            )\n        if args.use_8bit_ademamix:\n            optimizer_class = bnb.optim.AdEMAMix8bit\n        else:\n            optimizer_class = bnb.optim.AdEMAMix\n\n        optimizer = optimizer_class(params_to_optimize)\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        data_df_path=args.data_df_path,\n        dataset_name=\"Norod78/Yarn-art-style\",\n        size=args.resolution,\n        max_sequence_length=args.max_sequence_length,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=collate_fn,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    vae_config_shift_factor = vae.config.shift_factor\n    vae_config_scaling_factor = vae.config.scaling_factor\n    vae_config_block_out_channels = vae.config.block_out_channels\n    if args.cache_latents:\n        latents_cache = []\n        for batch in tqdm(train_dataloader, desc=\"Caching latents\"):\n            with torch.no_grad():\n                batch[\"pixel_values\"] = batch[\"pixel_values\"].to(\n                    accelerator.device, non_blocking=True, dtype=weight_dtype\n                )\n                latents_cache.append(vae.encode(batch[\"pixel_values\"]).latent_dist)\n\n        del vae\n        free_memory()\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-flux-dev-lora-nf4\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            with accelerator.accumulate(models_to_accumulate):\n                # Convert images to latent space\n                if args.cache_latents:\n                    model_input = latents_cache[step].sample()\n                else:\n                    pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                    model_input = vae.encode(pixel_values).latent_dist.sample()\n                model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor\n                model_input = model_input.to(dtype=weight_dtype)\n\n                vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)\n\n                latent_image_ids = FluxPipeline._prepare_latent_image_ids(\n                    model_input.shape[0],\n                    model_input.shape[2] // 2,\n                    model_input.shape[3] // 2,\n                    accelerator.device,\n                    weight_dtype,\n                )\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                # zt = (1 - texp) * x + texp * z1\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                packed_noisy_model_input = FluxPipeline._pack_latents(\n                    noisy_model_input,\n                    batch_size=model_input.shape[0],\n                    num_channels_latents=model_input.shape[1],\n                    height=model_input.shape[2],\n                    width=model_input.shape[3],\n                )\n\n                # handle guidance\n                if unwrap_model(transformer).config.guidance_embeds:\n                    guidance = torch.tensor([args.guidance_scale], device=accelerator.device)\n                    guidance = guidance.expand(model_input.shape[0])\n                else:\n                    guidance = None\n\n                # Predict the noise\n                prompt_embeds = batch[\"prompt_embeds\"].to(device=accelerator.device, dtype=weight_dtype)\n                pooled_prompt_embeds = batch[\"pooled_prompt_embeds\"].to(device=accelerator.device, dtype=weight_dtype)\n                text_ids = batch[\"text_ids\"].to(device=accelerator.device, dtype=weight_dtype)\n                model_pred = transformer(\n                    hidden_states=packed_noisy_model_input,\n                    # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)\n                    timestep=timesteps / 1000,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    return_dict=False,\n                )[0]\n                model_pred = FluxPipeline._unpack_latents(\n                    model_pred,\n                    height=model_input.shape[2] * vae_scale_factor,\n                    width=model_input.shape[3] * vae_scale_factor,\n                    vae_scale_factor=vae_scale_factor,\n                )\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = noise - model_input\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n                accelerator.backward(loss)\n\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        transformer = unwrap_model(transformer)\n        transformer_lora_layers = get_peft_model_state_dict(transformer)\n\n        FluxPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n            text_encoder_lora_layers=None,\n        )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                base_model=args.pretrained_model_name_or_path,\n                instance_prompt=None,\n                repo_folder=args.output_dir,\n                quantization_config=transformer.config[\"quantization_config\"],\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/geodiff/README.md",
    "content": "# GeoDiff\n\n> [!TIP]\n> This notebook is not actively maintained by the Diffusers team. For any questions or comments, please contact [natolambert](https://twitter.com/natolambert).\n\nThis is an experimental research notebook demonstrating how to generate stable 3D structures of molecules with [GeoDiff](https://github.com/MinkaiXu/GeoDiff) and Diffusers.\n"
  },
  {
    "path": "examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"F88mignPnalS\"\n   },\n   \"source\": [\n    \"# Introduction\\n\",\n    \"\\n\",\n    \"This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\\n\",\n    \"The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\\n\",\n    \"\\n\",\n    \"The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\\n\",\n    \"\\n\",\n    \"This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\\n\",\n    \"\\n\",\n    \"> Colab made by [natolambert](https://twitter.com/natolambert).\\n\",\n    \"\\n\",\n    \"![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"7cnwXMocnuzB\"\n   },\n   \"source\": [\n    \"## Installations\\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"ff9SxWnaNId9\"\n   },\n   \"source\": [\n    \"### Install Conda\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"1g_6zOabItDk\"\n   },\n   \"source\": [\n    \"Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"K0ofXobG5Y-X\",\n    \"outputId\": \"572c3d25-6f19-4c1e-83f5-a1d084a3207f\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"nvcc: NVIDIA (R) Cuda compiler driver\\n\",\n      \"Copyright (c) 2005-2021 NVIDIA Corporation\\n\",\n      \"Built on Sun_Feb_14_21:12:58_PST_2021\\n\",\n      \"Cuda compilation tools, release 11.2, V11.2.152\\n\",\n      \"Build cuda_11.2.r11.2/compiler.29618528_0\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"!nvcc --version\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"VfthW90vI0nw\"\n   },\n   \"source\": [\n    \"Install Conda for some more complex dependencies for geometric networks.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"2WNFzSnbiE0k\",\n    \"outputId\": \"690d0d4d-9d0a-4ead-c6dc-086f113f532f\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"\\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\\u001b[0m\\u001b[33m\\n\",\n      \"\\u001b[0m\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"!pip install -q condacolab\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"NUsbWYCUI7Km\"\n   },\n   \"source\": [\n    \"Setup Conda\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"FZelreINdmd0\",\n    \"outputId\": \"635f0cb8-0af4-499f-e0a4-b3790cb12e9f\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"✨🍰✨ Everything looks OK!\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import condacolab\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"condacolab.install()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"JzDHaPU7I9Sn\"\n   },\n   \"source\": [\n    \"Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"JMxRjHhL7w8V\",\n    \"outputId\": \"6ed511b3-9262-49e8-b340-08e76b05ebd8\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Collecting package metadata (current_repodata.json): - \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\bdone\\n\",\n      \"Solving environment: \\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\bdone\\n\",\n      \"\\n\",\n      \"## Package Plan ##\\n\",\n      \"\\n\",\n      \"  environment location: /usr/local\\n\",\n      \"\\n\",\n      \"  added / updated specs:\\n\",\n      \"    - cudatoolkit=11.1\\n\",\n      \"    - pytorch\\n\",\n      \"    - torchaudio\\n\",\n      \"    - torchvision\\n\",\n      \"\\n\",\n      \"\\n\",\n      \"The following packages will be downloaded:\\n\",\n      \"\\n\",\n      \"    package                    |            build\\n\",\n      \"    ---------------------------|-----------------\\n\",\n      \"    conda-22.9.0               |   py37h89c1867_1         960 KB  conda-forge\\n\",\n      \"    ------------------------------------------------------------\\n\",\n      \"                                           Total:         960 KB\\n\",\n      \"\\n\",\n      \"The following packages will be UPDATED:\\n\",\n      \"\\n\",\n      \"  conda                               4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\\n\",\n      \"\\n\",\n      \"\\n\",\n      \"\\n\",\n      \"Downloading and Extracting Packages\\n\",\n      \"conda-22.9.0         | 960 KB    | : 100% 1.0/1 [00:00<00:00,  4.15it/s]\\n\",\n      \"Preparing transaction: / \\b\\bdone\\n\",\n      \"Verifying transaction: \\\\ \\b\\bdone\\n\",\n      \"Executing transaction: / \\b\\bdone\\n\",\n      \"Retrieving notices: ...working... done\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\\n\",\n    \"# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"QDS6FPZ0Tu5b\"\n   },\n   \"source\": [\n    \"Need to remove a pathspec for colab that specifies the incorrect cuda version.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"dq1lxR10TtrR\",\n    \"outputId\": \"ed9c5a71-b449-418f-abb7-072b74e7f6c8\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"!rm /usr/local/conda-meta/pinned\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"Z1L3DdZOJB30\"\n   },\n   \"source\": [\n    \"Install torch geometric (used in the model later)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"D5ukfCOWfjzK\",\n    \"outputId\": \"8437485a-5aa6-4d53-8f7f-23517ac1ace6\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Collecting package metadata (current_repodata.json): - \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\bdone\\n\",\n      \"Solving environment: | \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\bdone\\n\",\n      \"\\n\",\n      \"## Package Plan ##\\n\",\n      \"\\n\",\n      \"  environment location: /usr/local\\n\",\n      \"\\n\",\n      \"  added / updated specs:\\n\",\n      \"    - pytorch-geometric=1.7.2\\n\",\n      \"\\n\",\n      \"\\n\",\n      \"The following packages will be downloaded:\\n\",\n      \"\\n\",\n      \"    package                    |            build\\n\",\n      \"    ---------------------------|-----------------\\n\",\n      \"    decorator-4.4.2            |             py_0          11 KB  conda-forge\\n\",\n      \"    googledrivedownloader-0.4  |     pyhd3deb0d_1           7 KB  conda-forge\\n\",\n      \"    jinja2-3.1.2               |     pyhd8ed1ab_1          99 KB  conda-forge\\n\",\n      \"    joblib-1.2.0               |     pyhd8ed1ab_0         205 KB  conda-forge\\n\",\n      \"    markupsafe-2.1.1           |   py37h540881e_1          22 KB  conda-forge\\n\",\n      \"    networkx-2.5.1             |     pyhd8ed1ab_0         1.2 MB  conda-forge\\n\",\n      \"    pandas-1.2.3               |   py37hdc94413_0        11.8 MB  conda-forge\\n\",\n      \"    pyparsing-3.0.9            |     pyhd8ed1ab_0          79 KB  conda-forge\\n\",\n      \"    python-dateutil-2.8.2      |     pyhd8ed1ab_0         240 KB  conda-forge\\n\",\n      \"    python-louvain-0.15        |     pyhd8ed1ab_1          13 KB  conda-forge\\n\",\n      \"    pytorch-cluster-1.5.9      |py37_torch_1.8.0_cu111         1.2 MB  rusty1s\\n\",\n      \"    pytorch-geometric-1.7.2    |py37_torch_1.8.0_cu111         445 KB  rusty1s\\n\",\n      \"    pytorch-scatter-2.0.8      |py37_torch_1.8.0_cu111         6.1 MB  rusty1s\\n\",\n      \"    pytorch-sparse-0.6.12      |py37_torch_1.8.0_cu111         2.9 MB  rusty1s\\n\",\n      \"    pytorch-spline-conv-1.2.1  |py37_torch_1.8.0_cu111         736 KB  rusty1s\\n\",\n      \"    pytz-2022.4                |     pyhd8ed1ab_0         232 KB  conda-forge\\n\",\n      \"    scikit-learn-1.0.2         |   py37hf9e9bfc_0         7.8 MB  conda-forge\\n\",\n      \"    scipy-1.7.3                |   py37hf2a6cf1_0        21.8 MB  conda-forge\\n\",\n      \"    setuptools-59.8.0          |   py37h89c1867_1         1.0 MB  conda-forge\\n\",\n      \"    threadpoolctl-3.1.0        |     pyh8a188c0_0          18 KB  conda-forge\\n\",\n      \"    ------------------------------------------------------------\\n\",\n      \"                                           Total:        55.9 MB\\n\",\n      \"\\n\",\n      \"The following NEW packages will be INSTALLED:\\n\",\n      \"\\n\",\n      \"  decorator          conda-forge/noarch::decorator-4.4.2-py_0 None\\n\",\n      \"  googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\\n\",\n      \"  jinja2             conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\\n\",\n      \"  joblib             conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\\n\",\n      \"  markupsafe         conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\\n\",\n      \"  networkx           conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\\n\",\n      \"  pandas             conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\\n\",\n      \"  pyparsing          conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\\n\",\n      \"  python-dateutil    conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\\n\",\n      \"  python-louvain     conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\\n\",\n      \"  pytorch-cluster    rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\\n\",\n      \"  pytorch-geometric  rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\\n\",\n      \"  pytorch-scatter    rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\\n\",\n      \"  pytorch-sparse     rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\\n\",\n      \"  pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\\n\",\n      \"  pytz               conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\\n\",\n      \"  scikit-learn       conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\\n\",\n      \"  scipy              conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\\n\",\n      \"  threadpoolctl      conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\\n\",\n      \"\\n\",\n      \"The following packages will be DOWNGRADED:\\n\",\n      \"\\n\",\n      \"  setuptools                          65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\\n\",\n      \"\\n\",\n      \"\\n\",\n      \"\\n\",\n      \"Downloading and Extracting Packages\\n\",\n      \"scikit-learn-1.0.2   | 7.8 MB    | : 100% 1.0/1 [00:01<00:00,  1.37s/it]              \\n\",\n      \"pytorch-scatter-2.0. | 6.1 MB    | : 100% 1.0/1 [00:06<00:00,  6.18s/it]\\n\",\n      \"pytorch-geometric-1. | 445 KB    | : 100% 1.0/1 [00:02<00:00,  2.53s/it]\\n\",\n      \"scipy-1.7.3          | 21.8 MB   | : 100% 1.0/1 [00:03<00:00,  3.06s/it]\\n\",\n      \"python-dateutil-2.8. | 240 KB    | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\\n\",\n      \"pytorch-spline-conv- | 736 KB    | : 100% 1.0/1 [00:01<00:00,  1.00s/it]\\n\",\n      \"pytorch-sparse-0.6.1 | 2.9 MB    | : 100% 1.0/1 [00:07<00:00,  7.51s/it]\\n\",\n      \"pyparsing-3.0.9      | 79 KB     | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\\n\",\n      \"pytorch-cluster-1.5. | 1.2 MB    | : 100% 1.0/1 [00:02<00:00,  2.78s/it]\\n\",\n      \"jinja2-3.1.2         | 99 KB     | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\\n\",\n      \"decorator-4.4.2      | 11 KB     | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\\n\",\n      \"joblib-1.2.0         | 205 KB    | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\\n\",\n      \"pytz-2022.4          | 232 KB    | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\\n\",\n      \"python-louvain-0.15  | 13 KB     | : 100% 1.0/1 [00:00<00:00,  3.34it/s]\\n\",\n      \"googledrivedownloade | 7 KB      | : 100% 1.0/1 [00:00<00:00,  3.33it/s]\\n\",\n      \"threadpoolctl-3.1.0  | 18 KB     | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\\n\",\n      \"markupsafe-2.1.1     | 22 KB     | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\\n\",\n      \"pandas-1.2.3         | 11.8 MB   | : 100% 1.0/1 [00:02<00:00,  2.08s/it]               \\n\",\n      \"networkx-2.5.1       | 1.2 MB    | : 100% 1.0/1 [00:01<00:00,  1.39s/it]\\n\",\n      \"setuptools-59.8.0    | 1.0 MB    | : 100% 1.0/1 [00:00<00:00,  4.25it/s]\\n\",\n      \"Preparing transaction: / \\b\\b- \\b\\b\\\\ \\b\\bdone\\n\",\n      \"Verifying transaction: / \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\bdone\\n\",\n      \"Executing transaction: / \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\b\\\\ \\b\\b| \\b\\b/ \\b\\b- \\b\\bdone\\n\",\n      \"Retrieving notices: ...working... done\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"!conda install -c rusty1s pytorch-geometric=1.7.2\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"ppxv6Mdkalbc\"\n   },\n   \"source\": [\n    \"### Install Diffusers\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"mgQA_XN-XGY2\",\n    \"outputId\": \"85392615-b6a4-4052-9d2a-79604be62c94\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/content\\n\",\n      \"Cloning into 'diffusers'...\\n\",\n      \"remote: Enumerating objects: 9298, done.\\u001b[K\\n\",\n      \"remote: Counting objects: 100% (40/40), done.\\u001b[K\\n\",\n      \"remote: Compressing objects: 100% (23/23), done.\\u001b[K\\n\",\n      \"remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\\u001b[K\\n\",\n      \"Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\\n\",\n      \"Resolving deltas: 100% (6168/6168), done.\\n\",\n      \"  Installing build dependencies ... \\u001b[?25l\\u001b[?25hdone\\n\",\n      \"  Getting requirements to build wheel ... \\u001b[?25l\\u001b[?25hdone\\n\",\n      \"  Preparing metadata (pyproject.toml) ... \\u001b[?25l\\u001b[?25hdone\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m757.0/757.0 kB\\u001b[0m \\u001b[31m52.8 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m163.5/163.5 kB\\u001b[0m \\u001b[31m21.9 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m40.8/40.8 kB\\u001b[0m \\u001b[31m5.5 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m596.3/596.3 kB\\u001b[0m \\u001b[31m51.7 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25h  Building wheel for diffusers (pyproject.toml) ... \\u001b[?25l\\u001b[?25hdone\\n\",\n      \"\\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\\u001b[0m\\u001b[33m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m432.7/432.7 kB\\u001b[0m \\u001b[31m36.8 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m5.3/5.3 MB\\u001b[0m \\u001b[31m90.9 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m35.3/35.3 MB\\u001b[0m \\u001b[31m39.7 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m115.1/115.1 kB\\u001b[0m \\u001b[31m16.3 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m948.0/948.0 kB\\u001b[0m \\u001b[31m63.6 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m212.2/212.2 kB\\u001b[0m \\u001b[31m21.5 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m95.8/95.8 kB\\u001b[0m \\u001b[31m12.8 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m140.8/140.8 kB\\u001b[0m \\u001b[31m18.8 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m7.6/7.6 MB\\u001b[0m \\u001b[31m104.3 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m148.0/148.0 kB\\u001b[0m \\u001b[31m20.8 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m231.3/231.3 kB\\u001b[0m \\u001b[31m30.0 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m94.8/94.8 kB\\u001b[0m \\u001b[31m14.0 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m58.8/58.8 kB\\u001b[0m \\u001b[31m8.0 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25h\\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\\u001b[0m\\u001b[33m\\n\",\n      \"\\u001b[0m\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%cd /content\\n\",\n    \"\\n\",\n    \"# install latest HF diffusers (will update to the release once added)\\n\",\n    \"!git clone https://github.com/huggingface/diffusers.git\\n\",\n    \"!pip install -q /content/diffusers\\n\",\n    \"\\n\",\n    \"# dependencies for diffusers\\n\",\n    \"!pip install -q datasets transformers\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"LZO6AJKuJKO8\"\n   },\n   \"source\": [\n    \"Check that torch is installed correctly and utilizing the GPU in the colab\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 53\n    },\n    \"id\": \"gZt7BNi1e1PA\",\n    \"outputId\": \"a0e1832c-9c02-49aa-cff8-1339e6cdc889\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"True\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"application/vnd.google.colaboratory.intrinsic+json\": {\n       \"type\": \"string\"\n      },\n      \"text/plain\": [\n       \"'1.8.2'\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"import torch\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"print(torch.cuda.is_available())\\n\",\n    \"torch.__version__\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"KLE7CqlfJNUO\"\n   },\n   \"source\": [\n    \"### Install Chemistry-specific Dependencies\\n\",\n    \"\\n\",\n    \"Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"0CPv_NvehRz3\",\n    \"outputId\": \"6ee0ae4e-4511-4816-de29-22b1c21d49bc\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\\n\",\n      \"Collecting rdkit\\n\",\n      \"  Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m36.8/36.8 MB\\u001b[0m \\u001b[31m34.6 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\\n\",\n      \"Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\\n\",\n      \"Installing collected packages: rdkit\\n\",\n      \"Successfully installed rdkit-2022.3.5\\n\",\n      \"\\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\\u001b[0m\\u001b[33m\\n\",\n      \"\\u001b[0m\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"!pip install rdkit\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"88GaDbDPxJ5I\"\n   },\n   \"source\": [\n    \"### Get viewer from nglview\\n\",\n    \"\\n\",\n    \"The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\\n\",\n    \"The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\\n\",\n    \"The rdmol in this object is a source of ground truth for the generated molecules.\\n\",\n    \"\\n\",\n    \"You will use one rendering function from nglviewer later!\\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 1000\n    },\n    \"id\": \"jcl8GCS2mz6t\",\n    \"outputId\": \"99b5cc40-67bb-4d8e-faa0-47d7cb33e98f\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\\n\",\n      \"Collecting nglview\\n\",\n      \"  Downloading nglview-3.0.3.tar.gz (5.7 MB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m5.7/5.7 MB\\u001b[0m \\u001b[31m91.2 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25h  Installing build dependencies ... \\u001b[?25l\\u001b[?25hdone\\n\",\n      \"  Getting requirements to build wheel ... \\u001b[?25l\\u001b[?25hdone\\n\",\n      \"  Preparing metadata (pyproject.toml) ... \\u001b[?25l\\u001b[?25hdone\\n\",\n      \"Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\\n\",\n      \"Collecting jupyterlab-widgets\\n\",\n      \"  Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m384.1/384.1 kB\\u001b[0m \\u001b[31m40.6 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting ipywidgets>=7\\n\",\n      \"  Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m134.4/134.4 kB\\u001b[0m \\u001b[31m21.2 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting widgetsnbextension~=4.0\\n\",\n      \"  Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m2.0/2.0 MB\\u001b[0m \\u001b[31m84.4 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting ipython>=6.1.0\\n\",\n      \"  Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m793.8/793.8 kB\\u001b[0m \\u001b[31m60.7 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting ipykernel>=4.5.1\\n\",\n      \"  Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m138.4/138.4 kB\\u001b[0m \\u001b[31m20.9 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting traitlets>=4.3.1\\n\",\n      \"  Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m107.1/107.1 kB\\u001b[0m \\u001b[31m17.0 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\\n\",\n      \"Collecting pyzmq>=17\\n\",\n      \"  Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m1.1/1.1 MB\\u001b[0m \\u001b[31m68.4 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting matplotlib-inline>=0.1\\n\",\n      \"  Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\\n\",\n      \"Collecting tornado>=6.1\\n\",\n      \"  Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m424.0/424.0 kB\\u001b[0m \\u001b[31m41.2 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting nest-asyncio\\n\",\n      \"  Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\\n\",\n      \"Collecting debugpy>=1.0\\n\",\n      \"  Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m1.8/1.8 MB\\u001b[0m \\u001b[31m83.4 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting psutil\\n\",\n      \"  Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m281.3/281.3 kB\\u001b[0m \\u001b[31m33.1 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting jupyter-client>=6.1.12\\n\",\n      \"  Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m132.2/132.2 kB\\u001b[0m \\u001b[31m19.7 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting pickleshare\\n\",\n      \"  Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\\n\",\n      \"Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\\n\",\n      \"Collecting backcall\\n\",\n      \"  Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\\n\",\n      \"Collecting pexpect>4.3\\n\",\n      \"  Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m59.0/59.0 kB\\u001b[0m \\u001b[31m7.3 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting pygments\\n\",\n      \"  Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m1.1/1.1 MB\\u001b[0m \\u001b[31m70.9 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting jedi>=0.16\\n\",\n      \"  Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m1.6/1.6 MB\\u001b[0m \\u001b[31m83.5 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\\n\",\n      \"  Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m382.3/382.3 kB\\u001b[0m \\u001b[31m40.0 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\\n\",\n      \"Collecting parso<0.9.0,>=0.8.0\\n\",\n      \"  Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m100.8/100.8 kB\\u001b[0m \\u001b[31m14.7 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\\n\",\n      \"Collecting entrypoints\\n\",\n      \"  Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\\n\",\n      \"Collecting jupyter-core>=4.9.2\\n\",\n      \"  Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\\n\",\n      \"\\u001b[2K     \\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\\u001b[0m \\u001b[32m88.4/88.4 kB\\u001b[0m \\u001b[31m14.0 MB/s\\u001b[0m eta \\u001b[36m0:00:00\\u001b[0m\\n\",\n      \"\\u001b[?25hCollecting ptyprocess>=0.5\\n\",\n      \"  Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\\n\",\n      \"Collecting wcwidth\\n\",\n      \"  Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\\n\",\n      \"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\\n\",\n      \"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\\n\",\n      \"Building wheels for collected packages: nglview\\n\",\n      \"  Building wheel for nglview (pyproject.toml) ... \\u001b[?25l\\u001b[?25hdone\\n\",\n      \"  Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\\n\",\n      \"  Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\\n\",\n      \"Successfully built nglview\\n\",\n      \"Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\\n\",\n      \"Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\\n\",\n      \"\\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\\u001b[0m\\u001b[33m\\n\",\n      \"\\u001b[0m\"\n     ]\n    },\n    {\n     \"data\": {\n      \"application/vnd.colab-display-data+json\": {\n       \"pip_warning\": {\n        \"packages\": [\n         \"pexpect\",\n         \"pickleshare\",\n         \"wcwidth\"\n        ]\n       }\n      }\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"!pip install nglview\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"8t8_e_uVLdKB\"\n   },\n   \"source\": [\n    \"## Create a diffusion model\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"G0rMncVtNSqU\"\n   },\n   \"source\": [\n    \"### Model class(es)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"L5FEXz5oXkzt\"\n   },\n   \"source\": [\n    \"Imports\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"-3-P4w5sXkRU\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\\n\",\n    \"# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\\n\",\n    \"from dataclasses import dataclass\\n\",\n    \"from typing import Callable, Tuple, Union\\n\",\n    \"\\n\",\n    \"import numpy as np\\n\",\n    \"import torch\\n\",\n    \"import torch.nn.functional as F\\n\",\n    \"from torch import Tensor, nn\\n\",\n    \"from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\\n\",\n    \"from torch_geometric.nn import MessagePassing, radius, radius_graph\\n\",\n    \"from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\\n\",\n    \"from torch_geometric.utils import dense_to_sparse, to_dense_adj\\n\",\n    \"from torch_scatter import scatter_add\\n\",\n    \"from torch_sparse import SparseTensor, coalesce\\n\",\n    \"\\n\",\n    \"from diffusers.configuration_utils import ConfigMixin, register_to_config\\n\",\n    \"from diffusers.modeling_utils import ModelMixin\\n\",\n    \"from diffusers.utils import BaseOutput\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"EzJQXPN_XrMX\"\n   },\n   \"source\": [\n    \"Helper classes\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"oR1Y56QiLY90\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"@dataclass\\n\",\n    \"class MoleculeGNNOutput(BaseOutput):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Args:\\n\",\n    \"        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\\n\",\n    \"            Hidden states output. Output of last layer of model.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    sample: torch.Tensor\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class MultiLayerPerceptron(nn.Module):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\\n\",\n    \"    Args:\\n\",\n    \"        input_dim (int): input dimension\\n\",\n    \"        hidden_dim (list of int): hidden dimensions\\n\",\n    \"        activation (str or function, optional): activation function\\n\",\n    \"        dropout (float, optional): dropout rate\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    def __init__(self, input_dim, hidden_dims, activation=\\\"relu\\\", dropout=0):\\n\",\n    \"        super(MultiLayerPerceptron, self).__init__()\\n\",\n    \"\\n\",\n    \"        self.dims = [input_dim] + hidden_dims\\n\",\n    \"        if isinstance(activation, str):\\n\",\n    \"            self.activation = getattr(F, activation)\\n\",\n    \"        else:\\n\",\n    \"            print(f\\\"Warning, activation passed {activation} is not string and ignored\\\")\\n\",\n    \"            self.activation = None\\n\",\n    \"        if dropout > 0:\\n\",\n    \"            self.dropout = nn.Dropout(dropout)\\n\",\n    \"        else:\\n\",\n    \"            self.dropout = None\\n\",\n    \"\\n\",\n    \"        self.layers = nn.ModuleList()\\n\",\n    \"        for i in range(len(self.dims) - 1):\\n\",\n    \"            self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\\n\",\n    \"\\n\",\n    \"    def forward(self, x):\\n\",\n    \"        \\\"\\\"\\\"\\\"\\\"\\\"\\n\",\n    \"        for i, layer in enumerate(self.layers):\\n\",\n    \"            x = layer(x)\\n\",\n    \"            if i < len(self.layers) - 1:\\n\",\n    \"                if self.activation:\\n\",\n    \"                    x = self.activation(x)\\n\",\n    \"                if self.dropout:\\n\",\n    \"                    x = self.dropout(x)\\n\",\n    \"        return x\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class ShiftedSoftplus(torch.nn.Module):\\n\",\n    \"    def __init__(self):\\n\",\n    \"        super(ShiftedSoftplus, self).__init__()\\n\",\n    \"        self.shift = torch.log(torch.tensor(2.0)).item()\\n\",\n    \"\\n\",\n    \"    def forward(self, x):\\n\",\n    \"        return F.softplus(x) - self.shift\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class CFConv(MessagePassing):\\n\",\n    \"    def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\\n\",\n    \"        super(CFConv, self).__init__(aggr=\\\"add\\\")\\n\",\n    \"        self.lin1 = Linear(in_channels, num_filters, bias=False)\\n\",\n    \"        self.lin2 = Linear(num_filters, out_channels)\\n\",\n    \"        self.nn = mlp\\n\",\n    \"        self.cutoff = cutoff\\n\",\n    \"        self.smooth = smooth\\n\",\n    \"\\n\",\n    \"        self.reset_parameters()\\n\",\n    \"\\n\",\n    \"    def reset_parameters(self):\\n\",\n    \"        torch.nn.init.xavier_uniform_(self.lin1.weight)\\n\",\n    \"        torch.nn.init.xavier_uniform_(self.lin2.weight)\\n\",\n    \"        self.lin2.bias.data.fill_(0)\\n\",\n    \"\\n\",\n    \"    def forward(self, x, edge_index, edge_length, edge_attr):\\n\",\n    \"        if self.smooth:\\n\",\n    \"            C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\\n\",\n    \"            C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0)  # Modification: cutoff\\n\",\n    \"        else:\\n\",\n    \"            C = (edge_length <= self.cutoff).float()\\n\",\n    \"        W = self.nn(edge_attr) * C.view(-1, 1)\\n\",\n    \"\\n\",\n    \"        x = self.lin1(x)\\n\",\n    \"        x = self.propagate(edge_index, x=x, W=W)\\n\",\n    \"        x = self.lin2(x)\\n\",\n    \"        return x\\n\",\n    \"\\n\",\n    \"    def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\\n\",\n    \"        return x_j * W\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class InteractionBlock(torch.nn.Module):\\n\",\n    \"    def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\\n\",\n    \"        super(InteractionBlock, self).__init__()\\n\",\n    \"        mlp = Sequential(\\n\",\n    \"            Linear(num_gaussians, num_filters),\\n\",\n    \"            ShiftedSoftplus(),\\n\",\n    \"            Linear(num_filters, num_filters),\\n\",\n    \"        )\\n\",\n    \"        self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\\n\",\n    \"        self.act = ShiftedSoftplus()\\n\",\n    \"        self.lin = Linear(hidden_channels, hidden_channels)\\n\",\n    \"\\n\",\n    \"    def forward(self, x, edge_index, edge_length, edge_attr):\\n\",\n    \"        x = self.conv(x, edge_index, edge_length, edge_attr)\\n\",\n    \"        x = self.act(x)\\n\",\n    \"        x = self.lin(x)\\n\",\n    \"        return x\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class SchNetEncoder(Module):\\n\",\n    \"    def __init__(\\n\",\n    \"        self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\\n\",\n    \"    ):\\n\",\n    \"        super().__init__()\\n\",\n    \"\\n\",\n    \"        self.hidden_channels = hidden_channels\\n\",\n    \"        self.num_filters = num_filters\\n\",\n    \"        self.num_interactions = num_interactions\\n\",\n    \"        self.cutoff = cutoff\\n\",\n    \"\\n\",\n    \"        self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\\n\",\n    \"\\n\",\n    \"        self.interactions = ModuleList()\\n\",\n    \"        for _ in range(num_interactions):\\n\",\n    \"            block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\\n\",\n    \"            self.interactions.append(block)\\n\",\n    \"\\n\",\n    \"    def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\\n\",\n    \"        if embed_node:\\n\",\n    \"            assert z.dim() == 1 and z.dtype == torch.long\\n\",\n    \"            h = self.embedding(z)\\n\",\n    \"        else:\\n\",\n    \"            h = z\\n\",\n    \"        for interaction in self.interactions:\\n\",\n    \"            h = h + interaction(h, edge_index, edge_length, edge_attr)\\n\",\n    \"\\n\",\n    \"        return h\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class GINEConv(MessagePassing):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Custom class of the graph isomorphism operator from the \\\"How Powerful are Graph Neural Networks?\\n\",\n    \"    https://huggingface.co/papers/1810.00826 paper. Note that this implementation has the added option of a custom activation.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\\\"softplus\\\", **kwargs):\\n\",\n    \"        super(GINEConv, self).__init__(aggr=\\\"add\\\", **kwargs)\\n\",\n    \"        self.nn = mlp\\n\",\n    \"        self.initial_eps = eps\\n\",\n    \"\\n\",\n    \"        if isinstance(activation, str):\\n\",\n    \"            self.activation = getattr(F, activation)\\n\",\n    \"        else:\\n\",\n    \"            self.activation = None\\n\",\n    \"\\n\",\n    \"        if train_eps:\\n\",\n    \"            self.eps = torch.nn.Parameter(torch.Tensor([eps]))\\n\",\n    \"        else:\\n\",\n    \"            self.register_buffer(\\\"eps\\\", torch.Tensor([eps]))\\n\",\n    \"\\n\",\n    \"    def forward(\\n\",\n    \"        self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\\n\",\n    \"    ) -> torch.Tensor:\\n\",\n    \"        \\\"\\\"\\\"\\\"\\\"\\\"\\n\",\n    \"        if isinstance(x, torch.Tensor):\\n\",\n    \"            x: OptPairTensor = (x, x)\\n\",\n    \"\\n\",\n    \"        # Node and edge feature dimensionalites need to match.\\n\",\n    \"        if isinstance(edge_index, torch.Tensor):\\n\",\n    \"            assert edge_attr is not None\\n\",\n    \"            assert x[0].size(-1) == edge_attr.size(-1)\\n\",\n    \"        elif isinstance(edge_index, SparseTensor):\\n\",\n    \"            assert x[0].size(-1) == edge_index.size(-1)\\n\",\n    \"\\n\",\n    \"        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\\n\",\n    \"        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\\n\",\n    \"\\n\",\n    \"        x_r = x[1]\\n\",\n    \"        if x_r is not None:\\n\",\n    \"            out += (1 + self.eps) * x_r\\n\",\n    \"\\n\",\n    \"        return self.nn(out)\\n\",\n    \"\\n\",\n    \"    def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\\n\",\n    \"        if self.activation:\\n\",\n    \"            return self.activation(x_j + edge_attr)\\n\",\n    \"        else:\\n\",\n    \"            return x_j + edge_attr\\n\",\n    \"\\n\",\n    \"    def __repr__(self):\\n\",\n    \"        return \\\"{}(nn={})\\\".format(self.__class__.__name__, self.nn)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class GINEncoder(torch.nn.Module):\\n\",\n    \"    def __init__(self, hidden_dim, num_convs=3, activation=\\\"relu\\\", short_cut=True, concat_hidden=False):\\n\",\n    \"        super().__init__()\\n\",\n    \"\\n\",\n    \"        self.hidden_dim = hidden_dim\\n\",\n    \"        self.num_convs = num_convs\\n\",\n    \"        self.short_cut = short_cut\\n\",\n    \"        self.concat_hidden = concat_hidden\\n\",\n    \"        self.node_emb = nn.Embedding(100, hidden_dim)\\n\",\n    \"\\n\",\n    \"        if isinstance(activation, str):\\n\",\n    \"            self.activation = getattr(F, activation)\\n\",\n    \"        else:\\n\",\n    \"            self.activation = None\\n\",\n    \"\\n\",\n    \"        self.convs = nn.ModuleList()\\n\",\n    \"        for i in range(self.num_convs):\\n\",\n    \"            self.convs.append(\\n\",\n    \"                GINEConv(\\n\",\n    \"                    MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\\n\",\n    \"                    activation=activation,\\n\",\n    \"                )\\n\",\n    \"            )\\n\",\n    \"\\n\",\n    \"    def forward(self, z, edge_index, edge_attr):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Input:\\n\",\n    \"            data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\\n\",\n    \"            hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\\n\",\n    \"        Output:\\n\",\n    \"            node_feature: graph feature\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"        node_attr = self.node_emb(z)  # (num_node, hidden)\\n\",\n    \"\\n\",\n    \"        hiddens = []\\n\",\n    \"        conv_input = node_attr  # (num_node, hidden)\\n\",\n    \"\\n\",\n    \"        for conv_idx, conv in enumerate(self.convs):\\n\",\n    \"            hidden = conv(conv_input, edge_index, edge_attr)\\n\",\n    \"            if conv_idx < len(self.convs) - 1 and self.activation is not None:\\n\",\n    \"                hidden = self.activation(hidden)\\n\",\n    \"            assert hidden.shape == conv_input.shape\\n\",\n    \"            if self.short_cut and hidden.shape == conv_input.shape:\\n\",\n    \"                hidden += conv_input\\n\",\n    \"\\n\",\n    \"            hiddens.append(hidden)\\n\",\n    \"            conv_input = hidden\\n\",\n    \"\\n\",\n    \"        if self.concat_hidden:\\n\",\n    \"            node_feature = torch.cat(hiddens, dim=-1)\\n\",\n    \"        else:\\n\",\n    \"            node_feature = hiddens[-1]\\n\",\n    \"\\n\",\n    \"        return node_feature\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class MLPEdgeEncoder(Module):\\n\",\n    \"    def __init__(self, hidden_dim=100, activation=\\\"relu\\\"):\\n\",\n    \"        super().__init__()\\n\",\n    \"        self.hidden_dim = hidden_dim\\n\",\n    \"        self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\\n\",\n    \"        self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\\n\",\n    \"\\n\",\n    \"    @property\\n\",\n    \"    def out_channels(self):\\n\",\n    \"        return self.hidden_dim\\n\",\n    \"\\n\",\n    \"    def forward(self, edge_length, edge_type):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Input:\\n\",\n    \"            edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\\n\",\n    \"        Returns:\\n\",\n    \"            edge_attr: The representation of edges. (E, 2 * num_gaussians)\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        d_emb = self.mlp(edge_length)  # (num_edge, hidden_dim)\\n\",\n    \"        edge_attr = self.bond_emb(edge_type)  # (num_edge, hidden_dim)\\n\",\n    \"        return d_emb * edge_attr  # (num_edge, hidden)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\\n\",\n    \"    h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\\n\",\n    \"    h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1)  # (E, 2H)\\n\",\n    \"    return h_pair\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Args:\\n\",\n    \"        num_nodes:  Number of atoms.\\n\",\n    \"        edge_index: Bond indices of the original graph.\\n\",\n    \"        edge_type:  Bond types of the original graph.\\n\",\n    \"        order:  Extension order.\\n\",\n    \"    Returns:\\n\",\n    \"        new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    def binarize(x):\\n\",\n    \"        return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\\n\",\n    \"\\n\",\n    \"    def get_higher_order_adj_matrix(adj, order):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Args:\\n\",\n    \"            adj:        (N, N)\\n\",\n    \"            type_mat:   (N, N)\\n\",\n    \"        Returns:\\n\",\n    \"            Following attributes will be updated:\\n\",\n    \"              - edge_index\\n\",\n    \"              - edge_type\\n\",\n    \"            Following attributes will be added to the data object:\\n\",\n    \"              - bond_edge_index: Original edge_index.\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        adj_mats = [\\n\",\n    \"            torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\\n\",\n    \"            binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\\n\",\n    \"        ]\\n\",\n    \"\\n\",\n    \"        for i in range(2, order + 1):\\n\",\n    \"            adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\\n\",\n    \"        order_mat = torch.zeros_like(adj)\\n\",\n    \"\\n\",\n    \"        for i in range(1, order + 1):\\n\",\n    \"            order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\\n\",\n    \"\\n\",\n    \"        return order_mat\\n\",\n    \"\\n\",\n    \"    num_types = 22\\n\",\n    \"    # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\\n\",\n    \"    # from rdkit.Chem.rdchem import BondType as BT\\n\",\n    \"    N = num_nodes\\n\",\n    \"    adj = to_dense_adj(edge_index).squeeze(0)\\n\",\n    \"    adj_order = get_higher_order_adj_matrix(adj, order)  # (N, N)\\n\",\n    \"\\n\",\n    \"    type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0)  # (N, N)\\n\",\n    \"    type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\\n\",\n    \"    assert (type_mat * type_highorder == 0).all()\\n\",\n    \"    type_new = type_mat + type_highorder\\n\",\n    \"\\n\",\n    \"    new_edge_index, new_edge_type = dense_to_sparse(type_new)\\n\",\n    \"    _, edge_order = dense_to_sparse(adj_order)\\n\",\n    \"\\n\",\n    \"    # data.bond_edge_index = data.edge_index  # Save original edges\\n\",\n    \"    new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N)  # modify data\\n\",\n    \"\\n\",\n    \"    return new_edge_index, new_edge_type\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\\n\",\n    \"    assert edge_type.dim() == 1\\n\",\n    \"    N = pos.size(0)\\n\",\n    \"\\n\",\n    \"    bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\\n\",\n    \"\\n\",\n    \"    if is_sidechain is None:\\n\",\n    \"        rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch)  # (2, E_r)\\n\",\n    \"    else:\\n\",\n    \"        # fetch sidechain and its batch index\\n\",\n    \"        is_sidechain = is_sidechain.bool()\\n\",\n    \"        dummy_index = torch.arange(pos.size(0), device=pos.device)\\n\",\n    \"        sidechain_pos = pos[is_sidechain]\\n\",\n    \"        sidechain_index = dummy_index[is_sidechain]\\n\",\n    \"        sidechain_batch = batch[is_sidechain]\\n\",\n    \"\\n\",\n    \"        assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\\n\",\n    \"        r_edge_index_x = assign_index[1]\\n\",\n    \"        r_edge_index_y = assign_index[0]\\n\",\n    \"        r_edge_index_y = sidechain_index[r_edge_index_y]\\n\",\n    \"\\n\",\n    \"        rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y))  # (2, E)\\n\",\n    \"        rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x))  # (2, E)\\n\",\n    \"        rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1)  # (2, 2E)\\n\",\n    \"        # delete self loop\\n\",\n    \"        rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\\n\",\n    \"\\n\",\n    \"    rgraph_adj = torch.sparse.LongTensor(\\n\",\n    \"        rgraph_edge_index,\\n\",\n    \"        torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\\n\",\n    \"        torch.Size([N, N]),\\n\",\n    \"    )\\n\",\n    \"\\n\",\n    \"    composed_adj = (bgraph_adj + rgraph_adj).coalesce()  # Sparse (N, N, T)\\n\",\n    \"\\n\",\n    \"    new_edge_index = composed_adj.indices()\\n\",\n    \"    new_edge_type = composed_adj.values().long()\\n\",\n    \"\\n\",\n    \"    return new_edge_index, new_edge_type\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def extend_graph_order_radius(\\n\",\n    \"    num_nodes,\\n\",\n    \"    pos,\\n\",\n    \"    edge_index,\\n\",\n    \"    edge_type,\\n\",\n    \"    batch,\\n\",\n    \"    order=3,\\n\",\n    \"    cutoff=10.0,\\n\",\n    \"    extend_order=True,\\n\",\n    \"    extend_radius=True,\\n\",\n    \"    is_sidechain=None,\\n\",\n    \"):\\n\",\n    \"    if extend_order:\\n\",\n    \"        edge_index, edge_type = _extend_graph_order(\\n\",\n    \"            num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"    if extend_radius:\\n\",\n    \"        edge_index, edge_type = _extend_to_radius_graph(\\n\",\n    \"            pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"    return edge_index, edge_type\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def get_distance(pos, edge_index):\\n\",\n    \"    return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def graph_field_network(score_d, pos, edge_index, edge_length):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\\n\",\n    \"    5-7 of the GeoDiff Paper https://huggingface.co/papers/2203.02923\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    N = pos.size(0)\\n\",\n    \"    dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]])  # (E, 3)\\n\",\n    \"    score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\\n\",\n    \"        -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\\n\",\n    \"    )  # (N, 3)\\n\",\n    \"    return score_pos\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def clip_norm(vec, limit, p=2):\\n\",\n    \"    norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\\n\",\n    \"    denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\\n\",\n    \"    return vec * denom\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def is_local_edge(edge_type):\\n\",\n    \"    return edge_type > 0\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"QWrHJFcYXyUB\"\n   },\n   \"source\": [\n    \"Main model class!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"MCeZA1qQXzoK\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"class MoleculeGNN(ModelMixin, ConfigMixin):\\n\",\n    \"    @register_to_config\\n\",\n    \"    def __init__(\\n\",\n    \"        self,\\n\",\n    \"        hidden_dim=128,\\n\",\n    \"        num_convs=6,\\n\",\n    \"        num_convs_local=4,\\n\",\n    \"        cutoff=10.0,\\n\",\n    \"        mlp_act=\\\"relu\\\",\\n\",\n    \"        edge_order=3,\\n\",\n    \"        edge_encoder=\\\"mlp\\\",\\n\",\n    \"        smooth_conv=True,\\n\",\n    \"    ):\\n\",\n    \"        super().__init__()\\n\",\n    \"        self.cutoff = cutoff\\n\",\n    \"        self.edge_encoder = edge_encoder\\n\",\n    \"        self.edge_order = edge_order\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\\n\",\n    \"        in SchNetEncoder\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act)  # get_edge_encoder(config)\\n\",\n    \"        self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act)  # get_edge_encoder(config)\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        The graph neural network that extracts node-wise features.\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        self.encoder_global = SchNetEncoder(\\n\",\n    \"            hidden_channels=hidden_dim,\\n\",\n    \"            num_filters=hidden_dim,\\n\",\n    \"            num_interactions=num_convs,\\n\",\n    \"            edge_channels=self.edge_encoder_global.out_channels,\\n\",\n    \"            cutoff=cutoff,\\n\",\n    \"            smooth=smooth_conv,\\n\",\n    \"        )\\n\",\n    \"        self.encoder_local = GINEncoder(\\n\",\n    \"            hidden_dim=hidden_dim,\\n\",\n    \"            num_convs=num_convs_local,\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\\n\",\n    \"            gradients w.r.t. edge_length (out_dim = 1).\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        self.grad_global_dist_mlp = MultiLayerPerceptron(\\n\",\n    \"            2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        self.grad_local_dist_mlp = MultiLayerPerceptron(\\n\",\n    \"            2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\\n\",\n    \"        )\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Incorporate parameters together\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\\n\",\n    \"        self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\\n\",\n    \"\\n\",\n    \"    def _forward(\\n\",\n    \"        self,\\n\",\n    \"        atom_type,\\n\",\n    \"        pos,\\n\",\n    \"        bond_index,\\n\",\n    \"        bond_type,\\n\",\n    \"        batch,\\n\",\n    \"        time_step,  # NOTE, model trained without timestep performed best\\n\",\n    \"        edge_index=None,\\n\",\n    \"        edge_type=None,\\n\",\n    \"        edge_length=None,\\n\",\n    \"        return_edges=False,\\n\",\n    \"        extend_order=True,\\n\",\n    \"        extend_radius=True,\\n\",\n    \"        is_sidechain=None,\\n\",\n    \"    ):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Args:\\n\",\n    \"            atom_type:  Types of atoms, (N, ).\\n\",\n    \"            bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\\n\",\n    \"            bond_type:  Bond types, (E, ).\\n\",\n    \"            batch:      Node index to graph index, (N, ).\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        N = atom_type.size(0)\\n\",\n    \"        if edge_index is None or edge_type is None or edge_length is None:\\n\",\n    \"            edge_index, edge_type = extend_graph_order_radius(\\n\",\n    \"                num_nodes=N,\\n\",\n    \"                pos=pos,\\n\",\n    \"                edge_index=bond_index,\\n\",\n    \"                edge_type=bond_type,\\n\",\n    \"                batch=batch,\\n\",\n    \"                order=self.edge_order,\\n\",\n    \"                cutoff=self.cutoff,\\n\",\n    \"                extend_order=extend_order,\\n\",\n    \"                extend_radius=extend_radius,\\n\",\n    \"                is_sidechain=is_sidechain,\\n\",\n    \"            )\\n\",\n    \"            edge_length = get_distance(pos, edge_index).unsqueeze(-1)  # (E, 1)\\n\",\n    \"        local_edge_mask = is_local_edge(edge_type)  # (E, )\\n\",\n    \"\\n\",\n    \"        # with the parameterization of NCSNv2\\n\",\n    \"        # DDPM loss implicit handle the noise variance scale conditioning\\n\",\n    \"        sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device)  # (E, 1)\\n\",\n    \"\\n\",\n    \"        # Encoding global\\n\",\n    \"        edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type)  # Embed edges\\n\",\n    \"\\n\",\n    \"        # Global\\n\",\n    \"        node_attr_global = self.encoder_global(\\n\",\n    \"            z=atom_type,\\n\",\n    \"            edge_index=edge_index,\\n\",\n    \"            edge_length=edge_length,\\n\",\n    \"            edge_attr=edge_attr_global,\\n\",\n    \"        )\\n\",\n    \"        # Assemble pairwise features\\n\",\n    \"        h_pair_global = assemble_atom_pair_feature(\\n\",\n    \"            node_attr=node_attr_global,\\n\",\n    \"            edge_index=edge_index,\\n\",\n    \"            edge_attr=edge_attr_global,\\n\",\n    \"        )  # (E_global, 2H)\\n\",\n    \"        # Invariant features of edges (radius graph, global)\\n\",\n    \"        edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge)  # (E_global, 1)\\n\",\n    \"\\n\",\n    \"        # Encoding local\\n\",\n    \"        edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type)  # Embed edges\\n\",\n    \"        # edge_attr += temb_edge\\n\",\n    \"\\n\",\n    \"        # Local\\n\",\n    \"        node_attr_local = self.encoder_local(\\n\",\n    \"            z=atom_type,\\n\",\n    \"            edge_index=edge_index[:, local_edge_mask],\\n\",\n    \"            edge_attr=edge_attr_local[local_edge_mask],\\n\",\n    \"        )\\n\",\n    \"        # Assemble pairwise features\\n\",\n    \"        h_pair_local = assemble_atom_pair_feature(\\n\",\n    \"            node_attr=node_attr_local,\\n\",\n    \"            edge_index=edge_index[:, local_edge_mask],\\n\",\n    \"            edge_attr=edge_attr_local[local_edge_mask],\\n\",\n    \"        )  # (E_local, 2H)\\n\",\n    \"\\n\",\n    \"        # Invariant features of edges (bond graph, local)\\n\",\n    \"        if isinstance(sigma_edge, torch.Tensor):\\n\",\n    \"            edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\\n\",\n    \"                1.0 / sigma_edge[local_edge_mask]\\n\",\n    \"            )  # (E_local, 1)\\n\",\n    \"        else:\\n\",\n    \"            edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge)  # (E_local, 1)\\n\",\n    \"\\n\",\n    \"        if return_edges:\\n\",\n    \"            return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\\n\",\n    \"        else:\\n\",\n    \"            return edge_inv_global, edge_inv_local\\n\",\n    \"\\n\",\n    \"    def forward(\\n\",\n    \"        self,\\n\",\n    \"        sample,\\n\",\n    \"        timestep: Union[torch.Tensor, float, int],\\n\",\n    \"        return_dict: bool = True,\\n\",\n    \"        sigma=1.0,\\n\",\n    \"        global_start_sigma=0.5,\\n\",\n    \"        w_global=1.0,\\n\",\n    \"        extend_order=False,\\n\",\n    \"        extend_radius=True,\\n\",\n    \"        clip_local=None,\\n\",\n    \"        clip_global=1000.0,\\n\",\n    \"    ) -> Union[MoleculeGNNOutput, Tuple]:\\n\",\n    \"        r\\\"\\\"\\\"\\n\",\n    \"        Args:\\n\",\n    \"            sample: packed torch geometric object\\n\",\n    \"            timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\\n\",\n    \"            return_dict (`bool`, *optional*, defaults to `True`):\\n\",\n    \"                Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\\n\",\n    \"        Returns:\\n\",\n    \"            [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\\n\",\n    \"            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"        # unpack sample\\n\",\n    \"        atom_type = sample.atom_type\\n\",\n    \"        bond_index = sample.edge_index\\n\",\n    \"        bond_type = sample.edge_type\\n\",\n    \"        num_graphs = sample.num_graphs\\n\",\n    \"        pos = sample.pos\\n\",\n    \"\\n\",\n    \"        timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\\n\",\n    \"\\n\",\n    \"        edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\\n\",\n    \"            atom_type=atom_type,\\n\",\n    \"            pos=sample.pos,\\n\",\n    \"            bond_index=bond_index,\\n\",\n    \"            bond_type=bond_type,\\n\",\n    \"            batch=sample.batch,\\n\",\n    \"            time_step=timesteps,\\n\",\n    \"            return_edges=True,\\n\",\n    \"            extend_order=extend_order,\\n\",\n    \"            extend_radius=extend_radius,\\n\",\n    \"        )  # (E_global, 1), (E_local, 1)\\n\",\n    \"\\n\",\n    \"        # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\\n\",\n    \"        node_eq_local = graph_field_network(\\n\",\n    \"            edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\\n\",\n    \"        )\\n\",\n    \"        if clip_local is not None:\\n\",\n    \"            node_eq_local = clip_norm(node_eq_local, limit=clip_local)\\n\",\n    \"\\n\",\n    \"        # Global\\n\",\n    \"        if sigma < global_start_sigma:\\n\",\n    \"            edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\\n\",\n    \"            node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\\n\",\n    \"            node_eq_global = clip_norm(node_eq_global, limit=clip_global)\\n\",\n    \"        else:\\n\",\n    \"            node_eq_global = 0\\n\",\n    \"\\n\",\n    \"        # Sum\\n\",\n    \"        eps_pos = node_eq_local + node_eq_global * w_global\\n\",\n    \"\\n\",\n    \"        if not return_dict:\\n\",\n    \"            return (-eps_pos,)\\n\",\n    \"\\n\",\n    \"        return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"CCIrPYSJj9wd\"\n   },\n   \"source\": [\n    \"### Load pretrained model\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"YdrAr6Ch--Ab\"\n   },\n   \"source\": [\n    \"#### Load a model\\n\",\n    \"The model used is a design an\\n\",\n    \"equivariant convolutional layer, named graph field network (GFN).\\n\",\n    \"\\n\",\n    \"The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 172,\n     \"referenced_widgets\": [\n      \"d90f304e9560472eacfbdd11e46765eb\",\n      \"1c6246f15b654f4daa11c9bcf997b78c\",\n      \"c2321b3bff6f490ca12040a20308f555\",\n      \"b7feb522161f4cf4b7cc7c1a078ff12d\",\n      \"e2d368556e494ae7ae4e2e992af2cd4f\",\n      \"bbef741e76ec41b7ab7187b487a383df\",\n      \"561f742d418d4721b0670cc8dd62e22c\",\n      \"872915dd1bb84f538c44e26badabafdd\",\n      \"d022575f1fa2446d891650897f187b4d\",\n      \"fdc393f3468c432aa0ada05e238a5436\",\n      \"2c9362906e4b40189f16d14aa9a348da\",\n      \"6010fc8daa7a44d5aec4b830ec2ebaa1\",\n      \"7e0bb1b8d65249d3974200686b193be2\",\n      \"ba98aa6d6a884e4ab8bbb5dfb5e4cf7a\",\n      \"6526646be5ed415c84d1245b040e629b\",\n      \"24d31fc3576e43dd9f8301d2ef3a37ab\",\n      \"2918bfaadc8d4b1a9832522c40dfefb8\",\n      \"a4bfdca35cc54dae8812720f1b276a08\",\n      \"e4901541199b45c6a18824627692fc39\",\n      \"f915cf874246446595206221e900b2fe\",\n      \"a9e388f22a9742aaaf538e22575c9433\",\n      \"42f6c3db29d7484ba6b4f73590abd2f4\"\n     ]\n    },\n    \"id\": \"DyCo0nsqjbml\",\n    \"outputId\": \"d6bce9d5-c51e-43a4-e680-e1e81bdfaf45\"\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"d90f304e9560472eacfbdd11e46765eb\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Downloading:   0%|          | 0.00/3.27M [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"6010fc8daa7a44d5aec4b830ec2ebaa1\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Downloading:   0%|          | 0.00/401 [00:00<?, ?B/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"The config attributes {'type': 'diffusion', 'network': 'dualenc', 'beta_schedule': 'sigmoid', 'beta_start': 1e-07, 'beta_end': 0.002, 'num_diffusion_timesteps': 5000} were passed to MoleculeGNN, but are not expected and will be ignored. Please verify your config.json configuration file.\\n\",\n      \"Some weights of the model checkpoint at fusing/gfn-molecule-gen-drugs were not used when initializing MoleculeGNN: ['betas', 'alphas']\\n\",\n      \"- This IS expected if you are initializing MoleculeGNN from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\\n\",\n      \"- This IS NOT expected if you are initializing MoleculeGNN from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"DEVICE = \\\"cuda\\\"\\n\",\n    \"model = MoleculeGNN.from_pretrained(\\\"fusing/gfn-molecule-gen-drugs\\\").to(DEVICE)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"HdclRaqoUWUD\"\n   },\n   \"source\": [\n    \"The warnings above are because the pre-trained model was uploaded before cleaning the code!\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"PlOkPySoJ1m9\"\n   },\n   \"source\": [\n    \"#### Create scheduler\\n\",\n    \"Note, other schedulers are used in the paper for slightly improved performance over DDPM.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"nNHnIk9CkAb2\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from diffusers import DDPMScheduler\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"RnDJdDBztjFF\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"num_timesteps = 1000\\n\",\n    \"scheduler = DDPMScheduler(\\n\",\n    \"    num_train_timesteps=num_timesteps, beta_schedule=\\\"sigmoid\\\", beta_start=1e-7, beta_end=2e-3, clip_sample=False\\n\",\n    \")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"1vh3fpSAflkL\"\n   },\n   \"source\": [\n    \"### Get a dataset\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"B6qzaGjVKFVk\"\n   },\n   \"source\": [\n    \"Grab a google tool so we can upload our data directly. Note you need to download the data from ***this [file](https://huggingface.co/datasets/fusing/geodiff-example-data/blob/main/data/molecules.pkl)***\\n\",\n    \"\\n\",\n    \"(direct downloading from the hub does not yet work for this datatype)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"jbLl3EJdgj3x\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# from google.colab import files\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"E591lVuTgxPE\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"# uploaded = files.upload()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"KUNxfK3ln98Q\"\n   },\n   \"source\": [\n    \"Load the dataset with torch.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"7L4iOShTpcQX\",\n    \"outputId\": \"7f2dcd29-493e-44de-98d1-3ad50f109a4a\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"--2022-10-12 18:32:19--  https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\\n\",\n      \"Resolving huggingface.co (huggingface.co)... 44.195.102.200, 52.5.54.249, 54.210.225.113, ...\\n\",\n      \"Connecting to huggingface.co (huggingface.co)|44.195.102.200|:443... connected.\\n\",\n      \"HTTP request sent, awaiting response... 200 OK\\n\",\n      \"Length: 127774 (125K) [application/octet-stream]\\n\",\n      \"Saving to: ‘molecules.pkl’\\n\",\n      \"\\n\",\n      \"molecules.pkl       100%[===================>] 124.78K   180KB/s    in 0.7s    \\n\",\n      \"\\n\",\n      \"2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\\n\",\n      \"\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import torch\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\\n\",\n    \"dataset = torch.load(\\\"/content/molecules.pkl\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"QZcmy1EvKQRk\"\n   },\n   \"source\": [\n    \"Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"JVjz6iH_H6Eh\",\n    \"outputId\": \"898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe\"\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=<rdkit.Chem.rdchem.Mol object at 0x7f707d2cb130>, smiles=\\\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\\\")\"\n      ]\n     },\n     \"execution_count\": 20,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"dataset[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"vHNiZAUxNgoy\"\n   },\n   \"source\": [\n    \"## Run the diffusion process\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"jZ1KZrxKqENg\"\n   },\n   \"source\": [\n    \"#### Helper Functions\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"s240tYueqKKf\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import copy\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"from torch_geometric.data import Batch, Data\\n\",\n    \"from torch_scatter import scatter_mean\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def repeat_data(data: Data, num_repeat) -> Batch:\\n\",\n    \"    datas = [copy.deepcopy(data) for i in range(num_repeat)]\\n\",\n    \"    return Batch.from_data_list(datas)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def repeat_batch(batch: Batch, num_repeat) -> Batch:\\n\",\n    \"    datas = batch.to_data_list()\\n\",\n    \"    new_data = []\\n\",\n    \"    for i in range(num_repeat):\\n\",\n    \"        new_data += copy.deepcopy(datas)\\n\",\n    \"    return Batch.from_data_list(new_data)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"AMnQTk0eqT7Z\"\n   },\n   \"source\": [\n    \"#### Constants\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"WYGkzqgzrHmF\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"num_samples = 1  # solutions per molecule\\n\",\n    \"num_molecules = 3\\n\",\n    \"\\n\",\n    \"DEVICE = \\\"cuda\\\"\\n\",\n    \"sampling_type = \\\"ddpm_noisy\\\"  #'' # paper also uses \\\"generalize\\\" and \\\"ld\\\"\\n\",\n    \"# constants for inference\\n\",\n    \"w_global = 0.5  # 0,.3 for qm9\\n\",\n    \"global_start_sigma = 0.5\\n\",\n    \"eta = 1.0\\n\",\n    \"clip_local = None\\n\",\n    \"clip_pos = None\\n\",\n    \"\\n\",\n    \"# constants for data handling\\n\",\n    \"save_traj = False\\n\",\n    \"save_data = False\\n\",\n    \"output_dir = \\\"/content/\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"-xD5bJ3SqM7t\"\n   },\n   \"source\": [\n    \"#### Generate samples!\\n\",\n    \"Note that the 3d representation of a molecule is referred to as the **conformation**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"x9xuLUNg26z1\",\n    \"outputId\": \"236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\\n\",\n      \"  after removing the cwd from sys.path.\\n\",\n      \"100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import pickle\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"results = []\\n\",\n    \"\\n\",\n    \"# define sigmas\\n\",\n    \"sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\\n\",\n    \"sigmas = sigmas.to(DEVICE)\\n\",\n    \"\\n\",\n    \"for count, data in enumerate(tqdm(dataset)):\\n\",\n    \"    num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\\n\",\n    \"\\n\",\n    \"    data_input = data.clone()\\n\",\n    \"    data_input[\\\"pos_ref\\\"] = None\\n\",\n    \"    batch = repeat_data(data_input, num_samples).to(DEVICE)\\n\",\n    \"\\n\",\n    \"    # initial configuration\\n\",\n    \"    pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\\n\",\n    \"\\n\",\n    \"    # for logging animation of denoising\\n\",\n    \"    pos_traj = []\\n\",\n    \"    with torch.no_grad():\\n\",\n    \"        # scale initial sample\\n\",\n    \"        pos = pos_init * sigmas[-1]\\n\",\n    \"        for t in scheduler.timesteps:\\n\",\n    \"            batch.pos = pos\\n\",\n    \"\\n\",\n    \"            # generate geometry with model, then filter it\\n\",\n    \"            epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\\n\",\n    \"\\n\",\n    \"            # Update\\n\",\n    \"            reconstructed_pos = scheduler.step(epsilon, t, pos)[\\\"prev_sample\\\"].to(DEVICE)\\n\",\n    \"\\n\",\n    \"            pos = reconstructed_pos\\n\",\n    \"\\n\",\n    \"            if torch.isnan(pos).any():\\n\",\n    \"                print(\\\"NaN detected. Please restart.\\\")\\n\",\n    \"                raise FloatingPointError()\\n\",\n    \"\\n\",\n    \"            # recenter graph of positions for next iteration\\n\",\n    \"            pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\\n\",\n    \"\\n\",\n    \"            # optional clipping\\n\",\n    \"            if clip_pos is not None:\\n\",\n    \"                pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\\n\",\n    \"            pos_traj.append(pos.clone().cpu())\\n\",\n    \"\\n\",\n    \"    pos_gen = pos.cpu()\\n\",\n    \"    if save_traj:\\n\",\n    \"        pos_gen_traj = pos_traj.cpu()\\n\",\n    \"        data.pos_gen = torch.stack(pos_gen_traj)\\n\",\n    \"    else:\\n\",\n    \"        data.pos_gen = pos_gen\\n\",\n    \"    results.append(data)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"if save_data:\\n\",\n    \"    save_path = os.path.join(output_dir, \\\"samples_all.pkl\\\")\\n\",\n    \"\\n\",\n    \"    with open(save_path, \\\"wb\\\") as f:\\n\",\n    \"        pickle.dump(results, f)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"fSApwSaZNndW\"\n   },\n   \"source\": [\n    \"## Render the results!\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"d47Zxo2OKdgZ\"\n   },\n   \"source\": [\n    \"This function allows us to render 3d in colab.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"e9Cd0kCAv9b8\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from google.colab import output\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"output.enable_custom_widget_manager()\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"RjaVuR15NqzF\"\n   },\n   \"source\": [\n    \"### Helper functions\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"28rBYa9NKhlz\"\n   },\n   \"source\": [\n    \"Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"LKdKdwxcyTQ6\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from copy import deepcopy\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def set_rdmol_positions(rdkit_mol, pos):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Args:\\n\",\n    \"        rdkit_mol:  An `rdkit.Chem.rdchem.Mol` object.\\n\",\n    \"        pos: (N_atoms, 3)\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    mol = deepcopy(rdkit_mol)\\n\",\n    \"    set_rdmol_positions_(mol, pos)\\n\",\n    \"    return mol\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def set_rdmol_positions_(mol, pos):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    Args:\\n\",\n    \"        rdkit_mol:  An `rdkit.Chem.rdchem.Mol` object.\\n\",\n    \"        pos: (N_atoms, 3)\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    for i in range(pos.shape[0]):\\n\",\n    \"        mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\\n\",\n    \"    return mol\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"NuE10hcpKmzK\"\n   },\n   \"source\": [\n    \"Process the generated data to make it easy to view.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\"\n    },\n    \"id\": \"KieVE1vc0_Vs\",\n    \"outputId\": \"6faa185d-b1bc-47e8-be18-30d1e557e7c8\"\n   },\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"collect 5 generated molecules in `mols`\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# the model can generate multiple conformations per 2d geometry\\n\",\n    \"num_gen = results[0][\\\"pos_gen\\\"].shape[0]\\n\",\n    \"\\n\",\n    \"# init storage objects\\n\",\n    \"mols_gen = []\\n\",\n    \"mols_orig = []\\n\",\n    \"for to_process in results:\\n\",\n    \"    # store the reference 3d position\\n\",\n    \"    to_process[\\\"pos_ref\\\"] = to_process[\\\"pos_ref\\\"].reshape(-1, to_process[\\\"rdmol\\\"].GetNumAtoms(), 3)\\n\",\n    \"\\n\",\n    \"    # store the generated 3d position\\n\",\n    \"    to_process[\\\"pos_gen\\\"] = to_process[\\\"pos_gen\\\"].reshape(-1, to_process[\\\"rdmol\\\"].GetNumAtoms(), 3)\\n\",\n    \"\\n\",\n    \"    # copy data to new object\\n\",\n    \"    new_mol = set_rdmol_positions(to_process.rdmol, to_process[\\\"pos_gen\\\"][0])\\n\",\n    \"\\n\",\n    \"    # append results\\n\",\n    \"    mols_gen.append(new_mol)\\n\",\n    \"    mols_orig.append(to_process.rdmol)\\n\",\n    \"\\n\",\n    \"print(f\\\"collect {len(mols_gen)} generated molecules in `mols`\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"tin89JwMKp4v\"\n   },\n   \"source\": [\n    \"Import tools to visualize the 2d chemical diagram of the molecule.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"yqV6gllSZn38\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from IPython.display import SVG, display\\n\",\n    \"from rdkit import Chem\\n\",\n    \"from rdkit.Chem.Draw import rdMolDraw2D as MD2\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"TFNKmGddVoOk\"\n   },\n   \"source\": [\n    \"Select molecule to visualize\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"KzuwLlrrVaGc\"\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"idx = 0\\n\",\n    \"assert idx < len(results), \\\"selected molecule that was not generated\\\"\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"hkb8w0_SNtU8\"\n   },\n   \"source\": [\n    \"### Viewing\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"I3R4QBQeKttN\"\n   },\n   \"source\": [\n    \"This 2D rendering is the equivalent of the **input to the model**!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 321\n    },\n    \"id\": \"gkQRWjraaKex\",\n    \"outputId\": \"9c3d1a91-a51d-475d-9e34-2be2459abc47\"\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"image/svg+xml\": [\n       \"<svg baseProfile=\\\"full\\\" height=\\\"300px\\\" version=\\\"1.1\\\" viewBox=\\\"0 0 450 300\\\" width=\\\"450px\\\" xml:space=\\\"preserve\\\" xmlns=\\\"http://www.w3.org/2000/svg\\\" xmlns:rdkit=\\\"http://www.rdkit.org/xml\\\" xmlns:xlink=\\\"http://www.w3.org/1999/xlink\\\">\\n\",\n       \"<!-- END OF HEADER -->\\n\",\n       \"<rect height=\\\"300.0\\\" style=\\\"opacity:1.0;fill:#FFFFFF;stroke:none\\\" width=\\\"450.0\\\" x=\\\"0.0\\\" y=\\\"0.0\\\"> </rect>\\n\",\n       \"<path class=\\\"bond-0 atom-0 atom-1\\\" d=\\\"M 20.5,147.6 L 57.8,136.7\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-1 atom-1 atom-2\\\" d=\\\"M 57.8,136.7 L 67.1,98.9\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-2 atom-2 atom-3\\\" d=\\\"M 67.1,98.9 L 104.4,88.1\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-3 atom-3 atom-4\\\" d=\\\"M 104.4,88.1 L 132.5,115.0\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-4 atom-4 atom-5\\\" d=\\\"M 132.5,115.0 L 128.7,130.5\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-4 atom-4 atom-5\\\" d=\\\"M 128.7,130.5 L 124.9,146.0\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-5 atom-5 atom-6\\\" d=\\\"M 128.7,158.0 L 140.0,168.8\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-5 atom-5 atom-6\\\" d=\\\"M 140.0,168.8 L 151.3,179.7\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-6 atom-6 atom-7\\\" d=\\\"M 155.1,180.6 L 151.3,196.1\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-6 atom-6 atom-7\\\" d=\\\"M 151.3,196.1 L 147.5,211.5\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-6 atom-6 atom-7\\\" d=\\\"M 147.5,178.8 L 143.7,194.2\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-6 atom-6 atom-7\\\" d=\\\"M 143.7,194.2 L 139.9,209.7\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-7 atom-6 atom-8\\\" d=\\\"M 151.3,179.7 L 188.7,168.8\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-8 atom-8 atom-9\\\" d=\\\"M 188.7,168.8 L 216.7,195.8\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-9 atom-9 atom-10\\\" d=\\\"M 216.7,195.8 L 254.1,184.9\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-10 atom-10 atom-11\\\" d=\\\"M 254.1,184.9 L 257.9,169.4\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-10 atom-10 atom-11\\\" d=\\\"M 257.9,169.4 L 261.7,153.9\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-11 atom-11 atom-12\\\" d=\\\"M 268.8,145.5 L 282.4,141.6\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-11 atom-11 atom-12\\\" d=\\\"M 282.4,141.6 L 295.9,137.7\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#CCCC00;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-12 atom-12 atom-13\\\" d=\\\"M 295.0,130.6 L 291.6,118.8\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#CCCC00;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-12 atom-12 atom-13\\\" d=\\\"M 291.6,118.8 L 288.2,107.0\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-12 atom-12 atom-13\\\" d=\\\"M 302.5,128.4 L 299.1,116.6\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#CCCC00;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-12 atom-12 atom-13\\\" d=\\\"M 299.1,116.6 L 295.6,104.9\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-13 atom-12 atom-14\\\" d=\\\"M 306.5,142.3 L 309.9,154.0\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#CCCC00;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-13 atom-12 atom-14\\\" d=\\\"M 309.9,154.0 L 313.3,165.7\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-13 atom-12 atom-14\\\" d=\\\"M 299.0,144.4 L 302.4,156.1\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#CCCC00;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-13 atom-12 atom-14\\\" d=\\\"M 302.4,156.1 L 305.8,167.9\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-14 atom-12 atom-15\\\" d=\\\"M 305.5,134.9 L 321.8,130.1\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#CCCC00;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-14 atom-12 atom-15\\\" d=\\\"M 321.8,130.1 L 338.1,125.4\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-15 atom-15 atom-16\\\" d=\\\"M 338.1,125.4 L 347.4,87.6\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-15 atom-15 atom-16\\\" d=\\\"M 347.0,121.6 L 353.5,95.2\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-16 atom-16 atom-17\\\" d=\\\"M 347.4,87.6 L 384.7,76.8\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-17 atom-17 atom-18\\\" d=\\\"M 384.7,76.8 L 412.8,103.7\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-17 atom-17 atom-18\\\" d=\\\"M 383.5,86.4 L 403.2,105.3\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-18 atom-18 atom-19\\\" d=\\\"M 412.8,103.7 L 403.5,141.5\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-19 atom-19 atom-20\\\" d=\\\"M 403.5,141.5 L 412.1,154.2\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-19 atom-19 atom-20\\\" d=\\\"M 412.1,154.2 L 420.8,166.9\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-19 atom-19 atom-20\\\" d=\\\"M 399.7,149.7 L 405.7,158.6\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-19 atom-19 atom-20\\\" d=\\\"M 405.7,158.6 L 411.7,167.4\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-20 atom-20 atom-21\\\" d=\\\"M 420.1,180.5 L 413.5,189.0\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-20 atom-20 atom-21\\\" d=\\\"M 413.5,189.0 L 406.8,197.5\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-21 atom-21 atom-22\\\" d=\\\"M 395.2,202.1 L 382.8,197.7\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-21 atom-21 atom-22\\\" d=\\\"M 382.8,197.7 L 370.4,193.2\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-22 atom-22 atom-23\\\" d=\\\"M 365.1,184.4 L 365.6,168.4\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-22 atom-22 atom-23\\\" d=\\\"M 365.6,168.4 L 366.2,152.3\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-22 atom-22 atom-23\\\" d=\\\"M 373.1,179.9 L 373.4,168.6\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-22 atom-22 atom-23\\\" d=\\\"M 373.4,168.6 L 373.8,157.4\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-23 atom-11 atom-24\\\" d=\\\"M 257.9,141.9 L 246.6,131.1\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-23 atom-11 atom-24\\\" d=\\\"M 246.6,131.1 L 235.3,120.2\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-24 atom-24 atom-25\\\" d=\\\"M 235.3,120.2 L 197.9,131.1\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-25 atom-5 atom-26\\\" d=\\\"M 117.8,154.4 L 101.8,159.0\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-25 atom-5 atom-26\\\" d=\\\"M 101.8,159.0 L 85.9,163.6\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-26 atom-26 atom-1\\\" d=\\\"M 85.9,163.6 L 57.8,136.7\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-27 atom-25 atom-8\\\" d=\\\"M 197.9,131.1 L 188.7,168.8\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-28 atom-23 atom-15\\\" d=\\\"M 366.2,152.3 L 338.1,125.4\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"bond-29 atom-23 atom-19\\\" d=\\\"M 366.2,152.3 L 403.5,141.5\\\" style=\\\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\\\"/>\\n\",\n       \"<path class=\\\"atom-5\\\" d=\\\"M 120.8 147.3 L 124.4 153.1 Q 124.8 153.7, 125.3 154.7 Q 125.9 155.8, 126.0 155.8 L 126.0 147.3 L 127.4 147.3 L 127.4 158.3 L 125.9 158.3 L 122.0 151.9 Q 121.6 151.2, 121.1 150.3 Q 120.6 149.4, 120.5 149.2 L 120.5 158.3 L 119.1 158.3 L 119.1 147.3 L 120.8 147.3 \\\" fill=\\\"#0000FF\\\"/>\\n\",\n       \"<path class=\\\"atom-7\\\" d=\\\"M 137.0 217.5 Q 137.0 214.9, 138.3 213.4 Q 139.6 211.9, 142.0 211.9 Q 144.5 211.9, 145.8 213.4 Q 147.1 214.9, 147.1 217.5 Q 147.1 220.2, 145.8 221.7 Q 144.4 223.2, 142.0 223.2 Q 139.6 223.2, 138.3 221.7 Q 137.0 220.2, 137.0 217.5 M 142.0 222.0 Q 143.7 222.0, 144.6 220.8 Q 145.5 219.7, 145.5 217.5 Q 145.5 215.3, 144.6 214.2 Q 143.7 213.1, 142.0 213.1 Q 140.4 213.1, 139.4 214.2 Q 138.5 215.3, 138.5 217.5 Q 138.5 219.7, 139.4 220.8 Q 140.4 222.0, 142.0 222.0 \\\" fill=\\\"#FF0000\\\"/>\\n\",\n       \"<path class=\\\"atom-11\\\" d=\\\"M 260.9 141.6 L 264.5 147.5 Q 264.9 148.0, 265.5 149.1 Q 266.1 150.1, 266.1 150.2 L 266.1 141.6 L 267.5 141.6 L 267.5 152.6 L 266.0 152.6 L 262.2 146.3 Q 261.7 145.5, 261.2 144.7 Q 260.8 143.8, 260.6 143.5 L 260.6 152.6 L 259.2 152.6 L 259.2 141.6 L 260.9 141.6 \\\" fill=\\\"#0000FF\\\"/>\\n\",\n       \"<path class=\\\"atom-12\\\" d=\\\"M 297.6 140.1 Q 297.7 140.1, 298.2 140.3 Q 298.8 140.5, 299.3 140.7 Q 299.9 140.8, 300.5 140.8 Q 301.5 140.8, 302.1 140.3 Q 302.7 139.8, 302.7 138.9 Q 302.7 138.3, 302.4 137.9 Q 302.1 137.6, 301.6 137.3 Q 301.2 137.1, 300.4 136.9 Q 299.4 136.6, 298.8 136.3 Q 298.2 136.1, 297.8 135.5 Q 297.4 134.9, 297.4 133.9 Q 297.4 132.5, 298.4 131.6 Q 299.3 130.8, 301.2 130.8 Q 302.4 130.8, 303.9 131.4 L 303.5 132.6 Q 302.2 132.0, 301.2 132.0 Q 300.1 132.0, 299.6 132.5 Q 299.0 132.9, 299.0 133.7 Q 299.0 134.3, 299.3 134.6 Q 299.6 135.0, 300.0 135.2 Q 300.5 135.4, 301.2 135.6 Q 302.2 135.9, 302.8 136.3 Q 303.4 136.6, 303.8 137.2 Q 304.3 137.8, 304.3 138.9 Q 304.3 140.4, 303.2 141.3 Q 302.2 142.1, 300.5 142.1 Q 299.5 142.1, 298.8 141.8 Q 298.1 141.6, 297.2 141.3 L 297.6 140.1 \\\" fill=\\\"#CCCC00\\\"/>\\n\",\n       \"<path class=\\\"atom-13\\\" d=\\\"M 284.8 99.0 Q 284.8 96.3, 286.1 94.8 Q 287.4 93.4, 289.9 93.4 Q 292.3 93.4, 293.6 94.8 Q 294.9 96.3, 294.9 99.0 Q 294.9 101.6, 293.6 103.2 Q 292.3 104.7, 289.9 104.7 Q 287.4 104.7, 286.1 103.2 Q 284.8 101.6, 284.8 99.0 M 289.9 103.4 Q 291.5 103.4, 292.5 102.3 Q 293.4 101.2, 293.4 99.0 Q 293.4 96.8, 292.5 95.7 Q 291.5 94.6, 289.9 94.6 Q 288.2 94.6, 287.3 95.7 Q 286.4 96.8, 286.4 99.0 Q 286.4 101.2, 287.3 102.3 Q 288.2 103.4, 289.9 103.4 \\\" fill=\\\"#FF0000\\\"/>\\n\",\n       \"<path class=\\\"atom-14\\\" d=\\\"M 306.5 173.7 Q 306.5 171.0, 307.8 169.5 Q 309.1 168.1, 311.6 168.1 Q 314.0 168.1, 315.3 169.5 Q 316.6 171.0, 316.6 173.7 Q 316.6 176.3, 315.3 177.9 Q 314.0 179.4, 311.6 179.4 Q 309.1 179.4, 307.8 177.9 Q 306.5 176.4, 306.5 173.7 M 311.6 178.1 Q 313.3 178.1, 314.2 177.0 Q 315.1 175.9, 315.1 173.7 Q 315.1 171.5, 314.2 170.4 Q 313.3 169.3, 311.6 169.3 Q 309.9 169.3, 309.0 170.4 Q 308.1 171.5, 308.1 173.7 Q 308.1 175.9, 309.0 177.0 Q 309.9 178.1, 311.6 178.1 \\\" fill=\\\"#FF0000\\\"/>\\n\",\n       \"<path class=\\\"atom-20\\\" d=\\\"M 422.9 168.2 L 426.5 174.0 Q 426.9 174.6, 427.5 175.6 Q 428.1 176.6, 428.1 176.7 L 428.1 168.2 L 429.5 168.2 L 429.5 179.2 L 428.0 179.2 L 424.2 172.8 Q 423.7 172.0, 423.2 171.2 Q 422.8 170.3, 422.6 170.1 L 422.6 179.2 L 421.2 179.2 L 421.2 168.2 L 422.9 168.2 \\\" fill=\\\"#0000FF\\\"/>\\n\",\n       \"<path class=\\\"atom-21\\\" d=\\\"M 396.5 204.4 Q 396.5 201.8, 397.8 200.3 Q 399.1 198.8, 401.5 198.8 Q 404.0 198.8, 405.3 200.3 Q 406.6 201.8, 406.6 204.4 Q 406.6 207.1, 405.3 208.6 Q 403.9 210.1, 401.5 210.1 Q 399.1 210.1, 397.8 208.6 Q 396.5 207.1, 396.5 204.4 M 401.5 208.9 Q 403.2 208.9, 404.1 207.8 Q 405.0 206.6, 405.0 204.4 Q 405.0 202.3, 404.1 201.2 Q 403.2 200.1, 401.5 200.1 Q 399.8 200.1, 398.9 201.2 Q 398.0 202.2, 398.0 204.4 Q 398.0 206.7, 398.9 207.8 Q 399.8 208.9, 401.5 208.9 \\\" fill=\\\"#FF0000\\\"/>\\n\",\n       \"<path class=\\\"atom-22\\\" d=\\\"M 362.5 185.7 L 366.1 191.5 Q 366.5 192.1, 367.0 193.2 Q 367.6 194.2, 367.6 194.3 L 367.6 185.7 L 369.1 185.7 L 369.1 196.7 L 367.6 196.7 L 363.7 190.4 Q 363.3 189.6, 362.8 188.7 Q 362.3 187.9, 362.2 187.6 L 362.2 196.7 L 360.8 196.7 L 360.8 185.7 L 362.5 185.7 \\\" fill=\\\"#0000FF\\\"/>\\n\",\n       \"</svg>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.SVG object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"mc = Chem.MolFromSmiles(dataset[0][\\\"smiles\\\"])\\n\",\n    \"molSize = (450, 300)\\n\",\n    \"drawer = MD2.MolDraw2DSVG(molSize[0], molSize[1])\\n\",\n    \"drawer.DrawMolecule(mc)\\n\",\n    \"drawer.FinishDrawing()\\n\",\n    \"svg = drawer.GetDrawingText()\\n\",\n    \"display(SVG(svg.replace(\\\"svg:\\\", \\\"\\\")))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"id\": \"z4FDMYMxKw2I\"\n   },\n   \"source\": [\n    \"Generate the 3d molecule!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 17,\n     \"referenced_widgets\": [\n      \"695ab5bbf30a4ab19df1f9f33469f314\",\n      \"eac6a8dcdc9d4335a2e51031793ead29\"\n     ]\n    },\n    \"id\": \"aT1Bkb8YxJfV\",\n    \"outputId\": \"b98870ae-049d-4386-b676-166e9526bda2\"\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"695ab5bbf30a4ab19df1f9f33469f314\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": []\n     },\n     \"metadata\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"colab\": {\n        \"custom_widget_manager\": {\n         \"url\": \"https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js\"\n        }\n       }\n      }\n     },\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"from nglview import show_rdkit as show\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"colab\": {\n     \"base_uri\": \"https://localhost:8080/\",\n     \"height\": 337,\n     \"referenced_widgets\": [\n      \"be446195da2b4ff2aec21ec5ff963a54\",\n      \"c6596896148b4a8a9c57963b67c7782f\",\n      \"2489b5e5648541fbbdceadb05632a050\",\n      \"01e0ba4e5da04914b4652b8d58565d7b\",\n      \"c30e6c2f3e2a44dbbb3d63bd519acaa4\",\n      \"f31c6e40e9b2466a9064a2669933ecd5\",\n      \"19308ccac642498ab8b58462e3f1b0bb\",\n      \"4a081cdc2ec3421ca79dd933b7e2b0c4\",\n      \"e5c0d75eb5e1447abd560c8f2c6017e1\",\n      \"5146907ef6764654ad7d598baebc8b58\",\n      \"144ec959b7604a2cabb5ca46ae5e5379\",\n      \"abce2a80e6304df3899109c6d6cac199\",\n      \"65195cb7a4134f4887e9dd19f3676462\"\n     ]\n    },\n    \"id\": \"pxtq8I-I18C-\",\n    \"outputId\": \"72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7\"\n   },\n   \"outputs\": [\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"be446195da2b4ff2aec21ec5ff963a54\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"NGLWidget()\"\n      ]\n     },\n     \"metadata\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"colab\": {\n        \"custom_widget_manager\": {\n         \"url\": \"https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js\"\n        }\n       }\n      }\n     },\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"# new molecule\\n\",\n    \"show(mols_gen[idx])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"id\": \"KJr4h2mwXeTo\"\n   },\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"accelerator\": \"GPU\",\n  \"colab\": {\n   \"provenance\": []\n  },\n  \"gpuClass\": \"standard\",\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.9\"\n  },\n  \"widgets\": {\n   \"application/vnd.jupyter.widget-state+json\": {\n    \"01e0ba4e5da04914b4652b8d58565d7b\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1\",\n       \"IPY_MODEL_5146907ef6764654ad7d598baebc8b58\"\n      ],\n      \"layout\": \"IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379\"\n     }\n    },\n    \"144ec959b7604a2cabb5ca46ae5e5379\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"19308ccac642498ab8b58462e3f1b0bb\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"1c6246f15b654f4daa11c9bcf997b78c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_bbef741e76ec41b7ab7187b487a383df\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_561f742d418d4721b0670cc8dd62e22c\",\n      \"value\": \"Downloading: 100%\"\n     }\n    },\n    \"2489b5e5648541fbbdceadb05632a050\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ButtonModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ButtonModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ButtonView\",\n      \"button_style\": \"\",\n      \"description\": \"\",\n      \"disabled\": false,\n      \"icon\": \"compress\",\n      \"layout\": \"IPY_MODEL_abce2a80e6304df3899109c6d6cac199\",\n      \"style\": \"IPY_MODEL_65195cb7a4134f4887e9dd19f3676462\",\n      \"tooltip\": \"\"\n     }\n    },\n    \"24d31fc3576e43dd9f8301d2ef3a37ab\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"2918bfaadc8d4b1a9832522c40dfefb8\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"2c9362906e4b40189f16d14aa9a348da\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"42f6c3db29d7484ba6b4f73590abd2f4\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"4a081cdc2ec3421ca79dd933b7e2b0c4\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"SliderStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"SliderStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\",\n      \"handle_color\": null\n     }\n    },\n    \"5146907ef6764654ad7d598baebc8b58\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"IntSliderModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"IntSliderModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"IntSliderView\",\n      \"continuous_update\": true,\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"disabled\": false,\n      \"layout\": \"IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb\",\n      \"max\": 0,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"readout\": true,\n      \"readout_format\": \"d\",\n      \"step\": 1,\n      \"style\": \"IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4\",\n      \"value\": 0\n     }\n    },\n    \"561f742d418d4721b0670cc8dd62e22c\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"6010fc8daa7a44d5aec4b830ec2ebaa1\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_7e0bb1b8d65249d3974200686b193be2\",\n       \"IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a\",\n       \"IPY_MODEL_6526646be5ed415c84d1245b040e629b\"\n      ],\n      \"layout\": \"IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab\"\n     }\n    },\n    \"65195cb7a4134f4887e9dd19f3676462\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ButtonStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ButtonStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"button_color\": null,\n      \"font_weight\": \"\"\n     }\n    },\n    \"6526646be5ed415c84d1245b040e629b\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_a9e388f22a9742aaaf538e22575c9433\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4\",\n      \"value\": \" 401/401 [00:00&lt;00:00, 13.5kB/s]\"\n     }\n    },\n    \"695ab5bbf30a4ab19df1f9f33469f314\": {\n     \"model_module\": \"nglview-js-widgets\",\n     \"model_module_version\": \"3.0.1\",\n     \"model_name\": \"ColormakerRegistryModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"nglview-js-widgets\",\n      \"_model_module_version\": \"3.0.1\",\n      \"_model_name\": \"ColormakerRegistryModel\",\n      \"_msg_ar\": [],\n      \"_msg_q\": [],\n      \"_ready\": false,\n      \"_view_count\": null,\n      \"_view_module\": \"nglview-js-widgets\",\n      \"_view_module_version\": \"3.0.1\",\n      \"_view_name\": \"ColormakerRegistryView\",\n      \"layout\": \"IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29\"\n     }\n    },\n    \"7e0bb1b8d65249d3974200686b193be2\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08\",\n      \"value\": \"Downloading: 100%\"\n     }\n    },\n    \"872915dd1bb84f538c44e26badabafdd\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"a4bfdca35cc54dae8812720f1b276a08\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"a9e388f22a9742aaaf538e22575c9433\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"abce2a80e6304df3899109c6d6cac199\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": \"34px\"\n     }\n    },\n    \"b7feb522161f4cf4b7cc7c1a078ff12d\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HTMLModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HTMLModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HTMLView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_fdc393f3468c432aa0ada05e238a5436\",\n      \"placeholder\": \"​\",\n      \"style\": \"IPY_MODEL_2c9362906e4b40189f16d14aa9a348da\",\n      \"value\": \" 3.27M/3.27M [00:01&lt;00:00, 3.25MB/s]\"\n     }\n    },\n    \"ba98aa6d6a884e4ab8bbb5dfb5e4cf7a\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_e4901541199b45c6a18824627692fc39\",\n      \"max\": 401,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_f915cf874246446595206221e900b2fe\",\n      \"value\": 401\n     }\n    },\n    \"bbef741e76ec41b7ab7187b487a383df\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"be446195da2b4ff2aec21ec5ff963a54\": {\n     \"model_module\": \"nglview-js-widgets\",\n     \"model_module_version\": \"3.0.1\",\n     \"model_name\": \"NGLModel\",\n     \"state\": {\n      \"_camera_orientation\": [\n       -15.519693580202304,\n       -14.065056548036177,\n       -23.53197484807691,\n       0,\n       -23.357853515109753,\n       20.94055073042662,\n       2.888695042134944,\n       0,\n       14.352363398292775,\n       18.870825741878015,\n       -20.744689572909344,\n       0,\n       0.2724999189376831,\n       0.6940000057220459,\n       -0.3734999895095825,\n       1\n      ],\n      \"_camera_str\": \"orthographic\",\n      \"_dom_classes\": [],\n      \"_gui_theme\": null,\n      \"_ibtn_fullscreen\": \"IPY_MODEL_2489b5e5648541fbbdceadb05632a050\",\n      \"_igui\": null,\n      \"_iplayer\": \"IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b\",\n      \"_model_module\": \"nglview-js-widgets\",\n      \"_model_module_version\": \"3.0.1\",\n      \"_model_name\": \"NGLModel\",\n      \"_ngl_color_dict\": {},\n      \"_ngl_coordinate_resource\": {},\n      \"_ngl_full_stage_parameters\": {\n       \"ambientColor\": 14540253,\n       \"ambientIntensity\": 0.2,\n       \"backgroundColor\": \"white\",\n       \"cameraEyeSep\": 0.3,\n       \"cameraFov\": 40,\n       \"cameraType\": \"perspective\",\n       \"clipDist\": 10,\n       \"clipFar\": 100,\n       \"clipNear\": 0,\n       \"fogFar\": 100,\n       \"fogNear\": 50,\n       \"hoverTimeout\": 0,\n       \"impostor\": true,\n       \"lightColor\": 14540253,\n       \"lightIntensity\": 1,\n       \"mousePreset\": \"default\",\n       \"panSpeed\": 1,\n       \"quality\": \"medium\",\n       \"rotateSpeed\": 2,\n       \"sampleLevel\": 0,\n       \"tooltip\": true,\n       \"workerDefault\": true,\n       \"zoomSpeed\": 1.2\n      },\n      \"_ngl_msg_archive\": [\n       {\n        \"args\": [\n         {\n          \"binary\": false,\n          \"data\": \"HETATM    1  C1  UNL     1      -0.025   3.128   2.316  1.00  0.00           C  \\nHETATM    2  H1  UNL     1       0.183   3.657   2.823  1.00  0.00           H  \\nHETATM    3  C2  UNL     1       0.590   3.559   0.963  1.00  0.00           C  \\nHETATM    4  C3  UNL     1       0.056   4.479   0.406  1.00  0.00           C  \\nHETATM    5  C4  UNL     1      -0.219   4.802  -1.065  1.00  0.00           C  \\nHETATM    6  H2  UNL     1       0.686   4.431  -1.575  1.00  0.00           H  \\nHETATM    7  H3  UNL     1      -0.524   5.217  -1.274  1.00  0.00           H  \\nHETATM    8  C5  UNL     1      -1.284   3.766  -1.342  1.00  0.00           C  \\nHETATM    9  N1  UNL     1      -1.073   2.494  -0.580  1.00  0.00           N  \\nHETATM   10  C6  UNL     1      -1.909   1.494  -0.964  1.00  0.00           C  \\nHETATM   11  O1  UNL     1      -2.487   1.531  -2.092  1.00  0.00           O  \\nHETATM   12  C7  UNL     1      -2.232   0.242  -0.130  1.00  0.00           C  \\nHETATM   13  C8  UNL     1      -2.161  -1.057  -1.037  1.00  0.00           C  \\nHETATM   14  C9  UNL     1      -0.744  -1.111  -1.610  1.00  0.00           C  \\nHETATM   15  N2  UNL     1       0.290  -0.917  -0.628  1.00  0.00           N  \\nHETATM   16  S1  UNL     1       1.717  -1.597  -0.914  1.00  0.00           S  \\nHETATM   17  O2  UNL     1       1.960  -1.671  -2.338  1.00  0.00           O  \\nHETATM   18  O3  UNL     1       2.713  -0.968  -0.082  1.00  0.00           O  \\nHETATM   19  C10 UNL     1       1.425  -3.170  -0.345  1.00  0.00           C  \\nHETATM   20  C11 UNL     1       1.225  -4.400  -1.271  1.00  0.00           C  \\nHETATM   21  C12 UNL     1       1.314  -5.913  -0.895  1.00  0.00           C  \\nHETATM   22  C13 UNL     1       1.823  -6.229   0.386  1.00  0.00           C  \\nHETATM   23  C14 UNL     1       2.031  -5.110   1.365  1.00  0.00           C  \\nHETATM   24  N3  UNL     1       1.850  -5.267   2.712  1.00  0.00           N  \\nHETATM   25  O4  UNL     1       1.382  -4.029   3.126  1.00  0.00           O  \\nHETATM   26  N4  UNL     1       1.300  -3.023   2.154  1.00  0.00           N  \\nHETATM   27  C15 UNL     1       1.731  -3.672   1.032  1.00  0.00           C  \\nHETATM   28  H4  UNL     1       2.380  -6.874   0.436  1.00  0.00           H  \\nHETATM   29  H5  UNL     1       0.704  -6.526  -1.420  1.00  0.00           H  \\nHETATM   30  H6  UNL     1       1.144  -4.035  -2.291  1.00  0.00           H  \\nHETATM   31  C16 UNL     1       0.044  -0.371   0.685  1.00  0.00           C  \\nHETATM   32  C17 UNL     1      -1.352  -0.045   1.077  1.00  0.00           C  \\nHETATM   33  H7  UNL     1      -1.395   0.770   1.768  1.00  0.00           H  \\nHETATM   34  H8  UNL     1      -1.792  -0.941   1.582  1.00  0.00           H  \\nHETATM   35  H9  UNL     1       0.583  -1.035   1.393  1.00  0.00           H  \\nHETATM   36  H10 UNL     1       0.664   0.613   0.663  1.00  0.00           H  \\nHETATM   37  H11 UNL     1      -0.631  -0.267  -2.335  1.00  0.00           H  \\nHETATM   38  H12 UNL     1      -0.571  -2.046  -2.098  1.00  0.00           H  \\nHETATM   39  H13 UNL     1      -2.872  -0.992  -1.826  1.00  0.00           H  \\nHETATM   40  H14 UNL     1      -2.370  -1.924  -0.444  1.00  0.00           H  \\nHETATM   41  H15 UNL     1      -3.258   0.364   0.197  1.00  0.00           H  \\nHETATM   42  C18 UNL     1       0.276   2.337  -0.078  1.00  0.00           C  \\nHETATM   43  H16 UNL     1       0.514   1.371   0.252  1.00  0.00           H  \\nHETATM   44  H17 UNL     1       0.988   2.413  -0.949  1.00  0.00           H  \\nHETATM   45  H18 UNL     1      -1.349   3.451  -2.379  1.00  0.00           H  \\nHETATM   46  H19 UNL     1      -2.224   4.055  -0.958  1.00  0.00           H  \\nHETATM   47  H20 UNL     1       0.793   5.486   0.669  1.00  0.00           H  \\nHETATM   48  H21 UNL     1      -0.849   4.974   0.937  1.00  0.00           H  \\nHETATM   49  H22 UNL     1       1.667   3.431   1.070  1.00  0.00           H  \\nHETATM   50  H23 UNL     1       0.379   2.143   2.689  1.00  0.00           H  \\nHETATM   51  H24 UNL     1      -1.094   2.983   2.223  1.00  0.00           H  \\nCONECT    1    2    3   50   51\\nCONECT    3    4   42   49\\nCONECT    4    5   47   48\\nCONECT    5    6    7    8\\nCONECT    8    9   45   46\\nCONECT    9   10   42\\nCONECT   10   11   11   12\\nCONECT   12   13   32   41\\nCONECT   13   14   39   40\\nCONECT   14   15   37   38\\nCONECT   15   16   31\\nCONECT   16   17   17   18   18\\nCONECT   16   19\\nCONECT   19   20   20   27\\nCONECT   20   21   30\\nCONECT   21   22   22   29\\nCONECT   22   23   28\\nCONECT   23   24   24   27\\nCONECT   24   25\\nCONECT   25   26\\nCONECT   26   27   27\\nCONECT   31   32   35   36\\nCONECT   32   33   34\\nCONECT   42   43   44\\nEND\\n\",\n          \"type\": \"blob\"\n         }\n        ],\n        \"kwargs\": {\n         \"defaultRepresentation\": true,\n         \"ext\": \"pdb\"\n        },\n        \"methodName\": \"loadFile\",\n        \"reconstruc_color_scheme\": false,\n        \"target\": \"Stage\",\n        \"type\": \"call_method\"\n       }\n      ],\n      \"_ngl_original_stage_parameters\": {\n       \"ambientColor\": 14540253,\n       \"ambientIntensity\": 0.2,\n       \"backgroundColor\": \"white\",\n       \"cameraEyeSep\": 0.3,\n       \"cameraFov\": 40,\n       \"cameraType\": \"perspective\",\n       \"clipDist\": 10,\n       \"clipFar\": 100,\n       \"clipNear\": 0,\n       \"fogFar\": 100,\n       \"fogNear\": 50,\n       \"hoverTimeout\": 0,\n       \"impostor\": true,\n       \"lightColor\": 14540253,\n       \"lightIntensity\": 1,\n       \"mousePreset\": \"default\",\n       \"panSpeed\": 1,\n       \"quality\": \"medium\",\n       \"rotateSpeed\": 2,\n       \"sampleLevel\": 0,\n       \"tooltip\": true,\n       \"workerDefault\": true,\n       \"zoomSpeed\": 1.2\n      },\n      \"_ngl_repr_dict\": {\n       \"0\": {\n        \"0\": {\n         \"params\": {\n          \"aspectRatio\": 1.5,\n          \"assembly\": \"default\",\n          \"bondScale\": 0.3,\n          \"bondSpacing\": 0.75,\n          \"clipCenter\": {\n           \"x\": 0,\n           \"y\": 0,\n           \"z\": 0\n          },\n          \"clipNear\": 0,\n          \"clipRadius\": 0,\n          \"colorMode\": \"hcl\",\n          \"colorReverse\": false,\n          \"colorScale\": \"\",\n          \"colorScheme\": \"element\",\n          \"colorValue\": 9474192,\n          \"cylinderOnly\": false,\n          \"defaultAssembly\": \"\",\n          \"depthWrite\": true,\n          \"diffuse\": 16777215,\n          \"diffuseInterior\": false,\n          \"disableImpostor\": false,\n          \"disablePicking\": false,\n          \"flatShaded\": false,\n          \"interiorColor\": 2236962,\n          \"interiorDarkening\": 0,\n          \"lazy\": false,\n          \"lineOnly\": false,\n          \"linewidth\": 2,\n          \"matrix\": {\n           \"elements\": [\n            1,\n            0,\n            0,\n            0,\n            0,\n            1,\n            0,\n            0,\n            0,\n            0,\n            1,\n            0,\n            0,\n            0,\n            0,\n            1\n           ]\n          },\n          \"metalness\": 0,\n          \"multipleBond\": \"off\",\n          \"opacity\": 1,\n          \"openEnded\": true,\n          \"quality\": \"high\",\n          \"radialSegments\": 20,\n          \"radiusData\": {},\n          \"radiusScale\": 2,\n          \"radiusSize\": 0.15,\n          \"radiusType\": \"size\",\n          \"roughness\": 0.4,\n          \"sele\": \"\",\n          \"side\": \"double\",\n          \"sphereDetail\": 2,\n          \"useInteriorColor\": true,\n          \"visible\": true,\n          \"wireframe\": false\n         },\n         \"type\": \"ball+stick\"\n        }\n       },\n       \"1\": {\n        \"0\": {\n         \"params\": {\n          \"aspectRatio\": 1.5,\n          \"assembly\": \"default\",\n          \"bondScale\": 0.3,\n          \"bondSpacing\": 0.75,\n          \"clipCenter\": {\n           \"x\": 0,\n           \"y\": 0,\n           \"z\": 0\n          },\n          \"clipNear\": 0,\n          \"clipRadius\": 0,\n          \"colorMode\": \"hcl\",\n          \"colorReverse\": false,\n          \"colorScale\": \"\",\n          \"colorScheme\": \"element\",\n          \"colorValue\": 9474192,\n          \"cylinderOnly\": false,\n          \"defaultAssembly\": \"\",\n          \"depthWrite\": true,\n          \"diffuse\": 16777215,\n          \"diffuseInterior\": false,\n          \"disableImpostor\": false,\n          \"disablePicking\": false,\n          \"flatShaded\": false,\n          \"interiorColor\": 2236962,\n          \"interiorDarkening\": 0,\n          \"lazy\": false,\n          \"lineOnly\": false,\n          \"linewidth\": 2,\n          \"matrix\": {\n           \"elements\": [\n            1,\n            0,\n            0,\n            0,\n            0,\n            1,\n            0,\n            0,\n            0,\n            0,\n            1,\n            0,\n            0,\n            0,\n            0,\n            1\n           ]\n          },\n          \"metalness\": 0,\n          \"multipleBond\": \"off\",\n          \"opacity\": 1,\n          \"openEnded\": true,\n          \"quality\": \"high\",\n          \"radialSegments\": 20,\n          \"radiusData\": {},\n          \"radiusScale\": 2,\n          \"radiusSize\": 0.15,\n          \"radiusType\": \"size\",\n          \"roughness\": 0.4,\n          \"sele\": \"\",\n          \"side\": \"double\",\n          \"sphereDetail\": 2,\n          \"useInteriorColor\": true,\n          \"visible\": true,\n          \"wireframe\": false\n         },\n         \"type\": \"ball+stick\"\n        }\n       }\n      },\n      \"_ngl_serialize\": false,\n      \"_ngl_version\": \"\",\n      \"_ngl_view_id\": [\n       \"FB989FD1-5B9C-446B-8914-6B58AF85446D\"\n      ],\n      \"_player_dict\": {},\n      \"_scene_position\": {},\n      \"_scene_rotation\": {},\n      \"_synced_model_ids\": [],\n      \"_synced_repr_model_ids\": [],\n      \"_view_count\": null,\n      \"_view_height\": \"\",\n      \"_view_module\": \"nglview-js-widgets\",\n      \"_view_module_version\": \"3.0.1\",\n      \"_view_name\": \"NGLView\",\n      \"_view_width\": \"\",\n      \"background\": \"white\",\n      \"frame\": 0,\n      \"gui_style\": null,\n      \"layout\": \"IPY_MODEL_c6596896148b4a8a9c57963b67c7782f\",\n      \"max_frame\": 0,\n      \"n_components\": 2,\n      \"picked\": {}\n     }\n    },\n    \"c2321b3bff6f490ca12040a20308f555\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"FloatProgressModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"FloatProgressModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"ProgressView\",\n      \"bar_style\": \"success\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"layout\": \"IPY_MODEL_872915dd1bb84f538c44e26badabafdd\",\n      \"max\": 3271865,\n      \"min\": 0,\n      \"orientation\": \"horizontal\",\n      \"style\": \"IPY_MODEL_d022575f1fa2446d891650897f187b4d\",\n      \"value\": 3271865\n     }\n    },\n    \"c30e6c2f3e2a44dbbb3d63bd519acaa4\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"c6596896148b4a8a9c57963b67c7782f\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"d022575f1fa2446d891650897f187b4d\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"d90f304e9560472eacfbdd11e46765eb\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"HBoxModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"HBoxModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"HBoxView\",\n      \"box_style\": \"\",\n      \"children\": [\n       \"IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c\",\n       \"IPY_MODEL_c2321b3bff6f490ca12040a20308f555\",\n       \"IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d\"\n      ],\n      \"layout\": \"IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f\"\n     }\n    },\n    \"e2d368556e494ae7ae4e2e992af2cd4f\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"e4901541199b45c6a18824627692fc39\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"e5c0d75eb5e1447abd560c8f2c6017e1\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"PlayModel\",\n     \"state\": {\n      \"_dom_classes\": [],\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"PlayModel\",\n      \"_playing\": false,\n      \"_repeat\": false,\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/controls\",\n      \"_view_module_version\": \"1.5.0\",\n      \"_view_name\": \"PlayView\",\n      \"description\": \"\",\n      \"description_tooltip\": null,\n      \"disabled\": false,\n      \"interval\": 100,\n      \"layout\": \"IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4\",\n      \"max\": 0,\n      \"min\": 0,\n      \"show_repeat\": true,\n      \"step\": 1,\n      \"style\": \"IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5\",\n      \"value\": 0\n     }\n    },\n    \"eac6a8dcdc9d4335a2e51031793ead29\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    },\n    \"f31c6e40e9b2466a9064a2669933ecd5\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"DescriptionStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"DescriptionStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"description_width\": \"\"\n     }\n    },\n    \"f915cf874246446595206221e900b2fe\": {\n     \"model_module\": \"@jupyter-widgets/controls\",\n     \"model_module_version\": \"1.5.0\",\n     \"model_name\": \"ProgressStyleModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/controls\",\n      \"_model_module_version\": \"1.5.0\",\n      \"_model_name\": \"ProgressStyleModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"StyleView\",\n      \"bar_color\": null,\n      \"description_width\": \"\"\n     }\n    },\n    \"fdc393f3468c432aa0ada05e238a5436\": {\n     \"model_module\": \"@jupyter-widgets/base\",\n     \"model_module_version\": \"1.2.0\",\n     \"model_name\": \"LayoutModel\",\n     \"state\": {\n      \"_model_module\": \"@jupyter-widgets/base\",\n      \"_model_module_version\": \"1.2.0\",\n      \"_model_name\": \"LayoutModel\",\n      \"_view_count\": null,\n      \"_view_module\": \"@jupyter-widgets/base\",\n      \"_view_module_version\": \"1.2.0\",\n      \"_view_name\": \"LayoutView\",\n      \"align_content\": null,\n      \"align_items\": null,\n      \"align_self\": null,\n      \"border\": null,\n      \"bottom\": null,\n      \"display\": null,\n      \"flex\": null,\n      \"flex_flow\": null,\n      \"grid_area\": null,\n      \"grid_auto_columns\": null,\n      \"grid_auto_flow\": null,\n      \"grid_auto_rows\": null,\n      \"grid_column\": null,\n      \"grid_gap\": null,\n      \"grid_row\": null,\n      \"grid_template_areas\": null,\n      \"grid_template_columns\": null,\n      \"grid_template_rows\": null,\n      \"height\": null,\n      \"justify_content\": null,\n      \"justify_items\": null,\n      \"left\": null,\n      \"margin\": null,\n      \"max_height\": null,\n      \"max_width\": null,\n      \"min_height\": null,\n      \"min_width\": null,\n      \"object_fit\": null,\n      \"object_position\": null,\n      \"order\": null,\n      \"overflow\": null,\n      \"overflow_x\": null,\n      \"overflow_y\": null,\n      \"padding\": null,\n      \"right\": null,\n      \"top\": null,\n      \"visibility\": null,\n      \"width\": null\n     }\n    }\n   }\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "examples/research_projects/gligen/README.md",
    "content": "# GLIGEN: Open-Set Grounded Text-to-Image Generation\n\nThese scripts contain the code to prepare the grounding data and train the GLIGEN model on COCO dataset.\n\n### Install the requirements\n\n```bash\nconda create -n diffusers python==3.10\nconda activate diffusers\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell e.g. a notebook\n\n```python\nfrom accelerate.utils import write_basic_config\n\nwrite_basic_config()\n```\n\n### Prepare the training data\n\nIf you want to make your own grounding data, you need to install the requirements.\n\nI used [RAM](https://github.com/xinyu1205/recognize-anything) to tag\nimages, [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO/issues?q=refer) to detect objects,\nand [BLIP2](https://huggingface.co/docs/transformers/en/model_doc/blip-2) to caption instances.\n\nOnly RAM needs to be installed manually:\n\n```bash\npip install git+https://github.com/xinyu1205/recognize-anything.git --no-deps\n```\n\nDownload the pre-trained model:\n\n```bash\nhf download --resume-download xinyu1205/recognize_anything_model ram_swin_large_14m.pth\nhf download --resume-download IDEA-Research/grounding-dino-base\nhf download --resume-download Salesforce/blip2-flan-t5-xxl\nhf download --resume-download clip-vit-large-patch14\nhf download --resume-download masterful/gligen-1-4-generation-text-box\n```\n\nMake the training data on 8 GPUs:\n\n```bash\ntorchrun --master_port 17673 --nproc_per_node=8 make_datasets.py \\\n    --data_root /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \\\n    --save_root /root/gligen_data \\\n    --ram_checkpoint /root/.cache/huggingface/hub/models--xinyu1205--recognize_anything_model/snapshots/ebc52dc741e86466202a5ab8ab22eae6e7d48bf1/ram_swin_large_14m.pth\n```\n\nYou can download the COCO training data from\n\n```bash\nhf download --resume-download Hzzone/GLIGEN_COCO coco_train2017.pth\n```\n\nIt's in the format of\n\n```json\n[\n  ...\n  {\n    'file_path': Path,\n    'annos': [\n      {\n        'caption': Instance\n        Caption,\n        'bbox': bbox\n        in\n        xyxy,\n        'text_embeddings_before_projection': CLIP\n        text\n        embedding\n        before\n        linear\n        projection\n      }\n    ]\n  }\n  ...\n]\n```\n\n### Training commands\n\nThe training script is heavily based\non https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py\n\n```bash\naccelerate launch train_gligen_text.py \\\n    --data_path /root/data/zhizhonghuang/coco_train2017.pth \\\n    --image_path /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \\\n    --train_batch_size 8 \\\n    --max_train_steps 100000 \\\n    --checkpointing_steps 1000 \\\n    --checkpoints_total_limit 10 \\\n    --learning_rate 5e-5 \\\n    --dataloader_num_workers 16 \\\n    --mixed_precision fp16 \\\n    --report_to wandb \\\n    --tracker_project_name gligen \\\n    --output_dir /root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\n```\n\nI trained the model on 8 A100 GPUs for about 11 hours (at least 24GB GPU memory). The generated images will follow the\nlayout possibly at 50k iterations.\n\nNote that although the pre-trained GLIGEN model has been loaded, the parameters of `fuser` and `position_net` have been reset (see line 420 in `train_gligen_text.py`)\n\nThe trained model can be downloaded from\n\n```bash\nhf download --resume-download Hzzone/GLIGEN_COCO config.json diffusion_pytorch_model.safetensors\n```\n\nYou can run `demo.ipynb` to visualize the generated images.\n\nExample prompts:\n\n```python\nprompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\nboxes = [[0.041015625, 0.548828125, 0.453125, 0.859375],\n         [0.525390625, 0.552734375, 0.93359375, 0.865234375],\n         [0.12890625, 0.015625, 0.412109375, 0.279296875],\n         [0.578125, 0.08203125, 0.857421875, 0.27734375]]\ngligen_phrases = ['a green car', 'a blue truck', 'a red air balloon', 'a bird']\n```\n\nExample images:\n![alt text](generated-images-100000-00.png)\n\n### Citation\n\n```\n@article{li2023gligen,\n  title={GLIGEN: Open-Set Grounded Text-to-Image Generation},\n  author={Li, Yuheng and Liu, Haotian and Wu, Qingyang and Mu, Fangzhou and Yang, Jianwei and Gao, Jianfeng and Li, Chunyuan and Lee, Yong Jae},\n  journal={CVPR},\n  year={2023}\n}\n```"
  },
  {
    "path": "examples/research_projects/gligen/dataset.py",
    "content": "import os\nimport random\n\nimport torch\nimport torchvision.transforms as transforms\nfrom PIL import Image\n\n\ndef recalculate_box_and_verify_if_valid(x, y, w, h, image_size, original_image_size, min_box_size):\n    scale = image_size / min(original_image_size)\n    crop_y = (original_image_size[1] * scale - image_size) // 2\n    crop_x = (original_image_size[0] * scale - image_size) // 2\n    x0 = max(x * scale - crop_x, 0)\n    y0 = max(y * scale - crop_y, 0)\n    x1 = min((x + w) * scale - crop_x, image_size)\n    y1 = min((y + h) * scale - crop_y, image_size)\n    if (x1 - x0) * (y1 - y0) / (image_size * image_size) < min_box_size:\n        return False, (None, None, None, None)\n    return True, (x0, y0, x1, y1)\n\n\nclass COCODataset(torch.utils.data.Dataset):\n    def __init__(\n        self,\n        data_path,\n        image_path,\n        image_size=512,\n        min_box_size=0.01,\n        max_boxes_per_data=8,\n        tokenizer=None,\n    ):\n        super().__init__()\n        self.min_box_size = min_box_size\n        self.max_boxes_per_data = max_boxes_per_data\n        self.image_size = image_size\n        self.image_path = image_path\n        self.tokenizer = tokenizer\n        self.transforms = transforms.Compose(\n            [\n                transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(image_size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n        self.data_list = torch.load(data_path, map_location=\"cpu\")\n\n    def __getitem__(self, index):\n        if self.max_boxes_per_data > 99:\n            assert False, \"Are you sure setting such large number of boxes per image?\"\n\n        out = {}\n\n        data = self.data_list[index]\n        image = Image.open(os.path.join(self.image_path, data[\"file_path\"])).convert(\"RGB\")\n        original_image_size = image.size\n        out[\"pixel_values\"] = self.transforms(image)\n\n        annos = data[\"annos\"]\n\n        areas, valid_annos = [], []\n        for anno in annos:\n            # x, y, w, h = anno['bbox']\n            x0, y0, x1, y1 = anno[\"bbox\"]\n            x, y, w, h = x0, y0, x1 - x0, y1 - y0\n            valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(\n                x, y, w, h, self.image_size, original_image_size, self.min_box_size\n            )\n            if valid:\n                anno[\"bbox\"] = [x0, y0, x1, y1]\n                areas.append((x1 - x0) * (y1 - y0))\n                valid_annos.append(anno)\n\n        # Sort according to area and choose the largest N objects\n        wanted_idxs = torch.tensor(areas).sort(descending=True)[1]\n        wanted_idxs = wanted_idxs[: self.max_boxes_per_data]\n        valid_annos = [valid_annos[i] for i in wanted_idxs]\n\n        out[\"boxes\"] = torch.zeros(self.max_boxes_per_data, 4)\n        out[\"masks\"] = torch.zeros(self.max_boxes_per_data)\n        out[\"text_embeddings_before_projection\"] = torch.zeros(self.max_boxes_per_data, 768)\n\n        for i, anno in enumerate(valid_annos):\n            out[\"boxes\"][i] = torch.tensor(anno[\"bbox\"]) / self.image_size\n            out[\"masks\"][i] = 1\n            out[\"text_embeddings_before_projection\"][i] = anno[\"text_embeddings_before_projection\"]\n\n        prob_drop_boxes = 0.1\n        if random.random() < prob_drop_boxes:\n            out[\"masks\"][:] = 0\n\n        caption = random.choice(data[\"captions\"])\n\n        prob_drop_captions = 0.5\n        if random.random() < prob_drop_captions:\n            caption = \"\"\n        caption = self.tokenizer(\n            caption,\n            max_length=self.tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        out[\"caption\"] = caption\n\n        return out\n\n    def __len__(self):\n        return len(self.data_list)\n"
  },
  {
    "path": "examples/research_projects/gligen/demo.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"The autoreload extension is already loaded. To reload it, use:\\n\",\n      \"  %reload_ext autoreload\\n\"\n     ]\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/root/miniconda/envs/densecaption/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\\n\",\n      \"  from .autonotebook import tqdm as notebook_tqdm\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"%load_ext autoreload\\n\",\n    \"%autoreload 2\\n\",\n    \"\\n\",\n    \"from diffusers import StableDiffusionGLIGENPipeline\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import CLIPTextModel, CLIPTokenizer\\n\",\n    \"\\n\",\n    \"import diffusers\\n\",\n    \"from diffusers import (\\n\",\n    \"    AutoencoderKL,\\n\",\n    \"    DDPMScheduler,\\n\",\n    \"    EulerDiscreteScheduler,\\n\",\n    \"    UNet2DConditionModel,\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\\n\",\n    \"\\n\",\n    \"pretrained_model_name_or_path = \\\"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\\\"\\n\",\n    \"\\n\",\n    \"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\\\"tokenizer\\\")\\n\",\n    \"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\\\"scheduler\\\")\\n\",\n    \"text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\\\"text_encoder\\\")\\n\",\n    \"vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\\\"vae\\\")\\n\",\n    \"# unet = UNet2DConditionModel.from_pretrained(\\n\",\n    \"#     pretrained_model_name_or_path, subfolder=\\\"unet\\\"\\n\",\n    \"# )\\n\",\n    \"\\n\",\n    \"noise_scheduler = EulerDiscreteScheduler.from_config(noise_scheduler.config)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"unet = UNet2DConditionModel.from_pretrained(\\\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion_gligen.pipeline_stable_diffusion_gligen.StableDiffusionGLIGENPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"pipe = StableDiffusionGLIGENPipeline(\\n\",\n    \"    vae,\\n\",\n    \"    text_encoder,\\n\",\n    \"    tokenizer,\\n\",\n    \"    unet,\\n\",\n    \"    noise_scheduler,\\n\",\n    \"    safety_checker=None,\\n\",\n    \"    feature_extractor=None,\\n\",\n    \")\\n\",\n    \"pipe = pipe.to(\\\"cuda\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\\n\",\n    \"# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\\n\",\n    \"\\n\",\n    \"# prompt = 'A realistic top-down view of a wooden table with two apples on it'\\n\",\n    \"# gen_boxes = [('a wooden table', [20, 148, 472, 216]), ('an apple', [150, 226, 100, 100]), ('an apple', [280, 226, 100, 100])]\\n\",\n    \"\\n\",\n    \"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\\n\",\n    \"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\\n\",\n    \"\\n\",\n    \"prompt = \\\"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\\\"\\n\",\n    \"gen_boxes = [(\\\"a steam boat\\\", [232, 225, 257, 149]), (\\\"a jumping pink dolphin\\\", [21, 249, 189, 123])]\\n\",\n    \"\\n\",\n    \"boxes = np.array([x[1] for x in gen_boxes])\\n\",\n    \"boxes = boxes / 512\\n\",\n    \"boxes[:, 2] = boxes[:, 0] + boxes[:, 2]\\n\",\n    \"boxes[:, 3] = boxes[:, 1] + boxes[:, 3]\\n\",\n    \"boxes = boxes.tolist()\\n\",\n    \"gligen_phrases = [x[0] for x in gen_boxes]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"/root/miniconda/envs/densecaption/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py:683: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'.\\n\",\n      \"  num_channels_latents = self.unet.in_channels\\n\",\n      \"/root/miniconda/envs/densecaption/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py:716: FutureWarning: Accessing config attribute `cross_attention_dim` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'cross_attention_dim' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.cross_attention_dim'.\\n\",\n      \"  max_objs, self.unet.cross_attention_dim, device=device, dtype=self.text_encoder.dtype\\n\",\n      \"100%|██████████| 50/50 [01:21<00:00,  1.64s/it]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"images = pipe(\\n\",\n    \"    prompt=prompt,\\n\",\n    \"    gligen_phrases=gligen_phrases,\\n\",\n    \"    gligen_boxes=boxes,\\n\",\n    \"    gligen_scheduled_sampling_beta=1.0,\\n\",\n    \"    output_type=\\\"pil\\\",\\n\",\n    \"    num_inference_steps=50,\\n\",\n    \"    negative_prompt=\\\"artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate\\\",\\n\",\n    \"    num_images_per_prompt=16,\\n\",\n    \").images\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"diffusers.utils.make_image_grid(images, 4, len(images) // 4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.11.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "examples/research_projects/gligen/make_datasets.py",
    "content": "import argparse\nimport os\nimport random\n\nimport torch\nimport torchvision\nimport torchvision.transforms as TS\nfrom PIL import Image\nfrom ram import inference_ram\nfrom ram.models import ram\nfrom tqdm import tqdm\nfrom transformers import (\n    AutoModelForZeroShotObjectDetection,\n    AutoProcessor,\n    Blip2ForConditionalGeneration,\n    Blip2Processor,\n    CLIPTextModel,\n    CLIPTokenizer,\n)\n\n\ntorch.autograd.set_grad_enabled(False)\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"Caption Generation script\", add_help=False)\n    parser.add_argument(\"--data_root\", type=str, required=True, help=\"path to COCO\")\n    parser.add_argument(\"--save_root\", type=str, required=True, help=\"path to save\")\n    parser.add_argument(\"--ram_checkpoint\", type=str, required=True, help=\"path to save\")\n    args = parser.parse_args()\n\n    # ram_checkpoint = '/root/.cache/huggingface/hub/models--xinyu1205--recognize_anything_model/snapshots/ebc52dc741e86466202a5ab8ab22eae6e7d48bf1/ram_swin_large_14m.pth'\n    # data_root = '/mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017'\n    # save_root = '/root/gligen_data'\n    box_threshold = 0.25\n    text_threshold = 0.2\n\n    import torch.distributed as dist\n\n    dist.init_process_group(backend=\"nccl\", init_method=\"env://\")\n    local_rank = torch.distributed.get_rank() % torch.cuda.device_count()\n    device = f\"cuda:{local_rank}\"\n    torch.cuda.set_device(local_rank)\n\n    ram_model = ram(pretrained=args.ram_checkpoint, image_size=384, vit=\"swin_l\").cuda().eval()\n    ram_processor = TS.Compose(\n        [TS.Resize((384, 384)), TS.ToTensor(), TS.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]\n    )\n\n    grounding_dino_processor = AutoProcessor.from_pretrained(\"IDEA-Research/grounding-dino-base\")\n    grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(\n        \"IDEA-Research/grounding-dino-base\"\n    ).cuda()\n\n    blip2_processor = Blip2Processor.from_pretrained(\"Salesforce/blip2-flan-t5-xxl\")\n    blip2_model = Blip2ForConditionalGeneration.from_pretrained(\n        \"Salesforce/blip2-flan-t5-xxl\", torch_dtype=torch.float16\n    ).cuda()\n\n    clip_text_encoder = CLIPTextModel.from_pretrained(\"openai/clip-vit-large-patch14\").cuda()\n    clip_tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\")\n\n    image_paths = [os.path.join(args.data_root, x) for x in os.listdir(args.data_root)]\n    random.shuffle(image_paths)\n\n    for image_path in tqdm.tqdm(image_paths):\n        pth_path = os.path.join(args.save_root, os.path.basename(image_path))\n        if os.path.exists(pth_path):\n            continue\n\n        sample = {\"file_path\": os.path.basename(image_path), \"annos\": []}\n\n        raw_image = Image.open(image_path).convert(\"RGB\")\n\n        res = inference_ram(ram_processor(raw_image).unsqueeze(0).cuda(), ram_model)\n\n        text = res[0].replace(\" |\", \".\")\n\n        inputs = grounding_dino_processor(images=raw_image, text=text, return_tensors=\"pt\")\n        inputs = {k: v.cuda() for k, v in inputs.items()}\n        outputs = grounding_dino_model(**inputs)\n\n        results = grounding_dino_processor.post_process_grounded_object_detection(\n            outputs,\n            inputs[\"input_ids\"],\n            box_threshold=box_threshold,\n            text_threshold=text_threshold,\n            target_sizes=[raw_image.size[::-1]],\n        )\n        boxes = results[0][\"boxes\"]\n        labels = results[0][\"labels\"]\n        scores = results[0][\"scores\"]\n        indices = torchvision.ops.nms(boxes, scores, 0.5)\n        boxes = boxes[indices]\n        category_names = [labels[i] for i in indices]\n\n        for i, bbox in enumerate(boxes):\n            bbox = bbox.tolist()\n            inputs = blip2_processor(images=raw_image.crop(bbox), return_tensors=\"pt\")\n            inputs = {k: v.cuda().to(torch.float16) for k, v in inputs.items()}\n            outputs = blip2_model.generate(**inputs)\n            caption = blip2_processor.decode(outputs[0], skip_special_tokens=True)\n            inputs = clip_tokenizer(\n                caption,\n                padding=\"max_length\",\n                max_length=clip_tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            inputs = {k: v.cuda() for k, v in inputs.items()}\n            text_embeddings_before_projection = clip_text_encoder(**inputs).pooler_output.squeeze(0)\n\n            sample[\"annos\"].append(\n                {\n                    \"caption\": caption,\n                    \"bbox\": bbox,\n                    \"text_embeddings_before_projection\": text_embeddings_before_projection,\n                }\n            )\n        torch.save(sample, pth_path)\n"
  },
  {
    "path": "examples/research_projects/gligen/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2\ndiffusers\nscipy\ntimm\nfairscale\nwandb"
  },
  {
    "path": "examples/research_projects/gligen/train_gligen_text.py",
    "content": "# from accelerate.utils import write_basic_config\n#\n# write_basic_config()\nimport argparse\nimport logging\nimport math\nimport os\nimport shutil\nfrom pathlib import Path\n\nimport accelerate\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom packaging import version\nfrom tqdm.auto import tqdm\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    EulerDiscreteScheduler,\n    StableDiffusionGLIGENPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import is_wandb_available, make_image_grid\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    pass\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\n# check_min_version(\"0.28.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\n@torch.no_grad()\ndef log_validation(vae, text_encoder, tokenizer, unet, noise_scheduler, args, accelerator, step, weight_dtype):\n    if accelerator.is_main_process:\n        print(\"generate test images...\")\n    unet = accelerator.unwrap_model(unet)\n    vae.to(accelerator.device, dtype=torch.float32)\n\n    pipeline = StableDiffusionGLIGENPipeline(\n        vae,\n        text_encoder,\n        tokenizer,\n        unet,\n        EulerDiscreteScheduler.from_config(noise_scheduler.config),\n        safety_checker=None,\n        feature_extractor=None,\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=not accelerator.is_main_process)\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    prompt = \"A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky\"\n    boxes = [\n        [0.041015625, 0.548828125, 0.453125, 0.859375],\n        [0.525390625, 0.552734375, 0.93359375, 0.865234375],\n        [0.12890625, 0.015625, 0.412109375, 0.279296875],\n        [0.578125, 0.08203125, 0.857421875, 0.27734375],\n    ]\n    gligen_phrases = [\"a green car\", \"a blue truck\", \"a red air balloon\", \"a bird\"]\n    images = pipeline(\n        prompt=prompt,\n        gligen_phrases=gligen_phrases,\n        gligen_boxes=boxes,\n        gligen_scheduled_sampling_beta=1.0,\n        output_type=\"pil\",\n        num_inference_steps=50,\n        negative_prompt=\"artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate\",\n        num_images_per_prompt=4,\n        generator=generator,\n    ).images\n    os.makedirs(os.path.join(args.output_dir, \"images\"), exist_ok=True)\n    make_image_grid(images, 1, 4).save(\n        os.path.join(args.output_dir, \"images\", f\"generated-images-{step:06d}-{accelerator.process_index:02d}.png\")\n    )\n\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a ControlNet training script.\")\n    parser.add_argument(\n        \"--data_path\",\n        type=str,\n        default=\"coco_train2017.pth\",\n        help=\"Path to training dataset.\",\n    )\n    parser.add_argument(\n        \"--image_path\",\n        type=str,\n        default=\"coco_train2017.pth\",\n        help=\"Path to training images.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"controlnet-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=0, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"train_controlnet\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n    args = parser.parse_args()\n    return args\n\n\ndef main(args):\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n    # import correct text encoder class\n    # text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n    # Load scheduler and models\n    from transformers import CLIPTextModel, CLIPTokenizer\n\n    pretrained_model_name_or_path = \"masterful/gligen-1-4-generation-text-box\"\n    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n    noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n\n    vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n    unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"unet\")\n\n    # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                i = len(weights) - 1\n\n                while len(weights) > 0:\n                    weights.pop()\n                    model = models[i]\n\n                    sub_dir = \"unet\"\n                    model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                    i -= 1\n\n        def load_model_hook(models, input_dir):\n            while len(models) > 0:\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = unet.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    vae.requires_grad_(False)\n    unet.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n            # controlnet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # if args.gradient_checkpointing:\n    #     controlnet.enable_gradient_checkpointing()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if unwrap_model(unet).dtype != torch.float32:\n        raise ValueError(f\"Controlnet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    optimizer_class = torch.optim.AdamW\n    # Optimizer creation\n    for n, m in unet.named_modules():\n        if (\"fuser\" in n) or (\"position_net\" in n):\n            import torch.nn as nn\n\n            if isinstance(m, (nn.Linear, nn.LayerNorm)):\n                m.reset_parameters()\n    params_to_optimize = []\n    for n, p in unet.named_parameters():\n        if (\"fuser\" in n) or (\"position_net\" in n):\n            p.requires_grad = True\n            params_to_optimize.append(p)\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    from dataset import COCODataset\n\n    train_dataset = COCODataset(\n        data_path=args.data_path,\n        image_path=args.image_path,\n        tokenizer=tokenizer,\n        image_size=args.resolution,\n        max_boxes_per_data=30,\n    )\n\n    print(\"num samples: \", len(train_dataset))\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        # collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae, unet and text_encoder to device and cast to weight_dtype\n    vae.to(accelerator.device, dtype=weight_dtype)\n    # unet.to(accelerator.device, dtype=weight_dtype)\n    unet.to(accelerator.device, dtype=torch.float32)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n\n        # tensorboard cannot handle list types for config\n        # tracker_config.pop(\"validation_prompt\")\n        # tracker_config.pop(\"validation_image\")\n\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Train!\n    # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    # logger.info(\"***** Running training *****\")\n    # logger.info(f\"  Num examples = {len(train_dataset)}\")\n    # logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    # logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    # logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    # logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    # logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    # logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    log_validation(\n        vae,\n        text_encoder,\n        tokenizer,\n        unet,\n        noise_scheduler,\n        args,\n        accelerator,\n        global_step,\n        weight_dtype,\n    )\n\n    # image_logs = None\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                with torch.no_grad():\n                    # Get the text embedding for conditioning\n                    encoder_hidden_states = text_encoder(\n                        batch[\"caption\"][\"input_ids\"].squeeze(1),\n                        # batch['caption']['attention_mask'].squeeze(1),\n                        return_dict=False,\n                    )[0]\n\n                cross_attention_kwargs = {}\n                cross_attention_kwargs[\"gligen\"] = {\n                    \"boxes\": batch[\"boxes\"],\n                    \"positive_embeddings\": batch[\"text_embeddings_before_projection\"],\n                    \"masks\": batch[\"masks\"],\n                }\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_latents,\n                    timesteps,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step:06d}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    # if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                    log_validation(\n                        vae,\n                        text_encoder,\n                        tokenizer,\n                        unet,\n                        noise_scheduler,\n                        args,\n                        accelerator,\n                        global_step,\n                        weight_dtype,\n                    )\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        unet.save_pretrained(args.output_dir)\n    #\n    #     # Run a final round of validation.\n    #     image_logs = None\n    #     if args.validation_prompt is not None:\n    #         image_logs = log_validation(\n    #             vae=vae,\n    #             text_encoder=text_encoder,\n    #             tokenizer=tokenizer,\n    #             unet=unet,\n    #             controlnet=None,\n    #             args=args,\n    #             accelerator=accelerator,\n    #             weight_dtype=weight_dtype,\n    #             step=global_step,\n    #             is_final_validation=True,\n    #         )\n    #\n    #     if args.push_to_hub:\n    #         save_model_card(\n    #             repo_id,\n    #             image_logs=image_logs,\n    #             base_model=args.pretrained_model_name_or_path,\n    #             repo_folder=args.output_dir,\n    #         )\n    #         upload_folder(\n    #             repo_id=repo_id,\n    #             folder_path=args.output_dir,\n    #             commit_message=\"End of training\",\n    #             ignore_patterns=[\"step_*\", \"epoch_*\"],\n    #         )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/instructpix2pix_lora/README.md",
    "content": "# InstructPix2Pix text-to-edit-image fine-tuning\nThis extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost).\nThis is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model.\n\n## Running locally with PyTorch\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder  and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\n\n## Training script example\n\n```bash\nexport MODEL_ID=\"timbrooks/instruct-pix2pix\"\nexport DATASET_ID=\"instruction-tuning-sd/cartoonization\"\nexport OUTPUT_DIR=\"instructPix2Pix-cartoonization\"\n\naccelerate launch train_instruct_pix2pix_lora.py \\\n  --pretrained_model_name_or_path=$MODEL_ID \\\n  --dataset_name=$DATASET_ID \\\n  --enable_xformers_memory_efficient_attention \\\n  --resolution=256 --random_flip \\\n  --train_batch_size=2 --gradient_accumulation_steps=4 --gradient_checkpointing \\\n  --max_train_steps=15000 \\\n  --checkpointing_steps=5000 --checkpoints_total_limit=1 \\\n  --learning_rate=5e-05 --lr_warmup_steps=0 \\\n  --val_image_url=\"https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png\" \\\n  --validation_prompt=\"Generate a cartoonized version of the natural image\" \\\n  --seed=42 \\\n  --rank=4 \\\n  --output_dir=$OUTPUT_DIR \\\n  --report_to=wandb \\\n  --push_to_hub \\\n  --original_image_column=\"original_image\" \\\n  --edited_image_column=\"cartoonized_image\" \\\n  --edit_prompt_column=\"edit_prompt\"\n```\n\n## Inference\nAfter training the model and the lora weight of the model is stored in the ```$OUTPUT_DIR```.\n\n```py\n# load the base model pipeline\npipe_lora = StableDiffusionInstructPix2PixPipeline.from_pretrained(\"timbrooks/instruct-pix2pix\")\n\n# Load LoRA weights from the provided path\noutput_dir = \"path/to/lora_weight_directory\"\npipe_lora.unet.load_attn_procs(output_dir)\n\ninput_image_path = \"/path/to/input_image\"\ninput_image = Image.open(input_image_path)\nedited_images = pipe_lora(num_images_per_prompt=1, prompt=args.edit_prompt, image=input_image, num_inference_steps=1000).images\nedited_images[0].show()\n```\n\n## Results\n\nHere is an example of using the script to train a instructpix2pix model.\nTrained on google colab T4 GPU\n\n```bash\nMODEL_ID=\"timbrooks/instruct-pix2pix\"\nDATASET_ID=\"instruction-tuning-sd/cartoonization\"\nTRAIN_EPOCHS=100\n```\n\nBelow are few examples for given the input image, edit_prompt and the edited_image (output of the model)\n\n<p align=\"center\">\n    <img src=\"https://github.com/Aiden-Frost/Efficiently-teaching-counting-and-cartoonization-to-InstructPix2Pix.-/blob/main/diffusers_result_assets/edited_image_results.png?raw=true\" alt=\"instructpix2pix-inputs\" width=600/>\n</p>\n\n\nHere are some rough statistics about the training model using this script\n\n<p align=\"center\">\n    <img src=\"https://github.com/Aiden-Frost/Efficiently-teaching-counting-and-cartoonization-to-InstructPix2Pix.-/blob/main/diffusers_result_assets/results.png?raw=true\" alt=\"instructpix2pix-inputs\" width=600/>\n</p>\n\n## References\n\n* InstructPix2Pix - https://github.com/timothybrooks/instruct-pix2pix\n* Dataset and example training script - https://huggingface.co/blog/instruction-tuning-sd\n* For more information about the project - https://github.com/Aiden-Frost/Efficiently-teaching-counting-and-cartoonization-to-InstructPix2Pix.-"
  },
  {
    "path": "examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nScript to fine-tune Stable Diffusion for LORA InstructPix2Pix.\nBase code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py\n\"\"\"\n\nimport argparse\nimport logging\nimport math\nimport os\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport PIL\nimport requests\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig\nfrom peft.utils import get_peft_model_state_dict\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel, cast_training_params\nfrom diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available\nfrom diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.32.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\nDATASET_NAME_MAPPING = {\n    \"fusing/instructpix2pix-1000-samples\": (\"input_image\", \"edit_prompt\", \"edited_image\"),\n}\nWANDB_TABLE_COL_NAMES = [\"original_image\", \"edited_image\", \"edit_prompt\"]\n\n\ndef save_model_card(\n    repo_id: str,\n    images: list = None,\n    base_model: str = None,\n    dataset_name: str = None,\n    repo_folder: str = None,\n):\n    img_str = \"\"\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# LoRA text2image fine-tuning - {repo_id}\nThese are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \\n\n{img_str}\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion\",\n        \"stable-diffusion-diffusers\",\n        \"text-to-image\",\n        \"instruct-pix2pix\",\n        \"diffusers\",\n        \"diffusers-training\",\n        \"lora\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    generator,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    original_image = download_image(args.val_image_url)\n    edited_images = []\n    if torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type)\n\n    with autocast_ctx:\n        for _ in range(args.num_validation_images):\n            edited_images.append(\n                pipeline(\n                    args.validation_prompt,\n                    image=original_image,\n                    num_inference_steps=20,\n                    image_guidance_scale=1.5,\n                    guidance_scale=7,\n                    generator=generator,\n                ).images[0]\n            )\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"wandb\":\n            wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)\n            for edited_image in edited_images:\n                wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)\n            tracker.log({\"validation\": wandb_table})\n\n    return edited_images\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script for InstructPix2Pix.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--original_image_column\",\n        type=str,\n        default=\"input_image\",\n        help=\"The column of the dataset containing the original image on which edits where made.\",\n    )\n    parser.add_argument(\n        \"--edited_image_column\",\n        type=str,\n        default=\"edited_image\",\n        help=\"The column of the dataset containing the edited image.\",\n    )\n    parser.add_argument(\n        \"--edit_prompt_column\",\n        type=str,\n        default=\"edit_prompt\",\n        help=\"The column of the dataset containing the edit instruction.\",\n    )\n    parser.add_argument(\n        \"--val_image_url\",\n        type=str,\n        default=None,\n        help=\"URL to the original image that you would like to edit (used during inference for debugging purposes).\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\", type=str, default=None, help=\"A prompt that is sampled during training for inference.\"\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"instruct-pix2pix-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=256,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--conditioning_dropout_prob\",\n        type=float,\n        default=None,\n        help=\"Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://huggingface.co/papers/2211.09800.\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--non_ema_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or\"\n            \" remote repository specified with --pretrained_model_name_or_path.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    # default to using the same revision for the non-ema model if not specified\n    if args.non_ema_revision is None:\n        args.non_ema_revision = args.revision\n\n    return args\n\n\ndef convert_to_np(image, resolution):\n    image = image.convert(\"RGB\").resize((resolution, resolution))\n    return np.array(image).transpose(2, 0, 1)\n\n\ndef download_image(url):\n    image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)\n    image = PIL.ImageOps.exif_transpose(image)\n    image = image.convert(\"RGB\")\n    return image\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if args.non_ema_revision is not None:\n        deprecate(\n            \"non_ema_revision!=None\",\n            \"0.15.0\",\n            message=(\n                \"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to\"\n                \" use `--variant=non_ema` instead.\"\n            ),\n        )\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n    )\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.non_ema_revision\n    )\n\n    # InstructPix2Pix uses an additional image for conditioning. To accommodate that,\n    # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is\n    # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized\n    # from the pre-trained checkpoints. For the extra channels added to the first layer, they are\n    # initialized to zero.\n    logger.info(\"Initializing the InstructPix2Pix UNet from the pretrained UNet.\")\n    in_channels = 8\n    out_channels = unet.conv_in.out_channels\n    unet.register_to_config(in_channels=in_channels)\n\n    with torch.no_grad():\n        new_conv_in = nn.Conv2d(\n            in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding\n        )\n        new_conv_in.weight.zero_()\n        new_conv_in.weight[:, :in_channels, :, :].copy_(unet.conv_in.weight)\n        unet.conv_in = new_conv_in\n\n    # Freeze vae, text_encoder and unet\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # referred to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Freeze the unet parameters before adding adapters\n    unet.requires_grad_(False)\n\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # Add adapter and make sure the trainable params are in float32.\n    unet.add_adapter(unet_lora_config)\n    if args.mixed_precision == \"fp16\":\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(unet, dtype=torch.float32)\n\n    # Create EMA for the unet.\n    if args.use_ema:\n        ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    trainable_params = filter(lambda p: p.requires_grad, unet.parameters())\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_unet.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    if weights:\n                        weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"unet_ema\"), UNet2DConditionModel)\n                ema_unet.load_state_dict(load_model.state_dict())\n                ema_unet.to(accelerator.device)\n                del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    # train on only lora_layers\n    optimizer = optimizer_cls(\n        trainable_params,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/main/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.original_image_column is None:\n        original_image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        original_image_column = args.original_image_column\n        if original_image_column not in column_names:\n            raise ValueError(\n                f\"--original_image_column' value '{args.original_image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.edit_prompt_column is None:\n        edit_prompt_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        edit_prompt_column = args.edit_prompt_column\n        if edit_prompt_column not in column_names:\n            raise ValueError(\n                f\"--edit_prompt_column' value '{args.edit_prompt_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.edited_image_column is None:\n        edited_image_column = dataset_columns[2] if dataset_columns is not None else column_names[2]\n    else:\n        edited_image_column = args.edited_image_column\n        if edited_image_column not in column_names:\n            raise ValueError(\n                f\"--edited_image_column' value '{args.edited_image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(captions):\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        return inputs.input_ids\n\n    # Preprocessing the datasets.\n    train_transforms = transforms.Compose(\n        [\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n        ]\n    )\n\n    def preprocess_images(examples):\n        original_images = np.concatenate(\n            [convert_to_np(image, args.resolution) for image in examples[original_image_column]]\n        )\n        edited_images = np.concatenate(\n            [convert_to_np(image, args.resolution) for image in examples[edited_image_column]]\n        )\n        # We need to ensure that the original and the edited images undergo the same\n        # augmentation transforms.\n        images = np.concatenate([original_images, edited_images])\n        images = torch.tensor(images)\n        images = 2 * (images / 255) - 1\n        return train_transforms(images)\n\n    def preprocess_train(examples):\n        # Preprocess images.\n        preprocessed_images = preprocess_images(examples)\n        # Since the original and edited images were concatenated before\n        # applying the transformations, we need to separate them and reshape\n        # them accordingly.\n        original_images, edited_images = preprocessed_images.chunk(2)\n        original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)\n        edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)\n\n        # Collate the preprocessed images into the `examples`.\n        examples[\"original_pixel_values\"] = original_images\n        examples[\"edited_pixel_values\"] = edited_images\n\n        # Preprocess the captions.\n        captions = list(examples[edit_prompt_column])\n        examples[\"input_ids\"] = tokenize_captions(captions)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        original_pixel_values = torch.stack([example[\"original_pixel_values\"] for example in examples])\n        original_pixel_values = original_pixel_values.to(memory_format=torch.contiguous_format).float()\n        edited_pixel_values = torch.stack([example[\"edited_pixel_values\"] for example in examples])\n        edited_pixel_values = edited_pixel_values.to(memory_format=torch.contiguous_format).float()\n        input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n        return {\n            \"original_pixel_values\": original_pixel_values,\n            \"edited_pixel_values\": edited_pixel_values,\n            \"input_ids\": input_ids,\n        }\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    if args.use_ema:\n        ema_unet.to(accelerator.device)\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move text_encode and vae to gpu and cast to weight_dtype\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"instruct-pix2pix\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(unet):\n                # We want to learn the denoising process w.r.t the edited images which\n                # are conditioned on the original image (which was edited) and the edit instruction.\n                # So, first, convert images to latent space.\n                latents = vae.encode(batch[\"edited_pixel_values\"].to(weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning.\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n                # Get the additional image embedding for conditioning.\n                # Instead of getting a diagonal Gaussian here, we simply take the mode.\n                original_image_embeds = vae.encode(batch[\"original_pixel_values\"].to(weight_dtype)).latent_dist.mode()\n\n                # Conditioning dropout to support classifier-free guidance during inference. For more details\n                # check out the section 3.2.1 of the original paper https://huggingface.co/papers/2211.09800.\n                if args.conditioning_dropout_prob is not None:\n                    random_p = torch.rand(bsz, device=latents.device, generator=generator)\n                    # Sample masks for the edit prompts.\n                    prompt_mask = random_p < 2 * args.conditioning_dropout_prob\n                    prompt_mask = prompt_mask.reshape(bsz, 1, 1)\n                    # Final text conditioning.\n                    null_conditioning = text_encoder(tokenize_captions([\"\"]).to(accelerator.device))[0]\n                    encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)\n\n                    # Sample masks for the original images.\n                    image_mask_dtype = original_image_embeds.dtype\n                    image_mask = 1 - (\n                        (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)\n                        * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)\n                    )\n                    image_mask = image_mask.reshape(bsz, 1, 1, 1)\n                    # Final image conditioning.\n                    original_image_embeds = image_mask * original_image_embeds\n\n                # Concatenate the `original_image_embeds` with the `noisy_latents`.\n                concatenated_noisy_latents = torch.cat([noisy_latents, original_image_embeds], dim=1)\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # Predict the noise residual and compute loss\n                model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_unet.step(trainable_params)\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        unwrapped_unet = unwrap_model(unet)\n                        unet_lora_state_dict = convert_state_dict_to_diffusers(\n                            get_peft_model_state_dict(unwrapped_unet)\n                        )\n\n                        StableDiffusionInstructPix2PixPipeline.save_lora_weights(\n                            save_directory=save_path,\n                            unet_lora_layers=unet_lora_state_dict,\n                            safe_serialization=True,\n                        )\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if (\n                (args.val_image_url is not None)\n                and (args.validation_prompt is not None)\n                and (epoch % args.validation_epochs == 0)\n            ):\n                logger.info(\n                    f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n                    f\" {args.validation_prompt}.\"\n                )\n                # create pipeline\n                if args.use_ema:\n                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                    ema_unet.store(unet.parameters())\n                    ema_unet.copy_to(unet.parameters())\n                # The models need unwrapping because for compatibility in distributed training mode.\n                pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    unet=unwrap_model(unet),\n                    text_encoder=unwrap_model(text_encoder),\n                    vae=unwrap_model(vae),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n\n                # run inference\n                log_validation(\n                    pipeline,\n                    args,\n                    accelerator,\n                    generator,\n                )\n\n                if args.use_ema:\n                    # Switch back to the original UNet parameters.\n                    ema_unet.restore(unet.parameters())\n\n                del pipeline\n                torch.cuda.empty_cache()\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        if args.use_ema:\n            ema_unet.copy_to(unet.parameters())\n\n        # store only LORA layers\n        unet = unet.to(torch.float32)\n\n        unwrapped_unet = unwrap_model(unet)\n        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))\n        StableDiffusionInstructPix2PixPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_state_dict,\n            safe_serialization=True,\n        )\n\n        pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            text_encoder=unwrap_model(text_encoder),\n            vae=unwrap_model(vae),\n            unet=unwrap_model(unet),\n            revision=args.revision,\n            variant=args.variant,\n        )\n        pipeline.load_lora_weights(args.output_dir)\n\n        images = None\n        if (args.val_image_url is not None) and (args.validation_prompt is not None):\n            images = log_validation(\n                pipeline,\n                args,\n                accelerator,\n                generator,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                dataset_name=args.dataset_name,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/intel_opts/README.md",
    "content": "## Diffusers examples with Intel optimizations\n\n**This research project is not actively maintained by the diffusers team. For any questions or comments, please make sure to tag @hshen14 .**\n\nThis aims to provide diffusers examples with Intel optimizations such as Bfloat16 for training/fine-tuning acceleration and 8-bit integer (INT8) for inference acceleration on Intel platforms.\n\n## Accelerating the fine-tuning for textual inversion\n\nWe accelerate the fine-tuning for textual inversion with Intel Extension for PyTorch. The [examples](textual_inversion) enable both single node and multi-node distributed training with Bfloat16 support on Intel Xeon Scalable Processor.\n\n## Accelerating the inference for Stable Diffusion using Bfloat16\n\nWe start the inference acceleration with Bfloat16 using Intel Extension for PyTorch. The [script](inference_bf16.py) is generally designed to support standard Stable Diffusion models with Bfloat16 support.\n```bash\npip install diffusers transformers accelerate scipy safetensors\n\nexport KMP_BLOCKTIME=1\nexport KMP_SETTINGS=1\nexport KMP_AFFINITY=granularity=fine,compact,1,0\n\n# Intel OpenMP\nexport OMP_NUM_THREADS=< Cores to use >\nexport LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libiomp5.so\n# Jemalloc is a recommended malloc implementation that emphasizes fragmentation avoidance and scalable concurrency support.\nexport LD_PRELOAD=${LD_PRELOAD}:/path/to/lib/libjemalloc.so\nexport MALLOC_CONF=\"oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:9000000000\"\n\n# Launch with default DDIM\nnumactl --membind <node N> -C <cpu list> python python inference_bf16.py\n# Launch with DPMSolverMultistepScheduler\nnumactl --membind <node N> -C <cpu list> python python inference_bf16.py --dpm\n```\n\n## Accelerating the inference for Stable Diffusion using INT8\n\nComing soon ...\n"
  },
  {
    "path": "examples/research_projects/intel_opts/inference_bf16.py",
    "content": "import argparse\n\nimport intel_extension_for_pytorch as ipex\nimport torch\n\nfrom diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline\n\n\nparser = argparse.ArgumentParser(\"Stable Diffusion script with intel optimization\", add_help=False)\nparser.add_argument(\"--dpm\", action=\"store_true\", help=\"Enable DPMSolver or not\")\nparser.add_argument(\"--steps\", default=None, type=int, help=\"Num inference steps\")\nargs = parser.parse_args()\n\n\ndevice = \"cpu\"\nprompt = \"a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brightly buildings\"\n\nmodel_id = \"path-to-your-trained-model\"\npipe = StableDiffusionPipeline.from_pretrained(model_id)\nif args.dpm:\n    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\npipe = pipe.to(device)\n\n# to channels last\npipe.unet = pipe.unet.to(memory_format=torch.channels_last)\npipe.vae = pipe.vae.to(memory_format=torch.channels_last)\npipe.text_encoder = pipe.text_encoder.to(memory_format=torch.channels_last)\nif pipe.requires_safety_checker:\n    pipe.safety_checker = pipe.safety_checker.to(memory_format=torch.channels_last)\n\n# optimize with ipex\nsample = torch.randn(2, 4, 64, 64)\ntimestep = torch.rand(1) * 999\nencoder_hidden_status = torch.randn(2, 77, 768)\ninput_example = (sample, timestep, encoder_hidden_status)\ntry:\n    pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=input_example)\nexcept Exception:\n    pipe.unet = ipex.optimize(pipe.unet.eval(), dtype=torch.bfloat16, inplace=True)\npipe.vae = ipex.optimize(pipe.vae.eval(), dtype=torch.bfloat16, inplace=True)\npipe.text_encoder = ipex.optimize(pipe.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)\nif pipe.requires_safety_checker:\n    pipe.safety_checker = ipex.optimize(pipe.safety_checker.eval(), dtype=torch.bfloat16, inplace=True)\n\n# compute\nseed = 666\ngenerator = torch.Generator(device).manual_seed(seed)\ngenerate_kwargs = {\"generator\": generator}\nif args.steps is not None:\n    generate_kwargs[\"num_inference_steps\"] = args.steps\n\nwith torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):\n    image = pipe(prompt, **generate_kwargs).images[0]\n\n# save image\nimage.save(\"generated.png\")\n"
  },
  {
    "path": "examples/research_projects/intel_opts/textual_inversion/README.md",
    "content": "## Textual Inversion fine-tuning example\n\n[Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.\nThe `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.\n\n## Training with Intel Extension for PyTorch\n\nIntel Extension for PyTorch provides the optimizations for faster training and inference on CPUs. You can leverage the training example \"textual_inversion.py\". Follow the [instructions](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) to get the model and [dataset](https://huggingface.co/sd-concepts-library/dicoo2) before running the script.\n\nThe example supports both single node and multi-node distributed training:\n\n### Single node training\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport DATA_DIR=\"path-to-dir-containing-dicoo-images\"\n\npython textual_inversion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<dicoo>\" --initializer_token=\"toy\" \\\n  --seed=7 \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --max_train_steps=3000 \\\n  --learning_rate=2.5e-03 --scale_lr \\\n  --output_dir=\"textual_inversion_dicoo\"\n```\n\nNote: Bfloat16 is available on Intel Xeon Scalable Processors Cooper Lake or Sapphire Rapids. You may not get performance speedup without Bfloat16 support.\n\n### Multi-node distributed training\n\nBefore running the scripts, make sure to install the library's training dependencies successfully:\n\n```bash\npython -m pip install oneccl_bind_pt==1.13 -f https://developer.intel.com/ipex-whl-stable-cpu\n```\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport DATA_DIR=\"path-to-dir-containing-dicoo-images\"\n\noneccl_bindings_for_pytorch_path=$(python -c \"from oneccl_bindings_for_pytorch import cwd; print(cwd)\")\nsource $oneccl_bindings_for_pytorch_path/env/setvars.sh\n\npython -m intel_extension_for_pytorch.cpu.launch --distributed \\\n  --hostfile hostfile --nnodes 2 --nproc_per_node 2 textual_inversion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<dicoo>\" --initializer_token=\"toy\" \\\n  --seed=7 \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --max_train_steps=750 \\\n  --learning_rate=2.5e-03 --scale_lr \\\n  --output_dir=\"textual_inversion_dicoo\"\n```\nThe above is a simple distributed training usage on 2 nodes with 2 processes on each node. Add the right hostname or ip address in the \"hostfile\" and make sure these 2 nodes are reachable from each other. For more details, please refer to the [user guide](https://github.com/intel/torch-ccl).\n\n\n### Reference\n\nWe publish a [Medium blog](https://medium.com/intel-analytics-software/personalized-stable-diffusion-with-few-shot-fine-tuning-on-a-single-cpu-f01a3316b13) on how to create your own Stable Diffusion model on CPUs using textual inversion. Try it out now, if you have interests.\n"
  },
  {
    "path": "examples/research_projects/intel_opts/textual_inversion/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.21.0\nftfy\ntensorboard\nJinja2\nintel_extension_for_pytorch>=1.13\n"
  },
  {
    "path": "examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py",
    "content": "import argparse\nimport itertools\nimport math\nimport os\nimport random\nfrom pathlib import Path\n\nimport intel_extension_for_pytorch as ipex\nimport numpy as np\nimport PIL\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\n\n# TODO: remove and import from diffusers.utils when the new version of diffusers is released\nfrom packaging import version\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker\nfrom diffusers.utils import check_min_version\n\n\nif version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.Resampling.BILINEAR,\n        \"bilinear\": PIL.Image.Resampling.BILINEAR,\n        \"bicubic\": PIL.Image.Resampling.BICUBIC,\n        \"lanczos\": PIL.Image.Resampling.LANCZOS,\n        \"nearest\": PIL.Image.Resampling.NEAREST,\n    }\nelse:\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.LINEAR,\n        \"bilinear\": PIL.Image.BILINEAR,\n        \"bicubic\": PIL.Image.BICUBIC,\n        \"lanczos\": PIL.Image.LANCZOS,\n        \"nearest\": PIL.Image.NEAREST,\n    }\n# ------------------------------------------------------------------------------\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.13.0.dev0\")\n\n\nlogger = get_logger(__name__)\n\n\ndef save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):\n    logger.info(\"Saving embeddings\")\n    learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]\n    learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}\n    torch.save(learned_embeds_dict, save_path)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=500,\n        help=\"Save learned_embeds.bin every X updates steps.\",\n    )\n    parser.add_argument(\n        \"--only_save_embeds\",\n        action=\"store_true\",\n        default=False,\n        help=\"Save only the embeddings for the new concept.\",\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\", type=str, default=None, required=True, help=\"A folder containing the training data.\"\n    )\n    parser.add_argument(\n        \"--placeholder_token\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A token to use as a placeholder for the concept.\",\n    )\n    parser.add_argument(\n        \"--initializer_token\", type=str, default=None, required=True, help=\"A token to use as initializer word.\"\n    )\n    parser.add_argument(\"--learnable_property\", type=str, default=\"object\", help=\"Choose between 'object' and 'style'\")\n    parser.add_argument(\"--repeats\", type=int, default=100, help=\"How many times to repeat the training data.\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\", action=\"store_true\", help=\"Whether to center crop images before resizing to resolution.\"\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=5000,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=True,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.train_data_dir is None:\n        raise ValueError(\"You must specify a train data directory.\")\n\n    return args\n\n\nimagenet_templates_small = [\n    \"a photo of a {}\",\n    \"a rendering of a {}\",\n    \"a cropped photo of the {}\",\n    \"the photo of a {}\",\n    \"a photo of a clean {}\",\n    \"a photo of a dirty {}\",\n    \"a dark photo of the {}\",\n    \"a photo of my {}\",\n    \"a photo of the cool {}\",\n    \"a close-up photo of a {}\",\n    \"a bright photo of the {}\",\n    \"a cropped photo of a {}\",\n    \"a photo of the {}\",\n    \"a good photo of the {}\",\n    \"a photo of one {}\",\n    \"a close-up photo of the {}\",\n    \"a rendition of the {}\",\n    \"a photo of the clean {}\",\n    \"a rendition of a {}\",\n    \"a photo of a nice {}\",\n    \"a good photo of a {}\",\n    \"a photo of the nice {}\",\n    \"a photo of the small {}\",\n    \"a photo of the weird {}\",\n    \"a photo of the large {}\",\n    \"a photo of a cool {}\",\n    \"a photo of a small {}\",\n]\n\nimagenet_style_templates_small = [\n    \"a painting in the style of {}\",\n    \"a rendering in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"the painting in the style of {}\",\n    \"a clean painting in the style of {}\",\n    \"a dirty painting in the style of {}\",\n    \"a dark painting in the style of {}\",\n    \"a picture in the style of {}\",\n    \"a cool painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a bright painting in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"a good painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a rendition in the style of {}\",\n    \"a nice painting in the style of {}\",\n    \"a small painting in the style of {}\",\n    \"a weird painting in the style of {}\",\n    \"a large painting in the style of {}\",\n]\n\n\nclass TextualInversionDataset(Dataset):\n    def __init__(\n        self,\n        data_root,\n        tokenizer,\n        learnable_property=\"object\",  # [object, style]\n        size=512,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.5,\n        set=\"train\",\n        placeholder_token=\"*\",\n        center_crop=False,\n    ):\n        self.data_root = data_root\n        self.tokenizer = tokenizer\n        self.learnable_property = learnable_property\n        self.size = size\n        self.placeholder_token = placeholder_token\n        self.center_crop = center_crop\n        self.flip_p = flip_p\n\n        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.interpolation = {\n            \"linear\": PIL_INTERPOLATION[\"linear\"],\n            \"bilinear\": PIL_INTERPOLATION[\"bilinear\"],\n            \"bicubic\": PIL_INTERPOLATION[\"bicubic\"],\n            \"lanczos\": PIL_INTERPOLATION[\"lanczos\"],\n        }[interpolation]\n\n        self.templates = imagenet_style_templates_small if learnable_property == \"style\" else imagenet_templates_small\n        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        placeholder_string = self.placeholder_token\n        text = random.choice(self.templates).format(placeholder_string)\n\n        example[\"input_ids\"] = self.tokenizer(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids[0]\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            (\n                h,\n                w,\n            ) = (\n                img.shape[0],\n                img.shape[1],\n            )\n            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]\n\n        image = Image.fromarray(img)\n        image = image.resize((self.size, self.size), resample=self.interpolation)\n\n        image = self.flip_transform(image)\n        image = np.array(image).astype(np.uint8)\n        image = (image / 127.5 - 1.0).astype(np.float32)\n\n        example[\"pixel_values\"] = torch.from_numpy(image).permute(2, 0, 1)\n        return example\n\n\ndef freeze_params(params):\n    for param in params:\n        param.requires_grad = False\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer and add the placeholder token as a additional special token\n    if args.tokenizer_name:\n        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n    # Add the placeholder token in tokenizer\n    num_added_tokens = tokenizer.add_tokens(args.placeholder_token)\n    if num_added_tokens == 0:\n        raise ValueError(\n            f\"The tokenizer already contains the token {args.placeholder_token}. Please pass a different\"\n            \" `placeholder_token` that is not already in the tokenizer.\"\n        )\n\n    # Convert the initializer_token, placeholder_token to ids\n    token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)\n    # Check if initializer_token is a single token or a sequence of tokens\n    if len(token_ids) > 1:\n        raise ValueError(\"The initializer token must be a single token.\")\n\n    initializer_token_id = token_ids[0]\n    placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)\n\n    # Load models and create wrapper for stable diffusion\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=args.revision,\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"unet\",\n        revision=args.revision,\n    )\n\n    # Resize the token embeddings as we are adding new special tokens to the tokenizer\n    text_encoder.resize_token_embeddings(len(tokenizer))\n\n    # Initialise the newly added placeholder token with the embeddings of the initializer token\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]\n\n    # Freeze vae and unet\n    freeze_params(vae.parameters())\n    freeze_params(unet.parameters())\n    # Freeze all parameters except for the token embeddings in text encoder\n    params_to_freeze = itertools.chain(\n        text_encoder.text_model.encoder.parameters(),\n        text_encoder.text_model.final_layer_norm.parameters(),\n        text_encoder.text_model.embeddings.position_embedding.parameters(),\n    )\n    freeze_params(params_to_freeze)\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    optimizer = torch.optim.AdamW(\n        text_encoder.get_input_embeddings().parameters(),  # only optimize the embeddings\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    train_dataset = TextualInversionDataset(\n        data_root=args.train_data_dir,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        placeholder_token=args.placeholder_token,\n        repeats=args.repeats,\n        learnable_property=args.learnable_property,\n        center_crop=args.center_crop,\n        set=\"train\",\n    )\n    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        text_encoder, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # Move vae and unet to device\n    vae.to(accelerator.device)\n    unet.to(accelerator.device)\n\n    # Keep vae and unet in eval model as we don't train these\n    vae.eval()\n    unet.eval()\n\n    unet = ipex.optimize(unet, dtype=torch.bfloat16, inplace=True)\n    vae = ipex.optimize(vae, dtype=torch.bfloat16, inplace=True)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"textual_inversion\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n    global_step = 0\n\n    text_encoder.train()\n    text_encoder, optimizer = ipex.optimize(text_encoder, optimizer=optimizer, dtype=torch.bfloat16)\n\n    for epoch in range(args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):\n                with accelerator.accumulate(text_encoder):\n                    # Convert images to latent space\n                    latents = vae.encode(batch[\"pixel_values\"]).latent_dist.sample().detach()\n                    latents = latents * vae.config.scaling_factor\n\n                    # Sample noise that we'll add to the latents\n                    noise = torch.randn(latents.shape).to(latents.device)\n                    bsz = latents.shape[0]\n                    # Sample a random timestep for each image\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device\n                    ).long()\n\n                    # Add noise to the latents according to the noise magnitude at each timestep\n                    # (this is the forward diffusion process)\n                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                    # Get the text embedding for conditioning\n                    encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n                    # Predict the noise residual\n                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n\n                    # Get the target for loss depending on the prediction type\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        target = noise\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                    else:\n                        raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                    loss = F.mse_loss(model_pred, target, reduction=\"none\").mean([1, 2, 3]).mean()\n                    accelerator.backward(loss)\n\n                    # Zero out the gradients for all token embeddings except the newly added\n                    # embeddings for the concept, as we only want to optimize the concept embeddings\n                    if accelerator.num_processes > 1:\n                        grads = text_encoder.module.get_input_embeddings().weight.grad\n                    else:\n                        grads = text_encoder.get_input_embeddings().weight.grad\n                    # Get the index for tokens that we want to zero the grads for\n                    index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id\n                    grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)\n\n                    optimizer.step()\n                    lr_scheduler.step()\n                    optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                if global_step % args.save_steps == 0:\n                    save_path = os.path.join(args.output_dir, f\"learned_embeds-steps-{global_step}.bin\")\n                    save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        accelerator.wait_for_everyone()\n\n    # Create the pipeline using using the trained modules and save it.\n    if accelerator.is_main_process:\n        if args.push_to_hub and args.only_save_embeds:\n            logger.warning(\"Enabling full model saving because --push_to_hub=True was specified.\")\n            save_full_model = True\n        else:\n            save_full_model = not args.only_save_embeds\n        if save_full_model:\n            pipeline = StableDiffusionPipeline(\n                text_encoder=accelerator.unwrap_model(text_encoder),\n                vae=vae,\n                unet=unet,\n                tokenizer=tokenizer,\n                scheduler=PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\"),\n                safety_checker=StableDiffusionSafetyChecker.from_pretrained(\"CompVis/stable-diffusion-safety-checker\"),\n                feature_extractor=CLIPImageProcessor.from_pretrained(\"openai/clip-vit-base-patch32\"),\n            )\n            pipeline.save_pretrained(args.output_dir)\n        # Save the newly trained embeddings\n        save_path = os.path.join(args.output_dir, \"learned_embeds.bin\")\n        save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/intel_opts/textual_inversion_dfq/README.md",
    "content": "# Distillation for quantization on Textual Inversion models to personalize text2image\n\n[Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images._By using just 3-5 images new concepts can be taught to Stable Diffusion and the model personalized on your own images_\nThe `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.\nWe have enabled distillation for quantization in `textual_inversion.py` to do quantization aware training as well as distillation on the model generated by Textual Inversion method.\n\n## Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n```bash\npip install -r requirements.txt\n```\n\n## Prepare Datasets\n\nOne picture which is from the huggingface datasets [sd-concepts-library/dicoo2](https://huggingface.co/sd-concepts-library/dicoo2) is needed, and save it to the `./dicoo` directory. The picture is shown below:\n\n<a href=\"https://huggingface.co/sd-concepts-library/dicoo2/blob/main/concept_images/1.jpeg\">\n    <img src=\"https://huggingface.co/sd-concepts-library/dicoo2/resolve/main/concept_images/1.jpeg\" width = \"300\" height=\"300\">\n</a>\n\n## Get a FP32 Textual Inversion model\n\nUse the following command to fine-tune the Stable Diffusion model on the above dataset to obtain the FP32 Textual Inversion model.\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport DATA_DIR=\"./dicoo\"\n\naccelerate launch textual_inversion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<dicoo>\" --initializer_token=\"toy\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=3000 \\\n  --learning_rate=5.0e-04 --scale_lr \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --output_dir=\"dicoo_model\"\n```\n\n## Do distillation for quantization\n\nDistillation for quantization is a method that combines [intermediate layer knowledge distillation](https://github.com/intel/neural-compressor/blob/master/docs/source/distillation.md#intermediate-layer-knowledge-distillation) and [quantization aware training](https://github.com/intel/neural-compressor/blob/master/docs/source/quantization.md#quantization-aware-training) in the same training process to improve the performance of the quantized model. Provided a FP32 model, the distillation for quantization approach will take this model itself as the teacher model and transfer the knowledges of the specified layers to the student model, i.e. quantized version of the FP32 model, during the quantization aware training process.\n\nOnce you have the FP32 Textual Inversion model, the following command will take the FP32 Textual Inversion model as input to do distillation for quantization and generate the INT8 Textual Inversion model.\n\n```bash\nexport FP32_MODEL_NAME=\"./dicoo_model\"\nexport DATA_DIR=\"./dicoo\"\n\naccelerate launch textual_inversion.py \\\n  --pretrained_model_name_or_path=$FP32_MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --use_ema --learnable_property=\"object\" \\\n  --placeholder_token=\"<dicoo>\" --initializer_token=\"toy\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=300 \\\n  --learning_rate=5.0e-04 --max_grad_norm=3 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --output_dir=\"int8_model\" \\\n  --do_quantization --do_distillation --verify_loading\n```\n\nAfter the distillation for quantization process, the quantized UNet would be 4 times smaller (3279MB -> 827MB).\n\n## Inference\n\nOnce you have trained a INT8 model with the above command, the inference can be done simply using the `text2images.py` script. Make sure to include the `placeholder_token` in your prompt.\n\n```bash\nexport INT8_MODEL_NAME=\"./int8_model\"\n\npython text2images.py \\\n  --pretrained_model_name_or_path=$INT8_MODEL_NAME \\\n  --caption \"a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brightly buildings.\" \\\n  --images_num 4\n```\n\nHere is the comparison of images generated by the FP32 model (left) and INT8 model (right) respectively:\n\n<p float=\"left\">\n  <img src=\"https://huggingface.co/datasets/Intel/textual_inversion_dicoo_dfq/resolve/main/FP32.png\" width = \"300\" height = \"300\" alt=\"FP32\" align=center />\n  <img src=\"https://huggingface.co/datasets/Intel/textual_inversion_dicoo_dfq/resolve/main/INT8.png\" width = \"300\" height = \"300\" alt=\"INT8\" align=center />\n</p>\n\n"
  },
  {
    "path": "examples/research_projects/intel_opts/textual_inversion_dfq/requirements.txt",
    "content": "accelerate\ntorchvision\ntransformers>=4.25.0\nftfy\ntensorboard\nmodelcards\nneural-compressor"
  },
  {
    "path": "examples/research_projects/intel_opts/textual_inversion_dfq/text2images.py",
    "content": "import argparse\nimport math\nimport os\n\nimport torch\nfrom neural_compressor.utils.pytorch import load\nfrom PIL import Image\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"-m\",\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"-c\",\n        \"--caption\",\n        type=str,\n        default=\"robotic cat with wings\",\n        help=\"Text used to generate images.\",\n    )\n    parser.add_argument(\n        \"-n\",\n        \"--images_num\",\n        type=int,\n        default=4,\n        help=\"How much images to generate.\",\n    )\n    parser.add_argument(\n        \"-s\",\n        \"--seed\",\n        type=int,\n        default=42,\n        help=\"Seed for random process.\",\n    )\n    parser.add_argument(\n        \"-ci\",\n        \"--cuda_id\",\n        type=int,\n        default=0,\n        help=\"cuda_id.\",\n    )\n    args = parser.parse_args()\n    return args\n\n\ndef image_grid(imgs, rows, cols):\n    if not len(imgs) == rows * cols:\n        raise ValueError(\"The specified number of rows and columns are not correct.\")\n\n    w, h = imgs[0].size\n    grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n    grid_w, grid_h = grid.size\n\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i % cols * w, i // cols * h))\n    return grid\n\n\ndef generate_images(\n    pipeline,\n    prompt=\"robotic cat with wings\",\n    guidance_scale=7.5,\n    num_inference_steps=50,\n    num_images_per_prompt=1,\n    seed=42,\n):\n    generator = torch.Generator(pipeline.device).manual_seed(seed)\n    images = pipeline(\n        prompt,\n        guidance_scale=guidance_scale,\n        num_inference_steps=num_inference_steps,\n        generator=generator,\n        num_images_per_prompt=num_images_per_prompt,\n    ).images\n    _rows = int(math.sqrt(num_images_per_prompt))\n    grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)\n    return grid, images\n\n\nargs = parse_args()\n# Load models and create wrapper for stable diffusion\ntokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\ntext_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"text_encoder\")\nvae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\")\nunet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"unet\")\n\npipeline = StableDiffusionPipeline.from_pretrained(\n    args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer\n)\npipeline.safety_checker = lambda images, clip_input: (images, False)\nif os.path.exists(os.path.join(args.pretrained_model_name_or_path, \"best_model.pt\")):\n    unet = load(args.pretrained_model_name_or_path, model=unet)\n    unet.eval()\n    setattr(pipeline, \"unet\", unet)\nelse:\n    unet = unet.to(torch.device(\"cuda\", args.cuda_id))\npipeline = pipeline.to(unet.device)\ngrid, images = generate_images(pipeline, prompt=args.caption, num_images_per_prompt=args.images_num, seed=args.seed)\ngrid.save(os.path.join(args.pretrained_model_name_or_path, \"{}.png\".format(\"_\".join(args.caption.split()))))\ndirname = os.path.join(args.pretrained_model_name_or_path, \"_\".join(args.caption.split()))\nos.makedirs(dirname, exist_ok=True)\nfor idx, image in enumerate(images):\n    image.save(os.path.join(dirname, \"{}.png\".format(idx + 1)))\n"
  },
  {
    "path": "examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py",
    "content": "import argparse\nimport itertools\nimport math\nimport os\nimport random\nfrom pathlib import Path\nfrom typing import Iterable\n\nimport numpy as np\nimport PIL\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom accelerate import Accelerator\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom neural_compressor.utils import logger\nfrom packaging import version\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import make_image_grid\n\n\nif version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.Resampling.BILINEAR,\n        \"bilinear\": PIL.Image.Resampling.BILINEAR,\n        \"bicubic\": PIL.Image.Resampling.BICUBIC,\n        \"lanczos\": PIL.Image.Resampling.LANCZOS,\n        \"nearest\": PIL.Image.Resampling.NEAREST,\n    }\nelse:\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.LINEAR,\n        \"bilinear\": PIL.Image.BILINEAR,\n        \"bicubic\": PIL.Image.BICUBIC,\n        \"lanczos\": PIL.Image.LANCZOS,\n        \"nearest\": PIL.Image.NEAREST,\n    }\n# ------------------------------------------------------------------------------\n\n\ndef save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):\n    logger.info(\"Saving embeddings\")\n    learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]\n    learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}\n    torch.save(learned_embeds_dict, save_path)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Example of distillation for quantization on Textual Inversion.\")\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=500,\n        help=\"Save learned_embeds.bin every X updates steps.\",\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\", type=str, default=None, required=True, help=\"A folder containing the training data.\"\n    )\n    parser.add_argument(\n        \"--placeholder_token\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A token to use as a placeholder for the concept.\",\n    )\n    parser.add_argument(\n        \"--initializer_token\", type=str, default=None, required=True, help=\"A token to use as initializer word.\"\n    )\n    parser.add_argument(\"--learnable_property\", type=str, default=\"object\", help=\"Choose between 'object' and 'style'\")\n    parser.add_argument(\"--repeats\", type=int, default=100, help=\"How many times to repeat the training data.\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=42, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\", action=\"store_true\", help=\"Whether to center crop images before resizing to resolution\"\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=5000,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--do_quantization\", action=\"store_true\", help=\"Whether or not to do quantization.\")\n    parser.add_argument(\"--do_distillation\", action=\"store_true\", help=\"Whether or not to do distillation.\")\n    parser.add_argument(\n        \"--verify_loading\", action=\"store_true\", help=\"Whether or not to verify the loading of the quantized model.\"\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.train_data_dir is None:\n        raise ValueError(\"You must specify a train data directory.\")\n\n    return args\n\n\nimagenet_templates_small = [\n    \"a photo of a {}\",\n    \"a rendering of a {}\",\n    \"a cropped photo of the {}\",\n    \"the photo of a {}\",\n    \"a photo of a clean {}\",\n    \"a photo of a dirty {}\",\n    \"a dark photo of the {}\",\n    \"a photo of my {}\",\n    \"a photo of the cool {}\",\n    \"a close-up photo of a {}\",\n    \"a bright photo of the {}\",\n    \"a cropped photo of a {}\",\n    \"a photo of the {}\",\n    \"a good photo of the {}\",\n    \"a photo of one {}\",\n    \"a close-up photo of the {}\",\n    \"a rendition of the {}\",\n    \"a photo of the clean {}\",\n    \"a rendition of a {}\",\n    \"a photo of a nice {}\",\n    \"a good photo of a {}\",\n    \"a photo of the nice {}\",\n    \"a photo of the small {}\",\n    \"a photo of the weird {}\",\n    \"a photo of the large {}\",\n    \"a photo of a cool {}\",\n    \"a photo of a small {}\",\n]\n\nimagenet_style_templates_small = [\n    \"a painting in the style of {}\",\n    \"a rendering in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"the painting in the style of {}\",\n    \"a clean painting in the style of {}\",\n    \"a dirty painting in the style of {}\",\n    \"a dark painting in the style of {}\",\n    \"a picture in the style of {}\",\n    \"a cool painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a bright painting in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"a good painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a rendition in the style of {}\",\n    \"a nice painting in the style of {}\",\n    \"a small painting in the style of {}\",\n    \"a weird painting in the style of {}\",\n    \"a large painting in the style of {}\",\n]\n\n\n# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14\nclass EMAModel:\n    \"\"\"\n    Exponential Moving Average of models weights\n    \"\"\"\n\n    def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):\n        parameters = list(parameters)\n        self.shadow_params = [p.clone().detach() for p in parameters]\n\n        self.decay = decay\n        self.optimization_step = 0\n\n    def get_decay(self, optimization_step):\n        \"\"\"\n        Compute the decay factor for the exponential moving average.\n        \"\"\"\n        value = (1 + optimization_step) / (10 + optimization_step)\n        return 1 - min(self.decay, value)\n\n    @torch.no_grad()\n    def step(self, parameters):\n        parameters = list(parameters)\n\n        self.optimization_step += 1\n        self.decay = self.get_decay(self.optimization_step)\n\n        for s_param, param in zip(self.shadow_params, parameters):\n            if param.requires_grad:\n                tmp = self.decay * (s_param - param)\n                s_param.sub_(tmp)\n            else:\n                s_param.copy_(param)\n\n        torch.cuda.empty_cache()\n\n    def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:\n        \"\"\"\n        Copy current averaged parameters into given collection of parameters.\n        Args:\n            parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n                updated with the stored moving averages. If `None`, the\n                parameters with which this `ExponentialMovingAverage` was\n                initialized will be used.\n        \"\"\"\n        parameters = list(parameters)\n        for s_param, param in zip(self.shadow_params, parameters):\n            param.data.copy_(s_param.data)\n\n    def to(self, device=None, dtype=None) -> None:\n        r\"\"\"Move internal buffers of the ExponentialMovingAverage to `device`.\n        Args:\n            device: like `device` argument to `torch.Tensor.to`\n        \"\"\"\n        # .to() on the tensors handles None correctly\n        self.shadow_params = [\n            p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)\n            for p in self.shadow_params\n        ]\n\n\nclass TextualInversionDataset(Dataset):\n    def __init__(\n        self,\n        data_root,\n        tokenizer,\n        learnable_property=\"object\",  # [object, style]\n        size=512,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.5,\n        set=\"train\",\n        placeholder_token=\"*\",\n        center_crop=False,\n    ):\n        self.data_root = data_root\n        self.tokenizer = tokenizer\n        self.learnable_property = learnable_property\n        self.size = size\n        self.placeholder_token = placeholder_token\n        self.center_crop = center_crop\n        self.flip_p = flip_p\n\n        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.interpolation = {\n            \"linear\": PIL_INTERPOLATION[\"linear\"],\n            \"bilinear\": PIL_INTERPOLATION[\"bilinear\"],\n            \"bicubic\": PIL_INTERPOLATION[\"bicubic\"],\n            \"lanczos\": PIL_INTERPOLATION[\"lanczos\"],\n        }[interpolation]\n\n        self.templates = imagenet_style_templates_small if learnable_property == \"style\" else imagenet_templates_small\n        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        placeholder_string = self.placeholder_token\n        text = random.choice(self.templates).format(placeholder_string)\n\n        example[\"input_ids\"] = self.tokenizer(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids[0]\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            (\n                h,\n                w,\n            ) = (\n                img.shape[0],\n                img.shape[1],\n            )\n            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]\n\n        image = Image.fromarray(img)\n        image = image.resize((self.size, self.size), resample=self.interpolation)\n\n        image = self.flip_transform(image)\n        image = np.array(image).astype(np.uint8)\n        image = (image / 127.5 - 1.0).astype(np.float32)\n\n        example[\"pixel_values\"] = torch.from_numpy(image).permute(2, 0, 1)\n        return example\n\n\ndef freeze_params(params):\n    for param in params:\n        param.requires_grad = False\n\n\ndef generate_images(pipeline, prompt=\"\", guidance_scale=7.5, num_inference_steps=50, num_images_per_prompt=1, seed=42):\n    generator = torch.Generator(pipeline.device).manual_seed(seed)\n    images = pipeline(\n        prompt,\n        guidance_scale=guidance_scale,\n        num_inference_steps=num_inference_steps,\n        generator=generator,\n        num_images_per_prompt=num_images_per_prompt,\n    ).images\n    _rows = int(math.sqrt(num_images_per_prompt))\n    grid = make_image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)\n    return grid\n\n\ndef main():\n    args = parse_args()\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=\"tensorboard\",\n        project_config=accelerator_project_config,\n    )\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer and add the placeholder token as a additional special token\n    if args.tokenizer_name:\n        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n    # Load models and create wrapper for stable diffusion\n    noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=args.revision,\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"unet\",\n        revision=args.revision,\n    )\n\n    train_unet = False\n    # Freeze vae and unet\n    freeze_params(vae.parameters())\n    if not args.do_quantization and not args.do_distillation:\n        # Add the placeholder token in tokenizer\n        num_added_tokens = tokenizer.add_tokens(args.placeholder_token)\n        if num_added_tokens == 0:\n            raise ValueError(\n                f\"The tokenizer already contains the token {args.placeholder_token}. Please pass a different\"\n                \" `placeholder_token` that is not already in the tokenizer.\"\n            )\n\n        # Convert the initializer_token, placeholder_token to ids\n        token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)\n        # Check if initializer_token is a single token or a sequence of tokens\n        if len(token_ids) > 1:\n            raise ValueError(\"The initializer token must be a single token.\")\n\n        initializer_token_id = token_ids[0]\n        placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)\n        # Resize the token embeddings as we are adding new special tokens to the tokenizer\n        text_encoder.resize_token_embeddings(len(tokenizer))\n\n        # Initialise the newly added placeholder token with the embeddings of the initializer token\n        token_embeds = text_encoder.get_input_embeddings().weight.data\n        token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]\n\n        freeze_params(unet.parameters())\n        # Freeze all parameters except for the token embeddings in text encoder\n        params_to_freeze = itertools.chain(\n            text_encoder.text_model.encoder.parameters(),\n            text_encoder.text_model.final_layer_norm.parameters(),\n            text_encoder.text_model.embeddings.position_embedding.parameters(),\n        )\n        freeze_params(params_to_freeze)\n    else:\n        train_unet = True\n        freeze_params(text_encoder.parameters())\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    optimizer = torch.optim.AdamW(\n        # only optimize the unet or embeddings of text_encoder\n        unet.parameters() if train_unet else text_encoder.get_input_embeddings().parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    train_dataset = TextualInversionDataset(\n        data_root=args.train_data_dir,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        placeholder_token=args.placeholder_token,\n        repeats=args.repeats,\n        learnable_property=args.learnable_property,\n        center_crop=args.center_crop,\n        set=\"train\",\n    )\n    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    if not train_unet:\n        text_encoder = accelerator.prepare(text_encoder)\n        unet.to(accelerator.device)\n        unet.eval()\n    else:\n        unet = accelerator.prepare(unet)\n        text_encoder.to(accelerator.device)\n        text_encoder.eval()\n    optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)\n\n    # Move vae to device\n    vae.to(accelerator.device)\n\n    # Keep vae in eval model as we don't train these\n    vae.eval()\n\n    compression_manager = None\n\n    def train_func(model):\n        if train_unet:\n            unet_ = model\n            text_encoder_ = text_encoder\n        else:\n            unet_ = unet\n            text_encoder_ = model\n        # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n        num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n        if overrode_max_train_steps:\n            args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        # Afterwards we recalculate our number of training epochs\n        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n        # We need to initialize the trackers we use, and also store our configuration.\n        # The trackers initializes automatically on the main process.\n        if accelerator.is_main_process:\n            accelerator.init_trackers(\"textual_inversion\", config=vars(args))\n\n        # Train!\n        total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n        logger.info(\"***** Running training *****\")\n        logger.info(f\"  Num examples = {len(train_dataset)}\")\n        logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n        logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n        logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n        logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n        logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n        # Only show the progress bar once on each machine.\n        progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)\n        progress_bar.set_description(\"Steps\")\n        global_step = 0\n\n        if train_unet and args.use_ema:\n            ema_unet = EMAModel(unet_.parameters())\n\n        for epoch in range(args.num_train_epochs):\n            model.train()\n            train_loss = 0.0\n            for step, batch in enumerate(train_dataloader):\n                with accelerator.accumulate(model):\n                    # Convert images to latent space\n                    latents = vae.encode(batch[\"pixel_values\"]).latent_dist.sample().detach()\n                    latents = latents * 0.18215\n\n                    # Sample noise that we'll add to the latents\n                    noise = torch.randn(latents.shape).to(latents.device)\n                    bsz = latents.shape[0]\n                    # Sample a random timestep for each image\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device\n                    ).long()\n\n                    # Add noise to the latents according to the noise magnitude at each timestep\n                    # (this is the forward diffusion process)\n                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                    # Get the text embedding for conditioning\n                    encoder_hidden_states = text_encoder_(batch[\"input_ids\"])[0]\n\n                    # Predict the noise residual\n                    model_pred = unet_(noisy_latents, timesteps, encoder_hidden_states).sample\n\n                    loss = F.mse_loss(model_pred, noise, reduction=\"none\").mean([1, 2, 3]).mean()\n                    if train_unet and compression_manager:\n                        unet_inputs = {\n                            \"sample\": noisy_latents,\n                            \"timestep\": timesteps,\n                            \"encoder_hidden_states\": encoder_hidden_states,\n                        }\n                        loss = compression_manager.callbacks.on_after_compute_loss(unet_inputs, model_pred, loss)\n\n                    # Gather the losses across all processes for logging (if we use distributed training).\n                    avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                    train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                    # Backpropagate\n                    accelerator.backward(loss)\n\n                    if train_unet:\n                        if accelerator.sync_gradients:\n                            accelerator.clip_grad_norm_(unet_.parameters(), args.max_grad_norm)\n                    else:\n                        # Zero out the gradients for all token embeddings except the newly added\n                        # embeddings for the concept, as we only want to optimize the concept embeddings\n                        if accelerator.num_processes > 1:\n                            grads = text_encoder_.module.get_input_embeddings().weight.grad\n                        else:\n                            grads = text_encoder_.get_input_embeddings().weight.grad\n                        # Get the index for tokens that we want to zero the grads for\n                        index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id\n                        grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)\n\n                    optimizer.step()\n                    lr_scheduler.step()\n                    optimizer.zero_grad()\n\n                # Checks if the accelerator has performed an optimization step behind the scenes\n                if accelerator.sync_gradients:\n                    if train_unet and args.use_ema:\n                        ema_unet.step(unet_.parameters())\n                    progress_bar.update(1)\n                    global_step += 1\n                    accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                    train_loss = 0.0\n                    if not train_unet and global_step % args.save_steps == 0:\n                        save_path = os.path.join(args.output_dir, f\"learned_embeds-steps-{global_step}.bin\")\n                        save_progress(text_encoder_, placeholder_token_id, accelerator, args, save_path)\n\n                logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n                progress_bar.set_postfix(**logs)\n                accelerator.log(logs, step=global_step)\n\n                if global_step >= args.max_train_steps:\n                    break\n            accelerator.wait_for_everyone()\n\n        if train_unet and args.use_ema:\n            ema_unet.copy_to(unet_.parameters())\n\n        if not train_unet:\n            return text_encoder_\n\n    if not train_unet:\n        text_encoder = train_func(text_encoder)\n    else:\n        import copy\n\n        model = copy.deepcopy(unet)\n        confs = []\n        if args.do_quantization:\n            from neural_compressor import QuantizationAwareTrainingConfig\n\n            q_conf = QuantizationAwareTrainingConfig()\n            confs.append(q_conf)\n\n        if args.do_distillation:\n            teacher_model = copy.deepcopy(model)\n\n            def attention_fetcher(x):\n                return x.sample\n\n            layer_mappings = [\n                [\n                    [\n                        \"conv_in\",\n                    ]\n                ],\n                [\n                    [\n                        \"time_embedding\",\n                    ]\n                ],\n                [[\"down_blocks.0.attentions.0\", attention_fetcher]],\n                [[\"down_blocks.0.attentions.1\", attention_fetcher]],\n                [\n                    [\n                        \"down_blocks.0.resnets.0\",\n                    ]\n                ],\n                [\n                    [\n                        \"down_blocks.0.resnets.1\",\n                    ]\n                ],\n                [\n                    [\n                        \"down_blocks.0.downsamplers.0\",\n                    ]\n                ],\n                [[\"down_blocks.1.attentions.0\", attention_fetcher]],\n                [[\"down_blocks.1.attentions.1\", attention_fetcher]],\n                [\n                    [\n                        \"down_blocks.1.resnets.0\",\n                    ]\n                ],\n                [\n                    [\n                        \"down_blocks.1.resnets.1\",\n                    ]\n                ],\n                [\n                    [\n                        \"down_blocks.1.downsamplers.0\",\n                    ]\n                ],\n                [[\"down_blocks.2.attentions.0\", attention_fetcher]],\n                [[\"down_blocks.2.attentions.1\", attention_fetcher]],\n                [\n                    [\n                        \"down_blocks.2.resnets.0\",\n                    ]\n                ],\n                [\n                    [\n                        \"down_blocks.2.resnets.1\",\n                    ]\n                ],\n                [\n                    [\n                        \"down_blocks.2.downsamplers.0\",\n                    ]\n                ],\n                [\n                    [\n                        \"down_blocks.3.resnets.0\",\n                    ]\n                ],\n                [\n                    [\n                        \"down_blocks.3.resnets.1\",\n                    ]\n                ],\n                [\n                    [\n                        \"up_blocks.0.resnets.0\",\n                    ]\n                ],\n                [\n                    [\n                        \"up_blocks.0.resnets.1\",\n                    ]\n                ],\n                [\n                    [\n                        \"up_blocks.0.resnets.2\",\n                    ]\n                ],\n                [\n                    [\n                        \"up_blocks.0.upsamplers.0\",\n                    ]\n                ],\n                [[\"up_blocks.1.attentions.0\", attention_fetcher]],\n                [[\"up_blocks.1.attentions.1\", attention_fetcher]],\n                [[\"up_blocks.1.attentions.2\", attention_fetcher]],\n                [\n                    [\n                        \"up_blocks.1.resnets.0\",\n                    ]\n                ],\n                [\n                    [\n                        \"up_blocks.1.resnets.1\",\n                    ]\n                ],\n                [\n                    [\n                        \"up_blocks.1.resnets.2\",\n                    ]\n                ],\n                [\n                    [\n                        \"up_blocks.1.upsamplers.0\",\n                    ]\n                ],\n                [[\"up_blocks.2.attentions.0\", attention_fetcher]],\n                [[\"up_blocks.2.attentions.1\", attention_fetcher]],\n                [[\"up_blocks.2.attentions.2\", attention_fetcher]],\n                [\n                    [\n                        \"up_blocks.2.resnets.0\",\n                    ]\n                ],\n                [\n                    [\n                        \"up_blocks.2.resnets.1\",\n                    ]\n                ],\n                [\n                    [\n                        \"up_blocks.2.resnets.2\",\n                    ]\n                ],\n                [\n                    [\n                        \"up_blocks.2.upsamplers.0\",\n                    ]\n                ],\n                [[\"up_blocks.3.attentions.0\", attention_fetcher]],\n                [[\"up_blocks.3.attentions.1\", attention_fetcher]],\n                [[\"up_blocks.3.attentions.2\", attention_fetcher]],\n                [\n                    [\n                        \"up_blocks.3.resnets.0\",\n                    ]\n                ],\n                [\n                    [\n                        \"up_blocks.3.resnets.1\",\n                    ]\n                ],\n                [\n                    [\n                        \"up_blocks.3.resnets.2\",\n                    ]\n                ],\n                [[\"mid_block.attentions.0\", attention_fetcher]],\n                [\n                    [\n                        \"mid_block.resnets.0\",\n                    ]\n                ],\n                [\n                    [\n                        \"mid_block.resnets.1\",\n                    ]\n                ],\n                [\n                    [\n                        \"conv_out\",\n                    ]\n                ],\n            ]\n            layer_names = [layer_mapping[0][0] for layer_mapping in layer_mappings]\n            if not set(layer_names).issubset([n[0] for n in model.named_modules()]):\n                raise ValueError(\n                    \"Provided model is not compatible with the default layer_mappings, \"\n                    'please use the model fine-tuned from \"CompVis/stable-diffusion-v1-4\", '\n                    \"or modify the layer_mappings variable to fit your model.\"\n                    f\"\\nDefault layer_mappings are as such:\\n{layer_mappings}\"\n                )\n            from neural_compressor.config import DistillationConfig, IntermediateLayersKnowledgeDistillationLossConfig\n\n            distillation_criterion = IntermediateLayersKnowledgeDistillationLossConfig(\n                layer_mappings=layer_mappings,\n                loss_types=[\"MSE\"] * len(layer_mappings),\n                loss_weights=[1.0 / len(layer_mappings)] * len(layer_mappings),\n                add_origin_loss=True,\n            )\n            d_conf = DistillationConfig(teacher_model=teacher_model, criterion=distillation_criterion)\n            confs.append(d_conf)\n\n        from neural_compressor.training import prepare_compression\n\n        compression_manager = prepare_compression(model, confs)\n        compression_manager.callbacks.on_train_begin()\n        model = compression_manager.model\n        train_func(model)\n        compression_manager.callbacks.on_train_end()\n\n        # Save the resulting model and its corresponding configuration in the given directory\n        model.save(args.output_dir)\n\n        logger.info(f\"Optimized model saved to: {args.output_dir}.\")\n\n        # change to framework model for further use\n        model = model.model\n\n    # Create the pipeline using using the trained modules and save it.\n    templates = imagenet_style_templates_small if args.learnable_property == \"style\" else imagenet_templates_small\n    prompt = templates[0].format(args.placeholder_token)\n    if accelerator.is_main_process:\n        pipeline = StableDiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            text_encoder=accelerator.unwrap_model(text_encoder),\n            vae=vae,\n            unet=accelerator.unwrap_model(unet),\n            tokenizer=tokenizer,\n        )\n        pipeline.save_pretrained(args.output_dir)\n        pipeline = pipeline.to(unet.device)\n        baseline_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)\n        baseline_model_images.save(\n            os.path.join(args.output_dir, \"{}_baseline_model.png\".format(\"_\".join(prompt.split())))\n        )\n\n        if not train_unet:\n            # Also save the newly trained embeddings\n            save_path = os.path.join(args.output_dir, \"learned_embeds.bin\")\n            save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)\n        else:\n            setattr(pipeline, \"unet\", accelerator.unwrap_model(model))\n            if args.do_quantization:\n                pipeline = pipeline.to(torch.device(\"cpu\"))\n\n            optimized_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)\n            optimized_model_images.save(\n                os.path.join(args.output_dir, \"{}_optimized_model.png\".format(\"_\".join(prompt.split())))\n            )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n    if args.do_quantization and args.verify_loading:\n        # Load the model obtained after Intel Neural Compressor quantization\n        from neural_compressor.utils.pytorch import load\n\n        loaded_model = load(args.output_dir, model=unet)\n        loaded_model.eval()\n\n        setattr(pipeline, \"unet\", loaded_model)\n        if args.do_quantization:\n            pipeline = pipeline.to(torch.device(\"cpu\"))\n\n        loaded_model_images = generate_images(pipeline, prompt=prompt, seed=args.seed)\n        if loaded_model_images != optimized_model_images:\n            logger.info(\"The quantized model was not successfully loaded.\")\n        else:\n            logger.info(\"The quantized model was successfully loaded.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/ip_adapter/README.md",
    "content": "# IP Adapter Training Example \n\n[IP Adapter](https://huggingface.co/papers/2308.06721) is a novel approach designed to enhance text-to-image models such as Stable Diffusion by enabling them to generate images based on image prompts rather than text prompts alone. Unlike traditional methods that rely solely on complex text prompts, IP Adapter introduces the concept of using image prompts, leveraging the idea that \"an image is worth a thousand words.\" By decoupling cross-attention layers for text and image features, IP Adapter effectively integrates image prompts into the generation process without the need for extensive fine-tuning or large computing resources.\n\n## Training locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the example folder and run\n\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell e.g. a notebook\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nCertainly! Below is the documentation in pure Markdown format:\n\n### Accelerate Launch Command Documentation\n\n#### Description:\nThe Accelerate launch command is used to train a model using multiple GPUs and mixed precision training. It launches the training script `tutorial_train_ip-adapter.py` with specified parameters and configurations.\n\n#### Usage Example:\n\n```\naccelerate launch --mixed_precision \"fp16\" \\\ntutorial_train_ip-adapter.py \\\n--pretrained_model_name_or_path=\"stable-diffusion-v1-5/stable-diffusion-v1-5/\" \\\n--image_encoder_path=\"{image_encoder_path}\" \\\n--data_json_file=\"{data.json}\" \\\n--data_root_path=\"{image_path}\" \\\n--mixed_precision=\"fp16\" \\\n--resolution=512 \\\n--train_batch_size=8 \\\n--dataloader_num_workers=4 \\\n--learning_rate=1e-04 \\\n--weight_decay=0.01 \\\n--output_dir=\"{output_dir}\" \\\n--save_steps=10000\n```\n\n### Multi-GPU Script:\n```\naccelerate launch --num_processes 8 --multi_gpu --mixed_precision \"fp16\" \\\n  tutorial_train_ip-adapter.py \\\n  --pretrained_model_name_or_path=\"stable-diffusion-v1-5/stable-diffusion-v1-5/\" \\\n  --image_encoder_path=\"{image_encoder_path}\" \\\n  --data_json_file=\"{data.json}\" \\\n  --data_root_path=\"{image_path}\" \\\n  --mixed_precision=\"fp16\" \\\n  --resolution=512 \\\n  --train_batch_size=8 \\\n  --dataloader_num_workers=4 \\\n  --learning_rate=1e-04 \\\n  --weight_decay=0.01 \\\n  --output_dir=\"{output_dir}\" \\\n  --save_steps=10000\n```\n\n#### Parameters:\n- `--num_processes`: Number of processes to launch for distributed training (in this example, 8 processes).\n- `--multi_gpu`: Flag indicating the usage of multiple GPUs for training.\n- `--mixed_precision \"fp16\"`: Enables mixed precision training with 16-bit floating-point precision.\n- `tutorial_train_ip-adapter.py`: Name of the training script to be executed.\n- `--pretrained_model_name_or_path`: Path or identifier for a pretrained model.\n- `--image_encoder_path`: Path to the CLIP image encoder.\n- `--data_json_file`: Path to the training data in JSON format.\n- `--data_root_path`: Root path where training images are located.\n- `--resolution`: Resolution of input images (512x512 in this example).\n- `--train_batch_size`: Batch size for training data (8 in this example).\n- `--dataloader_num_workers`: Number of subprocesses for data loading (4 in this example).\n- `--learning_rate`: Learning rate for training (1e-04 in this example).\n- `--weight_decay`: Weight decay for regularization (0.01 in this example).\n- `--output_dir`: Directory to save model checkpoints and predictions.\n- `--save_steps`: Frequency of saving checkpoints during training (10000 in this example).\n\n### Inference\n\n#### Description:\nThe provided inference code is used to load a trained model checkpoint and extract the components related to image projection and IP (Image Processing) adapter. These components are then saved into a binary file for later use in inference.\n\n#### Usage Example:\n```python\nfrom safetensors.torch import load_file, save_file\n\n# Load the trained model checkpoint in safetensors format\nckpt = \"checkpoint-50000/pytorch_model.safetensors\"\nsd = load_file(ckpt)  # Using safetensors load function\n\n# Extract image projection and IP adapter components\nimage_proj_sd = {}\nip_sd = {}\n\nfor k in sd:\n    if k.startswith(\"unet\"):\n        pass  # Skip unet-related keys\n    elif k.startswith(\"image_proj_model\"):\n        image_proj_sd[k.replace(\"image_proj_model.\", \"\")] = sd[k]\n    elif k.startswith(\"adapter_modules\"):\n        ip_sd[k.replace(\"adapter_modules.\", \"\")] = sd[k]\n\n# Save the components into separate safetensors files\nsave_file(image_proj_sd, \"image_proj.safetensors\")\nsave_file(ip_sd, \"ip_adapter.safetensors\")\n```\n\n### Sample Inference Script using the CLIP Model\n\n```python\n\nimport torch\nfrom safetensors.torch import load_file\nfrom transformers import CLIPProcessor, CLIPModel  # Using the Hugging Face CLIP model \n\n# Load model components from safetensors\nimage_proj_ckpt = \"image_proj.safetensors\"\nip_adapter_ckpt = \"ip_adapter.safetensors\"\n\n# Load the saved weights\nimage_proj_sd = load_file(image_proj_ckpt)\nip_adapter_sd = load_file(ip_adapter_ckpt)\n\n# Define the model Parameters\nclass ImageProjectionModel(torch.nn.Module):\n    def __init__(self, input_dim=768, output_dim=512):  # CLIP's default embedding size is 768\n        super().__init__()\n        self.model = torch.nn.Linear(input_dim, output_dim)\n\n    def forward(self, x):\n        return self.model(x)\n\nclass IPAdapterModel(torch.nn.Module):\n    def __init__(self, input_dim=512, output_dim=10):  # Example for 10 classes\n        super().__init__()\n        self.model = torch.nn.Linear(input_dim, output_dim)\n\n    def forward(self, x):\n        return self.model(x)\n\n# Initialize models\nimage_proj_model = ImageProjectionModel()\nip_adapter_model = IPAdapterModel()\n\n# Load weights into models\nimage_proj_model.load_state_dict(image_proj_sd)\nip_adapter_model.load_state_dict(ip_adapter_sd)\n\n# Set models to evaluation mode\nimage_proj_model.eval()\nip_adapter_model.eval()\n\n#Inference pipeline\ndef inference(image_tensor):\n    \"\"\"\n    Run inference using the loaded models.\n\n    Args:\n        image_tensor: Preprocessed image tensor from CLIPProcessor\n\n    Returns:\n        Final inference results\n    \"\"\"\n    with torch.no_grad():\n        # Step 1: Project the image features\n        image_proj = image_proj_model(image_tensor)\n\n        # Step 2: Pass the projected features through the IP Adapter\n        result = ip_adapter_model(image_proj)\n\n    return result\n\n# Using CLIP for image preprocessing\nprocessor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\nclip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n\n#Image file path\nimage_path = \"path/to/image.jpg\"\n\n# Preprocess the image\ninputs = processor(images=image_path, return_tensors=\"pt\")\nimage_features = clip_model.get_image_features(inputs[\"pixel_values\"])\n\n# Normalize the image features as per CLIP's recommendations\nimage_features = image_features / image_features.norm(dim=-1, keepdim=True)\n\n# Run inference\noutput = inference(image_features)\nprint(\"Inference output:\", output)\n```\n\n#### Parameters:\n- `ckpt`: Path to the trained model checkpoint file.\n- `map_location=\"cpu\"`: Specifies that the model should be loaded onto the CPU.\n- `image_proj_sd`: Dictionary to store the components related to image projection.\n- `ip_sd`: Dictionary to store the components related to the IP adapter.\n- `\"unet\"`, `\"image_proj_model\"`, `\"adapter_modules\"`: Prefixes indicating components of the model."
  },
  {
    "path": "examples/research_projects/ip_adapter/requirements.txt",
    "content": "accelerate\ntorchvision\ntransformers>=4.25.1\nip_adapter\n"
  },
  {
    "path": "examples/research_projects/ip_adapter/tutorial_train_faceid.py",
    "content": "import argparse\nimport itertools\nimport json\nimport os\nimport random\nimport time\nfrom pathlib import Path\n\nimport torch\nimport torch.nn.functional as F\nfrom accelerate import Accelerator\nfrom accelerate.utils import ProjectConfiguration\nfrom ip_adapter.attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor\nfrom ip_adapter.ip_adapter_faceid import MLPProjModel\nfrom PIL import Image\nfrom torchvision import transforms\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel\n\n\n# Dataset\nclass MyDataset(torch.utils.data.Dataset):\n    def __init__(\n        self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=\"\"\n    ):\n        super().__init__()\n\n        self.tokenizer = tokenizer\n        self.size = size\n        self.i_drop_rate = i_drop_rate\n        self.t_drop_rate = t_drop_rate\n        self.ti_drop_rate = ti_drop_rate\n        self.image_root_path = image_root_path\n\n        self.data = json.load(\n            open(json_file)\n        )  # list of dict: [{\"image_file\": \"1.png\", \"id_embed_file\": \"faceid.bin\"}]\n\n        self.transform = transforms.Compose(\n            [\n                transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(self.size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __getitem__(self, idx):\n        item = self.data[idx]\n        text = item[\"text\"]\n        image_file = item[\"image_file\"]\n\n        # read image\n        raw_image = Image.open(os.path.join(self.image_root_path, image_file))\n        image = self.transform(raw_image.convert(\"RGB\"))\n\n        face_id_embed = torch.load(item[\"id_embed_file\"], map_location=\"cpu\")\n        face_id_embed = torch.from_numpy(face_id_embed)\n\n        # drop\n        drop_image_embed = 0\n        rand_num = random.random()\n        if rand_num < self.i_drop_rate:\n            drop_image_embed = 1\n        elif rand_num < (self.i_drop_rate + self.t_drop_rate):\n            text = \"\"\n        elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):\n            text = \"\"\n            drop_image_embed = 1\n        if drop_image_embed:\n            face_id_embed = torch.zeros_like(face_id_embed)\n        # get text and tokenize\n        text_input_ids = self.tokenizer(\n            text,\n            max_length=self.tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n        ).input_ids\n\n        return {\n            \"image\": image,\n            \"text_input_ids\": text_input_ids,\n            \"face_id_embed\": face_id_embed,\n            \"drop_image_embed\": drop_image_embed,\n        }\n\n    def __len__(self):\n        return len(self.data)\n\n\ndef collate_fn(data):\n    images = torch.stack([example[\"image\"] for example in data])\n    text_input_ids = torch.cat([example[\"text_input_ids\"] for example in data], dim=0)\n    face_id_embed = torch.stack([example[\"face_id_embed\"] for example in data])\n    drop_image_embeds = [example[\"drop_image_embed\"] for example in data]\n\n    return {\n        \"images\": images,\n        \"text_input_ids\": text_input_ids,\n        \"face_id_embed\": face_id_embed,\n        \"drop_image_embeds\": drop_image_embeds,\n    }\n\n\nclass IPAdapter(torch.nn.Module):\n    \"\"\"IP-Adapter\"\"\"\n\n    def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):\n        super().__init__()\n        self.unet = unet\n        self.image_proj_model = image_proj_model\n        self.adapter_modules = adapter_modules\n\n        if ckpt_path is not None:\n            self.load_from_checkpoint(ckpt_path)\n\n    def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):\n        ip_tokens = self.image_proj_model(image_embeds)\n        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)\n        # Predict the noise residual\n        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample\n        return noise_pred\n\n    def load_from_checkpoint(self, ckpt_path: str):\n        # Calculate original checksums\n        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))\n        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))\n\n        state_dict = torch.load(ckpt_path, map_location=\"cpu\")\n\n        # Load state dict for image_proj_model and adapter_modules\n        self.image_proj_model.load_state_dict(state_dict[\"image_proj\"], strict=True)\n        self.adapter_modules.load_state_dict(state_dict[\"ip_adapter\"], strict=True)\n\n        # Calculate new checksums\n        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))\n        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))\n\n        # Verify if the weights have changed\n        assert orig_ip_proj_sum != new_ip_proj_sum, \"Weights of image_proj_model did not change!\"\n        assert orig_adapter_sum != new_adapter_sum, \"Weights of adapter_modules did not change!\"\n\n        print(f\"Successfully loaded weights from checkpoint {ckpt_path}\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_ip_adapter_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained ip adapter model. If not specified weights are initialized randomly.\",\n    )\n    parser.add_argument(\n        \"--data_json_file\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Training data\",\n    )\n    parser.add_argument(\n        \"--data_root_path\",\n        type=str,\n        default=\"\",\n        required=True,\n        help=\"Training data root path\",\n    )\n    parser.add_argument(\n        \"--image_encoder_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to CLIP image encoder\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-ip_adapter\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\"The resolution for input images\"),\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Learning rate to use.\",\n    )\n    parser.add_argument(\"--weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=8, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=2000,\n        help=(\"Save a checkpoint of the training state every X updates\"),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\ndef main():\n    args = parse_args()\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"text_encoder\")\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\")\n    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"unet\")\n    # image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)\n    # freeze parameters of models to save more memory\n    unet.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    # image_encoder.requires_grad_(False)\n\n    # ip-adapter\n    image_proj_model = MLPProjModel(\n        cross_attention_dim=unet.config.cross_attention_dim,\n        id_embeddings_dim=512,\n        num_tokens=4,\n    )\n    # init adapter modules\n    lora_rank = 128\n    attn_procs = {}\n    unet_sd = unet.state_dict()\n    for name in unet.attn_processors.keys():\n        cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n        if name.startswith(\"mid_block\"):\n            hidden_size = unet.config.block_out_channels[-1]\n        elif name.startswith(\"up_blocks\"):\n            block_id = int(name[len(\"up_blocks.\")])\n            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n        elif name.startswith(\"down_blocks\"):\n            block_id = int(name[len(\"down_blocks.\")])\n            hidden_size = unet.config.block_out_channels[block_id]\n        if cross_attention_dim is None:\n            attn_procs[name] = LoRAAttnProcessor(\n                hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank\n            )\n        else:\n            layer_name = name.split(\".processor\")[0]\n            weights = {\n                \"to_k_ip.weight\": unet_sd[layer_name + \".to_k.weight\"],\n                \"to_v_ip.weight\": unet_sd[layer_name + \".to_v.weight\"],\n            }\n            attn_procs[name] = LoRAIPAttnProcessor(\n                hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank\n            )\n            attn_procs[name].load_state_dict(weights, strict=False)\n    unet.set_attn_processor(attn_procs)\n    adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())\n\n    ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)\n\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n    # unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    # image_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # optimizer\n    params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())\n    optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)\n\n    # dataloader\n    train_dataset = MyDataset(\n        args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path\n    )\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Prepare everything with our `accelerator`.\n    ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)\n\n    global_step = 0\n    for epoch in range(0, args.num_train_epochs):\n        begin = time.perf_counter()\n        for step, batch in enumerate(train_dataloader):\n            load_data_time = time.perf_counter() - begin\n            with accelerator.accumulate(ip_adapter):\n                # Convert images to latent space\n                with torch.no_grad():\n                    latents = vae.encode(\n                        batch[\"images\"].to(accelerator.device, dtype=weight_dtype)\n                    ).latent_dist.sample()\n                    latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                image_embeds = batch[\"face_id_embed\"].to(accelerator.device, dtype=weight_dtype)\n\n                with torch.no_grad():\n                    encoder_hidden_states = text_encoder(batch[\"text_input_ids\"].to(accelerator.device))[0]\n\n                noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)\n\n                loss = F.mse_loss(noise_pred.float(), noise.float(), reduction=\"mean\")\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()\n\n                # Backpropagate\n                accelerator.backward(loss)\n                optimizer.step()\n                optimizer.zero_grad()\n\n                if accelerator.is_main_process:\n                    print(\n                        \"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}\".format(\n                            epoch, step, load_data_time, time.perf_counter() - begin, avg_loss\n                        )\n                    )\n\n            global_step += 1\n\n            if global_step % args.save_steps == 0:\n                save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                accelerator.save_state(save_path)\n\n            begin = time.perf_counter()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/ip_adapter/tutorial_train_ip-adapter.py",
    "content": "import argparse\nimport itertools\nimport json\nimport os\nimport random\nimport time\nfrom pathlib import Path\n\nimport torch\nimport torch.nn.functional as F\nfrom accelerate import Accelerator\nfrom accelerate.utils import ProjectConfiguration\nfrom ip_adapter.ip_adapter import ImageProjModel\nfrom ip_adapter.utils import is_torch2_available\nfrom PIL import Image\nfrom torchvision import transforms\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel\n\n\nif is_torch2_available():\n    from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor\n    from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor\nelse:\n    from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor\n\n\n# Dataset\nclass MyDataset(torch.utils.data.Dataset):\n    def __init__(\n        self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=\"\"\n    ):\n        super().__init__()\n\n        self.tokenizer = tokenizer\n        self.size = size\n        self.i_drop_rate = i_drop_rate\n        self.t_drop_rate = t_drop_rate\n        self.ti_drop_rate = ti_drop_rate\n        self.image_root_path = image_root_path\n\n        self.data = json.load(open(json_file))  # list of dict: [{\"image_file\": \"1.png\", \"text\": \"A dog\"}]\n\n        self.transform = transforms.Compose(\n            [\n                transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(self.size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        self.clip_image_processor = CLIPImageProcessor()\n\n    def __getitem__(self, idx):\n        item = self.data[idx]\n        text = item[\"text\"]\n        image_file = item[\"image_file\"]\n\n        # read image\n        raw_image = Image.open(os.path.join(self.image_root_path, image_file))\n        image = self.transform(raw_image.convert(\"RGB\"))\n        clip_image = self.clip_image_processor(images=raw_image, return_tensors=\"pt\").pixel_values\n\n        # drop\n        drop_image_embed = 0\n        rand_num = random.random()\n        if rand_num < self.i_drop_rate:\n            drop_image_embed = 1\n        elif rand_num < (self.i_drop_rate + self.t_drop_rate):\n            text = \"\"\n        elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):\n            text = \"\"\n            drop_image_embed = 1\n        # get text and tokenize\n        text_input_ids = self.tokenizer(\n            text,\n            max_length=self.tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n        ).input_ids\n\n        return {\n            \"image\": image,\n            \"text_input_ids\": text_input_ids,\n            \"clip_image\": clip_image,\n            \"drop_image_embed\": drop_image_embed,\n        }\n\n    def __len__(self):\n        return len(self.data)\n\n\ndef collate_fn(data):\n    images = torch.stack([example[\"image\"] for example in data])\n    text_input_ids = torch.cat([example[\"text_input_ids\"] for example in data], dim=0)\n    clip_images = torch.cat([example[\"clip_image\"] for example in data], dim=0)\n    drop_image_embeds = [example[\"drop_image_embed\"] for example in data]\n\n    return {\n        \"images\": images,\n        \"text_input_ids\": text_input_ids,\n        \"clip_images\": clip_images,\n        \"drop_image_embeds\": drop_image_embeds,\n    }\n\n\nclass IPAdapter(torch.nn.Module):\n    \"\"\"IP-Adapter\"\"\"\n\n    def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):\n        super().__init__()\n        self.unet = unet\n        self.image_proj_model = image_proj_model\n        self.adapter_modules = adapter_modules\n\n        if ckpt_path is not None:\n            self.load_from_checkpoint(ckpt_path)\n\n    def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):\n        ip_tokens = self.image_proj_model(image_embeds)\n        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)\n        # Predict the noise residual\n        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample\n        return noise_pred\n\n    def load_from_checkpoint(self, ckpt_path: str):\n        # Calculate original checksums\n        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))\n        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))\n\n        state_dict = torch.load(ckpt_path, map_location=\"cpu\")\n\n        # Load state dict for image_proj_model and adapter_modules\n        self.image_proj_model.load_state_dict(state_dict[\"image_proj\"], strict=True)\n        self.adapter_modules.load_state_dict(state_dict[\"ip_adapter\"], strict=True)\n\n        # Calculate new checksums\n        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))\n        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))\n\n        # Verify if the weights have changed\n        assert orig_ip_proj_sum != new_ip_proj_sum, \"Weights of image_proj_model did not change!\"\n        assert orig_adapter_sum != new_adapter_sum, \"Weights of adapter_modules did not change!\"\n\n        print(f\"Successfully loaded weights from checkpoint {ckpt_path}\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_ip_adapter_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained ip adapter model. If not specified weights are initialized randomly.\",\n    )\n    parser.add_argument(\n        \"--data_json_file\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Training data\",\n    )\n    parser.add_argument(\n        \"--data_root_path\",\n        type=str,\n        default=\"\",\n        required=True,\n        help=\"Training data root path\",\n    )\n    parser.add_argument(\n        \"--image_encoder_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to CLIP image encoder\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-ip_adapter\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\"The resolution for input images\"),\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Learning rate to use.\",\n    )\n    parser.add_argument(\"--weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=8, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=2000,\n        help=(\"Save a checkpoint of the training state every X updates\"),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\ndef main():\n    args = parse_args()\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"text_encoder\")\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\")\n    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"unet\")\n    image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)\n    # freeze parameters of models to save more memory\n    unet.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    image_encoder.requires_grad_(False)\n\n    # ip-adapter\n    image_proj_model = ImageProjModel(\n        cross_attention_dim=unet.config.cross_attention_dim,\n        clip_embeddings_dim=image_encoder.config.projection_dim,\n        clip_extra_context_tokens=4,\n    )\n    # init adapter modules\n    attn_procs = {}\n    unet_sd = unet.state_dict()\n    for name in unet.attn_processors.keys():\n        cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n        if name.startswith(\"mid_block\"):\n            hidden_size = unet.config.block_out_channels[-1]\n        elif name.startswith(\"up_blocks\"):\n            block_id = int(name[len(\"up_blocks.\")])\n            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n        elif name.startswith(\"down_blocks\"):\n            block_id = int(name[len(\"down_blocks.\")])\n            hidden_size = unet.config.block_out_channels[block_id]\n        if cross_attention_dim is None:\n            attn_procs[name] = AttnProcessor()\n        else:\n            layer_name = name.split(\".processor\")[0]\n            weights = {\n                \"to_k_ip.weight\": unet_sd[layer_name + \".to_k.weight\"],\n                \"to_v_ip.weight\": unet_sd[layer_name + \".to_v.weight\"],\n            }\n            attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)\n            attn_procs[name].load_state_dict(weights)\n    unet.set_attn_processor(attn_procs)\n    adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())\n\n    ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)\n\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n    # unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    image_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # optimizer\n    params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())\n    optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)\n\n    # dataloader\n    train_dataset = MyDataset(\n        args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path\n    )\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Prepare everything with our `accelerator`.\n    ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)\n\n    global_step = 0\n    for epoch in range(0, args.num_train_epochs):\n        begin = time.perf_counter()\n        for step, batch in enumerate(train_dataloader):\n            load_data_time = time.perf_counter() - begin\n            with accelerator.accumulate(ip_adapter):\n                # Convert images to latent space\n                with torch.no_grad():\n                    latents = vae.encode(\n                        batch[\"images\"].to(accelerator.device, dtype=weight_dtype)\n                    ).latent_dist.sample()\n                    latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                with torch.no_grad():\n                    image_embeds = image_encoder(\n                        batch[\"clip_images\"].to(accelerator.device, dtype=weight_dtype)\n                    ).image_embeds\n                image_embeds_ = []\n                for image_embed, drop_image_embed in zip(image_embeds, batch[\"drop_image_embeds\"]):\n                    if drop_image_embed == 1:\n                        image_embeds_.append(torch.zeros_like(image_embed))\n                    else:\n                        image_embeds_.append(image_embed)\n                image_embeds = torch.stack(image_embeds_)\n\n                with torch.no_grad():\n                    encoder_hidden_states = text_encoder(batch[\"text_input_ids\"].to(accelerator.device))[0]\n\n                noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)\n\n                loss = F.mse_loss(noise_pred.float(), noise.float(), reduction=\"mean\")\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()\n\n                # Backpropagate\n                accelerator.backward(loss)\n                optimizer.step()\n                optimizer.zero_grad()\n\n                if accelerator.is_main_process:\n                    print(\n                        \"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}\".format(\n                            epoch, step, load_data_time, time.perf_counter() - begin, avg_loss\n                        )\n                    )\n\n            global_step += 1\n\n            if global_step % args.save_steps == 0:\n                save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                accelerator.save_state(save_path)\n\n            begin = time.perf_counter()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/ip_adapter/tutorial_train_plus.py",
    "content": "import argparse\nimport itertools\nimport json\nimport os\nimport random\nimport time\nfrom pathlib import Path\n\nimport torch\nimport torch.nn.functional as F\nfrom accelerate import Accelerator\nfrom accelerate.utils import ProjectConfiguration\nfrom ip_adapter.resampler import Resampler\nfrom ip_adapter.utils import is_torch2_available\nfrom PIL import Image\nfrom torchvision import transforms\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel\n\n\nif is_torch2_available():\n    from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor\n    from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor\nelse:\n    from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor\n\n\n# Dataset\nclass MyDataset(torch.utils.data.Dataset):\n    def __init__(\n        self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=\"\"\n    ):\n        super().__init__()\n\n        self.tokenizer = tokenizer\n        self.size = size\n        self.i_drop_rate = i_drop_rate\n        self.t_drop_rate = t_drop_rate\n        self.ti_drop_rate = ti_drop_rate\n        self.image_root_path = image_root_path\n\n        self.data = json.load(open(json_file))  # list of dict: [{\"image_file\": \"1.png\", \"text\": \"A dog\"}]\n\n        self.transform = transforms.Compose(\n            [\n                transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(self.size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        self.clip_image_processor = CLIPImageProcessor()\n\n    def __getitem__(self, idx):\n        item = self.data[idx]\n        text = item[\"text\"]\n        image_file = item[\"image_file\"]\n\n        # read image\n        raw_image = Image.open(os.path.join(self.image_root_path, image_file))\n        image = self.transform(raw_image.convert(\"RGB\"))\n        clip_image = self.clip_image_processor(images=raw_image, return_tensors=\"pt\").pixel_values\n\n        # drop\n        drop_image_embed = 0\n        rand_num = random.random()\n        if rand_num < self.i_drop_rate:\n            drop_image_embed = 1\n        elif rand_num < (self.i_drop_rate + self.t_drop_rate):\n            text = \"\"\n        elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):\n            text = \"\"\n            drop_image_embed = 1\n        # get text and tokenize\n        text_input_ids = self.tokenizer(\n            text,\n            max_length=self.tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n        ).input_ids\n\n        return {\n            \"image\": image,\n            \"text_input_ids\": text_input_ids,\n            \"clip_image\": clip_image,\n            \"drop_image_embed\": drop_image_embed,\n        }\n\n    def __len__(self):\n        return len(self.data)\n\n\ndef collate_fn(data):\n    images = torch.stack([example[\"image\"] for example in data])\n    text_input_ids = torch.cat([example[\"text_input_ids\"] for example in data], dim=0)\n    clip_images = torch.cat([example[\"clip_image\"] for example in data], dim=0)\n    drop_image_embeds = [example[\"drop_image_embed\"] for example in data]\n\n    return {\n        \"images\": images,\n        \"text_input_ids\": text_input_ids,\n        \"clip_images\": clip_images,\n        \"drop_image_embeds\": drop_image_embeds,\n    }\n\n\nclass IPAdapter(torch.nn.Module):\n    \"\"\"IP-Adapter\"\"\"\n\n    def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):\n        super().__init__()\n        self.unet = unet\n        self.image_proj_model = image_proj_model\n        self.adapter_modules = adapter_modules\n\n        if ckpt_path is not None:\n            self.load_from_checkpoint(ckpt_path)\n\n    def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):\n        ip_tokens = self.image_proj_model(image_embeds)\n        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)\n        # Predict the noise residual\n        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample\n        return noise_pred\n\n    def load_from_checkpoint(self, ckpt_path: str):\n        # Calculate original checksums\n        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))\n        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))\n\n        state_dict = torch.load(ckpt_path, map_location=\"cpu\")\n\n        # Check if 'latents' exists in both the saved state_dict and the current model's state_dict\n        strict_load_image_proj_model = True\n        if \"latents\" in state_dict[\"image_proj\"] and \"latents\" in self.image_proj_model.state_dict():\n            # Check if the shapes are mismatched\n            if state_dict[\"image_proj\"][\"latents\"].shape != self.image_proj_model.state_dict()[\"latents\"].shape:\n                print(f\"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.\")\n                print(\"Removing 'latents' from checkpoint and loading the rest of the weights.\")\n                del state_dict[\"image_proj\"][\"latents\"]\n                strict_load_image_proj_model = False\n\n        # Load state dict for image_proj_model and adapter_modules\n        self.image_proj_model.load_state_dict(state_dict[\"image_proj\"], strict=strict_load_image_proj_model)\n        self.adapter_modules.load_state_dict(state_dict[\"ip_adapter\"], strict=True)\n\n        # Calculate new checksums\n        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))\n        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))\n\n        # Verify if the weights have changed\n        assert orig_ip_proj_sum != new_ip_proj_sum, \"Weights of image_proj_model did not change!\"\n        assert orig_adapter_sum != new_adapter_sum, \"Weights of adapter_modules did not change!\"\n\n        print(f\"Successfully loaded weights from checkpoint {ckpt_path}\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_ip_adapter_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained ip adapter model. If not specified weights are initialized randomly.\",\n    )\n    parser.add_argument(\n        \"--num_tokens\",\n        type=int,\n        default=16,\n        help=\"Number of tokens to query from the CLIP image encoding.\",\n    )\n    parser.add_argument(\n        \"--data_json_file\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Training data\",\n    )\n    parser.add_argument(\n        \"--data_root_path\",\n        type=str,\n        default=\"\",\n        required=True,\n        help=\"Training data root path\",\n    )\n    parser.add_argument(\n        \"--image_encoder_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to CLIP image encoder\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-ip_adapter\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\"The resolution for input images\"),\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Learning rate to use.\",\n    )\n    parser.add_argument(\"--weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=8, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=2000,\n        help=(\"Save a checkpoint of the training state every X updates\"),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\ndef main():\n    args = parse_args()\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"text_encoder\")\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\")\n    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"unet\")\n    image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)\n    # freeze parameters of models to save more memory\n    unet.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    image_encoder.requires_grad_(False)\n\n    # ip-adapter-plus\n    image_proj_model = Resampler(\n        dim=unet.config.cross_attention_dim,\n        depth=4,\n        dim_head=64,\n        heads=12,\n        num_queries=args.num_tokens,\n        embedding_dim=image_encoder.config.hidden_size,\n        output_dim=unet.config.cross_attention_dim,\n        ff_mult=4,\n    )\n    # init adapter modules\n    attn_procs = {}\n    unet_sd = unet.state_dict()\n    for name in unet.attn_processors.keys():\n        cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n        if name.startswith(\"mid_block\"):\n            hidden_size = unet.config.block_out_channels[-1]\n        elif name.startswith(\"up_blocks\"):\n            block_id = int(name[len(\"up_blocks.\")])\n            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n        elif name.startswith(\"down_blocks\"):\n            block_id = int(name[len(\"down_blocks.\")])\n            hidden_size = unet.config.block_out_channels[block_id]\n        if cross_attention_dim is None:\n            attn_procs[name] = AttnProcessor()\n        else:\n            layer_name = name.split(\".processor\")[0]\n            weights = {\n                \"to_k_ip.weight\": unet_sd[layer_name + \".to_k.weight\"],\n                \"to_v_ip.weight\": unet_sd[layer_name + \".to_v.weight\"],\n            }\n            attn_procs[name] = IPAttnProcessor(\n                hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens\n            )\n            attn_procs[name].load_state_dict(weights)\n    unet.set_attn_processor(attn_procs)\n    adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())\n\n    ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)\n\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n    # unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    image_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # optimizer\n    params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())\n    optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)\n\n    # dataloader\n    train_dataset = MyDataset(\n        args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path\n    )\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Prepare everything with our `accelerator`.\n    ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)\n\n    global_step = 0\n    for epoch in range(0, args.num_train_epochs):\n        begin = time.perf_counter()\n        for step, batch in enumerate(train_dataloader):\n            load_data_time = time.perf_counter() - begin\n            with accelerator.accumulate(ip_adapter):\n                # Convert images to latent space\n                with torch.no_grad():\n                    latents = vae.encode(\n                        batch[\"images\"].to(accelerator.device, dtype=weight_dtype)\n                    ).latent_dist.sample()\n                    latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                clip_images = []\n                for clip_image, drop_image_embed in zip(batch[\"clip_images\"], batch[\"drop_image_embeds\"]):\n                    if drop_image_embed == 1:\n                        clip_images.append(torch.zeros_like(clip_image))\n                    else:\n                        clip_images.append(clip_image)\n                clip_images = torch.stack(clip_images, dim=0)\n                with torch.no_grad():\n                    image_embeds = image_encoder(\n                        clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True\n                    ).hidden_states[-2]\n\n                with torch.no_grad():\n                    encoder_hidden_states = text_encoder(batch[\"text_input_ids\"].to(accelerator.device))[0]\n\n                noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)\n\n                loss = F.mse_loss(noise_pred.float(), noise.float(), reduction=\"mean\")\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()\n\n                # Backpropagate\n                accelerator.backward(loss)\n                optimizer.step()\n                optimizer.zero_grad()\n\n                if accelerator.is_main_process:\n                    print(\n                        \"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}\".format(\n                            epoch, step, load_data_time, time.perf_counter() - begin, avg_loss\n                        )\n                    )\n\n            global_step += 1\n\n            if global_step % args.save_steps == 0:\n                save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                accelerator.save_state(save_path)\n\n            begin = time.perf_counter()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/ip_adapter/tutorial_train_sdxl.py",
    "content": "import argparse\nimport itertools\nimport json\nimport os\nimport random\nimport time\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom accelerate import Accelerator\nfrom accelerate.utils import ProjectConfiguration\nfrom ip_adapter.ip_adapter import ImageProjModel\nfrom ip_adapter.utils import is_torch2_available\nfrom PIL import Image\nfrom torchvision import transforms\nfrom transformers import (\n    CLIPImageProcessor,\n    CLIPTextModel,\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel\n\n\nif is_torch2_available():\n    from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor\n    from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor\nelse:\n    from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor\n\n\n# Dataset\nclass MyDataset(torch.utils.data.Dataset):\n    def __init__(\n        self,\n        json_file,\n        tokenizer,\n        tokenizer_2,\n        size=1024,\n        center_crop=True,\n        t_drop_rate=0.05,\n        i_drop_rate=0.05,\n        ti_drop_rate=0.05,\n        image_root_path=\"\",\n    ):\n        super().__init__()\n\n        self.tokenizer = tokenizer\n        self.tokenizer_2 = tokenizer_2\n        self.size = size\n        self.center_crop = center_crop\n        self.i_drop_rate = i_drop_rate\n        self.t_drop_rate = t_drop_rate\n        self.ti_drop_rate = ti_drop_rate\n        self.image_root_path = image_root_path\n\n        self.data = json.load(open(json_file))  # list of dict: [{\"image_file\": \"1.png\", \"text\": \"A dog\"}]\n\n        self.transform = transforms.Compose(\n            [\n                transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n        self.clip_image_processor = CLIPImageProcessor()\n\n    def __getitem__(self, idx):\n        item = self.data[idx]\n        text = item[\"text\"]\n        image_file = item[\"image_file\"]\n\n        # read image\n        raw_image = Image.open(os.path.join(self.image_root_path, image_file))\n\n        # original size\n        original_width, original_height = raw_image.size\n        original_size = torch.tensor([original_height, original_width])\n\n        image_tensor = self.transform(raw_image.convert(\"RGB\"))\n        # random crop\n        delta_h = image_tensor.shape[1] - self.size\n        delta_w = image_tensor.shape[2] - self.size\n        assert not all([delta_h, delta_w])\n\n        if self.center_crop:\n            top = delta_h // 2\n            left = delta_w // 2\n        else:\n            top = np.random.randint(0, delta_h + 1)\n            left = np.random.randint(0, delta_w + 1)\n        image = transforms.functional.crop(image_tensor, top=top, left=left, height=self.size, width=self.size)\n        crop_coords_top_left = torch.tensor([top, left])\n\n        clip_image = self.clip_image_processor(images=raw_image, return_tensors=\"pt\").pixel_values\n\n        # drop\n        drop_image_embed = 0\n        rand_num = random.random()\n        if rand_num < self.i_drop_rate:\n            drop_image_embed = 1\n        elif rand_num < (self.i_drop_rate + self.t_drop_rate):\n            text = \"\"\n        elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):\n            text = \"\"\n            drop_image_embed = 1\n\n        # get text and tokenize\n        text_input_ids = self.tokenizer(\n            text,\n            max_length=self.tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n        ).input_ids\n\n        text_input_ids_2 = self.tokenizer_2(\n            text,\n            max_length=self.tokenizer_2.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n        ).input_ids\n\n        return {\n            \"image\": image,\n            \"text_input_ids\": text_input_ids,\n            \"text_input_ids_2\": text_input_ids_2,\n            \"clip_image\": clip_image,\n            \"drop_image_embed\": drop_image_embed,\n            \"original_size\": original_size,\n            \"crop_coords_top_left\": crop_coords_top_left,\n            \"target_size\": torch.tensor([self.size, self.size]),\n        }\n\n    def __len__(self):\n        return len(self.data)\n\n\ndef collate_fn(data):\n    images = torch.stack([example[\"image\"] for example in data])\n    text_input_ids = torch.cat([example[\"text_input_ids\"] for example in data], dim=0)\n    text_input_ids_2 = torch.cat([example[\"text_input_ids_2\"] for example in data], dim=0)\n    clip_images = torch.cat([example[\"clip_image\"] for example in data], dim=0)\n    drop_image_embeds = [example[\"drop_image_embed\"] for example in data]\n    original_size = torch.stack([example[\"original_size\"] for example in data])\n    crop_coords_top_left = torch.stack([example[\"crop_coords_top_left\"] for example in data])\n    target_size = torch.stack([example[\"target_size\"] for example in data])\n\n    return {\n        \"images\": images,\n        \"text_input_ids\": text_input_ids,\n        \"text_input_ids_2\": text_input_ids_2,\n        \"clip_images\": clip_images,\n        \"drop_image_embeds\": drop_image_embeds,\n        \"original_size\": original_size,\n        \"crop_coords_top_left\": crop_coords_top_left,\n        \"target_size\": target_size,\n    }\n\n\nclass IPAdapter(torch.nn.Module):\n    \"\"\"IP-Adapter\"\"\"\n\n    def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):\n        super().__init__()\n        self.unet = unet\n        self.image_proj_model = image_proj_model\n        self.adapter_modules = adapter_modules\n\n        if ckpt_path is not None:\n            self.load_from_checkpoint(ckpt_path)\n\n    def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):\n        ip_tokens = self.image_proj_model(image_embeds)\n        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)\n        # Predict the noise residual\n        noise_pred = self.unet(\n            noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs\n        ).sample\n        return noise_pred\n\n    def load_from_checkpoint(self, ckpt_path: str):\n        # Calculate original checksums\n        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))\n        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))\n\n        state_dict = torch.load(ckpt_path, map_location=\"cpu\")\n\n        # Load state dict for image_proj_model and adapter_modules\n        self.image_proj_model.load_state_dict(state_dict[\"image_proj\"], strict=True)\n        self.adapter_modules.load_state_dict(state_dict[\"ip_adapter\"], strict=True)\n\n        # Calculate new checksums\n        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))\n        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))\n\n        # Verify if the weights have changed\n        assert orig_ip_proj_sum != new_ip_proj_sum, \"Weights of image_proj_model did not change!\"\n        assert orig_adapter_sum != new_adapter_sum, \"Weights of adapter_modules did not change!\"\n\n        print(f\"Successfully loaded weights from checkpoint {ckpt_path}\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_ip_adapter_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained ip adapter model. If not specified weights are initialized randomly.\",\n    )\n    parser.add_argument(\n        \"--data_json_file\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Training data\",\n    )\n    parser.add_argument(\n        \"--data_root_path\",\n        type=str,\n        default=\"\",\n        required=True,\n        help=\"Training data root path\",\n    )\n    parser.add_argument(\n        \"--image_encoder_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to CLIP image encoder\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-ip_adapter\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\"The resolution for input images\"),\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Learning rate to use.\",\n    )\n    parser.add_argument(\"--weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=8, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=None, help=\"noise offset\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=2000,\n        help=(\"Save a checkpoint of the training state every X updates\"),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\ndef main():\n    args = parse_args()\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"text_encoder\")\n    tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer_2\")\n    text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\"\n    )\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\")\n    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"unet\")\n    image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)\n    # freeze parameters of models to save more memory\n    unet.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    text_encoder_2.requires_grad_(False)\n    image_encoder.requires_grad_(False)\n\n    # ip-adapter\n    num_tokens = 4\n    image_proj_model = ImageProjModel(\n        cross_attention_dim=unet.config.cross_attention_dim,\n        clip_embeddings_dim=image_encoder.config.projection_dim,\n        clip_extra_context_tokens=num_tokens,\n    )\n    # init adapter modules\n    attn_procs = {}\n    unet_sd = unet.state_dict()\n    for name in unet.attn_processors.keys():\n        cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n        if name.startswith(\"mid_block\"):\n            hidden_size = unet.config.block_out_channels[-1]\n        elif name.startswith(\"up_blocks\"):\n            block_id = int(name[len(\"up_blocks.\")])\n            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n        elif name.startswith(\"down_blocks\"):\n            block_id = int(name[len(\"down_blocks.\")])\n            hidden_size = unet.config.block_out_channels[block_id]\n        if cross_attention_dim is None:\n            attn_procs[name] = AttnProcessor()\n        else:\n            layer_name = name.split(\".processor\")[0]\n            weights = {\n                \"to_k_ip.weight\": unet_sd[layer_name + \".to_k.weight\"],\n                \"to_v_ip.weight\": unet_sd[layer_name + \".to_v.weight\"],\n            }\n            attn_procs[name] = IPAttnProcessor(\n                hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens\n            )\n            attn_procs[name].load_state_dict(weights)\n    unet.set_attn_processor(attn_procs)\n    adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())\n\n    ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)\n\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n    # unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device)  # use fp32\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_2.to(accelerator.device, dtype=weight_dtype)\n    image_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # optimizer\n    params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())\n    optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)\n\n    # dataloader\n    train_dataset = MyDataset(\n        args.data_json_file,\n        tokenizer=tokenizer,\n        tokenizer_2=tokenizer_2,\n        size=args.resolution,\n        image_root_path=args.data_root_path,\n    )\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Prepare everything with our `accelerator`.\n    ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)\n\n    global_step = 0\n    for epoch in range(0, args.num_train_epochs):\n        begin = time.perf_counter()\n        for step, batch in enumerate(train_dataloader):\n            load_data_time = time.perf_counter() - begin\n            with accelerator.accumulate(ip_adapter):\n                # Convert images to latent space\n                with torch.no_grad():\n                    # vae of sdxl should use fp32\n                    latents = vae.encode(batch[\"images\"].to(accelerator.device, dtype=vae.dtype)).latent_dist.sample()\n                    latents = latents * vae.config.scaling_factor\n                    latents = latents.to(accelerator.device, dtype=weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to(\n                        accelerator.device, dtype=weight_dtype\n                    )\n\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                with torch.no_grad():\n                    image_embeds = image_encoder(\n                        batch[\"clip_images\"].to(accelerator.device, dtype=weight_dtype)\n                    ).image_embeds\n                image_embeds_ = []\n                for image_embed, drop_image_embed in zip(image_embeds, batch[\"drop_image_embeds\"]):\n                    if drop_image_embed == 1:\n                        image_embeds_.append(torch.zeros_like(image_embed))\n                    else:\n                        image_embeds_.append(image_embed)\n                image_embeds = torch.stack(image_embeds_)\n\n                with torch.no_grad():\n                    encoder_output = text_encoder(\n                        batch[\"text_input_ids\"].to(accelerator.device), output_hidden_states=True\n                    )\n                    text_embeds = encoder_output.hidden_states[-2]\n                    encoder_output_2 = text_encoder_2(\n                        batch[\"text_input_ids_2\"].to(accelerator.device), output_hidden_states=True\n                    )\n                    pooled_text_embeds = encoder_output_2[0]\n                    text_embeds_2 = encoder_output_2.hidden_states[-2]\n                    text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1)  # concat\n\n                # add cond\n                add_time_ids = [\n                    batch[\"original_size\"].to(accelerator.device),\n                    batch[\"crop_coords_top_left\"].to(accelerator.device),\n                    batch[\"target_size\"].to(accelerator.device),\n                ]\n                add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype)\n                unet_added_cond_kwargs = {\"text_embeds\": pooled_text_embeds, \"time_ids\": add_time_ids}\n\n                noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, image_embeds)\n\n                loss = F.mse_loss(noise_pred.float(), noise.float(), reduction=\"mean\")\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()\n\n                # Backpropagate\n                accelerator.backward(loss)\n                optimizer.step()\n                optimizer.zero_grad()\n\n                if accelerator.is_main_process:\n                    print(\n                        \"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}\".format(\n                            epoch, step, load_data_time, time.perf_counter() - begin, avg_loss\n                        )\n                    )\n\n            global_step += 1\n\n            if global_step % args.save_steps == 0:\n                save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                accelerator.save_state(save_path)\n\n            begin = time.perf_counter()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/lora/README.md",
    "content": "# Stable Diffusion text-to-image fine-tuning\nThis extended LoRA training script was authored by [haofanwang](https://github.com/haofanwang).\nThis is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py). We further support add LoRA layers for text encoder.\n\n## Training with LoRA\n\nLow-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.\n\nIn a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:\n\n- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).\n- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.\n- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.\n\n[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.\n\nWith LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset\non consumer GPUs like Tesla T4, Tesla V100.\n\n### Training\n\nFirst, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).\n\n**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**\n\n**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n```\n\nFor this example we want to directly store the trained LoRA embeddings on the Hub, so\nwe need to be logged in and add the `--push_to_hub` flag.\n\n```bash\nhf auth login\n```\n\nNow we can start training!\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" train_text_to_image_lora.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$DATASET_NAME --caption_column=\"text\" \\\n  --resolution=512 --random_flip \\\n  --train_batch_size=1 \\\n  --num_train_epochs=100 --checkpointing_steps=5000 \\\n  --learning_rate=1e-04 --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --seed=42 \\\n  --output_dir=\"sd-naruto-model-lora\" \\\n  --validation_prompt=\"cute dragon creature\" --report_to=\"wandb\"\n  --use_peft \\\n  --lora_r=4 --lora_alpha=32 \\\n  --lora_text_encoder_r=4 --lora_text_encoder_alpha=32\n```\n\nThe above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.\n\n**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run `train_text_to_image_lora.py` in consumer GPUs like T4 or V100.___**\n\nThe final LoRA embedding weights have been uploaded to [sayakpaul/sd-model-finetuned-lora-t4](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4). **___Note: [The final weights](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/pytorch_lora_weights.bin) are only 3 MB in size, which is orders of magnitudes smaller than the original model.___**\n\nYou can check some inference samples that were logged during the course of the fine-tuning process [here](https://wandb.ai/sayakpaul/text2image-fine-tune/runs/q4lc0xsw).\n\n### Inference\n\nOnce you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline` after loading the trained LoRA weights.  You\nneed to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-naruto-model-lora`.\n\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\nmodel_path = \"sayakpaul/sd-model-finetuned-lora-t4\"\npipe = StableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", torch_dtype=torch.float16)\npipe.unet.load_attn_procs(model_path)\npipe.to(\"cuda\")\n\nprompt = \"A naruto with green eyes and red legs.\"\nimage = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]\nimage.save(\"naruto.png\")\n```"
  },
  {
    "path": "examples/research_projects/lora/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\ndatasets\nftfy\ntensorboard\nJinja2\ngit+https://github.com/huggingface/peft.git"
  },
  {
    "path": "examples/research_projects/lora/train_text_to_image_lora.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Fine-tuning script for Stable Diffusion for text2image with support for LoRA.\"\"\"\n\nimport argparse\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nfrom pathlib import Path\n\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel\nfrom diffusers.loaders import AttnProcsLayers\nfrom diffusers.models.attention_processor import LoRAAttnProcessor\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.14.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):\n    img_str = \"\"\n    for i, image in enumerate(images):\n        image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n        img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {base_model}\ntags:\n- stable-diffusion\n- stable-diffusion-diffusers\n- text-to-image\n- diffusers\n- diffusers-training\n- lora\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# LoRA text2image fine-tuning - {repo_id}\nThese are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \\n\n{img_str}\n\"\"\"\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\", type=str, default=None, help=\"A prompt that is sampled during training for inference.\"\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-model-finetuned-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\"--train_text_encoder\", action=\"store_true\", help=\"Whether to train the text encoder\")\n\n    # lora args\n    parser.add_argument(\"--use_peft\", action=\"store_true\", help=\"Whether to use peft to support lora\")\n    parser.add_argument(\"--lora_r\", type=int, default=4, help=\"Lora rank, only used if use_lora is True\")\n    parser.add_argument(\"--lora_alpha\", type=int, default=32, help=\"Lora alpha, only used if lora is True\")\n    parser.add_argument(\"--lora_dropout\", type=float, default=0.0, help=\"Lora dropout, only used if use_lora is True\")\n    parser.add_argument(\n        \"--lora_bias\",\n        type=str,\n        default=\"none\",\n        help=\"Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True\",\n    )\n    parser.add_argument(\n        \"--lora_text_encoder_r\",\n        type=int,\n        default=4,\n        help=\"Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True\",\n    )\n    parser.add_argument(\n        \"--lora_text_encoder_alpha\",\n        type=int,\n        default=32,\n        help=\"Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True\",\n    )\n    parser.add_argument(\n        \"--lora_text_encoder_dropout\",\n        type=float,\n        default=0.0,\n        help=\"Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True\",\n    )\n    parser.add_argument(\n        \"--lora_text_encoder_bias\",\n        type=str,\n        default=\"none\",\n        help=\"Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True\",\n    )\n\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more docs\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    return args\n\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef main():\n    args = parse_args()\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n    )\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n    )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if args.use_peft:\n        from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict\n\n        UNET_TARGET_MODULES = [\"to_q\", \"to_v\", \"query\", \"value\"]\n        TEXT_ENCODER_TARGET_MODULES = [\"q_proj\", \"v_proj\"]\n\n        config = LoraConfig(\n            r=args.lora_r,\n            lora_alpha=args.lora_alpha,\n            target_modules=UNET_TARGET_MODULES,\n            lora_dropout=args.lora_dropout,\n            bias=args.lora_bias,\n        )\n        unet = LoraModel(config, unet)\n\n        vae.requires_grad_(False)\n        if args.train_text_encoder:\n            config = LoraConfig(\n                r=args.lora_text_encoder_r,\n                lora_alpha=args.lora_text_encoder_alpha,\n                target_modules=TEXT_ENCODER_TARGET_MODULES,\n                lora_dropout=args.lora_text_encoder_dropout,\n                bias=args.lora_text_encoder_bias,\n            )\n            text_encoder = LoraModel(config, text_encoder)\n    else:\n        # freeze parameters of models to save more memory\n        unet.requires_grad_(False)\n        vae.requires_grad_(False)\n\n        text_encoder.requires_grad_(False)\n\n        # now we will add new LoRA weights to the attention layers\n        # It's important to realize here how many attention weights will be added and of which sizes\n        # The sizes of the attention layers consist only of two different variables:\n        # 1) - the \"hidden_size\", which is increased according to `unet.config.block_out_channels`.\n        # 2) - the \"cross attention size\", which is set to `unet.config.cross_attention_dim`.\n\n        # Let's first see how many attention processors we will have to set.\n        # For Stable Diffusion, it should be equal to:\n        # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12\n        # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2\n        # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18\n        # => 32 layers\n\n        # Set correct lora layers\n        lora_attn_procs = {}\n        for name in unet.attn_processors.keys():\n            cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n            if name.startswith(\"mid_block\"):\n                hidden_size = unet.config.block_out_channels[-1]\n            elif name.startswith(\"up_blocks\"):\n                block_id = int(name[len(\"up_blocks.\")])\n                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n            elif name.startswith(\"down_blocks\"):\n                block_id = int(name[len(\"down_blocks.\")])\n                hidden_size = unet.config.block_out_channels[block_id]\n\n            lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)\n\n        unet.set_attn_processor(lora_attn_procs)\n        lora_layers = AttnProcsLayers(unet.attn_processors)\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    vae.to(accelerator.device, dtype=weight_dtype)\n    if not args.train_text_encoder:\n        text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    if args.use_peft:\n        # Optimizer creation\n        params_to_optimize = (\n            itertools.chain(unet.parameters(), text_encoder.parameters())\n            if args.train_text_encoder\n            else unet.parameters()\n        )\n        optimizer = optimizer_cls(\n            params_to_optimize,\n            lr=args.learning_rate,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n    else:\n        optimizer = optimizer_cls(\n            lora_layers.parameters(),\n            lr=args.learning_rate,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        return inputs.input_ids\n\n    # Preprocessing the datasets.\n    train_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n        examples[\"input_ids\"] = tokenize_captions(examples)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n        return {\"pixel_values\": pixel_values, \"input_ids\": input_ids}\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.use_peft:\n        if args.train_text_encoder:\n            unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n                unet, text_encoder, optimizer, train_dataloader, lr_scheduler\n            )\n        else:\n            unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n                unet, optimizer, train_dataloader, lr_scheduler\n            )\n    else:\n        lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            lora_layers, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"text2image-fine-tune\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.train_text_encoder:\n            text_encoder.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # Predict the noise residual and compute loss\n                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    if args.use_peft:\n                        params_to_clip = (\n                            itertools.chain(unet.parameters(), text_encoder.parameters())\n                            if args.train_text_encoder\n                            else unet.parameters()\n                        )\n                    else:\n                        params_to_clip = lora_layers.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                logger.info(\n                    f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n                    f\" {args.validation_prompt}.\"\n                )\n                # create pipeline\n                pipeline = DiffusionPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    unet=accelerator.unwrap_model(unet),\n                    text_encoder=accelerator.unwrap_model(text_encoder),\n                    revision=args.revision,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline = pipeline.to(accelerator.device)\n                pipeline.set_progress_bar_config(disable=True)\n\n                # run inference\n                generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n                images = []\n                for _ in range(args.num_validation_images):\n                    images.append(\n                        pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]\n                    )\n\n                if accelerator.is_main_process:\n                    for tracker in accelerator.trackers:\n                        if tracker.name == \"tensorboard\":\n                            np_images = np.stack([np.asarray(img) for img in images])\n                            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n                        if tracker.name == \"wandb\":\n                            tracker.log(\n                                {\n                                    \"validation\": [\n                                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                        for i, image in enumerate(images)\n                                    ]\n                                }\n                            )\n\n                del pipeline\n                torch.cuda.empty_cache()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        if args.use_peft:\n            lora_config = {}\n            unwarpped_unet = accelerator.unwrap_model(unet)\n            state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet))\n            lora_config[\"peft_config\"] = unwarpped_unet.get_peft_config_as_dict(inference=True)\n            if args.train_text_encoder:\n                unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)\n                text_encoder_state_dict = get_peft_model_state_dict(\n                    unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder)\n                )\n                text_encoder_state_dict = {f\"text_encoder_{k}\": v for k, v in text_encoder_state_dict.items()}\n                state_dict.update(text_encoder_state_dict)\n                lora_config[\"text_encoder_peft_config\"] = unwarpped_text_encoder.get_peft_config_as_dict(\n                    inference=True\n                )\n\n            accelerator.save(state_dict, os.path.join(args.output_dir, f\"{global_step}_lora.pt\"))\n            with open(os.path.join(args.output_dir, f\"{global_step}_lora_config.json\"), \"w\") as f:\n                json.dump(lora_config, f)\n        else:\n            unet = unet.to(torch.float32)\n            unet.save_attn_procs(args.output_dir)\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                dataset_name=args.dataset_name,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    # Final inference\n    # Load previous pipeline\n    pipeline = DiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype\n    )\n\n    if args.use_peft:\n\n        def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype):\n            with open(os.path.join(args.output_dir, f\"{global_step}_lora_config.json\"), \"r\") as f:\n                lora_config = json.load(f)\n            print(lora_config)\n\n            checkpoint = os.path.join(args.output_dir, f\"{global_step}_lora.pt\")\n            lora_checkpoint_sd = torch.load(checkpoint)\n            unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if \"text_encoder_\" not in k}\n            text_encoder_lora_ds = {\n                k.replace(\"text_encoder_\", \"\"): v for k, v in lora_checkpoint_sd.items() if \"text_encoder_\" in k\n            }\n\n            unet_config = LoraConfig(**lora_config[\"peft_config\"])\n            pipe.unet = LoraModel(unet_config, pipe.unet)\n            set_peft_model_state_dict(pipe.unet, unet_lora_ds)\n\n            if \"text_encoder_peft_config\" in lora_config:\n                text_encoder_config = LoraConfig(**lora_config[\"text_encoder_peft_config\"])\n                pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder)\n                set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds)\n\n            if dtype in (torch.float16, torch.bfloat16):\n                pipe.unet.half()\n                pipe.text_encoder.half()\n\n            pipe.to(device)\n            return pipe\n\n        pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype)\n\n    else:\n        pipeline = pipeline.to(accelerator.device)\n        # load attention processors\n        pipeline.unet.load_attn_procs(args.output_dir)\n\n    # run inference\n    if args.seed is not None:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n    else:\n        generator = None\n    images = []\n    for _ in range(args.num_validation_images):\n        images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])\n\n    if accelerator.is_main_process:\n        for tracker in accelerator.trackers:\n            if tracker.name == \"tensorboard\":\n                np_images = np.stack([np.asarray(img) for img in images])\n                tracker.writer.add_images(\"test\", np_images, epoch, dataformats=\"NHWC\")\n            if tracker.name == \"wandb\":\n                tracker.log(\n                    {\n                        \"test\": [\n                            wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                            for i, image in enumerate(images)\n                        ]\n                    }\n                )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/lpl/README.md",
    "content": "# Latent Perceptual Loss (LPL) for Stable Diffusion XL\n\nThis directory contains an implementation of Latent Perceptual Loss (LPL) for training Stable Diffusion XL models, based on the paper: [Boosting Latent Diffusion with Perceptual Objectives](https://huggingface.co/papers/2411.04873) (Berrada et al., 2025). LPL is a perceptual loss that operates in the latent space of a VAE, helping to improve the quality and consistency of generated images by bridging the disconnect between the diffusion model and the autoencoder decoder. The implementation is based on the reference implementation provided by Tariq Berrada.\n\n## Overview\n\nLPL addresses a key limitation in latent diffusion models (LDMs): the disconnect between the diffusion model training and the autoencoder decoder. While LDMs train in the latent space, they don't receive direct feedback about how well their outputs decode into high-quality images. This can lead to:\n\n- Loss of fine details in generated images\n- Inconsistent image quality\n- Structural artifacts\n- Reduced sharpness and realism\n\nLPL works by comparing intermediate features from the VAE decoder between the predicted and target latents. This helps the model learn better perceptual features and can lead to:\n\n- Improved image quality and consistency (6-20% FID improvement)\n- Better preservation of fine details\n- More stable training, especially at high noise levels\n- Better handling of structural information\n- Sharper and more realistic textures\n\n## Implementation Details\n\nThe LPL implementation follows the paper's methodology and includes several key features:\n\n1. **Feature Extraction**: Extracts intermediate features from the VAE decoder, including:\n   - Middle block features\n   - Up block features (configurable number of blocks)\n   - Proper gradient checkpointing for memory efficiency\n   - Features are extracted only for timesteps below the threshold (high SNR)\n\n2. **Feature Normalization**: Multiple normalization options as validated in the paper:\n   - `default`: Normalize each feature map independently\n   - `shared`: Cross-normalize features using target statistics (recommended)\n   - `batch`: Batch-wise normalization\n\n3. **Outlier Handling**: Optional removal of outliers in feature maps using:\n   - Quantile-based filtering (2% quantiles)\n   - Morphological operations (opening/closing)\n   - Adaptive thresholding based on standard deviation\n\n4. **Loss Types**:\n   - MSE loss (default)\n   - L1 loss\n   - Optional power law weighting (2^(-i) for layer i)\n\n## Usage\n\nTo use LPL in your training, add the following arguments to your training command:\n\n```bash\npython examples/research_projects/lpl/train_sdxl_lpl.py \\\n    --use_lpl \\\n    --lpl_weight 1.0 \\                    # Weight for LPL loss (1.0-2.0 recommended)\n    --lpl_t_threshold 200 \\              # Apply LPL only for timesteps < threshold (high SNR)\n    --lpl_loss_type mse \\                # Loss type: \"mse\" or \"l1\"\n    --lpl_norm_type shared \\             # Normalization type: \"default\", \"shared\" (recommended), or \"batch\"\n    --lpl_pow_law \\                      # Use power law weighting for layers\n    --lpl_num_blocks 4 \\                 # Number of up blocks to use (1-4)\n    --lpl_remove_outliers \\              # Remove outliers in feature maps\n    --lpl_scale \\                        # Scale LPL loss by noise level weights\n    --lpl_start 0 \\                      # Step to start applying LPL\n    # ... other training arguments ...\n```\n\n### Key Parameters\n\n- `lpl_weight`: Controls the strength of the LPL loss relative to the main diffusion loss. Higher values (1.0-2.0) improve quality but may slow training.\n- `lpl_t_threshold`: LPL is only applied for timesteps below this threshold (high SNR). Lower values (100-200) focus on more important timesteps.\n- `lpl_loss_type`: Choose between MSE (default) and L1 loss. MSE is recommended for most cases.\n- `lpl_norm_type`: Feature normalization strategy. \"shared\" is recommended as it showed best results in the paper.\n- `lpl_pow_law`: Whether to use power law weighting (2^(-i) for layer i). Recommended for better feature balance.\n- `lpl_num_blocks`: Number of up blocks to use for feature extraction (1-4). More blocks capture more features but use more memory.\n- `lpl_remove_outliers`: Whether to remove outliers in feature maps. Recommended for stable training.\n- `lpl_scale`: Whether to scale LPL loss by noise level weights. Helps focus on more important timesteps.\n- `lpl_start`: Training step to start applying LPL. Can be used to warm up training.\n\n## Recommendations\n\n1. **Starting Point** (based on paper results):\n   ```bash\n   --use_lpl \\\n   --lpl_weight 1.0 \\\n   --lpl_t_threshold 200 \\\n   --lpl_loss_type mse \\\n   --lpl_norm_type shared \\\n   --lpl_pow_law \\\n   --lpl_num_blocks 4 \\\n   --lpl_remove_outliers \\\n   --lpl_scale\n   ```\n\n2. **Memory Efficiency**:\n   - Use `--gradient_checkpointing` for memory efficiency (enabled by default)\n   - Reduce `lpl_num_blocks` if memory is constrained (2-3 blocks still give good results)\n   - Consider using `--lpl_scale` to focus on more important timesteps\n   - Features are extracted only for timesteps below threshold to save memory\n\n3. **Quality vs Speed**:\n   - Higher `lpl_weight` (1.0-2.0) for better quality\n   - Lower `lpl_t_threshold` (100-200) for faster training\n   - Use `lpl_remove_outliers` for more stable training\n   - `lpl_norm_type shared` provides best quality/speed trade-off\n\n## Technical Details\n\n### Feature Extraction\n\nThe LPL implementation extracts features from the VAE decoder in the following order:\n1. Middle block output\n2. Up block outputs (configurable number of blocks)\n\nEach feature map is processed with:\n1. Optional outlier removal (2% quantiles, morphological operations)\n2. Feature normalization (shared statistics recommended)\n3. Loss calculation (MSE or L1)\n4. Optional power law weighting (2^(-i) for layer i)\n\n### Loss Calculation\n\nFor each feature map:\n1. Features are normalized according to the chosen strategy\n2. Loss is calculated between normalized features\n3. Outliers are masked out (if enabled)\n4. Loss is weighted by layer depth (if power law enabled)\n5. Final loss is averaged across all layers\n\n### Memory Considerations\n\n- Gradient checkpointing is used by default\n- Features are extracted only for timesteps below the threshold\n- Outlier removal is done in-place to save memory\n- Feature normalization is done efficiently using vectorized operations\n- Memory usage scales linearly with number of blocks used\n\n## Results\n\nBased on the paper's findings, LPL provides:\n- 6-20% improvement in FID scores\n- Better preservation of fine details\n- More realistic textures and structures\n- Improved consistency across different resolutions\n- Better performance on both small and large datasets\n\n## Citation\n\nIf you use this implementation in your research, please cite:\n\n```bibtex\n@inproceedings{berrada2025boosting,\n    title={Boosting Latent Diffusion with Perceptual Objectives},\n    author={Tariq Berrada and Pietro Astolfi and Melissa Hall and Marton Havasi and Yohann Benchetrit and Adriana Romero-Soriano and Karteek Alahari and Michal Drozdzal and Jakob Verbeek},\n    booktitle={The Thirteenth International Conference on Learning Representations},\n    year={2025},\n    url={https://openreview.net/forum?id=y4DtzADzd1}\n}\n```\n"
  },
  {
    "path": "examples/research_projects/lpl/lpl_loss.py",
    "content": "# Copyright 2025 Berrada et al.\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef normalize_tensor(in_feat, eps=1e-10):\n    norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))\n    return in_feat / (norm_factor + eps)\n\n\ndef cross_normalize(input, target, eps=1e-10):\n    norm_factor = torch.sqrt(torch.sum(target**2, dim=1, keepdim=True))\n    return input / (norm_factor + eps), target / (norm_factor + eps)\n\n\ndef remove_outliers(feat, down_f=1, opening=5, closing=3, m=100, quant=0.02):\n    opening = int(np.ceil(opening / down_f))\n    closing = int(np.ceil(closing / down_f))\n    if opening == 2:\n        opening = 3\n    if closing == 2:\n        closing = 1\n\n    # replace quantile with kth value here.\n    feat_flat = feat.flatten(-2, -1)\n    k1, k2 = int(feat_flat.shape[-1] * quant), int(feat_flat.shape[-1] * (1 - quant))\n    q1 = feat_flat.kthvalue(k1, dim=-1).values[..., None, None]\n    q2 = feat_flat.kthvalue(k2, dim=-1).values[..., None, None]\n\n    m = 2 * feat_flat.std(-1)[..., None, None].detach()\n    mask = (q1 - m < feat) * (feat < q2 + m)\n\n    # dilate the mask.\n    mask = nn.MaxPool2d(kernel_size=closing, stride=1, padding=(closing - 1) // 2)(mask.float())  # closing\n    mask = (-nn.MaxPool2d(kernel_size=opening, stride=1, padding=(opening - 1) // 2)(-mask)).bool()  # opening\n    feat = feat * mask\n    return mask, feat\n\n\nclass LatentPerceptualLoss(nn.Module):\n    def __init__(\n        self,\n        vae,\n        loss_type=\"mse\",\n        grad_ckpt=True,\n        pow_law=False,\n        norm_type=\"default\",\n        num_mid_blocks=4,\n        feature_type=\"feature\",\n        remove_outliers=True,\n    ):\n        super().__init__()\n        self.vae = vae\n        self.decoder = self.vae.decoder\n        # Store scaling factors as tensors on the correct device\n        device = next(self.vae.parameters()).device\n\n        # Get scaling factors with proper defaults and handle None values\n        scale_factor = getattr(self.vae.config, \"scaling_factor\", None)\n        shift_factor = getattr(self.vae.config, \"shift_factor\", None)\n\n        # Convert to tensors with proper defaults\n        self.scale = torch.tensor(1.0 if scale_factor is None else scale_factor, device=device)\n        self.shift = torch.tensor(0.0 if shift_factor is None else shift_factor, device=device)\n\n        self.gradient_checkpointing = grad_ckpt\n        self.pow_law = pow_law\n        self.norm_type = norm_type.lower()\n        self.outlier_mask = remove_outliers\n        self.last_feature_stats = []  # Store feature statistics for logging\n\n        assert feature_type in [\"feature\", \"image\"]\n        self.feature_type = feature_type\n\n        assert self.norm_type in [\"default\", \"shared\", \"batch\"]\n        assert num_mid_blocks >= 0 and num_mid_blocks <= 4\n        self.n_blocks = num_mid_blocks\n\n        assert loss_type in [\"mse\", \"l1\"]\n        if loss_type == \"mse\":\n            self.loss_fn = nn.MSELoss(reduction=\"none\")\n        elif loss_type == \"l1\":\n            self.loss_fn = nn.L1Loss(reduction=\"none\")\n\n    def get_features(self, z, latent_embeds=None, disable_grads=False):\n        with torch.set_grad_enabled(not disable_grads):\n            if self.gradient_checkpointing and not disable_grads:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                features = []\n                upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype\n                sample = z\n                sample = self.decoder.conv_in(sample)\n\n                # middle\n                sample = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(self.decoder.mid_block),\n                    sample,\n                    latent_embeds,\n                    use_reentrant=False,\n                )\n                sample = sample.to(upscale_dtype)\n                features.append(sample)\n\n                # up\n                for up_block in self.decoder.up_blocks[: self.n_blocks]:\n                    sample = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(up_block),\n                        sample,\n                        latent_embeds,\n                        use_reentrant=False,\n                    )\n                    features.append(sample)\n                return features\n            else:\n                features = []\n                upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype\n                sample = z\n                sample = self.decoder.conv_in(sample)\n\n                # middle\n                sample = self.decoder.mid_block(sample, latent_embeds)\n                sample = sample.to(upscale_dtype)\n                features.append(sample)\n\n                # up\n                for up_block in self.decoder.up_blocks[: self.n_blocks]:\n                    sample = up_block(sample, latent_embeds)\n                    features.append(sample)\n                return features\n\n    def get_loss(self, input, target, get_hist=False):\n        if self.feature_type == \"feature\":\n            inp_f = self.get_features(self.shift + input / self.scale)\n            tar_f = self.get_features(self.shift + target / self.scale, disable_grads=True)\n            losses = []\n            self.last_feature_stats = []  # Reset feature stats\n\n            for i, (x, y) in enumerate(zip(inp_f, tar_f, strict=False)):\n                my = torch.ones_like(y).bool()\n                outlier_ratio = 0.0\n\n                if self.outlier_mask:\n                    with torch.no_grad():\n                        if i == 2:\n                            my, y = remove_outliers(y, down_f=2)\n                            outlier_ratio = 1.0 - my.float().mean().item()\n                        elif i in [3, 4, 5]:\n                            my, y = remove_outliers(y, down_f=1)\n                            outlier_ratio = 1.0 - my.float().mean().item()\n\n                # Store feature statistics before normalization\n                with torch.no_grad():\n                    stats = {\n                        \"mean\": y.mean().item(),\n                        \"std\": y.std().item(),\n                        \"outlier_ratio\": outlier_ratio,\n                    }\n                    self.last_feature_stats.append(stats)\n\n                # normalize feature tensors\n                if self.norm_type == \"default\":\n                    x = normalize_tensor(x)\n                    y = normalize_tensor(y)\n                elif self.norm_type == \"shared\":\n                    x, y = cross_normalize(x, y, eps=1e-6)\n\n                term_loss = self.loss_fn(x, y) * my\n                # reduce loss term\n                loss_f = 2 ** (-min(i, 3)) if self.pow_law else 1.0\n                term_loss = term_loss.sum((2, 3)) * loss_f / my.sum((2, 3))\n                losses.append(term_loss.mean((1,)))\n\n            if get_hist:\n                return losses\n            else:\n                loss = sum(losses)\n                return loss / len(inp_f)\n        elif self.feature_type == \"image\":\n            inp_f = self.vae.decode(input / self.scale).sample\n            tar_f = self.vae.decode(target / self.scale).sample\n            return F.mse_loss(inp_f, tar_f)\n\n    def get_first_conv(self, z):\n        sample = self.decoder.conv_in(z)\n        return sample\n\n    def get_first_block(self, z):\n        sample = self.decoder.conv_in(z)\n        sample = self.decoder.mid_block(sample)\n        for resnet in self.decoder.up_blocks[0].resnets:\n            sample = resnet(sample, None)\n        return sample\n\n    def get_first_layer(self, input, target, target_layer=\"conv\"):\n        if target_layer == \"conv\":\n            feat_in = self.get_first_conv(input)\n            with torch.no_grad():\n                feat_tar = self.get_first_conv(target)\n        else:\n            feat_in = self.get_first_block(input)\n            with torch.no_grad():\n                feat_tar = self.get_first_block(target)\n\n        feat_in, feat_tar = cross_normalize(feat_in, feat_tar)\n\n        return F.mse_loss(feat_in, feat_tar, reduction=\"mean\")\n"
  },
  {
    "path": "examples/research_projects/lpl/train_sdxl_lpl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"LPL training script for Stable Diffusion XL for text2image.\"\"\"\n\nimport argparse\nimport functools\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport re\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\nfrom typing import Dict, List, Tuple\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedType, ProjectConfiguration, set_seed\nfrom datasets import concatenate_datasets, load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom lpl_loss import LatentPerceptualLoss\nfrom packaging import version\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel, compute_snr\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.33.0.dev0\")\n\nlogger = get_logger(__name__)\nif is_torch_npu_available():\n    import torch_npu\n\n    torch.npu.config.allow_internal_format = False\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n# Global dictionary to store intermediate features from hooks\nhook_features: Dict[str, torch.Tensor] = {}\n\n\ndef get_intermediate_features_hook(name: str):\n    \"\"\"Creates a hook function that saves the output of a layer.\"\"\"\n\n    def hook(model, input, output):\n        # Some layers might return tuples (e.g., attention blocks)\n        # We are usually interested in the first element (hidden states)\n        if isinstance(output, tuple):\n            hook_features[name] = output[0]\n        else:\n            hook_features[name] = output\n\n    return hook\n\n\ndef clear_hook_features():\n    \"\"\"Clears the global feature dictionary.\"\"\"\n    global hook_features\n    hook_features = {}\n\n\ndef normalize_features(\n    feat1: torch.Tensor, feat2: torch.Tensor, eps: float = 1e-6\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Normalizes feat1 and feat2 using the statistics of feat2 (predicted features).\n    Normalization is done per-channel.\n    \"\"\"\n    # Calculate stats over spatial dimensions (H, W)\n    dims = tuple(range(2, feat2.ndim))  # Dims to reduce over (usually 2, 3 for H, W)\n    mean = torch.mean(feat2, dim=dims, keepdim=True)\n    std = torch.std(feat2, dim=dims, keepdim=True) + eps\n\n    feat1_norm = (feat1 - mean) / std\n    feat2_norm = (feat2 - mean) / std\n    return feat1_norm, feat2_norm\n\n\ndef get_decoder_layer_names(decoder: nn.Module) -> List[str]:\n    \"\"\"Helper to get potential layer names for hooks in the VAE decoder.\"\"\"\n    layer_names = []\n    for name, module in decoder.named_modules():\n        # Example: Target ResnetBlocks and potentially UpBlocks\n        if isinstance(module, (diffusers.models.resnet.ResnetBlock2D, diffusers.models.unet_2d_blocks.UpBlock2D)):\n            # Filter out redundant names if UpBlock contains ResnetBlocks already named\n            is_child = any(\n                name.startswith(parent + \".\")\n                for parent in layer_names\n                if isinstance(decoder.get_submodule(parent), diffusers.models.unet_2d_blocks.UpBlock2D)\n            )\n            if not is_child:\n                layer_names.append(name)\n    # A basic default selection if complex logic fails\n    if not layer_names:\n        layer_names = [\n            name for name, module in decoder.named_modules() if re.match(r\"up_blocks\\.\\d+\\.resnets\\.\\d+$\", name)\n        ]\n    return layer_names\n\n\ndef save_model_card(\n    repo_id: str,\n    images: list = None,\n    validation_prompt: str = None,\n    base_model: str = None,\n    dataset_name: str = None,\n    repo_folder: str = None,\n    vae_path: str = None,\n):\n    img_str = \"\"\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# Text-to-image finetuning - {repo_id}\n\nThis pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \\n\n{img_str}\n\nSpecial VAE used for training: {vae_path}.\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion-xl\",\n        \"stable-diffusion-xl-diffusers\",\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"LPL based training script of Stable Diffusion XL.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sdxl-model-finetuned\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--timestep_bias_strategy\",\n        type=str,\n        default=\"none\",\n        choices=[\"earlier\", \"later\", \"range\", \"none\"],\n        help=(\n            \"The timestep bias strategy, which may help direct the model toward learning low or high frequency details.\"\n            \" Choices: ['earlier', 'later', 'range', 'none'].\"\n            \" The default is 'none', which means no bias is applied, and training proceeds normally.\"\n            \" The value of 'later' will increase the frequency of the model's final training timesteps.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_bias_multiplier\",\n        type=float,\n        default=1.0,\n        help=(\n            \"The multiplier for the bias. Defaults to 1.0, which means no bias is applied.\"\n            \" A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_bias_begin\",\n        type=int,\n        default=0,\n        help=(\n            \"When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias.\"\n            \" Defaults to zero, which equates to having no specific bias.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_bias_end\",\n        type=int,\n        default=1000,\n        help=(\n            \"When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias.\"\n            \" Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_bias_portion\",\n        type=float,\n        default=0.25,\n        help=(\n            \"The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased.\"\n            \" A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines\"\n            \" whether the biased portions are in the earlier or later timesteps.\"\n        ),\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://arxiv.org/abs/2303.09556.\",\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=None,\n        help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.\",\n    )\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_npu_flash_attention\", action=\"store_true\", help=\"Whether or not to use npu flash attention.\"\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n\n    parser.add_argument(\n        \"--use_lpl\",\n        action=\"store_true\",\n        help=\"Whether to use Latent Perceptual Loss (LPL). Increases memory usage.\",\n    )\n    parser.add_argument(\n        \"--lpl_weight\",\n        type=float,\n        default=1.0,\n        help=\"Weight for the Latent Perceptual Loss.\",\n    )\n    parser.add_argument(\n        \"--lpl_t_threshold\",\n        type=int,\n        default=200,\n        help=\"Apply LPL only for timesteps t < lpl_t_threshold. Corresponds to high SNR.\",\n    )\n    parser.add_argument(\n        \"--lpl_loss_type\",\n        type=str,\n        default=\"mse\",\n        choices=[\"mse\", \"l1\"],\n        help=\"Type of loss to use for LPL.\",\n    )\n    parser.add_argument(\n        \"--lpl_norm_type\",\n        type=str,\n        default=\"default\",\n        choices=[\"default\", \"shared\", \"batch\"],\n        help=\"Type of normalization to use for LPL features.\",\n    )\n    parser.add_argument(\n        \"--lpl_pow_law\",\n        action=\"store_true\",\n        help=\"Whether to use power law weighting for LPL layers.\",\n    )\n    parser.add_argument(\n        \"--lpl_num_blocks\",\n        type=int,\n        default=4,\n        help=\"Number of up blocks to use for LPL feature extraction.\",\n    )\n    parser.add_argument(\n        \"--lpl_remove_outliers\",\n        action=\"store_true\",\n        help=\"Whether to remove outliers in LPL feature maps.\",\n    )\n    parser.add_argument(\n        \"--lpl_scale\",\n        action=\"store_true\",\n        help=\"Whether to scale LPL loss by noise level weights.\",\n    )\n    parser.add_argument(\n        \"--lpl_start\",\n        type=int,\n        default=0,\n        help=\"Step to start applying LPL loss.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    return args\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):\n    prompt_embeds_list = []\n    prompt_batch = batch[caption_column]\n\n    captions = []\n    for caption in prompt_batch:\n        if random.random() < proportion_empty_prompts:\n            captions.append(\"\")\n        elif isinstance(caption, str):\n            captions.append(caption)\n        elif isinstance(caption, (list, np.ndarray)):\n            # take a random caption if there are multiple\n            captions.append(random.choice(caption) if is_train else caption[0])\n\n    with torch.no_grad():\n        for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n            text_inputs = tokenizer(\n                captions,\n                padding=\"max_length\",\n                max_length=tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            prompt_embeds = text_encoder(\n                text_input_ids.to(text_encoder.device),\n                output_hidden_states=True,\n                return_dict=False,\n            )\n\n            # We are only ALWAYS interested in the pooled output of the final text encoder\n            pooled_prompt_embeds = prompt_embeds[0]\n            prompt_embeds = prompt_embeds[-1][-2]\n            bs_embed, seq_len, _ = prompt_embeds.shape\n            prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n            prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return {\"prompt_embeds\": prompt_embeds.cpu(), \"pooled_prompt_embeds\": pooled_prompt_embeds.cpu()}\n\n\ndef compute_vae_encodings(batch, vae):\n    images = batch.pop(\"pixel_values\")\n    pixel_values = torch.stack(list(images))\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n    pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)\n\n    with torch.no_grad():\n        model_input = vae.encode(pixel_values).latent_dist.sample()\n    model_input = model_input * vae.config.scaling_factor\n\n    # There might have slightly performance improvement\n    # by changing model_input.cpu() to accelerator.gather(model_input)\n    return {\"model_input\": model_input.cpu()}\n\n\ndef generate_timestep_weights(args, num_timesteps):\n    weights = torch.ones(num_timesteps)\n\n    # Determine the indices to bias\n    num_to_bias = int(args.timestep_bias_portion * num_timesteps)\n\n    if args.timestep_bias_strategy == \"later\":\n        bias_indices = slice(-num_to_bias, None)\n    elif args.timestep_bias_strategy == \"earlier\":\n        bias_indices = slice(0, num_to_bias)\n    elif args.timestep_bias_strategy == \"range\":\n        # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.\n        range_begin = args.timestep_bias_begin\n        range_end = args.timestep_bias_end\n        if range_begin < 0:\n            raise ValueError(\n                \"When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero.\"\n            )\n        if range_end > num_timesteps:\n            raise ValueError(\n                \"When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps.\"\n            )\n        bias_indices = slice(range_begin, range_end)\n    else:  # 'none' or any other string\n        return weights\n    if args.timestep_bias_multiplier <= 0:\n        return ValueError(\n            \"The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps.\"\n            \" If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead.\"\n            \" A timestep bias multiplier less than or equal to 0 is not allowed.\"\n        )\n\n    # Apply the bias\n    weights[bias_indices] *= args.timestep_bias_multiplier\n\n    # Normalize\n    weights /= weights.sum()\n\n    return weights\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `huggingface-cli login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    # Check for terminal SNR in combination with SNR Gamma\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # Freeze vae and text encoders.\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    # Set unet as trainable.\n    unet.train()\n\n    # For mixed precision training we cast all non-trainable weights to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    vae.to(accelerator.device, dtype=torch.float32)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    # Create EMA for the unet.\n    if args.use_ema:\n        ema_unet = UNet2DConditionModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n        )\n        ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            unet.enable_npu_flash_attention()\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu devices.\")\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_unet.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    if weights:\n                        weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"unet_ema\"), UNet2DConditionModel)\n                ema_unet.load_state_dict(load_model.state_dict())\n                ema_unet.to(accelerator.device)\n                del load_model\n\n            for _ in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = unet.parameters()\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)\n    train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)\n    train_flip = transforms.RandomHorizontalFlip(p=1.0)\n    train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        # image aug\n        original_sizes = []\n        all_images = []\n        crop_top_lefts = []\n        for image in images:\n            original_sizes.append((image.height, image.width))\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            crop_top_left = (y1, x1)\n            crop_top_lefts.append(crop_top_left)\n            image = train_transforms(image)\n            all_images.append(image)\n\n        examples[\"original_sizes\"] = original_sizes\n        examples[\"crop_top_lefts\"] = crop_top_lefts\n        examples[\"pixel_values\"] = all_images\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    # Let's first compute all the embeddings so that we can free up the text encoders\n    # from memory. We will pre-compute the VAE encodings too.\n    text_encoders = [text_encoder_one, text_encoder_two]\n    tokenizers = [tokenizer_one, tokenizer_two]\n    compute_embeddings_fn = functools.partial(\n        encode_prompt,\n        text_encoders=text_encoders,\n        tokenizers=tokenizers,\n        proportion_empty_prompts=args.proportion_empty_prompts,\n        caption_column=args.caption_column,\n    )\n    compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)\n    with accelerator.main_process_first():\n        from datasets.fingerprint import Hasher\n\n        # fingerprint used by the cache for the other processes to load the result\n        # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401\n        new_fingerprint = Hasher.hash(args)\n        new_fingerprint_for_vae = Hasher.hash((vae_path, args))\n        train_dataset_with_embeddings = train_dataset.map(\n            compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint\n        )\n        train_dataset_with_vae = train_dataset.map(\n            compute_vae_encodings_fn,\n            batched=True,\n            batch_size=args.train_batch_size,\n            new_fingerprint=new_fingerprint_for_vae,\n        )\n        precomputed_dataset = concatenate_datasets(\n            [train_dataset_with_embeddings, train_dataset_with_vae.remove_columns([\"image\", \"text\"])], axis=1\n        )\n        precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)\n\n    del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two\n    del text_encoders, tokenizers\n    if not args.use_lpl:\n        del vae\n    gc.collect()\n\n    if is_torch_npu_available():\n        torch_npu.npu.empty_cache()\n    elif torch.cuda.is_available():\n        torch.cuda.empty_cache()\n\n    def collate_fn(examples):\n        model_input = torch.stack([torch.tensor(example[\"model_input\"]) for example in examples])\n        original_sizes = [example[\"original_sizes\"] for example in examples]\n        crop_top_lefts = [example[\"crop_top_lefts\"] for example in examples]\n        prompt_embeds = torch.stack([torch.tensor(example[\"prompt_embeds\"]) for example in examples])\n        pooled_prompt_embeds = torch.stack([torch.tensor(example[\"pooled_prompt_embeds\"]) for example in examples])\n\n        return {\n            \"model_input\": model_input,\n            \"prompt_embeds\": prompt_embeds,\n            \"pooled_prompt_embeds\": pooled_prompt_embeds,\n            \"original_sizes\": original_sizes,\n            \"crop_top_lefts\": crop_top_lefts,\n        }\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        precomputed_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    if args.use_ema:\n        ema_unet.to(accelerator.device)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"text2image-fine-tune-sdxl\", config=vars(args))\n\n    if args.use_lpl:\n        lpl_fn = LatentPerceptualLoss(\n            vae=vae,\n            loss_type=args.lpl_loss_type,\n            grad_ckpt=args.gradient_checkpointing,\n            pow_law=args.lpl_pow_law,\n            norm_type=args.lpl_norm_type,\n            num_mid_blocks=args.lpl_num_blocks,\n            feature_type=\"feature\",\n            remove_outliers=args.lpl_remove_outliers,\n        )\n        lpl_fn.to(accelerator.device)\n    else:\n        lpl_fn = None\n\n    # Function for unwrapping if torch.compile() was used in accelerate.\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    if torch.backends.mps.is_available() or \"playground\" in args.pretrained_model_name_or_path:\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(precomputed_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    # Get scheduler alphas and sigmas for LPL z0_hat calculation\n    alphas_cumprod = noise_scheduler.alphas_cumprod.to(accelerator.device)\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # Sample noise that we'll add to the latents\n                model_input = batch[\"model_input\"].to(accelerator.device)\n                noise = torch.randn_like(model_input)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device\n                    )\n\n                bsz = model_input.shape[0]\n                if args.timestep_bias_strategy == \"none\":\n                    # Sample a random timestep for each image without bias.\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                    )\n                else:\n                    # Sample a random timestep for each image, potentially biased by the timestep weights.\n                    # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.\n                    weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(\n                        model_input.device\n                    )\n                    timesteps = torch.multinomial(weights, bsz, replacement=True).long()\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)\n\n                # time ids\n                def compute_time_ids(original_size, crops_coords_top_left):\n                    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n                    target_size = (args.resolution, args.resolution)\n                    add_time_ids = list(original_size + crops_coords_top_left + target_size)\n                    add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)\n                    return add_time_ids\n\n                add_time_ids = torch.cat(\n                    [compute_time_ids(s, c) for s, c in zip(batch[\"original_sizes\"], batch[\"crop_top_lefts\"])]\n                )\n\n                # Predict the noise residual\n                unet_added_conditions = {\"time_ids\": add_time_ids}\n                prompt_embeds = batch[\"prompt_embeds\"].to(accelerator.device, dtype=weight_dtype)\n                pooled_prompt_embeds = batch[\"pooled_prompt_embeds\"].to(accelerator.device)\n                unet_added_conditions.update({\"text_embeds\": pooled_prompt_embeds})\n                model_pred = unet(\n                    noisy_model_input,\n                    timesteps,\n                    prompt_embeds,\n                    added_cond_kwargs=unet_added_conditions,\n                    return_dict=False,\n                )[0]\n\n                # Get the target for loss depending on the prediction type\n                if args.prediction_type is not None:\n                    # set prediction_type of scheduler if defined\n                    noise_scheduler.register_to_config(prediction_type=args.prediction_type)\n\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n                elif noise_scheduler.config.prediction_type == \"sample\":\n                    # We set the target to latents here, but the model_pred will return the noise sample prediction.\n                    target = model_input\n                    # We will have to subtract the noise residual from the prediction to get the target sample.\n                    model_pred = model_pred - noise\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                lpl_loss_value = torch.tensor(0.0, device=accelerator.device)\n                if args.use_lpl and lpl_fn is not None and global_step >= args.lpl_start:\n                    # Apply LPL only below the timestep threshold\n                    lpl_mask = timesteps < args.lpl_t_threshold\n                    if lpl_mask.any():\n                        # Select samples that meet the threshold\n                        masked_indices = torch.where(lpl_mask)[0]\n                        z0_masked = model_input[masked_indices]\n                        zt_masked = noisy_model_input[masked_indices]\n                        t_masked = timesteps[masked_indices]\n                        model_pred_masked = model_pred[masked_indices]\n\n                        # Calculate z0_hat for the masked samples\n                        alpha_t = alphas_cumprod[t_masked].sqrt().to(torch.float32)\n                        sigma_t = (1 - alphas_cumprod[t_masked]).sqrt().to(torch.float32)\n                        alpha_t = alpha_t.view(-1, 1, 1, 1)\n                        sigma_t = sigma_t.view(-1, 1, 1, 1)\n\n                        if noise_scheduler.config.prediction_type == \"epsilon\":\n                            z0_hat_masked = (zt_masked.float() - sigma_t * model_pred_masked.float()) / alpha_t\n                        elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                            z0_hat_masked = alpha_t * zt_masked.float() - sigma_t * model_pred_masked.float()\n                        else:  # sample prediction\n                            z0_hat_masked = model_pred_masked.float()\n\n                        with accelerator.autocast():\n                            lpl_loss_value = lpl_fn.get_loss(z0_hat_masked, z0_masked)\n\n                            if args.lpl_scale:\n                                if args.snr_gamma is not None:\n                                    # Use SNR-based weights if available\n                                    snr = compute_snr(noise_scheduler, t_masked)\n                                    snr_weights = torch.stack(\n                                        [snr, args.snr_gamma * torch.ones_like(t_masked)], dim=1\n                                    ).min(dim=1)[0]\n                                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                                        snr_weights = snr_weights / snr\n                                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                                        snr_weights = snr_weights / (snr + 1)\n                                    lpl_loss_value = (lpl_loss_value * snr_weights).mean()\n                                else:\n                                    # If no SNR weighting, just use mean\n                                    lpl_loss_value = lpl_loss_value.mean()\n                            else:\n                                lpl_loss_value = lpl_loss_value.mean()\n\n                # Combine losses\n                total_loss = loss + args.lpl_weight * lpl_loss_value\n\n                # Gather the losses across all processes for logging\n                avg_loss = accelerator.gather(total_loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(total_loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = unet.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_unet.step(unet.parameters())\n                progress_bar.update(1)\n                global_step += 1\n\n                # Enhanced logging for LPL metrics\n                log_data = {\n                    \"train_loss\": train_loss,\n                    \"diffusion_loss\": loss.item(),\n                    \"learning_rate\": lr_scheduler.get_last_lr()[0],\n                }\n\n                if args.use_lpl and lpl_fn is not None and global_step >= args.lpl_start:\n                    if lpl_mask.any():\n                        # LPL application statistics\n                        log_data.update(\n                            {\n                                \"lpl/loss\": lpl_loss_value.item(),\n                                \"lpl/num_samples\": lpl_mask.sum().item(),\n                                \"lpl/application_ratio\": lpl_mask.float().mean().item(),\n                                \"lpl/weight\": args.lpl_weight,\n                                \"lpl/weighted_loss\": (args.lpl_weight * lpl_loss_value).item(),\n                            }\n                        )\n\n                        # SNR statistics for LPL-applied samples\n                        if args.snr_gamma is not None:\n                            snr_values = snr[masked_indices]\n                            log_data.update(\n                                {\n                                    \"lpl/snr_mean\": snr_values.mean().item(),\n                                    \"lpl/snr_std\": snr_values.std().item(),\n                                    \"lpl/snr_min\": snr_values.min().item(),\n                                    \"lpl/snr_max\": snr_values.max().item(),\n                                }\n                            )\n\n                        # Feature statistics if available\n                        if hasattr(lpl_fn, \"last_feature_stats\"):\n                            for layer_idx, stats in enumerate(lpl_fn.last_feature_stats):\n                                log_data.update(\n                                    {\n                                        f\"lpl/features/layer_{layer_idx}/mean\": stats[\"mean\"],\n                                        f\"lpl/features/layer_{layer_idx}/std\": stats[\"std\"],\n                                        f\"lpl/features/layer_{layer_idx}/outlier_ratio\": stats.get(\n                                            \"outlier_ratio\", 0.0\n                                        ),\n                                    }\n                                )\n\n                        # Memory usage if available\n                        if torch.cuda.is_available():\n                            log_data.update(\n                                {\n                                    \"lpl/memory/allocated\": torch.cuda.memory_allocated() / 1024**2,  # MB\n                                    \"lpl/memory/reserved\": torch.cuda.memory_reserved() / 1024**2,  # MB\n                                }\n                            )\n\n                # Log to accelerator\n                accelerator.log(log_data, step=global_step)\n\n                # Update progress bar with more metrics\n                progress_bar_logs = {\n                    \"loss\": loss.detach().item(),\n                    \"lr\": lr_scheduler.get_last_lr()[0],\n                }\n                if args.use_lpl and lpl_loss_value.item() > 0:\n                    progress_bar_logs.update(\n                        {\n                            \"lpl\": lpl_loss_value.item(),\n                            \"lpl_ratio\": lpl_mask.float().mean().item() if lpl_mask.any() else 0.0,\n                        }\n                    )\n                progress_bar.set_postfix(**progress_bar_logs)\n\n                # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.\n                if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                logger.info(\n                    f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n                    f\" {args.validation_prompt}.\"\n                )\n                if args.use_ema:\n                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                    ema_unet.store(unet.parameters())\n                    ema_unet.copy_to(unet.parameters())\n\n                # create pipeline\n                vae = AutoencoderKL.from_pretrained(\n                    vae_path,\n                    subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n                    revision=args.revision,\n                    variant=args.variant,\n                )\n                pipeline = StableDiffusionXLPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    unet=accelerator.unwrap_model(unet),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                if args.prediction_type is not None:\n                    scheduler_args = {\"prediction_type\": args.prediction_type}\n                    pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n\n                pipeline = pipeline.to(accelerator.device)\n                pipeline.set_progress_bar_config(disable=True)\n\n                # run inference\n                generator = (\n                    torch.Generator(device=accelerator.device).manual_seed(args.seed)\n                    if args.seed is not None\n                    else None\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n\n                with autocast_ctx:\n                    images = [\n                        pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]\n                        for _ in range(args.num_validation_images)\n                    ]\n\n                for tracker in accelerator.trackers:\n                    if tracker.name == \"tensorboard\":\n                        np_images = np.stack([np.asarray(img) for img in images])\n                        tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n                    if tracker.name == \"wandb\":\n                        tracker.log(\n                            {\n                                \"validation\": [\n                                    wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                    for i, image in enumerate(images)\n                                ]\n                            }\n                        )\n\n                del pipeline\n                if is_torch_npu_available():\n                    torch_npu.npu.empty_cache()\n                elif torch.cuda.is_available():\n                    torch.cuda.empty_cache()\n\n                if args.use_ema:\n                    # Switch back to the original UNet parameters.\n                    ema_unet.restore(unet.parameters())\n\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        if args.use_ema:\n            ema_unet.copy_to(unet.parameters())\n\n        # Serialize pipeline.\n        vae = AutoencoderKL.from_pretrained(\n            vae_path,\n            subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        pipeline = StableDiffusionXLPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            unet=unet,\n            vae=vae,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        if args.prediction_type is not None:\n            scheduler_args = {\"prediction_type\": args.prediction_type}\n            pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n        pipeline.save_pretrained(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline = pipeline.to(accelerator.device)\n            generator = (\n                torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n            )\n\n            with autocast_ctx:\n                images = [\n                    pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]\n                    for _ in range(args.num_validation_images)\n                ]\n\n            for tracker in accelerator.trackers:\n                if tracker.name == \"tensorboard\":\n                    np_images = np.stack([np.asarray(img) for img in images])\n                    tracker.writer.add_images(\"test\", np_images, epoch, dataformats=\"NHWC\")\n                if tracker.name == \"wandb\":\n                    tracker.log(\n                        {\n                            \"test\": [\n                                wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                for i, image in enumerate(images)\n                            ]\n                        }\n                    )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id=repo_id,\n                images=images,\n                validation_prompt=args.validation_prompt,\n                base_model=args.pretrained_model_name_or_path,\n                dataset_name=args.dataset_name,\n                repo_folder=args.output_dir,\n                vae_path=args.pretrained_vae_model_name_or_path,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/multi_subject_dreambooth/README.md",
    "content": "# Multi Subject DreamBooth training\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.\nThis `train_multi_subject_dreambooth.py` script shows how to implement the training procedure for one or more subjects and adapt it for stable diffusion. Note that this code is based off of the `examples/dreambooth/train_dreambooth.py` script as of 01/06/2022.\n\nThis script was added by @kopsahlong, and is not actively maintained. However, if you come across anything that could use fixing, feel free to open an issue and tag @kopsahlong.\n\n## Running locally with PyTorch\n### Installing the dependencies\n\nBefore running the script, make sure to install the library's training dependencies:\n\nTo start, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd into the folder `diffusers/examples/research_projects/multi_subject_dreambooth` and run the following:\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell e.g. a notebook\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\n### Multi Subject Training Example\nIn order to have your model learn multiple concepts at once, we simply add in the additional data directories and prompts to our `instance_data_dir` and `instance_prompt` (as well as `class_data_dir` and `class_prompt` if `--with_prior_preservation` is specified) as one comma separated string.\n\nSee an example with 2 subjects below, which learns a model for one dog subject and one human subject:\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\n# Subject 1\nexport INSTANCE_DIR_1=\"path-to-instance-images-concept-1\"\nexport INSTANCE_PROMPT_1=\"a photo of a sks dog\"\nexport CLASS_DIR_1=\"path-to-class-images-dog\"\nexport CLASS_PROMPT_1=\"a photo of a dog\"\n\n# Subject 2\nexport INSTANCE_DIR_2=\"path-to-instance-images-concept-2\"\nexport INSTANCE_PROMPT_2=\"a photo of a t@y person\"\nexport CLASS_DIR_2=\"path-to-class-images-person\"\nexport CLASS_PROMPT_2=\"a photo of a person\"\n\naccelerate launch train_multi_subject_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=\"$INSTANCE_DIR_1,$INSTANCE_DIR_2\" \\\n  --output_dir=$OUTPUT_DIR \\\n  --train_text_encoder \\\n  --instance_prompt=\"$INSTANCE_PROMPT_1,$INSTANCE_PROMPT_2\" \\\n  --with_prior_preservation \\\n  --prior_loss_weight=1.0 \\\n  --class_data_dir=\"$CLASS_DIR_1,$CLASS_DIR_2\" \\\n  --class_prompt=\"$CLASS_PROMPT_1,$CLASS_PROMPT_2\"\\\n  --num_class_images=50 \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=1e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=1500\n```\n\nThis example shows training for 2 subjects, but please note that the model can be trained on any number of new concepts. This can be done by continuing to add in the corresponding directories and prompts to the corresponding comma separated string.\n\nNote also that in this script, `sks` and `t@y` were used as tokens to learn the new subjects ([this thread](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/issues/71) inspired the use of `t@y` as our second identifier). However, there may be better rare tokens to experiment with, and results also seemed to be good when more intuitive words are used.\n\n**Important**: New parameters are added to the script, making possible to validate the progress of the training by\ngenerating images at specified steps. Taking also into account that a comma separated list in a text field for a prompt\nit's never a good idea (simply because it is very common in prompts to have them as part of a regular text) we\nintroduce the `concept_list` parameter: allowing to specify a json-like file where you can define the different\nconfiguration for each subject that you want to train.\n\nAn example of how to generate the file:\n```python\nimport json\n\n# here we are using parameters for prior-preservation and validation as well.\nconcepts_list = [\n    {\n        \"instance_prompt\":      \"drawing of a t@y meme\",\n        \"class_prompt\":         \"drawing of a meme\",\n        \"instance_data_dir\":    \"/some_folder/meme_toy\",\n        \"class_data_dir\":       \"/data/meme\",\n        \"validation_prompt\":    \"drawing of a t@y meme about football in Uruguay\",\n        \"validation_negative_prompt\": \"black and white\"\n    },\n    {\n        \"instance_prompt\":      \"drawing of a sks sir\",\n        \"class_prompt\":         \"drawing of a sir\",\n        \"instance_data_dir\":    \"/some_other_folder/sir_sks\",\n        \"class_data_dir\":       \"/data/sir\",\n        \"validation_prompt\":    \"drawing of a sks sir with the Uruguayan sun in his chest\",\n        \"validation_negative_prompt\": \"an old man\",\n        \"validation_guidance_scale\": 20,\n        \"validation_number_images\": 3,\n        \"validation_inference_steps\": 10\n    }\n]\n\nwith open(\"concepts_list.json\", \"w\") as f:\n    json.dump(concepts_list, f, indent=4)\n```\nAnd then just point to the file when executing the script:\n\n```bash\n# exports...\naccelerate launch train_multi_subject_dreambooth.py \\\n# more parameters...\n--concepts_list=\"concepts_list.json\"\n```\n\nYou can use the helper from the script to get a better sense of each parameter.\n\n### Inference\n\nOnce you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.\n\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\nmodel_id = \"path-to-your-trained-model\"\npipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A photo of a t@y person petting an sks dog\"\nimage = pipe(prompt, num_inference_steps=200, guidance_scale=7.5).images[0]\n\nimage.save(\"person-petting-dog.png\")\n```\n\n### Inference from a training checkpoint\n\nYou can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it.\n\n## Additional Dreambooth documentation\nBecause the `train_multi_subject_dreambooth.py` script here was forked from an original version of `train_dreambooth.py` in the `examples/dreambooth` folder, I've included the original applicable training documentation for single subject examples below.\n\nThis should explain how to play with training variables such as prior preservation, fine tuning the text encoder, etc. which is still applicable to our multi subject training code. Note also that the examples below, which are single subject examples, also work with `train_multi_subject_dreambooth.py`, as this script supports 1 (or more) subjects.\n\n### Single subject dog toy example\n\nLet's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data.\n\nAnd launch the training using\n\n**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=400\n```\n\n### Training with prior-preservation loss\n\nPrior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.\nAccording to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n\n\n### Training on a 16GB GPU:\n\nWith the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.\n\nTo install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=2 --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n\n### Training on a 8 GB GPU:\n\nBy using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some\ntensors from VRAM to either CPU or NVME allowing to train with less VRAM.\n\nDeepSpeed needs to be enabled with `accelerate config`. During configuration\nanswer yes to \"Do you want to use DeepSpeed?\". With DeepSpeed stage 2, fp16\nmixed precision and offloading both parameters and optimizer state to cpu it's\npossible to train on under 8 GB VRAM with a drawback of requiring significantly\nmore RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.\n\nChanging the default Adam optimizer to DeepSpeed's special version of Adam\n`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling\nit requires CUDA toolchain with the same version as pytorch. 8-bit optimizer\ndoes not seem to be compatible with DeepSpeed at the moment.\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch --mixed_precision=\"fp16\" train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --sample_batch_size=1 \\\n  --gradient_accumulation_steps=1 --gradient_checkpointing \\\n  --learning_rate=5e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n\n### Fine-tune text encoder with the UNet.\n\nThe script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.\nPass the `--train_text_encoder` argument to the script to enable training `text_encoder`.\n\n___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport INSTANCE_DIR=\"path-to-instance-images\"\nexport CLASS_DIR=\"path-to-class-images\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\naccelerate launch train_dreambooth.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --train_text_encoder \\\n  --instance_data_dir=$INSTANCE_DIR \\\n  --class_data_dir=$CLASS_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --with_prior_preservation --prior_loss_weight=1.0 \\\n  --instance_prompt=\"a photo of sks dog\" \\\n  --class_prompt=\"a photo of dog\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --use_8bit_adam \\\n  --gradient_checkpointing \\\n  --learning_rate=2e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --num_class_images=200 \\\n  --max_train_steps=800\n```\n\n### Using DreamBooth for other pipelines than Stable Diffusion\n\nAltdiffusion also supports dreambooth now, the running command is basically the same as above, all you need to do is replace the `MODEL_NAME` like this:\nOne can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion).\n\n```\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\" --> export MODEL_NAME=\"BAAI/AltDiffusion-m9\"\nor\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\" --> export MODEL_NAME=\"BAAI/AltDiffusion\"\n```\n\n### Training with xformers:\nYou can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.\n\nYou can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint)."
  },
  {
    "path": "examples/research_projects/multi_subject_dreambooth/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2"
  },
  {
    "path": "examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py",
    "content": "import argparse\nimport itertools\nimport json\nimport logging\nimport math\nimport uuid\nimport warnings\nfrom os import environ, listdir, makedirs\nfrom os.path import basename, join\nfrom pathlib import Path\nfrom typing import List\n\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom PIL import Image\nfrom torch import dtype\nfrom torch.nn import Module\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    DPMSolverMultistepScheduler,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.13.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef log_validation_images_to_tracker(\n    images: List[np.array], label: str, validation_prompt: str, accelerator: Accelerator, epoch: int\n):\n    logger.info(f\"Logging images to tracker for validation prompt: {validation_prompt}.\")\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{label}_{epoch}_{i}: {validation_prompt}\")\n                        for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n\n# TODO: Add `prompt_embeds` and `negative_prompt_embeds` parameters to the function when `pre_compute_text_embeddings`\n#  argument is implemented.\ndef generate_validation_images(\n    text_encoder: Module,\n    tokenizer: Module,\n    unet: Module,\n    vae: Module,\n    arguments: argparse.Namespace,\n    accelerator: Accelerator,\n    weight_dtype: dtype,\n):\n    logger.info(\"Running validation images.\")\n\n    pipeline_args = {}\n\n    if text_encoder is not None:\n        pipeline_args[\"text_encoder\"] = accelerator.unwrap_model(text_encoder)\n\n    if vae is not None:\n        pipeline_args[\"vae\"] = vae\n\n    # create pipeline (note: unet and vae are loaded again in float32)\n    pipeline = DiffusionPipeline.from_pretrained(\n        arguments.pretrained_model_name_or_path,\n        tokenizer=tokenizer,\n        unet=accelerator.unwrap_model(unet),\n        revision=arguments.revision,\n        torch_dtype=weight_dtype,\n        **pipeline_args,\n    )\n\n    # We train on the simplified learning objective. If we were previously predicting a variance, we need the\n    # scheduler to ignore it\n    scheduler_args = {}\n\n    if \"variance_type\" in pipeline.scheduler.config:\n        variance_type = pipeline.scheduler.config.variance_type\n\n        if variance_type in [\"learned\", \"learned_range\"]:\n            variance_type = \"fixed_small\"\n\n        scheduler_args[\"variance_type\"] = variance_type\n\n    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    generator = (\n        None if arguments.seed is None else torch.Generator(device=accelerator.device).manual_seed(arguments.seed)\n    )\n\n    images_sets = []\n    for vp, nvi, vnp, vis, vgs in zip(\n        arguments.validation_prompt,\n        arguments.validation_number_images,\n        arguments.validation_negative_prompt,\n        arguments.validation_inference_steps,\n        arguments.validation_guidance_scale,\n    ):\n        images = []\n        if vp is not None:\n            logger.info(\n                f\"Generating {nvi} images with prompt: '{vp}', negative prompt: '{vnp}', inference steps: {vis}, \"\n                f\"guidance scale: {vgs}.\"\n            )\n\n            pipeline_args = {\"prompt\": vp, \"negative_prompt\": vnp, \"num_inference_steps\": vis, \"guidance_scale\": vgs}\n\n            # run inference\n            # TODO: it would be good to measure whether it's faster to run inference on all images at once, one at a\n            #  time or in small batches\n            for _ in range(nvi):\n                with torch.autocast(\"cuda\"):\n                    image = pipeline(**pipeline_args, num_images_per_prompt=1, generator=generator).images[0]\n                images.append(image)\n\n        images_sets.append(images)\n\n    del pipeline\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n\n    return images_sets\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of instance images.\",\n    )\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"The prompt with identifier specifying the instance\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\"--train_text_encoder\", action=\"store_true\", help=\"Whether to train the text encoder\")\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more docs\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=None,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt(s) `validation_prompt` \"\n            \"multiple times (`validation_number_images`) and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning. You can use commas to \"\n        \"define multiple negative prompts. This parameter can be defined also within the file given by \"\n        \"`concepts_list` parameter in the respective subject.\",\n    )\n    parser.add_argument(\n        \"--validation_number_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with the validation parameters given. This \"\n        \"can be defined within the file given by `concepts_list` parameter in the respective subject.\",\n    )\n    parser.add_argument(\n        \"--validation_negative_prompt\",\n        type=str,\n        default=None,\n        help=\"A negative prompt that is used during validation to verify that the model is learning. You can use commas\"\n        \" to define multiple negative prompts, each one corresponding to a validation prompt. This parameter can \"\n        \"be defined also within the file given by `concepts_list` parameter in the respective subject.\",\n    )\n    parser.add_argument(\n        \"--validation_inference_steps\",\n        type=int,\n        default=25,\n        help=\"Number of inference steps (denoising steps) to run during validation. This can be defined within the \"\n        \"file given by `concepts_list` parameter in the respective subject.\",\n    )\n    parser.add_argument(\n        \"--validation_guidance_scale\",\n        type=float,\n        default=7.5,\n        help=\"To control how much the image generation process follows the text prompt. This can be defined within the \"\n        \"file given by `concepts_list` parameter in the respective subject.\",\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--concepts_list\",\n        type=str,\n        default=None,\n        help=\"Path to json file containing a list of multiple concepts, will overwrite parameters like instance_prompt,\"\n        \" class_prompt, etc.\",\n    )\n\n    if input_args:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if not args.concepts_list and (not args.instance_data_dir or not args.instance_prompt):\n        raise ValueError(\n            \"You must specify either instance parameters (data directory, prompt, etc.) or use \"\n            \"the `concept_list` parameter and specify them within the file.\"\n        )\n\n    if args.concepts_list:\n        if args.instance_prompt:\n            raise ValueError(\"If you are using `concepts_list` parameter, define the instance prompt within the file.\")\n        if args.instance_data_dir:\n            raise ValueError(\n                \"If you are using `concepts_list` parameter, define the instance data directory within the file.\"\n            )\n        if args.validation_steps and (args.validation_prompt or args.validation_negative_prompt):\n            raise ValueError(\n                \"If you are using `concepts_list` parameter, define validation parameters for \"\n                \"each subject within the file:\\n - `validation_prompt`.\"\n                \"\\n - `validation_negative_prompt`.\\n - `validation_guidance_scale`.\"\n                \"\\n - `validation_number_images`.\\n - `validation_prompt`.\"\n                \"\\n - `validation_inference_steps`.\\nThe `validation_steps` parameter is the only one \"\n                \"that needs to be defined outside the file.\"\n            )\n\n    env_local_rank = int(environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if not args.concepts_list:\n            if not args.class_data_dir:\n                raise ValueError(\"You must specify a data directory for class images.\")\n            if not args.class_prompt:\n                raise ValueError(\"You must specify prompt for class images.\")\n        else:\n            if args.class_data_dir:\n                raise ValueError(\n                    \"If you are using `concepts_list` parameter, define the class data directory within the file.\"\n                )\n            if args.class_prompt:\n                raise ValueError(\n                    \"If you are using `concepts_list` parameter, define the class prompt within the file.\"\n                )\n    else:\n        # logger is not available yet\n        if not args.class_data_dir:\n            warnings.warn(\n                \"Ignoring `class_data_dir` parameter, you need to use it together with `with_prior_preservation`.\"\n            )\n        if not args.class_prompt:\n            warnings.warn(\n                \"Ignoring `class_prompt` parameter, you need to use it together with `with_prior_preservation`.\"\n            )\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images and then tokenizes prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        tokenizer,\n        class_data_root=None,\n        class_prompt=None,\n        size=512,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n        self.tokenizer = tokenizer\n\n        self.instance_data_root = []\n        self.instance_images_path = []\n        self.num_instance_images = []\n        self.instance_prompt = []\n        self.class_data_root = [] if class_data_root is not None else None\n        self.class_images_path = []\n        self.num_class_images = []\n        self.class_prompt = []\n        self._length = 0\n\n        for i in range(len(instance_data_root)):\n            self.instance_data_root.append(Path(instance_data_root[i]))\n            if not self.instance_data_root[i].exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            self.instance_images_path.append(list(Path(instance_data_root[i]).iterdir()))\n            self.num_instance_images.append(len(self.instance_images_path[i]))\n            self.instance_prompt.append(instance_prompt[i])\n            self._length += self.num_instance_images[i]\n\n            if class_data_root is not None:\n                self.class_data_root.append(Path(class_data_root[i]))\n                self.class_data_root[i].mkdir(parents=True, exist_ok=True)\n                self.class_images_path.append(list(self.class_data_root[i].iterdir()))\n                self.num_class_images.append(len(self.class_images_path))\n                if self.num_class_images[i] > self.num_instance_images[i]:\n                    self._length -= self.num_instance_images[i]\n                    self._length += self.num_class_images[i]\n                self.class_prompt.append(class_prompt[i])\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        for i in range(len(self.instance_images_path)):\n            instance_image = Image.open(self.instance_images_path[i][index % self.num_instance_images[i]])\n            if not instance_image.mode == \"RGB\":\n                instance_image = instance_image.convert(\"RGB\")\n            example[f\"instance_images_{i}\"] = self.image_transforms(instance_image)\n            example[f\"instance_prompt_ids_{i}\"] = self.tokenizer(\n                self.instance_prompt[i],\n                truncation=True,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                return_tensors=\"pt\",\n            ).input_ids\n\n        if self.class_data_root:\n            for i in range(len(self.class_data_root)):\n                class_image = Image.open(self.class_images_path[i][index % self.num_class_images[i]])\n                if not class_image.mode == \"RGB\":\n                    class_image = class_image.convert(\"RGB\")\n                example[f\"class_images_{i}\"] = self.image_transforms(class_image)\n                example[f\"class_prompt_ids_{i}\"] = self.tokenizer(\n                    self.class_prompt[i],\n                    truncation=True,\n                    padding=\"max_length\",\n                    max_length=self.tokenizer.model_max_length,\n                    return_tensors=\"pt\",\n                ).input_ids\n\n        return example\n\n\ndef collate_fn(num_instances, examples, with_prior_preservation=False):\n    input_ids = []\n    pixel_values = []\n\n    for i in range(num_instances):\n        input_ids += [example[f\"instance_prompt_ids_{i}\"] for example in examples]\n        pixel_values += [example[f\"instance_images_{i}\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        for i in range(num_instances):\n            input_ids += [example[f\"class_prompt_ids_{i}\"] for example in examples]\n            pixel_values += [example[f\"class_images_{i}\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    input_ids = torch.cat(input_ids, dim=0)\n\n    batch = {\n        \"input_ids\": input_ids,\n        \"pixel_values\": pixel_values,\n    }\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate\n    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.\n    # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.\n    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:\n        raise ValueError(\n            \"Gradient accumulation is not supported when training the text encoder in distributed training. \"\n            \"Please set gradient_accumulation_steps to 1. This feature will be supported in the future.\"\n        )\n\n    instance_data_dir = []\n    instance_prompt = []\n    class_data_dir = [] if args.with_prior_preservation else None\n    class_prompt = [] if args.with_prior_preservation else None\n    if args.concepts_list:\n        with open(args.concepts_list, \"r\") as f:\n            concepts_list = json.load(f)\n\n        if args.validation_steps:\n            args.validation_prompt = []\n            args.validation_number_images = []\n            args.validation_negative_prompt = []\n            args.validation_inference_steps = []\n            args.validation_guidance_scale = []\n\n        for concept in concepts_list:\n            instance_data_dir.append(concept[\"instance_data_dir\"])\n            instance_prompt.append(concept[\"instance_prompt\"])\n\n            if args.with_prior_preservation:\n                try:\n                    class_data_dir.append(concept[\"class_data_dir\"])\n                    class_prompt.append(concept[\"class_prompt\"])\n                except KeyError:\n                    raise KeyError(\n                        \"`class_data_dir` or `class_prompt` not found in concepts_list while using \"\n                        \"`with_prior_preservation`.\"\n                    )\n            else:\n                if \"class_data_dir\" in concept:\n                    warnings.warn(\n                        \"Ignoring `class_data_dir` key, to use it you need to enable `with_prior_preservation`.\"\n                    )\n                if \"class_prompt\" in concept:\n                    warnings.warn(\n                        \"Ignoring `class_prompt` key, to use it you need to enable `with_prior_preservation`.\"\n                    )\n\n            if args.validation_steps:\n                args.validation_prompt.append(concept.get(\"validation_prompt\", None))\n                args.validation_number_images.append(concept.get(\"validation_number_images\", 4))\n                args.validation_negative_prompt.append(concept.get(\"validation_negative_prompt\", None))\n                args.validation_inference_steps.append(concept.get(\"validation_inference_steps\", 25))\n                args.validation_guidance_scale.append(concept.get(\"validation_guidance_scale\", 7.5))\n    else:\n        # Parse instance and class inputs, and double check that lengths match\n        instance_data_dir = args.instance_data_dir.split(\",\")\n        instance_prompt = args.instance_prompt.split(\",\")\n        assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (\n            \"Instance data dir and prompt inputs are not of the same length.\"\n        )\n\n        if args.with_prior_preservation:\n            class_data_dir = args.class_data_dir.split(\",\")\n            class_prompt = args.class_prompt.split(\",\")\n            assert all(\n                x == len(instance_data_dir)\n                for x in [len(instance_data_dir), len(instance_prompt), len(class_data_dir), len(class_prompt)]\n            ), \"Instance & class data dir or prompt inputs are not of the same length.\"\n\n        if args.validation_steps:\n            validation_prompts = args.validation_prompt.split(\",\")\n            num_of_validation_prompts = len(validation_prompts)\n            args.validation_prompt = validation_prompts\n            args.validation_number_images = [args.validation_number_images] * num_of_validation_prompts\n\n            negative_validation_prompts = [None] * num_of_validation_prompts\n            if args.validation_negative_prompt:\n                negative_validation_prompts = args.validation_negative_prompt.split(\",\")\n                while len(negative_validation_prompts) < num_of_validation_prompts:\n                    negative_validation_prompts.append(None)\n            args.validation_negative_prompt = negative_validation_prompts\n\n            assert num_of_validation_prompts == len(negative_validation_prompts), (\n                \"The length of negative prompts for validation is greater than the number of validation prompts.\"\n            )\n            args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts\n            args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        for i in range(len(class_data_dir)):\n            class_images_dir = Path(class_data_dir[i])\n            if not class_images_dir.exists():\n                class_images_dir.mkdir(parents=True)\n            cur_class_images = len(list(class_images_dir.iterdir()))\n\n            if cur_class_images < args.num_class_images:\n                torch_dtype = torch.float16 if accelerator.device.type == \"cuda\" else torch.float32\n                if args.prior_generation_precision == \"fp32\":\n                    torch_dtype = torch.float32\n                elif args.prior_generation_precision == \"fp16\":\n                    torch_dtype = torch.float16\n                elif args.prior_generation_precision == \"bf16\":\n                    torch_dtype = torch.bfloat16\n                pipeline = DiffusionPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    torch_dtype=torch_dtype,\n                    safety_checker=None,\n                    revision=args.revision,\n                )\n                pipeline.set_progress_bar_config(disable=True)\n\n                num_new_images = args.num_class_images - cur_class_images\n                logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n                sample_dataset = PromptDataset(class_prompt[i], num_new_images)\n                sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n                sample_dataloader = accelerator.prepare(sample_dataloader)\n                pipeline.to(accelerator.device)\n\n                for example in tqdm(\n                    sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n                ):\n                    images = pipeline(example[\"prompt\"]).images\n\n                    for ii, image in enumerate(images):\n                        hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                        image_filename = (\n                            class_images_dir / f\"{example['index'][ii] + cur_class_images}-{hash_image}.jpg\"\n                        )\n                        image.save(image_filename)\n\n                # Clean up the memory deleting one-time-use variables.\n                del pipeline\n                del sample_dataloader\n                del sample_dataset\n                if torch.cuda.is_available():\n                    torch.cuda.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    tokenizer = None\n    if args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.pretrained_model_name_or_path,\n            subfolder=\"tokenizer\",\n            revision=args.revision,\n            use_fast=False,\n        )\n\n    # import correct text encoder class\n    text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = text_encoder_cls.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n    )\n\n    vae.requires_grad_(False)\n    if not args.train_text_encoder:\n        text_encoder.requires_grad_(False)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder.gradient_checkpointing_enable()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = (\n        itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()\n    )\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=instance_data_dir,\n        instance_prompt=instance_prompt,\n        class_data_root=class_data_dir,\n        class_prompt=class_prompt,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(len(instance_data_dir), examples, args.with_prior_preservation),\n        num_workers=1,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae and text_encoder to device and cast to weight_dtype\n    vae.to(accelerator.device, dtype=weight_dtype)\n    if not args.train_text_encoder:\n        text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initialize automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"dreambooth\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.train_text_encoder:\n            text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                time_steps = torch.randint(\n                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device\n                )\n                time_steps = time_steps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, time_steps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n                # Predict the noise residual\n                model_pred = unet(noisy_latents, time_steps, encoder_hidden_states).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, time_steps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute instance loss\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                    # Compute prior loss\n                    prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction=\"mean\")\n\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n                else:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(unet.parameters(), text_encoder.parameters())\n                        if args.train_text_encoder\n                        else unet.parameters()\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        save_path = join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if (\n                        args.validation_steps\n                        and any(args.validation_prompt)\n                        and global_step % args.validation_steps == 0\n                    ):\n                        images_set = generate_validation_images(\n                            text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype\n                        )\n                        for images, validation_prompt in zip(images_set, args.validation_prompt):\n                            if len(images) > 0:\n                                label = str(uuid.uuid1())[:8]  # generate an id for different set of images\n                                log_validation_images_to_tracker(\n                                    images, label, validation_prompt, accelerator, global_step\n                                )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        pipeline = DiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            unet=accelerator.unwrap_model(unet),\n            text_encoder=accelerator.unwrap_model(text_encoder),\n            revision=args.revision,\n        )\n        pipeline.save_pretrained(args.output_dir)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/multi_subject_dreambooth_inpainting/README.md",
    "content": "# Multi Subject Dreambooth for Inpainting Models\n\nPlease note that this project is not actively maintained. However, you can open an issue and tag @gzguevara.\n\n[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. This project consists of **two parts**. Training Stable Diffusion for inpainting requires prompt-image-mask pairs. The Unet of inpainiting models have 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself).\n\n**The first part**, the `multi_inpaint_dataset.ipynb` notebook, demonstrates how make a 🤗 dataset of prompt-image-mask pairs. You can, however, skip the first part and move straight to the second part with the example datasets in this project. ([cat toy dataset masked](https://huggingface.co/datasets/gzguevara/cat_toy_masked), [mr. potato head dataset masked](https://huggingface.co/datasets/gzguevara/mr_potato_head_masked))\n\n**The second part**, the `train_multi_subject_inpainting.py` training script, demonstrates how to implement a training procedure for one or more subjects and adapt it for stable diffusion for inpainting.\n\n## 1. Data Collection: Make Prompt-Image-Mask Pairs\n\n Earlier training scripts have provided approaches like random masking for the training images. This project provides a notebook for more precise mask setting.\n\nThe notebook can be found here: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JNEASI_B7pLW1srxhgln6nM0HoGAQT32?usp=sharing)\n\nThe `multi_inpaint_dataset.ipynb` notebook, takes training & validation images, on which the user draws masks and provides prompts to make a prompt-image-mask pairs. This ensures that during training, the loss is computed on the area masking the object of interest, rather than on random areas. Moreover, the `multi_inpaint_dataset.ipynb` notebook allows you to build a validation dataset with corresponding masks for monitoring the training process. Example below:\n\n![train_val_pairs](https://drive.google.com/uc?id=1PzwH8E3icl_ubVmA19G0HZGLImFX3x5I)\n\nYou can build multiple datasets for every subject and upload them to the 🤗 hub. Later, when launching the training script you can indicate the paths of the datasets, on which you would like to finetune Stable Diffusion for inpaining.\n\n## 2. Train Multi Subject Dreambooth for Inpainting\n\n### 2.1. Setting The Training Configuration\n\nBefore launching the training script, make sure to select the inpainting the target model, the output directory and the 🤗 datasets.\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-inpainting\"\nexport OUTPUT_DIR=\"path-to-save-model\"\n\nexport DATASET_1=\"gzguevara/mr_potato_head_masked\"\nexport DATASET_2=\"gzguevara/cat_toy_masked\"\n... # Further paths to 🤗 datasets\n```\n\n### 2.2. Launching The Training Script\n\n```bash\naccelerate launch train_multi_subject_dreambooth_inpaint.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir $DATASET_1 $DATASET_2 \\\n  --output_dir=$OUTPUT_DIR \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=2 \\\n  --learning_rate=3e-6 \\\n  --max_train_steps=500 \\\n  --report_to_wandb\n```\n\n### 2.3. Fine-tune text encoder with the UNet.\n\nThe script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.\nPass the `--train_text_encoder` argument to the script to enable training `text_encoder`.\n\n___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___\n\n```bash\naccelerate launch train_multi_subject_dreambooth_inpaint.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir $DATASET_1 $DATASET_2 \\\n  --output_dir=$OUTPUT_DIR \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=2 \\\n  --learning_rate=2e-6 \\\n  --max_train_steps=500 \\\n  --report_to_wandb \\\n  --train_text_encoder\n```\n\n## 3. Results\n\nA [![Weights & Biases](https://img.shields.io/badge/Weights%20&%20Biases-Report-blue)](https://wandb.ai/gzguevara/uncategorized/reports/Multi-Subject-Dreambooth-for-Inpainting--Vmlldzo2MzY5NDQ4?accessToken=y0nya2d7baguhbryxaikbfr1203amvn1jsmyl07vk122mrs7tnph037u1nqgse8t) is provided showing the training progress by every 50 steps. Note, the reported weights & biases run was performed on a A100 GPU with the following stetting:\n\n```bash\naccelerate launch train_multi_subject_dreambooth_inpaint.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --instance_data_dir $DATASET_1 $DATASET_2 \\\n  --output_dir=$OUTPUT_DIR \\\n  --resolution=512 \\\n  --train_batch_size=10 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=1e-6 \\\n  --max_train_steps=500 \\\n  --report_to_wandb \\\n  --train_text_encoder\n```\nHere you can see the target objects on my desk and next to my plant:\n\n![Results](https://drive.google.com/uc?id=1kQisOiiF5cj4rOYjdq8SCZenNsUP2aK0)\n"
  },
  {
    "path": "examples/research_projects/multi_subject_dreambooth_inpainting/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\ndatasets>=2.16.0\nwandb>=0.16.1\nftfy\ntensorboard\nJinja2"
  },
  {
    "path": "examples/research_projects/multi_subject_dreambooth_inpainting/train_multi_subject_dreambooth_inpainting.py",
    "content": "import argparse\nimport copy\nimport itertools\nimport logging\nimport math\nimport os\nimport random\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import concatenate_datasets, load_dataset\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    StableDiffusionInpaintPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.13.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\"--instance_data_dir\", nargs=\"+\", help=\"Instance data directories\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_text_encoder\", default=False, action=\"store_true\", help=\"Whether to train the text encoder\"\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=1000,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint and are suitable for resuming training\"\n            \" using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpointing_from\",\n        type=int,\n        default=1000,\n        help=(\"Start to checkpoint from step\"),\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=50,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_from\",\n        type=int,\n        default=0,\n        help=(\"Start to validate from step\"),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more docs\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--validation_project_name\",\n        type=str,\n        default=None,\n        help=\"The w&b name.\",\n    )\n    parser.add_argument(\n        \"--report_to_wandb\", default=False, action=\"store_true\", help=\"Whether to report to weights and biases\"\n    )\n\n    args = parser.parse_args()\n\n    return args\n\n\ndef prepare_mask_and_masked_image(image, mask):\n    image = np.array(image.convert(\"RGB\"))\n    image = image[None].transpose(0, 3, 1, 2)\n    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0\n\n    mask = np.array(mask.convert(\"L\"))\n    mask = mask.astype(np.float32) / 255.0\n    mask = mask[None, None]\n    mask[mask < 0.5] = 0\n    mask[mask >= 0.5] = 1\n    mask = torch.from_numpy(mask)\n\n    masked_image = image * (mask < 0.5)\n\n    return mask, masked_image\n\n\nclass DreamBoothDataset(Dataset):\n    def __init__(\n        self,\n        tokenizer,\n        datasets_paths,\n    ):\n        self.tokenizer = tokenizer\n        self.datasets_paths = (datasets_paths,)\n        self.datasets = [load_dataset(dataset_path) for dataset_path in self.datasets_paths[0]]\n        self.train_data = concatenate_datasets([dataset[\"train\"] for dataset in self.datasets])\n        self.test_data = concatenate_datasets([dataset[\"test\"] for dataset in self.datasets])\n\n        self.image_normalize = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def set_image(self, img, switch):\n        if img.mode not in [\"RGB\", \"L\"]:\n            img = img.convert(\"RGB\")\n\n        if switch:\n            img = img.transpose(Image.FLIP_LEFT_RIGHT)\n\n        img = img.resize((512, 512), Image.BILINEAR)\n\n        return img\n\n    def __len__(self):\n        return len(self.train_data)\n\n    def __getitem__(self, index):\n        # Lettings\n        example = {}\n        img_idx = index % len(self.train_data)\n        switch = random.choice([True, False])\n\n        # Load image\n        image = self.set_image(self.train_data[img_idx][\"image\"], switch)\n\n        # Normalize image\n        image_norm = self.image_normalize(image)\n\n        # Tokenise prompt\n        tokenized_prompt = self.tokenizer(\n            self.train_data[img_idx][\"prompt\"],\n            padding=\"do_not_pad\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n        ).input_ids\n\n        # Load masks for image\n        masks = [\n            self.set_image(self.train_data[img_idx][key], switch) for key in self.train_data[img_idx] if \"mask\" in key\n        ]\n\n        # Build example\n        example[\"PIL_image\"] = image\n        example[\"instance_image\"] = image_norm\n        example[\"instance_prompt_id\"] = tokenized_prompt\n        example[\"instance_masks\"] = masks\n\n        return example\n\n\ndef weighted_mask(masks):\n    # Convert each mask to a NumPy array and ensure it's binary\n    mask_arrays = [np.array(mask) / 255 for mask in masks]  # Normalizing to 0-1 range\n\n    # Generate random weights and apply them to each mask\n    weights = [random.random() for _ in masks]\n    weights = [weight / sum(weights) for weight in weights]\n    weighted_masks = [mask * weight for mask, weight in zip(mask_arrays, weights)]\n\n    # Sum the weighted masks\n    summed_mask = np.sum(weighted_masks, axis=0)\n\n    # Apply a threshold to create the final mask\n    threshold = 0.5  # This threshold can be adjusted\n    result_mask = summed_mask >= threshold\n\n    # Convert the result back to a PIL image\n    return Image.fromarray(result_mask.astype(np.uint8) * 255)\n\n\ndef collate_fn(examples, tokenizer):\n    input_ids = [example[\"instance_prompt_id\"] for example in examples]\n    pixel_values = [example[\"instance_image\"] for example in examples]\n\n    masks, masked_images = [], []\n\n    for example in examples:\n        # generate a random mask\n        mask = weighted_mask(example[\"instance_masks\"])\n\n        # prepare mask and masked image\n        mask, masked_image = prepare_mask_and_masked_image(example[\"PIL_image\"], mask)\n\n        masks.append(mask)\n        masked_images.append(masked_image)\n\n    pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()\n    masks = torch.stack(masks)\n    masked_images = torch.stack(masked_images)\n    input_ids = tokenizer.pad({\"input_ids\": input_ids}, padding=True, return_tensors=\"pt\").input_ids\n\n    batch = {\"input_ids\": input_ids, \"pixel_values\": pixel_values, \"masks\": masks, \"masked_images\": masked_images}\n\n    return batch\n\n\ndef log_validation(pipeline, text_encoder, unet, val_pairs, accelerator):\n    # update pipeline (note: unet and vae are loaded again in float32)\n    pipeline.text_encoder = accelerator.unwrap_model(text_encoder)\n    pipeline.unet = accelerator.unwrap_model(unet)\n\n    with torch.autocast(\"cuda\"):\n        val_results = [{\"data_or_path\": pipeline(**pair).images[0], \"caption\": pair[\"prompt\"]} for pair in val_pairs]\n\n    torch.cuda.empty_cache()\n\n    wandb.log({\"validation\": [wandb.Image(**val_result) for val_result in val_results]})\n\n\ndef checkpoint(args, global_step, accelerator):\n    save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n    accelerator.save_state(save_path)\n    logger.info(f\"Saved state to {save_path}\")\n\n\ndef main():\n    args = parse_args()\n\n    project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit,\n        project_dir=args.output_dir,\n        logging_dir=Path(args.output_dir, args.logging_dir),\n    )\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        project_config=project_config,\n        log_with=\"wandb\" if args.report_to_wandb else None,\n    )\n\n    if args.report_to_wandb and not is_wandb_available():\n        raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n\n    # Load the tokenizer & models and create wrapper for stable diffusion\n    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\"\n    ).requires_grad_(args.train_text_encoder)\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\").requires_grad_(False)\n    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"unet\")\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    optimizer = torch.optim.AdamW(\n        params=itertools.chain(unet.parameters(), text_encoder.parameters())\n        if args.train_text_encoder\n        else unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    train_dataset = DreamBoothDataset(\n        tokenizer=tokenizer,\n        datasets_paths=args.instance_data_dir,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, tokenizer),\n    )\n\n    # Scheduler and math around the number of training steps.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    if args.train_text_encoder:\n        unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    accelerator.register_for_checkpointing(lr_scheduler)\n\n    if args.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif args.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n    else:\n        weight_dtype = torch.float32\n\n    # Move text_encode and vae to gpu.\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    vae.to(accelerator.device, dtype=weight_dtype)\n    if not args.train_text_encoder:\n        text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n\n    # Afterwards we calculate our number of training epochs\n    num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = vars(copy.deepcopy(args))\n        accelerator.init_trackers(args.validation_project_name, config=tracker_config)\n\n    # create validation pipeline (note: unet and vae are loaded again in float32)\n    val_pipeline = StableDiffusionInpaintPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        unet=unet,\n        vae=vae,\n        torch_dtype=weight_dtype,\n        safety_checker=None,\n    )\n    val_pipeline.set_progress_bar_config(disable=True)\n\n    # prepare validation dataset\n    val_pairs = [\n        {\n            \"image\": example[\"image\"],\n            \"mask_image\": mask,\n            \"prompt\": example[\"prompt\"],\n        }\n        for example in train_dataset.test_data\n        for mask in [example[key] for key in example if \"mask\" in key]\n    ]\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            for model in models:\n                sub_dir = \"unet\" if isinstance(model, type(accelerator.unwrap_model(unet))) else \"text_encoder\"\n                model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n\n    print()\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n\n    global_step = 0\n    first_epoch = 0\n\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    for epoch in range(first_epoch, num_train_epochs):\n        unet.train()\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Convert masked images to latent space\n                masked_latents = vae.encode(\n                    batch[\"masked_images\"].reshape(batch[\"pixel_values\"].shape).to(dtype=weight_dtype)\n                ).latent_dist.sample()\n                masked_latents = masked_latents * vae.config.scaling_factor\n\n                masks = batch[\"masks\"]\n                # resize the mask to latents shape as we concatenate the mask to the latents\n                mask = torch.stack(\n                    [\n                        torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))\n                        for mask in masks\n                    ]\n                )\n                mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # concatenate the noised latents with the mask and the masked latents\n                latent_model_input = torch.cat([noisy_latents, mask, masked_latents], dim=1)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n                # Predict the noise residual\n                noise_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                loss = F.mse_loss(noise_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(unet.parameters(), text_encoder.parameters())\n                        if args.train_text_encoder\n                        else unet.parameters()\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if (\n                        global_step % args.validation_steps == 0\n                        and global_step >= args.validation_from\n                        and args.report_to_wandb\n                    ):\n                        log_validation(\n                            val_pipeline,\n                            text_encoder,\n                            unet,\n                            val_pairs,\n                            accelerator,\n                        )\n\n                    if global_step % args.checkpointing_steps == 0 and global_step >= args.checkpointing_from:\n                        checkpoint(\n                            args,\n                            global_step,\n                            accelerator,\n                        )\n\n            # Step logging\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        accelerator.wait_for_everyone()\n\n    # Terminate training\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/multi_token_textual_inversion/README.md",
    "content": "## [Deprecated] Multi Token Textual Inversion\n\n**IMPORTART: This research project is deprecated. Multi Token Textual Inversion is now supported natively in [the official textual inversion example](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion#running-locally-with-pytorch).**\n\nThe author of this project is [Isamu Isozaki](https://github.com/isamu-isozaki) - please make sure to tag the author for issue and PRs as well as @patrickvonplaten.\n\nWe add multi token support to textual inversion. I added\n1. num_vec_per_token for the number of used to reference that token\n2. progressive_tokens for progressively training the token from 1 token to 2 token etc\n3. progressive_tokens_max_steps for the max number of steps until we start full training\n4. vector_shuffle to shuffle vectors\n\nFeel free to add these options to your training! In practice num_vec_per_token around 10+vector shuffle works great!\n\n## Textual Inversion fine-tuning example\n\n[Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.\nThe `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.\n\n## Running on Colab\n\nColab for training\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)\n\nColab for inference\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)\n\n## Running locally with PyTorch\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder  and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\n\n### Cat toy example\n\nYou need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.\n\nYou have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).\n\nRun the following command to authenticate your token\n\n```bash\nhf auth login\n```\n\nIf you have already cloned the repo, then you won't need to go through these steps.\n\n<br>\n\nNow let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data.\n\nAnd launch the training using\n\n**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport DATA_DIR=\"path-to-dir-containing-images\"\n\naccelerate launch textual_inversion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<cat-toy>\" --initializer_token=\"toy\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=3000 \\\n  --learning_rate=5.0e-04 --scale_lr \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --output_dir=\"textual_inversion_cat\"\n```\n\nA full training run takes ~1 hour on one V100 GPU.\n\n### Inference\n\nOnce you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.\n\n```python\nfrom diffusers import StableDiffusionPipeline\n\nmodel_id = \"path-to-your-trained-model\"\npipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A <cat-toy> backpack\"\n\nimage = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]\n\nimage.save(\"cat-backpack.png\")\n```\n\n\n## Training with Flax/JAX\n\nFor faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n```bash\npip install -U -r requirements_flax.txt\n```\n\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport DATA_DIR=\"path-to-dir-containing-images\"\n\npython textual_inversion_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<cat-toy>\" --initializer_token=\"toy\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --max_train_steps=3000 \\\n  --learning_rate=5.0e-04 --scale_lr \\\n  --output_dir=\"textual_inversion_cat\"\n```\nIt should be at least 70% faster than the PyTorch script with the same configuration.\n\n### Training with xformers:\nYou can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.\n"
  },
  {
    "path": "examples/research_projects/multi_token_textual_inversion/multi_token_clip.py",
    "content": "\"\"\"\nThe main idea for this code is to provide a way for users to not need to bother with the hassle of multiple tokens for a concept by typing\na photo of <concept>_0 <concept>_1 ... and so on\nand instead just do\na photo of <concept>\nwhich gets translated to the above. This needs to work for both inference and training.\nFor inference,\nthe tokenizer encodes the text. So, we would want logic for our tokenizer to replace the placeholder token with\nit's underlying vectors\nFor training,\nwe would want to abstract away some logic like\n1. Adding tokens\n2. Updating gradient mask\n3. Saving embeddings\nto our Util class here.\nso\nTODO:\n1. have tokenizer keep track of concept, multiconcept pairs and replace during encode call x\n2. have mechanism for adding tokens x\n3. have mech for saving emebeddings x\n4. get mask to update x\n5. Loading tokens from embedding x\n6. Integrate to training x\n7. Test\n\"\"\"\n\nimport copy\nimport random\n\nfrom transformers import CLIPTokenizer\n\n\nclass MultiTokenCLIPTokenizer(CLIPTokenizer):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.token_map = {}\n\n    def try_adding_tokens(self, placeholder_token, *args, **kwargs):\n        num_added_tokens = super().add_tokens(placeholder_token, *args, **kwargs)\n        if num_added_tokens == 0:\n            raise ValueError(\n                f\"The tokenizer already contains the token {placeholder_token}. Please pass a different\"\n                \" `placeholder_token` that is not already in the tokenizer.\"\n            )\n\n    def add_placeholder_tokens(self, placeholder_token, *args, num_vec_per_token=1, **kwargs):\n        output = []\n        if num_vec_per_token == 1:\n            self.try_adding_tokens(placeholder_token, *args, **kwargs)\n            output.append(placeholder_token)\n        else:\n            output = []\n            for i in range(num_vec_per_token):\n                ith_token = placeholder_token + f\"_{i}\"\n                self.try_adding_tokens(ith_token, *args, **kwargs)\n                output.append(ith_token)\n        # handle cases where there is a new placeholder token that contains the current placeholder token but is larger\n        for token in self.token_map:\n            if token in placeholder_token:\n                raise ValueError(\n                    f\"The tokenizer already has placeholder token {token} that can get confused with\"\n                    f\" {placeholder_token}keep placeholder tokens independent\"\n                )\n        self.token_map[placeholder_token] = output\n\n    def replace_placeholder_tokens_in_text(self, text, vector_shuffle=False, prop_tokens_to_load=1.0):\n        \"\"\"\n        Here, we replace the placeholder tokens in text recorded in token_map so that the text_encoder\n        can encode them\n        vector_shuffle was inspired by https://github.com/rinongal/textual_inversion/pull/119\n        where shuffling tokens were found to force the model to learn the concepts more descriptively.\n        \"\"\"\n        if isinstance(text, list):\n            output = []\n            for i in range(len(text)):\n                output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))\n            return output\n        for placeholder_token in self.token_map:\n            if placeholder_token in text:\n                tokens = self.token_map[placeholder_token]\n                tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]\n                if vector_shuffle:\n                    tokens = copy.copy(tokens)\n                    random.shuffle(tokens)\n                text = text.replace(placeholder_token, \" \".join(tokens))\n        return text\n\n    def __call__(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs):\n        return super().__call__(\n            self.replace_placeholder_tokens_in_text(\n                text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load\n            ),\n            *args,\n            **kwargs,\n        )\n\n    def encode(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs):\n        return super().encode(\n            self.replace_placeholder_tokens_in_text(\n                text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load\n            ),\n            *args,\n            **kwargs,\n        )\n"
  },
  {
    "path": "examples/research_projects/multi_token_textual_inversion/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2\n"
  },
  {
    "path": "examples/research_projects/multi_token_textual_inversion/requirements_flax.txt",
    "content": "transformers>=4.25.1\nflax\noptax\ntorch\ntorchvision\nftfy\ntensorboard\nJinja2\n"
  },
  {
    "path": "examples/research_projects/multi_token_textual_inversion/textual_inversion.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nfrom pathlib import Path\n\nimport numpy as np\nimport PIL\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom multi_token_clip import MultiTokenCLIPTokenizer\n\n# TODO: remove and import from diffusers.utils when the new version of diffusers is released\nfrom packaging import version\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    DPMSolverMultistepScheduler,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nif version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.Resampling.BILINEAR,\n        \"bilinear\": PIL.Image.Resampling.BILINEAR,\n        \"bicubic\": PIL.Image.Resampling.BICUBIC,\n        \"lanczos\": PIL.Image.Resampling.LANCZOS,\n        \"nearest\": PIL.Image.Resampling.NEAREST,\n    }\nelse:\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.LINEAR,\n        \"bilinear\": PIL.Image.BILINEAR,\n        \"bicubic\": PIL.Image.BICUBIC,\n        \"lanczos\": PIL.Image.LANCZOS,\n        \"nearest\": PIL.Image.NEAREST,\n    }\n# ------------------------------------------------------------------------------\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.14.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=1, initializer_token=None):\n    \"\"\"\n    Add tokens to the tokenizer and set the initial value of token embeddings\n    \"\"\"\n    tokenizer.add_placeholder_tokens(placeholder_token, num_vec_per_token=num_vec_per_token)\n    text_encoder.resize_token_embeddings(len(tokenizer))\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)\n    if initializer_token:\n        token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)\n        for i, placeholder_token_id in enumerate(placeholder_token_ids):\n            token_embeds[placeholder_token_id] = token_embeds[token_ids[i * len(token_ids) // num_vec_per_token]]\n    else:\n        for i, placeholder_token_id in enumerate(placeholder_token_ids):\n            token_embeds[placeholder_token_id] = torch.randn_like(token_embeds[placeholder_token_id])\n    return placeholder_token\n\n\ndef save_progress(tokenizer, text_encoder, accelerator, save_path):\n    for placeholder_token in tokenizer.token_map:\n        placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)\n        learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_ids]\n        if len(placeholder_token_ids) == 1:\n            learned_embeds = learned_embeds[None]\n        learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}\n        torch.save(learned_embeds_dict, save_path)\n\n\ndef load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict):\n    for placeholder_token in learned_embeds_dict:\n        placeholder_embeds = learned_embeds_dict[placeholder_token]\n        num_vec_per_token = placeholder_embeds.shape[0]\n        placeholder_embeds = placeholder_embeds.to(dtype=text_encoder.dtype)\n        add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=num_vec_per_token)\n        placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)\n        token_embeds = text_encoder.get_input_embeddings().weight.data\n        for i, placeholder_token_id in enumerate(placeholder_token_ids):\n            token_embeds[placeholder_token_id] = placeholder_embeds[i]\n\n\ndef load_multitoken_tokenizer_from_automatic(tokenizer, text_encoder, automatic_dict, placeholder_token):\n    \"\"\"\n    Automatic1111's tokens have format\n    {'string_to_token': {'*': 265}, 'string_to_param': {'*': tensor([[ 0.0833,  0.0030,  0.0057,  ..., -0.0264, -0.0616, -0.0529],\n        [ 0.0058, -0.0190, -0.0584,  ..., -0.0025, -0.0945, -0.0490],\n        [ 0.0916,  0.0025,  0.0365,  ..., -0.0685, -0.0124,  0.0728],\n        [ 0.0812, -0.0199, -0.0100,  ..., -0.0581, -0.0780,  0.0254]],\n       requires_grad=True)}, 'name': 'FloralMarble-400', 'step': 399, 'sd_checkpoint': '4bdfc29c', 'sd_checkpoint_name': 'SD2.1-768'}\n    \"\"\"\n    learned_embeds_dict = {}\n    learned_embeds_dict[placeholder_token] = automatic_dict[\"string_to_param\"][\"*\"]\n    load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict)\n\n\ndef get_mask(tokenizer, accelerator):\n    # Get the mask of the weights that won't change\n    mask = torch.ones(len(tokenizer)).to(accelerator.device, dtype=torch.bool)\n    for placeholder_token in tokenizer.token_map:\n        placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)\n        for i in range(len(placeholder_token_ids)):\n            mask = mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i]).to(accelerator.device)\n    return mask\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--progressive_tokens_max_steps\",\n        type=int,\n        default=2000,\n        help=\"The number of steps until all tokens will be used.\",\n    )\n    parser.add_argument(\n        \"--progressive_tokens\",\n        action=\"store_true\",\n        help=\"Progressively train the tokens. For example, first train for 1 token, then 2 tokens and so on.\",\n    )\n    parser.add_argument(\"--vector_shuffle\", action=\"store_true\", help=\"Shuffling tokens durint training\")\n    parser.add_argument(\n        \"--num_vec_per_token\",\n        type=int,\n        default=1,\n        help=(\n            \"The number of vectors used to represent the placeholder token. The higher the number, the better the\"\n            \" result at the cost of editability. This can be fixed by prompt editing.\"\n        ),\n    )\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=500,\n        help=\"Save learned_embeds.bin every X updates steps.\",\n    )\n    parser.add_argument(\n        \"--only_save_embeds\",\n        action=\"store_true\",\n        default=False,\n        help=\"Save only the embeddings for the new concept.\",\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\", type=str, default=None, required=True, help=\"A folder containing the training data.\"\n    )\n    parser.add_argument(\n        \"--placeholder_token\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A token to use as a placeholder for the concept.\",\n    )\n    parser.add_argument(\n        \"--initializer_token\", type=str, default=None, required=True, help=\"A token to use as initializer word.\"\n    )\n    parser.add_argument(\"--learnable_property\", type=str, default=\"object\", help=\"Choose between 'object' and 'style'\")\n    parser.add_argument(\"--repeats\", type=int, default=100, help=\"How many times to repeat the training data.\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\", action=\"store_true\", help=\"Whether to center crop images before resizing to resolution.\"\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=5000,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run validation every X epochs. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more docs\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.train_data_dir is None:\n        raise ValueError(\"You must specify a train data directory.\")\n\n    return args\n\n\nimagenet_templates_small = [\n    \"a photo of a {}\",\n    \"a rendering of a {}\",\n    \"a cropped photo of the {}\",\n    \"the photo of a {}\",\n    \"a photo of a clean {}\",\n    \"a photo of a dirty {}\",\n    \"a dark photo of the {}\",\n    \"a photo of my {}\",\n    \"a photo of the cool {}\",\n    \"a close-up photo of a {}\",\n    \"a bright photo of the {}\",\n    \"a cropped photo of a {}\",\n    \"a photo of the {}\",\n    \"a good photo of the {}\",\n    \"a photo of one {}\",\n    \"a close-up photo of the {}\",\n    \"a rendition of the {}\",\n    \"a photo of the clean {}\",\n    \"a rendition of a {}\",\n    \"a photo of a nice {}\",\n    \"a good photo of a {}\",\n    \"a photo of the nice {}\",\n    \"a photo of the small {}\",\n    \"a photo of the weird {}\",\n    \"a photo of the large {}\",\n    \"a photo of a cool {}\",\n    \"a photo of a small {}\",\n]\n\nimagenet_style_templates_small = [\n    \"a painting in the style of {}\",\n    \"a rendering in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"the painting in the style of {}\",\n    \"a clean painting in the style of {}\",\n    \"a dirty painting in the style of {}\",\n    \"a dark painting in the style of {}\",\n    \"a picture in the style of {}\",\n    \"a cool painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a bright painting in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"a good painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a rendition in the style of {}\",\n    \"a nice painting in the style of {}\",\n    \"a small painting in the style of {}\",\n    \"a weird painting in the style of {}\",\n    \"a large painting in the style of {}\",\n]\n\n\nclass TextualInversionDataset(Dataset):\n    def __init__(\n        self,\n        data_root,\n        tokenizer,\n        learnable_property=\"object\",  # [object, style]\n        size=512,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.5,\n        set=\"train\",\n        placeholder_token=\"*\",\n        center_crop=False,\n        vector_shuffle=False,\n        progressive_tokens=False,\n    ):\n        self.data_root = data_root\n        self.tokenizer = tokenizer\n        self.learnable_property = learnable_property\n        self.size = size\n        self.placeholder_token = placeholder_token\n        self.center_crop = center_crop\n        self.flip_p = flip_p\n        self.vector_shuffle = vector_shuffle\n        self.progressive_tokens = progressive_tokens\n        self.prop_tokens_to_load = 0\n\n        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.interpolation = {\n            \"linear\": PIL_INTERPOLATION[\"linear\"],\n            \"bilinear\": PIL_INTERPOLATION[\"bilinear\"],\n            \"bicubic\": PIL_INTERPOLATION[\"bicubic\"],\n            \"lanczos\": PIL_INTERPOLATION[\"lanczos\"],\n        }[interpolation]\n\n        self.templates = imagenet_style_templates_small if learnable_property == \"style\" else imagenet_templates_small\n        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        placeholder_string = self.placeholder_token\n        text = random.choice(self.templates).format(placeholder_string)\n\n        example[\"input_ids\"] = self.tokenizer.encode(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n            vector_shuffle=self.vector_shuffle,\n            prop_tokens_to_load=self.prop_tokens_to_load if self.progressive_tokens else 1.0,\n        )[0]\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            (\n                h,\n                w,\n            ) = (\n                img.shape[0],\n                img.shape[1],\n            )\n            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]\n\n        image = Image.fromarray(img)\n        image = image.resize((self.size, self.size), resample=self.interpolation)\n\n        image = self.flip_transform(image)\n        image = np.array(image).astype(np.uint8)\n        image = (image / 127.5 - 1.0).astype(np.float32)\n\n        example[\"pixel_values\"] = torch.from_numpy(image).permute(2, 0, 1)\n        return example\n\n\ndef main():\n    args = parse_args()\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load tokenizer\n    if args.tokenizer_name:\n        tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n    )\n    if is_xformers_available():\n        try:\n            unet.enable_xformers_memory_efficient_attention()\n        except Exception as e:\n            logger.warning(\n                \"Could not enable memory efficient attention. Make sure xformers is installed\"\n                f\" correctly and a GPU is available: {e}\"\n            )\n    add_tokens(tokenizer, text_encoder, args.placeholder_token, args.num_vec_per_token, args.initializer_token)\n\n    # Freeze vae and unet\n    vae.requires_grad_(False)\n    unet.requires_grad_(False)\n    # Freeze all parameters except for the token embeddings in text encoder\n    text_encoder.text_model.encoder.requires_grad_(False)\n    text_encoder.text_model.final_layer_norm.requires_grad_(False)\n    text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)\n\n    if args.gradient_checkpointing:\n        # Keep unet in train mode if we are using gradient checkpointing to save memory.\n        # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.\n        unet.train()\n        text_encoder.gradient_checkpointing_enable()\n        unet.enable_gradient_checkpointing()\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    optimizer = torch.optim.AdamW(\n        text_encoder.get_input_embeddings().parameters(),  # only optimize the embeddings\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = TextualInversionDataset(\n        data_root=args.train_data_dir,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        placeholder_token=args.placeholder_token,\n        repeats=args.repeats,\n        learnable_property=args.learnable_property,\n        center_crop=args.center_crop,\n        set=\"train\",\n    )\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    # Prepare everything with our `accelerator`.\n    text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        text_encoder, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # For mixed precision training we cast the unet and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae and unet to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"textual_inversion\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    # keep original embeddings as reference\n    orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n            if args.progressive_tokens:\n                train_dataset.prop_tokens_to_load = float(global_step) / args.progressive_tokens_max_steps\n\n            with accelerator.accumulate(text_encoder):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample().detach()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0].to(dtype=weight_dtype)\n\n                # Predict the noise residual\n                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n                # Let's make sure we don't update any embedding weights besides the newly added token\n                index_no_updates = get_mask(tokenizer, accelerator)\n                with torch.no_grad():\n                    accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (\n                        orig_embeds_params[index_no_updates]\n                    )\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                if global_step % args.save_steps == 0:\n                    save_path = os.path.join(args.output_dir, f\"learned_embeds-steps-{global_step}.bin\")\n                    save_progress(tokenizer, text_encoder, accelerator, save_path)\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n            logger.info(\n                f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n                f\" {args.validation_prompt}.\"\n            )\n            # create pipeline (note: unet and vae are loaded again in float32)\n            pipeline = DiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                text_encoder=accelerator.unwrap_model(text_encoder),\n                tokenizer=tokenizer,\n                unet=unet,\n                vae=vae,\n                revision=args.revision,\n                torch_dtype=weight_dtype,\n            )\n            pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n            pipeline = pipeline.to(accelerator.device)\n            pipeline.set_progress_bar_config(disable=True)\n\n            # run inference\n            generator = (\n                None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)\n            )\n            images = []\n            for _ in range(args.num_validation_images):\n                with torch.autocast(\"cuda\"):\n                    image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]\n                images.append(image)\n\n            for tracker in accelerator.trackers:\n                if tracker.name == \"tensorboard\":\n                    np_images = np.stack([np.asarray(img) for img in images])\n                    tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n                if tracker.name == \"wandb\":\n                    tracker.log(\n                        {\n                            \"validation\": [\n                                wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                for i, image in enumerate(images)\n                            ]\n                        }\n                    )\n\n            del pipeline\n            torch.cuda.empty_cache()\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        if args.push_to_hub and args.only_save_embeds:\n            logger.warning(\"Enabling full model saving because --push_to_hub=True was specified.\")\n            save_full_model = True\n        else:\n            save_full_model = not args.only_save_embeds\n        if save_full_model:\n            pipeline = StableDiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                text_encoder=accelerator.unwrap_model(text_encoder),\n                vae=vae,\n                unet=unet,\n                tokenizer=tokenizer,\n            )\n            pipeline.save_pretrained(args.output_dir)\n        # Save the newly trained embeddings\n        save_path = os.path.join(args.output_dir, \"learned_embeds.bin\")\n        save_progress(tokenizer, text_encoder, accelerator, save_path)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/multi_token_textual_inversion/textual_inversion_flax.py",
    "content": "import argparse\nimport logging\nimport math\nimport os\nimport random\nfrom pathlib import Path\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\nimport PIL\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom flax import jax_utils\nfrom flax.training import train_state\nfrom flax.training.common_utils import shard\nfrom huggingface_hub import create_repo, upload_folder\n\n# TODO: remove and import from diffusers.utils when the new version of diffusers is released\nfrom packaging import version\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed\n\nfrom diffusers import (\n    FlaxAutoencoderKL,\n    FlaxDDPMScheduler,\n    FlaxPNDMScheduler,\n    FlaxStableDiffusionPipeline,\n    FlaxUNet2DConditionModel,\n)\nfrom diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker\nfrom diffusers.utils import check_min_version\n\n\nif version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.Resampling.BILINEAR,\n        \"bilinear\": PIL.Image.Resampling.BILINEAR,\n        \"bicubic\": PIL.Image.Resampling.BICUBIC,\n        \"lanczos\": PIL.Image.Resampling.LANCZOS,\n        \"nearest\": PIL.Image.Resampling.NEAREST,\n    }\nelse:\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.LINEAR,\n        \"bilinear\": PIL.Image.BILINEAR,\n        \"bicubic\": PIL.Image.BICUBIC,\n        \"lanczos\": PIL.Image.LANCZOS,\n        \"nearest\": PIL.Image.NEAREST,\n    }\n# ------------------------------------------------------------------------------\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.14.0.dev0\")\n\nlogger = logging.getLogger(__name__)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\", type=str, default=None, required=True, help=\"A folder containing the training data.\"\n    )\n    parser.add_argument(\n        \"--placeholder_token\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A token to use as a placeholder for the concept.\",\n    )\n    parser.add_argument(\n        \"--initializer_token\", type=str, default=None, required=True, help=\"A token to use as initializer word.\"\n    )\n    parser.add_argument(\"--learnable_property\", type=str, default=\"object\", help=\"Choose between 'object' and 'style'\")\n    parser.add_argument(\"--repeats\", type=int, default=100, help=\"How many times to repeat the training data.\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=42, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\", action=\"store_true\", help=\"Whether to center crop images before resizing to resolution.\"\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=5000,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=True,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\n        \"--use_auth_token\",\n        action=\"store_true\",\n        help=(\n            \"Will use the token generated when running `hf auth login` (necessary to use this script with\"\n            \" private models).\"\n        ),\n    )\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.train_data_dir is None:\n        raise ValueError(\"You must specify a train data directory.\")\n\n    return args\n\n\nimagenet_templates_small = [\n    \"a photo of a {}\",\n    \"a rendering of a {}\",\n    \"a cropped photo of the {}\",\n    \"the photo of a {}\",\n    \"a photo of a clean {}\",\n    \"a photo of a dirty {}\",\n    \"a dark photo of the {}\",\n    \"a photo of my {}\",\n    \"a photo of the cool {}\",\n    \"a close-up photo of a {}\",\n    \"a bright photo of the {}\",\n    \"a cropped photo of a {}\",\n    \"a photo of the {}\",\n    \"a good photo of the {}\",\n    \"a photo of one {}\",\n    \"a close-up photo of the {}\",\n    \"a rendition of the {}\",\n    \"a photo of the clean {}\",\n    \"a rendition of a {}\",\n    \"a photo of a nice {}\",\n    \"a good photo of a {}\",\n    \"a photo of the nice {}\",\n    \"a photo of the small {}\",\n    \"a photo of the weird {}\",\n    \"a photo of the large {}\",\n    \"a photo of a cool {}\",\n    \"a photo of a small {}\",\n]\n\nimagenet_style_templates_small = [\n    \"a painting in the style of {}\",\n    \"a rendering in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"the painting in the style of {}\",\n    \"a clean painting in the style of {}\",\n    \"a dirty painting in the style of {}\",\n    \"a dark painting in the style of {}\",\n    \"a picture in the style of {}\",\n    \"a cool painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a bright painting in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"a good painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a rendition in the style of {}\",\n    \"a nice painting in the style of {}\",\n    \"a small painting in the style of {}\",\n    \"a weird painting in the style of {}\",\n    \"a large painting in the style of {}\",\n]\n\n\nclass TextualInversionDataset(Dataset):\n    def __init__(\n        self,\n        data_root,\n        tokenizer,\n        learnable_property=\"object\",  # [object, style]\n        size=512,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.5,\n        set=\"train\",\n        placeholder_token=\"*\",\n        center_crop=False,\n    ):\n        self.data_root = data_root\n        self.tokenizer = tokenizer\n        self.learnable_property = learnable_property\n        self.size = size\n        self.placeholder_token = placeholder_token\n        self.center_crop = center_crop\n        self.flip_p = flip_p\n\n        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.interpolation = {\n            \"linear\": PIL_INTERPOLATION[\"linear\"],\n            \"bilinear\": PIL_INTERPOLATION[\"bilinear\"],\n            \"bicubic\": PIL_INTERPOLATION[\"bicubic\"],\n            \"lanczos\": PIL_INTERPOLATION[\"lanczos\"],\n        }[interpolation]\n\n        self.templates = imagenet_style_templates_small if learnable_property == \"style\" else imagenet_templates_small\n        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        placeholder_string = self.placeholder_token\n        text = random.choice(self.templates).format(placeholder_string)\n\n        example[\"input_ids\"] = self.tokenizer(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids[0]\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            (\n                h,\n                w,\n            ) = (\n                img.shape[0],\n                img.shape[1],\n            )\n            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]\n\n        image = Image.fromarray(img)\n        image = image.resize((self.size, self.size), resample=self.interpolation)\n\n        image = self.flip_transform(image)\n        image = np.array(image).astype(np.uint8)\n        image = (image / 127.5 - 1.0).astype(np.float32)\n\n        example[\"pixel_values\"] = torch.from_numpy(image).permute(2, 0, 1)\n        return example\n\n\ndef resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):\n    if model.config.vocab_size == new_num_tokens or new_num_tokens is None:\n        return\n    model.config.vocab_size = new_num_tokens\n\n    params = model.params\n    old_embeddings = params[\"text_model\"][\"embeddings\"][\"token_embedding\"][\"embedding\"]\n    old_num_tokens, emb_dim = old_embeddings.shape\n\n    initializer = jax.nn.initializers.normal()\n\n    new_embeddings = initializer(rng, (new_num_tokens, emb_dim))\n    new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings)\n    new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id])\n    params[\"text_model\"][\"embeddings\"][\"token_embedding\"][\"embedding\"] = new_embeddings\n\n    model.params = params\n    return model\n\n\ndef get_params_to_save(params):\n    return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))\n\n\ndef main():\n    args = parse_args()\n\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    if jax.process_index() == 0:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s -   %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    # Setup logging, we only want one process per machine to log things on the screen.\n    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)\n    if jax.process_index() == 0:\n        transformers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n\n    # Load the tokenizer and add the placeholder token as a additional special token\n    if args.tokenizer_name:\n        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n    # Add the placeholder token in tokenizer\n    num_added_tokens = tokenizer.add_tokens(args.placeholder_token)\n    if num_added_tokens == 0:\n        raise ValueError(\n            f\"The tokenizer already contains the token {args.placeholder_token}. Please pass a different\"\n            \" `placeholder_token` that is not already in the tokenizer.\"\n        )\n\n    # Convert the initializer_token, placeholder_token to ids\n    token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)\n    # Check if initializer_token is a single token or a sequence of tokens\n    if len(token_ids) > 1:\n        raise ValueError(\"The initializer token must be a single token.\")\n\n    initializer_token_id = token_ids[0]\n    placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)\n\n    # Load models and create wrapper for stable diffusion\n    text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"text_encoder\")\n    vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\")\n    unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"unet\")\n\n    # Create sampling rng\n    rng = jax.random.PRNGKey(args.seed)\n    rng, _ = jax.random.split(rng)\n    # Resize the token embeddings as we are adding new special tokens to the tokenizer\n    text_encoder = resize_token_embeddings(\n        text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng\n    )\n    original_token_embeds = text_encoder.params[\"text_model\"][\"embeddings\"][\"token_embedding\"][\"embedding\"]\n\n    train_dataset = TextualInversionDataset(\n        data_root=args.train_data_dir,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        placeholder_token=args.placeholder_token,\n        repeats=args.repeats,\n        learnable_property=args.learnable_property,\n        center_crop=args.center_crop,\n        set=\"train\",\n    )\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n\n        batch = {\"pixel_values\": pixel_values, \"input_ids\": input_ids}\n        batch = {k: v.numpy() for k, v in batch.items()}\n\n        return batch\n\n    total_train_batch_size = args.train_batch_size * jax.local_device_count()\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn\n    )\n\n    # Optimization\n    if args.scale_lr:\n        args.learning_rate = args.learning_rate * total_train_batch_size\n\n    constant_scheduler = optax.constant_schedule(args.learning_rate)\n\n    optimizer = optax.adamw(\n        learning_rate=constant_scheduler,\n        b1=args.adam_beta1,\n        b2=args.adam_beta2,\n        eps=args.adam_epsilon,\n        weight_decay=args.adam_weight_decay,\n    )\n\n    def create_mask(params, label_fn):\n        def _map(params, mask, label_fn):\n            for k in params:\n                if label_fn(k):\n                    mask[k] = \"token_embedding\"\n                else:\n                    if isinstance(params[k], dict):\n                        mask[k] = {}\n                        _map(params[k], mask[k], label_fn)\n                    else:\n                        mask[k] = \"zero\"\n\n        mask = {}\n        _map(params, mask, label_fn)\n        return mask\n\n    def zero_grads():\n        # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491\n        def init_fn(_):\n            return ()\n\n        def update_fn(updates, state, params=None):\n            return jax.tree_util.tree_map(jnp.zeros_like, updates), ()\n\n        return optax.GradientTransformation(init_fn, update_fn)\n\n    # Zero out gradients of layers other than the token embedding layer\n    tx = optax.multi_transform(\n        {\"token_embedding\": optimizer, \"zero\": zero_grads()},\n        create_mask(text_encoder.params, lambda s: s == \"token_embedding\"),\n    )\n\n    state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx)\n\n    noise_scheduler = FlaxDDPMScheduler(\n        beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000\n    )\n    noise_scheduler_state = noise_scheduler.create_state()\n\n    # Initialize our training\n    train_rngs = jax.random.split(rng, jax.local_device_count())\n\n    # Define gradient train step fn\n    def train_step(state, vae_params, unet_params, batch, train_rng):\n        dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)\n\n        def compute_loss(params):\n            vae_outputs = vae.apply(\n                {\"params\": vae_params}, batch[\"pixel_values\"], deterministic=True, method=vae.encode\n            )\n            latents = vae_outputs.latent_dist.sample(sample_rng)\n            # (NHWC) -> (NCHW)\n            latents = jnp.transpose(latents, (0, 3, 1, 2))\n            latents = latents * vae.config.scaling_factor\n\n            noise_rng, timestep_rng = jax.random.split(sample_rng)\n            noise = jax.random.normal(noise_rng, latents.shape)\n            bsz = latents.shape[0]\n            timesteps = jax.random.randint(\n                timestep_rng,\n                (bsz,),\n                0,\n                noise_scheduler.config.num_train_timesteps,\n            )\n            noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)\n            encoder_hidden_states = state.apply_fn(\n                batch[\"input_ids\"], params=params, dropout_rng=dropout_rng, train=True\n            )[0]\n            # Predict the noise residual and compute loss\n            model_pred = unet.apply(\n                {\"params\": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False\n            ).sample\n\n            # Get the target for loss depending on the prediction type\n            if noise_scheduler.config.prediction_type == \"epsilon\":\n                target = noise\n            elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)\n            else:\n                raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n            loss = (target - model_pred) ** 2\n            loss = loss.mean()\n\n            return loss\n\n        grad_fn = jax.value_and_grad(compute_loss)\n        loss, grad = grad_fn(state.params)\n        grad = jax.lax.pmean(grad, \"batch\")\n        new_state = state.apply_gradients(grads=grad)\n\n        # Keep the token embeddings fixed except the newly added embeddings for the concept,\n        # as we only want to optimize the concept embeddings\n        token_embeds = original_token_embeds.at[placeholder_token_id].set(\n            new_state.params[\"text_model\"][\"embeddings\"][\"token_embedding\"][\"embedding\"][placeholder_token_id]\n        )\n        new_state.params[\"text_model\"][\"embeddings\"][\"token_embedding\"][\"embedding\"] = token_embeds\n\n        metrics = {\"loss\": loss}\n        metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n        return new_state, metrics, new_train_rng\n\n    # Create parallel version of the train and eval step\n    p_train_step = jax.pmap(train_step, \"batch\", donate_argnums=(0,))\n\n    # Replicate the train state on each device\n    state = jax_utils.replicate(state)\n    vae_params = jax_utils.replicate(vae_params)\n    unet_params = jax_utils.replicate(unet_params)\n\n    # Train!\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader))\n\n    # Scheduler and math around the number of training steps.\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel & distributed) = {total_train_batch_size}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n\n    global_step = 0\n\n    epochs = tqdm(range(args.num_train_epochs), desc=f\"Epoch ... (1/{args.num_train_epochs})\", position=0)\n    for epoch in epochs:\n        # ======================== Training ================================\n\n        train_metrics = []\n\n        steps_per_epoch = len(train_dataset) // total_train_batch_size\n        train_step_progress_bar = tqdm(total=steps_per_epoch, desc=\"Training...\", position=1, leave=False)\n        # train\n        for batch in train_dataloader:\n            batch = shard(batch)\n            state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs)\n            train_metrics.append(train_metric)\n\n            train_step_progress_bar.update(1)\n            global_step += 1\n\n            if global_step >= args.max_train_steps:\n                break\n\n        train_metric = jax_utils.unreplicate(train_metric)\n\n        train_step_progress_bar.close()\n        epochs.write(f\"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})\")\n\n    # Create the pipeline using using the trained modules and save it.\n    if jax.process_index() == 0:\n        scheduler = FlaxPNDMScheduler(\n            beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", skip_prk_steps=True\n        )\n        safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(\n            \"CompVis/stable-diffusion-safety-checker\", from_pt=True\n        )\n        pipeline = FlaxStableDiffusionPipeline(\n            text_encoder=text_encoder,\n            vae=vae,\n            unet=unet,\n            tokenizer=tokenizer,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=CLIPImageProcessor.from_pretrained(\"openai/clip-vit-base-patch32\"),\n        )\n\n        pipeline.save_pretrained(\n            args.output_dir,\n            params={\n                \"text_encoder\": get_params_to_save(state.params),\n                \"vae\": get_params_to_save(vae_params),\n                \"unet\": get_params_to_save(unet_params),\n                \"safety_checker\": safety_checker.params,\n            },\n        )\n\n        # Also save the newly trained embeddings\n        learned_embeds = get_params_to_save(state.params)[\"text_model\"][\"embeddings\"][\"token_embedding\"][\"embedding\"][\n            placeholder_token_id\n        ]\n        learned_embeds_dict = {args.placeholder_token: learned_embeds}\n        jnp.save(os.path.join(args.output_dir, \"learned_embeds.npy\"), learned_embeds_dict)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/onnxruntime/README.md",
    "content": "## Diffusers examples with ONNXRuntime optimizations\n\n**This research project is not actively maintained by the diffusers team. For any questions or comments, please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.**\n\nThis aims to provide diffusers examples with ONNXRuntime optimizations for training/fine-tuning unconditional image generation, text to image, and textual inversion. Please see individual directories for more details on how to run each task using ONNXRuntime.\n"
  },
  {
    "path": "examples/research_projects/onnxruntime/text_to_image/README.md",
    "content": "# Stable Diffusion text-to-image fine-tuning\n\nThe `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset.\n\n___Note___:\n\n___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___\n\n\n## Running locally with PyTorch\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder  and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\n### Naruto example\n\nYou need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.\n\nYou have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).\n\nRun the following command to authenticate your token\n\n```bash\nhf auth login\n```\n\nIf you have already cloned the repo, then you won't need to go through these steps.\n\n<br>\n\n## Use ONNXRuntime to accelerate training\nIn order to leverage onnxruntime to accelerate training, please use train_text_to_image.py\n\nThe command to train a DDPM UNetCondition model on the Naruto dataset with onnxruntime:\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport dataset_name=\"lambdalabs/naruto-blip-captions\"\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$dataset_name \\\n  --use_ema \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --output_dir=\"sd-naruto-model\"\n```\n\nPlease contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions."
  },
  {
    "path": "examples/research_projects/onnxruntime/text_to_image/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\ndatasets\nftfy\ntensorboard\nmodelcards\n"
  },
  {
    "path": "examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer\nfrom onnxruntime.training.ortmodule import ORTModule\nfrom packaging import version\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\nfrom transformers.utils import ContextManagers\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel, compute_snr\nfrom diffusers.utils import check_min_version, deprecate, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nif is_wandb_available():\n    import wandb\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.17.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):\n    logger.info(\"Running validation... \")\n\n    pipeline = StableDiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=accelerator.unwrap_model(vae),\n        text_encoder=accelerator.unwrap_model(text_encoder),\n        tokenizer=tokenizer,\n        unet=accelerator.unwrap_model(unet),\n        safety_checker=None,\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    images = []\n    for i in range(len(args.validation_prompts)):\n        with torch.autocast(\"cuda\"):\n            image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]\n\n        images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompts[i]}\")\n                        for i, image in enumerate(images)\n                    ]\n                }\n            )\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--input_pertubation\", type=float, default=0, help=\"The scale of input pretubation. Recommended 0.1.\"\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompts\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\"A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`.\"),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-model-finetuned\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--non_ema_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or\"\n            \" remote repository specified with --pretrained_model_name_or_path.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more docs\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=5,\n        help=\"Run validation every X epochs.\",\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"text2image-fine-tune\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    # default to using the same revision for the non-ema model if not specified\n    if args.non_ema_revision is None:\n        args.non_ema_revision = args.revision\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if args.non_ema_revision is not None:\n        deprecate(\n            \"non_ema_revision!=None\",\n            \"0.15.0\",\n            message=(\n                \"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to\"\n                \" use `--variant=non_ema` instead.\"\n            ),\n        )\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n    )\n\n    def deepspeed_zero_init_disabled_context_manager():\n        \"\"\"\n        returns either a context list that includes one that will disable zero.Init or an empty context list\n        \"\"\"\n        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None\n        if deepspeed_plugin is None:\n            return []\n\n        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]\n\n    # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.\n    # For this to work properly all models must be run through `accelerate.prepare`. But accelerate\n    # will try to assign the same optimizer with the same weights to all models during\n    # `deepspeed.initialize`, which of course doesn't work.\n    #\n    # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2\n    # frozen models from being partitioned during `zero.Init` which gets called during\n    # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding\n    # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.\n    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n        text_encoder = CLIPTextModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n        )\n        vae = AutoencoderKL.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision\n        )\n\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.non_ema_revision\n    )\n\n    # Freeze vae and text_encoder\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n\n    # Create EMA for the unet.\n    if args.use_ema:\n        ema_unet = UNet2DConditionModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n        )\n        ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_unet.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"unet_ema\"), UNet2DConditionModel)\n                ema_unet.load_state_dict(load_model.state_dict())\n                ema_unet.to(accelerator.device)\n                del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    optimizer = optimizer_cls(\n        unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    optimizer = ORT_FP16_Optimizer(optimizer)\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        return inputs.input_ids\n\n    # Preprocessing the datasets.\n    train_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n        examples[\"input_ids\"] = tokenize_captions(examples)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n        return {\"pixel_values\": pixel_values, \"input_ids\": input_ids}\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    if args.use_ema:\n        ema_unet.to(accelerator.device)\n\n    unet = ORTModule(unet)\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move text_encode and vae to gpu and cast to weight_dtype\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        tracker_config.pop(\"validation_prompts\")\n        accelerator.init_trackers(args.tracker_project_name, tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (latents.shape[0], latents.shape[1], 1, 1), device=latents.device\n                    )\n                if args.input_pertubation:\n                    new_noise = noise + args.input_pertubation * torch.randn_like(noise)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                if args.input_pertubation:\n                    noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)\n                else:\n                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # Predict the noise residual and compute loss\n                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_unet.step(unet.parameters())\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompts is not None and epoch % args.validation_epochs == 0:\n                if args.use_ema:\n                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                    ema_unet.store(unet.parameters())\n                    ema_unet.copy_to(unet.parameters())\n                log_validation(\n                    vae,\n                    text_encoder,\n                    tokenizer,\n                    unet,\n                    args,\n                    accelerator,\n                    weight_dtype,\n                    global_step,\n                )\n                if args.use_ema:\n                    # Switch back to the original UNet parameters.\n                    ema_unet.restore(unet.parameters())\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = accelerator.unwrap_model(unet)\n        if args.use_ema:\n            ema_unet.copy_to(unet.parameters())\n\n        pipeline = StableDiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            text_encoder=text_encoder,\n            vae=vae,\n            unet=unet,\n            revision=args.revision,\n        )\n        pipeline.save_pretrained(args.output_dir)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/onnxruntime/textual_inversion/README.md",
    "content": "## Textual Inversion fine-tuning example\n\n[Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.\nThe `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.\n\n## Running on Colab\n\nColab for training\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)\n\nColab for inference\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)\n\n## Running locally with PyTorch\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder  and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\n\n### Cat toy example\n\nYou need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.\n\nYou have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).\n\nRun the following command to authenticate your token\n\n```bash\nhf auth login\n```\n\nIf you have already cloned the repo, then you won't need to go through these steps.\n\n<br>\n\nNow let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .\n\nLet's first download it locally:\n\n```py\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./cat\"\nsnapshot_download(\"diffusers/cat_toy_example\", local_dir=local_dir, repo_type=\"dataset\", ignore_patterns=\".gitattributes\")\n```\n\nThis will be our training data.\nNow we can launch the training using\n\n## Use ONNXRuntime to accelerate training\nIn order to leverage onnxruntime to accelerate training, please use textual_inversion.py\n\nThe command to train on custom data with onnxruntime:\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport DATA_DIR=\"path-to-dir-containing-images\"\n\naccelerate launch textual_inversion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<cat-toy>\" --initializer_token=\"toy\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=3000 \\\n  --learning_rate=5.0e-04 --scale_lr \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --output_dir=\"textual_inversion_cat\"\n```\n\nPlease contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions."
  },
  {
    "path": "examples/research_projects/onnxruntime/textual_inversion/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nmodelcards\n"
  },
  {
    "path": "examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport warnings\nfrom pathlib import Path\n\nimport numpy as np\nimport PIL\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer\nfrom onnxruntime.training.ortmodule import ORTModule\n\n# TODO: remove and import from diffusers.utils when the new version of diffusers is released\nfrom packaging import version\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    DPMSolverMultistepScheduler,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nif is_wandb_available():\n    import wandb\n\nif version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.Resampling.BILINEAR,\n        \"bilinear\": PIL.Image.Resampling.BILINEAR,\n        \"bicubic\": PIL.Image.Resampling.BICUBIC,\n        \"lanczos\": PIL.Image.Resampling.LANCZOS,\n        \"nearest\": PIL.Image.Resampling.NEAREST,\n    }\nelse:\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.LINEAR,\n        \"bilinear\": PIL.Image.BILINEAR,\n        \"bicubic\": PIL.Image.BICUBIC,\n        \"lanczos\": PIL.Image.LANCZOS,\n        \"nearest\": PIL.Image.NEAREST,\n    }\n# ------------------------------------------------------------------------------\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.17.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    for i, image in enumerate(images):\n        image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n        img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {base_model}\ntags:\n- stable-diffusion\n- stable-diffusion-diffusers\n- text-to-image\n- diffusers\n- textual_inversion\n- diffusers-training\n- onxruntime\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# Textual inversion text2image fine-tuning - {repo_id}\nThese are textual inversion adaption weights for {base_model}. You can find some example images in the following. \\n\n{img_str}\n\"\"\"\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    # create pipeline (note: unet and vae are loaded again in float32)\n    pipeline = DiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        text_encoder=accelerator.unwrap_model(text_encoder),\n        tokenizer=tokenizer,\n        unet=unet,\n        vae=vae,\n        safety_checker=None,\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n    )\n    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)\n    images = []\n    for _ in range(args.num_validation_images):\n        with torch.autocast(\"cuda\"):\n            image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]\n        images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    torch.cuda.empty_cache()\n    return images\n\n\ndef save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path):\n    logger.info(\"Saving embeddings\")\n    learned_embeds = (\n        accelerator.unwrap_model(text_encoder)\n        .get_input_embeddings()\n        .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]\n    )\n    learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}\n    torch.save(learned_embeds_dict, save_path)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=500,\n        help=\"Save learned_embeds.bin every X updates steps.\",\n    )\n    parser.add_argument(\n        \"--save_as_full_pipeline\",\n        action=\"store_true\",\n        help=\"Save the complete stable diffusion pipeline.\",\n    )\n    parser.add_argument(\n        \"--num_vectors\",\n        type=int,\n        default=1,\n        help=\"How many textual inversion vectors shall be used to learn the concept.\",\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\", type=str, default=None, required=True, help=\"A folder containing the training data.\"\n    )\n    parser.add_argument(\n        \"--placeholder_token\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A token to use as a placeholder for the concept.\",\n    )\n    parser.add_argument(\n        \"--initializer_token\", type=str, default=None, required=True, help=\"A token to use as initializer word.\"\n    )\n    parser.add_argument(\"--learnable_property\", type=str, default=\"object\", help=\"Choose between 'object' and 'style'\")\n    parser.add_argument(\"--repeats\", type=int, default=100, help=\"How many times to repeat the training data.\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\", action=\"store_true\", help=\"Whether to center crop images before resizing to resolution.\"\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=5000,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=None,\n        help=(\n            \"Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more docs\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.train_data_dir is None:\n        raise ValueError(\"You must specify a train data directory.\")\n\n    return args\n\n\nimagenet_templates_small = [\n    \"a photo of a {}\",\n    \"a rendering of a {}\",\n    \"a cropped photo of the {}\",\n    \"the photo of a {}\",\n    \"a photo of a clean {}\",\n    \"a photo of a dirty {}\",\n    \"a dark photo of the {}\",\n    \"a photo of my {}\",\n    \"a photo of the cool {}\",\n    \"a close-up photo of a {}\",\n    \"a bright photo of the {}\",\n    \"a cropped photo of a {}\",\n    \"a photo of the {}\",\n    \"a good photo of the {}\",\n    \"a photo of one {}\",\n    \"a close-up photo of the {}\",\n    \"a rendition of the {}\",\n    \"a photo of the clean {}\",\n    \"a rendition of a {}\",\n    \"a photo of a nice {}\",\n    \"a good photo of a {}\",\n    \"a photo of the nice {}\",\n    \"a photo of the small {}\",\n    \"a photo of the weird {}\",\n    \"a photo of the large {}\",\n    \"a photo of a cool {}\",\n    \"a photo of a small {}\",\n]\n\nimagenet_style_templates_small = [\n    \"a painting in the style of {}\",\n    \"a rendering in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"the painting in the style of {}\",\n    \"a clean painting in the style of {}\",\n    \"a dirty painting in the style of {}\",\n    \"a dark painting in the style of {}\",\n    \"a picture in the style of {}\",\n    \"a cool painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a bright painting in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"a good painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a rendition in the style of {}\",\n    \"a nice painting in the style of {}\",\n    \"a small painting in the style of {}\",\n    \"a weird painting in the style of {}\",\n    \"a large painting in the style of {}\",\n]\n\n\nclass TextualInversionDataset(Dataset):\n    def __init__(\n        self,\n        data_root,\n        tokenizer,\n        learnable_property=\"object\",  # [object, style]\n        size=512,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.5,\n        set=\"train\",\n        placeholder_token=\"*\",\n        center_crop=False,\n    ):\n        self.data_root = data_root\n        self.tokenizer = tokenizer\n        self.learnable_property = learnable_property\n        self.size = size\n        self.placeholder_token = placeholder_token\n        self.center_crop = center_crop\n        self.flip_p = flip_p\n\n        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.interpolation = {\n            \"linear\": PIL_INTERPOLATION[\"linear\"],\n            \"bilinear\": PIL_INTERPOLATION[\"bilinear\"],\n            \"bicubic\": PIL_INTERPOLATION[\"bicubic\"],\n            \"lanczos\": PIL_INTERPOLATION[\"lanczos\"],\n        }[interpolation]\n\n        self.templates = imagenet_style_templates_small if learnable_property == \"style\" else imagenet_templates_small\n        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        placeholder_string = self.placeholder_token\n        text = random.choice(self.templates).format(placeholder_string)\n\n        example[\"input_ids\"] = self.tokenizer(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids[0]\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            (\n                h,\n                w,\n            ) = (\n                img.shape[0],\n                img.shape[1],\n            )\n            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]\n\n        image = Image.fromarray(img)\n        image = image.resize((self.size, self.size), resample=self.interpolation)\n\n        image = self.flip_transform(image)\n        image = np.array(image).astype(np.uint8)\n        image = (image / 127.5 - 1.0).astype(np.float32)\n\n        example[\"pixel_values\"] = torch.from_numpy(image).permute(2, 0, 1)\n        return example\n\n\ndef main():\n    args = parse_args()\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load tokenizer\n    if args.tokenizer_name:\n        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n    )\n\n    # Add the placeholder token in tokenizer\n    placeholder_tokens = [args.placeholder_token]\n\n    if args.num_vectors < 1:\n        raise ValueError(f\"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}\")\n\n    # add dummy tokens for multi-vector\n    additional_tokens = []\n    for i in range(1, args.num_vectors):\n        additional_tokens.append(f\"{args.placeholder_token}_{i}\")\n    placeholder_tokens += additional_tokens\n\n    num_added_tokens = tokenizer.add_tokens(placeholder_tokens)\n    if num_added_tokens != args.num_vectors:\n        raise ValueError(\n            f\"The tokenizer already contains the token {args.placeholder_token}. Please pass a different\"\n            \" `placeholder_token` that is not already in the tokenizer.\"\n        )\n\n    # Convert the initializer_token, placeholder_token to ids\n    token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)\n    # Check if initializer_token is a single token or a sequence of tokens\n    if len(token_ids) > 1:\n        raise ValueError(\"The initializer token must be a single token.\")\n\n    initializer_token_id = token_ids[0]\n    placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)\n\n    # Resize the token embeddings as we are adding new special tokens to the tokenizer\n    text_encoder.resize_token_embeddings(len(tokenizer))\n\n    # Initialise the newly added placeholder token with the embeddings of the initializer token\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    with torch.no_grad():\n        for token_id in placeholder_token_ids:\n            token_embeds[token_id] = token_embeds[initializer_token_id].clone()\n\n    # Freeze vae and unet\n    vae.requires_grad_(False)\n    unet.requires_grad_(False)\n    # Freeze all parameters except for the token embeddings in text encoder\n    text_encoder.text_model.encoder.requires_grad_(False)\n    text_encoder.text_model.final_layer_norm.requires_grad_(False)\n    text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)\n\n    if args.gradient_checkpointing:\n        # Keep unet in train mode if we are using gradient checkpointing to save memory.\n        # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.\n        unet.train()\n        text_encoder.gradient_checkpointing_enable()\n        unet.enable_gradient_checkpointing()\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    optimizer = torch.optim.AdamW(\n        text_encoder.get_input_embeddings().parameters(),  # only optimize the embeddings\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    optimizer = ORT_FP16_Optimizer(optimizer)\n\n    # Dataset and DataLoaders creation:\n    train_dataset = TextualInversionDataset(\n        data_root=args.train_data_dir,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        placeholder_token=args.placeholder_token,\n        repeats=args.repeats,\n        learnable_property=args.learnable_property,\n        center_crop=args.center_crop,\n        set=\"train\",\n    )\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers\n    )\n    if args.validation_epochs is not None:\n        warnings.warn(\n            f\"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}.\"\n            \" Deprecated validation_epochs in favor of `validation_steps`\"\n            f\"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}\",\n            FutureWarning,\n            stacklevel=2,\n        )\n        args.validation_steps = args.validation_epochs * len(train_dataset)\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    # Prepare everything with our `accelerator`.\n    text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        text_encoder, optimizer, train_dataloader, lr_scheduler\n    )\n\n    text_encoder = ORTModule(text_encoder)\n    unet = ORTModule(unet)\n    vae = ORTModule(vae)\n\n    # For mixed precision training we cast the unet and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae and unet to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"textual_inversion\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    # keep original embeddings as reference\n    orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(text_encoder):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample().detach()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0].to(dtype=weight_dtype)\n\n                # Predict the noise residual\n                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n                # Let's make sure we don't update any embedding weights besides the newly added token\n                index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)\n                index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False\n\n                with torch.no_grad():\n                    accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (\n                        orig_embeds_params[index_no_updates]\n                    )\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                images = []\n                progress_bar.update(1)\n                global_step += 1\n                if global_step % args.save_steps == 0:\n                    save_path = os.path.join(args.output_dir, f\"learned_embeds-steps-{global_step}.bin\")\n                    save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        images = log_validation(\n                            text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        if args.push_to_hub and not args.save_as_full_pipeline:\n            logger.warning(\"Enabling full model saving because --push_to_hub=True was specified.\")\n            save_full_model = True\n        else:\n            save_full_model = args.save_as_full_pipeline\n        if save_full_model:\n            pipeline = StableDiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                text_encoder=accelerator.unwrap_model(text_encoder),\n                vae=vae,\n                unet=unet,\n                tokenizer=tokenizer,\n            )\n            pipeline.save_pretrained(args.output_dir)\n        # Save the newly trained embeddings\n        save_path = os.path.join(args.output_dir, \"learned_embeds.bin\")\n        save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/onnxruntime/unconditional_image_generation/README.md",
    "content": "## Training examples\n\nCreating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder  and run\n```bash\npip install -r requirements.txt\n```\n\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\n#### Use ONNXRuntime to accelerate training\n\nIn order to leverage onnxruntime to accelerate training, please use train_unconditional_ort.py\n\nThe command to train a DDPM UNet model on the Oxford Flowers dataset with onnxruntime:\n\n```bash\naccelerate launch train_unconditional.py \\\n  --dataset_name=\"huggan/flowers-102-categories\" \\\n  --resolution=64 --center_crop --random_flip \\\n  --output_dir=\"ddpm-ema-flowers-64\" \\\n  --use_ema \\\n  --train_batch_size=16 \\\n  --num_epochs=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=1e-4 \\\n  --lr_warmup_steps=500 \\\n  --mixed_precision=fp16\n  ```\n\nPlease contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.\n"
  },
  {
    "path": "examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ndatasets\ntensorboard"
  },
  {
    "path": "examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py",
    "content": "import argparse\nimport inspect\nimport logging\nimport math\nimport os\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport torch\nimport torch.nn.functional as F\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer\nfrom onnxruntime.training.ortmodule import ORTModule\nfrom packaging import version\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\n\nimport diffusers\nfrom diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel\nfrom diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.17.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef _extract_into_tensor(arr, timesteps, broadcast_shape):\n    \"\"\"\n    Extract values from a 1-D numpy array for a batch of indices.\n\n    :param arr: the 1-D numpy array.\n    :param timesteps: a tensor of indices into the array to extract.\n    :param broadcast_shape: a larger shape of K dimensions with the batch\n                            dimension equal to the length of timesteps.\n    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.\n    \"\"\"\n    if not isinstance(arr, torch.Tensor):\n        arr = torch.from_numpy(arr)\n    res = arr[timesteps].float().to(timesteps.device)\n    while len(res.shape) < len(broadcast_shape):\n        res = res[..., None]\n    return res.expand(broadcast_shape)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that HF Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--model_config_name_or_path\",\n        type=str,\n        default=None,\n        help=\"The config of the UNet model to train, leave as None to use standard DDPM configuration.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"ddpm-model-64\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--overwrite_output_dir\", action=\"store_true\")\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=64,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--eval_batch_size\", type=int, default=16, help=\"The number of images to generate for evaluation.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main\"\n            \" process.\"\n        ),\n    )\n    parser.add_argument(\"--num_epochs\", type=int, default=100)\n    parser.add_argument(\"--save_images_epochs\", type=int, default=10, help=\"How often to save images during training.\")\n    parser.add_argument(\n        \"--save_model_epochs\", type=int, default=10, help=\"How often to save the model during training.\"\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"cosine\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.95, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\n        \"--adam_weight_decay\", type=float, default=1e-6, help=\"Weight decay magnitude for the Adam optimizer.\"\n    )\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer.\")\n    parser.add_argument(\n        \"--use_ema\",\n        action=\"store_true\",\n        help=\"Whether to use Exponential Moving Average for the final model weights.\",\n    )\n    parser.add_argument(\"--ema_inv_gamma\", type=float, default=1.0, help=\"The inverse gamma value for the EMA decay.\")\n    parser.add_argument(\"--ema_power\", type=float, default=3 / 4, help=\"The power value for the EMA decay.\")\n    parser.add_argument(\"--ema_max_decay\", type=float, default=0.9999, help=\"The maximum decay magnitude for EMA.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--hub_private_repo\", action=\"store_true\", help=\"Whether or not to create a private repository.\"\n    )\n    parser.add_argument(\n        \"--logger\",\n        type=str,\n        default=\"tensorboard\",\n        choices=[\"tensorboard\", \"wandb\"],\n        help=(\n            \"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)\"\n            \" for experiment tracking and logging of model metrics and model checkpoints\"\n        ),\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=\"epsilon\",\n        choices=[\"epsilon\", \"sample\"],\n        help=\"Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.\",\n    )\n    parser.add_argument(\"--ddpm_num_steps\", type=int, default=1000)\n    parser.add_argument(\"--ddpm_num_inference_steps\", type=int, default=1000)\n    parser.add_argument(\"--ddpm_beta_schedule\", type=str, default=\"linear\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more docs\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"You must specify either a dataset name from the hub or a train data directory.\")\n\n    return args\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.logger == \"tensorboard\":\n        if not is_tensorboard_available():\n            raise ImportError(\"Make sure to install tensorboard if you want to use it for logging during training.\")\n\n    elif args.logger == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_model.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"unet_ema\"), UNet2DModel)\n                ema_model.load_state_dict(load_model.state_dict())\n                ema_model.to(accelerator.device)\n                del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Initialize the model\n    if args.model_config_name_or_path is None:\n        model = UNet2DModel(\n            sample_size=args.resolution,\n            in_channels=3,\n            out_channels=3,\n            layers_per_block=2,\n            block_out_channels=(128, 128, 256, 256, 512, 512),\n            down_block_types=(\n                \"DownBlock2D\",\n                \"DownBlock2D\",\n                \"DownBlock2D\",\n                \"DownBlock2D\",\n                \"AttnDownBlock2D\",\n                \"DownBlock2D\",\n            ),\n            up_block_types=(\n                \"UpBlock2D\",\n                \"AttnUpBlock2D\",\n                \"UpBlock2D\",\n                \"UpBlock2D\",\n                \"UpBlock2D\",\n                \"UpBlock2D\",\n            ),\n        )\n    else:\n        config = UNet2DModel.load_config(args.model_config_name_or_path)\n        model = UNet2DModel.from_config(config)\n\n    # Create EMA for the model.\n    if args.use_ema:\n        ema_model = EMAModel(\n            model.parameters(),\n            decay=args.ema_max_decay,\n            use_ema_warmup=True,\n            inv_gamma=args.ema_inv_gamma,\n            power=args.ema_power,\n            model_cls=UNet2DModel,\n            model_config=model.config,\n        )\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            model.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Initialize the scheduler\n    accepts_prediction_type = \"prediction_type\" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())\n    if accepts_prediction_type:\n        noise_scheduler = DDPMScheduler(\n            num_train_timesteps=args.ddpm_num_steps,\n            beta_schedule=args.ddpm_beta_schedule,\n            prediction_type=args.prediction_type,\n        )\n    else:\n        noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)\n\n    # Initialize the optimizer\n    optimizer = torch.optim.AdamW(\n        model.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    optimizer = ORT_FP16_Optimizer(optimizer)\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            split=\"train\",\n        )\n    else:\n        dataset = load_dataset(\"imagefolder\", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split=\"train\")\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets and DataLoaders creation.\n    augmentations = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def transform_images(examples):\n        images = [augmentations(image.convert(\"RGB\")) for image in examples[\"image\"]]\n        return {\"input\": images}\n\n    logger.info(f\"Dataset size: {len(dataset)}\")\n\n    dataset.set_transform(transform_images)\n    train_dataloader = torch.utils.data.DataLoader(\n        dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers\n    )\n\n    # Initialize the learning rate scheduler\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=(len(train_dataloader) * args.num_epochs),\n    )\n\n    # Prepare everything with our `accelerator`.\n    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, lr_scheduler\n    )\n\n    if args.use_ema:\n        ema_model.to(accelerator.device)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        run = os.path.split(__file__)[-1].split(\".\")[0]\n        accelerator.init_trackers(run)\n\n    model = ORTModule(model)\n\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    max_train_steps = args.num_epochs * num_update_steps_per_epoch\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {max_train_steps}\")\n\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Train!\n    for epoch in range(first_epoch, args.num_epochs):\n        model.train()\n        progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)\n        progress_bar.set_description(f\"Epoch {epoch}\")\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            clean_images = batch[\"input\"]\n            # Sample noise that we'll add to the images\n            noise = torch.randn(\n                clean_images.shape, dtype=(torch.float32 if args.mixed_precision == \"no\" else torch.float16)\n            ).to(clean_images.device)\n            bsz = clean_images.shape[0]\n            # Sample a random timestep for each image\n            timesteps = torch.randint(\n                0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device\n            ).long()\n\n            # Add noise to the clean images according to the noise magnitude at each timestep\n            # (this is the forward diffusion process)\n            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)\n\n            with accelerator.accumulate(model):\n                # Predict the noise residual\n                model_output = model(noisy_images, timesteps, return_dict=False)[0]\n\n                if args.prediction_type == \"epsilon\":\n                    loss = F.mse_loss(model_output, noise)  # this could have different weights!\n                elif args.prediction_type == \"sample\":\n                    alpha_t = _extract_into_tensor(\n                        noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)\n                    )\n                    snr_weights = alpha_t / (1 - alpha_t)\n                    loss = snr_weights * F.mse_loss(\n                        model_output, clean_images, reduction=\"none\"\n                    )  # use SNR weighting from distillation paper\n                    loss = loss.mean()\n                else:\n                    raise ValueError(f\"Unsupported prediction type: {args.prediction_type}\")\n\n                accelerator.backward(loss)\n\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(model.parameters(), 1.0)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_model.step(model.parameters())\n                progress_bar.update(1)\n                global_step += 1\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0], \"step\": global_step}\n            if args.use_ema:\n                logs[\"ema_decay\"] = ema_model.cur_decay_value\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n        progress_bar.close()\n\n        accelerator.wait_for_everyone()\n\n        # Generate sample images for visual inspection\n        if accelerator.is_main_process:\n            if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:\n                unet = accelerator.unwrap_model(model)\n\n                if args.use_ema:\n                    ema_model.store(unet.parameters())\n                    ema_model.copy_to(unet.parameters())\n\n                pipeline = DDPMPipeline(\n                    unet=unet,\n                    scheduler=noise_scheduler,\n                )\n\n                generator = torch.Generator(device=pipeline.device).manual_seed(0)\n                # run pipeline in inference (sample random noise and denoise)\n                images = pipeline(\n                    generator=generator,\n                    batch_size=args.eval_batch_size,\n                    num_inference_steps=args.ddpm_num_inference_steps,\n                    output_type=\"np\",\n                ).images\n\n                if args.use_ema:\n                    ema_model.restore(unet.parameters())\n\n                # denormalize the images and save to tensorboard\n                images_processed = (images * 255).round().astype(\"uint8\")\n\n                if args.logger == \"tensorboard\":\n                    if is_accelerate_version(\">=\", \"0.17.0.dev0\"):\n                        tracker = accelerator.get_tracker(\"tensorboard\", unwrap=True)\n                    else:\n                        tracker = accelerator.get_tracker(\"tensorboard\")\n                    tracker.add_images(\"test_samples\", images_processed.transpose(0, 3, 1, 2), epoch)\n                elif args.logger == \"wandb\":\n                    # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files\n                    accelerator.get_tracker(\"wandb\").log(\n                        {\"test_samples\": [wandb.Image(img) for img in images_processed], \"epoch\": epoch},\n                        step=global_step,\n                    )\n\n            if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:\n                # save the model\n                unet = accelerator.unwrap_model(model)\n\n                if args.use_ema:\n                    ema_model.store(unet.parameters())\n                    ema_model.copy_to(unet.parameters())\n\n                pipeline = DDPMPipeline(\n                    unet=unet,\n                    scheduler=noise_scheduler,\n                )\n\n                pipeline.save_pretrained(args.output_dir)\n\n                if args.use_ema:\n                    ema_model.restore(unet.parameters())\n\n                if args.push_to_hub:\n                    upload_folder(\n                        repo_id=repo_id,\n                        folder_path=args.output_dir,\n                        commit_message=f\"Epoch {epoch}\",\n                        ignore_patterns=[\"step_*\", \"epoch_*\"],\n                    )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/pixart/.gitignore",
    "content": "images/\noutput/"
  },
  {
    "path": "examples/research_projects/pixart/controlnet_pixart_alpha.py",
    "content": "from typing import Any, Dict, Optional\n\nimport torch\nfrom torch import nn\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models import PixArtTransformer2DModel\nfrom diffusers.models.attention import BasicTransformerBlock\nfrom diffusers.models.modeling_outputs import Transformer2DModelOutput\nfrom diffusers.models.modeling_utils import ModelMixin\n\n\nclass PixArtControlNetAdapterBlock(nn.Module):\n    def __init__(\n        self,\n        block_index,\n        # taken from PixArtTransformer2DModel\n        num_attention_heads: int = 16,\n        attention_head_dim: int = 72,\n        dropout: float = 0.0,\n        cross_attention_dim: Optional[int] = 1152,\n        attention_bias: bool = True,\n        activation_fn: str = \"gelu-approximate\",\n        num_embeds_ada_norm: Optional[int] = 1000,\n        upcast_attention: bool = False,\n        norm_type: str = \"ada_norm_single\",\n        norm_elementwise_affine: bool = False,\n        norm_eps: float = 1e-6,\n        attention_type: str | None = \"default\",\n    ):\n        super().__init__()\n\n        self.block_index = block_index\n        self.inner_dim = num_attention_heads * attention_head_dim\n\n        # the first block has a zero before layer\n        if self.block_index == 0:\n            self.before_proj = nn.Linear(self.inner_dim, self.inner_dim)\n            nn.init.zeros_(self.before_proj.weight)\n            nn.init.zeros_(self.before_proj.bias)\n\n        self.transformer_block = BasicTransformerBlock(\n            self.inner_dim,\n            num_attention_heads,\n            attention_head_dim,\n            dropout=dropout,\n            cross_attention_dim=cross_attention_dim,\n            activation_fn=activation_fn,\n            num_embeds_ada_norm=num_embeds_ada_norm,\n            attention_bias=attention_bias,\n            upcast_attention=upcast_attention,\n            norm_type=norm_type,\n            norm_elementwise_affine=norm_elementwise_affine,\n            norm_eps=norm_eps,\n            attention_type=attention_type,\n        )\n\n        self.after_proj = nn.Linear(self.inner_dim, self.inner_dim)\n        nn.init.zeros_(self.after_proj.weight)\n        nn.init.zeros_(self.after_proj.bias)\n\n    def train(self, mode: bool = True):\n        self.transformer_block.train(mode)\n\n        if self.block_index == 0:\n            self.before_proj.train(mode)\n\n        self.after_proj.train(mode)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        controlnet_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        timestep: Optional[torch.LongTensor] = None,\n        added_cond_kwargs: Dict[str, torch.Tensor] = None,\n        cross_attention_kwargs: Dict[str, Any] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n    ):\n        if self.block_index == 0:\n            controlnet_states = self.before_proj(controlnet_states)\n            controlnet_states = hidden_states + controlnet_states\n\n        controlnet_states_down = self.transformer_block(\n            hidden_states=controlnet_states,\n            encoder_hidden_states=encoder_hidden_states,\n            timestep=timestep,\n            added_cond_kwargs=added_cond_kwargs,\n            cross_attention_kwargs=cross_attention_kwargs,\n            attention_mask=attention_mask,\n            encoder_attention_mask=encoder_attention_mask,\n            class_labels=None,\n        )\n\n        controlnet_states_left = self.after_proj(controlnet_states_down)\n\n        return controlnet_states_left, controlnet_states_down\n\n\nclass PixArtControlNetAdapterModel(ModelMixin, ConfigMixin):\n    # N=13, as specified in the paper https://arxiv.org/html/2401.05252v1/#S4 ControlNet-Transformer\n    @register_to_config\n    def __init__(self, num_layers=13) -> None:\n        super().__init__()\n\n        self.num_layers = num_layers\n\n        self.controlnet_blocks = nn.ModuleList(\n            [PixArtControlNetAdapterBlock(block_index=i) for i in range(num_layers)]\n        )\n\n    @classmethod\n    def from_transformer(cls, transformer: PixArtTransformer2DModel):\n        control_net = PixArtControlNetAdapterModel()\n\n        # copied the specified number of blocks from the transformer\n        for depth in range(control_net.num_layers):\n            control_net.controlnet_blocks[depth].transformer_block.load_state_dict(\n                transformer.transformer_blocks[depth].state_dict()\n            )\n\n        return control_net\n\n    def train(self, mode: bool = True):\n        for block in self.controlnet_blocks:\n            block.train(mode)\n\n\nclass PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):\n    def __init__(\n        self,\n        transformer: PixArtTransformer2DModel,\n        controlnet: PixArtControlNetAdapterModel,\n        blocks_num=13,\n        init_from_transformer=False,\n        training=False,\n    ):\n        super().__init__()\n\n        self.blocks_num = blocks_num\n        self.gradient_checkpointing = False\n        self.register_to_config(**transformer.config)\n        self.training = training\n\n        if init_from_transformer:\n            # copies the specified number of blocks from the transformer\n            controlnet.from_transformer(transformer, self.blocks_num)\n\n        self.transformer = transformer\n        self.controlnet = controlnet\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        timestep: Optional[torch.LongTensor] = None,\n        controlnet_cond: Optional[torch.Tensor] = None,\n        added_cond_kwargs: Dict[str, torch.Tensor] = None,\n        cross_attention_kwargs: Dict[str, Any] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n    ):\n        if self.transformer.use_additional_conditions and added_cond_kwargs is None:\n            raise ValueError(\"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.\")\n\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.\n        #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.\n        #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None and attention_mask.ndim == 2:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # 1. Input\n        batch_size = hidden_states.shape[0]\n        height, width = (\n            hidden_states.shape[-2] // self.transformer.config.patch_size,\n            hidden_states.shape[-1] // self.transformer.config.patch_size,\n        )\n        hidden_states = self.transformer.pos_embed(hidden_states)\n\n        timestep, embedded_timestep = self.transformer.adaln_single(\n            timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype\n        )\n\n        if self.transformer.caption_projection is not None:\n            encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states)\n            encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])\n\n        controlnet_states_down = None\n        if controlnet_cond is not None:\n            controlnet_states_down = self.transformer.pos_embed(controlnet_cond)\n\n        # 2. Blocks\n        for block_index, block in enumerate(self.transformer.transformer_blocks):\n            if torch.is_grad_enabled() and self.gradient_checkpointing:\n                # rc todo: for training and gradient checkpointing\n                print(\"Gradient checkpointing is not supported for the controlnet transformer model, yet.\")\n                exit(1)\n\n                hidden_states = self._gradient_checkpointing_func(\n                    block,\n                    hidden_states,\n                    attention_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    timestep,\n                    cross_attention_kwargs,\n                    None,\n                )\n            else:\n                # the control nets are only used for the blocks 1 to self.blocks_num\n                if block_index > 0 and block_index <= self.blocks_num and controlnet_states_down is not None:\n                    controlnet_states_left, controlnet_states_down = self.controlnet.controlnet_blocks[\n                        block_index - 1\n                    ](\n                        hidden_states=hidden_states,  # used only in the first block\n                        controlnet_states=controlnet_states_down,\n                        encoder_hidden_states=encoder_hidden_states,\n                        timestep=timestep,\n                        added_cond_kwargs=added_cond_kwargs,\n                        cross_attention_kwargs=cross_attention_kwargs,\n                        attention_mask=attention_mask,\n                        encoder_attention_mask=encoder_attention_mask,\n                    )\n\n                    hidden_states = hidden_states + controlnet_states_left\n\n                hidden_states = block(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    encoder_hidden_states=encoder_hidden_states,\n                    encoder_attention_mask=encoder_attention_mask,\n                    timestep=timestep,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    class_labels=None,\n                )\n\n        # 3. Output\n        shift, scale = (\n            self.transformer.scale_shift_table[None]\n            + embedded_timestep[:, None].to(self.transformer.scale_shift_table.device)\n        ).chunk(2, dim=1)\n        hidden_states = self.transformer.norm_out(hidden_states)\n        # Modulation\n        hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)\n        hidden_states = self.transformer.proj_out(hidden_states)\n        hidden_states = hidden_states.squeeze(1)\n\n        # unpatchify\n        hidden_states = hidden_states.reshape(\n            shape=(\n                -1,\n                height,\n                width,\n                self.transformer.config.patch_size,\n                self.transformer.config.patch_size,\n                self.transformer.out_channels,\n            )\n        )\n        hidden_states = torch.einsum(\"nhwpqc->nchpwq\", hidden_states)\n        output = hidden_states.reshape(\n            shape=(\n                -1,\n                self.transformer.out_channels,\n                height * self.transformer.config.patch_size,\n                width * self.transformer.config.patch_size,\n            )\n        )\n\n        if not return_dict:\n            return (output,)\n\n        return Transformer2DModelOutput(sample=output)\n"
  },
  {
    "path": "examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py",
    "content": "# Copyright 2025 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport html\nimport inspect\nimport re\nimport urllib.parse as ul\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL\nimport torch\nfrom controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel\nfrom transformers import T5EncoderModel, T5Tokenizer\n\nfrom diffusers.image_processor import PipelineImageInput, PixArtImageProcessor\nfrom diffusers.models import AutoencoderKL, PixArtTransformer2DModel\nfrom diffusers.pipelines import DiffusionPipeline, ImagePipelineOutput\nfrom diffusers.schedulers import DPMSolverMultistepScheduler\nfrom diffusers.utils import (\n    BACKENDS_MAPPING,\n    deprecate,\n    is_bs4_available,\n    is_ftfy_available,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nif is_bs4_available():\n    from bs4 import BeautifulSoup\n\nif is_ftfy_available():\n    import ftfy\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n        >>> from diffusers import PixArtAlphaPipeline\n\n        >>> # You can replace the checkpoint id with \"PixArt-alpha/PixArt-XL-2-512x512\" too.\n        >>> pipe = PixArtAlphaPipeline.from_pretrained(\"PixArt-alpha/PixArt-XL-2-1024-MS\", torch_dtype=torch.float16)\n        >>> # Enable memory optimizations.\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> prompt = \"A small cactus with a happy face in the Sahara desert.\"\n        >>> image = pipe(prompt).images[0]\n        ```\n\"\"\"\n\nASPECT_RATIO_1024_BIN = {\n    \"0.25\": [512.0, 2048.0],\n    \"0.28\": [512.0, 1856.0],\n    \"0.32\": [576.0, 1792.0],\n    \"0.33\": [576.0, 1728.0],\n    \"0.35\": [576.0, 1664.0],\n    \"0.4\": [640.0, 1600.0],\n    \"0.42\": [640.0, 1536.0],\n    \"0.48\": [704.0, 1472.0],\n    \"0.5\": [704.0, 1408.0],\n    \"0.52\": [704.0, 1344.0],\n    \"0.57\": [768.0, 1344.0],\n    \"0.6\": [768.0, 1280.0],\n    \"0.68\": [832.0, 1216.0],\n    \"0.72\": [832.0, 1152.0],\n    \"0.78\": [896.0, 1152.0],\n    \"0.82\": [896.0, 1088.0],\n    \"0.88\": [960.0, 1088.0],\n    \"0.94\": [960.0, 1024.0],\n    \"1.0\": [1024.0, 1024.0],\n    \"1.07\": [1024.0, 960.0],\n    \"1.13\": [1088.0, 960.0],\n    \"1.21\": [1088.0, 896.0],\n    \"1.29\": [1152.0, 896.0],\n    \"1.38\": [1152.0, 832.0],\n    \"1.46\": [1216.0, 832.0],\n    \"1.67\": [1280.0, 768.0],\n    \"1.75\": [1344.0, 768.0],\n    \"2.0\": [1408.0, 704.0],\n    \"2.09\": [1472.0, 704.0],\n    \"2.4\": [1536.0, 640.0],\n    \"2.5\": [1600.0, 640.0],\n    \"3.0\": [1728.0, 576.0],\n    \"4.0\": [2048.0, 512.0],\n}\n\nASPECT_RATIO_512_BIN = {\n    \"0.25\": [256.0, 1024.0],\n    \"0.28\": [256.0, 928.0],\n    \"0.32\": [288.0, 896.0],\n    \"0.33\": [288.0, 864.0],\n    \"0.35\": [288.0, 832.0],\n    \"0.4\": [320.0, 800.0],\n    \"0.42\": [320.0, 768.0],\n    \"0.48\": [352.0, 736.0],\n    \"0.5\": [352.0, 704.0],\n    \"0.52\": [352.0, 672.0],\n    \"0.57\": [384.0, 672.0],\n    \"0.6\": [384.0, 640.0],\n    \"0.68\": [416.0, 608.0],\n    \"0.72\": [416.0, 576.0],\n    \"0.78\": [448.0, 576.0],\n    \"0.82\": [448.0, 544.0],\n    \"0.88\": [480.0, 544.0],\n    \"0.94\": [480.0, 512.0],\n    \"1.0\": [512.0, 512.0],\n    \"1.07\": [512.0, 480.0],\n    \"1.13\": [544.0, 480.0],\n    \"1.21\": [544.0, 448.0],\n    \"1.29\": [576.0, 448.0],\n    \"1.38\": [576.0, 416.0],\n    \"1.46\": [608.0, 416.0],\n    \"1.67\": [640.0, 384.0],\n    \"1.75\": [672.0, 384.0],\n    \"2.0\": [704.0, 352.0],\n    \"2.09\": [736.0, 352.0],\n    \"2.4\": [768.0, 320.0],\n    \"2.5\": [800.0, 320.0],\n    \"3.0\": [864.0, 288.0],\n    \"4.0\": [1024.0, 256.0],\n}\n\nASPECT_RATIO_256_BIN = {\n    \"0.25\": [128.0, 512.0],\n    \"0.28\": [128.0, 464.0],\n    \"0.32\": [144.0, 448.0],\n    \"0.33\": [144.0, 432.0],\n    \"0.35\": [144.0, 416.0],\n    \"0.4\": [160.0, 400.0],\n    \"0.42\": [160.0, 384.0],\n    \"0.48\": [176.0, 368.0],\n    \"0.5\": [176.0, 352.0],\n    \"0.52\": [176.0, 336.0],\n    \"0.57\": [192.0, 336.0],\n    \"0.6\": [192.0, 320.0],\n    \"0.68\": [208.0, 304.0],\n    \"0.72\": [208.0, 288.0],\n    \"0.78\": [224.0, 288.0],\n    \"0.82\": [224.0, 272.0],\n    \"0.88\": [240.0, 272.0],\n    \"0.94\": [240.0, 256.0],\n    \"1.0\": [256.0, 256.0],\n    \"1.07\": [256.0, 240.0],\n    \"1.13\": [272.0, 240.0],\n    \"1.21\": [272.0, 224.0],\n    \"1.29\": [288.0, 224.0],\n    \"1.38\": [288.0, 208.0],\n    \"1.46\": [304.0, 208.0],\n    \"1.67\": [320.0, 192.0],\n    \"1.75\": [336.0, 192.0],\n    \"2.0\": [352.0, 176.0],\n    \"2.09\": [368.0, 176.0],\n    \"2.4\": [384.0, 160.0],\n    \"2.5\": [400.0, 160.0],\n    \"3.0\": [432.0, 144.0],\n    \"4.0\": [512.0, 128.0],\n}\n\n\ndef get_closest_hw(width, height, image_size):\n    if image_size == 1024:\n        aspect_ratio_bin = ASPECT_RATIO_1024_BIN\n    elif image_size == 512:\n        aspect_ratio_bin = ASPECT_RATIO_512_BIN\n    else:\n        raise ValueError(\"Invalid image size\")\n\n    height, width = PixArtImageProcessor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)\n\n    return width, height\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass PixArtAlphaControlnetPipeline(DiffusionPipeline):\n    r\"\"\"\n    Pipeline for text-to-image generation using PixArt-Alpha.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`T5EncoderModel`]):\n            Frozen text-encoder. PixArt-Alpha uses\n            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the\n            [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.\n        tokenizer (`T5Tokenizer`):\n            Tokenizer of class\n            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).\n        transformer ([`PixArtTransformer2DModel`]):\n            A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n    \"\"\"\n\n    bad_punct_regex = re.compile(\n        r\"[\"\n        + \"#®•©™&@·º½¾¿¡§~\"\n        + r\"\\)\"\n        + r\"\\(\"\n        + r\"\\]\"\n        + r\"\\[\"\n        + r\"\\}\"\n        + r\"\\{\"\n        + r\"\\|\"\n        + \"\\\\\"\n        + r\"\\/\"\n        + r\"\\*\"\n        + r\"]{1,}\"\n    )  # noqa\n\n    _optional_components = [\"tokenizer\", \"text_encoder\"]\n    model_cpu_offload_seq = \"text_encoder->transformer->vae\"\n\n    def __init__(\n        self,\n        tokenizer: T5Tokenizer,\n        text_encoder: T5EncoderModel,\n        vae: AutoencoderKL,\n        transformer: PixArtTransformer2DModel,\n        controlnet: PixArtControlNetAdapterModel,\n        scheduler: DPMSolverMultistepScheduler,\n    ):\n        super().__init__()\n\n        # change to the controlnet transformer model\n        transformer = PixArtControlNetTransformerModel(transformer=transformer, controlnet=controlnet)\n\n        self.register_modules(\n            tokenizer=tokenizer,\n            text_encoder=text_encoder,\n            vae=vae,\n            transformer=transformer,\n            scheduler=scheduler,\n            controlnet=controlnet,\n        )\n\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.control_image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n    # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        do_classifier_free_guidance: bool = True,\n        negative_prompt: str = \"\",\n        num_images_per_prompt: int = 1,\n        device: Optional[torch.device] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        clean_caption: bool = False,\n        max_sequence_length: int = 120,\n        **kwargs,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`\n                instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For\n                PixArt-Alpha, this should be \"\".\n            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):\n                whether to use classifier free guidance or not\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                number of images that should be generated per prompt\n            device: (`torch.device`, *optional*):\n                torch device to place the resulting embeddings on\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the \"\"\n                string.\n            clean_caption (`bool`, defaults to `False`):\n                If `True`, the function will preprocess and clean the provided caption before encoding.\n            max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.\n        \"\"\"\n\n        if \"mask_feature\" in kwargs:\n            deprecation_message = \"The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version.\"\n            deprecate(\"mask_feature\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        if device is None:\n            device = self._execution_device\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        # See Section 3.1. of the paper.\n        max_length = max_sequence_length\n\n        if prompt_embeds is None:\n            prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                add_special_tokens=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])\n                logger.warning(\n                    \"The following part of your input was truncated because T5 can only handle sequences up to\"\n                    f\" {max_length} tokens: {removed_text}\"\n                )\n\n            prompt_attention_mask = text_inputs.attention_mask\n            prompt_attention_mask = prompt_attention_mask.to(device)\n\n            prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)\n            prompt_embeds = prompt_embeds[0]\n\n        if self.text_encoder is not None:\n            dtype = self.text_encoder.dtype\n        elif self.transformer is not None:\n            dtype = self.transformer.controlnet.dtype\n        else:\n            dtype = None\n\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n        prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)\n        prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens = [negative_prompt] * batch_size\n            uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_attention_mask=True,\n                add_special_tokens=True,\n                return_tensors=\"pt\",\n            )\n            negative_prompt_attention_mask = uncond_input.attention_mask\n            negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)\n            negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)\n        else:\n            negative_prompt_embeds = None\n            negative_prompt_attention_mask = None\n\n        return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        height,\n        width,\n        negative_prompt,\n        callback_steps,\n        image=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        prompt_attention_mask=None,\n        negative_prompt_attention_mask=None,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and prompt_attention_mask is None:\n            raise ValueError(\"Must provide `prompt_attention_mask` when specifying `prompt_embeds`.\")\n\n        if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:\n            raise ValueError(\"Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.\")\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n            if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:\n                raise ValueError(\n                    \"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`\"\n                    f\" {negative_prompt_attention_mask.shape}.\"\n                )\n\n        if image is not None:\n            self.check_image(image, prompt, prompt_embeds)\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing\n    def _text_preprocessing(self, text, clean_caption=False):\n        if clean_caption and not is_bs4_available():\n            logger.warning(BACKENDS_MAPPING[\"bs4\"][-1].format(\"Setting `clean_caption=True`\"))\n            logger.warning(\"Setting `clean_caption` to False...\")\n            clean_caption = False\n\n        if clean_caption and not is_ftfy_available():\n            logger.warning(BACKENDS_MAPPING[\"ftfy\"][-1].format(\"Setting `clean_caption=True`\"))\n            logger.warning(\"Setting `clean_caption` to False...\")\n            clean_caption = False\n\n        if not isinstance(text, (tuple, list)):\n            text = [text]\n\n        def process(text: str):\n            if clean_caption:\n                text = self._clean_caption(text)\n                text = self._clean_caption(text)\n            else:\n                text = text.lower().strip()\n            return text\n\n        return [process(t) for t in text]\n\n    # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption\n    def _clean_caption(self, caption):\n        caption = str(caption)\n        caption = ul.unquote_plus(caption)\n        caption = caption.strip().lower()\n        caption = re.sub(\"<person>\", \"person\", caption)\n        # urls:\n        caption = re.sub(\n            r\"\\b((?:https?:(?:\\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\\w/-]*\\b\\/?(?!@)))\",  # noqa\n            \"\",\n            caption,\n        )  # regex for urls\n        caption = re.sub(\n            r\"\\b((?:www:(?:\\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\\w/-]*\\b\\/?(?!@)))\",  # noqa\n            \"\",\n            caption,\n        )  # regex for urls\n        # html:\n        caption = BeautifulSoup(caption, features=\"html.parser\").text\n\n        # @<nickname>\n        caption = re.sub(r\"@[\\w\\d]+\\b\", \"\", caption)\n\n        # 31C0—31EF CJK Strokes\n        # 31F0—31FF Katakana Phonetic Extensions\n        # 3200—32FF Enclosed CJK Letters and Months\n        # 3300—33FF CJK Compatibility\n        # 3400—4DBF CJK Unified Ideographs Extension A\n        # 4DC0—4DFF Yijing Hexagram Symbols\n        # 4E00—9FFF CJK Unified Ideographs\n        caption = re.sub(r\"[\\u31c0-\\u31ef]+\", \"\", caption)\n        caption = re.sub(r\"[\\u31f0-\\u31ff]+\", \"\", caption)\n        caption = re.sub(r\"[\\u3200-\\u32ff]+\", \"\", caption)\n        caption = re.sub(r\"[\\u3300-\\u33ff]+\", \"\", caption)\n        caption = re.sub(r\"[\\u3400-\\u4dbf]+\", \"\", caption)\n        caption = re.sub(r\"[\\u4dc0-\\u4dff]+\", \"\", caption)\n        caption = re.sub(r\"[\\u4e00-\\u9fff]+\", \"\", caption)\n        #######################################################\n\n        # все виды тире / all types of dash --> \"-\"\n        caption = re.sub(\n            r\"[\\u002D\\u058A\\u05BE\\u1400\\u1806\\u2010-\\u2015\\u2E17\\u2E1A\\u2E3A\\u2E3B\\u2E40\\u301C\\u3030\\u30A0\\uFE31\\uFE32\\uFE58\\uFE63\\uFF0D]+\",  # noqa\n            \"-\",\n            caption,\n        )\n\n        # кавычки к одному стандарту\n        caption = re.sub(r\"[`´«»“”¨]\", '\"', caption)\n        caption = re.sub(r\"[‘’]\", \"'\", caption)\n\n        # &quot;\n        caption = re.sub(r\"&quot;?\", \"\", caption)\n        # &amp\n        caption = re.sub(r\"&amp\", \"\", caption)\n\n        # ip addresses:\n        caption = re.sub(r\"\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\", \" \", caption)\n\n        # article ids:\n        caption = re.sub(r\"\\d:\\d\\d\\s+$\", \"\", caption)\n\n        # \\n\n        caption = re.sub(r\"\\\\n\", \" \", caption)\n\n        # \"#123\"\n        caption = re.sub(r\"#\\d{1,3}\\b\", \"\", caption)\n        # \"#12345..\"\n        caption = re.sub(r\"#\\d{5,}\\b\", \"\", caption)\n        # \"123456..\"\n        caption = re.sub(r\"\\b\\d{6,}\\b\", \"\", caption)\n        # filenames:\n        caption = re.sub(r\"[\\S]+\\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)\", \"\", caption)\n\n        #\n        caption = re.sub(r\"[\\\"\\']{2,}\", r'\"', caption)  # \"\"\"AUSVERKAUFT\"\"\"\n        caption = re.sub(r\"[\\.]{2,}\", r\" \", caption)  # \"\"\"AUSVERKAUFT\"\"\"\n\n        caption = re.sub(self.bad_punct_regex, r\" \", caption)  # ***AUSVERKAUFT***, #AUSVERKAUFT\n        caption = re.sub(r\"\\s+\\.\\s+\", r\" \", caption)  # \" . \"\n\n        # this-is-my-cute-cat / this_is_my_cute_cat\n        regex2 = re.compile(r\"(?:\\-|\\_)\")\n        if len(re.findall(regex2, caption)) > 3:\n            caption = re.sub(regex2, \" \", caption)\n\n        caption = ftfy.fix_text(caption)\n        caption = html.unescape(html.unescape(caption))\n\n        caption = re.sub(r\"\\b[a-zA-Z]{1,3}\\d{3,15}\\b\", \"\", caption)  # jc6640\n        caption = re.sub(r\"\\b[a-zA-Z]+\\d+[a-zA-Z]+\\b\", \"\", caption)  # jc6640vc\n        caption = re.sub(r\"\\b\\d+[a-zA-Z]+\\d+\\b\", \"\", caption)  # 6640vc231\n\n        caption = re.sub(r\"(worldwide\\s+)?(free\\s+)?shipping\", \"\", caption)\n        caption = re.sub(r\"(free\\s)?download(\\sfree)?\", \"\", caption)\n        caption = re.sub(r\"\\bclick\\b\\s(?:for|on)\\s\\w+\", \"\", caption)\n        caption = re.sub(r\"\\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\\simage[s]?)?\", \"\", caption)\n        caption = re.sub(r\"\\bpage\\s+\\d+\\b\", \"\", caption)\n\n        caption = re.sub(r\"\\b\\d*[a-zA-Z]+\\d+[a-zA-Z]+\\d+[a-zA-Z\\d]*\\b\", r\" \", caption)  # j2d1a2a...\n\n        caption = re.sub(r\"\\b\\d+\\.?\\d*[xх×]\\d+\\.?\\d*\\b\", \"\", caption)\n\n        caption = re.sub(r\"\\b\\s+\\:\\s+\", r\": \", caption)\n        caption = re.sub(r\"(\\D[,\\./])\\b\", r\"\\1 \", caption)\n        caption = re.sub(r\"\\s+\", \" \", caption)\n\n        caption.strip()\n\n        caption = re.sub(r\"^[\\\"\\']([\\w\\W]+)[\\\"\\']$\", r\"\\1\", caption)\n        caption = re.sub(r\"^[\\'\\_,\\-\\:;]\", r\"\", caption)\n        caption = re.sub(r\"[\\'\\_,\\-\\:\\-\\+]$\", r\"\", caption)\n        caption = re.sub(r\"^\\.\\S+$\", \"\", caption)\n\n        return caption.strip()\n\n    def prepare_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # based on pipeline_pixart_inpaiting.py\n    def prepare_image_latents(self, image, device, dtype):\n        image = image.to(device=device, dtype=dtype)\n\n        image_latents = self.vae.encode(image).latent_dist.sample()\n        image_latents = image_latents * self.vae.config.scaling_factor\n        return image_latents\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        negative_prompt: str = \"\",\n        num_inference_steps: int = 20,\n        timesteps: List[int] = None,\n        sigmas: List[float] = None,\n        guidance_scale: float = 4.5,\n        num_images_per_prompt: Optional[int] = 1,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        prompt_attention_mask: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_attention_mask: Optional[torch.Tensor] = None,\n        # rc todo: controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        # rc todo: control_guidance_start = 0.0,\n        # rc todo: control_guidance_end = 1.0,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: int = 1,\n        clean_caption: bool = True,\n        use_resolution_binning: bool = True,\n        max_sequence_length: int = 120,\n        **kwargs,\n    ) -> Union[ImagePipelineOutput, Tuple]:\n        \"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            num_inference_steps (`int`, *optional*, defaults to 100):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            guidance_scale (`float`, *optional*, defaults to 4.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size):\n                The width in pixels of the generated image.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be \"\". If not\n                provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.\n            negative_prompt_attention_mask (`torch.Tensor`, *optional*):\n                Pre-generated attention mask for negative text embeddings.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n            clean_caption (`bool`, *optional*, defaults to `True`):\n                Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to\n                be installed. If the dependencies are not installed, the embeddings will be created from the raw\n                prompt.\n            use_resolution_binning (`bool` defaults to `True`):\n                If set to `True`, the requested height and width are first mapped to the closest resolutions using\n                `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to\n                the requested resolution. Useful for generating non-square images.\n            max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.ImagePipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is\n                returned where the first element is a list with the generated images\n        \"\"\"\n        if \"mask_feature\" in kwargs:\n            deprecation_message = \"The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version.\"\n            deprecate(\"mask_feature\", \"1.0.0\", deprecation_message, standard_warn=False)\n        # 1. Check inputs. Raise error if not correct\n        height = height or self.transformer.config.sample_size * self.vae_scale_factor\n        width = width or self.transformer.config.sample_size * self.vae_scale_factor\n        if use_resolution_binning:\n            if self.transformer.config.sample_size == 128:\n                aspect_ratio_bin = ASPECT_RATIO_1024_BIN\n            elif self.transformer.config.sample_size == 64:\n                aspect_ratio_bin = ASPECT_RATIO_512_BIN\n            elif self.transformer.config.sample_size == 32:\n                aspect_ratio_bin = ASPECT_RATIO_256_BIN\n            else:\n                raise ValueError(\"Invalid sample size\")\n            orig_height, orig_width = height, width\n            height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)\n\n        self.check_inputs(\n            prompt,\n            height,\n            width,\n            negative_prompt,\n            callback_steps,\n            image,\n            prompt_embeds,\n            negative_prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_attention_mask,\n        )\n\n        # 2. Default height and width to transformer\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        (\n            prompt_embeds,\n            prompt_attention_mask,\n            negative_prompt_embeds,\n            negative_prompt_attention_mask,\n        ) = self.encode_prompt(\n            prompt,\n            do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            device=device,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            prompt_attention_mask=prompt_attention_mask,\n            negative_prompt_attention_mask=negative_prompt_attention_mask,\n            clean_caption=clean_caption,\n            max_sequence_length=max_sequence_length,\n        )\n        if do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n            prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)\n\n        # 4. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(\n            self.scheduler, num_inference_steps, device, timesteps, sigmas\n        )\n\n        # 4.1 Prepare image\n        image_latents = None\n        if image is not None:\n            image = self.prepare_image(\n                image=image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=self.transformer.controlnet.dtype,\n                do_classifier_free_guidance=do_classifier_free_guidance,\n            )\n\n            image_latents = self.prepare_image_latents(image, device, self.transformer.controlnet.dtype)\n\n        # 5. Prepare latents.\n        latent_channels = self.transformer.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            latent_channels,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 6.1 Prepare micro-conditions.\n        added_cond_kwargs = {\"resolution\": None, \"aspect_ratio\": None}\n        if self.transformer.config.sample_size == 128:\n            resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)\n            aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)\n            resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)\n            aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)\n\n            if do_classifier_free_guidance:\n                resolution = torch.cat([resolution, resolution], dim=0)\n                aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)\n\n            added_cond_kwargs = {\"resolution\": resolution, \"aspect_ratio\": aspect_ratio}\n\n        # 7. Denoising loop\n        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                current_timestep = t\n                if not torch.is_tensor(current_timestep):\n                    # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n                    # This would be a good case for the `match` statement (Python 3.10+)\n                    is_mps = latent_model_input.device.type == \"mps\"\n                    is_npu = latent_model_input.device.type == \"npu\"\n                    if isinstance(current_timestep, float):\n                        dtype = torch.float32 if (is_mps or is_npu) else torch.float64\n                    else:\n                        dtype = torch.int32 if (is_mps or is_npu) else torch.int64\n                    current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)\n                elif len(current_timestep.shape) == 0:\n                    current_timestep = current_timestep[None].to(latent_model_input.device)\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                current_timestep = current_timestep.expand(latent_model_input.shape[0])\n\n                # predict noise model_output\n                noise_pred = self.transformer(\n                    latent_model_input,\n                    encoder_hidden_states=prompt_embeds,\n                    encoder_attention_mask=prompt_attention_mask,\n                    timestep=current_timestep,\n                    controlnet_cond=image_latents,\n                    # rc todo: controlnet_conditioning_scale=1.0,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # learned sigma\n                if self.transformer.config.out_channels // 2 == latent_channels:\n                    noise_pred = noise_pred.chunk(2, dim=1)[0]\n                else:\n                    noise_pred = noise_pred\n\n                # compute previous image: x_t -> x_t-1\n                if num_inference_steps == 1:\n                    # For DMD one step sampling: https://huggingface.co/papers/2311.18828\n                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample\n                else:\n                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n            if use_resolution_binning:\n                image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)\n        else:\n            image = latents\n\n        if not output_type == \"latent\":\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return ImagePipelineOutput(images=image)\n"
  },
  {
    "path": "examples/research_projects/pixart/requirements.txt",
    "content": "transformers\nSentencePiece\ntorchvision\ncontrolnet-aux\ndatasets\n# wandb"
  },
  {
    "path": "examples/research_projects/pixart/run_pixart_alpha_controlnet_pipeline.py",
    "content": "import torch\nimport torchvision.transforms as T\nfrom controlnet_aux import HEDdetector\n\nfrom diffusers.utils import load_image\nfrom examples.research_projects.pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel\nfrom examples.research_projects.pixart.pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline\n\n\ncontrolnet_repo_id = \"raulc0399/pixart-alpha-hed-controlnet\"\n\nweight_dtype = torch.float16\nimage_size = 1024\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\ntorch.manual_seed(0)\n\n# load controlnet\ncontrolnet = PixArtControlNetAdapterModel.from_pretrained(\n    controlnet_repo_id,\n    torch_dtype=weight_dtype,\n    use_safetensors=True,\n).to(device)\n\npipe = PixArtAlphaControlnetPipeline.from_pretrained(\n    \"PixArt-alpha/PixArt-XL-2-1024-MS\",\n    controlnet=controlnet,\n    torch_dtype=weight_dtype,\n    use_safetensors=True,\n).to(device)\n\nimages_path = \"images\"\ncontrol_image_file = \"0_7.jpg\"\n\n# prompt = \"cinematic photo of superman in action . 35mm photograph, film, bokeh, professional, 4k, highly detailed\"\n# prompt = \"yellow modern car, city in background, beautiful rainy day\"\n# prompt = \"modern villa, clear sky, suny day . 35mm photograph, film, bokeh, professional, 4k, highly detailed\"\n# prompt = \"robot dog toy in park . 35mm photograph, film, bokeh, professional, 4k, highly detailed\"\n# prompt = \"purple car, on highway, beautiful sunny day\"\n# prompt = \"realistical photo of a loving couple standing in the open kitchen of the living room, cooking .\"\nprompt = \"battleship in space, galaxy in background\"\n\ncontrol_image_name = control_image_file.split(\".\")[0]\n\ncontrol_image = load_image(f\"{images_path}/{control_image_file}\")\nprint(control_image.size)\nheight, width = control_image.size\n\nhed = HEDdetector.from_pretrained(\"lllyasviel/Annotators\")\n\ncondition_transform = T.Compose(\n    [\n        T.Lambda(lambda img: img.convert(\"RGB\")),\n        T.CenterCrop([image_size, image_size]),\n    ]\n)\n\ncontrol_image = condition_transform(control_image)\nhed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)\n\nhed_edge.save(f\"{images_path}/{control_image_name}_hed.jpg\")\n\n# run pipeline\nwith torch.no_grad():\n    out = pipe(\n        prompt=prompt,\n        image=hed_edge,\n        num_inference_steps=14,\n        guidance_scale=4.5,\n        height=image_size,\n        width=image_size,\n    )\n\n    out.images[0].save(f\"{images_path}//{control_image_name}_output.jpg\")\n"
  },
  {
    "path": "examples/research_projects/pixart/train_controlnet_hf_diffusers.sh",
    "content": "#!/bin/bash\n\n# run\n# accelerate config\n\n# check with\n# accelerate env\n\nexport MODEL_DIR=\"PixArt-alpha/PixArt-XL-2-512x512\"\nexport OUTPUT_DIR=\"output/pixart-controlnet-hf-diffusers-test\"\n\naccelerate launch ./train_pixart_controlnet_hf.py --mixed_precision=\"fp16\" \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --resolution=512 \\\n --learning_rate=1e-5 \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n --report_to=\"wandb\" \\\n --seed=42 \\\n --dataloader_num_workers=8\n#  --lr_scheduler=\"cosine\" --lr_warmup_steps=0 \\\n"
  },
  {
    "path": "examples/research_projects/pixart/train_pixart_controlnet_hf.py",
    "content": "# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Fine-tuning script for Stable Diffusion for text2image with HuggingFace diffusers.\"\"\"\n\nimport argparse\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import T5EncoderModel, T5Tokenizer\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler\nfrom diffusers.models import PixArtTransformer2DModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import compute_snr\nfrom diffusers.utils import check_min_version, is_wandb_available, make_image_grid\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\nfrom examples.research_projects.pixart.controlnet_pixart_alpha import (\n    PixArtControlNetAdapterModel,\n    PixArtControlNetTransformerModel,\n)\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.29.2\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef log_validation(\n    vae,\n    transformer,\n    controlnet,\n    tokenizer,\n    scheduler,\n    text_encoder,\n    args,\n    accelerator,\n    weight_dtype,\n    step,\n    is_final_validation=False,\n):\n    if weight_dtype == torch.float16 or weight_dtype == torch.bfloat16:\n        raise ValueError(\n            \"Validation is not supported with mixed precision training, disable validation and use the validation script, that will generate images from the saved checkpoints.\"\n        )\n\n    if not is_final_validation:\n        logger.info(f\"Running validation step {step} ... \")\n\n        controlnet = accelerator.unwrap_model(controlnet)\n        pipeline = PixArtAlphaControlnetPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            vae=vae,\n            transformer=transformer,\n            scheduler=scheduler,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            controlnet=controlnet,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n    else:\n        logger.info(\"Running validation - final ... \")\n\n        controlnet = PixArtControlNetAdapterModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)\n\n        pipeline = PixArtAlphaControlnetPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            controlnet=controlnet,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        validation_image = Image.open(validation_image).convert(\"RGB\")\n        validation_image = validation_image.resize((args.resolution, args.resolution))\n\n        images = []\n\n        for _ in range(args.num_validation_images):\n            image = pipeline(\n                prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator\n            ).images[0]\n            images.append(image)\n\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images = [np.asarray(validation_image)]\n\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images.append(wandb.Image(validation_image, caption=\"Controlnet conditioning\"))\n\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({tracker_key: formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n        del pipeline\n        gc.collect()\n        torch.cuda.empty_cache()\n\n        logger.info(\"Validation done!!\")\n\n        return image_logs\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, dataset_name=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# controlnet-{repo_id}\n\nThese are controlnet weights trained on {base_model} with new type of conditioning.\n{img_str}\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"openrail++\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"pixart-alpha\",\n        \"pixart-alpha-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"controlnet\",\n        \"diffusers-training\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--controlnet_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained controlnet model or model identifier from huggingface.co/models.\"\n        \" If not specified controlnet weights are initialized from the transformer.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the controlnet conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        nargs=\"+\",\n        default=None,\n        help=\"One or more prompts to be evaluated every `--validation_steps`.\"\n        \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n        \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\",\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"pixart-controlnet\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    # ----Diffusion Training Arguments----\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=None,\n        help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.\",\n    )\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"pixart_controlnet\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # See Section 3.1. of the paper.\n    max_length = 120\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\", torch_dtype=weight_dtype\n    )\n    tokenizer = T5Tokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision, torch_dtype=weight_dtype\n    )\n\n    text_encoder = T5EncoderModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, torch_dtype=weight_dtype\n    )\n    text_encoder.requires_grad_(False)\n    text_encoder.to(accelerator.device)\n\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    vae.requires_grad_(False)\n    vae.to(accelerator.device)\n\n    transformer = PixArtTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"transformer\")\n    transformer.to(accelerator.device)\n    transformer.requires_grad_(False)\n\n    if args.controlnet_model_name_or_path:\n        logger.info(\"Loading existing controlnet weights\")\n        controlnet = PixArtControlNetAdapterModel.from_pretrained(args.controlnet_model_name_or_path)\n    else:\n        logger.info(\"Initializing controlnet weights from transformer.\")\n        controlnet = PixArtControlNetAdapterModel.from_transformer(transformer)\n\n    transformer.to(dtype=weight_dtype)\n\n    controlnet.to(accelerator.device)\n    controlnet.train()\n\n    def unwrap_model(model, keep_fp32_wrapper=True):\n        model = accelerator.unwrap_model(model, keep_fp32_wrapper=keep_fp32_wrapper)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # 10. Handle saving and loading of checkpoints\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                for _, model in enumerate(models):\n                    if isinstance(model, PixArtControlNetTransformerModel):\n                        print(f\"Saving model {model.__class__.__name__} to {output_dir}\")\n                        model.controlnet.save_pretrained(os.path.join(output_dir, \"controlnet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            # rc todo: test and load the controlenet adapter and transformer\n            raise ValueError(\"load model hook not tested\")\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                if isinstance(model, PixArtControlNetTransformerModel):\n                    load_model = PixArtControlNetAdapterModel.from_pretrained(input_dir, subfolder=\"controlnet\")\n                    model.register_to_config(**load_model.config)\n\n                    model.load_state_dict(load_model.state_dict())\n                    del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warn(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            transformer.enable_xformers_memory_efficient_attention()\n            controlnet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if unwrap_model(controlnet).dtype != torch.float32:\n        raise ValueError(\n            f\"Transformer loaded as datatype {unwrap_model(controlnet).dtype}. The trainable parameters should be in torch.float32.\"\n        )\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n        controlnet.enable_gradient_checkpointing()\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    params_to_optimize = controlnet.parameters()\n    optimizer = optimizer_cls(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            data_dir=args.train_data_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    if args.image_column is None:\n        image_column = column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    if args.conditioning_image_column is None:\n        conditioning_image_column = column_names[2]\n        logger.info(f\"conditioning image column defaulting to {conditioning_image_column}\")\n    else:\n        conditioning_image_column = args.conditioning_image_column\n        if conditioning_image_column not in column_names:\n            raise ValueError(\n                f\"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(examples, is_train=True, proportion_empty_prompts=0.0, max_length=120):\n        captions = []\n        for caption in examples[caption_column]:\n            if random.random() < proportion_empty_prompts:\n                captions.append(\"\")\n            elif isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(captions, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n        return inputs.input_ids, inputs.attention_mask\n\n    # Preprocessing the datasets.\n    train_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    conditioning_image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n\n        conditioning_images = [image.convert(\"RGB\") for image in examples[args.conditioning_image_column]]\n        examples[\"conditioning_pixel_values\"] = [conditioning_image_transforms(image) for image in conditioning_images]\n\n        examples[\"input_ids\"], examples[\"prompt_attention_mask\"] = tokenize_captions(\n            examples, proportion_empty_prompts=args.proportion_empty_prompts, max_length=max_length\n        )\n\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n        conditioning_pixel_values = torch.stack([example[\"conditioning_pixel_values\"] for example in examples])\n        conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()\n\n        input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n        prompt_attention_mask = torch.stack([example[\"prompt_attention_mask\"] for example in examples])\n\n        return {\n            \"pixel_values\": pixel_values,\n            \"conditioning_pixel_values\": conditioning_pixel_values,\n            \"input_ids\": input_ids,\n            \"prompt_attention_mask\": prompt_attention_mask,\n        }\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    # Prepare everything with our `accelerator`.\n    controlnet_transformer = PixArtControlNetTransformerModel(transformer, controlnet, training=True)\n    controlnet_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        controlnet_transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(args.tracker_project_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    latent_channels = transformer.config.in_channels\n    for epoch in range(first_epoch, args.num_train_epochs):\n        controlnet_transformer.controlnet.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(controlnet):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Convert control images to latent space\n                controlnet_image_latents = vae.encode(\n                    batch[\"conditioning_pixel_values\"].to(dtype=weight_dtype)\n                ).latent_dist.sample()\n                controlnet_image_latents = controlnet_image_latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (latents.shape[0], latents.shape[1], 1, 1), device=latents.device\n                    )\n\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                prompt_embeds = text_encoder(batch[\"input_ids\"], attention_mask=batch[\"prompt_attention_mask\"])[0]\n                prompt_attention_mask = batch[\"prompt_attention_mask\"]\n                # Get the target for loss depending on the prediction type\n                if args.prediction_type is not None:\n                    # set prediction_type of scheduler if defined\n                    noise_scheduler.register_to_config(prediction_type=args.prediction_type)\n\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # Prepare micro-conditions.\n                added_cond_kwargs = {\"resolution\": None, \"aspect_ratio\": None}\n                if getattr(transformer, \"module\", transformer).config.sample_size == 128:\n                    resolution = torch.tensor([args.resolution, args.resolution]).repeat(bsz, 1)\n                    aspect_ratio = torch.tensor([float(args.resolution / args.resolution)]).repeat(bsz, 1)\n                    resolution = resolution.to(dtype=weight_dtype, device=latents.device)\n                    aspect_ratio = aspect_ratio.to(dtype=weight_dtype, device=latents.device)\n                    added_cond_kwargs = {\"resolution\": resolution, \"aspect_ratio\": aspect_ratio}\n\n                # Predict the noise residual and compute loss\n                model_pred = controlnet_transformer(\n                    noisy_latents,\n                    encoder_hidden_states=prompt_embeds,\n                    encoder_attention_mask=prompt_attention_mask,\n                    timestep=timesteps,\n                    controlnet_cond=controlnet_image_latents,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                if transformer.config.out_channels // 2 == latent_channels:\n                    model_pred = model_pred.chunk(2, dim=1)[0]\n                else:\n                    model_pred = model_pred\n\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    if noise_scheduler.config.prediction_type == \"v_prediction\":\n                        # Velocity objective requires that we add one to SNR values before we divide by them.\n                        snr = snr + 1\n                    mse_loss_weights = (\n                        torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr\n                    )\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = controlnet_transformer.controlnet.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        log_validation(\n                            vae,\n                            transformer,\n                            controlnet_transformer.controlnet,\n                            tokenizer,\n                            noise_scheduler,\n                            text_encoder,\n                            args,\n                            accelerator,\n                            weight_dtype,\n                            global_step,\n                            is_final_validation=False,\n                        )\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        controlnet = unwrap_model(controlnet_transformer.controlnet, keep_fp32_wrapper=False)\n        controlnet.save_pretrained(os.path.join(args.output_dir, \"controlnet\"))\n\n        image_logs = None\n        if args.validation_prompt is not None:\n            image_logs = log_validation(\n                vae,\n                transformer,\n                controlnet,\n                tokenizer,\n                noise_scheduler,\n                text_encoder,\n                args,\n                accelerator,\n                weight_dtype,\n                global_step,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                image_logs=image_logs,\n                base_model=args.pretrained_model_name_or_path,\n                dataset_name=args.dataset_name,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/promptdiffusion/README.md",
    "content": "# PromptDiffusion Pipeline\n\nFrom the project [page](https://zhendong-wang.github.io/prompt-diffusion.github.io/)\n\n\"With a prompt consisting of a task-specific example pair of images and text guidance, and a new query image, Prompt Diffusion can comprehend the desired task and generate the corresponding output image on both seen (trained) and unseen (new) task types.\"\n\nFor any usage questions, please refer to the [paper](https://huggingface.co/papers/2305.01115).\n\nPrepare models by converting them from the [checkpoint](https://huggingface.co/zhendongw/prompt-diffusion)\n\nTo convert the controlnet, use cldm_v15.yaml from the [repository](https://github.com/Zhendong-Wang/Prompt-Diffusion/tree/main/models/):\n\n```bash\npython convert_original_promptdiffusion_to_diffusers.py --checkpoint_path path-to-network-step04999.ckpt --original_config_file path-to-cldm_v15.yaml --dump_path path-to-output-directory\n```\n\nTo learn about how to convert the fine-tuned stable diffusion model, see the [Load different Stable Diffusion formats guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/other-formats).\n\n\n```py\nimport torch\nfrom diffusers import UniPCMultistepScheduler\nfrom diffusers.utils import load_image\nfrom promptdiffusioncontrolnet import PromptDiffusionControlNetModel\nfrom pipeline_prompt_diffusion import PromptDiffusionPipeline\n\n\nfrom PIL import ImageOps\n\nimage_a = ImageOps.invert(load_image(\"https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house_line.png?raw=true\"))\n\nimage_b = load_image(\"https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house.png?raw=true\")\nquery = ImageOps.invert(load_image(\"https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/new_01.png?raw=true\"))\n\n# load prompt diffusion controlnet and prompt diffusion\n\ncontrolnet = PromptDiffusionControlNetModel.from_pretrained(\"iczaw/prompt-diffusion-diffusers\", subfolder=\"controlnet\", torch_dtype=torch.float16)\nmodel_id = \"path-to-model\"\npipe = PromptDiffusionPipeline.from_pretrained(\"iczaw/prompt-diffusion-diffusers\", subfolder=\"base\", controlnet=controlnet, torch_dtype=torch.float16, variant=\"fp16\")\n\n# speed up diffusion process with faster scheduler and memory optimization\npipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n# remove following line if xformers is not installed\npipe.enable_xformers_memory_efficient_attention()\npipe.enable_model_cpu_offload()\n# generate image\ngenerator = torch.manual_seed(0)\nimage = pipe(\"a tortoise\", num_inference_steps=20, generator=generator, image_pair=[image_a,image_b], image=query).images[0]\n```\n"
  },
  {
    "path": "examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py",
    "content": "# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Conversion script for stable diffusion checkpoints which _only_ contain a controlnet.\"\"\"\n\nimport argparse\nimport re\nfrom contextlib import nullcontext\nfrom io import BytesIO\nfrom typing import Dict, Optional, Union\n\nimport requests\nimport torch\nimport yaml\nfrom promptdiffusioncontrolnet import PromptDiffusionControlNetModel\nfrom transformers import (\n    AutoFeatureExtractor,\n    BertTokenizerFast,\n    CLIPImageProcessor,\n    CLIPTextConfig,\n    CLIPTextModel,\n    CLIPTextModelWithProjection,\n    CLIPTokenizer,\n    CLIPVisionConfig,\n    CLIPVisionModelWithProjection,\n)\n\nfrom diffusers.models import (\n    AutoencoderKL,\n    ControlNetModel,\n    PriorTransformer,\n    UNet2DConditionModel,\n)\nfrom diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel\nfrom diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer\nfrom diffusers.schedulers import (\n    DDIMScheduler,\n    DDPMScheduler,\n    DPMSolverMultistepScheduler,\n    EulerAncestralDiscreteScheduler,\n    EulerDiscreteScheduler,\n    HeunDiscreteScheduler,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n    UnCLIPScheduler,\n)\nfrom diffusers.utils import is_accelerate_available, logging\nfrom diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT\n\n\nif is_accelerate_available():\n    from accelerate import init_empty_weights\n    from accelerate.utils import set_module_tensor_to_device\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef shave_segments(path, n_shave_prefix_segments=1):\n    \"\"\"\n    Removes segments. Positive values shave the first segments, negative shave the last segments.\n    \"\"\"\n    if n_shave_prefix_segments >= 0:\n        return \".\".join(path.split(\".\")[n_shave_prefix_segments:])\n    else:\n        return \".\".join(path.split(\".\")[:n_shave_prefix_segments])\n\n\ndef renew_resnet_paths(old_list, n_shave_prefix_segments=0):\n    \"\"\"\n    Updates paths inside resnets to the new naming scheme (local renaming)\n    \"\"\"\n    mapping = []\n    for old_item in old_list:\n        new_item = old_item.replace(\"in_layers.0\", \"norm1\")\n        new_item = new_item.replace(\"in_layers.2\", \"conv1\")\n\n        new_item = new_item.replace(\"out_layers.0\", \"norm2\")\n        new_item = new_item.replace(\"out_layers.3\", \"conv2\")\n\n        new_item = new_item.replace(\"emb_layers.1\", \"time_emb_proj\")\n        new_item = new_item.replace(\"skip_connection\", \"conv_shortcut\")\n\n        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)\n\n        mapping.append({\"old\": old_item, \"new\": new_item})\n\n    return mapping\n\n\ndef renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):\n    \"\"\"\n    Updates paths inside resnets to the new naming scheme (local renaming)\n    \"\"\"\n    mapping = []\n    for old_item in old_list:\n        new_item = old_item\n\n        new_item = new_item.replace(\"nin_shortcut\", \"conv_shortcut\")\n        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)\n\n        mapping.append({\"old\": old_item, \"new\": new_item})\n\n    return mapping\n\n\ndef renew_attention_paths(old_list, n_shave_prefix_segments=0):\n    \"\"\"\n    Updates paths inside attentions to the new naming scheme (local renaming)\n    \"\"\"\n    mapping = []\n    for old_item in old_list:\n        new_item = old_item\n\n        #         new_item = new_item.replace('norm.weight', 'group_norm.weight')\n        #         new_item = new_item.replace('norm.bias', 'group_norm.bias')\n\n        #         new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')\n        #         new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')\n\n        #         new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)\n\n        mapping.append({\"old\": old_item, \"new\": new_item})\n\n    return mapping\n\n\ndef renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):\n    \"\"\"\n    Updates paths inside attentions to the new naming scheme (local renaming)\n    \"\"\"\n    mapping = []\n    for old_item in old_list:\n        new_item = old_item\n\n        new_item = new_item.replace(\"norm.weight\", \"group_norm.weight\")\n        new_item = new_item.replace(\"norm.bias\", \"group_norm.bias\")\n\n        new_item = new_item.replace(\"q.weight\", \"to_q.weight\")\n        new_item = new_item.replace(\"q.bias\", \"to_q.bias\")\n\n        new_item = new_item.replace(\"k.weight\", \"to_k.weight\")\n        new_item = new_item.replace(\"k.bias\", \"to_k.bias\")\n\n        new_item = new_item.replace(\"v.weight\", \"to_v.weight\")\n        new_item = new_item.replace(\"v.bias\", \"to_v.bias\")\n\n        new_item = new_item.replace(\"proj_out.weight\", \"to_out.0.weight\")\n        new_item = new_item.replace(\"proj_out.bias\", \"to_out.0.bias\")\n\n        new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)\n\n        mapping.append({\"old\": old_item, \"new\": new_item})\n\n    return mapping\n\n\ndef assign_to_checkpoint(\n    paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None\n):\n    \"\"\"\n    This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits\n    attention layers, and takes into account additional replacements that may arise.\n\n    Assigns the weights to the new checkpoint.\n    \"\"\"\n    assert isinstance(paths, list), \"Paths should be a list of dicts containing 'old' and 'new' keys.\"\n\n    # Splits the attention layers into three variables.\n    if attention_paths_to_split is not None:\n        for path, path_map in attention_paths_to_split.items():\n            old_tensor = old_checkpoint[path]\n            channels = old_tensor.shape[0] // 3\n\n            target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)\n\n            num_heads = old_tensor.shape[0] // config[\"num_head_channels\"] // 3\n\n            old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])\n            query, key, value = old_tensor.split(channels // num_heads, dim=1)\n\n            checkpoint[path_map[\"query\"]] = query.reshape(target_shape)\n            checkpoint[path_map[\"key\"]] = key.reshape(target_shape)\n            checkpoint[path_map[\"value\"]] = value.reshape(target_shape)\n\n    for path in paths:\n        new_path = path[\"new\"]\n\n        # These have already been assigned\n        if attention_paths_to_split is not None and new_path in attention_paths_to_split:\n            continue\n\n        # Global renaming happens here\n        new_path = new_path.replace(\"middle_block.0\", \"mid_block.resnets.0\")\n        new_path = new_path.replace(\"middle_block.1\", \"mid_block.attentions.0\")\n        new_path = new_path.replace(\"middle_block.2\", \"mid_block.resnets.1\")\n\n        if additional_replacements is not None:\n            for replacement in additional_replacements:\n                new_path = new_path.replace(replacement[\"old\"], replacement[\"new\"])\n\n        # proj_attn.weight has to be converted from conv 1D to linear\n        is_attn_weight = \"proj_attn.weight\" in new_path or (\"attentions\" in new_path and \"to_\" in new_path)\n        shape = old_checkpoint[path[\"old\"]].shape\n        if is_attn_weight and len(shape) == 3:\n            checkpoint[new_path] = old_checkpoint[path[\"old\"]][:, :, 0]\n        elif is_attn_weight and len(shape) == 4:\n            checkpoint[new_path] = old_checkpoint[path[\"old\"]][:, :, 0, 0]\n        else:\n            checkpoint[new_path] = old_checkpoint[path[\"old\"]]\n\n\ndef conv_attn_to_linear(checkpoint):\n    keys = list(checkpoint.keys())\n    attn_keys = [\"query.weight\", \"key.weight\", \"value.weight\"]\n    for key in keys:\n        if \".\".join(key.split(\".\")[-2:]) in attn_keys:\n            if checkpoint[key].ndim > 2:\n                checkpoint[key] = checkpoint[key][:, :, 0, 0]\n        elif \"proj_attn.weight\" in key:\n            if checkpoint[key].ndim > 2:\n                checkpoint[key] = checkpoint[key][:, :, 0]\n\n\ndef create_unet_diffusers_config(original_config, image_size: int, controlnet=False):\n    \"\"\"\n    Creates a config for the diffusers based on the config of the LDM model.\n    \"\"\"\n    if controlnet:\n        unet_params = original_config[\"model\"][\"params\"][\"control_stage_config\"][\"params\"]\n    else:\n        if (\n            \"unet_config\" in original_config[\"model\"][\"params\"]\n            and original_config[\"model\"][\"params\"][\"unet_config\"] is not None\n        ):\n            unet_params = original_config[\"model\"][\"params\"][\"unet_config\"][\"params\"]\n        else:\n            unet_params = original_config[\"model\"][\"params\"][\"network_config\"][\"params\"]\n\n    vae_params = original_config[\"model\"][\"params\"][\"first_stage_config\"][\"params\"][\"ddconfig\"]\n\n    block_out_channels = [unet_params[\"model_channels\"] * mult for mult in unet_params[\"channel_mult\"]]\n\n    down_block_types = []\n    resolution = 1\n    for i in range(len(block_out_channels)):\n        block_type = \"CrossAttnDownBlock2D\" if resolution in unet_params[\"attention_resolutions\"] else \"DownBlock2D\"\n        down_block_types.append(block_type)\n        if i != len(block_out_channels) - 1:\n            resolution *= 2\n\n    up_block_types = []\n    for i in range(len(block_out_channels)):\n        block_type = \"CrossAttnUpBlock2D\" if resolution in unet_params[\"attention_resolutions\"] else \"UpBlock2D\"\n        up_block_types.append(block_type)\n        resolution //= 2\n\n    if unet_params[\"transformer_depth\"] is not None:\n        transformer_layers_per_block = (\n            unet_params[\"transformer_depth\"]\n            if isinstance(unet_params[\"transformer_depth\"], int)\n            else list(unet_params[\"transformer_depth\"])\n        )\n    else:\n        transformer_layers_per_block = 1\n\n    vae_scale_factor = 2 ** (len(vae_params[\"ch_mult\"]) - 1)\n\n    head_dim = unet_params[\"num_heads\"] if \"num_heads\" in unet_params else None\n    use_linear_projection = (\n        unet_params[\"use_linear_in_transformer\"] if \"use_linear_in_transformer\" in unet_params else False\n    )\n    if use_linear_projection:\n        # stable diffusion 2-base-512 and 2-768\n        if head_dim is None:\n            head_dim_mult = unet_params[\"model_channels\"] // unet_params[\"num_head_channels\"]\n            head_dim = [head_dim_mult * c for c in list(unet_params[\"channel_mult\"])]\n\n    class_embed_type = None\n    addition_embed_type = None\n    addition_time_embed_dim = None\n    projection_class_embeddings_input_dim = None\n    context_dim = None\n\n    if unet_params[\"context_dim\"] is not None:\n        context_dim = (\n            unet_params[\"context_dim\"]\n            if isinstance(unet_params[\"context_dim\"], int)\n            else unet_params[\"context_dim\"][0]\n        )\n\n    if \"num_classes\" in unet_params:\n        if unet_params[\"num_classes\"] == \"sequential\":\n            if context_dim in [2048, 1280]:\n                # SDXL\n                addition_embed_type = \"text_time\"\n                addition_time_embed_dim = 256\n            else:\n                class_embed_type = \"projection\"\n            assert \"adm_in_channels\" in unet_params\n            projection_class_embeddings_input_dim = unet_params[\"adm_in_channels\"]\n\n    config = {\n        \"sample_size\": image_size // vae_scale_factor,\n        \"in_channels\": unet_params[\"in_channels\"],\n        \"down_block_types\": tuple(down_block_types),\n        \"block_out_channels\": tuple(block_out_channels),\n        \"layers_per_block\": unet_params[\"num_res_blocks\"],\n        \"cross_attention_dim\": context_dim,\n        \"attention_head_dim\": head_dim,\n        \"use_linear_projection\": use_linear_projection,\n        \"class_embed_type\": class_embed_type,\n        \"addition_embed_type\": addition_embed_type,\n        \"addition_time_embed_dim\": addition_time_embed_dim,\n        \"projection_class_embeddings_input_dim\": projection_class_embeddings_input_dim,\n        \"transformer_layers_per_block\": transformer_layers_per_block,\n    }\n\n    if \"disable_self_attentions\" in unet_params:\n        config[\"only_cross_attention\"] = unet_params[\"disable_self_attentions\"]\n\n    if \"num_classes\" in unet_params and isinstance(unet_params[\"num_classes\"], int):\n        config[\"num_class_embeds\"] = unet_params[\"num_classes\"]\n\n    if controlnet:\n        config[\"conditioning_channels\"] = unet_params[\"hint_channels\"]\n    else:\n        config[\"out_channels\"] = unet_params[\"out_channels\"]\n        config[\"up_block_types\"] = tuple(up_block_types)\n\n    return config\n\n\ndef create_vae_diffusers_config(original_config, image_size: int):\n    \"\"\"\n    Creates a config for the diffusers based on the config of the LDM model.\n    \"\"\"\n    vae_params = original_config[\"model\"][\"params\"][\"first_stage_config\"][\"params\"][\"ddconfig\"]\n    _ = original_config[\"model\"][\"params\"][\"first_stage_config\"][\"params\"][\"embed_dim\"]\n\n    block_out_channels = [vae_params[\"ch\"] * mult for mult in vae_params[\"ch_mult\"]]\n    down_block_types = [\"DownEncoderBlock2D\"] * len(block_out_channels)\n    up_block_types = [\"UpDecoderBlock2D\"] * len(block_out_channels)\n\n    config = {\n        \"sample_size\": image_size,\n        \"in_channels\": vae_params[\"in_channels\"],\n        \"out_channels\": vae_params[\"out_ch\"],\n        \"down_block_types\": tuple(down_block_types),\n        \"up_block_types\": tuple(up_block_types),\n        \"block_out_channels\": tuple(block_out_channels),\n        \"latent_channels\": vae_params[\"z_channels\"],\n        \"layers_per_block\": vae_params[\"num_res_blocks\"],\n    }\n    return config\n\n\ndef create_diffusers_schedular(original_config):\n    schedular = DDIMScheduler(\n        num_train_timesteps=original_config[\"model\"][\"params\"][\"timesteps\"],\n        beta_start=original_config[\"model\"][\"params\"][\"linear_start\"],\n        beta_end=original_config[\"model\"][\"params\"][\"linear_end\"],\n        beta_schedule=\"scaled_linear\",\n    )\n    return schedular\n\n\ndef create_ldm_bert_config(original_config):\n    bert_params = original_config[\"model\"][\"params\"][\"cond_stage_config\"][\"params\"]\n    config = LDMBertConfig(\n        d_model=bert_params.n_embed,\n        encoder_layers=bert_params.n_layer,\n        encoder_ffn_dim=bert_params.n_embed * 4,\n    )\n    return config\n\n\ndef convert_ldm_unet_checkpoint(\n    checkpoint,\n    config,\n    path=None,\n    extract_ema=False,\n    controlnet=False,\n    skip_extract_state_dict=False,\n    promptdiffusion=False,\n):\n    \"\"\"\n    Takes a state dict and a config, and returns a converted checkpoint.\n    \"\"\"\n\n    if skip_extract_state_dict:\n        unet_state_dict = checkpoint\n    else:\n        # extract state_dict for UNet\n        unet_state_dict = {}\n        keys = list(checkpoint.keys())\n\n        if controlnet:\n            unet_key = \"control_model.\"\n        else:\n            unet_key = \"model.diffusion_model.\"\n\n        # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA\n        if sum(k.startswith(\"model_ema\") for k in keys) > 100 and extract_ema:\n            logger.warning(f\"Checkpoint {path} has both EMA and non-EMA weights.\")\n            logger.warning(\n                \"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA\"\n                \" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag.\"\n            )\n            for key in keys:\n                if key.startswith(\"model.diffusion_model\"):\n                    flat_ema_key = \"model_ema.\" + \"\".join(key.split(\".\")[1:])\n                    unet_state_dict[key.replace(unet_key, \"\")] = checkpoint.pop(flat_ema_key)\n        else:\n            if sum(k.startswith(\"model_ema\") for k in keys) > 100:\n                logger.warning(\n                    \"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA\"\n                    \" weights (usually better for inference), please make sure to add the `--extract_ema` flag.\"\n                )\n\n            for key in keys:\n                if key.startswith(unet_key):\n                    unet_state_dict[key.replace(unet_key, \"\")] = checkpoint.pop(key)\n\n    new_checkpoint = {}\n\n    new_checkpoint[\"time_embedding.linear_1.weight\"] = unet_state_dict[\"time_embed.0.weight\"]\n    new_checkpoint[\"time_embedding.linear_1.bias\"] = unet_state_dict[\"time_embed.0.bias\"]\n    new_checkpoint[\"time_embedding.linear_2.weight\"] = unet_state_dict[\"time_embed.2.weight\"]\n    new_checkpoint[\"time_embedding.linear_2.bias\"] = unet_state_dict[\"time_embed.2.bias\"]\n\n    if config[\"class_embed_type\"] is None:\n        # No parameters to port\n        ...\n    elif config[\"class_embed_type\"] == \"timestep\" or config[\"class_embed_type\"] == \"projection\":\n        new_checkpoint[\"class_embedding.linear_1.weight\"] = unet_state_dict[\"label_emb.0.0.weight\"]\n        new_checkpoint[\"class_embedding.linear_1.bias\"] = unet_state_dict[\"label_emb.0.0.bias\"]\n        new_checkpoint[\"class_embedding.linear_2.weight\"] = unet_state_dict[\"label_emb.0.2.weight\"]\n        new_checkpoint[\"class_embedding.linear_2.bias\"] = unet_state_dict[\"label_emb.0.2.bias\"]\n    else:\n        raise NotImplementedError(f\"Not implemented `class_embed_type`: {config['class_embed_type']}\")\n\n    if config[\"addition_embed_type\"] == \"text_time\":\n        new_checkpoint[\"add_embedding.linear_1.weight\"] = unet_state_dict[\"label_emb.0.0.weight\"]\n        new_checkpoint[\"add_embedding.linear_1.bias\"] = unet_state_dict[\"label_emb.0.0.bias\"]\n        new_checkpoint[\"add_embedding.linear_2.weight\"] = unet_state_dict[\"label_emb.0.2.weight\"]\n        new_checkpoint[\"add_embedding.linear_2.bias\"] = unet_state_dict[\"label_emb.0.2.bias\"]\n\n    # Relevant to StableDiffusionUpscalePipeline\n    if \"num_class_embeds\" in config:\n        if (config[\"num_class_embeds\"] is not None) and (\"label_emb.weight\" in unet_state_dict):\n            new_checkpoint[\"class_embedding.weight\"] = unet_state_dict[\"label_emb.weight\"]\n\n    new_checkpoint[\"conv_in.weight\"] = unet_state_dict[\"input_blocks.0.0.weight\"]\n    new_checkpoint[\"conv_in.bias\"] = unet_state_dict[\"input_blocks.0.0.bias\"]\n\n    if not controlnet:\n        new_checkpoint[\"conv_norm_out.weight\"] = unet_state_dict[\"out.0.weight\"]\n        new_checkpoint[\"conv_norm_out.bias\"] = unet_state_dict[\"out.0.bias\"]\n        new_checkpoint[\"conv_out.weight\"] = unet_state_dict[\"out.2.weight\"]\n        new_checkpoint[\"conv_out.bias\"] = unet_state_dict[\"out.2.bias\"]\n\n    # Retrieves the keys for the input blocks only\n    num_input_blocks = len({\".\".join(layer.split(\".\")[:2]) for layer in unet_state_dict if \"input_blocks\" in layer})\n    input_blocks = {\n        layer_id: [key for key in unet_state_dict if f\"input_blocks.{layer_id}\" in key]\n        for layer_id in range(num_input_blocks)\n    }\n\n    # Retrieves the keys for the middle blocks only\n    num_middle_blocks = len({\".\".join(layer.split(\".\")[:2]) for layer in unet_state_dict if \"middle_block\" in layer})\n    middle_blocks = {\n        layer_id: [key for key in unet_state_dict if f\"middle_block.{layer_id}\" in key]\n        for layer_id in range(num_middle_blocks)\n    }\n\n    # Retrieves the keys for the output blocks only\n    num_output_blocks = len({\".\".join(layer.split(\".\")[:2]) for layer in unet_state_dict if \"output_blocks\" in layer})\n    output_blocks = {\n        layer_id: [key for key in unet_state_dict if f\"output_blocks.{layer_id}\" in key]\n        for layer_id in range(num_output_blocks)\n    }\n\n    for i in range(1, num_input_blocks):\n        block_id = (i - 1) // (config[\"layers_per_block\"] + 1)\n        layer_in_block_id = (i - 1) % (config[\"layers_per_block\"] + 1)\n\n        resnets = [\n            key for key in input_blocks[i] if f\"input_blocks.{i}.0\" in key and f\"input_blocks.{i}.0.op\" not in key\n        ]\n        attentions = [key for key in input_blocks[i] if f\"input_blocks.{i}.1\" in key]\n\n        if f\"input_blocks.{i}.0.op.weight\" in unet_state_dict:\n            new_checkpoint[f\"down_blocks.{block_id}.downsamplers.0.conv.weight\"] = unet_state_dict.pop(\n                f\"input_blocks.{i}.0.op.weight\"\n            )\n            new_checkpoint[f\"down_blocks.{block_id}.downsamplers.0.conv.bias\"] = unet_state_dict.pop(\n                f\"input_blocks.{i}.0.op.bias\"\n            )\n\n        paths = renew_resnet_paths(resnets)\n        meta_path = {\"old\": f\"input_blocks.{i}.0\", \"new\": f\"down_blocks.{block_id}.resnets.{layer_in_block_id}\"}\n        assign_to_checkpoint(\n            paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config\n        )\n\n        if len(attentions):\n            paths = renew_attention_paths(attentions)\n\n            meta_path = {\"old\": f\"input_blocks.{i}.1\", \"new\": f\"down_blocks.{block_id}.attentions.{layer_in_block_id}\"}\n            assign_to_checkpoint(\n                paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config\n            )\n\n    resnet_0 = middle_blocks[0]\n    attentions = middle_blocks[1]\n    resnet_1 = middle_blocks[2]\n\n    resnet_0_paths = renew_resnet_paths(resnet_0)\n    assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)\n\n    resnet_1_paths = renew_resnet_paths(resnet_1)\n    assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)\n\n    attentions_paths = renew_attention_paths(attentions)\n    meta_path = {\"old\": \"middle_block.1\", \"new\": \"mid_block.attentions.0\"}\n    assign_to_checkpoint(\n        attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config\n    )\n\n    for i in range(num_output_blocks):\n        block_id = i // (config[\"layers_per_block\"] + 1)\n        layer_in_block_id = i % (config[\"layers_per_block\"] + 1)\n        output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]\n        output_block_list = {}\n\n        for layer in output_block_layers:\n            layer_id, layer_name = layer.split(\".\")[0], shave_segments(layer, 1)\n            if layer_id in output_block_list:\n                output_block_list[layer_id].append(layer_name)\n            else:\n                output_block_list[layer_id] = [layer_name]\n\n        if len(output_block_list) > 1:\n            resnets = [key for key in output_blocks[i] if f\"output_blocks.{i}.0\" in key]\n            attentions = [key for key in output_blocks[i] if f\"output_blocks.{i}.1\" in key]\n\n            resnet_0_paths = renew_resnet_paths(resnets)\n            paths = renew_resnet_paths(resnets)\n\n            meta_path = {\"old\": f\"output_blocks.{i}.0\", \"new\": f\"up_blocks.{block_id}.resnets.{layer_in_block_id}\"}\n            assign_to_checkpoint(\n                paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config\n            )\n\n            output_block_list = {k: sorted(v) for k, v in output_block_list.items()}\n            if [\"conv.bias\", \"conv.weight\"] in output_block_list.values():\n                index = list(output_block_list.values()).index([\"conv.bias\", \"conv.weight\"])\n                new_checkpoint[f\"up_blocks.{block_id}.upsamplers.0.conv.weight\"] = unet_state_dict[\n                    f\"output_blocks.{i}.{index}.conv.weight\"\n                ]\n                new_checkpoint[f\"up_blocks.{block_id}.upsamplers.0.conv.bias\"] = unet_state_dict[\n                    f\"output_blocks.{i}.{index}.conv.bias\"\n                ]\n\n                # Clear attentions as they have been attributed above.\n                if len(attentions) == 2:\n                    attentions = []\n\n            if len(attentions):\n                paths = renew_attention_paths(attentions)\n                meta_path = {\n                    \"old\": f\"output_blocks.{i}.1\",\n                    \"new\": f\"up_blocks.{block_id}.attentions.{layer_in_block_id}\",\n                }\n                assign_to_checkpoint(\n                    paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config\n                )\n        else:\n            resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)\n            for path in resnet_0_paths:\n                old_path = \".\".join([\"output_blocks\", str(i), path[\"old\"]])\n                new_path = \".\".join([\"up_blocks\", str(block_id), \"resnets\", str(layer_in_block_id), path[\"new\"]])\n\n                new_checkpoint[new_path] = unet_state_dict[old_path]\n\n    if controlnet and not promptdiffusion:\n        # conditioning embedding\n\n        orig_index = 0\n\n        new_checkpoint[\"controlnet_cond_embedding.conv_in.weight\"] = unet_state_dict.pop(\n            f\"input_hint_block.{orig_index}.weight\"\n        )\n        new_checkpoint[\"controlnet_cond_embedding.conv_in.bias\"] = unet_state_dict.pop(\n            f\"input_hint_block.{orig_index}.bias\"\n        )\n\n        orig_index += 2\n\n        diffusers_index = 0\n\n        while diffusers_index < 6:\n            new_checkpoint[f\"controlnet_cond_embedding.blocks.{diffusers_index}.weight\"] = unet_state_dict.pop(\n                f\"input_hint_block.{orig_index}.weight\"\n            )\n            new_checkpoint[f\"controlnet_cond_embedding.blocks.{diffusers_index}.bias\"] = unet_state_dict.pop(\n                f\"input_hint_block.{orig_index}.bias\"\n            )\n            diffusers_index += 1\n            orig_index += 2\n\n        new_checkpoint[\"controlnet_cond_embedding.conv_out.weight\"] = unet_state_dict.pop(\n            f\"input_hint_block.{orig_index}.weight\"\n        )\n        new_checkpoint[\"controlnet_cond_embedding.conv_out.bias\"] = unet_state_dict.pop(\n            f\"input_hint_block.{orig_index}.bias\"\n        )\n\n        # down blocks\n        for i in range(num_input_blocks):\n            new_checkpoint[f\"controlnet_down_blocks.{i}.weight\"] = unet_state_dict.pop(f\"zero_convs.{i}.0.weight\")\n            new_checkpoint[f\"controlnet_down_blocks.{i}.bias\"] = unet_state_dict.pop(f\"zero_convs.{i}.0.bias\")\n\n        # mid block\n        new_checkpoint[\"controlnet_mid_block.weight\"] = unet_state_dict.pop(\"middle_block_out.0.weight\")\n        new_checkpoint[\"controlnet_mid_block.bias\"] = unet_state_dict.pop(\"middle_block_out.0.bias\")\n\n    if promptdiffusion:\n        # conditioning embedding\n\n        orig_index = 0\n\n        new_checkpoint[\"controlnet_cond_embedding.conv_in.weight\"] = unet_state_dict.pop(\n            f\"input_hint_block.{orig_index}.weight\"\n        )\n        new_checkpoint[\"controlnet_cond_embedding.conv_in.bias\"] = unet_state_dict.pop(\n            f\"input_hint_block.{orig_index}.bias\"\n        )\n\n        new_checkpoint[\"controlnet_query_cond_embedding.conv_in.weight\"] = unet_state_dict.pop(\n            f\"input_cond_block.{orig_index}.weight\"\n        )\n        new_checkpoint[\"controlnet_query_cond_embedding.conv_in.bias\"] = unet_state_dict.pop(\n            f\"input_cond_block.{orig_index}.bias\"\n        )\n        orig_index += 2\n\n        diffusers_index = 0\n\n        while diffusers_index < 6:\n            new_checkpoint[f\"controlnet_cond_embedding.blocks.{diffusers_index}.weight\"] = unet_state_dict.pop(\n                f\"input_hint_block.{orig_index}.weight\"\n            )\n            new_checkpoint[f\"controlnet_cond_embedding.blocks.{diffusers_index}.bias\"] = unet_state_dict.pop(\n                f\"input_hint_block.{orig_index}.bias\"\n            )\n            new_checkpoint[f\"controlnet_query_cond_embedding.blocks.{diffusers_index}.weight\"] = unet_state_dict.pop(\n                f\"input_cond_block.{orig_index}.weight\"\n            )\n            new_checkpoint[f\"controlnet_query_cond_embedding.blocks.{diffusers_index}.bias\"] = unet_state_dict.pop(\n                f\"input_cond_block.{orig_index}.bias\"\n            )\n            diffusers_index += 1\n            orig_index += 2\n\n        new_checkpoint[\"controlnet_cond_embedding.conv_out.weight\"] = unet_state_dict.pop(\n            f\"input_hint_block.{orig_index}.weight\"\n        )\n        new_checkpoint[\"controlnet_cond_embedding.conv_out.bias\"] = unet_state_dict.pop(\n            f\"input_hint_block.{orig_index}.bias\"\n        )\n\n        new_checkpoint[\"controlnet_query_cond_embedding.conv_out.weight\"] = unet_state_dict.pop(\n            f\"input_cond_block.{orig_index}.weight\"\n        )\n        new_checkpoint[\"controlnet_query_cond_embedding.conv_out.bias\"] = unet_state_dict.pop(\n            f\"input_cond_block.{orig_index}.bias\"\n        )\n        # down blocks\n        for i in range(num_input_blocks):\n            new_checkpoint[f\"controlnet_down_blocks.{i}.weight\"] = unet_state_dict.pop(f\"zero_convs.{i}.0.weight\")\n            new_checkpoint[f\"controlnet_down_blocks.{i}.bias\"] = unet_state_dict.pop(f\"zero_convs.{i}.0.bias\")\n\n        # mid block\n        new_checkpoint[\"controlnet_mid_block.weight\"] = unet_state_dict.pop(\"middle_block_out.0.weight\")\n        new_checkpoint[\"controlnet_mid_block.bias\"] = unet_state_dict.pop(\"middle_block_out.0.bias\")\n\n    return new_checkpoint\n\n\ndef convert_ldm_vae_checkpoint(checkpoint, config):\n    # extract state dict for VAE\n    vae_state_dict = {}\n    keys = list(checkpoint.keys())\n    vae_key = \"first_stage_model.\" if any(k.startswith(\"first_stage_model.\") for k in keys) else \"\"\n    for key in keys:\n        if key.startswith(vae_key):\n            vae_state_dict[key.replace(vae_key, \"\")] = checkpoint.get(key)\n\n    new_checkpoint = {}\n\n    new_checkpoint[\"encoder.conv_in.weight\"] = vae_state_dict[\"encoder.conv_in.weight\"]\n    new_checkpoint[\"encoder.conv_in.bias\"] = vae_state_dict[\"encoder.conv_in.bias\"]\n    new_checkpoint[\"encoder.conv_out.weight\"] = vae_state_dict[\"encoder.conv_out.weight\"]\n    new_checkpoint[\"encoder.conv_out.bias\"] = vae_state_dict[\"encoder.conv_out.bias\"]\n    new_checkpoint[\"encoder.conv_norm_out.weight\"] = vae_state_dict[\"encoder.norm_out.weight\"]\n    new_checkpoint[\"encoder.conv_norm_out.bias\"] = vae_state_dict[\"encoder.norm_out.bias\"]\n\n    new_checkpoint[\"decoder.conv_in.weight\"] = vae_state_dict[\"decoder.conv_in.weight\"]\n    new_checkpoint[\"decoder.conv_in.bias\"] = vae_state_dict[\"decoder.conv_in.bias\"]\n    new_checkpoint[\"decoder.conv_out.weight\"] = vae_state_dict[\"decoder.conv_out.weight\"]\n    new_checkpoint[\"decoder.conv_out.bias\"] = vae_state_dict[\"decoder.conv_out.bias\"]\n    new_checkpoint[\"decoder.conv_norm_out.weight\"] = vae_state_dict[\"decoder.norm_out.weight\"]\n    new_checkpoint[\"decoder.conv_norm_out.bias\"] = vae_state_dict[\"decoder.norm_out.bias\"]\n\n    new_checkpoint[\"quant_conv.weight\"] = vae_state_dict[\"quant_conv.weight\"]\n    new_checkpoint[\"quant_conv.bias\"] = vae_state_dict[\"quant_conv.bias\"]\n    new_checkpoint[\"post_quant_conv.weight\"] = vae_state_dict[\"post_quant_conv.weight\"]\n    new_checkpoint[\"post_quant_conv.bias\"] = vae_state_dict[\"post_quant_conv.bias\"]\n\n    # Retrieves the keys for the encoder down blocks only\n    num_down_blocks = len({\".\".join(layer.split(\".\")[:3]) for layer in vae_state_dict if \"encoder.down\" in layer})\n    down_blocks = {\n        layer_id: [key for key in vae_state_dict if f\"down.{layer_id}\" in key] for layer_id in range(num_down_blocks)\n    }\n\n    # Retrieves the keys for the decoder up blocks only\n    num_up_blocks = len({\".\".join(layer.split(\".\")[:3]) for layer in vae_state_dict if \"decoder.up\" in layer})\n    up_blocks = {\n        layer_id: [key for key in vae_state_dict if f\"up.{layer_id}\" in key] for layer_id in range(num_up_blocks)\n    }\n\n    for i in range(num_down_blocks):\n        resnets = [key for key in down_blocks[i] if f\"down.{i}\" in key and f\"down.{i}.downsample\" not in key]\n\n        if f\"encoder.down.{i}.downsample.conv.weight\" in vae_state_dict:\n            new_checkpoint[f\"encoder.down_blocks.{i}.downsamplers.0.conv.weight\"] = vae_state_dict.pop(\n                f\"encoder.down.{i}.downsample.conv.weight\"\n            )\n            new_checkpoint[f\"encoder.down_blocks.{i}.downsamplers.0.conv.bias\"] = vae_state_dict.pop(\n                f\"encoder.down.{i}.downsample.conv.bias\"\n            )\n\n        paths = renew_vae_resnet_paths(resnets)\n        meta_path = {\"old\": f\"down.{i}.block\", \"new\": f\"down_blocks.{i}.resnets\"}\n        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n\n    mid_resnets = [key for key in vae_state_dict if \"encoder.mid.block\" in key]\n    num_mid_res_blocks = 2\n    for i in range(1, num_mid_res_blocks + 1):\n        resnets = [key for key in mid_resnets if f\"encoder.mid.block_{i}\" in key]\n\n        paths = renew_vae_resnet_paths(resnets)\n        meta_path = {\"old\": f\"mid.block_{i}\", \"new\": f\"mid_block.resnets.{i - 1}\"}\n        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n\n    mid_attentions = [key for key in vae_state_dict if \"encoder.mid.attn\" in key]\n    paths = renew_vae_attention_paths(mid_attentions)\n    meta_path = {\"old\": \"mid.attn_1\", \"new\": \"mid_block.attentions.0\"}\n    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n    conv_attn_to_linear(new_checkpoint)\n\n    for i in range(num_up_blocks):\n        block_id = num_up_blocks - 1 - i\n        resnets = [\n            key for key in up_blocks[block_id] if f\"up.{block_id}\" in key and f\"up.{block_id}.upsample\" not in key\n        ]\n\n        if f\"decoder.up.{block_id}.upsample.conv.weight\" in vae_state_dict:\n            new_checkpoint[f\"decoder.up_blocks.{i}.upsamplers.0.conv.weight\"] = vae_state_dict[\n                f\"decoder.up.{block_id}.upsample.conv.weight\"\n            ]\n            new_checkpoint[f\"decoder.up_blocks.{i}.upsamplers.0.conv.bias\"] = vae_state_dict[\n                f\"decoder.up.{block_id}.upsample.conv.bias\"\n            ]\n\n        paths = renew_vae_resnet_paths(resnets)\n        meta_path = {\"old\": f\"up.{block_id}.block\", \"new\": f\"up_blocks.{i}.resnets\"}\n        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n\n    mid_resnets = [key for key in vae_state_dict if \"decoder.mid.block\" in key]\n    num_mid_res_blocks = 2\n    for i in range(1, num_mid_res_blocks + 1):\n        resnets = [key for key in mid_resnets if f\"decoder.mid.block_{i}\" in key]\n\n        paths = renew_vae_resnet_paths(resnets)\n        meta_path = {\"old\": f\"mid.block_{i}\", \"new\": f\"mid_block.resnets.{i - 1}\"}\n        assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n\n    mid_attentions = [key for key in vae_state_dict if \"decoder.mid.attn\" in key]\n    paths = renew_vae_attention_paths(mid_attentions)\n    meta_path = {\"old\": \"mid.attn_1\", \"new\": \"mid_block.attentions.0\"}\n    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n    conv_attn_to_linear(new_checkpoint)\n    return new_checkpoint\n\n\ndef convert_ldm_bert_checkpoint(checkpoint, config):\n    def _copy_attn_layer(hf_attn_layer, pt_attn_layer):\n        hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight\n        hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight\n        hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight\n\n        hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight\n        hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias\n\n    def _copy_linear(hf_linear, pt_linear):\n        hf_linear.weight = pt_linear.weight\n        hf_linear.bias = pt_linear.bias\n\n    def _copy_layer(hf_layer, pt_layer):\n        # copy layer norms\n        _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])\n        _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])\n\n        # copy attn\n        _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])\n\n        # copy MLP\n        pt_mlp = pt_layer[1][1]\n        _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])\n        _copy_linear(hf_layer.fc2, pt_mlp.net[2])\n\n    def _copy_layers(hf_layers, pt_layers):\n        for i, hf_layer in enumerate(hf_layers):\n            if i != 0:\n                i += i\n            pt_layer = pt_layers[i : i + 2]\n            _copy_layer(hf_layer, pt_layer)\n\n    hf_model = LDMBertModel(config).eval()\n\n    # copy  embeds\n    hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight\n    hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight\n\n    # copy layer norm\n    _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)\n\n    # copy hidden layers\n    _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)\n\n    _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)\n\n    return hf_model\n\n\ndef convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):\n    if text_encoder is None:\n        config_name = \"openai/clip-vit-large-patch14\"\n        try:\n            config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)\n        except Exception:\n            raise ValueError(\n                f\"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'.\"\n            )\n\n        ctx = init_empty_weights if is_accelerate_available() else nullcontext\n        with ctx():\n            text_model = CLIPTextModel(config)\n    else:\n        text_model = text_encoder\n\n    keys = list(checkpoint.keys())\n\n    text_model_dict = {}\n\n    remove_prefixes = [\"cond_stage_model.transformer\", \"conditioner.embedders.0.transformer\"]\n\n    for key in keys:\n        for prefix in remove_prefixes:\n            if key.startswith(prefix):\n                text_model_dict[key[len(prefix + \".\") :]] = checkpoint[key]\n\n    if is_accelerate_available():\n        for param_name, param in text_model_dict.items():\n            set_module_tensor_to_device(text_model, param_name, \"cpu\", value=param)\n    else:\n        if not (hasattr(text_model, \"embeddings\") and hasattr(text_model.embeddings.position_ids)):\n            text_model_dict.pop(\"text_model.embeddings.position_ids\", None)\n\n        text_model.load_state_dict(text_model_dict)\n\n    return text_model\n\n\ntextenc_conversion_lst = [\n    (\"positional_embedding\", \"text_model.embeddings.position_embedding.weight\"),\n    (\"token_embedding.weight\", \"text_model.embeddings.token_embedding.weight\"),\n    (\"ln_final.weight\", \"text_model.final_layer_norm.weight\"),\n    (\"ln_final.bias\", \"text_model.final_layer_norm.bias\"),\n    (\"text_projection\", \"text_projection.weight\"),\n]\ntextenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}\n\ntextenc_transformer_conversion_lst = [\n    # (stable-diffusion, HF Diffusers)\n    (\"resblocks.\", \"text_model.encoder.layers.\"),\n    (\"ln_1\", \"layer_norm1\"),\n    (\"ln_2\", \"layer_norm2\"),\n    (\".c_fc.\", \".fc1.\"),\n    (\".c_proj.\", \".fc2.\"),\n    (\".attn\", \".self_attn\"),\n    (\"ln_final.\", \"transformer.text_model.final_layer_norm.\"),\n    (\"token_embedding.weight\", \"transformer.text_model.embeddings.token_embedding.weight\"),\n    (\"positional_embedding\", \"transformer.text_model.embeddings.position_embedding.weight\"),\n]\nprotected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}\ntextenc_pattern = re.compile(\"|\".join(protected.keys()))\n\n\ndef convert_paint_by_example_checkpoint(checkpoint, local_files_only=False):\n    config = CLIPVisionConfig.from_pretrained(\"openai/clip-vit-large-patch14\", local_files_only=local_files_only)\n    model = PaintByExampleImageEncoder(config)\n\n    keys = list(checkpoint.keys())\n\n    text_model_dict = {}\n\n    for key in keys:\n        if key.startswith(\"cond_stage_model.transformer\"):\n            text_model_dict[key[len(\"cond_stage_model.transformer.\") :]] = checkpoint[key]\n\n    # load clip vision\n    model.model.load_state_dict(text_model_dict)\n\n    # load mapper\n    keys_mapper = {\n        k[len(\"cond_stage_model.mapper.res\") :]: v\n        for k, v in checkpoint.items()\n        if k.startswith(\"cond_stage_model.mapper\")\n    }\n\n    MAPPING = {\n        \"attn.c_qkv\": [\"attn1.to_q\", \"attn1.to_k\", \"attn1.to_v\"],\n        \"attn.c_proj\": [\"attn1.to_out.0\"],\n        \"ln_1\": [\"norm1\"],\n        \"ln_2\": [\"norm3\"],\n        \"mlp.c_fc\": [\"ff.net.0.proj\"],\n        \"mlp.c_proj\": [\"ff.net.2\"],\n    }\n\n    mapped_weights = {}\n    for key, value in keys_mapper.items():\n        prefix = key[: len(\"blocks.i\")]\n        suffix = key.split(prefix)[-1].split(\".\")[-1]\n        name = key.split(prefix)[-1].split(suffix)[0][1:-1]\n        mapped_names = MAPPING[name]\n\n        num_splits = len(mapped_names)\n        for i, mapped_name in enumerate(mapped_names):\n            new_name = \".\".join([prefix, mapped_name, suffix])\n            shape = value.shape[0] // num_splits\n            mapped_weights[new_name] = value[i * shape : (i + 1) * shape]\n\n    model.mapper.load_state_dict(mapped_weights)\n\n    # load final layer norm\n    model.final_layer_norm.load_state_dict(\n        {\n            \"bias\": checkpoint[\"cond_stage_model.final_ln.bias\"],\n            \"weight\": checkpoint[\"cond_stage_model.final_ln.weight\"],\n        }\n    )\n\n    # load final proj\n    model.proj_out.load_state_dict(\n        {\n            \"bias\": checkpoint[\"proj_out.bias\"],\n            \"weight\": checkpoint[\"proj_out.weight\"],\n        }\n    )\n\n    # load uncond vector\n    model.uncond_vector.data = torch.nn.Parameter(checkpoint[\"learnable_vector\"])\n    return model\n\n\ndef convert_open_clip_checkpoint(\n    checkpoint,\n    config_name,\n    prefix=\"cond_stage_model.model.\",\n    has_projection=False,\n    local_files_only=False,\n    **config_kwargs,\n):\n    # text_model = CLIPTextModel.from_pretrained(\"stabilityai/stable-diffusion-2\", subfolder=\"text_encoder\")\n    # text_model = CLIPTextModelWithProjection.from_pretrained(\n    #    \"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k\", projection_dim=1280\n    # )\n    try:\n        config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)\n    except Exception:\n        raise ValueError(\n            f\"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'.\"\n        )\n\n    ctx = init_empty_weights if is_accelerate_available() else nullcontext\n    with ctx():\n        text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)\n\n    keys = list(checkpoint.keys())\n\n    keys_to_ignore = []\n    if config_name == \"stabilityai/stable-diffusion-2\" and config.num_hidden_layers == 23:\n        # make sure to remove all keys > 22\n        keys_to_ignore += [k for k in keys if k.startswith(\"cond_stage_model.model.transformer.resblocks.23\")]\n        keys_to_ignore += [\"cond_stage_model.model.text_projection\"]\n\n    text_model_dict = {}\n\n    if prefix + \"text_projection\" in checkpoint:\n        d_model = int(checkpoint[prefix + \"text_projection\"].shape[0])\n    else:\n        d_model = 1024\n\n    text_model_dict[\"text_model.embeddings.position_ids\"] = text_model.text_model.embeddings.get_buffer(\"position_ids\")\n\n    for key in keys:\n        if key in keys_to_ignore:\n            continue\n        if key[len(prefix) :] in textenc_conversion_map:\n            if key.endswith(\"text_projection\"):\n                value = checkpoint[key].T.contiguous()\n            else:\n                value = checkpoint[key]\n\n            text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value\n\n        if key.startswith(prefix + \"transformer.\"):\n            new_key = key[len(prefix + \"transformer.\") :]\n            if new_key.endswith(\".in_proj_weight\"):\n                new_key = new_key[: -len(\".in_proj_weight\")]\n                new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)\n                text_model_dict[new_key + \".q_proj.weight\"] = checkpoint[key][:d_model, :]\n                text_model_dict[new_key + \".k_proj.weight\"] = checkpoint[key][d_model : d_model * 2, :]\n                text_model_dict[new_key + \".v_proj.weight\"] = checkpoint[key][d_model * 2 :, :]\n            elif new_key.endswith(\".in_proj_bias\"):\n                new_key = new_key[: -len(\".in_proj_bias\")]\n                new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)\n                text_model_dict[new_key + \".q_proj.bias\"] = checkpoint[key][:d_model]\n                text_model_dict[new_key + \".k_proj.bias\"] = checkpoint[key][d_model : d_model * 2]\n                text_model_dict[new_key + \".v_proj.bias\"] = checkpoint[key][d_model * 2 :]\n            else:\n                new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)\n\n                text_model_dict[new_key] = checkpoint[key]\n\n    if is_accelerate_available():\n        for param_name, param in text_model_dict.items():\n            set_module_tensor_to_device(text_model, param_name, \"cpu\", value=param)\n    else:\n        if not (hasattr(text_model, \"embeddings\") and hasattr(text_model.embeddings.position_ids)):\n            text_model_dict.pop(\"text_model.embeddings.position_ids\", None)\n\n        text_model.load_state_dict(text_model_dict)\n\n    return text_model\n\n\ndef stable_unclip_image_encoder(original_config, local_files_only=False):\n    \"\"\"\n    Returns the image processor and clip image encoder for the img2img unclip pipeline.\n\n    We currently know of two types of stable unclip models which separately use the clip and the openclip image\n    encoders.\n    \"\"\"\n\n    image_embedder_config = original_config[\"model\"][\"params\"][\"embedder_config\"]\n\n    sd_clip_image_embedder_class = image_embedder_config[\"target\"]\n    sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(\".\")[-1]\n\n    if sd_clip_image_embedder_class == \"ClipImageEmbedder\":\n        clip_model_name = image_embedder_config.params.model\n\n        if clip_model_name == \"ViT-L/14\":\n            feature_extractor = CLIPImageProcessor()\n            image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n                \"openai/clip-vit-large-patch14\", local_files_only=local_files_only\n            )\n        else:\n            raise NotImplementedError(f\"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}\")\n\n    elif sd_clip_image_embedder_class == \"FrozenOpenCLIPImageEmbedder\":\n        feature_extractor = CLIPImageProcessor()\n        image_encoder = CLIPVisionModelWithProjection.from_pretrained(\n            \"laion/CLIP-ViT-H-14-laion2B-s32B-b79K\", local_files_only=local_files_only\n        )\n    else:\n        raise NotImplementedError(\n            f\"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}\"\n        )\n\n    return feature_extractor, image_encoder\n\n\ndef stable_unclip_image_noising_components(\n    original_config, clip_stats_path: str | None = None, device: str | None = None\n):\n    \"\"\"\n    Returns the noising components for the img2img and txt2img unclip pipelines.\n\n    Converts the stability noise augmentor into\n    1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats\n    2. a `DDPMScheduler` for holding the noise schedule\n\n    If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.\n    \"\"\"\n    noise_aug_config = original_config[\"model\"][\"params\"][\"noise_aug_config\"]\n    noise_aug_class = noise_aug_config[\"target\"]\n    noise_aug_class = noise_aug_class.split(\".\")[-1]\n\n    if noise_aug_class == \"CLIPEmbeddingNoiseAugmentation\":\n        noise_aug_config = noise_aug_config.params\n        embedding_dim = noise_aug_config.timestep_dim\n        max_noise_level = noise_aug_config.noise_schedule_config.timesteps\n        beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule\n\n        image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)\n        image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)\n\n        if \"clip_stats_path\" in noise_aug_config:\n            if clip_stats_path is None:\n                raise ValueError(\"This stable unclip config requires a `clip_stats_path`\")\n\n            clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)\n            clip_mean = clip_mean[None, :]\n            clip_std = clip_std[None, :]\n\n            clip_stats_state_dict = {\n                \"mean\": clip_mean,\n                \"std\": clip_std,\n            }\n\n            image_normalizer.load_state_dict(clip_stats_state_dict)\n    else:\n        raise NotImplementedError(f\"Unknown noise augmentor class: {noise_aug_class}\")\n\n    return image_normalizer, image_noising_scheduler\n\n\ndef convert_controlnet_checkpoint(\n    checkpoint,\n    original_config,\n    checkpoint_path,\n    image_size,\n    upcast_attention,\n    extract_ema,\n    use_linear_projection=None,\n    cross_attention_dim=None,\n):\n    ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)\n    ctrlnet_config[\"upcast_attention\"] = upcast_attention\n\n    ctrlnet_config.pop(\"sample_size\")\n\n    if use_linear_projection is not None:\n        ctrlnet_config[\"use_linear_projection\"] = use_linear_projection\n\n    if cross_attention_dim is not None:\n        ctrlnet_config[\"cross_attention_dim\"] = cross_attention_dim\n\n    ctx = init_empty_weights if is_accelerate_available() else nullcontext\n    with ctx():\n        controlnet = ControlNetModel(**ctrlnet_config)\n\n    # Some controlnet ckpt files are distributed independently from the rest of the\n    # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/\n    if \"time_embed.0.weight\" in checkpoint:\n        skip_extract_state_dict = True\n    else:\n        skip_extract_state_dict = False\n\n    converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(\n        checkpoint,\n        ctrlnet_config,\n        path=checkpoint_path,\n        extract_ema=extract_ema,\n        controlnet=True,\n        skip_extract_state_dict=skip_extract_state_dict,\n    )\n\n    if is_accelerate_available():\n        for param_name, param in converted_ctrl_checkpoint.items():\n            set_module_tensor_to_device(controlnet, param_name, \"cpu\", value=param)\n    else:\n        controlnet.load_state_dict(converted_ctrl_checkpoint)\n\n    return controlnet\n\n\ndef convert_promptdiffusion_checkpoint(\n    checkpoint,\n    original_config,\n    checkpoint_path,\n    image_size,\n    upcast_attention,\n    extract_ema,\n    use_linear_projection=None,\n    cross_attention_dim=None,\n):\n    ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)\n    ctrlnet_config[\"upcast_attention\"] = upcast_attention\n\n    ctrlnet_config.pop(\"sample_size\")\n\n    if use_linear_projection is not None:\n        ctrlnet_config[\"use_linear_projection\"] = use_linear_projection\n\n    if cross_attention_dim is not None:\n        ctrlnet_config[\"cross_attention_dim\"] = cross_attention_dim\n\n    ctx = init_empty_weights if is_accelerate_available() else nullcontext\n    with ctx():\n        controlnet = PromptDiffusionControlNetModel(**ctrlnet_config)\n\n    # Some controlnet ckpt files are distributed independently from the rest of the\n    # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/\n    if \"time_embed.0.weight\" in checkpoint:\n        skip_extract_state_dict = True\n    else:\n        skip_extract_state_dict = False\n\n    converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(\n        checkpoint,\n        ctrlnet_config,\n        path=checkpoint_path,\n        extract_ema=extract_ema,\n        promptdiffusion=True,\n        controlnet=True,\n        skip_extract_state_dict=skip_extract_state_dict,\n    )\n\n    if is_accelerate_available():\n        for param_name, param in converted_ctrl_checkpoint.items():\n            set_module_tensor_to_device(controlnet, param_name, \"cpu\", value=param)\n    else:\n        controlnet.load_state_dict(converted_ctrl_checkpoint)\n\n    return controlnet\n\n\ndef download_from_original_stable_diffusion_ckpt(\n    checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]],\n    original_config_file: str = None,\n    image_size: Optional[int] = None,\n    prediction_type: str = None,\n    model_type: str = None,\n    extract_ema: bool = False,\n    scheduler_type: str = \"pndm\",\n    num_in_channels: Optional[int] = None,\n    upcast_attention: Optional[bool] = None,\n    device: str = None,\n    from_safetensors: bool = False,\n    stable_unclip: str | None = None,\n    stable_unclip_prior: str | None = None,\n    clip_stats_path: str | None = None,\n    controlnet: Optional[bool] = None,\n    adapter: Optional[bool] = None,\n    load_safety_checker: bool = True,\n    pipeline_class: DiffusionPipeline = None,\n    local_files_only=False,\n    vae_path=None,\n    vae=None,\n    text_encoder=None,\n    text_encoder_2=None,\n    tokenizer=None,\n    tokenizer_2=None,\n    config_files=None,\n) -> DiffusionPipeline:\n    \"\"\"\n    Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`\n    config file.\n\n    Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the\n    global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is\n    recommended that you override the default values and/or supply an `original_config_file` wherever possible.\n\n    Args:\n        checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict.\n        original_config_file (`str`):\n            Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically\n            inferred by looking for a key that only exists in SD2.0 models.\n        image_size (`int`, *optional*, defaults to 512):\n            The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2\n            Base. Use 768 for Stable Diffusion v2.\n        prediction_type (`str`, *optional*):\n            The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable\n            Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.\n        num_in_channels (`int`, *optional*, defaults to None):\n            The number of input channels. If `None`, it will be automatically inferred.\n        scheduler_type (`str`, *optional*, defaults to 'pndm'):\n            Type of scheduler to use. Should be one of `[\"pndm\", \"lms\", \"heun\", \"euler\", \"euler-ancestral\", \"dpm\",\n            \"ddim\"]`.\n        model_type (`str`, *optional*, defaults to `None`):\n            The pipeline type. `None` to automatically infer, or one of `[\"FrozenOpenCLIPEmbedder\",\n            \"FrozenCLIPEmbedder\", \"PaintByExample\"]`.\n        is_img2img (`bool`, *optional*, defaults to `False`):\n            Whether the model should be loaded as an img2img pipeline.\n        extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for\n            checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to\n            `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for\n            inference. Non-EMA weights are usually better to continue fine-tuning.\n        upcast_attention (`bool`, *optional*, defaults to `None`):\n            Whether the attention computation should always be upcasted. This is necessary when running stable\n            diffusion 2.1.\n        device (`str`, *optional*, defaults to `None`):\n            The device to use. Pass `None` to determine automatically.\n        from_safetensors (`str`, *optional*, defaults to `False`):\n            If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.\n        load_safety_checker (`bool`, *optional*, defaults to `True`):\n            Whether to load the safety checker or not. Defaults to `True`.\n        pipeline_class (`str`, *optional*, defaults to `None`):\n            The pipeline class to use. Pass `None` to determine automatically.\n        local_files_only (`bool`, *optional*, defaults to `False`):\n            Whether or not to only look at local files (i.e., do not try to download the model).\n        vae (`AutoencoderKL`, *optional*, defaults to `None`):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If\n            this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.\n        text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):\n            An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)\n            to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)\n            variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.\n        tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):\n            An instance of\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)\n            to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if\n            needed.\n        config_files (`Dict[str, str]`, *optional*, defaults to `None`):\n            A dictionary mapping from config file names to their contents. If this parameter is `None`, the function\n            will load the config files by itself, if needed. Valid keys are:\n                - `v1`: Config file for Stable Diffusion v1\n                - `v2`: Config file for Stable Diffusion v2\n                - `xl`: Config file for Stable Diffusion XL\n                - `xl_refiner`: Config file for Stable Diffusion XL Refiner\n        return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.\n    \"\"\"\n\n    # import pipelines here to avoid circular import error when using from_single_file method\n    from diffusers import (\n        LDMTextToImagePipeline,\n        PaintByExamplePipeline,\n        StableDiffusionControlNetPipeline,\n        StableDiffusionInpaintPipeline,\n        StableDiffusionPipeline,\n        StableDiffusionUpscalePipeline,\n        StableDiffusionXLControlNetInpaintPipeline,\n        StableDiffusionXLImg2ImgPipeline,\n        StableDiffusionXLInpaintPipeline,\n        StableDiffusionXLPipeline,\n        StableUnCLIPImg2ImgPipeline,\n        StableUnCLIPPipeline,\n    )\n\n    if prediction_type == \"v-prediction\":\n        prediction_type = \"v_prediction\"\n\n    if isinstance(checkpoint_path_or_dict, str):\n        if from_safetensors:\n            from safetensors.torch import load_file as safe_load\n\n            checkpoint = safe_load(checkpoint_path_or_dict, device=\"cpu\")\n        else:\n            if device is None:\n                device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n                checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)\n            else:\n                checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)\n    elif isinstance(checkpoint_path_or_dict, dict):\n        checkpoint = checkpoint_path_or_dict\n\n    # Sometimes models don't have the global_step item\n    if \"global_step\" in checkpoint:\n        global_step = checkpoint[\"global_step\"]\n    else:\n        logger.debug(\"global_step key not found in model\")\n        global_step = None\n\n    # NOTE: this while loop isn't great but this controlnet checkpoint has one additional\n    # \"state_dict\" key https://huggingface.co/thibaud/controlnet-canny-sd21\n    while \"state_dict\" in checkpoint:\n        checkpoint = checkpoint[\"state_dict\"]\n\n    if original_config_file is None:\n        key_name_v2_1 = \"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight\"\n        key_name_sd_xl_base = \"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias\"\n        key_name_sd_xl_refiner = \"conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias\"\n        is_upscale = pipeline_class == StableDiffusionUpscalePipeline\n\n        config_url = None\n\n        # model_type = \"v1\"\n        if config_files is not None and \"v1\" in config_files:\n            original_config_file = config_files[\"v1\"]\n        else:\n            config_url = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml\"\n\n        if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:\n            # model_type = \"v2\"\n            if config_files is not None and \"v2\" in config_files:\n                original_config_file = config_files[\"v2\"]\n            else:\n                config_url = \"https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml\"\n            if global_step == 110000:\n                # v2.1 needs to upcast attention\n                upcast_attention = True\n        elif key_name_sd_xl_base in checkpoint:\n            # only base xl has two text embedders\n            if config_files is not None and \"xl\" in config_files:\n                original_config_file = config_files[\"xl\"]\n            else:\n                config_url = \"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml\"\n        elif key_name_sd_xl_refiner in checkpoint:\n            # only refiner xl has embedder and one text embedders\n            if config_files is not None and \"xl_refiner\" in config_files:\n                original_config_file = config_files[\"xl_refiner\"]\n            else:\n                config_url = \"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml\"\n\n        if is_upscale:\n            config_url = \"https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml\"\n\n        if config_url is not None:\n            original_config_file = BytesIO(requests.get(config_url, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)\n        else:\n            with open(original_config_file, \"r\") as f:\n                original_config_file = f.read()\n\n    original_config = yaml.safe_load(original_config_file)\n\n    # Convert the text model.\n    if (\n        model_type is None\n        and \"cond_stage_config\" in original_config[\"model\"][\"params\"]\n        and original_config[\"model\"][\"params\"][\"cond_stage_config\"] is not None\n    ):\n        model_type = original_config[\"model\"][\"params\"][\"cond_stage_config\"][\"target\"].split(\".\")[-1]\n        logger.debug(f\"no `model_type` given, `model_type` inferred as: {model_type}\")\n    elif model_type is None and original_config[\"model\"][\"params\"][\"network_config\"] is not None:\n        if original_config[\"model\"][\"params\"][\"network_config\"][\"params\"][\"context_dim\"] == 2048:\n            model_type = \"SDXL\"\n        else:\n            model_type = \"SDXL-Refiner\"\n        if image_size is None:\n            image_size = 1024\n\n    if pipeline_class is None:\n        # Check if we have a SDXL or SD model and initialize default pipeline\n        if model_type not in [\"SDXL\", \"SDXL-Refiner\"]:\n            pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline\n        else:\n            pipeline_class = StableDiffusionXLPipeline if model_type == \"SDXL\" else StableDiffusionXLImg2ImgPipeline\n\n    if num_in_channels is None and pipeline_class in [\n        StableDiffusionInpaintPipeline,\n        StableDiffusionXLInpaintPipeline,\n        StableDiffusionXLControlNetInpaintPipeline,\n    ]:\n        num_in_channels = 9\n    if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:\n        num_in_channels = 7\n    elif num_in_channels is None:\n        num_in_channels = 4\n\n    if \"unet_config\" in original_config[\"model\"][\"params\"]:\n        original_config[\"model\"][\"params\"][\"unet_config\"][\"params\"][\"in_channels\"] = num_in_channels\n\n    if (\n        \"parameterization\" in original_config[\"model\"][\"params\"]\n        and original_config[\"model\"][\"params\"][\"parameterization\"] == \"v\"\n    ):\n        if prediction_type is None:\n            # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type==\"epsilon\"`\n            # as it relies on a brittle global step parameter here\n            prediction_type = \"epsilon\" if global_step == 875000 else \"v_prediction\"\n        if image_size is None:\n            # NOTE: For stable diffusion 2 base one has to pass `image_size==512`\n            # as it relies on a brittle global step parameter here\n            image_size = 512 if global_step == 875000 else 768\n    else:\n        if prediction_type is None:\n            prediction_type = \"epsilon\"\n        if image_size is None:\n            image_size = 512\n\n    if controlnet is None and \"control_stage_config\" in original_config[\"model\"][\"params\"]:\n        path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else \"\"\n        controlnet = convert_controlnet_checkpoint(\n            checkpoint, original_config, path, image_size, upcast_attention, extract_ema\n        )\n\n    if \"timesteps\" in original_config[\"model\"][\"params\"]:\n        num_train_timesteps = original_config[\"model\"][\"params\"][\"timesteps\"]\n    else:\n        num_train_timesteps = 1000\n\n    if model_type in [\"SDXL\", \"SDXL-Refiner\"]:\n        scheduler_dict = {\n            \"beta_schedule\": \"scaled_linear\",\n            \"beta_start\": 0.00085,\n            \"beta_end\": 0.012,\n            \"interpolation_type\": \"linear\",\n            \"num_train_timesteps\": num_train_timesteps,\n            \"prediction_type\": \"epsilon\",\n            \"sample_max_value\": 1.0,\n            \"set_alpha_to_one\": False,\n            \"skip_prk_steps\": True,\n            \"steps_offset\": 1,\n            \"timestep_spacing\": \"leading\",\n        }\n        scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)\n        scheduler_type = \"euler\"\n    else:\n        if \"linear_start\" in original_config[\"model\"][\"params\"]:\n            beta_start = original_config[\"model\"][\"params\"][\"linear_start\"]\n        else:\n            beta_start = 0.02\n\n        if \"linear_end\" in original_config[\"model\"][\"params\"]:\n            beta_end = original_config[\"model\"][\"params\"][\"linear_end\"]\n        else:\n            beta_end = 0.085\n        scheduler = DDIMScheduler(\n            beta_end=beta_end,\n            beta_schedule=\"scaled_linear\",\n            beta_start=beta_start,\n            num_train_timesteps=num_train_timesteps,\n            steps_offset=1,\n            clip_sample=False,\n            set_alpha_to_one=False,\n            prediction_type=prediction_type,\n        )\n    # make sure scheduler works correctly with DDIM\n    scheduler.register_to_config(clip_sample=False)\n\n    if scheduler_type == \"pndm\":\n        config = dict(scheduler.config)\n        config[\"skip_prk_steps\"] = True\n        scheduler = PNDMScheduler.from_config(config)\n    elif scheduler_type == \"lms\":\n        scheduler = LMSDiscreteScheduler.from_config(scheduler.config)\n    elif scheduler_type == \"heun\":\n        scheduler = HeunDiscreteScheduler.from_config(scheduler.config)\n    elif scheduler_type == \"euler\":\n        scheduler = EulerDiscreteScheduler.from_config(scheduler.config)\n    elif scheduler_type == \"euler-ancestral\":\n        scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)\n    elif scheduler_type == \"dpm\":\n        scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)\n    elif scheduler_type == \"ddim\":\n        scheduler = scheduler\n    else:\n        raise ValueError(f\"Scheduler of type {scheduler_type} doesn't exist!\")\n\n    if pipeline_class == StableDiffusionUpscalePipeline:\n        image_size = original_config[\"model\"][\"params\"][\"unet_config\"][\"params\"][\"image_size\"]\n\n    # Convert the UNet2DConditionModel model.\n    unet_config = create_unet_diffusers_config(original_config, image_size=image_size)\n    unet_config[\"upcast_attention\"] = upcast_attention\n\n    path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else \"\"\n    converted_unet_checkpoint = convert_ldm_unet_checkpoint(\n        checkpoint, unet_config, path=path, extract_ema=extract_ema\n    )\n\n    ctx = init_empty_weights if is_accelerate_available() else nullcontext\n    with ctx():\n        unet = UNet2DConditionModel(**unet_config)\n\n    if is_accelerate_available():\n        if model_type not in [\"SDXL\", \"SDXL-Refiner\"]:  # SBM Delay this.\n            for param_name, param in converted_unet_checkpoint.items():\n                set_module_tensor_to_device(unet, param_name, \"cpu\", value=param)\n    else:\n        unet.load_state_dict(converted_unet_checkpoint)\n\n    # Convert the VAE model.\n    if vae_path is None and vae is None:\n        vae_config = create_vae_diffusers_config(original_config, image_size=image_size)\n        converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)\n\n        if (\n            \"model\" in original_config\n            and \"params\" in original_config[\"model\"]\n            and \"scale_factor\" in original_config[\"model\"][\"params\"]\n        ):\n            vae_scaling_factor = original_config[\"model\"][\"params\"][\"scale_factor\"]\n        else:\n            vae_scaling_factor = 0.18215  # default SD scaling factor\n\n        vae_config[\"scaling_factor\"] = vae_scaling_factor\n\n        ctx = init_empty_weights if is_accelerate_available() else nullcontext\n        with ctx():\n            vae = AutoencoderKL(**vae_config)\n\n        if is_accelerate_available():\n            for param_name, param in converted_vae_checkpoint.items():\n                set_module_tensor_to_device(vae, param_name, \"cpu\", value=param)\n        else:\n            vae.load_state_dict(converted_vae_checkpoint)\n    elif vae is None:\n        vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=local_files_only)\n\n    if model_type == \"FrozenOpenCLIPEmbedder\":\n        config_name = \"stabilityai/stable-diffusion-2\"\n        config_kwargs = {\"subfolder\": \"text_encoder\"}\n\n        if text_encoder is None:\n            text_model = convert_open_clip_checkpoint(\n                checkpoint, config_name, local_files_only=local_files_only, **config_kwargs\n            )\n        else:\n            text_model = text_encoder\n\n        try:\n            tokenizer = CLIPTokenizer.from_pretrained(\n                \"stabilityai/stable-diffusion-2\", subfolder=\"tokenizer\", local_files_only=local_files_only\n            )\n        except Exception:\n            raise ValueError(\n                f\"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'.\"\n            )\n\n        if stable_unclip is None:\n            if controlnet:\n                pipe = pipeline_class(\n                    vae=vae,\n                    text_encoder=text_model,\n                    tokenizer=tokenizer,\n                    unet=unet,\n                    scheduler=scheduler,\n                    controlnet=controlnet,\n                    safety_checker=None,\n                    feature_extractor=None,\n                )\n                if hasattr(pipe, \"requires_safety_checker\"):\n                    pipe.requires_safety_checker = False\n\n            elif pipeline_class == StableDiffusionUpscalePipeline:\n                scheduler = DDIMScheduler.from_pretrained(\n                    \"stabilityai/stable-diffusion-x4-upscaler\", subfolder=\"scheduler\"\n                )\n                low_res_scheduler = DDPMScheduler.from_pretrained(\n                    \"stabilityai/stable-diffusion-x4-upscaler\", subfolder=\"low_res_scheduler\"\n                )\n\n                pipe = pipeline_class(\n                    vae=vae,\n                    text_encoder=text_model,\n                    tokenizer=tokenizer,\n                    unet=unet,\n                    scheduler=scheduler,\n                    low_res_scheduler=low_res_scheduler,\n                    safety_checker=None,\n                    feature_extractor=None,\n                )\n\n            else:\n                pipe = pipeline_class(\n                    vae=vae,\n                    text_encoder=text_model,\n                    tokenizer=tokenizer,\n                    unet=unet,\n                    scheduler=scheduler,\n                    safety_checker=None,\n                    feature_extractor=None,\n                )\n                if hasattr(pipe, \"requires_safety_checker\"):\n                    pipe.requires_safety_checker = False\n\n        else:\n            image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(\n                original_config, clip_stats_path=clip_stats_path, device=device\n            )\n\n            if stable_unclip == \"img2img\":\n                feature_extractor, image_encoder = stable_unclip_image_encoder(original_config)\n\n                pipe = StableUnCLIPImg2ImgPipeline(\n                    # image encoding components\n                    feature_extractor=feature_extractor,\n                    image_encoder=image_encoder,\n                    # image noising components\n                    image_normalizer=image_normalizer,\n                    image_noising_scheduler=image_noising_scheduler,\n                    # regular denoising components\n                    tokenizer=tokenizer,\n                    text_encoder=text_model,\n                    unet=unet,\n                    scheduler=scheduler,\n                    # vae\n                    vae=vae,\n                )\n            elif stable_unclip == \"txt2img\":\n                if stable_unclip_prior is None or stable_unclip_prior == \"karlo\":\n                    karlo_model = \"kakaobrain/karlo-v1-alpha\"\n                    prior = PriorTransformer.from_pretrained(\n                        karlo_model, subfolder=\"prior\", local_files_only=local_files_only\n                    )\n\n                    try:\n                        prior_tokenizer = CLIPTokenizer.from_pretrained(\n                            \"openai/clip-vit-large-patch14\", local_files_only=local_files_only\n                        )\n                    except Exception:\n                        raise ValueError(\n                            f\"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'.\"\n                        )\n                    prior_text_model = CLIPTextModelWithProjection.from_pretrained(\n                        \"openai/clip-vit-large-patch14\", local_files_only=local_files_only\n                    )\n\n                    prior_scheduler = UnCLIPScheduler.from_pretrained(\n                        karlo_model, subfolder=\"prior_scheduler\", local_files_only=local_files_only\n                    )\n                    prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)\n                else:\n                    raise NotImplementedError(f\"unknown prior for stable unclip model: {stable_unclip_prior}\")\n\n                pipe = StableUnCLIPPipeline(\n                    # prior components\n                    prior_tokenizer=prior_tokenizer,\n                    prior_text_encoder=prior_text_model,\n                    prior=prior,\n                    prior_scheduler=prior_scheduler,\n                    # image noising components\n                    image_normalizer=image_normalizer,\n                    image_noising_scheduler=image_noising_scheduler,\n                    # regular denoising components\n                    tokenizer=tokenizer,\n                    text_encoder=text_model,\n                    unet=unet,\n                    scheduler=scheduler,\n                    # vae\n                    vae=vae,\n                )\n            else:\n                raise NotImplementedError(f\"unknown `stable_unclip` type: {stable_unclip}\")\n    elif model_type == \"PaintByExample\":\n        vision_model = convert_paint_by_example_checkpoint(checkpoint)\n        try:\n            tokenizer = CLIPTokenizer.from_pretrained(\n                \"openai/clip-vit-large-patch14\", local_files_only=local_files_only\n            )\n        except Exception:\n            raise ValueError(\n                f\"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'.\"\n            )\n        try:\n            feature_extractor = AutoFeatureExtractor.from_pretrained(\n                \"CompVis/stable-diffusion-safety-checker\", local_files_only=local_files_only\n            )\n        except Exception:\n            raise ValueError(\n                f\"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'.\"\n            )\n        pipe = PaintByExamplePipeline(\n            vae=vae,\n            image_encoder=vision_model,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=None,\n            feature_extractor=feature_extractor,\n        )\n    elif model_type == \"FrozenCLIPEmbedder\":\n        text_model = convert_ldm_clip_checkpoint(\n            checkpoint, local_files_only=local_files_only, text_encoder=text_encoder\n        )\n        try:\n            tokenizer = (\n                CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", local_files_only=local_files_only)\n                if tokenizer is None\n                else tokenizer\n            )\n        except Exception:\n            raise ValueError(\n                f\"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'.\"\n            )\n\n        if load_safety_checker:\n            safety_checker = StableDiffusionSafetyChecker.from_pretrained(\n                \"CompVis/stable-diffusion-safety-checker\", local_files_only=local_files_only\n            )\n            feature_extractor = AutoFeatureExtractor.from_pretrained(\n                \"CompVis/stable-diffusion-safety-checker\", local_files_only=local_files_only\n            )\n        else:\n            safety_checker = None\n            feature_extractor = None\n\n        if controlnet:\n            pipe = pipeline_class(\n                vae=vae,\n                text_encoder=text_model,\n                tokenizer=tokenizer,\n                unet=unet,\n                controlnet=controlnet,\n                scheduler=scheduler,\n                safety_checker=safety_checker,\n                feature_extractor=feature_extractor,\n            )\n        else:\n            pipe = pipeline_class(\n                vae=vae,\n                text_encoder=text_model,\n                tokenizer=tokenizer,\n                unet=unet,\n                scheduler=scheduler,\n                safety_checker=safety_checker,\n                feature_extractor=feature_extractor,\n            )\n    elif model_type in [\"SDXL\", \"SDXL-Refiner\"]:\n        is_refiner = model_type == \"SDXL-Refiner\"\n\n        if (is_refiner is False) and (tokenizer is None):\n            try:\n                tokenizer = CLIPTokenizer.from_pretrained(\n                    \"openai/clip-vit-large-patch14\", local_files_only=local_files_only\n                )\n            except Exception:\n                raise ValueError(\n                    f\"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'.\"\n                )\n\n        if (is_refiner is False) and (text_encoder is None):\n            text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)\n\n        if tokenizer_2 is None:\n            try:\n                tokenizer_2 = CLIPTokenizer.from_pretrained(\n                    \"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k\", pad_token=\"!\", local_files_only=local_files_only\n                )\n            except Exception:\n                raise ValueError(\n                    f\"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'.\"\n                )\n\n        if text_encoder_2 is None:\n            config_name = \"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k\"\n            config_kwargs = {\"projection_dim\": 1280}\n            prefix = \"conditioner.embedders.0.model.\" if is_refiner else \"conditioner.embedders.1.model.\"\n\n            text_encoder_2 = convert_open_clip_checkpoint(\n                checkpoint,\n                config_name,\n                prefix=prefix,\n                has_projection=True,\n                local_files_only=local_files_only,\n                **config_kwargs,\n            )\n\n        if is_accelerate_available():  # SBM Now move model to cpu.\n            for param_name, param in converted_unet_checkpoint.items():\n                set_module_tensor_to_device(unet, param_name, \"cpu\", value=param)\n\n        if controlnet:\n            pipe = pipeline_class(\n                vae=vae,\n                text_encoder=text_encoder,\n                tokenizer=tokenizer,\n                text_encoder_2=text_encoder_2,\n                tokenizer_2=tokenizer_2,\n                unet=unet,\n                controlnet=controlnet,\n                scheduler=scheduler,\n                force_zeros_for_empty_prompt=True,\n            )\n        elif adapter:\n            pipe = pipeline_class(\n                vae=vae,\n                text_encoder=text_encoder,\n                tokenizer=tokenizer,\n                text_encoder_2=text_encoder_2,\n                tokenizer_2=tokenizer_2,\n                unet=unet,\n                adapter=adapter,\n                scheduler=scheduler,\n                force_zeros_for_empty_prompt=True,\n            )\n\n        else:\n            pipeline_kwargs = {\n                \"vae\": vae,\n                \"text_encoder\": text_encoder,\n                \"tokenizer\": tokenizer,\n                \"text_encoder_2\": text_encoder_2,\n                \"tokenizer_2\": tokenizer_2,\n                \"unet\": unet,\n                \"scheduler\": scheduler,\n            }\n\n            if (pipeline_class == StableDiffusionXLImg2ImgPipeline) or (\n                pipeline_class == StableDiffusionXLInpaintPipeline\n            ):\n                pipeline_kwargs.update({\"requires_aesthetics_score\": is_refiner})\n\n            if is_refiner:\n                pipeline_kwargs.update({\"force_zeros_for_empty_prompt\": False})\n\n            pipe = pipeline_class(**pipeline_kwargs)\n    else:\n        text_config = create_ldm_bert_config(original_config)\n        text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)\n        tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-uncased\", local_files_only=local_files_only)\n        pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)\n\n    return pipe\n\n\ndef download_controlnet_from_original_ckpt(\n    checkpoint_path: str,\n    original_config_file: str,\n    image_size: int = 512,\n    extract_ema: bool = False,\n    num_in_channels: Optional[int] = None,\n    upcast_attention: Optional[bool] = None,\n    device: str = None,\n    from_safetensors: bool = False,\n    use_linear_projection: Optional[bool] = None,\n    cross_attention_dim: Optional[bool] = None,\n) -> DiffusionPipeline:\n    if from_safetensors:\n        from safetensors import safe_open\n\n        checkpoint = {}\n        with safe_open(checkpoint_path, framework=\"pt\", device=\"cpu\") as f:\n            for key in f.keys():\n                checkpoint[key] = f.get_tensor(key)\n    else:\n        if device is None:\n            device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n            checkpoint = torch.load(checkpoint_path, map_location=device)\n        else:\n            checkpoint = torch.load(checkpoint_path, map_location=device)\n\n    # NOTE: this while loop isn't great but this controlnet checkpoint has one additional\n    # \"state_dict\" key https://huggingface.co/thibaud/controlnet-canny-sd21\n    while \"state_dict\" in checkpoint:\n        checkpoint = checkpoint[\"state_dict\"]\n\n    original_config = yaml.safe_load(original_config_file)\n\n    if num_in_channels is not None:\n        original_config[\"model\"][\"params\"][\"unet_config\"][\"params\"][\"in_channels\"] = num_in_channels\n\n    if \"control_stage_config\" not in original_config[\"model\"][\"params\"]:\n        raise ValueError(\"`control_stage_config` not present in original config\")\n\n    controlnet = convert_controlnet_checkpoint(\n        checkpoint,\n        original_config,\n        checkpoint_path,\n        image_size,\n        upcast_attention,\n        extract_ema,\n        use_linear_projection=use_linear_projection,\n        cross_attention_dim=cross_attention_dim,\n    )\n\n    return controlnet\n\n\ndef download_promptdiffusion_from_original_ckpt(\n    checkpoint_path: str,\n    original_config_file: str,\n    image_size: int = 512,\n    extract_ema: bool = False,\n    num_in_channels: Optional[int] = None,\n    upcast_attention: Optional[bool] = None,\n    device: str = None,\n    from_safetensors: bool = False,\n    use_linear_projection: Optional[bool] = None,\n    cross_attention_dim: Optional[bool] = None,\n) -> DiffusionPipeline:\n    if from_safetensors:\n        from safetensors import safe_open\n\n        checkpoint = {}\n        with safe_open(checkpoint_path, framework=\"pt\", device=\"cpu\") as f:\n            for key in f.keys():\n                checkpoint[key] = f.get_tensor(key)\n    else:\n        if device is None:\n            device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n            checkpoint = torch.load(checkpoint_path, map_location=device)\n        else:\n            checkpoint = torch.load(checkpoint_path, map_location=device)\n\n    # NOTE: this while loop isn't great but this controlnet checkpoint has one additional\n    # \"state_dict\" key https://huggingface.co/thibaud/controlnet-canny-sd21\n    while \"state_dict\" in checkpoint:\n        checkpoint = checkpoint[\"state_dict\"]\n\n    original_config = yaml.safe_load(open(original_config_file))\n\n    if num_in_channels is not None:\n        original_config[\"model\"][\"params\"][\"unet_config\"][\"params\"][\"in_channels\"] = num_in_channels\n    if \"control_stage_config\" not in original_config[\"model\"][\"params\"]:\n        raise ValueError(\"`control_stage_config` not present in original config\")\n\n    controlnet = convert_promptdiffusion_checkpoint(\n        checkpoint,\n        original_config,\n        checkpoint_path,\n        image_size,\n        upcast_attention,\n        extract_ema,\n        use_linear_projection=use_linear_projection,\n        cross_attention_dim=cross_attention_dim,\n    )\n\n    return controlnet\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\n        \"--checkpoint_path\", default=None, type=str, required=True, help=\"Path to the checkpoint to convert.\"\n    )\n    parser.add_argument(\n        \"--original_config_file\",\n        type=str,\n        required=True,\n        help=\"The YAML config file corresponding to the original architecture.\",\n    )\n    parser.add_argument(\n        \"--num_in_channels\",\n        default=None,\n        type=int,\n        help=\"The number of input channels. If `None` number of input channels will be automatically inferred.\",\n    )\n    parser.add_argument(\n        \"--image_size\",\n        default=512,\n        type=int,\n        help=(\n            \"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2\"\n            \" Base. Use 768 for Stable Diffusion v2.\"\n        ),\n    )\n    parser.add_argument(\n        \"--extract_ema\",\n        action=\"store_true\",\n        help=(\n            \"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights\"\n            \" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield\"\n            \" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_attention\",\n        action=\"store_true\",\n        help=(\n            \"Whether the attention computation should always be upcasted. This is necessary when running stable\"\n            \" diffusion 2.1.\"\n        ),\n    )\n    parser.add_argument(\n        \"--from_safetensors\",\n        action=\"store_true\",\n        help=\"If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.\",\n    )\n    parser.add_argument(\n        \"--to_safetensors\",\n        action=\"store_true\",\n        help=\"Whether to store pipeline in safetensors format or not.\",\n    )\n    parser.add_argument(\"--dump_path\", default=None, type=str, required=True, help=\"Path to the output model.\")\n    parser.add_argument(\"--device\", type=str, help=\"Device to use (e.g. cpu, cuda:0, cuda:1, etc.)\")\n\n    # small workaround to get argparser to parse a boolean input as either true _or_ false\n    def parse_bool(string):\n        if string == \"True\":\n            return True\n        elif string == \"False\":\n            return False\n        else:\n            raise ValueError(f\"could not parse string as bool {string}\")\n\n    parser.add_argument(\n        \"--use_linear_projection\", help=\"Override for use linear projection\", required=False, type=parse_bool\n    )\n\n    parser.add_argument(\"--cross_attention_dim\", help=\"Override for cross attention_dim\", required=False, type=int)\n\n    args = parser.parse_args()\n\n    controlnet = download_promptdiffusion_from_original_ckpt(\n        checkpoint_path=args.checkpoint_path,\n        original_config_file=args.original_config_file,\n        image_size=args.image_size,\n        extract_ema=args.extract_ema,\n        num_in_channels=args.num_in_channels,\n        upcast_attention=args.upcast_attention,\n        from_safetensors=args.from_safetensors,\n        device=args.device,\n        use_linear_projection=args.use_linear_projection,\n        cross_attention_dim=args.cross_attention_dim,\n    )\n\n    controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)\n"
  },
  {
    "path": "examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# Based on [In-Context Learning Unlocked for Diffusion Models](https://huggingface.co/papers/2305.01115)\n# Authors: Zhendong Wang, Yifan Jiang, Yadong Lu, Yelong Shen, Pengcheng He, Weizhu Chen, Zhangyang Wang, Mingyuan Zhou\n# Project Page: https://zhendong-wang.github.io/prompt-diffusion.github.io/\n# Code: https://github.com/Zhendong-Wang/Prompt-Diffusion\n#\n# Adapted to Diffusers by [iczaw](https://github.com/iczaw).\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> # !pip install opencv-python transformers accelerate\n        >>> from promptdiffusioncontrolnet import PromptDiffusionControlNetModel\n        >>> from diffusers.utils import load_image\n        >>> import torch\n\n        >>> from diffusers.pipelines.pipeline_utils import DiffusionPipeline\n        >>> from diffusers import UniPCMultistepScheduler\n        >>> from PIL import ImageOps\n\n        >>> # download an image\n        >>> image_a = ImageOps.invert(load_image(\n        ...     \"https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house_line.png?raw=true\"\n        ... ))\n\n        >>> # download an image\n        >>> image_b = load_image(\n        ...     \"https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house.png?raw=true\"\n        ... )\n\n        >>> # download an image\n        >>> query = ImageOps.invert(load_image(\n        ...     \"https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/new_01.png?raw=true\"\n        ... ))\n\n        >>> # load prompt diffusion control net and prompt diffusion\n        >>> controlnet = PromptDiffusionControlNetModel.from_pretrained(\"path-to-converted-promptdiffusion-controlnet\", torch_dtype=torch.float16)\n        >>> pipe = DiffusionPipeline.from_pretrained(model_id, controlnet=controlnet, torch_dtype=torch.float16, variant=\"fp16\", custom_pipeline=\"pipeline_prompt_diffusion\")\n\n        >>> # speed up diffusion process with faster scheduler and memory optimization\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n        >>> # remove following line if xformers is not installed\n        >>> pipe.enable_xformers_memory_efficient_attention()\n\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> # generate image\n        >>> generator = torch.manual_seed(0)\n        >>> image = pipe(\n        ...     \"a tortoise\", num_inference_steps=20, generator=generator, image_pair=[image_a,image_b], image=query\n        ... ).images[0]\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used,\n            `timesteps` must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default\n                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`\n                must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass PromptDiffusionPipeline(\n    DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin\n):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.\n\n    This pipeline also adds experimental support for [Prompt Diffusion](https://huggingface.co/papers/2305.01115).\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):\n            Provides additional conditioning to the `unet` during the denoising process. If you set multiple\n            ControlNets as a list, the outputs from each ControlNet are added together to create one combined\n            additional conditioning.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)\n        self.control_image_processor = VaeImageProcessor(\n            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False\n        )\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        self.vae.enable_slicing()\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_slicing()\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        depr_message = f\"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`.\"\n        deprecate(\n            \"enable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.enable_tiling()\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        depr_message = f\"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`.\"\n        deprecate(\n            \"disable_vae_tiling\",\n            \"0.40.0\",\n            depr_message,\n        )\n        self.vae.disable_tiling()\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ):\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n            # Retrieve the original scale by scaling back the LoRA layers\n            unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents\n    def decode_latents(self, latents):\n        deprecation_message = \"The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead\"\n        deprecate(\"decode_latents\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        image,\n        image_pair,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        controlnet_conditioning_scale=1.0,\n        control_guidance_start=0.0,\n        control_guidance_end=1.0,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        # `prompt` needs more sophisticated handling when there are multiple\n        # conditionings.\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if isinstance(prompt, list):\n                logger.warning(\n                    f\"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}\"\n                    \" prompts. The conditionings will be fixed across the prompts.\"\n                )\n\n        # Check `image`\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            self.check_image(image, prompt, prompt_embeds)\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if not isinstance(image, list):\n                raise TypeError(\"For multiple controlnets: `image` must be type `list`\")\n\n            # When `image` is a nested list:\n            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])\n            elif any(isinstance(i, list) for i in image):\n                raise ValueError(\"A single batch of multiple conditionings is not supported at the moment.\")\n            elif len(image) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets.\"\n                )\n\n            for image_ in image:\n                self.check_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `image_pair`\n        if len(image_pair) == 2:\n            for image in image_pair:\n                if (\n                    isinstance(self.controlnet, ControlNetModel)\n                    or is_compiled\n                    and isinstance(self.controlnet._orig_mod, ControlNetModel)\n                ):\n                    self.check_image(image, prompt, prompt_embeds)\n        else:\n            raise ValueError(\n                f\"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two.\"\n            )\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings is not supported at the moment.\")\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n        if not isinstance(control_guidance_start, (tuple, list)):\n            control_guidance_start = [control_guidance_start]\n\n        if not isinstance(control_guidance_end, (tuple, list)):\n            control_guidance_end = [control_guidance_end]\n\n        if len(control_guidance_start) != len(control_guidance_end):\n            raise ValueError(\n                f\"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list.\"\n            )\n\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if len(control_guidance_start) != len(self.controlnet.nets):\n                raise ValueError(\n                    f\"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}.\"\n                )\n\n        for start, end in zip(control_guidance_start, control_guidance_end):\n            if start >= end:\n                raise ValueError(\n                    f\"control guidance start: {start} cannot be larger or equal to control guidance end: {end}.\"\n                )\n            if start < 0.0:\n                raise ValueError(f\"control guidance start: {start} can't be smaller than 0.\")\n            if end > 1.0:\n                raise ValueError(f\"control guidance end: {end} can't be larger than 1.0.\")\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_np = isinstance(image, np.ndarray)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)\n\n        if (\n            not image_is_pil\n            and not image_is_tensor\n            and not image_is_np\n            and not image_is_pil_list\n            and not image_is_tensor_list\n            and not image_is_np_list\n        ):\n            raise TypeError(\n                f\"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        else:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image\n    def prepare_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu\n    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):\n        r\"\"\"Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497.\n\n        The suffixes after the scaling factors represent the stages where they are being applied.\n\n        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values\n        that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.\n\n        Args:\n            s1 (`float`):\n                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to\n                mitigate \"oversmoothing effect\" in the enhanced denoising process.\n            s2 (`float`):\n                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to\n                mitigate \"oversmoothing effect\" in the enhanced denoising process.\n            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.\n            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.\n        \"\"\"\n        if not hasattr(self, \"unet\"):\n            raise ValueError(\"The pipeline must have `unet` for using FreeU.\")\n        self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu\n    def disable_freeu(self):\n        \"\"\"Disables the FreeU mechanism if enabled.\"\"\"\n        self.unet.disable_freeu()\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            timesteps (`torch.Tensor`):\n                generate embedding vectors at these timesteps\n            embedding_dim (`int`, *optional*, defaults to 512):\n                dimension of the embeddings to generate\n            dtype:\n                data type of the generated embeddings\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        image_pair: List[PipelineImageInput] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        timesteps: List[int] = None,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,\n        guess_mode: bool = False,\n        control_guidance_start: Union[float, List[float]] = 0.0,\n        control_guidance_end: Union[float, List[float]] = 1.0,\n        clip_skip: Optional[int] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:\n                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):\n                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is\n                specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be\n                accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height\n                and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in\n                `init`, images must be passed as a list such that each element of the list can be correctly batched for\n                input to a single ControlNet.\n            image_pair `List[PIL.Image.Image]`:\n                a pair of task-specific example images\n            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor is generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that calls every `callback_steps` steps during inference. The function is called with the\n                following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function is called. If not specified, the callback is called at\n                every step.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added\n                to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set\n                the corresponding scale as a list.\n            guess_mode (`bool`, *optional*, defaults to `False`):\n                The ControlNet encoder tries to recognize the content of the input image even if you remove all\n                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.\n            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):\n                The percentage of total steps at which the ControlNet starts applying.\n            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):\n                The percentage of total steps at which the ControlNet stops applying.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`\",\n            )\n\n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # align format for control guidance\n        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):\n            control_guidance_start = len(control_guidance_end) * [control_guidance_start]\n        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):\n            control_guidance_end = len(control_guidance_start) * [control_guidance_end]\n        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):\n            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1\n            control_guidance_start, control_guidance_end = (\n                mult * [control_guidance_start],\n                mult * [control_guidance_end],\n            )\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            image,\n            image_pair,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            controlnet_conditioning_scale,\n            control_guidance_start,\n            control_guidance_end,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):\n            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)\n\n        global_pool_conditions = (\n            controlnet.config.global_pool_conditions\n            if isinstance(controlnet, ControlNetModel)\n            else controlnet.nets[0].config.global_pool_conditions\n        )\n        guess_mode = guess_mode or global_pool_conditions\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        # 3.1 Prepare image pair\n\n        if isinstance(controlnet, ControlNetModel):\n            image_pair = torch.cat(\n                [\n                    self.prepare_image(\n                        image=im,\n                        width=width,\n                        height=height,\n                        batch_size=batch_size * num_images_per_prompt,\n                        num_images_per_prompt=num_images_per_prompt,\n                        device=device,\n                        dtype=controlnet.dtype,\n                        do_classifier_free_guidance=self.do_classifier_free_guidance,\n                        guess_mode=guess_mode,\n                    )\n                    for im in image_pair\n                ],\n                1,\n            )\n        # 4. Prepare image\n        if isinstance(controlnet, ControlNetModel):\n            image = self.prepare_image(\n                image=image,\n                width=width,\n                height=height,\n                batch_size=batch_size * num_images_per_prompt,\n                num_images_per_prompt=num_images_per_prompt,\n                device=device,\n                dtype=controlnet.dtype,\n                do_classifier_free_guidance=self.do_classifier_free_guidance,\n                guess_mode=guess_mode,\n            )\n            height, width = image.shape[-2:]\n        elif isinstance(controlnet, MultiControlNetModel):\n            images = []\n\n            for image_ in image:\n                image_ = self.prepare_image(\n                    image=image_,\n                    width=width,\n                    height=height,\n                    batch_size=batch_size * num_images_per_prompt,\n                    num_images_per_prompt=num_images_per_prompt,\n                    device=device,\n                    dtype=controlnet.dtype,\n                    do_classifier_free_guidance=self.do_classifier_free_guidance,\n                    guess_mode=guess_mode,\n                )\n\n                images.append(image_)\n\n            image = images\n            height, width = image[0].shape[-2:]\n        else:\n            assert False\n\n        # 5. Prepare timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)\n        self._num_timesteps = len(timesteps)\n\n        # 6. Prepare latent variables\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6.5 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7.2 Create tensor stating which controlnets to keep\n        controlnet_keep = []\n        for i in range(len(timesteps)):\n            keeps = [\n                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)\n                for s, e in zip(control_guidance_start, control_guidance_end)\n            ]\n            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        is_unet_compiled = is_compiled_module(self.unet)\n        is_controlnet_compiled = is_compiled_module(self.controlnet)\n        is_torch_higher_equal_2_1 = is_torch_version(\">=\", \"2.1\")\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # Relevant thread:\n                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428\n                if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:\n                    torch._inductor.cudagraph_mark_step_begin()\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # controlnet(s) inference\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Infer ControlNet only for the conditional batch.\n                    control_model_input = latents\n                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)\n                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]\n                else:\n                    control_model_input = latent_model_input\n                    controlnet_prompt_embeds = prompt_embeds\n\n                if isinstance(controlnet_keep[i], list):\n                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]\n                else:\n                    controlnet_cond_scale = controlnet_conditioning_scale\n                    if isinstance(controlnet_cond_scale, list):\n                        controlnet_cond_scale = controlnet_cond_scale[0]\n                    cond_scale = controlnet_cond_scale * controlnet_keep[i]\n\n                down_block_res_samples, mid_block_res_sample = self.controlnet(\n                    control_model_input,\n                    t,\n                    encoder_hidden_states=controlnet_prompt_embeds,\n                    controlnet_query_cond=image,\n                    controlnet_cond=image_pair,\n                    conditioning_scale=cond_scale,\n                    guess_mode=guess_mode,\n                    return_dict=False,\n                )\n\n                if guess_mode and self.do_classifier_free_guidance:\n                    # Inferred ControlNet only for the conditional batch.\n                    # To apply the output of ControlNet to both the unconditional and conditional batches,\n                    # add 0 to the unconditional batch to keep it unchanged.\n                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]\n                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    down_block_additional_residuals=down_block_res_samples,\n                    mid_block_additional_residual=mid_block_res_sample,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        # If we do sequential model offloading, let's offload unet and controlnet\n        # manually for max memory savings\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.unet.to(\"cpu\")\n            self.controlnet.to(\"cpu\")\n            torch.cuda.empty_cache()\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\n                0\n            ]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport torch\n\nfrom diffusers.configuration_utils import register_to_config\nfrom diffusers.models.controlnet import (\n    ControlNetConditioningEmbedding,\n    ControlNetModel,\n    ControlNetOutput,\n)\nfrom diffusers.utils import logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass PromptDiffusionControlNetModel(ControlNetModel):\n    \"\"\"\n    A PromptDiffusionControlNet model.\n\n    Args:\n        in_channels (`int`, defaults to 4):\n            The number of channels in the input sample.\n        flip_sin_to_cos (`bool`, defaults to `True`):\n            Whether to flip the sin to cos in the time embedding.\n        freq_shift (`int`, defaults to 0):\n            The frequency shift to apply to the time embedding.\n        down_block_types (`tuple[str]`, defaults to `(\"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"DownBlock2D\")`):\n            The tuple of downsample blocks to use.\n        only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):\n        block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        layers_per_block (`int`, defaults to 2):\n            The number of layers per block.\n        downsample_padding (`int`, defaults to 1):\n            The padding to use for the downsampling convolution.\n        mid_block_scale_factor (`float`, defaults to 1):\n            The scale factor to use for the mid block.\n        act_fn (`str`, defaults to \"silu\"):\n            The activation function to use.\n        norm_num_groups (`int`, *optional*, defaults to 32):\n            The number of groups to use for the normalization. If None, normalization and activation layers is skipped\n            in post-processing.\n        norm_eps (`float`, defaults to 1e-5):\n            The epsilon to use for the normalization.\n        cross_attention_dim (`int`, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n        encoder_hid_dim (`int`, *optional*, defaults to None):\n            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`\n            dimension to `cross_attention_dim`.\n        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):\n            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text\n            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.\n        attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):\n            The dimension of the attention heads.\n        use_linear_projection (`bool`, defaults to `False`):\n        class_embed_type (`str`, *optional*, defaults to `None`):\n            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,\n            `\"timestep\"`, `\"identity\"`, `\"projection\"`, or `\"simple_projection\"`.\n        addition_embed_type (`str`, *optional*, defaults to `None`):\n            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or\n            \"text\". \"text\" will use the `TextTimeEmbedding` layer.\n        num_class_embeds (`int`, *optional*, defaults to 0):\n            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing\n            class conditioning with `class_embed_type` equal to `None`.\n        upcast_attention (`bool`, defaults to `False`):\n        resnet_time_scale_shift (`str`, defaults to `\"default\"`):\n            Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.\n        projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):\n            The dimension of the `class_labels` input when `class_embed_type=\"projection\"`. Required when\n            `class_embed_type=\"projection\"`.\n        controlnet_conditioning_channel_order (`str`, defaults to `\"rgb\"`):\n            The channel order of conditional image. Will convert to `rgb` if it's `bgr`.\n        conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):\n            The tuple of output channel for each block in the `conditioning_embedding` layer.\n        global_pool_conditions (`bool`, defaults to `False`):\n            TODO(Patrick) - unused parameter.\n        addition_embed_type_num_heads (`int`, defaults to 64):\n            The number of heads to use for the `TextTimeEmbedding` layer.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        in_channels: int = 4,\n        conditioning_channels: int = 3,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str, ...] = (\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"DownBlock2D\",\n        ),\n        mid_block_type: str | None = \"UNetMidBlock2DCrossAttn\",\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),\n        layers_per_block: int = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: int = 1280,\n        transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,\n        encoder_hid_dim: Optional[int] = None,\n        encoder_hid_dim_type: str | None = None,\n        attention_head_dim: Union[int, Tuple[int, ...]] = 8,\n        num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,\n        use_linear_projection: bool = False,\n        class_embed_type: str | None = None,\n        addition_embed_type: str | None = None,\n        addition_time_embed_dim: Optional[int] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        controlnet_conditioning_channel_order: str = \"rgb\",\n        conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),\n        global_pool_conditions: bool = False,\n        addition_embed_type_num_heads: int = 64,\n    ):\n        super().__init__(\n            in_channels,\n            conditioning_channels,\n            flip_sin_to_cos,\n            freq_shift,\n            down_block_types,\n            mid_block_type,\n            only_cross_attention,\n            block_out_channels,\n            layers_per_block,\n            downsample_padding,\n            mid_block_scale_factor,\n            act_fn,\n            norm_num_groups,\n            norm_eps,\n            cross_attention_dim,\n            transformer_layers_per_block,\n            encoder_hid_dim,\n            encoder_hid_dim_type,\n            attention_head_dim,\n            num_attention_heads,\n            use_linear_projection,\n            class_embed_type,\n            addition_embed_type,\n            addition_time_embed_dim,\n            num_class_embeds,\n            upcast_attention,\n            resnet_time_scale_shift,\n            projection_class_embeddings_input_dim,\n            controlnet_conditioning_channel_order,\n            conditioning_embedding_out_channels,\n            global_pool_conditions,\n            addition_embed_type_num_heads,\n        )\n        self.controlnet_query_cond_embedding = ControlNetConditioningEmbedding(\n            conditioning_embedding_channels=block_out_channels[0],\n            block_out_channels=conditioning_embedding_out_channels,\n            conditioning_channels=3,\n        )\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        controlnet_cond: torch.Tensor,\n        controlnet_query_cond: torch.Tensor,\n        conditioning_scale: float = 1.0,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guess_mode: bool = False,\n        return_dict: bool = True,\n    ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:\n        \"\"\"\n        The [`~PromptDiffusionControlNetModel`] forward method.\n\n        Args:\n            sample (`torch.Tensor`):\n                The noisy input tensor.\n            timestep (`Union[torch.Tensor, float, int]`):\n                The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.Tensor`):\n                The encoder hidden states.\n            controlnet_cond (`torch.Tensor`):\n                The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.\n            controlnet_query_cond (`torch.Tensor`):\n                The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.\n            conditioning_scale (`float`, defaults to `1.0`):\n                The scale factor for ControlNet outputs.\n            class_labels (`torch.Tensor`, *optional*, defaults to `None`):\n                Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.\n            timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):\n                Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the\n                timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep\n                embeddings.\n            attention_mask (`torch.Tensor`, *optional*, defaults to `None`):\n                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask\n                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large\n                negative values to the attention scores corresponding to \"discard\" tokens.\n            added_cond_kwargs (`dict`):\n                Additional conditions for the Stable Diffusion XL UNet.\n            cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):\n                A kwargs dictionary that if specified is passed along to the `AttnProcessor`.\n            guess_mode (`bool`, defaults to `False`):\n                In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if\n                you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.\n            return_dict (`bool`, defaults to `True`):\n                Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.controlnets.controlnet.ControlNetOutput`] **or** `tuple`:\n                If `return_dict` is `True`, a [`~models.controlnets.controlnet.ControlNetOutput`] is returned, otherwise a tuple is\n                returned where the first element is the sample tensor.\n        \"\"\"\n        # check channel order\n        channel_order = self.config.controlnet_conditioning_channel_order\n\n        if channel_order == \"rgb\":\n            # in rgb order by default\n            ...\n        elif channel_order == \"bgr\":\n            controlnet_cond = torch.flip(controlnet_cond, dims=[1])\n        else:\n            raise ValueError(f\"unknown `controlnet_conditioning_channel_order`: {channel_order}\")\n\n        # prepare attention_mask\n        if attention_mask is not None:\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            is_npu = sample.device.type == \"npu\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if (is_mps or is_npu) else torch.float64\n            else:\n                dtype = torch.int32 if (is_mps or is_npu) else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)\n            emb = emb + class_emb\n\n        if self.config.addition_embed_type is not None:\n            if self.config.addition_embed_type == \"text\":\n                aug_emb = self.add_embedding(encoder_hidden_states)\n\n            elif self.config.addition_embed_type == \"text_time\":\n                if \"text_embeds\" not in added_cond_kwargs:\n                    raise ValueError(\n                        f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`\"\n                    )\n                text_embeds = added_cond_kwargs.get(\"text_embeds\")\n                if \"time_ids\" not in added_cond_kwargs:\n                    raise ValueError(\n                        f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`\"\n                    )\n                time_ids = added_cond_kwargs.get(\"time_ids\")\n                time_embeds = self.add_time_proj(time_ids.flatten())\n                time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))\n\n                add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)\n                add_embeds = add_embeds.to(emb.dtype)\n                aug_emb = self.add_embedding(add_embeds)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)\n        controlnet_query_cond = self.controlnet_query_cond_embedding(controlnet_query_cond)\n        sample = sample + controlnet_cond + controlnet_query_cond\n\n        # 3. down\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n\n            down_block_res_samples += res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            if hasattr(self.mid_block, \"has_cross_attention\") and self.mid_block.has_cross_attention:\n                sample = self.mid_block(\n                    sample,\n                    emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                )\n            else:\n                sample = self.mid_block(sample, emb)\n\n        # 5. Control net blocks\n\n        controlnet_down_block_res_samples = ()\n\n        for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):\n            down_block_res_sample = controlnet_block(down_block_res_sample)\n            controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)\n\n        down_block_res_samples = controlnet_down_block_res_samples\n\n        mid_block_res_sample = self.controlnet_mid_block(sample)\n\n        # 6. scaling\n        if guess_mode and not self.config.global_pool_conditions:\n            scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device)  # 0.1 to 1.0\n            scales = scales * conditioning_scale\n            down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]\n            mid_block_res_sample = mid_block_res_sample * scales[-1]  # last one\n        else:\n            down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]\n            mid_block_res_sample = mid_block_res_sample * conditioning_scale\n\n        if self.config.global_pool_conditions:\n            down_block_res_samples = [\n                torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples\n            ]\n            mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)\n\n        if not return_dict:\n            return (down_block_res_samples, mid_block_res_sample)\n\n        return ControlNetOutput(\n            down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample\n        )\n"
  },
  {
    "path": "examples/research_projects/pytorch_xla/inference/flux/README.md",
    "content": "# Generating images using Flux and PyTorch/XLA\n\nThe `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.\n\n## Create TPU\n\nTo create a TPU on Google Cloud, follow [this guide](https://cloud.google.com/tpu/docs/v6e)\n\n## Setup TPU environment\n\nSSH into the VM and install Pytorch, Pytorch/XLA\n\n```bash\npip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html\npip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html\n```\n\nVerify that PyTorch and PyTorch/XLA were installed correctly:\n\n```bash\npython3 -c \"import torch; import torch_xla;\"\n```\n\nClone the diffusers repo and install dependencies\n\n```bash\ngit clone https://github.com/huggingface/diffusers.git\ncd diffusers\npip install transformers accelerate sentencepiece structlog\npip install .\ncd examples/research_projects/pytorch_xla/inference/flux/\n```\n\n## Run the inference job\n\n### Authenticate\n\n**Gated Model**\n\nAs the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:\n\n```bash\nhf auth login\n```\n\nThen run:\n\n```bash\npython flux_inference.py\n```\n\nThe script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest. \n\nOn a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel):\n\n```bash\nWARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.\nLoading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.06it/s]\nLoading pipeline components...:  60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                | 3/5 [00:00<00:00,  6.80it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers\nLoading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.28it/s]\n2025-03-14 21:17:53 [info     ] loading flux from black-forest-labs/FLUX.1-dev\n2025-03-14 21:17:53 [info     ] loading flux from black-forest-labs/FLUX.1-dev\nLoading pipeline components...:   0%|                                                                                                                                                                                                                                                        | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:53 [info     ] loading flux from black-forest-labs/FLUX.1-dev\n2025-03-14 21:17:53 [info     ] loading flux from black-forest-labs/FLUX.1-dev\nLoading pipeline components...:   0%|                                                                                                                                                                                                                                                        | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:54 [info     ] loading flux from black-forest-labs/FLUX.1-dev\n2025-03-14 21:17:54 [info     ] loading flux from black-forest-labs/FLUX.1-dev\nLoading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.66it/s]\nLoading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  4.48it/s]\nLoading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.32it/s]\nLoading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.69it/s]\nLoading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.74it/s]\nLoading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.10it/s]\n2025-03-14 21:17:56 [info     ] loading flux from black-forest-labs/FLUX.1-dev\nLoading pipeline components...:   0%|                                                                                                                                                                                                                                                        | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:56 [info     ] loading flux from black-forest-labs/FLUX.1-dev\nLoading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.55it/s]\nLoading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.46it/s]\n2025-03-14 21:18:34 [info     ] starting compilation run...   \n2025-03-14 21:18:37 [info     ] starting compilation run...   \n2025-03-14 21:18:38 [info     ] starting compilation run...   \n2025-03-14 21:18:39 [info     ] starting compilation run...   \n2025-03-14 21:18:41 [info     ] starting compilation run...   \n2025-03-14 21:18:41 [info     ] starting compilation run...   \n2025-03-14 21:18:42 [info     ] starting compilation run...   \n2025-03-14 21:18:43 [info     ] starting compilation run...   \n 82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                | 23/28 [13:35<03:04, 36.80s/it]2025-03-14 21:33:42.057559: W torch_xla/csrc/runtime/pjrt_computation_client.cc:667] Failed to deserialize executable: INTERNAL: TfrtTpuExecutable proto deserialization failed while parsing core program!\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.28s/it]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.26s/it]\n2025-03-14 21:36:38 [info     ] compilation took 1079.3314765350078 sec.\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]\n2025-03-14 21:36:38 [info     ] starting inference run...     \n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]\n2025-03-14 21:36:38 [info     ] compilation took 1081.89390801001 sec.\n2025-03-14 21:36:39 [info     ] starting inference run...     \n2025-03-14 21:36:39 [info     ] compilation took 1077.1543154849933 sec.\n2025-03-14 21:36:39 [info     ] compilation took 1075.7239800530078 sec.\n2025-03-14 21:36:39 [info     ] starting inference run...     \n2025-03-14 21:36:40 [info     ] starting inference run...     \n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:22<00:00, 35.10s/it]\n2025-03-14 21:36:50 [info     ] compilation took 1088.1632604240003 sec.\n2025-03-14 21:36:50 [info     ] starting inference run...     \n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:28<00:00, 35.32s/it]\n2025-03-14 21:36:55 [info     ] compilation took 1096.8027802760043 sec.\n2025-03-14 21:36:56 [info     ] starting inference run...     \n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:59<00:00, 36.40s/it]\n2025-03-14 21:37:08 [info     ] compilation took 1113.8591305939917 sec.\n2025-03-14 21:37:08 [info     ] starting inference run...     \n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:55<00:00, 36.26s/it]\n2025-03-14 21:37:22 [info     ] compilation took 1120.5590810020076 sec.\n2025-03-14 21:37:22 [info     ] starting inference run...     \n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00,  5.00it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00,  2.98it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00,  4.08it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00,  2.82it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:08<00:00,  3.34it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00,  4.22it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00,  4.09it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:11<00:00,  2.41it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00,  4.50it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00,  5.10it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00,  5.27it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00,  4.80it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00,  5.39it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00,  5.39it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.67it/s]\n 29%|█████████████████████████████████████████████████████████████████████████████▍                                                                                                                                                                                                 | 8/28 [00:01<00:03,  6.08it/s]/home/jfacevedo_google_com/diffusers/src/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast\n  images = (images * 255).round().astype(\"uint8\")\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.82it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.93it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.02it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.02it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.98it/s]\n 71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                             | 20/28 [00:03<00:01,  6.03it/s]2025-03-14 21:38:32 [info     ] inference time: 5.962021178987925\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.89it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.09it/s]\n2025-03-14 21:38:32 [info     ] avg. inference over 5 iterations took 7.2685392687970305 sec.\n2025-03-14 21:38:32 [info     ] avg. inference over 5 iterations took 7.402720856998348 sec.\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.01it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.89it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.96it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.06it/s]\n 71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                             | 20/28 [00:03<00:01,  6.01it/s]2025-03-14 21:38:38 [info     ] inference time: 5.950578948002658\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.87it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.09it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.00it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.86it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.99it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.05it/s]\n2025-03-14 21:38:43 [info     ] avg. inference over 5 iterations took 6.763298449796276 sec.\n 71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                             | 20/28 [00:03<00:01,  6.04it/s]2025-03-14 21:38:44 [info     ] inference time: 5.949129879008979\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.92it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.10it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.02it/s]\n 39%|██████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                                                                    | 11/28 [00:01<00:02,  5.98it/s]2025-03-14 21:38:46 [info     ] avg. inference over 5 iterations took 7.221068455604836 sec.\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.96it/s]\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.08it/s]\n 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                   | 26/28 [00:04<00:00,  5.92it/s]2025-03-14 21:38:50 [info     ] inference time: 5.954778069004533\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.90it/s]\n 11%|█████████████████████████████                                                                                                                                                                                                                                                  | 3/28 [00:00<00:04,  6.03it/s]2025-03-14 21:38:50 [info     ] avg. inference over 5 iterations took 6.05970350120042 sec.\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.02it/s]\n 32%|███████████████████████████████████████████████████████████████████████████████████████                                                                                                                                                                                        | 9/28 [00:01<00:03,  5.99it/s]2025-03-14 21:38:51 [info     ] avg. inference over 5 iterations took 6.018543455796316 sec.\n 54%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                                                             | 15/28 [00:02<00:02,  6.00it/s]2025-03-14 21:38:52 [info     ] avg. inference over 5 iterations took 5.9609976705978625 sec.\n100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.97it/s]\n2025-03-14 21:38:56 [info     ] inference time: 5.944058528999449\n2025-03-14 21:38:56 [info     ] avg. inference over 5 iterations took 5.952113320800708 sec.\n2025-03-14 21:38:56 [info     ] saved metric information as /tmp/metrics_report.txt\n```"
  },
  {
    "path": "examples/research_projects/pytorch_xla/inference/flux/flux_inference.py",
    "content": "from argparse import ArgumentParser\nfrom pathlib import Path\nfrom time import perf_counter\n\nimport structlog\nimport torch\nimport torch_xla.core.xla_model as xm\nimport torch_xla.debug.metrics as met\nimport torch_xla.debug.profiler as xp\nimport torch_xla.distributed.xla_multiprocessing as xmp\nimport torch_xla.runtime as xr\nfrom torch_xla.experimental.custom_kernel import FlashAttention\n\nfrom diffusers import FluxPipeline\n\n\nlogger = structlog.get_logger()\nmetrics_filepath = \"/tmp/metrics_report.txt\"\n\n\ndef _main(index, args, text_pipe, ckpt_id):\n    cache_path = Path(\"/tmp/data/compiler_cache_tRiLlium_eXp\")\n    cache_path.mkdir(parents=True, exist_ok=True)\n    xr.initialize_cache(str(cache_path), readonly=False)\n\n    profile_path = Path(\"/tmp/data/profiler_out_tRiLlium_eXp\")\n    profile_path.mkdir(parents=True, exist_ok=True)\n    profiler_port = 9012\n    profile_duration = args.profile_duration\n    if args.profile:\n        logger.info(f\"starting profiler on port {profiler_port}\")\n        _ = xp.start_server(profiler_port)\n    device0 = xm.xla_device()\n\n    logger.info(f\"loading flux from {ckpt_id}\")\n    flux_pipe = FluxPipeline.from_pretrained(\n        ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16\n    ).to(device0)\n    flux_pipe.transformer.enable_xla_flash_attention(partition_spec=(\"data\", None, None, None), is_flux=True)\n    FlashAttention.DEFAULT_BLOCK_SIZES = {\n        \"block_q\": 1536,\n        \"block_k_major\": 1536,\n        \"block_k\": 1536,\n        \"block_b\": 1536,\n        \"block_q_major_dkv\": 1536,\n        \"block_k_major_dkv\": 1536,\n        \"block_q_dkv\": 1536,\n        \"block_k_dkv\": 1536,\n        \"block_q_dq\": 1536,\n        \"block_k_dq\": 1536,\n        \"block_k_major_dq\": 1536,\n    }\n\n    prompt = \"photograph of an electronics chip in the shape of a race car with trillium written on its side\"\n    width = args.width\n    height = args.height\n    guidance = args.guidance\n    n_steps = 4 if args.schnell else 28\n\n    logger.info(\"starting compilation run...\")\n    ts = perf_counter()\n    with torch.no_grad():\n        prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(\n            prompt=prompt, prompt_2=None, max_sequence_length=512\n        )\n    prompt_embeds = prompt_embeds.to(device0)\n    pooled_prompt_embeds = pooled_prompt_embeds.to(device0)\n\n    image = flux_pipe(\n        prompt_embeds=prompt_embeds,\n        pooled_prompt_embeds=pooled_prompt_embeds,\n        num_inference_steps=28,\n        guidance_scale=guidance,\n        height=height,\n        width=width,\n    ).images[0]\n    logger.info(f\"compilation took {perf_counter() - ts} sec.\")\n    image.save(\"/tmp/compile_out.png\")\n\n    base_seed = 4096 if args.seed is None else args.seed\n    seed_range = 1000\n    unique_seed = base_seed + index * seed_range\n    xm.set_rng_state(seed=unique_seed, device=device0)\n    times = []\n    logger.info(\"starting inference run...\")\n    with torch.no_grad():\n        prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(\n            prompt=prompt, prompt_2=None, max_sequence_length=512\n        )\n    prompt_embeds = prompt_embeds.to(device0)\n    pooled_prompt_embeds = pooled_prompt_embeds.to(device0)\n    for _ in range(args.itters):\n        ts = perf_counter()\n\n        if args.profile:\n            xp.trace_detached(f\"localhost:{profiler_port}\", str(profile_path), duration_ms=profile_duration)\n        image = flux_pipe(\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            num_inference_steps=n_steps,\n            guidance_scale=guidance,\n            height=height,\n            width=width,\n        ).images[0]\n        inference_time = perf_counter() - ts\n        if index == 0:\n            logger.info(f\"inference time: {inference_time}\")\n        times.append(inference_time)\n    logger.info(f\"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.\")\n    image.save(f\"/tmp/inference_out-{index}.png\")\n    if index == 0:\n        metrics_report = met.metrics_report()\n        with open(metrics_filepath, \"w+\") as fout:\n            fout.write(metrics_report)\n        logger.info(f\"saved metric information as {metrics_filepath}\")\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\"--schnell\", action=\"store_true\", help=\"run flux schnell instead of dev\")\n    parser.add_argument(\"--width\", type=int, default=1024, help=\"width of the image to generate\")\n    parser.add_argument(\"--height\", type=int, default=1024, help=\"height of the image to generate\")\n    parser.add_argument(\"--guidance\", type=float, default=3.5, help=\"guidance strength for dev\")\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"seed for inference\")\n    parser.add_argument(\"--profile\", action=\"store_true\", help=\"enable profiling\")\n    parser.add_argument(\"--profile-duration\", type=int, default=10000, help=\"duration for profiling in msec.\")\n    parser.add_argument(\"--itters\", type=int, default=15, help=\"items to run inference and get avg time in sec.\")\n    args = parser.parse_args()\n    if args.schnell:\n        ckpt_id = \"black-forest-labs/FLUX.1-schnell\"\n    else:\n        ckpt_id = \"black-forest-labs/FLUX.1-dev\"\n    text_pipe = FluxPipeline.from_pretrained(ckpt_id, transformer=None, vae=None, torch_dtype=torch.bfloat16).to(\"cpu\")\n    xmp.spawn(_main, args=(args, text_pipe, ckpt_id))\n"
  },
  {
    "path": "examples/research_projects/pytorch_xla/training/text_to_image/README.md",
    "content": "# Stable Diffusion text-to-image fine-tuning using PyTorch/XLA\n\nThe `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.\n\nIt has been tested on v4 and v5p TPU versions. Training code has been tested on multi-host. \n\nThis script implements Distributed Data Parallel using GSPMD feature in XLA compiler\nwhere we shard the input batches over the TPU devices. \n\nAs of 10-31-2024, these are some expected step times.\n\n| accelerator | global batch size | step time (seconds) |\n| ----------- | ----------------- | --------- |\n| v5p-512 | 16384 | 1.01 |\n| v5p-256 | 8192 | 1.01 |\n| v5p-128 | 4096 | 1.0 |\n| v5p-64 | 2048 | 1.01 |\n\n## Create TPU\n\nTo create a TPU on Google Cloud first set these environment variables:\n\n```bash\nexport TPU_NAME=<tpu-name>\nexport PROJECT_ID=<project-id>\nexport ZONE=<google-cloud-zone>\nexport ACCELERATOR_TYPE=<accelerator type like v5p-8>\nexport RUNTIME_VERSION=<runtime version like v2-alpha-tpuv5 for v5p>\n```\n\nThen run the create TPU command:\n```bash\ngcloud alpha compute tpus tpu-vm create ${TPU_NAME} --project ${PROJECT_ID} \n--zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION} \n--reserved\n```\n\nYou can also use other ways to reserve TPUs like GKE or queued resources.\n\n## Setup TPU environment\n\nInstall PyTorch and PyTorch/XLA nightly versions:\n```bash\ngcloud compute tpus tpu-vm ssh ${TPU_NAME} \\\n--project=${PROJECT_ID} --zone=${ZONE} --worker=all \\\n--command='\npip3 install --pre torch==2.6.0.dev20241031+cpu torchvision --index-url https://download.pytorch.org/whl/nightly/cpu\npip3 install \"torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241031.cxx11-cp310-cp310-linux_x86_64.whl\" -f https://storage.googleapis.com/libtpu-releases/index.html\npip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html\n'\n```\n\nVerify that PyTorch and PyTorch/XLA were installed correctly:\n\n```bash\ngcloud compute tpus tpu-vm ssh ${TPU_NAME} \\\n--project ${PROJECT_ID} --zone ${ZONE} --worker=all \\\n--command='python3 -c \"import torch; import torch_xla;\"'\n```\n\nInstall dependencies:\n```bash\ngcloud compute tpus tpu-vm ssh ${TPU_NAME} \\\n--project=${PROJECT_ID} --zone=${ZONE} --worker=all \\\n--command='\ngit clone https://github.com/huggingface/diffusers.git\ncd diffusers\ngit checkout main\ncd examples/research_projects/pytorch_xla\npip3 install -r requirements.txt\npip3 install pillow --upgrade\ncd ../../..\npip3 install .'\n```\n\n## Run the training job\n\n### Authenticate\n\nRun the following command to authenticate your token.\n\n```bash\nhf auth login\n```\n\nThis script only trains the unet part of the network. The VAE and text encoder\nare fixed.\n\n```bash\ngcloud compute tpus tpu-vm ssh ${TPU_NAME} \\\n--project=${PROJECT_ID} --zone=${ZONE} --worker=all \\\n--command='\nexport XLA_DISABLE_FUNCTIONALIZATION=0\nexport PROFILE_DIR=/tmp/\nexport CACHE_DIR=/tmp/\nexport DATASET_NAME=lambdalabs/naruto-blip-captions\nexport PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p\nexport TRAIN_STEPS=50\nexport OUTPUT_DIR=/tmp/trained-model/\npython diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE  --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'\n```\n\nPass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer. \n\n### Environment Envs Explained\n\n*   `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.\n*   `PROFILE_DIR`: Specify where to put the profiling results.\n*   `CACHE_DIR`: Directory to store XLA compiled graphs for persistent caching.\n*   `DATASET_NAME`: Dataset to train the model. \n*   `PER_HOST_BATCH_SIZE`: Size of the batch to load per CPU host. For e.g. for a v5p-16 with 2 CPU hosts, the global batch size will be 2xPER_HOST_BATCH_SIZE. The input batch is sharded along the batch axis.\n*    `TRAIN_STEPS`: Total number of training steps to run the training for.\n*    `OUTPUT_DIR`: Directory to store the fine-tuned model.\n\n## Run inference using the output model\n\nTo run inference using the output, you can simply load the model and pass it\ninput prompts. The first pass will compile the graph and takes longer with the following passes running much faster.\n\n```bash\nexport CACHE_DIR=/tmp/\n```\n\n```python\nimport torch\nimport os\nimport sys\nimport  numpy as np\n\nimport torch_xla.core.xla_model as xm\nfrom time import time\nfrom diffusers import StableDiffusionPipeline\nimport torch_xla.runtime as xr\n\nCACHE_DIR = os.environ.get(\"CACHE_DIR\", None)\nif CACHE_DIR:\n    xr.initialize_cache(CACHE_DIR, readonly=False)\n\ndef main():\n    device = xm.xla_device()\n    model_path = \"jffacevedo/pxla_trained_model\"\n    pipe = StableDiffusionPipeline.from_pretrained(\n        model_path, \n        torch_dtype=torch.bfloat16\n    )\n    pipe.to(device)\n    prompt = [\"A naruto with green eyes and red legs.\"]\n    start = time()\n    print(\"compiling...\")\n    image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]\n    print(f\"compile time: {time() - start}\")\n    print(\"generate...\")\n    start = time()\n    image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]\n    print(f\"generation time (after compile) : {time() - start}\")\n    image.save(\"naruto.png\")\n\nif __name__ == '__main__':\n    main()\n```\n\nExpected Results:\n\n```bash\ncompiling...\n100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [10:03<00:00, 20.10s/it]\ncompile time: 720.656970500946\ngenerate...\n100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 17.65it/s]\ngeneration time (after compile) : 1.8461642265319824"
  },
  {
    "path": "examples/research_projects/pytorch_xla/training/text_to_image/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\ndatasets>=2.19.1\nftfy\ntensorboard\nJinja2\npeft==0.7.0\n"
  },
  {
    "path": "examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py",
    "content": "import argparse\nimport os\nimport random\nimport time\nfrom pathlib import Path\n\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport torch_xla.core.xla_model as xm\nimport torch_xla.debug.profiler as xp\nimport torch_xla.distributed.parallel_loader as pl\nimport torch_xla.distributed.spmd as xs\nimport torch_xla.runtime as xr\nfrom huggingface_hub import create_repo, upload_folder\nfrom torchvision import transforms\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.training_utils import compute_snr\nfrom diffusers.utils import is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\n\n\nif is_wandb_available():\n    pass\n\nPROFILE_DIR = os.environ.get(\"PROFILE_DIR\", None)\nCACHE_DIR = os.environ.get(\"CACHE_DIR\", None)\nif CACHE_DIR:\n    xr.initialize_cache(CACHE_DIR, readonly=False)\nxr.use_spmd()\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\nPORT = 9012\n\n\ndef save_model_card(\n    args,\n    repo_id: str,\n    repo_folder: str = None,\n):\n    model_description = f\"\"\"\n# Text-to-image finetuning - {repo_id}\n\nThis pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. \\n\n\n## Pipeline usage\n\nYou can use the pipeline like so:\n\n```python\nimport torch\nimport os\nimport sys\nimport  numpy as np\n\nimport torch_xla.core.xla_model as xm\nfrom time import time\nfrom typing import Tuple\nfrom diffusers import StableDiffusionPipeline\n\ndef main(args):\n    device = xm.xla_device()\n    model_path = <output_dir>\n    pipe = StableDiffusionPipeline.from_pretrained(\n        model_path,\n        torch_dtype=torch.bfloat16\n    )\n    pipe.to(device)\n    prompt = [\"A naruto with green eyes and red legs.\"]\n    image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]\n    image.save(\"naruto.png\")\n\nif __name__ == '__main__':\n    main()\n```\n\n## Training info\n\nThese are the key hyperparameters used during training:\n\n* Steps: {args.max_train_steps}\n* Learning rate: {args.learning_rate}\n* Batch size: {args.train_batch_size}\n* Image resolution: {args.resolution}\n* Mixed-precision: {args.mixed_precision}\n\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=args.pretrained_model_name_or_path,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\"stable-diffusion\", \"stable-diffusion-diffusers\", \"text-to-image\", \"diffusers\", \"diffusers-training\"]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\nclass TrainSD:\n    def __init__(\n        self,\n        vae,\n        weight_dtype,\n        device,\n        noise_scheduler,\n        unet,\n        optimizer,\n        text_encoder,\n        dataloader,\n        args,\n    ):\n        self.vae = vae\n        self.weight_dtype = weight_dtype\n        self.device = device\n        self.noise_scheduler = noise_scheduler\n        self.unet = unet\n        self.optimizer = optimizer\n        self.text_encoder = text_encoder\n        self.args = args\n        self.mesh = xs.get_global_mesh()\n        self.dataloader = iter(dataloader)\n        self.global_step = 0\n\n    def run_optimizer(self):\n        self.optimizer.step()\n\n    def start_training(self):\n        dataloader_exception = False\n        measure_start_step = args.measure_start_step\n        assert measure_start_step < self.args.max_train_steps\n        total_time = 0\n        for step in range(0, self.args.max_train_steps):\n            try:\n                batch = next(self.dataloader)\n            except Exception as e:\n                dataloader_exception = True\n                print(e)\n                break\n            if step == measure_start_step and PROFILE_DIR is not None:\n                xm.wait_device_ops()\n                xp.trace_detached(f\"localhost:{PORT}\", PROFILE_DIR, duration_ms=args.profile_duration)\n                last_time = time.time()\n            loss = self.step_fn(batch[\"pixel_values\"], batch[\"input_ids\"])\n            self.global_step += 1\n\n            def print_loss_closure(step, loss):\n                print(f\"Step: {step}, Loss: {loss}\")\n\n            if args.print_loss:\n                xm.add_step_closure(\n                    print_loss_closure,\n                    args=(\n                        self.global_step,\n                        loss,\n                    ),\n                )\n        xm.mark_step()\n        if not dataloader_exception:\n            xm.wait_device_ops()\n            total_time = time.time() - last_time\n            print(f\"Average step time: {total_time / (self.args.max_train_steps - measure_start_step)}\")\n        else:\n            print(\"dataloader exception happen, skip result\")\n            return\n\n    def step_fn(\n        self,\n        pixel_values,\n        input_ids,\n    ):\n        with xp.Trace(\"model.forward\"):\n            self.optimizer.zero_grad()\n            latents = self.vae.encode(pixel_values).latent_dist.sample()\n            latents = latents * self.vae.config.scaling_factor\n            noise = torch.randn_like(latents).to(self.device, dtype=self.weight_dtype)\n            bsz = latents.shape[0]\n            timesteps = torch.randint(\n                0,\n                self.noise_scheduler.config.num_train_timesteps,\n                (bsz,),\n                device=latents.device,\n            )\n            timesteps = timesteps.long()\n\n            noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)\n            encoder_hidden_states = self.text_encoder(input_ids, return_dict=False)[0]\n            if self.args.prediction_type is not None:\n                # set prediction_type of scheduler if defined\n                self.noise_scheduler.register_to_config(prediction_type=self.args.prediction_type)\n\n            if self.noise_scheduler.config.prediction_type == \"epsilon\":\n                target = noise\n            elif self.noise_scheduler.config.prediction_type == \"v_prediction\":\n                target = self.noise_scheduler.get_velocity(latents, noise, timesteps)\n            else:\n                raise ValueError(f\"Unknown prediction type {self.noise_scheduler.config.prediction_type}\")\n            model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]\n        with xp.Trace(\"model.backward\"):\n            if self.args.snr_gamma is None:\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n            else:\n                # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                # This is discussed in Section 4.2 of the same paper.\n                snr = compute_snr(self.noise_scheduler, timesteps)\n                mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                    dim=1\n                )[0]\n                if self.noise_scheduler.config.prediction_type == \"epsilon\":\n                    mse_loss_weights = mse_loss_weights / snr\n                elif self.noise_scheduler.config.prediction_type == \"v_prediction\":\n                    mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                loss = loss.mean()\n            loss.backward()\n        with xp.Trace(\"optimizer_step\"):\n            self.run_optimizer()\n        return loss\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\"--profile_duration\", type=int, default=10000, help=\"Profile duration in ms\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-model-finetuned\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--non_ema_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or\"\n            \" remote repository specified with --pretrained_model_name_or_path.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--loader_prefetch_size\",\n        type=int,\n        default=1,\n        help=(\"Number of subprocesses to use for data loading to cpu.\"),\n    )\n    parser.add_argument(\n        \"--loader_prefetch_factor\",\n        type=int,\n        default=2,\n        help=(\"Number of batches loaded in advance by each worker.\"),\n    )\n    parser.add_argument(\n        \"--device_prefetch_size\",\n        type=int,\n        default=1,\n        help=(\"Number of subprocesses to use for data loading to tpu from cpu. \"),\n    )\n    parser.add_argument(\"--measure_start_step\", type=int, default=10, help=\"Step to start profiling.\")\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=None,\n        help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.\",\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"bf16\"],\n        help=(\"Whether to use mixed precision. Bf16 requires PyTorch >= 1.10\"),\n    )\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--print_loss\",\n        default=False,\n        action=\"store_true\",\n        help=(\"Print loss at every step.\"),\n    )\n\n    args = parser.parse_args()\n\n    # default to using the same revision for the non-ema model if not specified\n    if args.non_ema_revision is None:\n        args.non_ema_revision = args.revision\n\n    return args\n\n\ndef setup_optimizer(unet, args):\n    optimizer_cls = torch.optim.AdamW\n    return optimizer_cls(\n        unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n        foreach=True,\n    )\n\n\ndef load_dataset(args):\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = datasets.load_dataset(\n            args.dataset_name,\n            cache_dir=args.cache_dir,\n            data_dir=args.train_data_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = datasets.load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n    return dataset\n\n\ndef get_column_names(dataset, args):\n    column_names = dataset[\"train\"].column_names\n\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    return image_column, caption_column\n\n\ndef main(args):\n    args = parse_args()\n\n    _ = xp.start_server(PORT)\n\n    num_devices = xr.global_runtime_device_count()\n    mesh = xs.get_1d_mesh(\"data\")\n    xs.set_global_mesh(mesh)\n\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"unet\",\n        revision=args.non_ema_revision,\n    )\n\n    if xm.is_master_ordinal() and args.push_to_hub:\n        repo_id = create_repo(\n            repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n        ).repo_id\n\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n\n    from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear\n\n    unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)\n    unet.enable_xla_flash_attention(partition_spec=(\"data\", None, None, None))\n\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    unet.train()\n\n    # For mixed precision training we cast all non-trainable weights (vae,\n    # non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full\n    # precision is not required.\n    weight_dtype = torch.float32\n    if args.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    device = xm.xla_device()\n\n    # Move text_encode and vae to device and cast to weight_dtype\n    text_encoder = text_encoder.to(device, dtype=weight_dtype)\n    vae = vae.to(device, dtype=weight_dtype)\n    unet = unet.to(device, dtype=weight_dtype)\n    optimizer = setup_optimizer(unet, args)\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    unet.train()\n\n    dataset = load_dataset(args)\n    image_column, caption_column = get_column_names(dataset, args)\n\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions,\n            max_length=tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        return inputs.input_ids\n\n    train_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            (transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)),\n            (transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x)),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n        examples[\"input_ids\"] = tokenize_captions(examples)\n        return examples\n\n    train_dataset = dataset[\"train\"]\n    train_dataset.set_format(\"torch\")\n    train_dataset.set_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).to(weight_dtype)\n        input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n        return {\"pixel_values\": pixel_values, \"input_ids\": input_ids}\n\n    g = torch.Generator()\n    g.manual_seed(xr.host_index())\n    sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10), generator=g)\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        sampler=sampler,\n        collate_fn=collate_fn,\n        num_workers=args.dataloader_num_workers,\n        batch_size=args.train_batch_size,\n        prefetch_factor=args.loader_prefetch_factor,\n    )\n\n    train_dataloader = pl.MpDeviceLoader(\n        train_dataloader,\n        device,\n        input_sharding={\n            \"pixel_values\": xs.ShardingSpec(mesh, (\"data\", None, None, None), minibatch=True),\n            \"input_ids\": xs.ShardingSpec(mesh, (\"data\", None), minibatch=True),\n        },\n        loader_prefetch_size=args.loader_prefetch_size,\n        device_prefetch_size=args.device_prefetch_size,\n    )\n\n    num_hosts = xr.process_count()\n    num_devices_per_host = num_devices // num_hosts\n    if xm.is_master_ordinal():\n        print(\"***** Running training *****\")\n        print(f\"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host}\")\n        print(\n            f\"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}\"\n        )\n        print(f\"  Total optimization steps = {args.max_train_steps}\")\n\n    trainer = TrainSD(\n        vae=vae,\n        weight_dtype=weight_dtype,\n        device=device,\n        noise_scheduler=noise_scheduler,\n        unet=unet,\n        optimizer=optimizer,\n        text_encoder=text_encoder,\n        dataloader=train_dataloader,\n        args=args,\n    )\n\n    trainer.start_training()\n    unet = trainer.unet.to(\"cpu\")\n    vae = trainer.vae.to(\"cpu\")\n    text_encoder = trainer.text_encoder.to(\"cpu\")\n\n    pipeline = StableDiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        text_encoder=text_encoder,\n        vae=vae,\n        unet=unet,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    pipeline.save_pretrained(args.output_dir)\n\n    if xm.is_master_ordinal() and args.push_to_hub:\n        save_model_card(args, repo_id, repo_folder=args.output_dir)\n        upload_folder(\n            repo_id=repo_id,\n            folder_path=args.output_dir,\n            commit_message=\"End of training\",\n            ignore_patterns=[\"step_*\", \"epoch_*\"],\n        )\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/rdm/README.md",
    "content": "## Diffusers examples with ONNXRuntime optimizations\n\n**This research project is not actively maintained by the diffusers team. For any questions or comments, please contact Isamu Isozaki(isamu-isozaki) on github with any questions.**\n\nThe aim of this project is to provide retrieval augmented diffusion models to diffusers!"
  },
  {
    "path": "examples/research_projects/rdm/pipeline_rdm.py",
    "content": "import inspect\nfrom typing import Callable, List, Optional, Union\n\nimport torch\nfrom PIL import Image\nfrom retriever import Retriever, normalize_images, preprocess_images\nfrom transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDIMScheduler,\n    DiffusionPipeline,\n    DPMSolverMultistepScheduler,\n    EulerAncestralDiscreteScheduler,\n    EulerDiscreteScheduler,\n    ImagePipelineOutput,\n    LMSDiscreteScheduler,\n    PNDMScheduler,\n    UNet2DConditionModel,\n)\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.pipelines.pipeline_utils import StableDiffusionMixin\nfrom diffusers.utils import logging\nfrom diffusers.utils.torch_utils import randn_tensor\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass RDMPipeline(DiffusionPipeline, StableDiffusionMixin):\n    r\"\"\"\n    Pipeline for text-to-image generation using Retrieval Augmented Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        clip ([`CLIPModel`]):\n            Frozen CLIP model. Retrieval Augmented Diffusion uses the CLIP model, specifically the\n            [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        clip: CLIPModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: Union[\n            DDIMScheduler,\n            PNDMScheduler,\n            LMSDiscreteScheduler,\n            EulerDiscreteScheduler,\n            EulerAncestralDiscreteScheduler,\n            DPMSolverMultistepScheduler,\n        ],\n        feature_extractor: CLIPImageProcessor,\n        retriever: Optional[Retriever] = None,\n    ):\n        super().__init__()\n        self.register_modules(\n            vae=vae,\n            clip=clip,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n        )\n        # Copy from statement here and all the methods we take from stable_diffusion_pipeline\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, \"vae\", None) else 8\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.retriever = retriever\n\n    def _encode_prompt(self, prompt):\n        # get prompt text embeddings\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n\n        if text_input_ids.shape[-1] > self.tokenizer.model_max_length:\n            removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n            )\n            text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]\n        prompt_embeds = self.clip.get_text_features(text_input_ids.to(self.device))\n        prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True)\n        prompt_embeds = prompt_embeds[:, None, :]\n        return prompt_embeds\n\n    def _encode_image(self, retrieved_images, batch_size):\n        if len(retrieved_images[0]) == 0:\n            return None\n        for i in range(len(retrieved_images)):\n            retrieved_images[i] = normalize_images(retrieved_images[i])\n            retrieved_images[i] = preprocess_images(retrieved_images[i], self.feature_extractor).to(\n                self.clip.device, dtype=self.clip.dtype\n            )\n        _, c, h, w = retrieved_images[0].shape\n\n        retrieved_images = torch.reshape(torch.cat(retrieved_images, dim=0), (-1, c, h, w))\n        image_embeddings = self.clip.get_image_features(retrieved_images)\n        image_embeddings = image_embeddings / torch.linalg.norm(image_embeddings, dim=-1, keepdim=True)\n        _, d = image_embeddings.shape\n        image_embeddings = torch.reshape(image_embeddings, (batch_size, -1, d))\n        return image_embeddings\n\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            int(height) // self.vae_scale_factor,\n            int(width) // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def retrieve_images(self, retrieved_images, prompt_embeds, knn=10):\n        if self.retriever is not None:\n            additional_images = self.retriever.retrieve_imgs_batch(prompt_embeds[:, 0].cpu(), knn).total_examples\n            for i in range(len(retrieved_images)):\n                retrieved_images[i] += additional_images[i][self.retriever.config.image_column]\n        return retrieved_images\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]],\n        retrieved_images: Optional[List[Image.Image]] = None,\n        height: int = 768,\n        width: int = 768,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: torch.Generator | None = None,\n        latents: Optional[torch.Tensor] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        output_type: str | None = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,\n        callback_steps: Optional[int] = 1,\n        knn: Optional[int] = 10,\n        **kwargs,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`):\n                The prompt or prompts to guide the image generation.\n            height (`int`, *optional*, defaults to 512):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to 512):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n                deterministic.\n            latents (`torch.Tensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will be generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n\n        Returns:\n            [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if\n            `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the\n            generated images.\n        \"\"\"\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n        if isinstance(prompt, str):\n            batch_size = 1\n        elif isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        if retrieved_images is not None:\n            retrieved_images = [retrieved_images for _ in range(batch_size)]\n        else:\n            retrieved_images = [[] for _ in range(batch_size)]\n        device = self._execution_device\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n        if prompt_embeds is None:\n            prompt_embeds = self._encode_prompt(prompt)\n        retrieved_images = self.retrieve_images(retrieved_images, prompt_embeds, knn=knn)\n        image_embeddings = self._encode_image(retrieved_images, batch_size)\n        if image_embeddings is not None:\n            prompt_embeds = torch.cat([prompt_embeds, image_embeddings], dim=1)\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance:\n            uncond_embeddings = torch.zeros_like(prompt_embeds).to(prompt_embeds.device)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([uncond_embeddings, prompt_embeds])\n        # get the initial random noise unless the user supplied it\n        num_channels_latents = self.unet.config.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # set timesteps\n        self.scheduler.set_timesteps(num_inference_steps)\n\n        # Some schedulers like PNDM have timesteps as arrays\n        # It's more optimized to move all timesteps to correct device beforehand\n        timesteps_tensor = self.scheduler.timesteps.to(self.device)\n\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502\n        # and should be between [0, 1]\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n\n        for i, t in enumerate(self.progress_bar(timesteps_tensor)):\n            # expand the latents if we are doing classifier free guidance\n            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n            # predict the noise residual\n            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample\n\n            # perform guidance\n            if do_classifier_free_guidance:\n                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n            # compute the previous noisy sample x_t -> x_t-1\n            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n            # call the callback, if provided\n            if callback is not None and i % callback_steps == 0:\n                step_idx = i // getattr(self.scheduler, \"order\", 1)\n                callback(step_idx, t, latents)\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]\n        else:\n            image = latents\n\n        image = self.image_processor.postprocess(\n            image, output_type=output_type, do_denormalize=[True] * image.shape[0]\n        )\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image,)\n\n        return ImagePipelineOutput(images=image)\n"
  },
  {
    "path": "examples/research_projects/rdm/retriever.py",
    "content": "import os\nfrom typing import List\n\nimport faiss\nimport numpy as np\nimport torch\nfrom datasets import Dataset, load_dataset\nfrom PIL import Image\nfrom transformers import CLIPImageProcessor, CLIPModel, PretrainedConfig\n\nfrom diffusers import logging\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef normalize_images(images: List[Image.Image]):\n    images = [np.array(image) for image in images]\n    images = [image / 127.5 - 1 for image in images]\n    return images\n\n\ndef preprocess_images(images: List[np.array], feature_extractor: CLIPImageProcessor) -> torch.Tensor:\n    \"\"\"\n    Preprocesses a list of images into a batch of tensors.\n\n    Args:\n        images (:obj:`List[Image.Image]`):\n            A list of images to preprocess.\n\n    Returns:\n        :obj:`torch.Tensor`: A batch of tensors.\n    \"\"\"\n    images = [np.array(image) for image in images]\n    images = [(image + 1.0) / 2.0 for image in images]\n    images = feature_extractor(images, return_tensors=\"pt\").pixel_values\n    return images\n\n\nclass IndexConfig(PretrainedConfig):\n    def __init__(\n        self,\n        clip_name_or_path=\"openai/clip-vit-large-patch14\",\n        dataset_name=\"Isamu136/oxford_pets_with_l14_emb\",\n        image_column=\"image\",\n        index_name=\"embeddings\",\n        index_path=None,\n        dataset_set=\"train\",\n        metric_type=faiss.METRIC_L2,\n        faiss_device=-1,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.clip_name_or_path = clip_name_or_path\n        self.dataset_name = dataset_name\n        self.image_column = image_column\n        self.index_name = index_name\n        self.index_path = index_path\n        self.dataset_set = dataset_set\n        self.metric_type = metric_type\n        self.faiss_device = faiss_device\n\n\nclass Index:\n    \"\"\"\n    Each index for a retrieval model is specific to the clip model used and the dataset used.\n    \"\"\"\n\n    def __init__(self, config: IndexConfig, dataset: Dataset):\n        self.config = config\n        self.dataset = dataset\n        self.index_initialized = False\n        self.index_name = config.index_name\n        self.index_path = config.index_path\n        self.init_index()\n\n    def set_index_name(self, index_name: str):\n        self.index_name = index_name\n\n    def init_index(self):\n        if not self.index_initialized:\n            if self.index_path and self.index_name:\n                try:\n                    self.dataset.add_faiss_index(\n                        column=self.index_name, metric_type=self.config.metric_type, device=self.config.faiss_device\n                    )\n                    self.index_initialized = True\n                except Exception as e:\n                    print(e)\n                    logger.info(\"Index not initialized\")\n            if self.index_name in self.dataset.features:\n                self.dataset.add_faiss_index(column=self.index_name)\n                self.index_initialized = True\n\n    def build_index(\n        self,\n        model=None,\n        feature_extractor: CLIPImageProcessor = None,\n        torch_dtype=torch.float32,\n    ):\n        if not self.index_initialized:\n            model = model or CLIPModel.from_pretrained(self.config.clip_name_or_path).to(dtype=torch_dtype)\n            feature_extractor = feature_extractor or CLIPImageProcessor.from_pretrained(self.config.clip_name_or_path)\n            self.dataset = get_dataset_with_emb_from_clip_model(\n                self.dataset,\n                model,\n                feature_extractor,\n                image_column=self.config.image_column,\n                index_name=self.config.index_name,\n            )\n            self.init_index()\n\n    def retrieve_imgs(self, vec, k: int = 20):\n        vec = np.array(vec).astype(np.float32)\n        return self.dataset.get_nearest_examples(self.index_name, vec, k=k)\n\n    def retrieve_imgs_batch(self, vec, k: int = 20):\n        vec = np.array(vec).astype(np.float32)\n        return self.dataset.get_nearest_examples_batch(self.index_name, vec, k=k)\n\n    def retrieve_indices(self, vec, k: int = 20):\n        vec = np.array(vec).astype(np.float32)\n        return self.dataset.search(self.index_name, vec, k=k)\n\n    def retrieve_indices_batch(self, vec, k: int = 20):\n        vec = np.array(vec).astype(np.float32)\n        return self.dataset.search_batch(self.index_name, vec, k=k)\n\n\nclass Retriever:\n    def __init__(\n        self,\n        config: IndexConfig,\n        index: Index = None,\n        dataset: Dataset = None,\n        model=None,\n        feature_extractor: CLIPImageProcessor = None,\n    ):\n        self.config = config\n        self.index = index or self._build_index(config, dataset, model=model, feature_extractor=feature_extractor)\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        retriever_name_or_path: str,\n        index: Index = None,\n        dataset: Dataset = None,\n        model=None,\n        feature_extractor: CLIPImageProcessor = None,\n        **kwargs,\n    ):\n        config = kwargs.pop(\"config\", None) or IndexConfig.from_pretrained(retriever_name_or_path, **kwargs)\n        return cls(config, index=index, dataset=dataset, model=model, feature_extractor=feature_extractor)\n\n    @staticmethod\n    def _build_index(\n        config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPImageProcessor = None\n    ):\n        dataset = dataset or load_dataset(config.dataset_name)\n        dataset = dataset[config.dataset_set]\n        index = Index(config, dataset)\n        index.build_index(model=model, feature_extractor=feature_extractor)\n        return index\n\n    def save_pretrained(self, save_directory):\n        os.makedirs(save_directory, exist_ok=True)\n        if self.config.index_path is None:\n            index_path = os.path.join(save_directory, \"hf_dataset_index.faiss\")\n            self.index.dataset.get_index(self.config.index_name).save(index_path)\n            self.config.index_path = index_path\n        self.config.save_pretrained(save_directory)\n\n    def init_retrieval(self):\n        logger.info(\"initializing retrieval\")\n        self.index.init_index()\n\n    def retrieve_imgs(self, embeddings: np.ndarray, k: int):\n        return self.index.retrieve_imgs(embeddings, k)\n\n    def retrieve_imgs_batch(self, embeddings: np.ndarray, k: int):\n        return self.index.retrieve_imgs_batch(embeddings, k)\n\n    def retrieve_indices(self, embeddings: np.ndarray, k: int):\n        return self.index.retrieve_indices(embeddings, k)\n\n    def retrieve_indices_batch(self, embeddings: np.ndarray, k: int):\n        return self.index.retrieve_indices_batch(embeddings, k)\n\n    def __call__(\n        self,\n        embeddings,\n        k: int = 20,\n    ):\n        return self.index.retrieve_imgs(embeddings, k)\n\n\ndef map_txt_to_clip_feature(clip_model, tokenizer, prompt):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=tokenizer.model_max_length,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n\n    if text_input_ids.shape[-1] > tokenizer.model_max_length:\n        removed_text = tokenizer.batch_decode(text_input_ids[:, tokenizer.model_max_length :])\n        logger.warning(\n            \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n            f\" {tokenizer.model_max_length} tokens: {removed_text}\"\n        )\n        text_input_ids = text_input_ids[:, : tokenizer.model_max_length]\n    text_embeddings = clip_model.get_text_features(text_input_ids.to(clip_model.device))\n    text_embeddings = text_embeddings / torch.linalg.norm(text_embeddings, dim=-1, keepdim=True)\n    text_embeddings = text_embeddings[:, None, :]\n    return text_embeddings[0][0].cpu().detach().numpy()\n\n\ndef map_img_to_model_feature(model, feature_extractor, imgs, device):\n    for i, image in enumerate(imgs):\n        if not image.mode == \"RGB\":\n            imgs[i] = image.convert(\"RGB\")\n    imgs = normalize_images(imgs)\n    retrieved_images = preprocess_images(imgs, feature_extractor).to(device)\n    image_embeddings = model(retrieved_images)\n    image_embeddings = image_embeddings / torch.linalg.norm(image_embeddings, dim=-1, keepdim=True)\n    image_embeddings = image_embeddings[None, ...]\n    return image_embeddings.cpu().detach().numpy()[0][0]\n\n\ndef get_dataset_with_emb_from_model(dataset, model, feature_extractor, image_column=\"image\", index_name=\"embeddings\"):\n    return dataset.map(\n        lambda example: {\n            index_name: map_img_to_model_feature(model, feature_extractor, [example[image_column]], model.device)\n        }\n    )\n\n\ndef get_dataset_with_emb_from_clip_model(\n    dataset, clip_model, feature_extractor, image_column=\"image\", index_name=\"embeddings\"\n):\n    return dataset.map(\n        lambda example: {\n            index_name: map_img_to_model_feature(\n                clip_model.get_image_features, feature_extractor, [example[image_column]], clip_model.device\n            )\n        }\n    )\n"
  },
  {
    "path": "examples/research_projects/realfill/README.md",
    "content": "# RealFill\n\n[RealFill](https://huggingface.co/papers/2309.16668) is a method to personalize text2image inpainting models like stable diffusion inpainting given just a few(1~5) images of a scene.\nThe `train_realfill.py` script shows how to implement the training procedure for stable diffusion inpainting.\n\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\ncd to the realfill folder and run\n```bash\ncd realfill\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell e.g. a notebook\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\n\n### Toy example\n\nNow let's fill the real. For this example, we will use some images of the flower girl example from the paper.\n\nWe already provide some images for testing in [this link](https://github.com/thuanz123/realfill/tree/main/data/flowerwoman)\n\nYou only have to launch the training using:\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-2-inpainting\"\nexport TRAIN_DIR=\"data/flowerwoman\"\nexport OUTPUT_DIR=\"flowerwoman-model\"\n\naccelerate launch train_realfill.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$TRAIN_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --resolution=512 \\\n  --train_batch_size=16 \\\n  --gradient_accumulation_steps=1 \\\n  --unet_learning_rate=2e-4 \\\n  --text_encoder_learning_rate=4e-5 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=100 \\\n  --max_train_steps=2000 \\\n  --lora_rank=8 \\\n  --lora_dropout=0.1 \\\n  --lora_alpha=16 \\\n```\n\n### Training on a low-memory GPU:\n\nIt is possible to run realfill on a low-memory GPU by using the following optimizations:\n- [gradient checkpointing and the 8-bit optimizer](#training-with-gradient-checkpointing-and-8-bit-optimizers)\n- [xformers](#training-with-xformers)\n- [setting grads to none](#set-grads-to-none)\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-2-inpainting\"\nexport TRAIN_DIR=\"data/flowerwoman\"\nexport OUTPUT_DIR=\"flowerwoman-model\"\n\naccelerate launch train_realfill.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$TRAIN_DIR \\\n  --output_dir=$OUTPUT_DIR \\\n  --resolution=512 \\\n  --train_batch_size=16 \\\n  --gradient_accumulation_steps=1 --gradient_checkpointing \\\n  --use_8bit_adam \\\n  --enable_xformers_memory_efficient_attention \\\n  --set_grads_to_none \\\n  --unet_learning_rate=2e-4 \\\n  --text_encoder_learning_rate=4e-5 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=100 \\\n  --max_train_steps=2000 \\\n  --lora_rank=8 \\\n  --lora_dropout=0.1 \\\n  --lora_alpha=16 \\\n```\n\n### Training with gradient checkpointing and 8-bit optimizers:\n\nWith the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train realfill on a 16GB GPU.\n\nTo install `bitsandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).\n\n### Training with xformers:\nYou can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script.\n\n### Set grads to none\n\nTo save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.\n\nMore info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\n\n## Acknowledge\nThis repo is built upon the code of DreamBooth from diffusers and we thank the developers for their great works and efforts to release source code. Furthermore, a special \"thank you\" to RealFill's authors for publishing such an amazing work.\n"
  },
  {
    "path": "examples/research_projects/realfill/infer.py",
    "content": "import argparse\nimport os\n\nimport torch\nfrom PIL import Image, ImageFilter\nfrom transformers import CLIPTextModel\n\nfrom diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel\n\n\nparser = argparse.ArgumentParser(description=\"Inference\")\nparser.add_argument(\n    \"--model_path\",\n    type=str,\n    default=None,\n    required=True,\n    help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n)\nparser.add_argument(\n    \"--validation_image\",\n    type=str,\n    default=None,\n    required=True,\n    help=\"The directory of the validation image\",\n)\nparser.add_argument(\n    \"--validation_mask\",\n    type=str,\n    default=None,\n    required=True,\n    help=\"The directory of the validation mask\",\n)\nparser.add_argument(\n    \"--output_dir\",\n    type=str,\n    default=\"./test-infer/\",\n    help=\"The output directory where predictions are saved\",\n)\nparser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible inference.\")\n\nargs = parser.parse_args()\n\nif __name__ == \"__main__\":\n    os.makedirs(args.output_dir, exist_ok=True)\n    generator = None\n\n    # create & load model\n    pipe = StableDiffusionInpaintPipeline.from_pretrained(\n        \"stabilityai/stable-diffusion-2-inpainting\", torch_dtype=torch.float32, revision=None\n    )\n\n    pipe.unet = UNet2DConditionModel.from_pretrained(\n        args.model_path,\n        subfolder=\"unet\",\n        revision=None,\n    )\n    pipe.text_encoder = CLIPTextModel.from_pretrained(\n        args.model_path,\n        subfolder=\"text_encoder\",\n        revision=None,\n    )\n    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n    pipe = pipe.to(\"cuda\")\n\n    if args.seed is not None:\n        generator = torch.Generator(device=\"cuda\").manual_seed(args.seed)\n\n    image = Image.open(args.validation_image)\n    mask_image = Image.open(args.validation_mask)\n\n    results = pipe(\n        [\"a photo of sks\"] * 16,\n        image=image,\n        mask_image=mask_image,\n        num_inference_steps=25,\n        guidance_scale=5,\n        generator=generator,\n    ).images\n\n    erode_kernel = ImageFilter.MaxFilter(3)\n    mask_image = mask_image.filter(erode_kernel)\n\n    blur_kernel = ImageFilter.BoxBlur(1)\n    mask_image = mask_image.filter(blur_kernel)\n\n    for idx, result in enumerate(results):\n        result = Image.composite(result, image, mask_image)\n        result.save(f\"{args.output_dir}/{idx}.png\")\n\n    del pipe\n    torch.cuda.empty_cache()\n"
  },
  {
    "path": "examples/research_projects/realfill/requirements.txt",
    "content": "diffusers==0.20.1\naccelerate==0.23.0\ntransformers==4.38.0\npeft==0.5.0\ntorch==2.2.0\ntorchvision>=0.16\nftfy==6.1.1\ntensorboard==2.14.0\nJinja2==3.1.6\n"
  },
  {
    "path": "examples/research_projects/realfill/train_realfill.py",
    "content": "import argparse\nimport copy\nimport itertools\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport torchvision.transforms.v2 as transforms_v2\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig, PeftModel, get_peft_model\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, CLIPTextModel\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DPMSolverMultistepScheduler,\n    StableDiffusionInpaintPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.20.1\")\n\nlogger = get_logger(__name__)\n\n\ndef make_mask(images, resolution, times=30):\n    mask, times = torch.ones_like(images[0:1, :, :]), np.random.randint(1, times)\n    min_size, max_size, margin = np.array([0.03, 0.25, 0.01]) * resolution\n    max_size = min(max_size, resolution - margin * 2)\n\n    for _ in range(times):\n        width = np.random.randint(int(min_size), int(max_size))\n        height = np.random.randint(int(min_size), int(max_size))\n\n        x_start = np.random.randint(int(margin), resolution - int(margin) - width + 1)\n        y_start = np.random.randint(int(margin), resolution - int(margin) - height + 1)\n        mask[:, y_start : y_start + height, x_start : x_start + width] = 0\n\n    mask = 1 - mask if random.random() < 0.5 else mask\n    return mask\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model=str,\n    repo_folder=None,\n):\n    img_str = \"\"\n    for i, image in enumerate(images):\n        image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n        img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {base_model}\nprompt: \"a photo of sks\"\ntags:\n- stable-diffusion-inpainting\n- stable-diffusion-inpainting-diffusers\n- text-to-image\n- diffusers\n- realfill\n- diffusers-training\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# RealFill - {repo_id}\n\nThis is a realfill model derived from {base_model}. The weights were trained using [RealFill](https://realfill.github.io/).\nYou can find some example images in the following. \\n\n{img_str}\n\"\"\"\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef log_validation(\n    text_encoder,\n    tokenizer,\n    unet,\n    args,\n    accelerator,\n    weight_dtype,\n    epoch,\n):\n    logger.info(f\"Running validation... \\nGenerating {args.num_validation_images} images\")\n\n    # create pipeline (note: unet and vae are loaded again in float32)\n    pipeline = StableDiffusionInpaintPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        tokenizer=tokenizer,\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n    )\n\n    # set `keep_fp32_wrapper` to True because we do not want to remove\n    # mixed precision hooks while we are still training\n    pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)\n    pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)\n    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    target_dir = Path(args.train_data_dir) / \"target\"\n    target_image, target_mask = target_dir / \"target.png\", target_dir / \"mask.png\"\n    image, mask_image = Image.open(target_image), Image.open(target_mask)\n\n    if image.mode != \"RGB\":\n        image = image.convert(\"RGB\")\n\n    images = []\n    for _ in range(args.num_validation_images):\n        image = pipeline(\n            prompt=\"a photo of sks\",\n            image=image,\n            mask_image=mask_image,\n            num_inference_steps=25,\n            guidance_scale=5,\n            generator=generator,\n        ).images[0]\n        images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log({\"validation\": [wandb.Image(image, caption=str(i)) for i, image in enumerate(images)]})\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    return images\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A folder containing the training data of images.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_conditioning`.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run realfill validation every X steps. RealFill validation consists of running the conditioning\"\n            \" `args.validation_conditioning` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"realfill-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--unet_learning_rate\",\n        type=float,\n        default=2e-4,\n        help=\"Learning rate to use for unet.\",\n    )\n    parser.add_argument(\n        \"--text_encoder_learning_rate\",\n        type=float,\n        default=4e-5,\n        help=\"Learning rate to use for text encoder.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--wandb_key\",\n        type=str,\n        default=None,\n        help=(\"If report to option is set to wandb, api-key for wandb used for login to wandb \"),\n    )\n    parser.add_argument(\n        \"--wandb_project_name\",\n        type=str,\n        default=None,\n        help=(\"If report to option is set to wandb, project name in wandb for log tracking  \"),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--lora_rank\",\n        type=int,\n        default=16,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_alpha\",\n        type=int,\n        default=27,\n        help=(\"The alpha constant of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--lora_dropout\",\n        type=float,\n        default=0.0,\n        help=\"The dropout rate of the LoRA update matrices.\",\n    )\n    parser.add_argument(\n        \"--lora_bias\",\n        type=str,\n        default=\"none\",\n        help=\"The bias type of the Lora update matrices. Must be 'none', 'all' or 'lora_only'.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\nclass RealFillDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the training and conditioning images and\n    the masks with the dummy prompt for fine-tuning the model.\n    It pre-processes the images, masks and tokenizes the prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        train_data_root,\n        tokenizer,\n        size=512,\n    ):\n        self.size = size\n        self.tokenizer = tokenizer\n\n        self.ref_data_root = Path(train_data_root) / \"ref\"\n        self.target_image = Path(train_data_root) / \"target\" / \"target.png\"\n        self.target_mask = Path(train_data_root) / \"target\" / \"mask.png\"\n        if not (self.ref_data_root.exists() and self.target_image.exists() and self.target_mask.exists()):\n            raise ValueError(\"Train images root doesn't exists.\")\n\n        self.train_images_path = list(self.ref_data_root.iterdir()) + [self.target_image]\n        self.num_train_images = len(self.train_images_path)\n        self.train_prompt = \"a photo of sks\"\n\n        self.transform = transforms_v2.Compose(\n            [\n                transforms_v2.ToImage(),\n                transforms_v2.RandomResize(size, int(1.125 * size)),\n                transforms_v2.RandomCrop(size),\n                transforms_v2.ToDtype(torch.float32, scale=True),\n                transforms_v2.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self.num_train_images\n\n    def __getitem__(self, index):\n        example = {}\n\n        image = Image.open(self.train_images_path[index])\n        image = exif_transpose(image)\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        if index < len(self) - 1:\n            weighting = Image.new(\"L\", image.size)\n        else:\n            weighting = Image.open(self.target_mask)\n            weighting = exif_transpose(weighting)\n\n        image, weighting = self.transform(image, weighting)\n        example[\"images\"], example[\"weightings\"] = image, weighting < 0\n\n        if random.random() < 0.1:\n            example[\"masks\"] = torch.ones_like(example[\"images\"][0:1, :, :])\n        else:\n            example[\"masks\"] = make_mask(example[\"images\"], self.size)\n\n        example[\"conditioning_images\"] = example[\"images\"] * (example[\"masks\"] < 0.5)\n\n        train_prompt = \"\" if random.random() < 0.1 else self.train_prompt\n        example[\"prompt_ids\"] = self.tokenizer(\n            train_prompt,\n            truncation=True,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids\n\n        return example\n\n\ndef collate_fn(examples):\n    input_ids = [example[\"prompt_ids\"] for example in examples]\n    images = [example[\"images\"] for example in examples]\n\n    masks = [example[\"masks\"] for example in examples]\n    weightings = [example[\"weightings\"] for example in examples]\n    conditioning_images = [example[\"conditioning_images\"] for example in examples]\n\n    images = torch.stack(images)\n    images = images.to(memory_format=torch.contiguous_format).float()\n\n    masks = torch.stack(masks)\n    masks = masks.to(memory_format=torch.contiguous_format).float()\n\n    weightings = torch.stack(weightings)\n    weightings = weightings.to(memory_format=torch.contiguous_format).float()\n\n    conditioning_images = torch.stack(conditioning_images)\n    conditioning_images = conditioning_images.to(memory_format=torch.contiguous_format).float()\n\n    input_ids = torch.cat(input_ids, dim=0)\n\n    batch = {\n        \"input_ids\": input_ids,\n        \"images\": images,\n        \"masks\": masks,\n        \"weightings\": weightings,\n        \"conditioning_images\": conditioning_images,\n    }\n    return batch\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_dir=logging_dir,\n    )\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n        wandb.login(key=args.wandb_key)\n        wandb.init(project=args.wandb_project_name)\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    if args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.pretrained_model_name_or_path,\n            subfolder=\"tokenizer\",\n            revision=args.revision,\n            use_fast=False,\n        )\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n    )\n\n    config = LoraConfig(\n        r=args.lora_rank,\n        lora_alpha=args.lora_alpha,\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"key\", \"query\", \"value\"],\n        lora_dropout=args.lora_dropout,\n        bias=args.lora_bias,\n    )\n    unet = get_peft_model(unet, config)\n\n    config = LoraConfig(\n        r=args.lora_rank,\n        lora_alpha=args.lora_alpha,\n        target_modules=[\"k_proj\", \"q_proj\", \"v_proj\"],\n        lora_dropout=args.lora_dropout,\n        bias=args.lora_bias,\n    )\n    text_encoder = get_peft_model(text_encoder, config)\n\n    vae.requires_grad_(False)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        text_encoder.gradient_checkpointing_enable()\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            for model in models:\n                sub_dir = (\n                    \"unet\"\n                    if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))\n                    else \"text_encoder\"\n                )\n                model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n    def load_model_hook(models, input_dir):\n        while len(models) > 0:\n            # pop models so that they are not loaded again\n            model = models.pop()\n\n            sub_dir = (\n                \"unet\"\n                if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))\n                else \"text_encoder\"\n            )\n            model_cls = (\n                UNet2DConditionModel\n                if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))\n                else CLIPTextModel\n            )\n\n            load_model = model_cls.from_pretrained(args.pretrained_model_name_or_path, subfolder=sub_dir)\n            load_model = PeftModel.from_pretrained(load_model, input_dir, subfolder=sub_dir)\n\n            model.load_state_dict(load_model.state_dict())\n            del load_model\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.unet_learning_rate = (\n            args.unet_learning_rate\n            * args.gradient_accumulation_steps\n            * args.train_batch_size\n            * accelerator.num_processes\n        )\n\n        args.text_encoder_learning_rate = (\n            args.text_encoder_learning_rate\n            * args.gradient_accumulation_steps\n            * args.train_batch_size\n            * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    optimizer = optimizer_class(\n        [\n            {\"params\": unet.parameters(), \"lr\": args.unet_learning_rate},\n            {\"params\": text_encoder.parameters(), \"lr\": args.text_encoder_learning_rate},\n        ],\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = RealFillDataset(\n        train_data_root=args.train_data_dir,\n        tokenizer=tokenizer,\n        size=args.resolution,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=collate_fn,\n        num_workers=1,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(\n        unet, text_encoder, optimizer, train_dataloader\n    )\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae to device and cast to weight_dtype\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = vars(copy.deepcopy(args))\n        accelerator.init_trackers(\"realfill\", config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        text_encoder.train()\n\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet, text_encoder):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"images\"].to(dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * 0.18215\n\n                # Convert masked images to latent space\n                conditionings = vae.encode(batch[\"conditioning_images\"].to(dtype=weight_dtype)).latent_dist.sample()\n                conditionings = conditionings * 0.18215\n\n                # Downsample mask and weighting so that they match with the latents\n                masks, size = batch[\"masks\"].to(dtype=weight_dtype), latents.shape[2:]\n                masks = F.interpolate(masks, size=size)\n\n                weightings = batch[\"weightings\"].to(dtype=weight_dtype)\n                weightings = F.interpolate(weightings, size=size)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Concatenate noisy latents, masks and conditionings to get inputs to unet\n                inputs = torch.cat([noisy_latents, masks, conditionings], dim=1)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n                # Predict the noise residual\n                model_pred = unet(inputs, timesteps, encoder_hidden_states).sample\n\n                # Compute the diffusion loss\n                assert noise_scheduler.config.prediction_type == \"epsilon\"\n                loss = (weightings * F.mse_loss(model_pred.float(), noise.float(), reduction=\"none\")).mean()\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                if args.report_to == \"wandb\":\n                    accelerator.print(progress_bar)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if global_step % args.validation_steps == 0:\n                        log_validation(\n                            text_encoder,\n                            tokenizer,\n                            unet,\n                            args,\n                            accelerator,\n                            weight_dtype,\n                            global_step,\n                        )\n\n            logs = {\"loss\": loss.detach().item()}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        pipeline = StableDiffusionInpaintPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).merge_and_unload(),\n            text_encoder=accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True).merge_and_unload(),\n            revision=args.revision,\n        )\n\n        pipeline.save_pretrained(args.output_dir)\n\n        # Final inference\n        images = log_validation(\n            text_encoder,\n            tokenizer,\n            unet,\n            args,\n            accelerator,\n            weight_dtype,\n            global_step,\n        )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/sana/README.md",
    "content": "# Training SANA Sprint Diffuser\n\nThis README explains how to use the provided bash script commands to download a pre-trained teacher diffuser model and train it on a specific dataset, following the [SANA Sprint methodology](https://huggingface.co/papers/2503.09641).\n\n\n## Setup\n\n### 1. Define the local paths\n\nSet a variable for your desired output directory. This directory will store the downloaded model and the training checkpoints/results.\n\n```bash\nyour_local_path='output' # Or any other path you prefer\nmkdir -p $your_local_path # Create the directory if it doesn't exist\n```\n\n### 2. Download the pre-trained model\n\nDownload the SANA Sprint teacher model from Hugging Face Hub. The script uses the 1.6B parameter model.\n\n```bash\nhf download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers\n```\n\n*(Optional: You can also download the 0.6B model by replacing the model name: `Efficient-Large-Model/Sana_Sprint_0.6B_1024px_teacher_diffusers`)*\n\n### 3. Acquire the dataset shards\n\nThe training script in this example uses specific `.parquet` shards from a randomly selected `brivangl/midjourney-v6-llava` dataset instead of downloading the entire dataset automatically via `dataset_name`.\n\nThe script specifically uses these three files:\n*   `data/train_000.parquet`\n*   `data/train_001.parquet`\n*   `data/train_002.parquet`\n\n\n\nYou can either:\n\nLet the script download the dataset automatically during first run\n\nOr download it manually\n\n**Note:** The full `brivangl/midjourney-v6-llava` dataset is much larger and contains many more shards. This script example explicitly trains *only* on the three specified shards.\n\n## Usage\n\nOnce the model is downloaded, you can run the training script.\n\n```bash\n\nyour_local_path='output' # Ensure this variable is set\n\npython train_sana_sprint_diffusers.py \\\n    --pretrained_model_name_or_path=$your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers \\\n    --output_dir=$your_local_path \\\n    --mixed_precision=bf16 \\\n    --resolution=1024 \\\n    --learning_rate=1e-6 \\\n    --max_train_steps=30000 \\\n    --dataloader_num_workers=8 \\\n    --dataset_name='brivangl/midjourney-v6-llava' \\\n    --file_path data/train_000.parquet data/train_001.parquet data/train_002.parquet \\\n    --checkpointing_steps=500 --checkpoints_total_limit=10 \\\n    --train_batch_size=1 \\\n    --gradient_accumulation_steps=1 \\\n    --seed=453645634 \\\n    --train_largest_timestep \\\n    --misaligned_pairs_D \\\n    --gradient_checkpointing \\\n    --resume_from_checkpoint=\"latest\" \\\n```\n\n### Explanation of parameters\n\n*   `--pretrained_model_name_or_path`: Path to the downloaded pre-trained model directory.\n*   `--output_dir`: Directory where training logs, checkpoints, and the final model will be saved.\n*   `--mixed_precision`: Use BF16 mixed precision for training, which can save memory and speed up training on compatible hardware.\n*   `--resolution`: The image resolution used for training (1024x1024).\n*   `--learning_rate`: The learning rate for the optimizer.\n*   `--max_train_steps`: The total number of training steps to perform.\n*   `--dataloader_num_workers`: Number of worker processes for loading data. Increase for faster data loading if your CPU and disk can handle it.\n*   `--dataset_name`: The name of the dataset on Hugging Face Hub (`brivangl/midjourney-v6-llava`).\n*   `--file_path`: **Specifies the local paths to the dataset shards to be used for training.** In this case, `data/train_000.parquet`, `data/train_001.parquet`, and `data/train_002.parquet`.\n*   `--checkpointing_steps`: Save a training checkpoint every X steps.\n*   `--checkpoints_total_limit`: Maximum number of checkpoints to keep. Older checkpoints will be deleted.\n*   `--train_batch_size`: The batch size per GPU.\n*   `--gradient_accumulation_steps`: Number of steps to accumulate gradients before performing an optimizer step.\n*   `--seed`: Random seed for reproducibility.\n*   `--train_largest_timestep`: A specific training strategy focusing on larger timesteps.\n*   `--misaligned_pairs_D`: Another specific training strategy to add misaligned image-text pairs as fake data for GAN.\n*   `--gradient_checkpointing`: Enable gradient checkpointing to save GPU memory.\n*   `--resume_from_checkpoint`: Allows resuming training from the latest saved checkpoint in the `--output_dir`.\n\n\n"
  },
  {
    "path": "examples/research_projects/sana/train_sana_sprint_diffusers.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 Sana-Sprint team. All rights reserved.\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport io\nimport logging\nimport math\nimport os\nimport shutil\nfrom pathlib import Path\nfrom typing import Callable, Optional\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport torchvision.transforms as T\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom safetensors.torch import load_file\nfrom torch.nn.utils.spectral_norm import SpectralNorm\nfrom torch.utils.data import DataLoader, Dataset\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, Gemma2Model\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderDC,\n    SanaPipeline,\n    SanaSprintPipeline,\n    SanaTransformer2DModel,\n)\nfrom diffusers.models.attention_processor import Attention\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    free_memory,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.33.0.dev0\")\n\nlogger = get_logger(__name__)\n\nif is_torch_npu_available():\n    torch.npu.config.allow_internal_format = False\n\nCOMPLEX_HUMAN_INSTRUCTION = [\n    \"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:\",\n    \"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.\",\n    \"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\",\n    \"Here are examples of how to transform or refine prompts:\",\n    \"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\",\n    \"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\",\n    \"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:\",\n    \"User Prompt: \",\n]\n\n\nclass SanaVanillaAttnProcessor:\n    r\"\"\"\n    Processor for implementing scaled dot-product attention to support JVP calculation during training.\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    @staticmethod\n    def scaled_dot_product_attention(\n        query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None\n    ) -> torch.Tensor:\n        B, H, L, S = *query.size()[:-1], key.size(-2)\n        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\n        attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)\n\n        if attn_mask is not None:\n            if attn_mask.dtype == torch.bool:\n                attn_bias.masked_fill_(attn_mask.logical_not(), float(\"-inf\"))\n            else:\n                attn_bias += attn_mask\n        attn_weight = query @ key.transpose(-2, -1) * scale_factor\n        attn_weight += attn_bias\n        attn_weight = torch.softmax(attn_weight, dim=-1)\n        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)\n        return attn_weight @ value\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n\n        if attn.norm_q is not None:\n            query = attn.norm_q(query)\n        if attn.norm_k is not None:\n            key = attn.norm_k(key)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        hidden_states = self.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\n\nclass Text2ImageDataset(Dataset):\n    \"\"\"\n    A PyTorch Dataset class for loading text-image pairs from a HuggingFace dataset.\n    This dataset is designed for text-to-image generation tasks.\n    Args:\n        hf_dataset (datasets.Dataset):\n            A HuggingFace dataset containing 'image' (bytes) and 'llava' (text) fields. Note that 'llava' is the field name for text descriptions in this specific dataset - you may need to adjust this key if using a different HuggingFace dataset with a different text field name.\n            resolution (int, optional): Target resolution for image resizing. Defaults to 1024.\n    Returns:\n        dict: A dictionary containing:\n            - 'text': The text description (str)\n            - 'image': The processed image tensor (torch.Tensor) of shape [3, resolution, resolution]\n    \"\"\"\n\n    def __init__(self, hf_dataset, resolution=1024):\n        self.dataset = hf_dataset\n        self.transform = T.Compose(\n            [\n                T.Lambda(lambda img: img.convert(\"RGB\")),\n                T.Resize(resolution),  # Image.BICUBIC\n                T.CenterCrop(resolution),\n                T.ToTensor(),\n                T.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return len(self.dataset)\n\n    def __getitem__(self, idx):\n        item = self.dataset[idx]\n        text = item[\"llava\"]\n        image_bytes = item[\"image\"]\n\n        # Convert bytes to PIL Image\n        image = Image.open(io.BytesIO(image_bytes))\n\n        image_tensor = self.transform(image)\n\n        return {\"text\": text, \"image\": image_tensor}\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    validation_prompt=None,\n    repo_folder=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# Sana Sprint - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} Sana Sprint weights for {base_model}.\n\nThe weights were trained using [Sana-Sprint](https://nvlabs.github.io/Sana/Sprint/).\n\n## License\n\nTODO\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"other\",\n        base_model=base_model,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"sana-sprint\",\n        \"sana-sprint-diffusers\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    if args.enable_vae_tiling:\n        pipeline.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_stride_width=1024)\n\n    pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n\n    images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n\n    return images\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=300,\n        help=\"Maximum sequence length to use with with the Gemma model\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sana-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    # ----Image Processing----\n    parser.add_argument(\"--file_path\", nargs=\"+\", required=True, help=\"List of parquet files (space-separated)\")\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_fix_crop_and_size\",\n        action=\"store_true\",\n        help=\"Whether or not to use the fixed crop and size for the teacher model.\",\n        default=False,\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.2, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.6, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_mean_discriminator\", type=float, default=-0.6, help=\"Logit mean for discriminator timestep sampling\"\n    )\n    parser.add_argument(\n        \"--logit_std_discriminator\", type=float, default=1.0, help=\"Logit std for discriminator timestep sampling\"\n    )\n    parser.add_argument(\"--ladd_multi_scale\", action=\"store_true\", help=\"Whether to use multi-scale discriminator\")\n    parser.add_argument(\n        \"--head_block_ids\",\n        type=int,\n        nargs=\"+\",\n        default=[2, 8, 14, 19],\n        help=\"Specify which transformer blocks to use for discriminator heads\",\n    )\n    parser.add_argument(\"--adv_lambda\", type=float, default=0.5, help=\"Weighting coefficient for adversarial loss\")\n    parser.add_argument(\"--scm_lambda\", type=float, default=1.0, help=\"Weighting coefficient for SCM loss\")\n    parser.add_argument(\"--gradient_clip\", type=float, default=0.1, help=\"Threshold for gradient clipping\")\n    parser.add_argument(\n        \"--sigma_data\", type=float, default=0.5, help=\"Standard deviation of data distribution is supposed to be 0.5\"\n    )\n    parser.add_argument(\n        \"--tangent_warmup_steps\", type=int, default=4000, help=\"Number of warmup steps for tangent vectors\"\n    )\n    parser.add_argument(\n        \"--guidance_embeds_scale\", type=float, default=0.1, help=\"Scaling factor for guidance embeddings\"\n    )\n    parser.add_argument(\n        \"--scm_cfg_scale\", type=float, nargs=\"+\", default=[4, 4.5, 5], help=\"Range for classifier-free guidance scale\"\n    )\n    parser.add_argument(\n        \"--train_largest_timestep\", action=\"store_true\", help=\"Whether to enable special training for large timesteps\"\n    )\n    parser.add_argument(\"--largest_timestep\", type=float, default=1.57080, help=\"Maximum timestep value\")\n    parser.add_argument(\n        \"--largest_timestep_prob\", type=float, default=0.5, help=\"Sampling probability for large timesteps\"\n    )\n    parser.add_argument(\n        \"--misaligned_pairs_D\", action=\"store_true\", help=\"Add misaligned sample pairs for discriminator\"\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--cache_latents\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\n        \"--offload\",\n        action=\"store_true\",\n        help=\"Whether to offload the VAE and the text encoder to CPU when they are not used.\",\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\"--enable_vae_tiling\", action=\"store_true\", help=\"Enabla vae tiling in log validation\")\n    parser.add_argument(\"--enable_npu_flash_attention\", action=\"store_true\", help=\"Enabla Flash Attention for NPU\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, fn: Callable):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return (self.fn(x) + x) / np.sqrt(2)\n\n\nclass SpectralConv1d(nn.Conv1d):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        SpectralNorm.apply(self, name=\"weight\", n_power_iterations=1, dim=0, eps=1e-12)\n\n\nclass BatchNormLocal(nn.Module):\n    def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 8, eps: float = 1e-5):\n        super().__init__()\n        self.virtual_bs = virtual_bs\n        self.eps = eps\n        self.affine = affine\n\n        if self.affine:\n            self.weight = nn.Parameter(torch.ones(num_features))\n            self.bias = nn.Parameter(torch.zeros(num_features))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        shape = x.size()\n\n        # Reshape batch into groups.\n        G = np.ceil(x.size(0) / self.virtual_bs).astype(int)\n        x = x.view(G, -1, x.size(-2), x.size(-1))\n\n        # Calculate stats.\n        mean = x.mean([1, 3], keepdim=True)\n        var = x.var([1, 3], keepdim=True, unbiased=False)\n        x = (x - mean) / (torch.sqrt(var + self.eps))\n\n        if self.affine:\n            x = x * self.weight[None, :, None] + self.bias[None, :, None]\n\n        return x.view(shape)\n\n\ndef make_block(channels: int, kernel_size: int) -> nn.Module:\n    return nn.Sequential(\n        SpectralConv1d(\n            channels,\n            channels,\n            kernel_size=kernel_size,\n            padding=kernel_size // 2,\n            padding_mode=\"circular\",\n        ),\n        BatchNormLocal(channels),\n        nn.LeakyReLU(0.2, True),\n    )\n\n\n# Adapted from https://github.com/autonomousvision/stylegan-t/blob/main/networks/discriminator.py\nclass DiscHead(nn.Module):\n    def __init__(self, channels: int, c_dim: int, cmap_dim: int = 64):\n        super().__init__()\n        self.channels = channels\n        self.c_dim = c_dim\n        self.cmap_dim = cmap_dim\n\n        self.main = nn.Sequential(\n            make_block(channels, kernel_size=1), ResidualBlock(make_block(channels, kernel_size=9))\n        )\n\n        if self.c_dim > 0:\n            self.cmapper = nn.Linear(self.c_dim, cmap_dim)\n            self.cls = SpectralConv1d(channels, cmap_dim, kernel_size=1, padding=0)\n        else:\n            self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0)\n\n    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:\n        h = self.main(x)\n        out = self.cls(h)\n\n        if self.c_dim > 0:\n            cmap = self.cmapper(c).unsqueeze(-1)\n            out = (out * cmap).sum(1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))\n\n        return out\n\n\nclass SanaMSCMDiscriminator(nn.Module):\n    def __init__(self, pretrained_model, is_multiscale=False, head_block_ids=None):\n        super().__init__()\n        self.transformer = pretrained_model\n        self.transformer.requires_grad_(False)\n\n        if head_block_ids is None or len(head_block_ids) == 0:\n            self.block_hooks = {2, 8, 14, 20, 27} if is_multiscale else {self.transformer.depth - 1}\n        else:\n            self.block_hooks = head_block_ids\n\n        heads = []\n        for i in range(len(self.block_hooks)):\n            heads.append(DiscHead(self.transformer.hidden_size, 0, 0))\n        self.heads = nn.ModuleList(heads)\n\n    def get_head_inputs(self):\n        return self.head_inputs\n\n    def forward(self, hidden_states, timestep, encoder_hidden_states=None, **kwargs):\n        feat_list = []\n        self.head_inputs = []\n\n        def get_features(module, input, output):\n            feat_list.append(output)\n            return output\n\n        hooks = []\n        for i, block in enumerate(self.transformer.transformer_blocks):\n            if i in self.block_hooks:\n                hooks.append(block.register_forward_hook(get_features))\n\n        self.transformer(\n            hidden_states=hidden_states,\n            timestep=timestep,\n            encoder_hidden_states=encoder_hidden_states,\n            return_logvar=False,\n            **kwargs,\n        )\n\n        for hook in hooks:\n            hook.remove()\n\n        res_list = []\n        for feat, head in zip(feat_list, self.heads):\n            B, N, C = feat.shape\n            feat = feat.transpose(1, 2)  # [B, C, N]\n            self.head_inputs.append(feat)\n            res_list.append(head(feat, None).reshape(feat.shape[0], -1))\n\n        concat_res = torch.cat(res_list, dim=1)\n\n        return concat_res\n\n    @property\n    def model(self):\n        return self.transformer\n\n    def save_pretrained(self, path):\n        torch.save(self.state_dict(), path)\n\n\nclass DiscHeadModel:\n    def __init__(self, disc):\n        self.disc = disc\n\n    def state_dict(self):\n        return {name: param for name, param in self.disc.state_dict().items() if not name.startswith(\"transformer.\")}\n\n    def __getattr__(self, name):\n        return getattr(self.disc, name)\n\n\nclass SanaTrigFlow(SanaTransformer2DModel):\n    def __init__(self, original_model, guidance=False):\n        self.__dict__ = original_model.__dict__\n        self.hidden_size = self.config.num_attention_heads * self.config.attention_head_dim\n        self.guidance = guidance\n        if self.guidance:\n            hidden_size = self.config.num_attention_heads * self.config.attention_head_dim\n            self.logvar_linear = torch.nn.Linear(hidden_size, 1)\n            torch.nn.init.xavier_uniform_(self.logvar_linear.weight)\n            torch.nn.init.constant_(self.logvar_linear.bias, 0)\n\n    def forward(\n        self, hidden_states, encoder_hidden_states, timestep, guidance=None, jvp=False, return_logvar=False, **kwargs\n    ):\n        batch_size = hidden_states.shape[0]\n        latents = hidden_states\n        prompt_embeds = encoder_hidden_states\n        t = timestep\n\n        # TrigFlow --> Flow Transformation\n        timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype)\n        latents_model_input = latents\n\n        flow_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))\n\n        flow_timestep_expanded = flow_timestep.view(-1, 1, 1, 1)\n        latent_model_input = latents_model_input * torch.sqrt(\n            flow_timestep_expanded**2 + (1 - flow_timestep_expanded) ** 2\n        )\n        latent_model_input = latent_model_input.to(prompt_embeds.dtype)\n\n        # forward in original flow\n\n        if jvp and self.gradient_checkpointing:\n            self.gradient_checkpointing = False\n            model_out = super().forward(\n                hidden_states=latent_model_input,\n                encoder_hidden_states=prompt_embeds,\n                timestep=flow_timestep,\n                guidance=guidance,\n                **kwargs,\n            )[0]\n            self.gradient_checkpointing = True\n        else:\n            model_out = super().forward(\n                hidden_states=latent_model_input,\n                encoder_hidden_states=prompt_embeds,\n                timestep=flow_timestep,\n                guidance=guidance,\n                **kwargs,\n            )[0]\n\n        # Flow --> TrigFlow Transformation\n        trigflow_model_out = (\n            (1 - 2 * flow_timestep_expanded) * latent_model_input\n            + (1 - 2 * flow_timestep_expanded + 2 * flow_timestep_expanded**2) * model_out\n        ) / torch.sqrt(flow_timestep_expanded**2 + (1 - flow_timestep_expanded) ** 2)\n\n        if self.guidance and guidance is not None:\n            timestep, embedded_timestep = self.time_embed(\n                timestep, guidance=guidance, hidden_dtype=hidden_states.dtype\n            )\n        else:\n            timestep, embedded_timestep = self.time_embed(\n                timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype\n            )\n\n        if return_logvar:\n            logvar = self.logvar_linear(embedded_timestep)\n            return trigflow_model_out, logvar\n\n        return (trigflow_model_out,)\n\n\ndef compute_density_for_timestep_sampling_scm(batch_size: int, logit_mean: float = None, logit_std: float = None):\n    \"\"\"Compute the density for sampling the timesteps when doing Sana-Sprint training.\"\"\"\n    sigma = torch.randn(batch_size, device=\"cpu\")\n    sigma = (sigma * logit_std + logit_mean).exp()\n    u = torch.atan(sigma / 0.5)  # TODO: 0.5 should be a hyper-parameter\n\n    return u\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load the tokenizer\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n\n    # Load scheduler and models\n    text_encoder = Gemma2Model.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    vae = AutoencoderDC.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n\n    ori_transformer = SanaTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n        guidance_embeds=True,\n    )\n    ori_transformer.set_attn_processor(SanaVanillaAttnProcessor())\n\n    ori_transformer_no_guide = SanaTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"transformer\",\n        revision=args.revision,\n        variant=args.variant,\n        guidance_embeds=False,\n    )\n\n    original_state_dict = load_file(\n        f\"{args.pretrained_model_name_or_path}/transformer/diffusion_pytorch_model.safetensors\"\n    )\n\n    param_mapping = {\n        \"time_embed.emb.timestep_embedder.linear_1.weight\": \"time_embed.timestep_embedder.linear_1.weight\",\n        \"time_embed.emb.timestep_embedder.linear_1.bias\": \"time_embed.timestep_embedder.linear_1.bias\",\n        \"time_embed.emb.timestep_embedder.linear_2.weight\": \"time_embed.timestep_embedder.linear_2.weight\",\n        \"time_embed.emb.timestep_embedder.linear_2.bias\": \"time_embed.timestep_embedder.linear_2.bias\",\n    }\n\n    for src_key, dst_key in param_mapping.items():\n        if src_key in original_state_dict:\n            ori_transformer.load_state_dict({dst_key: original_state_dict[src_key]}, strict=False, assign=True)\n\n    guidance_embedder_module = ori_transformer.time_embed.guidance_embedder\n\n    zero_state_dict = {}\n\n    target_device = accelerator.device\n    param_w1 = guidance_embedder_module.linear_1.weight\n    zero_state_dict[\"linear_1.weight\"] = torch.zeros(param_w1.shape, device=target_device)\n    param_b1 = guidance_embedder_module.linear_1.bias\n    zero_state_dict[\"linear_1.bias\"] = torch.zeros(param_b1.shape, device=target_device)\n    param_w2 = guidance_embedder_module.linear_2.weight\n    zero_state_dict[\"linear_2.weight\"] = torch.zeros(param_w2.shape, device=target_device)\n    param_b2 = guidance_embedder_module.linear_2.bias\n    zero_state_dict[\"linear_2.bias\"] = torch.zeros(param_b2.shape, device=target_device)\n    guidance_embedder_module.load_state_dict(zero_state_dict, strict=False, assign=True)\n\n    transformer = SanaTrigFlow(ori_transformer, guidance=True).train()\n    pretrained_model = SanaTrigFlow(ori_transformer_no_guide, guidance=False).eval()\n\n    disc = SanaMSCMDiscriminator(\n        pretrained_model,\n        is_multiscale=args.ladd_multi_scale,\n        head_block_ids=args.head_block_ids,\n    ).train()\n\n    transformer.requires_grad_(True)\n    pretrained_model.requires_grad_(False)\n    disc.model.requires_grad_(False)\n    disc.heads.requires_grad_(True)\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    # VAE should always be kept in fp32 for SANA (?)\n    vae.to(accelerator.device, dtype=torch.float32)\n    transformer.to(accelerator.device, dtype=weight_dtype)\n    pretrained_model.to(accelerator.device, dtype=weight_dtype)\n    disc.to(accelerator.device, dtype=weight_dtype)\n    # because Gemma2 is particularly suited for bfloat16.\n    text_encoder.to(dtype=torch.bfloat16)\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            for block in transformer.transformer_blocks:\n                block.attn2.set_use_npu_flash_attention(True)\n            for block in pretrained_model.transformer_blocks:\n                block.attn2.set_use_npu_flash_attention(True)\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu device \")\n\n    # Initialize a text encoding pipeline and keep it to CPU for now.\n    text_encoding_pipeline = SanaPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=None,\n        transformer=None,\n        text_encoder=text_encoder,\n        tokenizer=tokenizer,\n        torch_dtype=torch.bfloat16,\n    )\n    text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                for model in models:\n                    unwrapped_model = unwrap_model(model)\n                    # Handle transformer model\n                    if isinstance(unwrapped_model, type(unwrap_model(transformer))):\n                        model = unwrapped_model\n                        model.save_pretrained(os.path.join(output_dir, \"transformer\"))\n                    # Handle discriminator model (only save heads)\n                    elif isinstance(unwrapped_model, type(unwrap_model(disc))):\n                        # Save only the heads\n                        torch.save(unwrapped_model.heads.state_dict(), os.path.join(output_dir, \"disc_heads.pt\"))\n                    else:\n                        raise ValueError(f\"unexpected save model: {unwrapped_model.__class__}\")\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    if weights:\n                        weights.pop()\n\n        def load_model_hook(models, input_dir):\n            transformer_ = None\n            disc_ = None\n\n            if not accelerator.distributed_type == DistributedType.DEEPSPEED:\n                while len(models) > 0:\n                    model = models.pop()\n                    unwrapped_model = unwrap_model(model)\n\n                    if isinstance(unwrapped_model, type(unwrap_model(transformer))):\n                        transformer_ = model  # noqa: F841\n                    elif isinstance(unwrapped_model, type(unwrap_model(disc))):\n                        # Load only the heads\n                        heads_state_dict = torch.load(os.path.join(input_dir, \"disc_heads.pt\"))\n                        unwrapped_model.heads.load_state_dict(heads_state_dict)\n                        disc_ = model  # noqa: F841\n                    else:\n                        raise ValueError(f\"unexpected save model: {unwrapped_model.__class__}\")\n\n            else:\n                # DeepSpeed case\n                transformer_ = SanaTransformer2DModel.from_pretrained(input_dir, subfolder=\"transformer\")  # noqa: F841\n                disc_heads_state_dict = torch.load(os.path.join(input_dir, \"disc_heads.pt\"))  # noqa: F841\n                # You'll need to handle how to load the heads in DeepSpeed case\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimization parameters\n    optimizer_G = optimizer_class(\n        transformer.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    optimizer_D = optimizer_class(\n        disc.heads.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    hf_dataset = load_dataset(\n        args.dataset_name,\n        data_files=args.file_path,\n        split=\"train\",\n    )\n\n    train_dataset = Text2ImageDataset(\n        hf_dataset=hf_dataset,\n        resolution=args.resolution,\n    )\n\n    train_dataloader = DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n        pin_memory=True,\n        persistent_workers=True,\n        shuffle=True,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer_G,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    transformer, pretrained_model, disc, optimizer_G, optimizer_D, train_dataloader, lr_scheduler = (\n        accelerator.prepare(\n            transformer, pretrained_model, disc, optimizer_G, optimizer_D, train_dataloader, lr_scheduler\n        )\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"sana-sprint\"\n        config = {\n            k: str(v) if not isinstance(v, (int, float, str, bool, torch.Tensor)) else v for k, v in vars(args).items()\n        }\n        accelerator.init_trackers(tracker_name, config=config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    phase = \"G\"\n    vae_config_scaling_factor = vae.config.scaling_factor\n    sigma_data = args.sigma_data\n    negative_prompt = [\"\"] * args.train_batch_size\n    negative_prompt_embeds, negative_prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(\n        prompt=negative_prompt,\n        complex_human_instruction=False,\n        do_classifier_free_guidance=False,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n        disc.train()\n\n        for step, batch in enumerate(train_dataloader):\n            # text encoding\n            prompts = batch[\"text\"]\n            with torch.no_grad():\n                prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(\n                    prompts, complex_human_instruction=COMPLEX_HUMAN_INSTRUCTION, do_classifier_free_guidance=False\n                )\n\n            # Convert images to latent space\n            vae = vae.to(accelerator.device)\n            pixel_values = batch[\"image\"].to(dtype=vae.dtype)\n            model_input = vae.encode(pixel_values).latent\n            model_input = model_input * vae_config_scaling_factor * sigma_data\n            model_input = model_input.to(dtype=weight_dtype)\n\n            # Sample noise that we'll add to the latents\n            noise = torch.randn_like(model_input) * sigma_data\n            bsz = model_input.shape[0]\n\n            # Sample a random timestep for each image\n            # for weighting schemes where we sample timesteps non-uniformly\n            u = compute_density_for_timestep_sampling_scm(\n                batch_size=bsz,\n                logit_mean=args.logit_mean,\n                logit_std=args.logit_std,\n            ).to(accelerator.device)\n\n            # Add noise according to TrigFlow.\n            # zt = cos(t) * x + sin(t) * noise\n            t = u.view(-1, 1, 1, 1)\n            noisy_model_input = torch.cos(t) * model_input + torch.sin(t) * noise\n\n            scm_cfg_scale = torch.tensor(\n                np.random.choice(args.scm_cfg_scale, size=bsz, replace=True),\n                device=accelerator.device,\n            )\n\n            def model_wrapper(scaled_x_t, t):\n                pred, logvar = accelerator.unwrap_model(transformer)(\n                    hidden_states=scaled_x_t,\n                    timestep=t.flatten(),\n                    encoder_hidden_states=prompt_embeds,\n                    encoder_attention_mask=prompt_attention_mask,\n                    guidance=(scm_cfg_scale.flatten() * args.guidance_embeds_scale),\n                    jvp=True,\n                    return_logvar=True,\n                )\n                return pred, logvar\n\n            if phase == \"G\":\n                transformer.train()\n                disc.eval()\n                models_to_accumulate = [transformer]\n                with accelerator.accumulate(models_to_accumulate):\n                    with torch.no_grad():\n                        cfg_x_t = torch.cat([noisy_model_input, noisy_model_input], dim=0)\n                        cfg_t = torch.cat([t, t], dim=0)\n                        cfg_y = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)\n                        cfg_y_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)\n\n                        cfg_pretrain_pred = pretrained_model(\n                            hidden_states=(cfg_x_t / sigma_data),\n                            timestep=cfg_t.flatten(),\n                            encoder_hidden_states=cfg_y,\n                            encoder_attention_mask=cfg_y_mask,\n                        )[0]\n\n                        cfg_dxt_dt = sigma_data * cfg_pretrain_pred\n\n                        dxt_dt_uncond, dxt_dt = cfg_dxt_dt.chunk(2)\n\n                        scm_cfg_scale = scm_cfg_scale.view(-1, 1, 1, 1)\n                        dxt_dt = dxt_dt_uncond + scm_cfg_scale * (dxt_dt - dxt_dt_uncond)\n\n                    v_x = torch.cos(t) * torch.sin(t) * dxt_dt / sigma_data\n                    v_t = torch.cos(t) * torch.sin(t)\n\n                    # Adapt from https://github.com/xandergos/sCM-mnist/blob/master/train_consistency.py\n                    with torch.no_grad():\n                        F_theta, F_theta_grad, logvar = torch.func.jvp(\n                            model_wrapper, (noisy_model_input / sigma_data, t), (v_x, v_t), has_aux=True\n                        )\n\n                    F_theta, logvar = transformer(\n                        hidden_states=(noisy_model_input / sigma_data),\n                        timestep=t.flatten(),\n                        encoder_hidden_states=prompt_embeds,\n                        encoder_attention_mask=prompt_attention_mask,\n                        guidance=(scm_cfg_scale.flatten() * args.guidance_embeds_scale),\n                        return_logvar=True,\n                    )\n\n                    logvar = logvar.view(-1, 1, 1, 1)\n                    F_theta_grad = F_theta_grad.detach()\n                    F_theta_minus = F_theta.detach()\n\n                    # Warmup steps\n                    r = min(1, global_step / args.tangent_warmup_steps)\n\n                    # Calculate gradient g using JVP rearrangement\n                    g = -torch.cos(t) * torch.cos(t) * (sigma_data * F_theta_minus - dxt_dt)\n                    second_term = -r * (torch.cos(t) * torch.sin(t) * noisy_model_input + sigma_data * F_theta_grad)\n                    g = g + second_term\n\n                    # Tangent normalization\n                    g_norm = torch.linalg.vector_norm(g, dim=(1, 2, 3), keepdim=True)\n                    g = g / (g_norm + 0.1)  # 0.1 is the constant c, can be modified but 0.1 was used in the paper\n\n                    sigma = torch.tan(t) * sigma_data\n                    weight = 1 / sigma\n\n                    l2_loss = torch.square(F_theta - F_theta_minus - g)\n\n                    # Calculate loss with normalization factor\n                    loss = (weight / torch.exp(logvar)) * l2_loss + logvar\n\n                    loss = loss.mean()\n\n                    loss_no_logvar = weight * torch.square(F_theta - F_theta_minus - g)\n                    loss_no_logvar = loss_no_logvar.mean()\n                    g_norm = g_norm.mean()\n\n                    pred_x_0 = torch.cos(t) * noisy_model_input - torch.sin(t) * F_theta * sigma_data\n\n                    if args.train_largest_timestep:\n                        pred_x_0.detach()\n                        u = compute_density_for_timestep_sampling_scm(\n                            batch_size=bsz,\n                            logit_mean=args.logit_mean,\n                            logit_std=args.logit_std,\n                        ).to(accelerator.device)\n                        t_new = u.view(-1, 1, 1, 1)\n\n                        random_mask = torch.rand_like(t_new) < args.largest_timestep_prob\n\n                        t_new = torch.where(random_mask, torch.full_like(t_new, args.largest_timestep), t_new)\n                        z_new = torch.randn_like(model_input) * sigma_data\n                        x_t_new = torch.cos(t_new) * model_input + torch.sin(t_new) * z_new\n\n                        F_theta = transformer(\n                            hidden_states=(x_t_new / sigma_data),\n                            timestep=t_new.flatten(),\n                            encoder_hidden_states=prompt_embeds,\n                            encoder_attention_mask=prompt_attention_mask,\n                            guidance=(scm_cfg_scale.flatten() * args.guidance_embeds_scale),\n                            return_logvar=False,\n                            jvp=False,\n                        )[0]\n\n                        pred_x_0 = torch.cos(t_new) * x_t_new - torch.sin(t_new) * F_theta * sigma_data\n\n                    # Sample timesteps for discriminator\n                    timesteps_D = compute_density_for_timestep_sampling_scm(\n                        batch_size=bsz,\n                        logit_mean=args.logit_mean_discriminator,\n                        logit_std=args.logit_std_discriminator,\n                    ).to(accelerator.device)\n                    t_D = timesteps_D.view(-1, 1, 1, 1)\n\n                    # Add noise to predicted x0\n                    z_D = torch.randn_like(model_input) * sigma_data\n                    noised_predicted_x0 = torch.cos(t_D) * pred_x_0 + torch.sin(t_D) * z_D\n\n                    # Calculate adversarial loss\n                    pred_fake = disc(\n                        hidden_states=(noised_predicted_x0 / sigma_data),\n                        timestep=t_D.flatten(),\n                        encoder_hidden_states=prompt_embeds,\n                        encoder_attention_mask=prompt_attention_mask,\n                    )\n                    adv_loss = -torch.mean(pred_fake)\n\n                    # Total loss = sCM loss + LADD loss\n\n                    total_loss = args.scm_lambda * loss + adv_loss * args.adv_lambda\n\n                    total_loss = total_loss / args.gradient_accumulation_steps\n\n                    accelerator.backward(total_loss)\n\n                    if accelerator.sync_gradients:\n                        grad_norm = accelerator.clip_grad_norm_(transformer.parameters(), args.gradient_clip)\n                        if torch.logical_or(grad_norm.isnan(), grad_norm.isinf()):\n                            optimizer_G.zero_grad(set_to_none=True)\n                            optimizer_D.zero_grad(set_to_none=True)\n                            logger.warning(\"NaN or Inf detected in grad_norm, skipping iteration...\")\n                            continue\n\n                        # switch phase to D\n                        phase = \"D\"\n\n                        optimizer_G.step()\n                        lr_scheduler.step()\n                        optimizer_G.zero_grad(set_to_none=True)\n\n            elif phase == \"D\":\n                transformer.eval()\n                disc.train()\n                models_to_accumulate = [disc]\n                with accelerator.accumulate(models_to_accumulate):\n                    with torch.no_grad():\n                        scm_cfg_scale = torch.tensor(\n                            np.random.choice(args.scm_cfg_scale, size=bsz, replace=True),\n                            device=accelerator.device,\n                        )\n\n                        if args.train_largest_timestep:\n                            random_mask = torch.rand_like(t) < args.largest_timestep_prob\n                            t = torch.where(random_mask, torch.full_like(t, args.largest_timestep_prob), t)\n\n                            z_new = torch.randn_like(model_input) * sigma_data\n                            noisy_model_input = torch.cos(t) * model_input + torch.sin(t) * z_new\n                        # here\n                        F_theta = transformer(\n                            hidden_states=(noisy_model_input / sigma_data),\n                            timestep=t.flatten(),\n                            encoder_hidden_states=prompt_embeds,\n                            encoder_attention_mask=prompt_attention_mask,\n                            guidance=(scm_cfg_scale.flatten() * args.guidance_embeds_scale),\n                            return_logvar=False,\n                            jvp=False,\n                        )[0]\n                        pred_x_0 = torch.cos(t) * noisy_model_input - torch.sin(t) * F_theta * sigma_data\n\n                    # Sample timesteps for fake and real samples\n                    timestep_D_fake = compute_density_for_timestep_sampling_scm(\n                        batch_size=bsz,\n                        logit_mean=args.logit_mean_discriminator,\n                        logit_std=args.logit_std_discriminator,\n                    ).to(accelerator.device)\n                    timesteps_D_real = timestep_D_fake\n\n                    t_D_fake = timestep_D_fake.view(-1, 1, 1, 1)\n                    t_D_real = timesteps_D_real.view(-1, 1, 1, 1)\n\n                    # Add noise to predicted x0 and real x0\n                    z_D_fake = torch.randn_like(model_input) * sigma_data\n                    z_D_real = torch.randn_like(model_input) * sigma_data\n                    noised_predicted_x0 = torch.cos(t_D_fake) * pred_x_0 + torch.sin(t_D_fake) * z_D_fake\n                    noised_latents = torch.cos(t_D_real) * model_input + torch.sin(t_D_real) * z_D_real\n\n                    # Add misaligned pairs if enabled and batch size > 1\n                    if args.misaligned_pairs_D and bsz > 1:\n                        # Create shifted pairs\n                        shifted_x0 = torch.roll(model_input, 1, 0)\n                        timesteps_D_shifted = compute_density_for_timestep_sampling_scm(\n                            batch_size=bsz,\n                            logit_mean=args.logit_mean_discriminator,\n                            logit_std=args.logit_std_discriminator,\n                        ).to(accelerator.device)\n                        t_D_shifted = timesteps_D_shifted.view(-1, 1, 1, 1)\n\n                        # Add noise to shifted pairs\n                        z_D_shifted = torch.randn_like(shifted_x0) * sigma_data\n                        noised_shifted_x0 = torch.cos(t_D_shifted) * shifted_x0 + torch.sin(t_D_shifted) * z_D_shifted\n\n                        # Concatenate with original noised samples\n                        noised_predicted_x0 = torch.cat([noised_predicted_x0, noised_shifted_x0], dim=0)\n                        t_D_fake = torch.cat([t_D_fake, t_D_shifted], dim=0)\n                        prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)\n                        prompt_attention_mask = torch.cat([prompt_attention_mask, prompt_attention_mask], dim=0)\n\n                    # Calculate D loss\n\n                    pred_fake = disc(\n                        hidden_states=(noised_predicted_x0 / sigma_data),\n                        timestep=t_D_fake.flatten(),\n                        encoder_hidden_states=prompt_embeds,\n                        encoder_attention_mask=prompt_attention_mask,\n                    )\n                    pred_true = disc(\n                        hidden_states=(noised_latents / sigma_data),\n                        timestep=t_D_real.flatten(),\n                        encoder_hidden_states=prompt_embeds,\n                        encoder_attention_mask=prompt_attention_mask,\n                    )\n\n                    # hinge loss\n                    loss_real = torch.mean(F.relu(1.0 - pred_true))\n                    loss_gen = torch.mean(F.relu(1.0 + pred_fake))\n                    loss_D = 0.5 * (loss_real + loss_gen)\n\n                    loss_D = loss_D / args.gradient_accumulation_steps\n\n                    accelerator.backward(loss_D)\n\n                    if accelerator.sync_gradients:\n                        grad_norm = accelerator.clip_grad_norm_(disc.parameters(), args.gradient_clip)\n                        if torch.logical_or(grad_norm.isnan(), grad_norm.isinf()):\n                            optimizer_G.zero_grad(set_to_none=True)\n                            optimizer_D.zero_grad(set_to_none=True)\n                            logger.warning(\"NaN or Inf detected in grad_norm, skipping iteration...\")\n                            continue\n\n                        # switch back to phase G and add global step by one.\n                        phase = \"G\"\n\n                        optimizer_D.step()\n                        optimizer_D.zero_grad(set_to_none=True)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\n                \"scm_loss\": loss.detach().item(),\n                \"adv_loss\": adv_loss.detach().item(),\n                \"lr\": lr_scheduler.get_last_lr()[0],\n            }\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = SanaSprintPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    transformer=accelerator.unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=torch.float32,\n                )\n                pipeline_args = {\n                    \"prompt\": args.validation_prompt,\n                    \"complex_human_instruction\": COMPLEX_HUMAN_INSTRUCTION,\n                }\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                )\n                free_memory()\n\n                images = None\n                del pipeline\n\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        transformer = unwrap_model(transformer)\n        if args.upcast_before_saving:\n            transformer.to(torch.float32)\n        else:\n            transformer = transformer.to(weight_dtype)\n\n        # Save discriminator heads\n        disc = unwrap_model(disc)\n        disc_heads_state_dict = disc.heads.state_dict()\n\n        # Save transformer model\n        transformer.save_pretrained(os.path.join(args.output_dir, \"transformer\"))\n\n        # Save discriminator heads\n        torch.save(disc_heads_state_dict, os.path.join(args.output_dir, \"disc_heads.pt\"))\n\n        # Final inference\n        # Load previous pipeline\n        pipeline = SanaSprintPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            transformer=accelerator.unwrap_model(transformer),\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=torch.float32,\n        )\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\n                \"prompt\": args.validation_prompt,\n                \"complex_human_instruction\": COMPLEX_HUMAN_INSTRUCTION,\n            }\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=pipeline_args,\n                epoch=epoch,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                instance_prompt=args.instance_prompt,\n                validation_prompt=args.validation_prompt,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        images = None\n        del pipeline\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/sana/train_sana_sprint_diffusers.sh",
    "content": "your_local_path='output'\n\nhf download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers  --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers\n\n# or Sana_Sprint_0.6B_1024px_teacher_diffusers\n\npython train_sana_sprint_diffusers.py \\\n    --pretrained_model_name_or_path=$your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers \\\n    --output_dir=$your_local_path \\\n    --mixed_precision=bf16 \\\n    --resolution=1024 \\\n    --learning_rate=1e-6 \\\n    --max_train_steps=30000 \\\n    --dataloader_num_workers=8 \\\n    --dataset_name='brivangl/midjourney-v6-llava' \\\n    --file_path data/train_000.parquet data/train_001.parquet data/train_002.parquet \\\n    --checkpointing_steps=500 --checkpoints_total_limit=10 \\\n    --train_batch_size=1 \\\n    --gradient_accumulation_steps=1 \\\n    --seed=453645634 \\\n    --train_largest_timestep \\\n    --misaligned_pairs_D \\\n    --gradient_checkpointing \\\n    --resume_from_checkpoint=\"latest\" \\\n\n\n"
  },
  {
    "path": "examples/research_projects/scheduled_huber_loss_training/README.md",
    "content": "# Scheduled Pseudo-Huber Loss for Diffusers\n\nThese are the modifications of to include the possibility of training text2image models with Scheduled Pseudo Huber loss, introduced in https://huggingface.co/papers/2403.16728. (https://github.com/kabachuha/SPHL-for-stable-diffusion)\n\n## Why this might be useful?\n\n- If you suspect that the part of the training dataset might be corrupted, and you don't want these outliers to distort the model's supposed output\n\n- If you want to improve the aesthetic quality of pictures by helping the model disentangle concepts and be less influenced by another sorts of pictures.\n\nSee https://github.com/huggingface/diffusers/issues/7488 for the detailed description.\n\n## Instructions\n\nThe same usage as in the case of the corresponding vanilla Diffusers scripts https://github.com/huggingface/diffusers/tree/main/examples\n"
  },
  {
    "path": "examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport gc\nimport importlib\nimport itertools\nimport logging\nimport math\nimport os\nimport shutil\nimport warnings\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, model_info, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom packaging import version\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import compute_snr\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.28.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images: list = None,\n    base_model: str = None,\n    train_text_encoder=False,\n    prompt: str = None,\n    repo_folder: str = None,\n    pipeline: DiffusionPipeline = None,\n):\n    img_str = \"\"\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# DreamBooth - {repo_id}\n\nThis is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).\nYou can find some example images in the following. \\n\n{img_str}\n\nDreamBooth for the text encoder was enabled: {train_text_encoder}.\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        prompt=prompt,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\"text-to-image\", \"dreambooth\", \"diffusers-training\"]\n    if isinstance(pipeline, StableDiffusionPipeline):\n        tags.extend([\"stable-diffusion\", \"stable-diffusion-diffusers\"])\n    else:\n        tags.extend([\"if\", \"if-diffusers\"])\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    text_encoder,\n    tokenizer,\n    unet,\n    vae,\n    args,\n    accelerator,\n    weight_dtype,\n    global_step,\n    prompt_embeds,\n    negative_prompt_embeds,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n\n    pipeline_args = {}\n\n    if vae is not None:\n        pipeline_args[\"vae\"] = vae\n\n    # create pipeline (note: unet and vae are loaded again in float32)\n    pipeline = DiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        unet=unet,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n        **pipeline_args,\n    )\n\n    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n    scheduler_args = {}\n\n    if \"variance_type\" in pipeline.scheduler.config:\n        variance_type = pipeline.scheduler.config.variance_type\n\n        if variance_type in [\"learned\", \"learned_range\"]:\n            variance_type = \"fixed_small\"\n\n        scheduler_args[\"variance_type\"] = variance_type\n\n    module = importlib.import_module(\"diffusers\")\n    scheduler_class = getattr(module, args.validation_scheduler)\n    pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.pre_compute_text_embeddings:\n        pipeline_args = {\n            \"prompt_embeds\": prompt_embeds,\n            \"negative_prompt_embeds\": negative_prompt_embeds,\n        }\n    else:\n        pipeline_args = {\"prompt\": args.validation_prompt}\n\n    # run inference\n    generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)\n    images = []\n    if args.validation_images is None:\n        for _ in range(args.num_validation_images):\n            with torch.autocast(\"cuda\"):\n                image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]\n            images.append(image)\n    else:\n        for image in args.validation_images:\n            image = Image.open(image)\n            image = pipeline(**pipeline_args, image=image, generator=generator).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, global_step, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A folder containing the training data of instance images.\",\n    )\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"dreambooth-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more details\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n\n    parser.add_argument(\n        \"--offset_noise\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Fine-tuning against a modified noise\"\n            \" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information.\"\n        ),\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--pre_compute_text_embeddings\",\n        action=\"store_true\",\n        help=\"Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_max_length\",\n        type=int,\n        default=None,\n        required=False,\n        help=\"The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.\",\n    )\n    parser.add_argument(\n        \"--text_encoder_use_attention_mask\",\n        action=\"store_true\",\n        required=False,\n        help=\"Whether to use attention mask for the text encoder\",\n    )\n    parser.add_argument(\n        \"--skip_save_text_encoder\", action=\"store_true\", required=False, help=\"Set to not save text encoder\"\n    )\n    parser.add_argument(\n        \"--validation_images\",\n        required=False,\n        default=None,\n        nargs=\"+\",\n        help=\"Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.\",\n    )\n    parser.add_argument(\n        \"--class_labels_conditioning\",\n        required=False,\n        default=None,\n        help=\"The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.\",\n    )\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"l2\",\n        choices=[\"l2\", \"huber\", \"smooth_l1\"],\n        help=\"The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.\",\n    )\n    parser.add_argument(\n        \"--huber_schedule\",\n        type=str,\n        default=\"snr\",\n        choices=[\"constant\", \"exponential\", \"snr\"],\n        help=\"The schedule to use for the huber losses parameter\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=0.1,\n        help=\"The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.\",\n    )\n    parser.add_argument(\n        \"--validation_scheduler\",\n        type=str,\n        default=\"DPMSolverMultistepScheduler\",\n        choices=[\"DPMSolverMultistepScheduler\", \"DDPMScheduler\"],\n        help=\"Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    if args.train_text_encoder and args.pre_compute_text_embeddings:\n        raise ValueError(\"`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images and the tokenizes prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        tokenizer,\n        class_data_root=None,\n        class_prompt=None,\n        class_num=None,\n        size=512,\n        center_crop=False,\n        encoder_hidden_states=None,\n        class_prompt_encoder_hidden_states=None,\n        tokenizer_max_length=None,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n        self.tokenizer = tokenizer\n        self.encoder_hidden_states = encoder_hidden_states\n        self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states\n        self.tokenizer_max_length = tokenizer_max_length\n\n        self.instance_data_root = Path(instance_data_root)\n        if not self.instance_data_root.exists():\n            raise ValueError(f\"Instance {self.instance_data_root} images root doesn't exists.\")\n\n        self.instance_images_path = list(Path(instance_data_root).iterdir())\n        self.num_instance_images = len(self.instance_images_path)\n        self.instance_prompt = instance_prompt\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n            self.class_prompt = class_prompt\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])\n        instance_image = exif_transpose(instance_image)\n\n        if not instance_image.mode == \"RGB\":\n            instance_image = instance_image.convert(\"RGB\")\n        example[\"instance_images\"] = self.image_transforms(instance_image)\n\n        if self.encoder_hidden_states is not None:\n            example[\"instance_prompt_ids\"] = self.encoder_hidden_states\n        else:\n            text_inputs = tokenize_prompt(\n                self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length\n            )\n            example[\"instance_prompt_ids\"] = text_inputs.input_ids\n            example[\"instance_attention_mask\"] = text_inputs.attention_mask\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n\n            if self.class_prompt_encoder_hidden_states is not None:\n                example[\"class_prompt_ids\"] = self.class_prompt_encoder_hidden_states\n            else:\n                class_text_inputs = tokenize_prompt(\n                    self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length\n                )\n                example[\"class_prompt_ids\"] = class_text_inputs.input_ids\n                example[\"class_attention_mask\"] = class_text_inputs.attention_mask\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    has_attention_mask = \"instance_attention_mask\" in examples[0]\n\n    input_ids = [example[\"instance_prompt_ids\"] for example in examples]\n    pixel_values = [example[\"instance_images\"] for example in examples]\n\n    if has_attention_mask:\n        attention_mask = [example[\"instance_attention_mask\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        input_ids += [example[\"class_prompt_ids\"] for example in examples]\n        pixel_values += [example[\"class_images\"] for example in examples]\n\n        if has_attention_mask:\n            attention_mask += [example[\"class_attention_mask\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    input_ids = torch.cat(input_ids, dim=0)\n\n    batch = {\n        \"input_ids\": input_ids,\n        \"pixel_values\": pixel_values,\n    }\n\n    if has_attention_mask:\n        attention_mask = torch.cat(attention_mask, dim=0)\n        batch[\"attention_mask\"] = attention_mask\n\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef model_has_vae(args):\n    config_file_name = os.path.join(\"vae\", AutoencoderKL.config_name)\n    if os.path.isdir(args.pretrained_model_name_or_path):\n        config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)\n        return os.path.isfile(config_file_name)\n    else:\n        files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings\n        return any(file.rfilename == config_file_name for file in files_in_repo)\n\n\ndef tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):\n    if tokenizer_max_length is not None:\n        max_length = tokenizer_max_length\n    else:\n        max_length = tokenizer.model_max_length\n\n    text_inputs = tokenizer(\n        prompt,\n        truncation=True,\n        padding=\"max_length\",\n        max_length=max_length,\n        return_tensors=\"pt\",\n    )\n\n    return text_inputs\n\n\ndef encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):\n    text_input_ids = input_ids.to(text_encoder.device)\n\n    if text_encoder_use_attention_mask:\n        attention_mask = attention_mask.to(text_encoder.device)\n    else:\n        attention_mask = None\n\n    prompt_embeds = text_encoder(\n        text_input_ids,\n        attention_mask=attention_mask,\n        return_dict=False,\n    )\n    prompt_embeds = prompt_embeds[0]\n\n    return prompt_embeds\n\n\n# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already\ndef conditional_loss(\n    model_pred: torch.Tensor,\n    target: torch.Tensor,\n    reduction: str = \"mean\",\n    loss_type: str = \"l2\",\n    huber_c: float = 0.1,\n):\n    if loss_type == \"l2\":\n        loss = F.mse_loss(model_pred, target, reduction=reduction)\n    elif loss_type == \"huber\":\n        loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n        if reduction == \"mean\":\n            loss = torch.mean(loss)\n        elif reduction == \"sum\":\n            loss = torch.sum(loss)\n    elif loss_type == \"smooth_l1\":\n        loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n        if reduction == \"mean\":\n            loss = torch.mean(loss)\n        elif reduction == \"sum\":\n            loss = torch.sum(loss)\n    else:\n        raise NotImplementedError(f\"Unsupported Loss Type {loss_type}\")\n    return loss\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate\n    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.\n    # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.\n    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:\n        raise ValueError(\n            \"Gradient accumulation is not supported when training the text encoder in distributed training. \"\n            \"Please set gradient_accumulation_steps to 1. This feature will be supported in the future.\"\n        )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            torch_dtype = torch.float16 if accelerator.device.type == \"cuda\" else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n            pipeline = DiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                safety_checker=None,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    if args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.pretrained_model_name_or_path,\n            subfolder=\"tokenizer\",\n            revision=args.revision,\n            use_fast=False,\n        )\n\n    # import correct text encoder class\n    text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = text_encoder_cls.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n\n    if model_has_vae(args):\n        vae = AutoencoderKL.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n        )\n    else:\n        vae = None\n\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            for model in models:\n                sub_dir = \"unet\" if isinstance(model, type(unwrap_model(unet))) else \"text_encoder\"\n                model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n    def load_model_hook(models, input_dir):\n        while len(models) > 0:\n            # pop models so that they are not loaded again\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(text_encoder))):\n                # load transformers style into model\n                load_model = text_encoder_cls.from_pretrained(input_dir, subfolder=\"text_encoder\")\n                model.config = load_model.config\n            else:\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n            model.load_state_dict(load_model.state_dict())\n            del load_model\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if vae is not None:\n        vae.requires_grad_(False)\n\n    if not args.train_text_encoder:\n        text_encoder.requires_grad_(False)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder.gradient_checkpointing_enable()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \"Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training. copy of the weights should still be float32.\"\n    )\n\n    if unwrap_model(unet).dtype != torch.float32:\n        raise ValueError(f\"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}\")\n\n    if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:\n        raise ValueError(\n            f\"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}\"\n        )\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = (\n        itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()\n    )\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    if args.pre_compute_text_embeddings:\n\n        def compute_text_embeddings(prompt):\n            with torch.no_grad():\n                text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)\n                prompt_embeds = encode_prompt(\n                    text_encoder,\n                    text_inputs.input_ids,\n                    text_inputs.attention_mask,\n                    text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,\n                )\n\n            return prompt_embeds\n\n        pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)\n        validation_prompt_negative_prompt_embeds = compute_text_embeddings(\"\")\n\n        if args.validation_prompt is not None:\n            validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)\n        else:\n            validation_prompt_encoder_hidden_states = None\n\n        if args.class_prompt is not None:\n            pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)\n        else:\n            pre_computed_class_prompt_encoder_hidden_states = None\n\n        text_encoder = None\n        tokenizer = None\n\n        gc.collect()\n        torch.cuda.empty_cache()\n    else:\n        pre_computed_encoder_hidden_states = None\n        validation_prompt_encoder_hidden_states = None\n        validation_prompt_negative_prompt_embeds = None\n        pre_computed_class_prompt_encoder_hidden_states = None\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_prompt=args.class_prompt,\n        class_num=args.num_class_images,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        center_crop=args.center_crop,\n        encoder_hidden_states=pre_computed_encoder_hidden_states,\n        class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,\n        tokenizer_max_length=args.tokenizer_max_length,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae and text_encoder to device and cast to weight_dtype\n    if vae is not None:\n        vae.to(accelerator.device, dtype=weight_dtype)\n\n    if not args.train_text_encoder and text_encoder is not None:\n        text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = vars(copy.deepcopy(args))\n        tracker_config.pop(\"validation_images\")\n        accelerator.init_trackers(\"dreambooth\", config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.train_text_encoder:\n            text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                pixel_values = batch[\"pixel_values\"].to(dtype=weight_dtype)\n\n                if vae is not None:\n                    # Convert images to latent space\n                    model_input = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                    model_input = model_input * vae.config.scaling_factor\n                else:\n                    model_input = pixel_values\n\n                # Sample noise that we'll add to the model input\n                if args.offset_noise:\n                    noise = torch.randn_like(model_input) + 0.1 * torch.randn(\n                        model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device\n                    )\n                else:\n                    noise = torch.randn_like(model_input)\n                bsz, channels, height, width = model_input.shape\n                # Sample a random timestep for each image\n                if args.loss_type == \"huber\" or args.loss_type == \"smooth_l1\":\n                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=\"cpu\")\n                    timestep = timesteps.item()\n\n                    if args.huber_schedule == \"exponential\":\n                        alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps\n                        huber_c = math.exp(-alpha * timestep)\n                    elif args.huber_schedule == \"snr\":\n                        alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]\n                        sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5\n                        huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c\n                    elif args.huber_schedule == \"constant\":\n                        huber_c = args.huber_c\n                    else:\n                        raise NotImplementedError(f\"Unknown Huber loss schedule {args.huber_schedule}!\")\n\n                    timesteps = timesteps.repeat(bsz).to(model_input.device)\n                elif args.loss_type == \"l2\":\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                    )\n                    huber_c = 1  # may be anything, as it's not used\n                else:\n                    raise NotImplementedError(f\"Unknown loss type {args.loss_type}\")\n\n                timesteps = timesteps.long()\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                if args.pre_compute_text_embeddings:\n                    encoder_hidden_states = batch[\"input_ids\"]\n                else:\n                    encoder_hidden_states = encode_prompt(\n                        text_encoder,\n                        batch[\"input_ids\"],\n                        batch[\"attention_mask\"],\n                        text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,\n                    )\n\n                if unwrap_model(unet).config.in_channels == channels * 2:\n                    noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)\n\n                if args.class_labels_conditioning == \"timesteps\":\n                    class_labels = timesteps\n                else:\n                    class_labels = None\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False\n                )[0]\n\n                if model_pred.shape[1] == 6:\n                    model_pred, _ = torch.chunk(model_pred, 2, dim=1)\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n                    # Compute prior loss\n                    prior_loss = conditional_loss(\n                        model_pred_prior.float(),\n                        target_prior.float(),\n                        reduction=\"mean\",\n                        loss_type=args.loss_type,\n                        huber_c=huber_c,\n                    )\n\n                # Compute instance loss\n                if args.snr_gamma is None:\n                    loss = conditional_loss(\n                        model_pred.float(), target.float(), reduction=\"mean\", loss_type=args.loss_type, huber_c=huber_c\n                    )\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    base_weight = (\n                        torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr\n                    )\n\n                    if noise_scheduler.config.prediction_type == \"v_prediction\":\n                        # Velocity objective needs to be floored to an SNR weight of one.\n                        mse_loss_weights = base_weight + 1\n                    else:\n                        # Epsilon and sample both use the same loss weights.\n                        mse_loss_weights = base_weight\n                    loss = conditional_loss(\n                        model_pred.float(), target.float(), reduction=\"none\", loss_type=args.loss_type, huber_c=huber_c\n                    )\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(unet.parameters(), text_encoder.parameters())\n                        if args.train_text_encoder\n                        else unet.parameters()\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    images = []\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        images = log_validation(\n                            unwrap_model(text_encoder) if text_encoder is not None else text_encoder,\n                            tokenizer,\n                            unwrap_model(unet),\n                            vae,\n                            args,\n                            accelerator,\n                            weight_dtype,\n                            global_step,\n                            validation_prompt_encoder_hidden_states,\n                            validation_prompt_negative_prompt_embeds,\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        pipeline_args = {}\n\n        if text_encoder is not None:\n            pipeline_args[\"text_encoder\"] = unwrap_model(text_encoder)\n\n        if args.skip_save_text_encoder:\n            pipeline_args[\"text_encoder\"] = None\n\n        pipeline = DiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            unet=unwrap_model(unet),\n            revision=args.revision,\n            variant=args.variant,\n            **pipeline_args,\n        )\n\n        # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n        scheduler_args = {}\n\n        if \"variance_type\" in pipeline.scheduler.config:\n            variance_type = pipeline.scheduler.config.variance_type\n\n            if variance_type in [\"learned\", \"learned_range\"]:\n                variance_type = \"fixed_small\"\n\n            scheduler_args[\"variance_type\"] = variance_type\n\n        pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n\n        pipeline.save_pretrained(args.output_dir)\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                train_text_encoder=args.train_text_encoder,\n                prompt=args.instance_prompt,\n                repo_folder=args.output_dir,\n                pipeline=pipeline,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport gc\nimport logging\nimport math\nimport os\nimport shutil\nimport warnings\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom packaging import version\nfrom peft import LoraConfig\nfrom peft.utils import get_peft_model_state_dict, set_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    DPMSolverMultistepScheduler,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params\nfrom diffusers.utils import (\n    check_min_version,\n    convert_state_dict_to_diffusers,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.28.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model=str,\n    train_text_encoder=False,\n    prompt=str,\n    repo_folder=None,\n    pipeline: DiffusionPipeline = None,\n):\n    img_str = \"\"\n    for i, image in enumerate(images):\n        image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n        img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# LoRA DreamBooth - {repo_id}\n\nThese are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \\n\n{img_str}\n\nLoRA for the text encoder was enabled: {train_text_encoder}.\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        prompt=prompt,\n        model_description=model_description,\n        inference=True,\n    )\n    tags = [\"text-to-image\", \"diffusers\", \"lora\", \"diffusers-training\"]\n    if isinstance(pipeline, StableDiffusionPipeline):\n        tags.extend([\"stable-diffusion\", \"stable-diffusion-diffusers\"])\n    else:\n        tags.extend([\"if\", \"if-diffusers\"])\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n    scheduler_args = {}\n\n    if \"variance_type\" in pipeline.scheduler.config:\n        variance_type = pipeline.scheduler.config.variance_type\n\n        if variance_type in [\"learned\", \"learned_range\"]:\n            variance_type = \"fixed_small\"\n\n        scheduler_args[\"variance_type\"] = variance_type\n\n    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n\n    if args.validation_images is None:\n        images = []\n        for _ in range(args.num_validation_images):\n            with torch.cuda.amp.autocast():\n                image = pipeline(**pipeline_args, generator=generator).images[0]\n                images.append(image)\n    else:\n        images = []\n        for image in args.validation_images:\n            image = Image.open(image)\n            with torch.cuda.amp.autocast():\n                image = pipeline(**pipeline_args, image=image, generator=generator).images[0]\n            images.append(image)\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A folder containing the training data of instance images.\",\n    )\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"lora-dreambooth-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--pre_compute_text_embeddings\",\n        action=\"store_true\",\n        help=\"Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_max_length\",\n        type=int,\n        default=None,\n        required=False,\n        help=\"The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.\",\n    )\n    parser.add_argument(\n        \"--text_encoder_use_attention_mask\",\n        action=\"store_true\",\n        required=False,\n        help=\"Whether to use attention mask for the text encoder\",\n    )\n    parser.add_argument(\n        \"--validation_images\",\n        required=False,\n        default=None,\n        nargs=\"+\",\n        help=\"Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.\",\n    )\n    parser.add_argument(\n        \"--class_labels_conditioning\",\n        required=False,\n        default=None,\n        help=\"The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.\",\n    )\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"l2\",\n        choices=[\"l2\", \"huber\", \"smooth_l1\"],\n        help=\"The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.\",\n    )\n    parser.add_argument(\n        \"--huber_schedule\",\n        type=str,\n        default=\"snr\",\n        choices=[\"constant\", \"exponential\", \"snr\"],\n        help=\"The schedule to use for the huber losses parameter\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=0.1,\n        help=\"The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.\",\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    if args.train_text_encoder and args.pre_compute_text_embeddings:\n        raise ValueError(\"`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images and the tokenizes prompts.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        tokenizer,\n        class_data_root=None,\n        class_prompt=None,\n        class_num=None,\n        size=512,\n        center_crop=False,\n        encoder_hidden_states=None,\n        class_prompt_encoder_hidden_states=None,\n        tokenizer_max_length=None,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n        self.tokenizer = tokenizer\n        self.encoder_hidden_states = encoder_hidden_states\n        self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states\n        self.tokenizer_max_length = tokenizer_max_length\n\n        self.instance_data_root = Path(instance_data_root)\n        if not self.instance_data_root.exists():\n            raise ValueError(\"Instance images root doesn't exists.\")\n\n        self.instance_images_path = list(Path(instance_data_root).iterdir())\n        self.num_instance_images = len(self.instance_images_path)\n        self.instance_prompt = instance_prompt\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n            self.class_prompt = class_prompt\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])\n        instance_image = exif_transpose(instance_image)\n\n        if not instance_image.mode == \"RGB\":\n            instance_image = instance_image.convert(\"RGB\")\n        example[\"instance_images\"] = self.image_transforms(instance_image)\n\n        if self.encoder_hidden_states is not None:\n            example[\"instance_prompt_ids\"] = self.encoder_hidden_states\n        else:\n            text_inputs = tokenize_prompt(\n                self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length\n            )\n            example[\"instance_prompt_ids\"] = text_inputs.input_ids\n            example[\"instance_attention_mask\"] = text_inputs.attention_mask\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n\n            if self.class_prompt_encoder_hidden_states is not None:\n                example[\"class_prompt_ids\"] = self.class_prompt_encoder_hidden_states\n            else:\n                class_text_inputs = tokenize_prompt(\n                    self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length\n                )\n                example[\"class_prompt_ids\"] = class_text_inputs.input_ids\n                example[\"class_attention_mask\"] = class_text_inputs.attention_mask\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    has_attention_mask = \"instance_attention_mask\" in examples[0]\n\n    input_ids = [example[\"instance_prompt_ids\"] for example in examples]\n    pixel_values = [example[\"instance_images\"] for example in examples]\n\n    if has_attention_mask:\n        attention_mask = [example[\"instance_attention_mask\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        input_ids += [example[\"class_prompt_ids\"] for example in examples]\n        pixel_values += [example[\"class_images\"] for example in examples]\n        if has_attention_mask:\n            attention_mask += [example[\"class_attention_mask\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    input_ids = torch.cat(input_ids, dim=0)\n\n    batch = {\n        \"input_ids\": input_ids,\n        \"pixel_values\": pixel_values,\n    }\n\n    if has_attention_mask:\n        batch[\"attention_mask\"] = attention_mask\n\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):\n    if tokenizer_max_length is not None:\n        max_length = tokenizer_max_length\n    else:\n        max_length = tokenizer.model_max_length\n\n    text_inputs = tokenizer(\n        prompt,\n        truncation=True,\n        padding=\"max_length\",\n        max_length=max_length,\n        return_tensors=\"pt\",\n    )\n\n    return text_inputs\n\n\ndef encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):\n    text_input_ids = input_ids.to(text_encoder.device)\n\n    if text_encoder_use_attention_mask:\n        attention_mask = attention_mask.to(text_encoder.device)\n    else:\n        attention_mask = None\n\n    prompt_embeds = text_encoder(\n        text_input_ids,\n        attention_mask=attention_mask,\n        return_dict=False,\n    )\n    prompt_embeds = prompt_embeds[0]\n\n    return prompt_embeds\n\n\n# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already\ndef conditional_loss(\n    model_pred: torch.Tensor,\n    target: torch.Tensor,\n    reduction: str = \"mean\",\n    loss_type: str = \"l2\",\n    huber_c: float = 0.1,\n):\n    if loss_type == \"l2\":\n        loss = F.mse_loss(model_pred, target, reduction=reduction)\n    elif loss_type == \"huber\":\n        loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n        if reduction == \"mean\":\n            loss = torch.mean(loss)\n        elif reduction == \"sum\":\n            loss = torch.sum(loss)\n    elif loss_type == \"smooth_l1\":\n        loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n        if reduction == \"mean\":\n            loss = torch.mean(loss)\n        elif reduction == \"sum\":\n            loss = torch.sum(loss)\n    else:\n        raise NotImplementedError(f\"Unsupported Loss Type {loss_type}\")\n    return loss\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate\n    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.\n    # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.\n    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:\n        raise ValueError(\n            \"Gradient accumulation is not supported when training the text encoder in distributed training. \"\n            \"Please set gradient_accumulation_steps to 1. This feature will be supported in the future.\"\n        )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            torch_dtype = torch.float16 if accelerator.device.type == \"cuda\" else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n            pipeline = DiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                safety_checker=None,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizer\n    if args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.pretrained_model_name_or_path,\n            subfolder=\"tokenizer\",\n            revision=args.revision,\n            use_fast=False,\n        )\n\n    # import correct text encoder class\n    text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = text_encoder_cls.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    try:\n        vae = AutoencoderKL.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n        )\n    except OSError:\n        # IF does not have a VAE so let's just set it to None\n        # We don't have to error out here\n        vae = None\n\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    if vae is not None:\n        vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    if vae is not None:\n        vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder.gradient_checkpointing_enable()\n\n    # now we will add new LoRA weights to the attention layers\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\", \"add_k_proj\", \"add_v_proj\"],\n    )\n    unet.add_adapter(unet_lora_config)\n\n    # The text encoder comes from 🤗 transformers, we will also attach adapters to it.\n    if args.train_text_encoder:\n        text_lora_config = LoraConfig(\n            r=args.rank,\n            lora_alpha=args.rank,\n            init_lora_weights=\"gaussian\",\n            target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n        )\n        text_encoder.add_adapter(text_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            # there are only two options here. Either are just the unet attn processor layers\n            # or there are the unet and text encoder atten layers\n            unet_lora_layers_to_save = None\n            text_encoder_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(unwrap_model(unet))):\n                    unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))\n                elif isinstance(model, type(unwrap_model(text_encoder))):\n                    text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(\n                        get_peft_model_state_dict(model)\n                    )\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            StableDiffusionLoraLoaderMixin.save_lora_weights(\n                output_dir,\n                unet_lora_layers=unet_lora_layers_to_save,\n                text_encoder_lora_layers=text_encoder_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n        text_encoder_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(unet))):\n                unet_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder))):\n                text_encoder_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)\n\n        unet_state_dict = {f\"{k.replace('unet.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"unet.\")}\n        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)\n        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name=\"default\")\n\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        if args.train_text_encoder:\n            _set_state_dict_into_text_encoder(lora_state_dict, prefix=\"text_encoder.\", text_encoder=text_encoder_)\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [unet_]\n            if args.train_text_encoder:\n                models.append(text_encoder_)\n\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models, dtype=torch.float32)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [unet]\n        if args.train_text_encoder:\n            models.append(text_encoder)\n\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))\n    if args.train_text_encoder:\n        params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters()))\n\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    if args.pre_compute_text_embeddings:\n\n        def compute_text_embeddings(prompt):\n            with torch.no_grad():\n                text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)\n                prompt_embeds = encode_prompt(\n                    text_encoder,\n                    text_inputs.input_ids,\n                    text_inputs.attention_mask,\n                    text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,\n                )\n\n            return prompt_embeds\n\n        pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)\n        validation_prompt_negative_prompt_embeds = compute_text_embeddings(\"\")\n\n        if args.validation_prompt is not None:\n            validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)\n        else:\n            validation_prompt_encoder_hidden_states = None\n\n        if args.class_prompt is not None:\n            pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)\n        else:\n            pre_computed_class_prompt_encoder_hidden_states = None\n\n        text_encoder = None\n        tokenizer = None\n\n        gc.collect()\n        torch.cuda.empty_cache()\n    else:\n        pre_computed_encoder_hidden_states = None\n        validation_prompt_encoder_hidden_states = None\n        validation_prompt_negative_prompt_embeds = None\n        pre_computed_class_prompt_encoder_hidden_states = None\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_prompt=args.class_prompt,\n        class_num=args.num_class_images,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        center_crop=args.center_crop,\n        encoder_hidden_states=pre_computed_encoder_hidden_states,\n        class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,\n        tokenizer_max_length=args.tokenizer_max_length,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = vars(copy.deepcopy(args))\n        tracker_config.pop(\"validation_images\")\n        accelerator.init_trackers(\"dreambooth-lora\", config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.train_text_encoder:\n            text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                pixel_values = batch[\"pixel_values\"].to(dtype=weight_dtype)\n\n                if vae is not None:\n                    # Convert images to latent space\n                    model_input = vae.encode(pixel_values).latent_dist.sample()\n                    model_input = model_input * vae.config.scaling_factor\n                else:\n                    model_input = pixel_values\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz, channels, height, width = model_input.shape\n                # Sample a random timestep for each image\n                if args.loss_type == \"huber\" or args.loss_type == \"smooth_l1\":\n                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=\"cpu\")\n                    timestep = timesteps.item()\n\n                    if args.huber_schedule == \"exponential\":\n                        alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps\n                        huber_c = math.exp(-alpha * timestep)\n                    elif args.huber_schedule == \"snr\":\n                        alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]\n                        sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5\n                        huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c\n                    elif args.huber_schedule == \"constant\":\n                        huber_c = args.huber_c\n                    else:\n                        raise NotImplementedError(f\"Unknown Huber loss schedule {args.huber_schedule}!\")\n\n                    timesteps = timesteps.repeat(bsz).to(model_input.device)\n                elif args.loss_type == \"l2\":\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                    )\n                    huber_c = 1  # may be anything, as it's not used\n                else:\n                    raise NotImplementedError(f\"Unknown loss type {args.loss_type}\")\n\n                timesteps = timesteps.long()\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                if args.pre_compute_text_embeddings:\n                    encoder_hidden_states = batch[\"input_ids\"]\n                else:\n                    encoder_hidden_states = encode_prompt(\n                        text_encoder,\n                        batch[\"input_ids\"],\n                        batch[\"attention_mask\"],\n                        text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,\n                    )\n\n                if unwrap_model(unet).config.in_channels == channels * 2:\n                    noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)\n\n                if args.class_labels_conditioning == \"timesteps\":\n                    class_labels = timesteps\n                else:\n                    class_labels = None\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_model_input,\n                    timesteps,\n                    encoder_hidden_states,\n                    class_labels=class_labels,\n                    return_dict=False,\n                )[0]\n\n                # if model predicts variance, throw away the prediction. we will only train on the\n                # simplified training objective. This means that all schedulers using the fine tuned\n                # model must be configured to use one of the fixed variance variance types.\n                if model_pred.shape[1] == 6:\n                    model_pred, _ = torch.chunk(model_pred, 2, dim=1)\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute instance loss\n                    loss = conditional_loss(\n                        model_pred.float(), target.float(), reduction=\"mean\", loss_type=args.loss_type, huber_c=huber_c\n                    )\n\n                    # Compute prior loss\n                    prior_loss = conditional_loss(\n                        model_pred_prior.float(),\n                        target_prior.float(),\n                        reduction=\"mean\",\n                        loss_type=args.loss_type,\n                        huber_c=huber_c,\n                    )\n\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n                else:\n                    loss = conditional_loss(\n                        model_pred.float(), target.float(), reduction=\"mean\", loss_type=args.loss_type, huber_c=huber_c\n                    )\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = DiffusionPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    unet=unwrap_model(unet),\n                    text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n\n                if args.pre_compute_text_embeddings:\n                    pipeline_args = {\n                        \"prompt_embeds\": validation_prompt_encoder_hidden_states,\n                        \"negative_prompt_embeds\": validation_prompt_negative_prompt_embeds,\n                    }\n                else:\n                    pipeline_args = {\"prompt\": args.validation_prompt}\n\n                images = log_validation(\n                    pipeline,\n                    args,\n                    accelerator,\n                    pipeline_args,\n                    epoch,\n                )\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        unet = unet.to(torch.float32)\n\n        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n\n        if args.train_text_encoder:\n            text_encoder = unwrap_model(text_encoder)\n            text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))\n        else:\n            text_encoder_state_dict = None\n\n        StableDiffusionLoraLoaderMixin.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_state_dict,\n            text_encoder_lora_layers=text_encoder_state_dict,\n        )\n\n        # Final inference\n        # Load previous pipeline\n        pipeline = DiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype\n        )\n\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir, weight_name=\"pytorch_lora_weights.safetensors\")\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt, \"num_inference_steps\": 25}\n            images = log_validation(\n                pipeline,\n                args,\n                accelerator,\n                pipeline_args,\n                epoch,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                train_text_encoder=args.train_text_encoder,\n                prompt=args.instance_prompt,\n                repo_folder=args.output_dir,\n                pipeline=pipeline,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport contextlib\nimport gc\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom pathlib import Path\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, hf_hub_download, upload_folder\nfrom huggingface_hub.utils import insecure_hashlib\nfrom packaging import version\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom safetensors.torch import load_file, save_file\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DPMSolverMultistepScheduler,\n    EDMEulerScheduler,\n    EulerDiscreteScheduler,\n    StableDiffusionXLPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr\nfrom diffusers.utils import (\n    check_min_version,\n    convert_all_state_dict_to_peft,\n    convert_state_dict_to_diffusers,\n    convert_state_dict_to_kohya,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.28.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef determine_scheduler_type(pretrained_model_name_or_path, revision):\n    model_index_filename = \"model_index.json\"\n    if os.path.isdir(pretrained_model_name_or_path):\n        model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)\n    else:\n        model_index = hf_hub_download(\n            repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision\n        )\n\n    with open(model_index, \"r\") as f:\n        scheduler_type = json.load(f)[\"scheduler\"][1]\n    return scheduler_type\n\n\ndef save_model_card(\n    repo_id: str,\n    use_dora: bool,\n    images=None,\n    base_model: str = None,\n    train_text_encoder=False,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n    vae_path=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# {\"SDXL\" if \"playground\" not in base_model else \"Playground\"} LoRA DreamBooth - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} LoRA adaption weights for {base_model}.\n\nThe weights were trained  using [DreamBooth](https://dreambooth.github.io/).\n\nLoRA for the text encoder was enabled: {train_text_encoder}.\n\nSpecial VAE used for training: {vae_path}.\n\n## Trigger words\n\nYou should use {instance_prompt} to trigger the image generation.\n\n## Download model\n\nWeights for this model are available in Safetensors format.\n\n[Download]({repo_id}/tree/main) them in the Files & versions tab.\n\n\"\"\"\n    if \"playground\" in base_model:\n        model_description += \"\"\"\\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"openrail++\" if \"playground\" not in base_model else \"playground-v2dot5-community\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\" if not use_dora else \"dora\",\n        \"template:sd-lora\",\n    ]\n    if \"playground\" in base_model:\n        tags.extend([\"playground\", \"playground-diffusers\"])\n    else:\n        tags.extend([\"stable-diffusion-xl\", \"stable-diffusion-xl-diffusers\"])\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n\n    # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it\n    scheduler_args = {}\n\n    if not args.do_edm_style_training:\n        if \"variance_type\" in pipeline.scheduler.config:\n            variance_type = pipeline.scheduler.config.variance_type\n\n            if variance_type in [\"learned\", \"learned_range\"]:\n                variance_type = \"fixed_small\"\n\n            scheduler_args[\"variance_type\"] = variance_type\n\n        pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n    # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better\n    # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051\n    inference_ctx = (\n        contextlib.nullcontext() if \"playground\" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()\n    )\n\n    with inference_ctx:\n        images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\n        \"--image_column\",\n        type=str,\n        default=\"image\",\n        help=\"The column of the dataset containing the target image. By \"\n        \"default, the standard Image Dataset maps out 'file_name' \"\n        \"to 'image'.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=None,\n        help=\"The column of the dataset containing the instance prompt for each image\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--class_data_dir\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"A folder containing the training data of class images.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--class_prompt\",\n        type=str,\n        default=None,\n        help=\"The prompt to specify images in the same class as provided instance images.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--do_edm_style_training\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to conduct training using the EDM formulation as introduced in https://huggingface.co/papers/2206.00364.\",\n    )\n    parser.add_argument(\n        \"--with_prior_preservation\",\n        default=False,\n        action=\"store_true\",\n        help=\"Flag to add prior preservation loss.\",\n    )\n    parser.add_argument(\"--prior_loss_weight\", type=float, default=1.0, help=\"The weight of prior preservation loss.\")\n    parser.add_argument(\n        \"--num_class_images\",\n        type=int,\n        default=100,\n        help=(\n            \"Minimal class images for prior preservation loss. If there are not enough images already present in\"\n            \" class_data_dir, additional images will be sampled with class_prompt.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"lora-dreambooth-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--output_kohya_format\",\n        action=\"store_true\",\n        help=\"Flag to additionally generate final state dict in the Kohya format so that it becomes compatible with A111, Comfy, Kohya, etc.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--text_encoder_lr\",\n        type=float,\n        default=5e-6,\n        help=\"Text encoder learning rate to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    parser.add_argument(\n        \"--adam_weight_decay_text_encoder\", type=float, default=1e-03, help=\"Weight decay to use for text_encoder\"\n    )\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--use_dora\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://huggingface.co/papers/2402.09353. \"\n            \"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`\"\n        ),\n    )\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"l2\",\n        choices=[\"l2\", \"huber\", \"smooth_l1\"],\n        help=\"The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.\",\n    )\n    parser.add_argument(\n        \"--huber_schedule\",\n        type=str,\n        default=\"snr\",\n        choices=[\"constant\", \"exponential\", \"snr\"],\n        help=\"The schedule to use for the huber losses parameter\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=0.1,\n        help=\"The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.instance_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--instance_data_dir`\")\n\n    if args.dataset_name is not None and args.instance_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--instance_data_dir`\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.with_prior_preservation:\n        if args.class_data_dir is None:\n            raise ValueError(\"You must specify a data directory for class images.\")\n        if args.class_prompt is None:\n            raise ValueError(\"You must specify prompt for class images.\")\n    else:\n        # logger is not available yet\n        if args.class_data_dir is not None:\n            warnings.warn(\"You need not use --class_data_dir without --with_prior_preservation.\")\n        if args.class_prompt is not None:\n            warnings.warn(\"You need not use --class_prompt without --with_prior_preservation.\")\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        instance_data_root,\n        instance_prompt,\n        class_prompt,\n        class_data_root=None,\n        class_num=None,\n        size=1024,\n        repeats=1,\n        center_crop=False,\n    ):\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.custom_instance_prompts = None\n        self.class_prompt = class_prompt\n\n        # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,\n        # we load the training data using load_dataset\n        if args.dataset_name is not None:\n            try:\n                from datasets import load_dataset\n            except ImportError:\n                raise ImportError(\n                    \"You are trying to load your data using the datasets library. If you wish to train using custom \"\n                    \"captions please install the datasets library: `pip install datasets`. If you wish to load a \"\n                    \"local folder containing images only, specify --instance_data_dir instead.\"\n                )\n            # Downloading and loading a dataset from the hub.\n            # See more about loading custom images at\n            # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n            dataset = load_dataset(\n                args.dataset_name,\n                args.dataset_config_name,\n                cache_dir=args.cache_dir,\n            )\n            # Preprocessing the datasets.\n            column_names = dataset[\"train\"].column_names\n\n            # 6. Get the column names for input/target.\n            if args.image_column is None:\n                image_column = column_names[0]\n                logger.info(f\"image column defaulting to {image_column}\")\n            else:\n                image_column = args.image_column\n                if image_column not in column_names:\n                    raise ValueError(\n                        f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n            instance_images = dataset[\"train\"][image_column]\n\n            if args.caption_column is None:\n                logger.info(\n                    \"No caption column provided, defaulting to instance_prompt for all images. If your dataset \"\n                    \"contains captions/prompts for the images, make sure to specify the \"\n                    \"column as --caption_column\"\n                )\n                self.custom_instance_prompts = None\n            else:\n                if args.caption_column not in column_names:\n                    raise ValueError(\n                        f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n                    )\n                custom_instance_prompts = dataset[\"train\"][args.caption_column]\n                # create final list of captions according to --repeats\n                self.custom_instance_prompts = []\n                for caption in custom_instance_prompts:\n                    self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))\n        else:\n            self.instance_data_root = Path(instance_data_root)\n            if not self.instance_data_root.exists():\n                raise ValueError(\"Instance images root doesn't exists.\")\n\n            instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n            self.custom_instance_prompts = None\n\n        self.instance_images = []\n        for img in instance_images:\n            self.instance_images.extend(itertools.repeat(img, repeats))\n\n        # image processing to prepare for using SD-XL micro-conditioning\n        self.original_sizes = []\n        self.crop_top_lefts = []\n        self.pixel_values = []\n        train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in self.instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            self.original_sizes.append((image.height, image.width))\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            crop_top_left = (y1, x1)\n            self.crop_top_lefts.append(crop_top_left)\n            image = train_transforms(image)\n            self.pixel_values.append(image)\n\n        self.num_instance_images = len(self.instance_images)\n        self._length = self.num_instance_images\n\n        if class_data_root is not None:\n            self.class_data_root = Path(class_data_root)\n            self.class_data_root.mkdir(parents=True, exist_ok=True)\n            self.class_images_path = list(self.class_data_root.iterdir())\n            if class_num is not None:\n                self.num_class_images = min(len(self.class_images_path), class_num)\n            else:\n                self.num_class_images = len(self.class_images_path)\n            self._length = max(self.num_class_images, self.num_instance_images)\n        else:\n            self.class_data_root = None\n\n        self.image_transforms = transforms.Compose(\n            [\n                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),\n                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        original_size = self.original_sizes[index % self.num_instance_images]\n        crop_top_left = self.crop_top_lefts[index % self.num_instance_images]\n        example[\"instance_images\"] = instance_image\n        example[\"original_size\"] = original_size\n        example[\"crop_top_left\"] = crop_top_left\n\n        if self.custom_instance_prompts:\n            caption = self.custom_instance_prompts[index % self.num_instance_images]\n            if caption:\n                example[\"instance_prompt\"] = caption\n            else:\n                example[\"instance_prompt\"] = self.instance_prompt\n\n        else:  # custom prompts were provided, but length does not match size of image dataset\n            example[\"instance_prompt\"] = self.instance_prompt\n\n        if self.class_data_root:\n            class_image = Image.open(self.class_images_path[index % self.num_class_images])\n            class_image = exif_transpose(class_image)\n\n            if not class_image.mode == \"RGB\":\n                class_image = class_image.convert(\"RGB\")\n            example[\"class_images\"] = self.image_transforms(class_image)\n            example[\"class_prompt\"] = self.class_prompt\n\n        return example\n\n\ndef collate_fn(examples, with_prior_preservation=False):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompts = [example[\"instance_prompt\"] for example in examples]\n    original_sizes = [example[\"original_size\"] for example in examples]\n    crop_top_lefts = [example[\"crop_top_left\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    if with_prior_preservation:\n        pixel_values += [example[\"class_images\"] for example in examples]\n        prompts += [example[\"class_prompt\"] for example in examples]\n        original_sizes += [example[\"original_size\"] for example in examples]\n        crop_top_lefts += [example[\"crop_top_left\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    batch = {\n        \"pixel_values\": pixel_values,\n        \"prompts\": prompts,\n        \"original_sizes\": original_sizes,\n        \"crop_top_lefts\": crop_top_lefts,\n    }\n    return batch\n\n\nclass PromptDataset(Dataset):\n    \"\"\"A simple dataset to prepare the prompts to generate class images on multiple GPUs.\"\"\"\n\n    def __init__(self, prompt, num_samples):\n        self.prompt = prompt\n        self.num_samples = num_samples\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, index):\n        example = {}\n        example[\"prompt\"] = self.prompt\n        example[\"index\"] = index\n        return example\n\n\ndef tokenize_prompt(tokenizer, prompt):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=tokenizer.model_max_length,\n        truncation=True,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):\n    prompt_embeds_list = []\n\n    for i, text_encoder in enumerate(text_encoders):\n        if tokenizers is not None:\n            tokenizer = tokenizers[i]\n            text_input_ids = tokenize_prompt(tokenizer, prompt)\n        else:\n            assert text_input_ids_list is not None\n            text_input_ids = text_input_ids_list[i]\n\n        prompt_embeds = text_encoder(\n            text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False\n        )\n\n        # We are only ALWAYS interested in the pooled output of the final text encoder\n        pooled_prompt_embeds = prompt_embeds[0]\n        prompt_embeds = prompt_embeds[-1][-2]\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n        prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\n# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already\ndef conditional_loss(\n    model_pred: torch.Tensor,\n    target: torch.Tensor,\n    reduction: str = \"mean\",\n    loss_type: str = \"l2\",\n    huber_c: float = 0.1,\n    weighting: Optional[torch.Tensor] = None,\n):\n    if loss_type == \"l2\":\n        if weighting is not None:\n            loss = torch.mean(\n                (weighting * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                1,\n            )\n            if reduction == \"mean\":\n                loss = torch.mean(loss)\n            elif reduction == \"sum\":\n                loss = torch.sum(loss)\n        else:\n            loss = F.mse_loss(model_pred.float(), target.float(), reduction=reduction)\n\n    elif loss_type == \"huber\":\n        if weighting is not None:\n            loss = torch.mean(\n                (\n                    2\n                    * huber_c\n                    * (\n                        torch.sqrt(weighting.float() * (model_pred.float() - target.float()) ** 2 + huber_c**2)\n                        - huber_c\n                    )\n                ).reshape(target.shape[0], -1),\n                1,\n            )\n            if reduction == \"mean\":\n                loss = torch.mean(loss)\n            elif reduction == \"sum\":\n                loss = torch.sum(loss)\n        else:\n            loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n            if reduction == \"mean\":\n                loss = torch.mean(loss)\n            elif reduction == \"sum\":\n                loss = torch.sum(loss)\n    elif loss_type == \"smooth_l1\":\n        if weighting is not None:\n            loss = torch.mean(\n                (\n                    2\n                    * (\n                        torch.sqrt(weighting.float() * (model_pred.float() - target.float()) ** 2 + huber_c**2)\n                        - huber_c\n                    )\n                ).reshape(target.shape[0], -1),\n                1,\n            )\n            if reduction == \"mean\":\n                loss = torch.mean(loss)\n            elif reduction == \"sum\":\n                loss = torch.sum(loss)\n        else:\n            loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n            if reduction == \"mean\":\n                loss = torch.mean(loss)\n            elif reduction == \"sum\":\n                loss = torch.sum(loss)\n    else:\n        raise NotImplementedError(f\"Unsupported Loss Type {loss_type}\")\n    return loss\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if args.do_edm_style_training and args.snr_gamma is not None:\n        raise ValueError(\"Min-SNR formulation is not supported when conducting EDM-style training.\")\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Generate class images if prior preservation is enabled.\n    if args.with_prior_preservation:\n        class_images_dir = Path(args.class_data_dir)\n        if not class_images_dir.exists():\n            class_images_dir.mkdir(parents=True)\n        cur_class_images = len(list(class_images_dir.iterdir()))\n\n        if cur_class_images < args.num_class_images:\n            torch_dtype = torch.float16 if accelerator.device.type == \"cuda\" else torch.float32\n            if args.prior_generation_precision == \"fp32\":\n                torch_dtype = torch.float32\n            elif args.prior_generation_precision == \"fp16\":\n                torch_dtype = torch.float16\n            elif args.prior_generation_precision == \"bf16\":\n                torch_dtype = torch.bfloat16\n            pipeline = StableDiffusionXLPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                torch_dtype=torch_dtype,\n                revision=args.revision,\n                variant=args.variant,\n            )\n            pipeline.set_progress_bar_config(disable=True)\n\n            num_new_images = args.num_class_images - cur_class_images\n            logger.info(f\"Number of class images to sample: {num_new_images}.\")\n\n            sample_dataset = PromptDataset(args.class_prompt, num_new_images)\n            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)\n\n            sample_dataloader = accelerator.prepare(sample_dataloader)\n            pipeline.to(accelerator.device)\n\n            for example in tqdm(\n                sample_dataloader, desc=\"Generating class images\", disable=not accelerator.is_local_main_process\n            ):\n                images = pipeline(example[\"prompt\"]).images\n\n                for i, image in enumerate(images):\n                    hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()\n                    image_filename = class_images_dir / f\"{example['index'][i] + cur_class_images}-{hash_image}.jpg\"\n                    image.save(image_filename)\n\n            del pipeline\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)\n    if \"EDM\" in scheduler_type:\n        args.do_edm_style_training = True\n        noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n        logger.info(\"Performing EDM-style training!\")\n    elif args.do_edm_style_training:\n        noise_scheduler = EulerDiscreteScheduler.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n        )\n        logger.info(\"Performing EDM-style training!\")\n    else:\n        noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    latents_mean = latents_std = None\n    if hasattr(vae.config, \"latents_mean\") and vae.config.latents_mean is not None:\n        latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)\n    if hasattr(vae.config, \"latents_std\") and vae.config.latents_std is not None:\n        latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)\n\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n\n    # The VAE is always in float32 to avoid NaN losses.\n    vae.to(accelerator.device, dtype=torch.float32)\n\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, \"\n                    \"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder_one.gradient_checkpointing_enable()\n            text_encoder_two.gradient_checkpointing_enable()\n\n    # now we will add new LoRA weights to the attention layers\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        use_dora=args.use_dora,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n    unet.add_adapter(unet_lora_config)\n\n    # The text encoder comes from 🤗 transformers, so we cannot directly modify it.\n    # So, instead, we monkey-patch the forward calls of its attention-blocks.\n    if args.train_text_encoder:\n        text_lora_config = LoraConfig(\n            r=args.rank,\n            use_dora=args.use_dora,\n            lora_alpha=args.rank,\n            init_lora_weights=\"gaussian\",\n            target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n        )\n        text_encoder_one.add_adapter(text_lora_config)\n        text_encoder_two.add_adapter(text_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            # there are only two options here. Either are just the unet attn processor layers\n            # or there are the unet and text encoder atten layers\n            unet_lora_layers_to_save = None\n            text_encoder_one_lora_layers_to_save = None\n            text_encoder_two_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(unwrap_model(unet))):\n                    unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))\n                elif isinstance(model, type(unwrap_model(text_encoder_one))):\n                    text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(\n                        get_peft_model_state_dict(model)\n                    )\n                elif isinstance(model, type(unwrap_model(text_encoder_two))):\n                    text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(\n                        get_peft_model_state_dict(model)\n                    )\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            StableDiffusionXLPipeline.save_lora_weights(\n                output_dir,\n                unet_lora_layers=unet_lora_layers_to_save,\n                text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,\n                text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n        text_encoder_one_ = None\n        text_encoder_two_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(unet))):\n                unet_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder_one))):\n                text_encoder_one_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder_two))):\n                text_encoder_two_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)\n\n        unet_state_dict = {f\"{k.replace('unet.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"unet.\")}\n        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)\n        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        if args.train_text_encoder:\n            # Do we need to call `scale_lora_layers()` here?\n            _set_state_dict_into_text_encoder(lora_state_dict, prefix=\"text_encoder.\", text_encoder=text_encoder_one_)\n\n            _set_state_dict_into_text_encoder(\n                lora_state_dict, prefix=\"text_encoder_2.\", text_encoder=text_encoder_two_\n            )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [unet_]\n            if args.train_text_encoder:\n                models.extend([text_encoder_one_, text_encoder_two_])\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [unet]\n        if args.train_text_encoder:\n            models.extend([text_encoder_one, text_encoder_two])\n\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))\n\n    if args.train_text_encoder:\n        text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))\n        text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))\n\n    # Optimization parameters\n    unet_lora_parameters_with_lr = {\"params\": unet_lora_parameters, \"lr\": args.learning_rate}\n    if args.train_text_encoder:\n        # different learning rate for text encoder and unet\n        text_lora_parameters_one_with_lr = {\n            \"params\": text_lora_parameters_one,\n            \"weight_decay\": args.adam_weight_decay_text_encoder,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        text_lora_parameters_two_with_lr = {\n            \"params\": text_lora_parameters_two,\n            \"weight_decay\": args.adam_weight_decay_text_encoder,\n            \"lr\": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,\n        }\n        params_to_optimize = [\n            unet_lora_parameters_with_lr,\n            text_lora_parameters_one_with_lr,\n            text_lora_parameters_two_with_lr,\n        ]\n    else:\n        params_to_optimize = [unet_lora_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n        if args.train_text_encoder and args.text_encoder_lr:\n            logger.warning(\n                f\"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:\"\n                f\" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. \"\n                f\"When using prodigy only learning_rate is used as the initial learning rate.\"\n            )\n            # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be\n            # --learning_rate\n            params_to_optimize[1][\"lr\"] = args.learning_rate\n            params_to_optimize[2][\"lr\"] = args.learning_rate\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        class_prompt=args.class_prompt,\n        class_data_root=args.class_data_dir if args.with_prior_preservation else None,\n        class_num=args.num_class_images,\n        size=args.resolution,\n        repeats=args.repeats,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Computes additional embeddings/ids required by the SDXL UNet.\n    # regular text embeddings (when `train_text_encoder` is not True)\n    # pooled text embeddings\n    # time ids\n\n    def compute_time_ids(original_size, crops_coords_top_left):\n        # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n        target_size = (args.resolution, args.resolution)\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n        add_time_ids = torch.tensor([add_time_ids])\n        add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)\n        return add_time_ids\n\n    if not args.train_text_encoder:\n        tokenizers = [tokenizer_one, tokenizer_two]\n        text_encoders = [text_encoder_one, text_encoder_two]\n\n        def compute_text_embeddings(prompt, text_encoders, tokenizers):\n            with torch.no_grad():\n                prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)\n                prompt_embeds = prompt_embeds.to(accelerator.device)\n                pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)\n            return prompt_embeds, pooled_prompt_embeds\n\n    # If no type of tuning is done on the text_encoder and custom instance prompts are NOT\n    # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid\n    # the redundant encoding.\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(\n            args.instance_prompt, text_encoders, tokenizers\n        )\n\n    # Handle class prompt for prior-preservation.\n    if args.with_prior_preservation:\n        if not args.train_text_encoder:\n            class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(\n                args.class_prompt, text_encoders, tokenizers\n            )\n\n    # Clear the memory here\n    if not args.train_text_encoder and not train_dataset.custom_instance_prompts:\n        del tokenizers, text_encoders\n        gc.collect()\n        torch.cuda.empty_cache()\n\n    # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),\n    # pack the statically computed variables appropriately here. This is so that we don't\n    # have to pass them to the dataloader.\n\n    if not train_dataset.custom_instance_prompts:\n        if not args.train_text_encoder:\n            prompt_embeds = instance_prompt_hidden_states\n            unet_add_text_embeds = instance_pooled_prompt_embeds\n            if args.with_prior_preservation:\n                prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)\n                unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)\n        # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the\n        # batch prompts on all training steps\n        else:\n            tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)\n            tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)\n            if args.with_prior_preservation:\n                class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)\n                class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)\n                tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)\n                tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = (\n            \"dreambooth-lora-sd-xl\"\n            if \"playground\" not in args.pretrained_model_name_or_path\n            else \"dreambooth-lora-playground\"\n        )\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.train_text_encoder:\n            text_encoder_one.train()\n            text_encoder_two.train()\n\n            # set top parameter requires_grad = True for gradient checkpointing works\n            accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)\n            accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)\n\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n                prompts = batch[\"prompts\"]\n\n                # encode batch prompts when custom prompts are provided for each image -\n                if train_dataset.custom_instance_prompts:\n                    if not args.train_text_encoder:\n                        prompt_embeds, unet_add_text_embeds = compute_text_embeddings(\n                            prompts, text_encoders, tokenizers\n                        )\n                    else:\n                        tokens_one = tokenize_prompt(tokenizer_one, prompts)\n                        tokens_two = tokenize_prompt(tokenizer_two, prompts)\n\n                # Convert images to latent space\n                model_input = vae.encode(pixel_values).latent_dist.sample()\n\n                if latents_mean is None and latents_std is None:\n                    model_input = model_input * vae.config.scaling_factor\n                    if args.pretrained_vae_model_name_or_path is None:\n                        model_input = model_input.to(weight_dtype)\n                else:\n                    latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)\n                    latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)\n                    model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std\n                    model_input = model_input.to(dtype=weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                if not args.do_edm_style_training:\n                    if args.loss_type == \"huber\" or args.loss_type == \"smooth_l1\":\n                        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=\"cpu\")\n                        timestep = timesteps.item()\n\n                        if args.huber_schedule == \"exponential\":\n                            alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps\n                            huber_c = math.exp(-alpha * timestep)\n                        elif args.huber_schedule == \"snr\":\n                            alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]\n                            sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5\n                            huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c\n                        elif args.huber_schedule == \"constant\":\n                            huber_c = args.huber_c\n                        else:\n                            raise NotImplementedError(f\"Unknown Huber loss schedule {args.huber_schedule}!\")\n\n                        timesteps = timesteps.repeat(bsz).to(model_input.device)\n                    elif args.loss_type == \"l2\":\n                        timesteps = torch.randint(\n                            0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                        )\n                        huber_c = 1  # may be anything, as it's not used\n                    else:\n                        raise NotImplementedError(f\"Unknown loss type {args.loss_type}\")\n                    timesteps = timesteps.long()\n                else:\n                    if \"huber\" in args.loss_type or \"l1\" in args.loss_type:\n                        raise NotImplementedError(\"Huber loss is not implemented for EDM training yet!\")\n                    # in EDM formulation, the model is conditioned on the pre-conditioned noise levels\n                    # instead of discrete timesteps, so here we sample indices to get the noise levels\n                    # from `scheduler.timesteps`\n                    indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))\n                    timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n                # For EDM-style training, we first obtain the sigmas based on the continuous timesteps.\n                # We then precondition the final model inputs based on these sigmas instead of the timesteps.\n                # Follow: Section 5 of https://huggingface.co/papers/2206.00364.\n                if args.do_edm_style_training:\n                    sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)\n                    if \"EDM\" in scheduler_type:\n                        inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)\n                    else:\n                        inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)\n\n                # time ids\n                add_time_ids = torch.cat(\n                    [\n                        compute_time_ids(original_size=s, crops_coords_top_left=c)\n                        for s, c in zip(batch[\"original_sizes\"], batch[\"crop_top_lefts\"])\n                    ]\n                )\n\n                # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.\n                if not train_dataset.custom_instance_prompts:\n                    elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz\n                else:\n                    elems_to_repeat_text_embeds = 1\n\n                # Predict the noise residual\n                if not args.train_text_encoder:\n                    unet_added_conditions = {\n                        \"time_ids\": add_time_ids,\n                        \"text_embeds\": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),\n                    }\n                    prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)\n                    model_pred = unet(\n                        inp_noisy_latents if args.do_edm_style_training else noisy_model_input,\n                        timesteps,\n                        prompt_embeds_input,\n                        added_cond_kwargs=unet_added_conditions,\n                        return_dict=False,\n                    )[0]\n                else:\n                    unet_added_conditions = {\"time_ids\": add_time_ids}\n                    prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                        text_encoders=[text_encoder_one, text_encoder_two],\n                        tokenizers=None,\n                        prompt=None,\n                        text_input_ids_list=[tokens_one, tokens_two],\n                    )\n                    unet_added_conditions.update(\n                        {\"text_embeds\": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}\n                    )\n                    prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)\n                    model_pred = unet(\n                        inp_noisy_latents if args.do_edm_style_training else noisy_model_input,\n                        timesteps,\n                        prompt_embeds_input,\n                        added_cond_kwargs=unet_added_conditions,\n                        return_dict=False,\n                    )[0]\n\n                weighting = None\n                if args.do_edm_style_training:\n                    # Similar to the input preconditioning, the model predictions are also preconditioned\n                    # on noised model inputs (before preconditioning) and the sigmas.\n                    # Follow: Section 5 of https://huggingface.co/papers/2206.00364.\n                    if \"EDM\" in scheduler_type:\n                        model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)\n                    else:\n                        if noise_scheduler.config.prediction_type == \"epsilon\":\n                            model_pred = model_pred * (-sigmas) + noisy_model_input\n                        elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                            model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (\n                                noisy_model_input / (sigmas**2 + 1)\n                            )\n                    # We are not doing weighting here because it tends result in numerical problems.\n                    # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051\n                    # There might be other alternatives for weighting as well:\n                    # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686\n                    if \"EDM\" not in scheduler_type:\n                        weighting = (sigmas**-2.0).float()\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = model_input if args.do_edm_style_training else noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = (\n                        model_input\n                        if args.do_edm_style_training\n                        else noise_scheduler.get_velocity(model_input, noise, timesteps)\n                    )\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.with_prior_preservation:\n                    # Chunk the noise and model_pred into two parts and compute the loss on each part separately.\n                    model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)\n                    target, target_prior = torch.chunk(target, 2, dim=0)\n\n                    # Compute prior loss\n                    prior_loss = conditional_loss(\n                        model_pred_prior,\n                        target_prior,\n                        reduction=\"mean\",\n                        loss_type=args.loss_type,\n                        huber_c=huber_c,\n                        weighting=weighting,\n                    )\n\n                if args.snr_gamma is None:\n                    loss = conditional_loss(\n                        model_pred,\n                        target,\n                        reduction=\"mean\",\n                        loss_type=args.loss_type,\n                        huber_c=huber_c,\n                        weighting=weighting,\n                    )\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    base_weight = (\n                        torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr\n                    )\n\n                    if noise_scheduler.config.prediction_type == \"v_prediction\":\n                        # Velocity objective needs to be floored to an SNR weight of one.\n                        mse_loss_weights = base_weight + 1\n                    else:\n                        # Epsilon and sample both use the same loss weights.\n                        mse_loss_weights = base_weight\n\n                    loss = conditional_loss(\n                        model_pred, target, reduction=\"none\", loss_type=args.loss_type, huber_c=huber_c, weighting=None\n                    )\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                if args.with_prior_preservation:\n                    # Add the prior loss to the instance loss.\n                    loss = loss + args.prior_loss_weight * prior_loss\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)\n                        if args.train_text_encoder\n                        else unet_lora_parameters\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                if not args.train_text_encoder:\n                    text_encoder_one = text_encoder_cls_one.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        subfolder=\"text_encoder\",\n                        revision=args.revision,\n                        variant=args.variant,\n                    )\n                    text_encoder_two = text_encoder_cls_two.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        subfolder=\"text_encoder_2\",\n                        revision=args.revision,\n                        variant=args.variant,\n                    )\n                pipeline = StableDiffusionXLPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    text_encoder=accelerator.unwrap_model(text_encoder_one),\n                    text_encoder_2=accelerator.unwrap_model(text_encoder_two),\n                    unet=accelerator.unwrap_model(unet),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n\n                images = log_validation(\n                    pipeline,\n                    args,\n                    accelerator,\n                    pipeline_args,\n                    epoch,\n                )\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        unet = unet.to(torch.float32)\n        unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n\n        if args.train_text_encoder:\n            text_encoder_one = unwrap_model(text_encoder_one)\n            text_encoder_lora_layers = convert_state_dict_to_diffusers(\n                get_peft_model_state_dict(text_encoder_one.to(torch.float32))\n            )\n            text_encoder_two = unwrap_model(text_encoder_two)\n            text_encoder_2_lora_layers = convert_state_dict_to_diffusers(\n                get_peft_model_state_dict(text_encoder_two.to(torch.float32))\n            )\n        else:\n            text_encoder_lora_layers = None\n            text_encoder_2_lora_layers = None\n\n        StableDiffusionXLPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_layers,\n            text_encoder_lora_layers=text_encoder_lora_layers,\n            text_encoder_2_lora_layers=text_encoder_2_lora_layers,\n        )\n        if args.output_kohya_format:\n            lora_state_dict = load_file(f\"{args.output_dir}/pytorch_lora_weights.safetensors\")\n            peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)\n            kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)\n            save_file(kohya_state_dict, f\"{args.output_dir}/pytorch_lora_weights_kohya.safetensors\")\n\n        # Final inference\n        # Load previous pipeline\n        vae = AutoencoderKL.from_pretrained(\n            vae_path,\n            subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        pipeline = StableDiffusionXLPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            vae=vae,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt, \"num_inference_steps\": 25}\n            images = log_validation(\n                pipeline,\n                args,\n                accelerator,\n                pipeline_args,\n                epoch,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                use_dora=args.use_dora,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                train_text_encoder=args.train_text_encoder,\n                instance_prompt=args.instance_prompt,\n                validation_prompt=args.validation_prompt,\n                repo_folder=args.output_dir,\n                vae_path=args.pretrained_vae_model_name_or_path,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\nfrom transformers.utils import ContextManagers\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel, compute_snr\nfrom diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.28.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef save_model_card(\n    args,\n    repo_id: str,\n    images: list = None,\n    repo_folder: str = None,\n):\n    img_str = \"\"\n    if len(images) > 0:\n        image_grid = make_image_grid(images, 1, len(args.validation_prompts))\n        image_grid.save(os.path.join(repo_folder, \"val_imgs_grid.png\"))\n        img_str += \"![val_imgs_grid](./val_imgs_grid.png)\\n\"\n\n    model_description = f\"\"\"\n# Text-to-image finetuning - {repo_id}\n\nThis pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \\n\n{img_str}\n\n## Pipeline usage\n\nYou can use the pipeline like so:\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(\"{repo_id}\", torch_dtype=torch.float16)\nprompt = \"{args.validation_prompts[0]}\"\nimage = pipeline(prompt).images[0]\nimage.save(\"my_image.png\")\n```\n\n## Training info\n\nThese are the key hyperparameters used during training:\n\n* Epochs: {args.num_train_epochs}\n* Learning rate: {args.learning_rate}\n* Batch size: {args.train_batch_size}\n* Gradient accumulation steps: {args.gradient_accumulation_steps}\n* Image resolution: {args.resolution}\n* Mixed-precision: {args.mixed_precision}\n\n\"\"\"\n    wandb_info = \"\"\n    if is_wandb_available():\n        wandb_run_url = None\n        if wandb.run is not None:\n            wandb_run_url = wandb.run.url\n\n    if wandb_run_url is not None:\n        wandb_info = f\"\"\"\nMore information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).\n\"\"\"\n\n    model_description += wandb_info\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=args.pretrained_model_name_or_path,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\"stable-diffusion\", \"stable-diffusion-diffusers\", \"text-to-image\", \"diffusers\", \"diffusers-training\"]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):\n    logger.info(\"Running validation... \")\n\n    pipeline = StableDiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=accelerator.unwrap_model(vae),\n        text_encoder=accelerator.unwrap_model(text_encoder),\n        tokenizer=tokenizer,\n        unet=accelerator.unwrap_model(unet),\n        safety_checker=None,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    images = []\n    for i in range(len(args.validation_prompts)):\n        with torch.autocast(\"cuda\"):\n            image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]\n\n        images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompts[i]}\")\n                        for i, image in enumerate(images)\n                    ]\n                }\n            )\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    return images\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--input_perturbation\", type=float, default=0, help=\"The scale of input perturbation. Recommended 0.1.\"\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompts\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\"A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`.\"),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-model-finetuned\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--non_ema_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or\"\n            \" remote repository specified with --pretrained_model_name_or_path.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=None,\n        help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.\",\n    )\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"l2\",\n        choices=[\"l2\", \"huber\", \"smooth_l1\"],\n        help=\"The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.\",\n    )\n    parser.add_argument(\n        \"--huber_schedule\",\n        type=str,\n        default=\"snr\",\n        choices=[\"constant\", \"exponential\", \"snr\"],\n        help=\"The schedule to use for the huber losses parameter\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=0.1,\n        help=\"The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=5,\n        help=\"Run validation every X epochs.\",\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"text2image-fine-tune\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    # default to using the same revision for the non-ema model if not specified\n    if args.non_ema_revision is None:\n        args.non_ema_revision = args.revision\n\n    return args\n\n\n# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already\ndef conditional_loss(\n    model_pred: torch.Tensor,\n    target: torch.Tensor,\n    reduction: str = \"mean\",\n    loss_type: str = \"l2\",\n    huber_c: float = 0.1,\n):\n    if loss_type == \"l2\":\n        loss = F.mse_loss(model_pred, target, reduction=reduction)\n    elif loss_type == \"huber\":\n        loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n        if reduction == \"mean\":\n            loss = torch.mean(loss)\n        elif reduction == \"sum\":\n            loss = torch.sum(loss)\n    elif loss_type == \"smooth_l1\":\n        loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n        if reduction == \"mean\":\n            loss = torch.mean(loss)\n        elif reduction == \"sum\":\n            loss = torch.sum(loss)\n    else:\n        raise NotImplementedError(f\"Unsupported Loss Type {loss_type}\")\n    return loss\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if args.non_ema_revision is not None:\n        deprecate(\n            \"non_ema_revision!=None\",\n            \"0.15.0\",\n            message=(\n                \"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to\"\n                \" use `--variant=non_ema` instead.\"\n            ),\n        )\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n    )\n\n    def deepspeed_zero_init_disabled_context_manager():\n        \"\"\"\n        returns either a context list that includes one that will disable zero.Init or an empty context list\n        \"\"\"\n        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None\n        if deepspeed_plugin is None:\n            return []\n\n        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]\n\n    # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.\n    # For this to work properly all models must be run through `accelerate.prepare`. But accelerate\n    # will try to assign the same optimizer with the same weights to all models during\n    # `deepspeed.initialize`, which of course doesn't work.\n    #\n    # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2\n    # frozen models from being partitioned during `zero.Init` which gets called during\n    # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding\n    # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.\n    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n        text_encoder = CLIPTextModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n        )\n        vae = AutoencoderKL.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n        )\n\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.non_ema_revision\n    )\n\n    # Freeze vae and text_encoder and set unet to trainable\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    unet.train()\n\n    # Create EMA for the unet.\n    if args.use_ema:\n        ema_unet = UNet2DConditionModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n        )\n        ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_unet.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"unet_ema\"), UNet2DConditionModel)\n                ema_unet.load_state_dict(load_model.state_dict())\n                ema_unet.to(accelerator.device)\n                del load_model\n\n            for _ in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    optimizer = optimizer_cls(\n        unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            data_dir=args.train_data_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        return inputs.input_ids\n\n    # Preprocessing the datasets.\n    train_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n        examples[\"input_ids\"] = tokenize_captions(examples)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n        return {\"pixel_values\": pixel_values, \"input_ids\": input_ids}\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    if args.use_ema:\n        ema_unet.to(accelerator.device)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n        args.mixed_precision = accelerator.mixed_precision\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n        args.mixed_precision = accelerator.mixed_precision\n\n    # Move text_encode and vae to gpu and cast to weight_dtype\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        tracker_config.pop(\"validation_prompts\")\n        accelerator.init_trackers(args.tracker_project_name, tracker_config)\n\n    # Function for unwrapping if model was compiled with `torch.compile`.\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (latents.shape[0], latents.shape[1], 1, 1), device=latents.device\n                    )\n                if args.input_perturbation:\n                    new_noise = noise + args.input_perturbation * torch.randn_like(noise)\n                bsz = latents.shape[0]\n\n                # Sample a random timestep for each image\n                if args.loss_type == \"huber\" or args.loss_type == \"smooth_l1\":\n                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=\"cpu\")\n                    timestep = timesteps.item()\n\n                    if args.huber_schedule == \"exponential\":\n                        alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps\n                        huber_c = math.exp(-alpha * timestep)\n                    elif args.huber_schedule == \"snr\":\n                        alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]\n                        sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5\n                        huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c\n                    elif args.huber_schedule == \"constant\":\n                        huber_c = args.huber_c\n                    else:\n                        raise NotImplementedError(f\"Unknown Huber loss schedule {args.huber_schedule}!\")\n\n                    timesteps = timesteps.repeat(bsz).to(latents.device)\n                elif args.loss_type == \"l2\":\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device\n                    )\n                    huber_c = 1  # may be anything, as it's not used\n                else:\n                    raise NotImplementedError(f\"Unknown loss type {args.loss_type}\")\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                if args.input_perturbation:\n                    noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)\n                else:\n                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"], return_dict=False)[0]\n\n                # Get the target for loss depending on the prediction type\n                if args.prediction_type is not None:\n                    # set prediction_type of scheduler if defined\n                    noise_scheduler.register_to_config(prediction_type=args.prediction_type)\n\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # Predict the noise residual and compute loss\n                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]\n\n                if args.snr_gamma is None:\n                    loss = conditional_loss(\n                        model_pred.float(), target.float(), reduction=\"mean\", loss_type=args.loss_type, huber_c=huber_c\n                    )\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = conditional_loss(\n                        model_pred.float(), target.float(), reduction=\"none\", loss_type=args.loss_type, huber_c=huber_c\n                    )\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_unet.step(unet.parameters())\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompts is not None and epoch % args.validation_epochs == 0:\n                if args.use_ema:\n                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                    ema_unet.store(unet.parameters())\n                    ema_unet.copy_to(unet.parameters())\n                log_validation(\n                    vae,\n                    text_encoder,\n                    tokenizer,\n                    unet,\n                    args,\n                    accelerator,\n                    weight_dtype,\n                    global_step,\n                )\n                if args.use_ema:\n                    # Switch back to the original UNet parameters.\n                    ema_unet.restore(unet.parameters())\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        if args.use_ema:\n            ema_unet.copy_to(unet.parameters())\n\n        pipeline = StableDiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            text_encoder=text_encoder,\n            vae=vae,\n            unet=unet,\n            revision=args.revision,\n            variant=args.variant,\n        )\n        pipeline.save_pretrained(args.output_dir)\n\n        # Run a final round of inference.\n        images = []\n        if args.validation_prompts is not None:\n            logger.info(\"Running inference for collecting generated images...\")\n            pipeline = pipeline.to(accelerator.device)\n            pipeline.torch_dtype = weight_dtype\n            pipeline.set_progress_bar_config(disable=True)\n\n            if args.enable_xformers_memory_efficient_attention:\n                pipeline.enable_xformers_memory_efficient_attention()\n\n            if args.seed is None:\n                generator = None\n            else:\n                generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n            for i in range(len(args.validation_prompts)):\n                with torch.autocast(\"cuda\"):\n                    image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]\n                images.append(image)\n\n        if args.push_to_hub:\n            save_model_card(args, repo_id, images, repo_folder=args.output_dir)\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Fine-tuning script for Stable Diffusion for text2image with support for LoRA.\"\"\"\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig\nfrom peft.utils import get_peft_model_state_dict\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import cast_training_params, compute_snr\nfrom diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.28.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef save_model_card(\n    repo_id: str,\n    images: list = None,\n    base_model: str = None,\n    dataset_name: str = None,\n    repo_folder: str = None,\n):\n    img_str = \"\"\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# LoRA text2image fine-tuning - {repo_id}\nThese are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \\n\n{img_str}\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion\",\n        \"stable-diffusion-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"diffusers-training\",\n        \"lora\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\", type=str, default=None, help=\"A prompt that is sampled during training for inference.\"\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-model-finetuned-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=None,\n        help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.\",\n    )\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"l2\",\n        choices=[\"l2\", \"huber\", \"smooth_l1\"],\n        help=\"The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.\",\n    )\n    parser.add_argument(\n        \"--huber_schedule\",\n        type=str,\n        default=\"snr\",\n        choices=[\"constant\", \"exponential\", \"snr\"],\n        help=\"The schedule to use for the huber losses parameter\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=0.1,\n        help=\"The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.\",\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    return args\n\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\n# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already\ndef conditional_loss(\n    model_pred: torch.Tensor,\n    target: torch.Tensor,\n    reduction: str = \"mean\",\n    loss_type: str = \"l2\",\n    huber_c: float = 0.1,\n):\n    if loss_type == \"l2\":\n        loss = F.mse_loss(model_pred, target, reduction=reduction)\n    elif loss_type == \"huber\":\n        loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n        if reduction == \"mean\":\n            loss = torch.mean(loss)\n        elif reduction == \"sum\":\n            loss = torch.sum(loss)\n    elif loss_type == \"smooth_l1\":\n        loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n        if reduction == \"mean\":\n            loss = torch.mean(loss)\n        elif reduction == \"sum\":\n            loss = torch.sum(loss)\n    else:\n        raise NotImplementedError(f\"Unsupported Loss Type {loss_type}\")\n    return loss\n\n\ndef main():\n    args = parse_args()\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n    )\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n    # freeze parameters of models to save more memory\n    unet.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Freeze the unet parameters before adding adapters\n    for param in unet.parameters():\n        param.requires_grad_(False)\n\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # Add adapter and make sure the trainable params are in float32.\n    unet.add_adapter(unet_lora_config)\n    if args.mixed_precision == \"fp16\":\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(unet, dtype=torch.float32)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    lora_layers = filter(lambda p: p.requires_grad, unet.parameters())\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    optimizer = optimizer_cls(\n        lora_layers,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            data_dir=args.train_data_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        return inputs.input_ids\n\n    # Preprocessing the datasets.\n    train_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n        examples[\"input_ids\"] = tokenize_captions(examples)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n        return {\"pixel_values\": pixel_values, \"input_ids\": input_ids}\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"text2image-fine-tune\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (latents.shape[0], latents.shape[1], 1, 1), device=latents.device\n                    )\n\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                if args.loss_type == \"huber\" or args.loss_type == \"smooth_l1\":\n                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=\"cpu\")\n                    timestep = timesteps.item()\n\n                    if args.huber_schedule == \"exponential\":\n                        alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps\n                        huber_c = math.exp(-alpha * timestep)\n                    elif args.huber_schedule == \"snr\":\n                        alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]\n                        sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5\n                        huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c\n                    elif args.huber_schedule == \"constant\":\n                        huber_c = args.huber_c\n                    else:\n                        raise NotImplementedError(f\"Unknown Huber loss schedule {args.huber_schedule}!\")\n\n                    timesteps = timesteps.repeat(bsz).to(latents.device)\n                elif args.loss_type == \"l2\":\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device\n                    )\n                    huber_c = 1  # may be anything, as it's not used\n                else:\n                    raise NotImplementedError(f\"Unknown loss type {args.loss_type}\")\n\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"], return_dict=False)[0]\n\n                # Get the target for loss depending on the prediction type\n                if args.prediction_type is not None:\n                    # set prediction_type of scheduler if defined\n                    noise_scheduler.register_to_config(prediction_type=args.prediction_type)\n\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # Predict the noise residual and compute loss\n                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]\n\n                if args.snr_gamma is None:\n                    loss = conditional_loss(\n                        model_pred.float(), target.float(), reduction=\"mean\", loss_type=args.loss_type, huber_c=huber_c\n                    )\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = conditional_loss(\n                        model_pred.float(), target.float(), reduction=\"none\", loss_type=args.loss_type, huber_c=huber_c\n                    )\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = lora_layers\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n\n                        unwrapped_unet = unwrap_model(unet)\n                        unet_lora_state_dict = convert_state_dict_to_diffusers(\n                            get_peft_model_state_dict(unwrapped_unet)\n                        )\n\n                        StableDiffusionPipeline.save_lora_weights(\n                            save_directory=save_path,\n                            unet_lora_layers=unet_lora_state_dict,\n                            safe_serialization=True,\n                        )\n\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                logger.info(\n                    f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n                    f\" {args.validation_prompt}.\"\n                )\n                # create pipeline\n                pipeline = DiffusionPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    unet=unwrap_model(unet),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline = pipeline.to(accelerator.device)\n                pipeline.set_progress_bar_config(disable=True)\n\n                # run inference\n                generator = torch.Generator(device=accelerator.device)\n                if args.seed is not None:\n                    generator = generator.manual_seed(args.seed)\n                images = []\n                with torch.cuda.amp.autocast():\n                    for _ in range(args.num_validation_images):\n                        images.append(\n                            pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]\n                        )\n\n                for tracker in accelerator.trackers:\n                    if tracker.name == \"tensorboard\":\n                        np_images = np.stack([np.asarray(img) for img in images])\n                        tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n                    if tracker.name == \"wandb\":\n                        tracker.log(\n                            {\n                                \"validation\": [\n                                    wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                    for i, image in enumerate(images)\n                                ]\n                            }\n                        )\n\n                del pipeline\n                torch.cuda.empty_cache()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unet.to(torch.float32)\n\n        unwrapped_unet = unwrap_model(unet)\n        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))\n        StableDiffusionPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_state_dict,\n            safe_serialization=True,\n        )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                dataset_name=args.dataset_name,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n        # Final inference\n        # Load previous pipeline\n        if args.validation_prompt is not None:\n            pipeline = DiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                revision=args.revision,\n                variant=args.variant,\n                torch_dtype=weight_dtype,\n            )\n            pipeline = pipeline.to(accelerator.device)\n\n            # load attention processors\n            pipeline.load_lora_weights(args.output_dir)\n\n            # run inference\n            generator = torch.Generator(device=accelerator.device)\n            if args.seed is not None:\n                generator = generator.manual_seed(args.seed)\n            images = []\n            with torch.cuda.amp.autocast():\n                for _ in range(args.num_validation_images):\n                    images.append(\n                        pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]\n                    )\n\n            for tracker in accelerator.trackers:\n                if len(images) != 0:\n                    if tracker.name == \"tensorboard\":\n                        np_images = np.stack([np.asarray(img) for img in images])\n                        tracker.writer.add_images(\"test\", np_images, epoch, dataformats=\"NHWC\")\n                    if tracker.name == \"wandb\":\n                        tracker.log(\n                            {\n                                \"test\": [\n                                    wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                    for i, image in enumerate(images)\n                                ]\n                            }\n                        )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA.\"\"\"\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    StableDiffusionXLPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr\nfrom diffusers.utils import (\n    check_min_version,\n    convert_state_dict_to_diffusers,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.28.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images: list = None,\n    base_model: str = None,\n    dataset_name: str = None,\n    train_text_encoder: bool = False,\n    repo_folder: str = None,\n    vae_path: str = None,\n):\n    img_str = \"\"\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# LoRA text2image fine-tuning - {repo_id}\n\nThese are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \\n\n{img_str}\n\nLoRA for the text encoder was enabled: {train_text_encoder}.\n\nSpecial VAE used for training: {vae_path}.\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion-xl\",\n        \"stable-diffusion-xl-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"diffusers-training\",\n        \"lora\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-model-finetuned-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=None,\n        help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.\",\n    )\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"l2\",\n        choices=[\"l2\", \"huber\", \"smooth_l1\"],\n        help=\"The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.\",\n    )\n    parser.add_argument(\n        \"--huber_schedule\",\n        type=str,\n        default=\"snr\",\n        choices=[\"constant\", \"exponential\", \"snr\"],\n        help=\"The schedule to use for the huber losses parameter\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=0.1,\n        help=\"The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.\",\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--debug_loss\",\n        action=\"store_true\",\n        help=\"debug loss for each image, if filenames are awailable in the dataset\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    return args\n\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef tokenize_prompt(tokenizer, prompt):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=tokenizer.model_max_length,\n        truncation=True,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):\n    prompt_embeds_list = []\n\n    for i, text_encoder in enumerate(text_encoders):\n        if tokenizers is not None:\n            tokenizer = tokenizers[i]\n            text_input_ids = tokenize_prompt(tokenizer, prompt)\n        else:\n            assert text_input_ids_list is not None\n            text_input_ids = text_input_ids_list[i]\n\n        prompt_embeds = text_encoder(\n            text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False\n        )\n\n        # We are only ALWAYS interested in the pooled output of the final text encoder\n        pooled_prompt_embeds = prompt_embeds[0]\n        prompt_embeds = prompt_embeds[-1][-2]\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n        prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\n# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already\ndef conditional_loss(\n    model_pred: torch.Tensor,\n    target: torch.Tensor,\n    reduction: str = \"mean\",\n    loss_type: str = \"l2\",\n    huber_c: float = 0.1,\n):\n    if loss_type == \"l2\":\n        loss = F.mse_loss(model_pred, target, reduction=reduction)\n    elif loss_type == \"huber\" or loss_type == \"huber_scheduled\":\n        loss = huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n        if reduction == \"mean\":\n            loss = torch.mean(loss)\n        elif reduction == \"sum\":\n            loss = torch.sum(loss)\n    else:\n        raise NotImplementedError(f\"Unsupported Loss Type {loss_type}\")\n    return loss\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    unet.to(accelerator.device, dtype=weight_dtype)\n\n    if args.pretrained_vae_model_name_or_path is None:\n        vae.to(accelerator.device, dtype=torch.float32)\n    else:\n        vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # now we will add new LoRA weights to the attention layers\n    # Set correct lora layers\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n\n    unet.add_adapter(unet_lora_config)\n\n    # The text encoder comes from 🤗 transformers, we will also attach adapters to it.\n    if args.train_text_encoder:\n        # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16\n        text_lora_config = LoraConfig(\n            r=args.rank,\n            lora_alpha=args.rank,\n            init_lora_weights=\"gaussian\",\n            target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n        )\n        text_encoder_one.add_adapter(text_lora_config)\n        text_encoder_two.add_adapter(text_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            # there are only two options here. Either are just the unet attn processor layers\n            # or there are the unet and text encoder attn layers\n            unet_lora_layers_to_save = None\n            text_encoder_one_lora_layers_to_save = None\n            text_encoder_two_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(unwrap_model(model), type(unwrap_model(unet))):\n                    unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))\n                elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):\n                    text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(\n                        get_peft_model_state_dict(model)\n                    )\n                elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):\n                    text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(\n                        get_peft_model_state_dict(model)\n                    )\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                if weights:\n                    weights.pop()\n\n            StableDiffusionXLPipeline.save_lora_weights(\n                output_dir,\n                unet_lora_layers=unet_lora_layers_to_save,\n                text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,\n                text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n        text_encoder_one_ = None\n        text_encoder_two_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(unet))):\n                unet_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder_one))):\n                text_encoder_one_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder_two))):\n                text_encoder_two_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)\n        unet_state_dict = {f\"{k.replace('unet.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"unet.\")}\n        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)\n        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        if args.train_text_encoder:\n            _set_state_dict_into_text_encoder(lora_state_dict, prefix=\"text_encoder.\", text_encoder=text_encoder_one_)\n\n            _set_state_dict_into_text_encoder(\n                lora_state_dict, prefix=\"text_encoder_2.\", text_encoder=text_encoder_two_\n            )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [unet_]\n            if args.train_text_encoder:\n                models.extend([text_encoder_one_, text_encoder_two_])\n            cast_training_params(models, dtype=torch.float32)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder_one.gradient_checkpointing_enable()\n            text_encoder_two.gradient_checkpointing_enable()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [unet]\n        if args.train_text_encoder:\n            models.extend([text_encoder_one, text_encoder_two])\n        cast_training_params(models, dtype=torch.float32)\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))\n    if args.train_text_encoder:\n        params_to_optimize = (\n            params_to_optimize\n            + list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))\n            + list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))\n        )\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        tokens_one = tokenize_prompt(tokenizer_one, captions)\n        tokens_two = tokenize_prompt(tokenizer_two, captions)\n        return tokens_one, tokens_two\n\n    # Preprocessing the datasets.\n    train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)\n    train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)\n    train_flip = transforms.RandomHorizontalFlip(p=1.0)\n    train_transforms = transforms.Compose(\n        [\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        # image aug\n        original_sizes = []\n        all_images = []\n        crop_top_lefts = []\n        for image in images:\n            original_sizes.append((image.height, image.width))\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            crop_top_left = (y1, x1)\n            crop_top_lefts.append(crop_top_left)\n            image = train_transforms(image)\n            all_images.append(image)\n\n        examples[\"original_sizes\"] = original_sizes\n        examples[\"crop_top_lefts\"] = crop_top_lefts\n        examples[\"pixel_values\"] = all_images\n        tokens_one, tokens_two = tokenize_captions(examples)\n        examples[\"input_ids_one\"] = tokens_one\n        examples[\"input_ids_two\"] = tokens_two\n        if args.debug_loss:\n            fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]\n            if fnames:\n                examples[\"filenames\"] = fnames\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train, output_all_columns=True)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        original_sizes = [example[\"original_sizes\"] for example in examples]\n        crop_top_lefts = [example[\"crop_top_lefts\"] for example in examples]\n        input_ids_one = torch.stack([example[\"input_ids_one\"] for example in examples])\n        input_ids_two = torch.stack([example[\"input_ids_two\"] for example in examples])\n        result = {\n            \"pixel_values\": pixel_values,\n            \"input_ids_one\": input_ids_one,\n            \"input_ids_two\": input_ids_two,\n            \"original_sizes\": original_sizes,\n            \"crop_top_lefts\": crop_top_lefts,\n        }\n\n        filenames = [example[\"filenames\"] for example in examples if \"filenames\" in example]\n        if filenames:\n            result[\"filenames\"] = filenames\n        return result\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"text2image-fine-tune\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.train_text_encoder:\n            text_encoder_one.train()\n            text_encoder_two.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                if args.pretrained_vae_model_name_or_path is not None:\n                    pixel_values = batch[\"pixel_values\"].to(dtype=weight_dtype)\n                else:\n                    pixel_values = batch[\"pixel_values\"]\n\n                model_input = vae.encode(pixel_values).latent_dist.sample()\n                model_input = model_input * vae.config.scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    model_input = model_input.to(weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device\n                    )\n\n                bsz = model_input.shape[0]\n                # Sample a random timestep for each image\n                if args.loss_type == \"huber\" or args.loss_type == \"smooth_l1\":\n                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=\"cpu\")\n                    timestep = timesteps.item()\n\n                    if args.huber_schedule == \"exponential\":\n                        alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps\n                        huber_c = math.exp(-alpha * timestep)\n                    elif args.huber_schedule == \"snr\":\n                        alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]\n                        sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5\n                        huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c\n                    elif args.huber_schedule == \"constant\":\n                        huber_c = args.huber_c\n                    else:\n                        raise NotImplementedError(f\"Unknown Huber loss schedule {args.huber_schedule}!\")\n\n                    timesteps = timesteps.repeat(bsz).to(model_input.device)\n                elif args.loss_type == \"l2\":\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                    )\n                    huber_c = 1  # may be anything, as it's not used\n                else:\n                    raise NotImplementedError(f\"Unknown loss type {args.loss_type}\")\n\n                timesteps = timesteps.long()\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n\n                # time ids\n                def compute_time_ids(original_size, crops_coords_top_left):\n                    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n                    target_size = (args.resolution, args.resolution)\n                    add_time_ids = list(original_size + crops_coords_top_left + target_size)\n                    add_time_ids = torch.tensor([add_time_ids])\n                    add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)\n                    return add_time_ids\n\n                add_time_ids = torch.cat(\n                    [compute_time_ids(s, c) for s, c in zip(batch[\"original_sizes\"], batch[\"crop_top_lefts\"])]\n                )\n\n                # Predict the noise residual\n                unet_added_conditions = {\"time_ids\": add_time_ids}\n                prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                    text_encoders=[text_encoder_one, text_encoder_two],\n                    tokenizers=None,\n                    prompt=None,\n                    text_input_ids_list=[batch[\"input_ids_one\"], batch[\"input_ids_two\"]],\n                )\n                unet_added_conditions.update({\"text_embeds\": pooled_prompt_embeds})\n                model_pred = unet(\n                    noisy_model_input,\n                    timesteps,\n                    prompt_embeds,\n                    added_cond_kwargs=unet_added_conditions,\n                    return_dict=False,\n                )[0]\n\n                # Get the target for loss depending on the prediction type\n                if args.prediction_type is not None:\n                    # set prediction_type of scheduler if defined\n                    noise_scheduler.register_to_config(prediction_type=args.prediction_type)\n\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.snr_gamma is None:\n                    loss = conditional_loss(\n                        model_pred.float(), target.float(), reduction=\"mean\", loss_type=args.loss_type, huber_c=huber_c\n                    )\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = conditional_loss(\n                        model_pred.float(), target.float(), reduction=\"none\", loss_type=args.loss_type, huber_c=huber_c\n                    )\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n                if args.debug_loss and \"filenames\" in batch:\n                    for fname in batch[\"filenames\"]:\n                        accelerator.log({\"loss_for_\" + fname: loss}, step=global_step)\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                logger.info(\n                    f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n                    f\" {args.validation_prompt}.\"\n                )\n                # create pipeline\n                pipeline = StableDiffusionXLPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    text_encoder=unwrap_model(text_encoder_one),\n                    text_encoder_2=unwrap_model(text_encoder_two),\n                    unet=unwrap_model(unet),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n\n                pipeline = pipeline.to(accelerator.device)\n                pipeline.set_progress_bar_config(disable=True)\n\n                # run inference\n                generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n                pipeline_args = {\"prompt\": args.validation_prompt}\n\n                with torch.cuda.amp.autocast():\n                    images = [\n                        pipeline(**pipeline_args, generator=generator).images[0]\n                        for _ in range(args.num_validation_images)\n                    ]\n\n                for tracker in accelerator.trackers:\n                    if tracker.name == \"tensorboard\":\n                        np_images = np.stack([np.asarray(img) for img in images])\n                        tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n                    if tracker.name == \"wandb\":\n                        tracker.log(\n                            {\n                                \"validation\": [\n                                    wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                    for i, image in enumerate(images)\n                                ]\n                            }\n                        )\n\n                del pipeline\n                torch.cuda.empty_cache()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n\n        if args.train_text_encoder:\n            text_encoder_one = unwrap_model(text_encoder_one)\n            text_encoder_two = unwrap_model(text_encoder_two)\n\n            text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))\n            text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))\n        else:\n            text_encoder_lora_layers = None\n            text_encoder_2_lora_layers = None\n\n        StableDiffusionXLPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_state_dict,\n            text_encoder_lora_layers=text_encoder_lora_layers,\n            text_encoder_2_lora_layers=text_encoder_2_lora_layers,\n        )\n\n        del unet\n        del text_encoder_one\n        del text_encoder_two\n        del text_encoder_lora_layers\n        del text_encoder_2_lora_layers\n        torch.cuda.empty_cache()\n\n        # Final inference\n        # Make sure vae.dtype is consistent with the unet.dtype\n        if args.mixed_precision == \"fp16\":\n            vae.to(weight_dtype)\n        # Load previous pipeline\n        pipeline = StableDiffusionXLPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            vae=vae,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        pipeline = pipeline.to(accelerator.device)\n\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n            images = [\n                pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]\n                for _ in range(args.num_validation_images)\n            ]\n\n            for tracker in accelerator.trackers:\n                if tracker.name == \"tensorboard\":\n                    np_images = np.stack([np.asarray(img) for img in images])\n                    tracker.writer.add_images(\"test\", np_images, epoch, dataformats=\"NHWC\")\n                if tracker.name == \"wandb\":\n                    tracker.log(\n                        {\n                            \"test\": [\n                                wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                for i, image in enumerate(images)\n                            ]\n                        }\n                    )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                dataset_name=args.dataset_name,\n                train_text_encoder=args.train_text_encoder,\n                repo_folder=args.output_dir,\n                vae_path=args.pretrained_vae_model_name_or_path,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Fine-tuning script for Stable Diffusion XL for text2image.\"\"\"\n\nimport argparse\nimport functools\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import concatenate_datasets, load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel, compute_snr\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.28.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef save_model_card(\n    repo_id: str,\n    images: list = None,\n    validation_prompt: str = None,\n    base_model: str = None,\n    dataset_name: str = None,\n    repo_folder: str = None,\n    vae_path: str = None,\n):\n    img_str = \"\"\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# Text-to-image finetuning - {repo_id}\n\nThis pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \\n\n{img_str}\n\nSpecial VAE used for training: {vae_path}.\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion-xl\",\n        \"stable-diffusion-xl-diffusers\",\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sdxl-model-finetuned\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--timestep_bias_strategy\",\n        type=str,\n        default=\"none\",\n        choices=[\"earlier\", \"later\", \"range\", \"none\"],\n        help=(\n            \"The timestep bias strategy, which may help direct the model toward learning low or high frequency details.\"\n            \" Choices: ['earlier', 'later', 'range', 'none'].\"\n            \" The default is 'none', which means no bias is applied, and training proceeds normally.\"\n            \" The value of 'later' will increase the frequency of the model's final training timesteps.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_bias_multiplier\",\n        type=float,\n        default=1.0,\n        help=(\n            \"The multiplier for the bias. Defaults to 1.0, which means no bias is applied.\"\n            \" A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_bias_begin\",\n        type=int,\n        default=0,\n        help=(\n            \"When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias.\"\n            \" Defaults to zero, which equates to having no specific bias.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_bias_end\",\n        type=int,\n        default=1000,\n        help=(\n            \"When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias.\"\n            \" Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_bias_portion\",\n        type=float,\n        default=0.25,\n        help=(\n            \"The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased.\"\n            \" A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines\"\n            \" whether the biased portions are in the earlier or later timesteps.\"\n        ),\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=None,\n        help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.\",\n    )\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"l2\",\n        choices=[\"l2\", \"huber\", \"smooth_l1\"],\n        help=\"The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.\",\n    )\n    parser.add_argument(\n        \"--huber_schedule\",\n        type=str,\n        default=\"snr\",\n        choices=[\"constant\", \"exponential\", \"snr\"],\n        help=\"The schedule to use for the huber losses parameter\",\n    )\n    parser.add_argument(\n        \"--huber_c\",\n        type=float,\n        default=0.1,\n        help=\"The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    return args\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):\n    prompt_embeds_list = []\n    prompt_batch = batch[caption_column]\n\n    captions = []\n    for caption in prompt_batch:\n        if random.random() < proportion_empty_prompts:\n            captions.append(\"\")\n        elif isinstance(caption, str):\n            captions.append(caption)\n        elif isinstance(caption, (list, np.ndarray)):\n            # take a random caption if there are multiple\n            captions.append(random.choice(caption) if is_train else caption[0])\n\n    with torch.no_grad():\n        for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n            text_inputs = tokenizer(\n                captions,\n                padding=\"max_length\",\n                max_length=tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            prompt_embeds = text_encoder(\n                text_input_ids.to(text_encoder.device),\n                output_hidden_states=True,\n                return_dict=False,\n            )\n\n            # We are only ALWAYS interested in the pooled output of the final text encoder\n            pooled_prompt_embeds = prompt_embeds[0]\n            prompt_embeds = prompt_embeds[-1][-2]\n            bs_embed, seq_len, _ = prompt_embeds.shape\n            prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n            prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return {\"prompt_embeds\": prompt_embeds.cpu(), \"pooled_prompt_embeds\": pooled_prompt_embeds.cpu()}\n\n\ndef compute_vae_encodings(batch, vae):\n    images = batch.pop(\"pixel_values\")\n    pixel_values = torch.stack(list(images))\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n    pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)\n\n    with torch.no_grad():\n        model_input = vae.encode(pixel_values).latent_dist.sample()\n    model_input = model_input * vae.config.scaling_factor\n    return {\"model_input\": model_input.cpu()}\n\n\ndef generate_timestep_weights(args, num_timesteps):\n    weights = torch.ones(num_timesteps)\n\n    # Determine the indices to bias\n    num_to_bias = int(args.timestep_bias_portion * num_timesteps)\n\n    if args.timestep_bias_strategy == \"later\":\n        bias_indices = slice(-num_to_bias, None)\n    elif args.timestep_bias_strategy == \"earlier\":\n        bias_indices = slice(0, num_to_bias)\n    elif args.timestep_bias_strategy == \"range\":\n        # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.\n        range_begin = args.timestep_bias_begin\n        range_end = args.timestep_bias_end\n        if range_begin < 0:\n            raise ValueError(\n                \"When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero.\"\n            )\n        if range_end > num_timesteps:\n            raise ValueError(\n                \"When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps.\"\n            )\n        bias_indices = slice(range_begin, range_end)\n    else:  # 'none' or any other string\n        return weights\n    if args.timestep_bias_multiplier <= 0:\n        return ValueError(\n            \"The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps.\"\n            \" If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead.\"\n            \" A timestep bias multiplier less than or equal to 0 is not allowed.\"\n        )\n\n    # Apply the bias\n    weights[bias_indices] *= args.timestep_bias_multiplier\n\n    # Normalize\n    weights /= weights.sum()\n\n    return weights\n\n\n# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already\ndef conditional_loss(\n    model_pred: torch.Tensor,\n    target: torch.Tensor,\n    reduction: str = \"mean\",\n    loss_type: str = \"l2\",\n    huber_c: float = 0.1,\n):\n    if loss_type == \"l2\":\n        loss = F.mse_loss(model_pred, target, reduction=reduction)\n    elif loss_type == \"huber\":\n        loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n        if reduction == \"mean\":\n            loss = torch.mean(loss)\n        elif reduction == \"sum\":\n            loss = torch.sum(loss)\n    elif loss_type == \"smooth_l1\":\n        loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)\n        if reduction == \"mean\":\n            loss = torch.mean(loss)\n        elif reduction == \"sum\":\n            loss = torch.sum(loss)\n    else:\n        raise NotImplementedError(f\"Unsupported Loss Type {loss_type}\")\n    return loss\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    # Check for terminal SNR in combination with SNR Gamma\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # Freeze vae and text encoders.\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    # Set unet as trainable.\n    unet.train()\n\n    # For mixed precision training we cast all non-trainable weights to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    vae.to(accelerator.device, dtype=torch.float32)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    # Create EMA for the unet.\n    if args.use_ema:\n        ema_unet = UNet2DConditionModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n        )\n        ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_unet.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"unet_ema\"), UNet2DConditionModel)\n                ema_unet.load_state_dict(load_model.state_dict())\n                ema_unet.to(accelerator.device)\n                del load_model\n\n            for _ in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = unet.parameters()\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)\n    train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)\n    train_flip = transforms.RandomHorizontalFlip(p=1.0)\n    train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        # image aug\n        original_sizes = []\n        all_images = []\n        crop_top_lefts = []\n        for image in images:\n            original_sizes.append((image.height, image.width))\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            crop_top_left = (y1, x1)\n            crop_top_lefts.append(crop_top_left)\n            image = train_transforms(image)\n            all_images.append(image)\n\n        examples[\"original_sizes\"] = original_sizes\n        examples[\"crop_top_lefts\"] = crop_top_lefts\n        examples[\"pixel_values\"] = all_images\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    # Let's first compute all the embeddings so that we can free up the text encoders\n    # from memory. We will pre-compute the VAE encodings too.\n    text_encoders = [text_encoder_one, text_encoder_two]\n    tokenizers = [tokenizer_one, tokenizer_two]\n    compute_embeddings_fn = functools.partial(\n        encode_prompt,\n        text_encoders=text_encoders,\n        tokenizers=tokenizers,\n        proportion_empty_prompts=args.proportion_empty_prompts,\n        caption_column=args.caption_column,\n    )\n    compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)\n    with accelerator.main_process_first():\n        from datasets.fingerprint import Hasher\n\n        # fingerprint used by the cache for the other processes to load the result\n        # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401\n        new_fingerprint = Hasher.hash(args)\n        new_fingerprint_for_vae = Hasher.hash(vae_path)\n        train_dataset_with_embeddings = train_dataset.map(\n            compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint\n        )\n        train_dataset_with_vae = train_dataset.map(\n            compute_vae_encodings_fn,\n            batched=True,\n            batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,\n            new_fingerprint=new_fingerprint_for_vae,\n        )\n        precomputed_dataset = concatenate_datasets(\n            [train_dataset_with_embeddings, train_dataset_with_vae.remove_columns([\"image\", \"text\"])], axis=1\n        )\n        precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)\n\n    del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two\n    del text_encoders, tokenizers, vae\n    gc.collect()\n    torch.cuda.empty_cache()\n\n    def collate_fn(examples):\n        model_input = torch.stack([torch.tensor(example[\"model_input\"]) for example in examples])\n        original_sizes = [example[\"original_sizes\"] for example in examples]\n        crop_top_lefts = [example[\"crop_top_lefts\"] for example in examples]\n        prompt_embeds = torch.stack([torch.tensor(example[\"prompt_embeds\"]) for example in examples])\n        pooled_prompt_embeds = torch.stack([torch.tensor(example[\"pooled_prompt_embeds\"]) for example in examples])\n\n        return {\n            \"model_input\": model_input,\n            \"prompt_embeds\": prompt_embeds,\n            \"pooled_prompt_embeds\": pooled_prompt_embeds,\n            \"original_sizes\": original_sizes,\n            \"crop_top_lefts\": crop_top_lefts,\n        }\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        precomputed_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    if args.use_ema:\n        ema_unet.to(accelerator.device)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"text2image-fine-tune-sdxl\", config=vars(args))\n\n    # Function for unwrapping if torch.compile() was used in accelerate.\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(precomputed_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # Sample noise that we'll add to the latents\n                model_input = batch[\"model_input\"].to(accelerator.device)\n                noise = torch.randn_like(model_input)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device\n                    )\n\n                bsz = model_input.shape[0]\n                if args.timestep_bias_strategy == \"none\":\n                    # Sample a random timestep for each image\n                    if args.loss_type == \"huber\" or args.loss_type == \"smooth_l1\":\n                        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=\"cpu\")\n                        timestep = timesteps.item()\n\n                        if args.huber_schedule == \"exponential\":\n                            alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps\n                            huber_c = math.exp(-alpha * timestep)\n                        elif args.huber_schedule == \"snr\":\n                            alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]\n                            sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5\n                            huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c\n                        elif args.huber_schedule == \"constant\":\n                            huber_c = args.huber_c\n                        else:\n                            raise NotImplementedError(f\"Unknown Huber loss schedule {args.huber_schedule}!\")\n\n                        timesteps = timesteps.repeat(bsz).to(model_input.device)\n                    elif args.loss_type == \"l2\":\n                        timesteps = torch.randint(\n                            0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                        )\n                        huber_c = 1  # may be anything, as it's not used\n                    else:\n                        raise NotImplementedError(f\"Unknown loss type {args.loss_type}\")\n\n                    timesteps = timesteps.long()\n\n                else:\n                    if \"huber_scheduled\" in args.loss_type:\n                        raise NotImplementedError(\n                            \"Randomly weighted timesteps not implemented yet for scheduled huber loss!\"\n                        )\n                    else:\n                        huber_c = args.huber_c\n                    # Sample a random timestep for each image, potentially biased by the timestep weights.\n                    # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.\n                    weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(\n                        model_input.device\n                    )\n                    timesteps = torch.multinomial(weights, bsz, replacement=True).long()\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n\n                # time ids\n                def compute_time_ids(original_size, crops_coords_top_left):\n                    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n                    target_size = (args.resolution, args.resolution)\n                    add_time_ids = list(original_size + crops_coords_top_left + target_size)\n                    add_time_ids = torch.tensor([add_time_ids])\n                    add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)\n                    return add_time_ids\n\n                add_time_ids = torch.cat(\n                    [compute_time_ids(s, c) for s, c in zip(batch[\"original_sizes\"], batch[\"crop_top_lefts\"])]\n                )\n\n                # Predict the noise residual\n                unet_added_conditions = {\"time_ids\": add_time_ids}\n                prompt_embeds = batch[\"prompt_embeds\"].to(accelerator.device)\n                pooled_prompt_embeds = batch[\"pooled_prompt_embeds\"].to(accelerator.device)\n                unet_added_conditions.update({\"text_embeds\": pooled_prompt_embeds})\n                model_pred = unet(\n                    noisy_model_input,\n                    timesteps,\n                    prompt_embeds,\n                    added_cond_kwargs=unet_added_conditions,\n                    return_dict=False,\n                )[0]\n\n                # Get the target for loss depending on the prediction type\n                if args.prediction_type is not None:\n                    # set prediction_type of scheduler if defined\n                    noise_scheduler.register_to_config(prediction_type=args.prediction_type)\n\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n                elif noise_scheduler.config.prediction_type == \"sample\":\n                    # We set the target to latents here, but the model_pred will return the noise sample prediction.\n                    target = model_input\n                    # We will have to subtract the noise residual from the prediction to get the target sample.\n                    model_pred = model_pred - noise\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.snr_gamma is None:\n                    loss = conditional_loss(\n                        model_pred.float(), target.float(), reduction=\"mean\", loss_type=args.loss_type, huber_c=huber_c\n                    )\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = conditional_loss(\n                        model_pred.float(), target.float(), reduction=\"none\", loss_type=args.loss_type, huber_c=huber_c\n                    )\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = unet.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_unet.step(unet.parameters())\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                logger.info(\n                    f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n                    f\" {args.validation_prompt}.\"\n                )\n                if args.use_ema:\n                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                    ema_unet.store(unet.parameters())\n                    ema_unet.copy_to(unet.parameters())\n\n                # create pipeline\n                vae = AutoencoderKL.from_pretrained(\n                    vae_path,\n                    subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n                    revision=args.revision,\n                    variant=args.variant,\n                )\n                pipeline = StableDiffusionXLPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    unet=accelerator.unwrap_model(unet),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                if args.prediction_type is not None:\n                    scheduler_args = {\"prediction_type\": args.prediction_type}\n                    pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n\n                pipeline = pipeline.to(accelerator.device)\n                pipeline.set_progress_bar_config(disable=True)\n\n                # run inference\n                generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n                pipeline_args = {\"prompt\": args.validation_prompt}\n\n                with torch.cuda.amp.autocast():\n                    images = [\n                        pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]\n                        for _ in range(args.num_validation_images)\n                    ]\n\n                for tracker in accelerator.trackers:\n                    if tracker.name == \"tensorboard\":\n                        np_images = np.stack([np.asarray(img) for img in images])\n                        tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n                    if tracker.name == \"wandb\":\n                        tracker.log(\n                            {\n                                \"validation\": [\n                                    wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                    for i, image in enumerate(images)\n                                ]\n                            }\n                        )\n\n                del pipeline\n                torch.cuda.empty_cache()\n\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        if args.use_ema:\n            ema_unet.copy_to(unet.parameters())\n\n        # Serialize pipeline.\n        vae = AutoencoderKL.from_pretrained(\n            vae_path,\n            subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        pipeline = StableDiffusionXLPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            unet=unet,\n            vae=vae,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        if args.prediction_type is not None:\n            scheduler_args = {\"prediction_type\": args.prediction_type}\n            pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n        pipeline.save_pretrained(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline = pipeline.to(accelerator.device)\n            generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n            with torch.cuda.amp.autocast():\n                images = [\n                    pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]\n                    for _ in range(args.num_validation_images)\n                ]\n\n            for tracker in accelerator.trackers:\n                if tracker.name == \"tensorboard\":\n                    np_images = np.stack([np.asarray(img) for img in images])\n                    tracker.writer.add_images(\"test\", np_images, epoch, dataformats=\"NHWC\")\n                if tracker.name == \"wandb\":\n                    tracker.log(\n                        {\n                            \"test\": [\n                                wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                for i, image in enumerate(images)\n                            ]\n                        }\n                    )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id=repo_id,\n                images=images,\n                validation_prompt=args.validation_prompt,\n                base_model=args.pretrained_model_name_or_path,\n                dataset_name=args.dataset_name,\n                repo_folder=args.output_dir,\n                vae_path=args.pretrained_vae_model_name_or_path,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/sd3_lora_colab/README.md",
    "content": "# Running Stable Diffusion 3 DreamBooth LoRA training under 16GB\n\nThis is an **EDUCATIONAL** project that provides utilities for DreamBooth LoRA training for [Stable Diffusion 3 (SD3)](ttps://huggingface.co/papers/2403.03206) under 16GB GPU VRAM. This means you can successfully try out this project using a [free-tier Colab Notebook](https://colab.research.google.com/github/huggingface/diffusers/blob/main/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb) instance. 🤗\n\n> [!NOTE]\n> SD3 is gated, so you need to make sure you agree to [share your contact info](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) to access the model before using it with Diffusers. Once you have access, you need to log in so your system knows you’re authorized. Use the command below to log in:\n\n```bash\nhf auth login\n```\n\nThis will also allow us to push the trained model parameters to the Hugging Face Hub platform.\n\nFor setup, inference code, and details on how to run the code, please follow the Colab Notebook provided above.\n\n## How\n\nWe make use of several techniques to make this possible:\n\n* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit (as introduced in [`LLM.int8()`](https://huggingface.co/papers/2208.07339)) T5 to reduce memory requirements to ~10.5GB.\n* In the `train_dreambooth_sd3_lora_miniature.py` script, we make use of:\n  * 8bit Adam for optimization through the `bitsandbytes` library.\n  * Gradient checkpointing and gradient accumulation.\n  * FP16 precision.\n  * Flash attention through `F.scaled_dot_product_attention()`.\n\nComputing the text embeddings is arguably the most memory-intensive part in the pipeline as SD3 employs three text encoders. If we run them in FP32, it will take about 20GB of VRAM. With FP16, we are down to 12GB.\n\n\n## Gotchas\n\nThis project is educational. It exists to showcase the possibility of fine-tuning a big diffusion system on consumer GPUs. But additional components might have to be added to obtain state-of-the-art performance. Below are some commonly known gotchas that users should be aware of:\n\n* Training of text encoders is purposefully disabled.\n* Techniques such as prior-preservation is unsupported.\n* Custom instance captions for instance images are unsupported, but this should be relatively easy to integrate.\n\nHopefully, this project gives you a template to extend it further to suit your needs.\n"
  },
  {
    "path": "examples/research_projects/sd3_lora_colab/compute_embeddings.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport glob\nimport hashlib\n\nimport pandas as pd\nimport torch\nfrom transformers import T5EncoderModel\n\nfrom diffusers import StableDiffusion3Pipeline\n\n\nPROMPT = \"a photo of sks dog\"\nMAX_SEQ_LENGTH = 77\nLOCAL_DATA_DIR = \"dog\"\nOUTPUT_PATH = \"sample_embeddings.parquet\"\n\n\ndef bytes_to_giga_bytes(bytes):\n    return bytes / 1024 / 1024 / 1024\n\n\ndef generate_image_hash(image_path):\n    with open(image_path, \"rb\") as f:\n        img_data = f.read()\n    return hashlib.sha256(img_data).hexdigest()\n\n\ndef load_sd3_pipeline():\n    id = \"stabilityai/stable-diffusion-3-medium-diffusers\"\n    text_encoder = T5EncoderModel.from_pretrained(id, subfolder=\"text_encoder_3\", load_in_8bit=True, device_map=\"auto\")\n    pipeline = StableDiffusion3Pipeline.from_pretrained(\n        id, text_encoder_3=text_encoder, transformer=None, vae=None, device_map=\"balanced\"\n    )\n    return pipeline\n\n\n@torch.no_grad()\ndef compute_embeddings(pipeline, prompt, max_sequence_length):\n    (\n        prompt_embeds,\n        negative_prompt_embeds,\n        pooled_prompt_embeds,\n        negative_pooled_prompt_embeds,\n    ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None, max_sequence_length=max_sequence_length)\n\n    print(\n        f\"{prompt_embeds.shape=}, {negative_prompt_embeds.shape=}, {pooled_prompt_embeds.shape=}, {negative_pooled_prompt_embeds.shape}\"\n    )\n\n    max_memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())\n    print(f\"Max memory allocated: {max_memory:.3f} GB\")\n    return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds\n\n\ndef run(args):\n    pipeline = load_sd3_pipeline()\n    prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = compute_embeddings(\n        pipeline, args.prompt, args.max_sequence_length\n    )\n\n    # Assumes that the images within `args.local_image_dir` have a JPEG extension. Change\n    # as needed.\n    image_paths = glob.glob(f\"{args.local_data_dir}/*.jpeg\")\n    data = []\n    for image_path in image_paths:\n        img_hash = generate_image_hash(image_path)\n        data.append(\n            (img_hash, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds)\n        )\n\n    # Create a DataFrame\n    embedding_cols = [\n        \"prompt_embeds\",\n        \"negative_prompt_embeds\",\n        \"pooled_prompt_embeds\",\n        \"negative_pooled_prompt_embeds\",\n    ]\n    df = pd.DataFrame(\n        data,\n        columns=[\"image_hash\"] + embedding_cols,\n    )\n\n    # Convert embedding lists to arrays (for proper storage in parquet)\n    for col in embedding_cols:\n        df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())\n\n    # Save the dataframe to a parquet file\n    df.to_parquet(args.output_path)\n    print(f\"Data successfully serialized to {args.output_path}\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--prompt\", type=str, default=PROMPT, help=\"The instance prompt.\")\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=MAX_SEQ_LENGTH,\n        help=\"Maximum sequence length to use for computing the embeddings. The more the higher computational costs.\",\n    )\n    parser.add_argument(\n        \"--local_data_dir\", type=str, default=LOCAL_DATA_DIR, help=\"Path to the directory containing instance images.\"\n    )\n    parser.add_argument(\"--output_path\", type=str, default=OUTPUT_PATH, help=\"Path to serialize the parquet file.\")\n    args = parser.parse_args()\n\n    run(args)\n"
  },
  {
    "path": "examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"a6xLZDgOajbd\"\n      },\n      \"source\": [\n        \"# Running Stable Diffusion 3 (SD3) DreamBooth LoRA training under 16GB GPU VRAM\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"0jPZpMTwafua\"\n      },\n      \"source\": [\n        \"## Install Dependencies\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"lIYdn1woOS1n\",\n        \"outputId\": \"6d4a6332-d1f5-46e2-ad2b-c9e51b9f279a\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"!pip install -q -U git+https://github.com/huggingface/diffusers\\n\",\n        \"!pip install -q -U \\\\\\n\",\n        \"    transformers \\\\\\n\",\n        \"    accelerate \\\\\\n\",\n        \"    wandb \\\\\\n\",\n        \"    bitsandbytes \\\\\\n\",\n        \"    peft\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"5qUNciw6aov2\"\n      },\n      \"source\": [\n        \"As SD3 is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"Bpk5FleeK1NR\",\n        \"outputId\": \"54d8e774-514e-46fe-b9a7-0185e0bcf211\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"!hf auth login\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"tcF7gl4FasJV\"\n      },\n      \"source\": [\n        \"## Clone `diffusers`\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"QgSOJYglJKiM\",\n        \"outputId\": \"be51f30f-8848-4a79-ae91-c4fb89c244ba\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"!git clone https://github.com/huggingface/diffusers\\n\",\n        \"%cd diffusers/examples/research_projects/sd3_lora_colab\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"X9dBawr6ayRY\"\n      },\n      \"source\": [\n        \"## Download instance data images\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\",\n          \"height\": 351,\n          \"referenced_widgets\": [\n            \"8720a1f0a3b043dba02b6aab0afb861a\",\n            \"0e70a30146ef4b30b014179bd4bfd131\",\n            \"c39072b8cfff4a11ba283a9ae3155e52\",\n            \"1e834badd9c74b95bda30456a585fc06\",\n            \"d7c7c83b341b4471ad8a0ca1fe76d9ff\",\n            \"5ab639bd765f4824818a53ab84f690a8\",\n            \"cd94205b05d54e4c96c9c475e13abe83\",\n            \"be260274fdb04798af6fce6169646ff2\",\n            \"b9912757b9f9477186c171ecb2551d3a\",\n            \"a1f88f8e27894cdfab54cad04871f74e\",\n            \"19026c269dce47d585506f734fa2981a\",\n            \"50237341e55e4da0ba5cdbf652f30115\",\n            \"1d006f25b17e4cd8aaa5f66d58940dc7\",\n            \"4c673aa247ff4d65b82c4c64ca2e72da\",\n            \"92698388b667476ea1daf5cacb2fdb07\",\n            \"f6b3aa0f980e450289ee15cea7cb3ed7\",\n            \"0690a95eb8c3403e90d5b023aaadb22c\",\n            \"4a20ceca22724ab082d013c20c758d31\",\n            \"c80f3825981646a8a3e178192e338962\",\n            \"5673d0ca1f1247dd924874355eadecd4\",\n            \"7774ac850ab2451ea380bf80f3be5a86\",\n            \"22e57b8c83fa48489d6a327f1bbb756b\",\n            \"dd2debcf1c774181bef97efab0f3d6e1\",\n            \"633d7df9a17e4bf6951249aca83a9e96\",\n            \"6469e9991d7b41a0b83a7b443b9eebe5\",\n            \"0b9c72fa39c241ba9dd22dd67c2436fe\",\n            \"99e707cfe1454757aad4014230f6dae8\",\n            \"5a4ec2d031fa438eb4d0492329b28f00\",\n            \"6c0d4d6d84704f88b46a9b5cf94e3836\",\n            \"e1fb8dec23c04d6f8d1217242f8a495c\",\n            \"4b35f9d8d6444d0397a8bafcf3d73e8f\",\n            \"0f3279a4e6a247a7b69ff73bc06acfe0\",\n            \"b5ac4ab9256e4d5092ba6e449bc3cdd3\",\n            \"2840e90c518d4666b3f5a935c90569a7\",\n            \"adb012e95d7d442a820680e61e615e3c\",\n            \"be4fd10d940d49cf8e916904da8192ab\",\n            \"fd93adba791f46c1b0a25ff692426149\",\n            \"cdee3b61ca6a487c8ec8e7e884eb8b07\",\n            \"190a7fbe2b554104a6d5b2caa3b0a08e\",\n            \"975922b877e143edb09cdb888cb7cae8\",\n            \"d7365b62df59406dbd38677299cce1c8\",\n            \"67f0f5f1179140b4bdaa74c5583e3958\",\n            \"e560f25c3e334cf2a4c748981ac38da6\",\n            \"65173381a80b40748b7b2800fdb89151\",\n            \"7a4c5c0acd2d400e91da611e91ff5306\",\n            \"2a02c69a19a741b4a032dedbc21ad088\",\n            \"c6211ddb71e64f9e92c70158da2f7ef1\",\n            \"c219a1e791894469aa1452045b0f74b5\",\n            \"8e881fd3a17e4a5d95e71f6411ed8167\",\n            \"5350f001bf774b5fb7f3e362f912bec3\",\n            \"a893c93bcbc444a4931d1ddc6e342786\",\n            \"03047a13f06744fcac17c77cb03bca62\",\n            \"4d77a9c44d1c47b18022f8a51b29e20d\",\n            \"0b5bb94394fc447282d2c44780303f15\",\n            \"01bfa49325a8403b808ad1662465996e\",\n            \"3c0f67144f974aea85c7088f482e830d\",\n            \"0996e9f698dc4d6ab3c4319c96186619\",\n            \"9c663933ece544b193531725f4dc873d\",\n            \"b48dc06ca1654fe39bf8a77352a921f2\",\n            \"641f5a361a584cc0b71fd71d5f786958\",\n            \"66a952451b4c43cab54312fe886df5e6\",\n            \"42757924240c4abeb39add0c26687ab3\",\n            \"7f26ae5417cf4c80921cce830e60f72b\",\n            \"b77093ec9ffd40e0b2c1d9d1bdc063f5\",\n            \"7d8e3510c1e34524993849b8fce52758\",\n            \"b15011804460483ab84904a52da754b7\"\n          ]\n        },\n        \"id\": \"La1rBYWFNjEP\",\n        \"outputId\": \"e8567843-193e-4653-86b8-be26390700df\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from huggingface_hub import snapshot_download\\n\",\n        \"\\n\",\n        \"local_dir = \\\"./dog\\\"\\n\",\n        \"snapshot_download(\\n\",\n        \"    \\\"diffusers/dog-example\\\",\\n\",\n        \"    local_dir=local_dir, repo_type=\\\"dataset\\\",\\n\",\n        \"    ignore_patterns=\\\".gitattributes\\\",\\n\",\n        \")\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"hbsIzdjbOzgi\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"!rm -rf dog/.cache\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"88sOTn2ga07q\"\n      },\n      \"source\": [\n        \"## Compute embeddings\\n\",\n        \"\\n\",\n        \"Here we are using the default instance prompt \\\"a photo of sks dog\\\". But you can configure this. Refer to the `compute_embeddings.py` script for details on other supported arguments.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"ha6hPLpHLM8c\",\n        \"outputId\": \"82843eb0-473e-4d6b-d11d-1f79bcd1d11a\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"!python compute_embeddings.py\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"10iMo-RUa_yv\"\n      },\n      \"source\": [\n        \"## Clear memory\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"-YltRmPgMuNa\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import torch\\n\",\n        \"import gc\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"def flush():\\n\",\n        \"    torch.cuda.empty_cache()\\n\",\n        \"    gc.collect()\\n\",\n        \"\\n\",\n        \"flush()\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"UO5oEtOJbBS9\"\n      },\n      \"source\": [\n        \"## Train!\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"base_uri\": \"https://localhost:8080/\"\n        },\n        \"id\": \"HuJ6hdm2M4Aw\",\n        \"outputId\": \"0b2d8ca3-c65f-4bb4-af9a-809b77116510\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"!accelerate launch train_dreambooth_lora_sd3_miniature.py \\\\\\n\",\n        \"  --pretrained_model_name_or_path=\\\"stabilityai/stable-diffusion-3-medium-diffusers\\\"  \\\\\\n\",\n        \"  --instance_data_dir=\\\"dog\\\" \\\\\\n\",\n        \"  --data_df_path=\\\"sample_embeddings.parquet\\\" \\\\\\n\",\n        \"  --output_dir=\\\"trained-sd3-lora-miniature\\\" \\\\\\n\",\n        \"  --mixed_precision=\\\"fp16\\\" \\\\\\n\",\n        \"  --instance_prompt=\\\"a photo of sks dog\\\" \\\\\\n\",\n        \"  --resolution=1024 \\\\\\n\",\n        \"  --train_batch_size=1 \\\\\\n\",\n        \"  --gradient_accumulation_steps=4 --gradient_checkpointing \\\\\\n\",\n        \"  --use_8bit_adam \\\\\\n\",\n        \"  --learning_rate=1e-4 \\\\\\n\",\n        \"  --report_to=\\\"wandb\\\" \\\\\\n\",\n        \"  --lr_scheduler=\\\"constant\\\" \\\\\\n\",\n        \"  --lr_warmup_steps=0 \\\\\\n\",\n        \"  --max_train_steps=500 \\\\\\n\",\n        \"  --seed=\\\"0\\\"\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"itS-dsJ0gjy3\"\n      },\n      \"source\": [\n        \"Training will take about an hour to complete depending on the length of your dataset.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"BpOuL7S1bI6j\"\n      },\n      \"source\": [\n        \"## Inference\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"clfMv4jKfQzb\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"flush()\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"np03SXHkbKpG\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from diffusers import DiffusionPipeline\\n\",\n        \"import torch\\n\",\n        \"\\n\",\n        \"pipeline = DiffusionPipeline.from_pretrained(\\n\",\n        \"    \\\"stabilityai/stable-diffusion-3-medium-diffusers\\\",\\n\",\n        \"    torch_dtype=torch.float16\\n\",\n        \")\\n\",\n        \"lora_output_path = \\\"trained-sd3-lora-miniature\\\"\\n\",\n        \"pipeline.load_lora_weights(\\\"trained-sd3-lora-miniature\\\")\\n\",\n        \"\\n\",\n        \"pipeline.enable_sequential_cpu_offload()\\n\",\n        \"\\n\",\n        \"image = pipeline(\\\"a photo of sks dog in a bucket\\\").images[0]\\n\",\n        \"image.save(\\\"bucket_dog.png\\\")\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"HDfrY2opjGjD\"\n      },\n      \"source\": [\n        \"Note that inference will be very slow in this case because we're loading and unloading individual components of the models and that introduces significant data movement overhead. Refer to [this resource](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more memory optimization related techniques.\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"accelerator\": \"GPU\",\n    \"colab\": {\n      \"gpuType\": \"T4\",\n      \"provenance\": []\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    },\n    \"widgets\": {\n      \"application/vnd.jupyter.widget-state+json\": {\n        \"01bfa49325a8403b808ad1662465996e\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"03047a13f06744fcac17c77cb03bca62\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"0690a95eb8c3403e90d5b023aaadb22c\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"0996e9f698dc4d6ab3c4319c96186619\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_66a952451b4c43cab54312fe886df5e6\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_42757924240c4abeb39add0c26687ab3\",\n            \"value\": \"alvan-nee-bQaAJCbNq3g-unsplash.jpeg: 100%\"\n          }\n        },\n        \"0b5bb94394fc447282d2c44780303f15\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"0b9c72fa39c241ba9dd22dd67c2436fe\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_0f3279a4e6a247a7b69ff73bc06acfe0\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_b5ac4ab9256e4d5092ba6e449bc3cdd3\",\n            \"value\": \" 1.16M/1.16M [00:00&lt;00:00, 10.3MB/s]\"\n          }\n        },\n        \"0e70a30146ef4b30b014179bd4bfd131\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_5ab639bd765f4824818a53ab84f690a8\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_cd94205b05d54e4c96c9c475e13abe83\",\n            \"value\": \"Fetching 5 files: 100%\"\n          }\n        },\n        \"0f3279a4e6a247a7b69ff73bc06acfe0\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"19026c269dce47d585506f734fa2981a\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"190a7fbe2b554104a6d5b2caa3b0a08e\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"1d006f25b17e4cd8aaa5f66d58940dc7\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_0690a95eb8c3403e90d5b023aaadb22c\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_4a20ceca22724ab082d013c20c758d31\",\n            \"value\": \"alvan-nee-9M0tSjb-cpA-unsplash.jpeg: 100%\"\n          }\n        },\n        \"1e834badd9c74b95bda30456a585fc06\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_a1f88f8e27894cdfab54cad04871f74e\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_19026c269dce47d585506f734fa2981a\",\n            \"value\": \" 5/5 [00:01&lt;00:00,  3.44it/s]\"\n          }\n        },\n        \"22e57b8c83fa48489d6a327f1bbb756b\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"2840e90c518d4666b3f5a935c90569a7\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_adb012e95d7d442a820680e61e615e3c\",\n              \"IPY_MODEL_be4fd10d940d49cf8e916904da8192ab\",\n              \"IPY_MODEL_fd93adba791f46c1b0a25ff692426149\"\n            ],\n            \"layout\": \"IPY_MODEL_cdee3b61ca6a487c8ec8e7e884eb8b07\"\n          }\n        },\n        \"2a02c69a19a741b4a032dedbc21ad088\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_5350f001bf774b5fb7f3e362f912bec3\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_a893c93bcbc444a4931d1ddc6e342786\",\n            \"value\": \"alvan-nee-eoqnr8ikwFE-unsplash.jpeg: 100%\"\n          }\n        },\n        \"3c0f67144f974aea85c7088f482e830d\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_0996e9f698dc4d6ab3c4319c96186619\",\n              \"IPY_MODEL_9c663933ece544b193531725f4dc873d\",\n              \"IPY_MODEL_b48dc06ca1654fe39bf8a77352a921f2\"\n            ],\n            \"layout\": \"IPY_MODEL_641f5a361a584cc0b71fd71d5f786958\"\n          }\n        },\n        \"42757924240c4abeb39add0c26687ab3\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"4a20ceca22724ab082d013c20c758d31\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"4b35f9d8d6444d0397a8bafcf3d73e8f\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"4c673aa247ff4d65b82c4c64ca2e72da\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_c80f3825981646a8a3e178192e338962\",\n            \"max\": 677407,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_5673d0ca1f1247dd924874355eadecd4\",\n            \"value\": 677407\n          }\n        },\n        \"4d77a9c44d1c47b18022f8a51b29e20d\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"50237341e55e4da0ba5cdbf652f30115\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_1d006f25b17e4cd8aaa5f66d58940dc7\",\n              \"IPY_MODEL_4c673aa247ff4d65b82c4c64ca2e72da\",\n              \"IPY_MODEL_92698388b667476ea1daf5cacb2fdb07\"\n            ],\n            \"layout\": \"IPY_MODEL_f6b3aa0f980e450289ee15cea7cb3ed7\"\n          }\n        },\n        \"5350f001bf774b5fb7f3e362f912bec3\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"5673d0ca1f1247dd924874355eadecd4\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"5a4ec2d031fa438eb4d0492329b28f00\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"5ab639bd765f4824818a53ab84f690a8\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"633d7df9a17e4bf6951249aca83a9e96\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_5a4ec2d031fa438eb4d0492329b28f00\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_6c0d4d6d84704f88b46a9b5cf94e3836\",\n            \"value\": \"alvan-nee-Id1DBHv4fbg-unsplash.jpeg: 100%\"\n          }\n        },\n        \"641f5a361a584cc0b71fd71d5f786958\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"6469e9991d7b41a0b83a7b443b9eebe5\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_e1fb8dec23c04d6f8d1217242f8a495c\",\n            \"max\": 1163467,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_4b35f9d8d6444d0397a8bafcf3d73e8f\",\n            \"value\": 1163467\n          }\n        },\n        \"65173381a80b40748b7b2800fdb89151\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"66a952451b4c43cab54312fe886df5e6\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"67f0f5f1179140b4bdaa74c5583e3958\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"6c0d4d6d84704f88b46a9b5cf94e3836\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"7774ac850ab2451ea380bf80f3be5a86\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"7a4c5c0acd2d400e91da611e91ff5306\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_2a02c69a19a741b4a032dedbc21ad088\",\n              \"IPY_MODEL_c6211ddb71e64f9e92c70158da2f7ef1\",\n              \"IPY_MODEL_c219a1e791894469aa1452045b0f74b5\"\n            ],\n            \"layout\": \"IPY_MODEL_8e881fd3a17e4a5d95e71f6411ed8167\"\n          }\n        },\n        \"7d8e3510c1e34524993849b8fce52758\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"7f26ae5417cf4c80921cce830e60f72b\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"8720a1f0a3b043dba02b6aab0afb861a\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_0e70a30146ef4b30b014179bd4bfd131\",\n              \"IPY_MODEL_c39072b8cfff4a11ba283a9ae3155e52\",\n              \"IPY_MODEL_1e834badd9c74b95bda30456a585fc06\"\n            ],\n            \"layout\": \"IPY_MODEL_d7c7c83b341b4471ad8a0ca1fe76d9ff\"\n          }\n        },\n        \"8e881fd3a17e4a5d95e71f6411ed8167\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"92698388b667476ea1daf5cacb2fdb07\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_7774ac850ab2451ea380bf80f3be5a86\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_22e57b8c83fa48489d6a327f1bbb756b\",\n            \"value\": \" 677k/677k [00:00&lt;00:00, 9.50MB/s]\"\n          }\n        },\n        \"975922b877e143edb09cdb888cb7cae8\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"99e707cfe1454757aad4014230f6dae8\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"9c663933ece544b193531725f4dc873d\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_7f26ae5417cf4c80921cce830e60f72b\",\n            \"max\": 1396297,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_b77093ec9ffd40e0b2c1d9d1bdc063f5\",\n            \"value\": 1396297\n          }\n        },\n        \"a1f88f8e27894cdfab54cad04871f74e\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"a893c93bcbc444a4931d1ddc6e342786\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"adb012e95d7d442a820680e61e615e3c\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_190a7fbe2b554104a6d5b2caa3b0a08e\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_975922b877e143edb09cdb888cb7cae8\",\n            \"value\": \"alvan-nee-brFsZ7qszSY-unsplash.jpeg: 100%\"\n          }\n        },\n        \"b15011804460483ab84904a52da754b7\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"b48dc06ca1654fe39bf8a77352a921f2\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_7d8e3510c1e34524993849b8fce52758\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_b15011804460483ab84904a52da754b7\",\n            \"value\": \" 1.40M/1.40M [00:00&lt;00:00, 10.5MB/s]\"\n          }\n        },\n        \"b5ac4ab9256e4d5092ba6e449bc3cdd3\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"b77093ec9ffd40e0b2c1d9d1bdc063f5\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"b9912757b9f9477186c171ecb2551d3a\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"ProgressStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"ProgressStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"bar_color\": null,\n            \"description_width\": \"\"\n          }\n        },\n        \"be260274fdb04798af6fce6169646ff2\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"be4fd10d940d49cf8e916904da8192ab\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_d7365b62df59406dbd38677299cce1c8\",\n            \"max\": 1186464,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_67f0f5f1179140b4bdaa74c5583e3958\",\n            \"value\": 1186464\n          }\n        },\n        \"c219a1e791894469aa1452045b0f74b5\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_0b5bb94394fc447282d2c44780303f15\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_01bfa49325a8403b808ad1662465996e\",\n            \"value\": \" 1.17M/1.17M [00:00&lt;00:00, 15.7MB/s]\"\n          }\n        },\n        \"c39072b8cfff4a11ba283a9ae3155e52\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_be260274fdb04798af6fce6169646ff2\",\n            \"max\": 5,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_b9912757b9f9477186c171ecb2551d3a\",\n            \"value\": 5\n          }\n        },\n        \"c6211ddb71e64f9e92c70158da2f7ef1\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"FloatProgressModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"FloatProgressModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"ProgressView\",\n            \"bar_style\": \"success\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_03047a13f06744fcac17c77cb03bca62\",\n            \"max\": 1167042,\n            \"min\": 0,\n            \"orientation\": \"horizontal\",\n            \"style\": \"IPY_MODEL_4d77a9c44d1c47b18022f8a51b29e20d\",\n            \"value\": 1167042\n          }\n        },\n        \"c80f3825981646a8a3e178192e338962\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"cd94205b05d54e4c96c9c475e13abe83\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"DescriptionStyleModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"DescriptionStyleModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"StyleView\",\n            \"description_width\": \"\"\n          }\n        },\n        \"cdee3b61ca6a487c8ec8e7e884eb8b07\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"d7365b62df59406dbd38677299cce1c8\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"d7c7c83b341b4471ad8a0ca1fe76d9ff\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"dd2debcf1c774181bef97efab0f3d6e1\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HBoxModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HBoxModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HBoxView\",\n            \"box_style\": \"\",\n            \"children\": [\n              \"IPY_MODEL_633d7df9a17e4bf6951249aca83a9e96\",\n              \"IPY_MODEL_6469e9991d7b41a0b83a7b443b9eebe5\",\n              \"IPY_MODEL_0b9c72fa39c241ba9dd22dd67c2436fe\"\n            ],\n            \"layout\": \"IPY_MODEL_99e707cfe1454757aad4014230f6dae8\"\n          }\n        },\n        \"e1fb8dec23c04d6f8d1217242f8a495c\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"e560f25c3e334cf2a4c748981ac38da6\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"f6b3aa0f980e450289ee15cea7cb3ed7\": {\n          \"model_module\": \"@jupyter-widgets/base\",\n          \"model_module_version\": \"1.2.0\",\n          \"model_name\": \"LayoutModel\",\n          \"state\": {\n            \"_model_module\": \"@jupyter-widgets/base\",\n            \"_model_module_version\": \"1.2.0\",\n            \"_model_name\": \"LayoutModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/base\",\n            \"_view_module_version\": \"1.2.0\",\n            \"_view_name\": \"LayoutView\",\n            \"align_content\": null,\n            \"align_items\": null,\n            \"align_self\": null,\n            \"border\": null,\n            \"bottom\": null,\n            \"display\": null,\n            \"flex\": null,\n            \"flex_flow\": null,\n            \"grid_area\": null,\n            \"grid_auto_columns\": null,\n            \"grid_auto_flow\": null,\n            \"grid_auto_rows\": null,\n            \"grid_column\": null,\n            \"grid_gap\": null,\n            \"grid_row\": null,\n            \"grid_template_areas\": null,\n            \"grid_template_columns\": null,\n            \"grid_template_rows\": null,\n            \"height\": null,\n            \"justify_content\": null,\n            \"justify_items\": null,\n            \"left\": null,\n            \"margin\": null,\n            \"max_height\": null,\n            \"max_width\": null,\n            \"min_height\": null,\n            \"min_width\": null,\n            \"object_fit\": null,\n            \"object_position\": null,\n            \"order\": null,\n            \"overflow\": null,\n            \"overflow_x\": null,\n            \"overflow_y\": null,\n            \"padding\": null,\n            \"right\": null,\n            \"top\": null,\n            \"visibility\": null,\n            \"width\": null\n          }\n        },\n        \"fd93adba791f46c1b0a25ff692426149\": {\n          \"model_module\": \"@jupyter-widgets/controls\",\n          \"model_module_version\": \"1.5.0\",\n          \"model_name\": \"HTMLModel\",\n          \"state\": {\n            \"_dom_classes\": [],\n            \"_model_module\": \"@jupyter-widgets/controls\",\n            \"_model_module_version\": \"1.5.0\",\n            \"_model_name\": \"HTMLModel\",\n            \"_view_count\": null,\n            \"_view_module\": \"@jupyter-widgets/controls\",\n            \"_view_module_version\": \"1.5.0\",\n            \"_view_name\": \"HTMLView\",\n            \"description\": \"\",\n            \"description_tooltip\": null,\n            \"layout\": \"IPY_MODEL_e560f25c3e334cf2a4c748981ac38da6\",\n            \"placeholder\": \"​\",\n            \"style\": \"IPY_MODEL_65173381a80b40748b7b2800fdb89151\",\n            \"value\": \" 1.19M/1.19M [00:00&lt;00:00, 12.7MB/s]\"\n          }\n        }\n      }\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}"
  },
  {
    "path": "examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport copy\nimport gc\nimport hashlib\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom PIL import Image\nfrom PIL.ImageOps import exif_transpose\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    FlowMatchEulerDiscreteScheduler,\n    SD3Transformer2DModel,\n    StableDiffusion3Pipeline,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.30.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(\n    repo_id: str,\n    images=None,\n    base_model: str = None,\n    train_text_encoder=False,\n    instance_prompt=None,\n    validation_prompt=None,\n    repo_folder=None,\n):\n    widget_dict = []\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            widget_dict.append(\n                {\"text\": validation_prompt if validation_prompt else \" \", \"output\": {\"url\": f\"image_{i}.png\"}}\n            )\n\n    model_description = f\"\"\"\n# SD3 DreamBooth LoRA - {repo_id}\n\n<Gallery />\n\n## Model description\n\nThese are {repo_id} DreamBooth weights for {base_model}.\n\nThe weights were trained  using [DreamBooth](https://dreambooth.github.io/).\n\nLoRA for the text encoder was enabled: {train_text_encoder}.\n\n## Trigger words\n\nYou should use {instance_prompt} to trigger the image generation.\n\n## Download model\n\n[Download]({repo_id}/tree/main) them in the Files & versions tab.\n\n## License\n\nPlease adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"openrail++\",\n        base_model=base_model,\n        prompt=instance_prompt,\n        model_description=model_description,\n        widget=widget_dict,\n    )\n    tags = [\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n        \"lora\",\n        \"sd3\",\n        \"sd3-diffusers\",\n        \"template:sd-lora\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline.enable_model_cpu_offload()\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n    # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n    autocast_ctx = nullcontext()\n\n    with autocast_ctx:\n        images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n\n    return images\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--instance_data_dir\",\n        type=str,\n        default=None,\n        help=(\"A folder containing the training data. \"),\n    )\n    parser.add_argument(\n        \"--data_df_path\",\n        type=str,\n        default=None,\n        help=(\"Path to the parquet file serialized with compute_embeddings.py.\"),\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=77,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=50,\n        help=(\n            \"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd3-dreambooth-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"logit_normal\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\"],\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n\n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer.\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.instance_data_dir is None:\n        raise ValueError(\"Specify `instance_data_dir`.\")\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\nclass DreamBoothDataset(Dataset):\n    \"\"\"\n    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.\n    It pre-processes the images.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_df_path,\n        instance_data_root,\n        instance_prompt,\n        size=1024,\n        center_crop=False,\n    ):\n        # Logistics\n        self.size = size\n        self.center_crop = center_crop\n\n        self.instance_prompt = instance_prompt\n        self.instance_data_root = Path(instance_data_root)\n        if not self.instance_data_root.exists():\n            raise ValueError(\"Instance images root doesn't exists.\")\n\n        # Load images.\n        instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]\n        image_hashes = [self.generate_image_hash(path) for path in list(Path(instance_data_root).iterdir())]\n        self.instance_images = instance_images\n        self.image_hashes = image_hashes\n\n        # Image transformations\n        self.pixel_values = self.apply_image_transformations(\n            instance_images=instance_images, size=size, center_crop=center_crop\n        )\n\n        # Map hashes to embeddings.\n        self.data_dict = self.map_image_hash_embedding(data_df_path=data_df_path)\n\n        self.num_instance_images = len(instance_images)\n        self._length = self.num_instance_images\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, index):\n        example = {}\n        instance_image = self.pixel_values[index % self.num_instance_images]\n        image_hash = self.image_hashes[index % self.num_instance_images]\n        prompt_embeds, pooled_prompt_embeds = self.data_dict[image_hash]\n        example[\"instance_images\"] = instance_image\n        example[\"prompt_embeds\"] = prompt_embeds\n        example[\"pooled_prompt_embeds\"] = pooled_prompt_embeds\n        return example\n\n    def apply_image_transformations(self, instance_images, size, center_crop):\n        pixel_values = []\n\n        train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)\n        train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n        train_flip = transforms.RandomHorizontalFlip(p=1.0)\n        train_transforms = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        for image in instance_images:\n            image = exif_transpose(image)\n            if not image.mode == \"RGB\":\n                image = image.convert(\"RGB\")\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            image = train_transforms(image)\n            pixel_values.append(image)\n\n        return pixel_values\n\n    def convert_to_torch_tensor(self, embeddings: list):\n        prompt_embeds = embeddings[0]\n        pooled_prompt_embeds = embeddings[1]\n        prompt_embeds = np.array(prompt_embeds).reshape(154, 4096)\n        pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(2048)\n        return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds)\n\n    def map_image_hash_embedding(self, data_df_path):\n        hashes_df = pd.read_parquet(data_df_path)\n        data_dict = {}\n        for i, row in hashes_df.iterrows():\n            embeddings = [row[\"prompt_embeds\"], row[\"pooled_prompt_embeds\"]]\n            prompt_embeds, pooled_prompt_embeds = self.convert_to_torch_tensor(embeddings=embeddings)\n            data_dict.update({row[\"image_hash\"]: (prompt_embeds, pooled_prompt_embeds)})\n        return data_dict\n\n    def generate_image_hash(self, image_path):\n        with open(image_path, \"rb\") as f:\n            img_data = f.read()\n        return hashlib.sha256(img_data).hexdigest()\n\n\ndef collate_fn(examples):\n    pixel_values = [example[\"instance_images\"] for example in examples]\n    prompt_embeds = [example[\"prompt_embeds\"] for example in examples]\n    pooled_prompt_embeds = [example[\"pooled_prompt_embeds\"] for example in examples]\n\n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n    prompt_embeds = torch.stack(prompt_embeds)\n    pooled_prompt_embeds = torch.stack(pooled_prompt_embeds)\n\n    batch = {\n        \"pixel_values\": pixel_values,\n        \"prompt_embeds\": prompt_embeds,\n        \"pooled_prompt_embeds\": pooled_prompt_embeds,\n    }\n    return batch\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n            ).repo_id\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n    )\n    noise_scheduler_copy = copy.deepcopy(noise_scheduler)\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    transformer = SD3Transformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n    )\n\n    transformer.requires_grad_(False)\n    vae.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    vae.to(accelerator.device, dtype=torch.float32)\n    transformer.to(accelerator.device, dtype=weight_dtype)\n\n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n\n    # now we will add new LoRA weights to the attention layers\n    transformer_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n    transformer.add_adapter(transformer_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            transformer_lora_layers_to_save = None\n            for model in models:\n                if isinstance(model, type(unwrap_model(transformer))):\n                    transformer_lora_layers_to_save = get_peft_model_state_dict(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            StableDiffusion3Pipeline.save_lora_weights(\n                output_dir,\n                transformer_lora_layers=transformer_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        transformer_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(transformer))):\n                transformer_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)\n\n        transformer_state_dict = {\n            f\"{k.replace('transformer.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"transformer.\")\n        }\n        transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)\n        incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [transformer_]\n            # only upcast trainable parameters (LoRA) into fp32\n            cast_training_params(models)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    # Optimization parameters\n    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n    transformer_parameters_with_lr = {\"params\": transformer_lora_parameters, \"lr\": args.learning_rate}\n    params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include [adamW].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = DreamBoothDataset(\n        data_df_path=args.data_df_path,\n        instance_data_root=args.instance_data_dir,\n        instance_prompt=args.instance_prompt,\n        size=args.resolution,\n        center_crop=args.center_crop,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"dreambooth-sd3-lora-miniature\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n\n        for step, batch in enumerate(train_dataloader):\n            models_to_accumulate = [transformer]\n            with accelerator.accumulate(models_to_accumulate):\n                pixel_values = batch[\"pixel_values\"].to(dtype=vae.dtype)\n\n                # Convert images to latent space\n                model_input = vae.encode(pixel_values).latent_dist.sample()\n                model_input = model_input * vae.config.scaling_factor\n                model_input = model_input.to(dtype=weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                bsz = model_input.shape[0]\n\n                # Sample a random timestep for each image\n                # for weighting schemes where we sample timesteps non-uniformly\n                u = compute_density_for_timestep_sampling(\n                    weighting_scheme=args.weighting_scheme,\n                    batch_size=bsz,\n                    logit_mean=args.logit_mean,\n                    logit_std=args.logit_std,\n                    mode_scale=args.mode_scale,\n                )\n                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()\n                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)\n\n                # Add noise according to flow matching.\n                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input\n\n                # Predict the noise residual\n                prompt_embeds, pooled_prompt_embeds = batch[\"prompt_embeds\"], batch[\"pooled_prompt_embeds\"]\n                prompt_embeds = prompt_embeds.to(device=accelerator.device, dtype=weight_dtype)\n                pooled_prompt_embeds = pooled_prompt_embeds.to(device=accelerator.device, dtype=weight_dtype)\n                model_pred = transformer(\n                    hidden_states=noisy_model_input,\n                    timestep=timesteps,\n                    encoder_hidden_states=prompt_embeds,\n                    pooled_projections=pooled_prompt_embeds,\n                    return_dict=False,\n                )[0]\n\n                # Follow: Section 5 of https://huggingface.co/papers/2206.00364.\n                # Preconditioning of the model outputs.\n                model_pred = model_pred * (-sigmas) + noisy_model_input\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = model_input\n\n                # Compute regular loss.\n                loss = torch.mean(\n                    (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    1,\n                )\n                loss = loss.mean()\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = transformer_lora_parameters\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                pipeline = StableDiffusion3Pipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    transformer=accelerator.unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                )\n                torch.cuda.empty_cache()\n                gc.collect()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        transformer = unwrap_model(transformer)\n        transformer = transformer.to(torch.float32)\n        transformer_lora_layers = get_peft_model_state_dict(transformer)\n\n        StableDiffusion3Pipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            transformer_lora_layers=transformer_lora_layers,\n        )\n\n        # Final inference\n        # Load previous pipeline\n        pipeline = StableDiffusion3Pipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt}\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=pipeline_args,\n                epoch=epoch,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                instance_prompt=args.instance_prompt,\n                validation_prompt=args.validation_prompt,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/research_projects/sdxl_flax/README.md",
    "content": "# Stable Diffusion XL for JAX + TPUv5e\n\n[TPU v5e](https://cloud.google.com/blog/products/compute/how-cloud-tpu-v5e-accelerates-large-scale-ai-inference) is a new generation of TPUs from Google Cloud. It is the most cost-effective, versatile, and scalable Cloud TPU to date. This makes them ideal for serving and scaling large diffusion models.\n\n[JAX](https://github.com/google/jax) is a high-performance numerical computation library that is well-suited to develop and deploy diffusion models:\n\n- **High performance**. All JAX operations are implemented in terms of operations in [XLA](https://www.tensorflow.org/xla/) - the Accelerated Linear Algebra compiler\n\n- **Compilation**. JAX uses just-in-time (jit) compilation of JAX Python functions so it can be executed efficiently in XLA. In order to get the best performance, we must use static shapes for jitted functions, this is because JAX transforms work by tracing a function and to determine its effect on inputs of a specific shape and type. When a new shape is introduced to an already compiled function, it retriggers compilation on the new shape, which can greatly reduce performance. **Note**: JIT compilation is particularly well-suited for text-to-image generation because all inputs and outputs (image input / output sizes) are static.\n\n- **Parallelization**. Workloads can be scaled across multiple devices using JAX's [pmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html), which expresses single-program multiple-data (SPMD) programs. Applying pmap to a function will compile a function with XLA, then execute in parallel on XLA devices. For text-to-image generation workloads this means that increasing the number of images rendered simultaneously is straightforward to implement and doesn't compromise performance.\n\n👉 Try it out for yourself:\n\n[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/google/sdxl)\n\n## Stable Diffusion XL pipeline in JAX\n\nUpon having access to a TPU VM (TPUs higher than version 3), you should first install\na TPU-compatible version of JAX:\n```sh\npip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html\n```\n\nNext, we can install [flax](https://github.com/google/flax) and the diffusers library:\n\n```sh\npip install flax diffusers transformers\n```\n\nIn [sdxl_single.py](./sdxl_single.py) we give a simple example of how to write a text-to-image generation pipeline in JAX using [StabilityAI's Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0).\n\nLet's explain it step-by-step:\n\n**Imports and Setup**\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom flax.jax_utils import replicate\nfrom diffusers import FlaxStableDiffusionXLPipeline\n\nfrom jax.experimental.compilation_cache import compilation_cache as cc\ncc.initialize_cache(\"/tmp/sdxl_cache\")\nimport time\n\nNUM_DEVICES = jax.device_count()\n```\n\nFirst, we import the necessary libraries:\n- `jax` is provides the primitives for TPU operations\n- `flax.jax_utils` contains some useful utility functions for `Flax`, a neural network library built on top of JAX\n- `diffusers` has all the code that is relevant for SDXL.\n- We also initialize a cache to speed up the JAX model compilation.\n- We automatically determine the number of available TPU devices.\n\n**1. Downloading Model and Loading Pipeline**\n\n```python\npipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", revision=\"refs/pr/95\", split_head_dim=True\n)\n```\nHere, a pre-trained model `stable-diffusion-xl-base-1.0` from the namespace `stabilityai` is loaded. It returns a pipeline for inference and its parameters.\n\n**2. Casting Parameter Types**\n\n```python\nscheduler_state = params.pop(\"scheduler\")\nparams = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)\nparams[\"scheduler\"] = scheduler_state\n```\nThis section adjusts the data types of the model parameters.\nWe convert all parameters to `bfloat16` to speed-up the computation with model weights.\n**Note** that the scheduler parameters are **not** converted to `blfoat16` as the loss\nin precision is degrading the pipeline's performance too significantly.\n\n**3. Define Inputs to Pipeline**\n\n```python\ndefault_prompt = ...\ndefault_neg_prompt = ...\ndefault_seed = 33\ndefault_guidance_scale = 5.0\ndefault_num_steps = 25\n```\nHere, various default inputs for the pipeline are set, including the prompt, negative prompt, random seed, guidance scale, and the number of inference steps.\n\n**4. Tokenizing Inputs**\n\n```python\ndef tokenize_prompt(prompt, neg_prompt):\n    prompt_ids = pipeline.prepare_inputs(prompt)\n    neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)\n    return prompt_ids, neg_prompt_ids\n```\nThis function tokenizes the given prompts. It's essential because the text encoders of SDXL don't understand raw text; they work with numbers. Tokenization converts text to numbers.\n\n**5. Parallelization and Replication**\n\n```python\np_params = replicate(params)\n\ndef replicate_all(prompt_ids, neg_prompt_ids, seed):\n    ...\n```\nTo utilize JAX's parallel capabilities, the parameters and input tensors are duplicated across devices. The `replicate_all` function also ensures that every device produces a different image by creating a unique random seed for each device.\n\n**6. Putting Everything Together**\n\n```python\ndef generate(...):\n    ...\n```\nThis function integrates all the steps to produce the desired outputs from the model. It takes in prompts, tokenizes them, replicates them across devices, runs them through the pipeline, and converts the images to a format that's more interpretable (PIL format).\n\n**7. Compilation Step**\n\n```python\nstart = time.time()\nprint(f\"Compiling ...\")\ngenerate(default_prompt, default_neg_prompt)\nprint(f\"Compiled in {time.time() - start}\")\n```\nThe initial run of the `generate` function will be slow because JAX compiles the function during this call. By running it once here, subsequent calls will be much faster. This section measures and prints the compilation time.\n\n**8. Fast Inference**\n\n```python\nstart = time.time()\nprompt = ...\nneg_prompt = ...\nimages = generate(prompt, neg_prompt)\nprint(f\"Inference in {time.time() - start}\")\n```\nNow that the function is compiled, this section shows how to use it for fast inference. It measures and prints the inference time.\n\nIn summary, the code demonstrates how to load a pre-trained model using Flax and JAX, prepare it for inference, and run it efficiently using JAX's capabilities.\n\n## Ahead of Time (AOT) Compilation\n\nFlaxStableDiffusionXLPipeline takes care of parallelization across multiple devices using jit. Now let's build parallelization ourselves.\n\nFor this we will be using a JAX feature called [Ahead of Time](https://jax.readthedocs.io/en/latest/aot.html) (AOT) lowering and compilation. AOT allows to fully compile prior to execution time and have control over different parts of the compilation process.\n\nIn [sdxl_single_aot.py](./sdxl_single_aot.py) we give a simple example of how to write our own parallelization logic for text-to-image generation pipeline in JAX using [StabilityAI's Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0)\n\nWe add a `aot_compile` function that compiles the `pipeline._generate` function\ntelling JAX which input arguments are static, that is, arguments that\nare known at compile time and won't change. In our case, it is num_inference_steps,\nheight, width and return_latents.\n\nOnce the function is compiled, these parameters are omitted from future calls and\ncannot be changed without modifying the code and recompiling.\n\n```python\ndef aot_compile(\n        prompt=default_prompt,\n        negative_prompt=default_neg_prompt,\n        seed=default_seed,\n        guidance_scale=default_guidance_scale,\n        num_inference_steps=default_num_steps\n):\n    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)\n    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)\n    g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)\n    g = g[:, None]\n\n    return pmap(\n        pipeline._generate,static_broadcasted_argnums=[3, 4, 5, 9]\n        ).lower(\n            prompt_ids,\n            p_params,\n            rng,\n            num_inference_steps, # num_inference_steps\n            height, # height\n            width, # width\n            g,\n            None,\n            neg_prompt_ids,\n            False # return_latents\n            ).compile()\n````\n\nNext we can compile the generate function by executing `aot_compile`.\n\n```python\nstart = time.time()\nprint(\"Compiling ...\")\np_generate = aot_compile()\nprint(f\"Compiled in {time.time() - start}\")\n```\nAnd again we put everything together in a `generate` function.\n\n```python\ndef generate(\n    prompt,\n    negative_prompt,\n    seed=default_seed,\n    guidance_scale=default_guidance_scale\n):\n    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)\n    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)\n    g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)\n    g = g[:, None]\n    images = p_generate(\n        prompt_ids,\n        p_params,\n        rng,\n        g,\n        None,\n        neg_prompt_ids)\n\n    # convert the images to PIL\n    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])\n    return pipeline.numpy_to_pil(np.array(images))\n```\n\nThe first forward pass after AOT compilation still takes a while longer than\nsubsequent passes, this is because on the first pass, JAX uses Python dispatch, which\nFills the C++ dispatch cache.\nWhen using jit, this extra step is done automatically, but when using AOT compilation,\nit doesn't happen until the function call is made.\n\n```python\nstart = time.time()\nprompt = \"photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang\"\nneg_prompt = \"cartoon, illustration, animation. face. male, female\"\nimages = generate(prompt, neg_prompt)\nprint(f\"First inference in {time.time() - start}\")\n```\n\nFrom this point forward, Any calls to generate should result in a faster inference\ntime and it won't change.\n\n```python\nstart = time.time()\nprompt = \"photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang\"\nneg_prompt = \"cartoon, illustration, animation. face. male, female\"\nimages = generate(prompt, neg_prompt)\nprint(f\"Inference in {time.time() - start}\")\n```\n"
  },
  {
    "path": "examples/research_projects/sdxl_flax/sdxl_single.py",
    "content": "# Show best practices for SDXL JAX\nimport time\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom flax.jax_utils import replicate\n\n# Let's cache the model compilation, so that it doesn't take as long the next time around.\nfrom jax.experimental.compilation_cache import compilation_cache as cc\n\nfrom diffusers import FlaxStableDiffusionXLPipeline\n\n\ncc.initialize_cache(\"/tmp/sdxl_cache\")\n\n\nNUM_DEVICES = jax.device_count()\n\n# 1. Let's start by downloading the model and loading it into our pipeline class\n# Adhering to JAX's functional approach, the model's parameters are returned separately and\n# will have to be passed to the pipeline during inference\npipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", revision=\"refs/pr/95\", split_head_dim=True\n)\n\n# 2. We cast all parameters to bfloat16 EXCEPT the scheduler which we leave in\n# float32 to keep maximal precision\nscheduler_state = params.pop(\"scheduler\")\nparams = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)\nparams[\"scheduler\"] = scheduler_state\n\n# 3. Next, we define the different inputs to the pipeline\ndefault_prompt = \"a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart\"\ndefault_neg_prompt = \"fog, grainy, purple\"\ndefault_seed = 33\ndefault_guidance_scale = 5.0\ndefault_num_steps = 25\n\n\n# 4. In order to be able to compile the pipeline\n# all inputs have to be tensors or strings\n# Let's tokenize the prompt and negative prompt\ndef tokenize_prompt(prompt, neg_prompt):\n    prompt_ids = pipeline.prepare_inputs(prompt)\n    neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)\n    return prompt_ids, neg_prompt_ids\n\n\n# 5. To make full use of JAX's parallelization capabilities\n# the parameters and input tensors are duplicated across devices\n# To make sure every device generates a different image, we create\n# different seeds for each image. The model parameters won't change\n# during inference so we do not wrap them into a function\np_params = replicate(params)\n\n\ndef replicate_all(prompt_ids, neg_prompt_ids, seed):\n    p_prompt_ids = replicate(prompt_ids)\n    p_neg_prompt_ids = replicate(neg_prompt_ids)\n    rng = jax.random.PRNGKey(seed)\n    rng = jax.random.split(rng, NUM_DEVICES)\n    return p_prompt_ids, p_neg_prompt_ids, rng\n\n\n# 6. Let's now put it all together in a generate function\ndef generate(\n    prompt,\n    negative_prompt,\n    seed=default_seed,\n    guidance_scale=default_guidance_scale,\n    num_inference_steps=default_num_steps,\n):\n    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)\n    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)\n    images = pipeline(\n        prompt_ids,\n        p_params,\n        rng,\n        num_inference_steps=num_inference_steps,\n        neg_prompt_ids=neg_prompt_ids,\n        guidance_scale=guidance_scale,\n        jit=True,\n    ).images\n\n    # convert the images to PIL\n    images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])\n    return pipeline.numpy_to_pil(np.array(images))\n\n\n# 7. Remember that the first call will compile the function and hence be very slow. Let's run generate once\n# so that the pipeline call is compiled\nstart = time.time()\nprint(\"Compiling ...\")\ngenerate(default_prompt, default_neg_prompt)\nprint(f\"Compiled in {time.time() - start}\")\n\n# 8. Now the model forward pass will run very quickly, let's try it again\nstart = time.time()\nprompt = \"photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang\"\nneg_prompt = \"cartoon, illustration, animation. face. male, female\"\nimages = generate(prompt, neg_prompt)\nprint(f\"Inference in {time.time() - start}\")\n\nfor i, image in enumerate(images):\n    image.save(f\"castle_{i}.png\")\n"
  },
  {
    "path": "examples/research_projects/sdxl_flax/sdxl_single_aot.py",
    "content": "import time\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom flax.jax_utils import replicate\nfrom jax import pmap\n\n# Let's cache the model compilation, so that it doesn't take as long the next time around.\nfrom jax.experimental.compilation_cache import compilation_cache as cc\n\nfrom diffusers import FlaxStableDiffusionXLPipeline\n\n\ncc.initialize_cache(\"/tmp/sdxl_cache\")\n\n\nNUM_DEVICES = jax.device_count()\n\n# 1. Let's start by downloading the model and loading it into our pipeline class\n# Adhering to JAX's functional approach, the model's parameters are returned separately and\n# will have to be passed to the pipeline during inference\npipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(\n    \"stabilityai/stable-diffusion-xl-base-1.0\", revision=\"refs/pr/95\", split_head_dim=True\n)\n\n# 2. We cast all parameters to bfloat16 EXCEPT the scheduler which we leave in\n# float32 to keep maximal precision\nscheduler_state = params.pop(\"scheduler\")\nparams = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)\nparams[\"scheduler\"] = scheduler_state\n\n# 3. Next, we define the different inputs to the pipeline\ndefault_prompt = \"a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart\"\ndefault_neg_prompt = \"fog, grainy, purple\"\ndefault_seed = 33\ndefault_guidance_scale = 5.0\ndefault_num_steps = 25\nwidth = 1024\nheight = 1024\n\n\n# 4. In order to be able to compile the pipeline\n# all inputs have to be tensors or strings\n# Let's tokenize the prompt and negative prompt\ndef tokenize_prompt(prompt, neg_prompt):\n    prompt_ids = pipeline.prepare_inputs(prompt)\n    neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)\n    return prompt_ids, neg_prompt_ids\n\n\n# 5. To make full use of JAX's parallelization capabilities\n# the parameters and input tensors are duplicated across devices\n# To make sure every device generates a different image, we create\n# different seeds for each image. The model parameters won't change\n# during inference so we do not wrap them into a function\np_params = replicate(params)\n\n\ndef replicate_all(prompt_ids, neg_prompt_ids, seed):\n    p_prompt_ids = replicate(prompt_ids)\n    p_neg_prompt_ids = replicate(neg_prompt_ids)\n    rng = jax.random.PRNGKey(seed)\n    rng = jax.random.split(rng, NUM_DEVICES)\n    return p_prompt_ids, p_neg_prompt_ids, rng\n\n\n# 6. To compile the pipeline._generate function, we must pass all parameters\n# to the function and tell JAX which are static arguments, that is, arguments that\n# are known at compile time and won't change. In our case, it is num_inference_steps,\n# height, width and return_latents.\n# Once the function is compiled, these parameters are omitted from future calls and\n# cannot be changed without modifying the code and recompiling.\ndef aot_compile(\n    prompt=default_prompt,\n    negative_prompt=default_neg_prompt,\n    seed=default_seed,\n    guidance_scale=default_guidance_scale,\n    num_inference_steps=default_num_steps,\n):\n    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)\n    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)\n    g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)\n    g = g[:, None]\n\n    return (\n        pmap(pipeline._generate, static_broadcasted_argnums=[3, 4, 5, 9])\n        .lower(\n            prompt_ids,\n            p_params,\n            rng,\n            num_inference_steps,  # num_inference_steps\n            height,  # height\n            width,  # width\n            g,\n            None,\n            neg_prompt_ids,\n            False,  # return_latents\n        )\n        .compile()\n    )\n\n\nstart = time.time()\nprint(\"Compiling ...\")\np_generate = aot_compile()\nprint(f\"Compiled in {time.time() - start}\")\n\n\n# 7. Let's now put it all together in a generate function.\ndef generate(prompt, negative_prompt, seed=default_seed, guidance_scale=default_guidance_scale):\n    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)\n    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)\n    g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)\n    g = g[:, None]\n    images = p_generate(prompt_ids, p_params, rng, g, None, neg_prompt_ids)\n\n    # convert the images to PIL\n    images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])\n    return pipeline.numpy_to_pil(np.array(images))\n\n\n# 8. The first forward pass after AOT compilation still takes a while longer than\n# subsequent passes, this is because on the first pass, JAX uses Python dispatch, which\n# Fills the C++ dispatch cache.\n# When using jit, this extra step is done automatically, but when using AOT compilation,\n# it doesn't happen until the function call is made.\nstart = time.time()\nprompt = \"photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang\"\nneg_prompt = \"cartoon, illustration, animation. face. male, female\"\nimages = generate(prompt, neg_prompt)\nprint(f\"First inference in {time.time() - start}\")\n\n# 9. From this point forward, Any calls to generate should result in a faster inference\n# time and it won't change.\nstart = time.time()\nprompt = \"photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang\"\nneg_prompt = \"cartoon, illustration, animation. face. male, female\"\nimages = generate(prompt, neg_prompt)\nprint(f\"Inference in {time.time() - start}\")\n\nfor i, image in enumerate(images):\n    image.save(f\"castle_{i}.png\")\n"
  },
  {
    "path": "examples/research_projects/vae/README.md",
    "content": "# VAE\n\n`vae_roundtrip.py` Demonstrates the use of a VAE by roundtripping an image through the encoder and decoder. Original and reconstructed images are displayed side by side.\n\n```\ncd examples/research_projects/vae\npython vae_roundtrip.py \\\n    --pretrained_model_name_or_path=\"stable-diffusion-v1-5/stable-diffusion-v1-5\" \\\n    --subfolder=\"vae\" \\\n    --input_image=\"/path/to/your/input.png\"\n```\n"
  },
  {
    "path": "examples/research_projects/vae/vae_roundtrip.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport typing\nfrom typing import Optional, Union\n\nimport torch\nfrom PIL import Image\nfrom torchvision import transforms  # type: ignore\n\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.models.autoencoders.autoencoder_kl import (\n    AutoencoderKL,\n    AutoencoderKLOutput,\n)\nfrom diffusers.models.autoencoders.autoencoder_tiny import (\n    AutoencoderTiny,\n    AutoencoderTinyOutput,\n)\nfrom diffusers.models.autoencoders.vae import DecoderOutput\n\n\nSupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny]\n\n\ndef load_vae_model(\n    *,\n    device: torch.device,\n    model_name_or_path: str,\n    revision: str | None,\n    variant: str | None,\n    # NOTE: use subfolder=\"vae\" if the pointed model is for stable diffusion as a whole instead of just the VAE\n    subfolder: str | None,\n    use_tiny_nn: bool,\n) -> SupportedAutoencoder:\n    if use_tiny_nn:\n        # NOTE: These scaling factors don't have to be the same as each other.\n        down_scale = 2\n        up_scale = 2\n        vae = AutoencoderTiny.from_pretrained(  # type: ignore\n            model_name_or_path,\n            subfolder=subfolder,\n            revision=revision,\n            variant=variant,\n            downscaling_scaling_factor=down_scale,\n            upsampling_scaling_factor=up_scale,\n        )\n        assert isinstance(vae, AutoencoderTiny)\n    else:\n        vae = AutoencoderKL.from_pretrained(  # type: ignore\n            model_name_or_path,\n            subfolder=subfolder,\n            revision=revision,\n            variant=variant,\n        )\n        assert isinstance(vae, AutoencoderKL)\n    vae = vae.to(device)\n    vae.eval()  # Set the model to inference mode\n    return vae\n\n\ndef pil_to_nhwc(\n    *,\n    device: torch.device,\n    image: Image.Image,\n) -> torch.Tensor:\n    assert image.mode == \"RGB\"\n    transform = transforms.ToTensor()\n    nhwc = transform(image).unsqueeze(0).to(device)  # type: ignore\n    assert isinstance(nhwc, torch.Tensor)\n    return nhwc\n\n\ndef nhwc_to_pil(\n    *,\n    nhwc: torch.Tensor,\n) -> Image.Image:\n    assert nhwc.shape[0] == 1\n    hwc = nhwc.squeeze(0).cpu()\n    return transforms.ToPILImage()(hwc)  # type: ignore\n\n\ndef concatenate_images(\n    *,\n    left: Image.Image,\n    right: Image.Image,\n    vertical: bool = False,\n) -> Image.Image:\n    width1, height1 = left.size\n    width2, height2 = right.size\n    if vertical:\n        total_height = height1 + height2\n        max_width = max(width1, width2)\n        new_image = Image.new(\"RGB\", (max_width, total_height))\n        new_image.paste(left, (0, 0))\n        new_image.paste(right, (0, height1))\n    else:\n        total_width = width1 + width2\n        max_height = max(height1, height2)\n        new_image = Image.new(\"RGB\", (total_width, max_height))\n        new_image.paste(left, (0, 0))\n        new_image.paste(right, (width1, 0))\n    return new_image\n\n\ndef to_latent(\n    *,\n    rgb_nchw: torch.Tensor,\n    vae: SupportedAutoencoder,\n) -> torch.Tensor:\n    rgb_nchw = VaeImageProcessor.normalize(rgb_nchw)  # type: ignore\n    encoding_nchw = vae.encode(typing.cast(torch.FloatTensor, rgb_nchw))\n    if isinstance(encoding_nchw, AutoencoderKLOutput):\n        latent = encoding_nchw.latent_dist.sample()  # type: ignore\n        assert isinstance(latent, torch.Tensor)\n    elif isinstance(encoding_nchw, AutoencoderTinyOutput):\n        latent = encoding_nchw.latents\n        do_internal_vae_scaling = False  # Is this needed?\n        if do_internal_vae_scaling:\n            latent = vae.scale_latents(latent).mul(255).round().byte()  # type: ignore\n            latent = vae.unscale_latents(latent / 255.0)  # type: ignore\n            assert isinstance(latent, torch.Tensor)\n    else:\n        assert False, f\"Unknown encoding type: {type(encoding_nchw)}\"\n    return latent\n\n\ndef from_latent(\n    *,\n    latent_nchw: torch.Tensor,\n    vae: SupportedAutoencoder,\n) -> torch.Tensor:\n    decoding_nchw = vae.decode(latent_nchw)  # type: ignore\n    assert isinstance(decoding_nchw, DecoderOutput)\n    rgb_nchw = VaeImageProcessor.denormalize(decoding_nchw.sample)  # type: ignore\n    assert isinstance(rgb_nchw, torch.Tensor)\n    return rgb_nchw\n\n\ndef main_kwargs(\n    *,\n    device: torch.device,\n    input_image_path: str,\n    pretrained_model_name_or_path: str,\n    revision: str | None,\n    variant: str | None,\n    subfolder: str | None,\n    use_tiny_nn: bool,\n) -> None:\n    vae = load_vae_model(\n        device=device,\n        model_name_or_path=pretrained_model_name_or_path,\n        revision=revision,\n        variant=variant,\n        subfolder=subfolder,\n        use_tiny_nn=use_tiny_nn,\n    )\n    original_pil = Image.open(input_image_path).convert(\"RGB\")\n    original_image = pil_to_nhwc(\n        device=device,\n        image=original_pil,\n    )\n    print(f\"Original image shape: {original_image.shape}\")\n    reconstructed_image: Optional[torch.Tensor] = None\n\n    with torch.no_grad():\n        latent_image = to_latent(rgb_nchw=original_image, vae=vae)\n        print(f\"Latent shape: {latent_image.shape}\")\n        reconstructed_image = from_latent(latent_nchw=latent_image, vae=vae)\n        reconstructed_pil = nhwc_to_pil(nhwc=reconstructed_image)\n    combined_image = concatenate_images(\n        left=original_pil,\n        right=reconstructed_pil,\n        vertical=False,\n    )\n    combined_image.show(\"Original | Reconstruction\")\n    print(f\"Reconstructed image shape: {reconstructed_image.shape}\")\n\n\ndef parse_args() -> argparse.Namespace:\n    parser = argparse.ArgumentParser(description=\"Inference with VAE\")\n    parser.add_argument(\n        \"--input_image\",\n        type=str,\n        required=True,\n        help=\"Path to the input image for inference.\",\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        required=True,\n        help=\"Path to pretrained VAE model.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        help=\"Model version.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Model file variant, e.g., 'fp16'.\",\n    )\n    parser.add_argument(\n        \"--subfolder\",\n        type=str,\n        default=None,\n        help=\"Subfolder in the model file.\",\n    )\n    parser.add_argument(\n        \"--use_cuda\",\n        action=\"store_true\",\n        help=\"Use CUDA if available.\",\n    )\n    parser.add_argument(\n        \"--use_tiny_nn\",\n        action=\"store_true\",\n        help=\"Use tiny neural network.\",\n    )\n    return parser.parse_args()\n\n\n# EXAMPLE USAGE:\n#\n# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path \"stable-diffusion-v1-5/stable-diffusion-v1-5\" --subfolder \"vae\" --input_image \"foo.png\"\n#\n# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path \"madebyollin/taesd\" --use_tiny_nn --input_image \"foo.png\"\n#\ndef main_cli() -> None:\n    args = parse_args()\n\n    input_image_path = args.input_image\n    assert isinstance(input_image_path, str)\n\n    pretrained_model_name_or_path = args.pretrained_model_name_or_path\n    assert isinstance(pretrained_model_name_or_path, str)\n\n    revision = args.revision\n    assert isinstance(revision, (str, type(None)))\n\n    variant = args.variant\n    assert isinstance(variant, (str, type(None)))\n\n    subfolder = args.subfolder\n    assert isinstance(subfolder, (str, type(None)))\n\n    use_cuda = args.use_cuda\n    assert isinstance(use_cuda, bool)\n\n    use_tiny_nn = args.use_tiny_nn\n    assert isinstance(use_tiny_nn, bool)\n\n    device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n\n    main_kwargs(\n        device=device,\n        input_image_path=input_image_path,\n        pretrained_model_name_or_path=pretrained_model_name_or_path,\n        revision=revision,\n        variant=variant,\n        subfolder=subfolder,\n        use_tiny_nn=use_tiny_nn,\n    )\n\n\nif __name__ == \"__main__\":\n    main_cli()\n"
  },
  {
    "path": "examples/research_projects/wuerstchen/text_to_image/README.md",
    "content": "# Würstchen text-to-image fine-tuning\n\n## Running locally with PyTorch\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd into the example folder and run\n```bash\ncd examples/wuerstchen/text_to_image\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\nFor this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run:\n```bash\nhf auth login\n```\n\n## Prior training\n\nYou can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups.\n\n<br>\n\n<!-- accelerate_snippet_start -->\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch  train_text_to_image_prior.py \\\n  --mixed_precision=\"fp16\" \\\n  --dataset_name=$DATASET_NAME \\\n  --resolution=768 \\\n  --train_batch_size=4 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --dataloader_num_workers=4 \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --checkpoints_total_limit=3 \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --validation_prompts=\"A robot naruto, 4k photo\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub \\\n  --output_dir=\"wuerstchen-prior-naruto-model\"\n```\n<!-- accelerate_snippet_end -->\n\n## Training with LoRA\n\nLow-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.\n\nIn a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:\n\n- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).\n- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.\n- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.\n\n\n### Prior Training\n\nFirst, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Naruto captions dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).\n\n```bash\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch train_text_to_image_lora_prior.py \\\n  --mixed_precision=\"fp16\" \\\n  --dataset_name=$DATASET_NAME --caption_column=\"text\" \\\n  --resolution=768 \\\n  --train_batch_size=8 \\\n  --num_train_epochs=100 --checkpointing_steps=5000 \\\n  --learning_rate=1e-04 --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --seed=42 \\\n  --rank=4 \\\n  --validation_prompt=\"cute dragon creature\" \\\n  --report_to=\"wandb\" \\\n  --push_to_hub \\\n  --output_dir=\"wuerstchen-prior-naruto-lora\"\n```\n"
  },
  {
    "path": "examples/research_projects/wuerstchen/text_to_image/__init__.py",
    "content": ""
  },
  {
    "path": "examples/research_projects/wuerstchen/text_to_image/modeling_efficient_net_encoder.py",
    "content": "import torch.nn as nn\nfrom torchvision.models import efficientnet_v2_l, efficientnet_v2_s\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.modeling_utils import ModelMixin\n\n\nclass EfficientNetEncoder(ModelMixin, ConfigMixin):\n    @register_to_config\n    def __init__(self, c_latent=16, c_cond=1280, effnet=\"efficientnet_v2_s\"):\n        super().__init__()\n\n        if effnet == \"efficientnet_v2_s\":\n            self.backbone = efficientnet_v2_s(weights=\"DEFAULT\").features\n        else:\n            self.backbone = efficientnet_v2_l(weights=\"DEFAULT\").features\n        self.mapper = nn.Sequential(\n            nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False),\n            nn.BatchNorm2d(c_latent),  # then normalize them to have mean 0 and std 1\n        )\n\n    def forward(self, x):\n        return self.mapper(self.backbone(x))\n"
  },
  {
    "path": "examples/research_projects/wuerstchen/text_to_image/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nwandb\nbitsandbytes\ndeepspeed\npeft>=0.6.0\n"
  },
  {
    "path": "examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.state import AcceleratorState, is_initialized\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, hf_hub_download, upload_folder\nfrom modeling_efficient_net_encoder import EfficientNetEncoder\nfrom peft import LoraConfig\nfrom peft.utils import get_peft_model_state_dict\nfrom torchvision import transforms\nfrom tqdm import tqdm\nfrom transformers import CLIPTextModel, PreTrainedTokenizerFast\nfrom transformers.utils import ContextManagers\n\nfrom diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior\nfrom diffusers.utils import check_min_version, is_wandb_available, make_image_grid\nfrom diffusers.utils.logging import set_verbosity_error, set_verbosity_info\n\n\nif is_wandb_available():\n    import wandb\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.32.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef save_model_card(\n    args,\n    repo_id: str,\n    images=None,\n    repo_folder=None,\n):\n    img_str = \"\"\n    if len(images) > 0:\n        image_grid = make_image_grid(images, 1, len(args.validation_prompts))\n        image_grid.save(os.path.join(repo_folder, \"val_imgs_grid.png\"))\n        img_str += \"![val_imgs_grid](./val_imgs_grid.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: mit\nbase_model: {args.pretrained_prior_model_name_or_path}\ndatasets:\n- {args.dataset_name}\ntags:\n- wuerstchen\n- text-to-image\n- diffusers\n- diffusers-training\n- lora\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# LoRA Finetuning - {repo_id}\n\nThis pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \\n\n{img_str}\n\n## Pipeline usage\n\nYou can use the pipeline like so:\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipeline = AutoPipelineForText2Image.from_pretrained(\n                \"{args.pretrained_decoder_model_name_or_path}\", torch_dtype={args.weight_dtype}\n            )\n# load lora weights from folder:\npipeline.prior_pipe.load_lora_weights(\"{repo_id}\", torch_dtype={args.weight_dtype})\n\nimage = pipeline(prompt=prompt).images[0]\nimage.save(\"my_image.png\")\n```\n\n## Training info\n\nThese are the key hyperparameters used during training:\n\n* LoRA rank: {args.rank}\n* Epochs: {args.num_train_epochs}\n* Learning rate: {args.learning_rate}\n* Batch size: {args.train_batch_size}\n* Gradient accumulation steps: {args.gradient_accumulation_steps}\n* Image resolution: {args.resolution}\n* Mixed-precision: {args.mixed_precision}\n\n\"\"\"\n    wandb_info = \"\"\n    if is_wandb_available():\n        wandb_run_url = None\n        if wandb.run is not None:\n            wandb_run_url = wandb.run.url\n\n    if wandb_run_url is not None:\n        wandb_info = f\"\"\"\nMore information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).\n\"\"\"\n\n    model_card += wandb_info\n\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):\n    logger.info(\"Running validation... \")\n\n    pipeline = AutoPipelineForText2Image.from_pretrained(\n        args.pretrained_decoder_model_name_or_path,\n        prior=accelerator.unwrap_model(prior),\n        prior_text_encoder=accelerator.unwrap_model(text_encoder),\n        prior_tokenizer=tokenizer,\n        torch_dtype=weight_dtype,\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    images = []\n    for i in range(len(args.validation_prompts)):\n        with torch.cuda.amp.autocast():\n            image = pipeline(\n                args.validation_prompts[i],\n                prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,\n                generator=generator,\n                height=args.resolution,\n                width=args.resolution,\n            ).images[0]\n        images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompts[i]}\")\n                        for i, image in enumerate(images)\n                    ]\n                }\n            )\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    return images\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of finetuning Würstchen Prior.\")\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--pretrained_decoder_model_name_or_path\",\n        type=str,\n        default=\"warp-ai/wuerstchen\",\n        required=False,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_prior_model_name_or_path\",\n        type=str,\n        default=\"warp-ai/wuerstchen-prior\",\n        required=False,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompts\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\"A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`.\"),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"wuerstchen-model-finetuned-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=1, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"learning rate\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\n        \"--adam_weight_decay\",\n        type=float,\n        default=0.0,\n        required=False,\n        help=\"weight decay_to_use\",\n    )\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=5,\n        help=\"Run validation every X epochs.\",\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"text2image-fine-tune\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    return args\n\n\ndef main():\n    args = parse_args()\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load scheduler, effnet, tokenizer, clip_model\n    noise_scheduler = DDPMWuerstchenScheduler()\n    tokenizer = PreTrainedTokenizerFast.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"tokenizer\"\n    )\n\n    def deepspeed_zero_init_disabled_context_manager():\n        \"\"\"\n        returns either a context list that includes one that will disable zero.Init or an empty context list\n        \"\"\"\n        deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None\n        if deepspeed_plugin is None:\n            return []\n\n        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]\n\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n        pretrained_checkpoint_file = hf_hub_download(\"dome272/wuerstchen\", filename=\"model_v2_stage_b.pt\")\n        state_dict = torch.load(pretrained_checkpoint_file, map_location=\"cpu\")\n        image_encoder = EfficientNetEncoder()\n        image_encoder.load_state_dict(state_dict[\"effnet_state_dict\"])\n        image_encoder.eval()\n\n        text_encoder = CLIPTextModel.from_pretrained(\n            args.pretrained_prior_model_name_or_path, subfolder=\"text_encoder\", torch_dtype=weight_dtype\n        ).eval()\n\n    # Freeze text_encoder, cast to weight_dtype and image_encoder and move to device\n    text_encoder.requires_grad_(False)\n    image_encoder.requires_grad_(False)\n    image_encoder.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # load prior model, cast to weight_dtype and move to device\n    prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"prior\")\n    prior.to(accelerator.device, dtype=weight_dtype)\n\n    # lora attn processor\n    prior_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\", \"add_k_proj\", \"add_v_proj\"],\n    )\n    # Add adapter and make sure the trainable params are in float32.\n    prior.add_adapter(prior_lora_config)\n    if args.mixed_precision == \"fp16\":\n        for param in prior.parameters():\n            # only upcast trainable parameters (LoRA) into fp32\n            if param.requires_grad:\n                param.data = param.to(torch.float32)\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            prior_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(accelerator.unwrap_model(prior))):\n                    prior_lora_layers_to_save = get_peft_model_state_dict(model)\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            WuerstchenPriorPipeline.save_lora_weights(\n                output_dir,\n                unet_lora_layers=prior_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        prior_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(accelerator.unwrap_model(prior))):\n                prior_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, network_alphas = WuerstchenPriorPipeline.lora_state_dict(input_dir)\n        WuerstchenPriorPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=prior_)\n        WuerstchenPriorPipeline.load_lora_into_text_encoder(\n            lora_state_dict,\n            network_alphas=network_alphas,\n        )\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n    params_to_optimize = list(filter(lambda p: p.requires_grad, prior.parameters()))\n    optimizer = optimizer_cls(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        text_input_ids = inputs.input_ids\n        text_mask = inputs.attention_mask.bool()\n        return text_input_ids, text_mask\n\n    effnet_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"effnet_pixel_values\"] = [effnet_transforms(image) for image in images]\n        examples[\"text_input_ids\"], examples[\"text_mask\"] = tokenize_captions(examples)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        effnet_pixel_values = torch.stack([example[\"effnet_pixel_values\"] for example in examples])\n        effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float()\n        text_input_ids = torch.stack([example[\"text_input_ids\"] for example in examples])\n        text_mask = torch.stack([example[\"text_mask\"] for example in examples])\n        return {\"effnet_pixel_values\": effnet_pixel_values, \"text_input_ids\": text_input_ids, \"text_mask\": text_mask}\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n\n    prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        prior, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        tracker_config.pop(\"validation_prompts\")\n        accelerator.init_trackers(args.tracker_project_name, tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        prior.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(prior):\n                # Convert images to latent space\n                text_input_ids, text_mask, effnet_images = (\n                    batch[\"text_input_ids\"],\n                    batch[\"text_mask\"],\n                    batch[\"effnet_pixel_values\"].to(weight_dtype),\n                )\n\n                with torch.no_grad():\n                    text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask)\n                    prompt_embeds = text_encoder_output.last_hidden_state\n                    image_embeds = image_encoder(effnet_images)\n                    # scale\n                    image_embeds = image_embeds.add(1.0).div(42.0)\n\n                    # Sample noise that we'll add to the image_embeds\n                    noise = torch.randn_like(image_embeds)\n                    bsz = image_embeds.shape[0]\n\n                    # Sample a random timestep for each image\n                    timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)\n\n                    # add noise to latent\n                    noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)\n\n                # Predict the noise residual and compute losscd\n                pred_noise = prior(noisy_latents, timesteps, prompt_embeds)\n\n                # vanilla loss\n                loss = F.mse_loss(pred_noise.float(), noise.float(), reduction=\"mean\")\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompts is not None and epoch % args.validation_epochs == 0:\n                log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        prior = accelerator.unwrap_model(prior)\n        prior = prior.to(torch.float32)\n\n        prior_lora_state_dict = get_peft_model_state_dict(prior)\n\n        WuerstchenPriorPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=prior_lora_state_dict,\n        )\n\n        # Run a final round of inference.\n        images = []\n        if args.validation_prompts is not None:\n            logger.info(\"Running inference for collecting generated images...\")\n            pipeline = AutoPipelineForText2Image.from_pretrained(\n                args.pretrained_decoder_model_name_or_path,\n                prior_text_encoder=accelerator.unwrap_model(text_encoder),\n                prior_tokenizer=tokenizer,\n                torch_dtype=weight_dtype,\n            )\n            pipeline = pipeline.to(accelerator.device)\n\n            # load lora weights\n            pipeline.prior_pipe.load_lora_weights(args.output_dir, weight_name=\"pytorch_lora_weights.safetensors\")\n            pipeline.set_progress_bar_config(disable=True)\n\n            if args.seed is None:\n                generator = None\n            else:\n                generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n            for i in range(len(args.validation_prompts)):\n                with torch.cuda.amp.autocast():\n                    image = pipeline(\n                        args.validation_prompts[i],\n                        prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,\n                        generator=generator,\n                        width=args.resolution,\n                        height=args.resolution,\n                    ).images[0]\n                images.append(image)\n\n        if args.push_to_hub:\n            save_model_card(args, repo_id, images, repo_folder=args.output_dir)\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py",
    "content": "# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.state import AcceleratorState, is_initialized\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, hf_hub_download, upload_folder\nfrom modeling_efficient_net_encoder import EfficientNetEncoder\nfrom packaging import version\nfrom torchvision import transforms\nfrom tqdm import tqdm\nfrom transformers import CLIPTextModel, PreTrainedTokenizerFast\nfrom transformers.utils import ContextManagers\n\nfrom diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior\nfrom diffusers.training_utils import EMAModel\nfrom diffusers.utils import check_min_version, is_wandb_available, make_image_grid\nfrom diffusers.utils.logging import set_verbosity_error, set_verbosity_info\n\n\nif is_wandb_available():\n    import wandb\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.32.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef save_model_card(\n    args,\n    repo_id: str,\n    images=None,\n    repo_folder=None,\n):\n    img_str = \"\"\n    if len(images) > 0:\n        image_grid = make_image_grid(images, 1, len(args.validation_prompts))\n        image_grid.save(os.path.join(repo_folder, \"val_imgs_grid.png\"))\n        img_str += \"![val_imgs_grid](./val_imgs_grid.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: mit\nbase_model: {args.pretrained_prior_model_name_or_path}\ndatasets:\n- {args.dataset_name}\ntags:\n- wuerstchen\n- text-to-image\n- diffusers\n- diffusers-training\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# Finetuning - {repo_id}\n\nThis pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \\n\n{img_str}\n\n## Pipeline usage\n\nYou can use the pipeline like so:\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipe_prior = DiffusionPipeline.from_pretrained(\"{repo_id}\", torch_dtype={args.weight_dtype})\npipe_t2i = DiffusionPipeline.from_pretrained(\"{args.pretrained_decoder_model_name_or_path}\", torch_dtype={args.weight_dtype})\nprompt = \"{args.validation_prompts[0]}\"\n(image_embeds,) = pipe_prior(prompt).to_tuple()\nimage = pipe_t2i(image_embeddings=image_embeds, prompt=prompt).images[0]\nimage.save(\"my_image.png\")\n```\n\n## Training info\n\nThese are the key hyperparameters used during training:\n\n* Epochs: {args.num_train_epochs}\n* Learning rate: {args.learning_rate}\n* Batch size: {args.train_batch_size}\n* Gradient accumulation steps: {args.gradient_accumulation_steps}\n* Image resolution: {args.resolution}\n* Mixed-precision: {args.mixed_precision}\n\n\"\"\"\n    wandb_info = \"\"\n    if is_wandb_available():\n        wandb_run_url = None\n        if wandb.run is not None:\n            wandb_run_url = wandb.run.url\n\n    if wandb_run_url is not None:\n        wandb_info = f\"\"\"\nMore information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).\n\"\"\"\n\n    model_card += wandb_info\n\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):\n    logger.info(\"Running validation... \")\n\n    pipeline = AutoPipelineForText2Image.from_pretrained(\n        args.pretrained_decoder_model_name_or_path,\n        prior_prior=accelerator.unwrap_model(prior),\n        prior_text_encoder=accelerator.unwrap_model(text_encoder),\n        prior_tokenizer=tokenizer,\n        torch_dtype=weight_dtype,\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    images = []\n    for i in range(len(args.validation_prompts)):\n        with torch.autocast(\"cuda\"):\n            image = pipeline(\n                args.validation_prompts[i],\n                prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,\n                generator=generator,\n                height=args.resolution,\n                width=args.resolution,\n            ).images[0]\n\n        images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompts[i]}\")\n                        for i, image in enumerate(images)\n                    ]\n                }\n            )\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    return images\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of finetuning Würstchen Prior.\")\n    parser.add_argument(\n        \"--pretrained_decoder_model_name_or_path\",\n        type=str,\n        default=\"warp-ai/wuerstchen\",\n        required=False,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_prior_model_name_or_path\",\n        type=str,\n        default=\"warp-ai/wuerstchen-prior\",\n        required=False,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompts\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\"A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`.\"),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"wuerstchen-model-finetuned\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=1, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"learning rate\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\n        \"--adam_weight_decay\",\n        type=float,\n        default=0.0,\n        required=False,\n        help=\"weight decay_to_use\",\n    )\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=5,\n        help=\"Run validation every X epochs.\",\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"text2image-fine-tune\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    return args\n\n\ndef main():\n    args = parse_args()\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir\n    )\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load scheduler, effnet, tokenizer, clip_model\n    noise_scheduler = DDPMWuerstchenScheduler()\n    tokenizer = PreTrainedTokenizerFast.from_pretrained(\n        args.pretrained_prior_model_name_or_path, subfolder=\"tokenizer\"\n    )\n\n    def deepspeed_zero_init_disabled_context_manager():\n        \"\"\"\n        returns either a context list that includes one that will disable zero.Init or an empty context list\n        \"\"\"\n        deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None\n        if deepspeed_plugin is None:\n            return []\n\n        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]\n\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n        pretrained_checkpoint_file = hf_hub_download(\"dome272/wuerstchen\", filename=\"model_v2_stage_b.pt\")\n        state_dict = torch.load(pretrained_checkpoint_file, map_location=\"cpu\")\n        image_encoder = EfficientNetEncoder()\n        image_encoder.load_state_dict(state_dict[\"effnet_state_dict\"])\n        image_encoder.eval()\n\n        text_encoder = CLIPTextModel.from_pretrained(\n            args.pretrained_prior_model_name_or_path, subfolder=\"text_encoder\", torch_dtype=weight_dtype\n        ).eval()\n\n    # Freeze text_encoder and image_encoder\n    text_encoder.requires_grad_(False)\n    image_encoder.requires_grad_(False)\n\n    # load prior model\n    prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"prior\")\n\n    # Create EMA for the prior\n    if args.use_ema:\n        ema_prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder=\"prior\")\n        ema_prior = EMAModel(ema_prior.parameters(), model_cls=WuerstchenPrior, model_config=ema_prior.config)\n        ema_prior.to(accelerator.device)\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if args.use_ema:\n                ema_prior.save_pretrained(os.path.join(output_dir, \"prior_ema\"))\n\n            for i, model in enumerate(models):\n                model.save_pretrained(os.path.join(output_dir, \"prior\"))\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"prior_ema\"), WuerstchenPrior)\n                ema_prior.load_state_dict(load_model.state_dict())\n                ema_prior.to(accelerator.device)\n                del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = WuerstchenPrior.from_pretrained(input_dir, subfolder=\"prior\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        prior.enable_gradient_checkpointing()\n\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n    optimizer = optimizer_cls(\n        prior.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        text_input_ids = inputs.input_ids\n        text_mask = inputs.attention_mask.bool()\n        return text_input_ids, text_mask\n\n    effnet_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"effnet_pixel_values\"] = [effnet_transforms(image) for image in images]\n        examples[\"text_input_ids\"], examples[\"text_mask\"] = tokenize_captions(examples)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        effnet_pixel_values = torch.stack([example[\"effnet_pixel_values\"] for example in examples])\n        effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float()\n        text_input_ids = torch.stack([example[\"text_input_ids\"] for example in examples])\n        text_mask = torch.stack([example[\"text_mask\"] for example in examples])\n        return {\"effnet_pixel_values\": effnet_pixel_values, \"text_input_ids\": text_input_ids, \"text_mask\": text_mask}\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n\n    prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        prior, optimizer, train_dataloader, lr_scheduler\n    )\n    image_encoder.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        tracker_config.pop(\"validation_prompts\")\n        accelerator.init_trackers(args.tracker_project_name, tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        prior.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(prior):\n                # Convert images to latent space\n                text_input_ids, text_mask, effnet_images = (\n                    batch[\"text_input_ids\"],\n                    batch[\"text_mask\"],\n                    batch[\"effnet_pixel_values\"].to(weight_dtype),\n                )\n\n                with torch.no_grad():\n                    text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask)\n                    prompt_embeds = text_encoder_output.last_hidden_state\n                    image_embeds = image_encoder(effnet_images)\n                    # scale\n                    image_embeds = image_embeds.add(1.0).div(42.0)\n\n                    # Sample noise that we'll add to the image_embeds\n                    noise = torch.randn_like(image_embeds)\n                    bsz = image_embeds.shape[0]\n\n                    # Sample a random timestep for each image\n                    timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)\n\n                    # add noise to latent\n                    noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)\n\n                # Predict the noise residual and compute losscd\n                pred_noise = prior(noisy_latents, timesteps, prompt_embeds)\n\n                # vanilla loss\n                loss = F.mse_loss(pred_noise.float(), noise.float(), reduction=\"mean\")\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_prior.step(prior.parameters())\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompts is not None and epoch % args.validation_epochs == 0:\n                if args.use_ema:\n                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                    ema_prior.store(prior.parameters())\n                    ema_prior.copy_to(prior.parameters())\n                log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)\n                if args.use_ema:\n                    # Switch back to the original UNet parameters.\n                    ema_prior.restore(prior.parameters())\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        prior = accelerator.unwrap_model(prior)\n        if args.use_ema:\n            ema_prior.copy_to(prior.parameters())\n\n        pipeline = AutoPipelineForText2Image.from_pretrained(\n            args.pretrained_decoder_model_name_or_path,\n            prior_prior=prior,\n            prior_text_encoder=accelerator.unwrap_model(text_encoder),\n            prior_tokenizer=tokenizer,\n        )\n        pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, \"prior_pipeline\"))\n\n        # Run a final round of inference.\n        images = []\n        if args.validation_prompts is not None:\n            logger.info(\"Running inference for collecting generated images...\")\n            pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype)\n            pipeline.set_progress_bar_config(disable=True)\n\n            if args.seed is None:\n                generator = None\n            else:\n                generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n            for i in range(len(args.validation_prompts)):\n                with torch.autocast(\"cuda\"):\n                    image = pipeline(\n                        args.validation_prompts[i],\n                        prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,\n                        generator=generator,\n                        width=args.resolution,\n                        height=args.resolution,\n                    ).images[0]\n                images.append(image)\n\n        if args.push_to_hub:\n            save_model_card(args, repo_id, images, repo_folder=args.output_dir)\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/server/README.md",
    "content": "\n# Create a server\n\nDiffusers' pipelines can be used as an inference engine for a server. It supports concurrent and multithreaded requests to generate images that may be requested by multiple users at the same time.\n\nThis guide will show you how to use the [`StableDiffusion3Pipeline`] in a server, but feel free to use any pipeline you want.\n\n\nStart by navigating to the `examples/server` folder and installing all of the dependencies.\n\n```py\npip install diffusers\npip install -r requirements.txt\n```\n\nLaunch the server with the following command.\n\n```py\npython server.py\n```\n\nThe server is accessed at http://localhost:8000. You can curl this model with the following command.\n```\ncurl -X POST -H \"Content-Type: application/json\" --data '{\"model\": \"something\", \"prompt\": \"a kitten in front of a fireplace\"}' http://localhost:8000/v1/images/generations\n```\n\nIf you need to upgrade some dependencies, you can use either [pip-tools](https://github.com/jazzband/pip-tools) or [uv](https://github.com/astral-sh/uv). For example, upgrade the dependencies with `uv` using the following command.\n\n```\nuv pip compile requirements.in -o requirements.txt\n```\n\n\nThe server is built with [FastAPI](https://fastapi.tiangolo.com/async/). The endpoint for `v1/images/generations` is shown below.\n```py\n@app.post(\"/v1/images/generations\")\nasync def generate_image(image_input: TextToImageInput):\n    try:\n        loop = asyncio.get_event_loop()\n        scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)\n        pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)\n        generator = torch.Generator(device=\"cuda\")\n        generator.manual_seed(random.randint(0, 10000000))\n        output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))\n        logger.info(f\"output: {output}\")\n        image_url = save_image(output.images[0])\n        return {\"data\": [{\"url\": image_url}]}\n    except Exception as e:\n        if isinstance(e, HTTPException):\n            raise e\n        elif hasattr(e, 'message'):\n            raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())\n        raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())\n```\nThe `generate_image` function is defined as asynchronous with the [async](https://fastapi.tiangolo.com/async/) keyword so that FastAPI knows that whatever is happening in this function won't necessarily return a result right away. Once it hits some point in the function that it needs to await some other [Task](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task), the main thread goes back to answering other HTTP requests. This is shown in the code below with the [await](https://fastapi.tiangolo.com/async/#async-and-await) keyword.\n```py\noutput = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))\n```\nAt this point, the execution of the pipeline function is placed onto a [new thread](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor), and the main thread performs other things until a result is returned from the `pipeline`.\n\nAnother important aspect of this implementation is creating a `pipeline` from `shared_pipeline`. The goal behind this is to avoid loading the underlying model more than once onto the GPU while still allowing for each new request that is running on a separate thread to have its own generator and scheduler. The scheduler, in particular, is not thread-safe, and it will cause errors like: `IndexError: index 21 is out of bounds for dimension 0 with size 21` if you try to use the same scheduler across multiple threads.\n"
  },
  {
    "path": "examples/server/requirements.in",
    "content": "torch~=2.7.0\ntransformers==4.46.1\nsentencepiece\naiohttp\npy-consul\nprometheus_client >= 0.18.0\nprometheus-fastapi-instrumentator >= 7.0.0\nfastapi\nuvicorn\naccelerate\n"
  },
  {
    "path": "examples/server/requirements.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv pip compile requirements.in -o requirements.txt\naiohappyeyeballs==2.6.1\n    # via aiohttp\naiohttp==3.12.14\n    # via -r requirements.in\naiosignal==1.4.0\n    # via aiohttp\nannotated-types==0.7.0\n    # via pydantic\nanyio==4.6.2.post1\n    # via starlette\nasync-timeout==4.0.3\n    # via aiohttp\nattrs==24.2.0\n    # via aiohttp\ncertifi==2024.8.30\n    # via requests\ncharset-normalizer==3.4.0\n    # via requests\nclick==8.1.7\n    # via uvicorn\nexceptiongroup==1.3.0\n    # via anyio\nfastapi==0.115.3\n    # via -r requirements.in\nfilelock==3.16.1\n    # via\n    #   huggingface-hub\n    #   torch\n    #   transformers\nfrozenlist==1.5.0\n    # via\n    #   aiohttp\n    #   aiosignal\nfsspec==2024.10.0\n    # via\n    #   huggingface-hub\n    #   torch\nh11==0.14.0\n    # via uvicorn\nhuggingface-hub==0.35.0\n    # via\n    #   tokenizers\n    #   transformers\nidna==3.10\n    # via\n    #   anyio\n    #   requests\n    #   yarl\njinja2==3.1.4\n    # via torch\nmarkupsafe==3.0.2\n    # via jinja2\nmpmath==1.3.0\n    # via sympy\nmultidict==6.1.0\n    # via\n    #   aiohttp\n    #   yarl\nnetworkx==3.2.1\n    # via torch\nnumpy==2.0.2\n    # via transformers\nnvidia-cublas-cu12==12.6.4.1\n    # via\n    #   nvidia-cudnn-cu12\n    #   nvidia-cusolver-cu12\n    #   torch\nnvidia-cuda-cupti-cu12==12.6.80\n    # via torch\nnvidia-cuda-nvrtc-cu12==12.6.77\n    # via torch\nnvidia-cuda-runtime-cu12==12.6.77\n    # via torch\nnvidia-cudnn-cu12==9.5.1.17\n    # via torch\nnvidia-cufft-cu12==11.3.0.4\n    # via torch\nnvidia-cufile-cu12==1.11.1.6\n    # via torch\nnvidia-curand-cu12==10.3.7.77\n    # via torch\nnvidia-cusolver-cu12==11.7.1.2\n    # via torch\nnvidia-cusparse-cu12==12.5.4.2\n    # via\n    #   nvidia-cusolver-cu12\n    #   torch\nnvidia-cusparselt-cu12==0.6.3\n    # via torch\nnvidia-nccl-cu12==2.26.2\n    # via torch\nnvidia-nvjitlink-cu12==12.6.85\n    # via\n    #   nvidia-cufft-cu12\n    #   nvidia-cusolver-cu12\n    #   nvidia-cusparse-cu12\n    #   torch\nnvidia-nvtx-cu12==12.6.77\n    # via torch\npackaging==24.1\n    # via\n    #   huggingface-hub\n    #   transformers\nprometheus-client==0.21.0\n    # via\n    #   -r requirements.in\n    #   prometheus-fastapi-instrumentator\nprometheus-fastapi-instrumentator==7.0.0\n    # via -r requirements.in\npropcache==0.2.0\n    # via\n    #   aiohttp\n    #   yarl\npy-consul==1.5.3\n    # via -r requirements.in\npydantic==2.9.2\n    # via fastapi\npydantic-core==2.23.4\n    # via pydantic\npyyaml==6.0.2\n    # via\n    #   huggingface-hub\n    #   transformers\nregex==2024.9.11\n    # via transformers\nrequests==2.32.3\n    # via\n    #   huggingface-hub\n    #   py-consul\n    #   transformers\nsafetensors==0.4.5\n    # via transformers\nsentencepiece==0.2.0\n    # via -r requirements.in\nsniffio==1.3.1\n    # via anyio\nstarlette==0.41.0\n    # via\n    #   fastapi\n    #   prometheus-fastapi-instrumentator\nsympy==1.13.3\n    # via torch\ntokenizers==0.20.1\n    # via transformers\ntorch==2.7.0\n    # via -r requirements.in\ntqdm==4.66.5\n    # via\n    #   huggingface-hub\n    #   transformers\ntransformers==4.46.1\n    # via -r requirements.in\ntriton==3.3.0\n    # via torch\ntyping-extensions==4.12.2\n    # via\n    #   aiosignal\n    #   anyio\n    #   exceptiongroup\n    #   fastapi\n    #   huggingface-hub\n    #   multidict\n    #   pydantic\n    #   pydantic-core\n    #   starlette\n    #   torch\n    #   uvicorn\nurllib3==2.5.0\n    # via requests\nuvicorn==0.32.0\n    # via -r requirements.in\nyarl==1.18.3\n    # via aiohttp\n"
  },
  {
    "path": "examples/server/server.py",
    "content": "import asyncio\nimport logging\nimport os\nimport random\nimport tempfile\nimport traceback\nimport uuid\n\nimport aiohttp\nimport torch\nfrom fastapi import FastAPI, HTTPException\nfrom fastapi.middleware.cors import CORSMiddleware\nfrom fastapi.staticfiles import StaticFiles\nfrom pydantic import BaseModel\n\nfrom diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Pipeline\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass TextToImageInput(BaseModel):\n    model: str\n    prompt: str\n    size: str | None = None\n    n: int | None = None\n\n\nclass HttpClient:\n    session: aiohttp.ClientSession = None\n\n    def start(self):\n        self.session = aiohttp.ClientSession()\n\n    async def stop(self):\n        await self.session.close()\n        self.session = None\n\n    def __call__(self) -> aiohttp.ClientSession:\n        assert self.session is not None\n        return self.session\n\n\nclass TextToImagePipeline:\n    pipeline: StableDiffusion3Pipeline = None\n    device: str = None\n\n    def start(self):\n        if torch.cuda.is_available():\n            model_path = os.getenv(\"MODEL_PATH\", \"stabilityai/stable-diffusion-3.5-large\")\n            logger.info(\"Loading CUDA\")\n            self.device = \"cuda\"\n            self.pipeline = StableDiffusion3Pipeline.from_pretrained(\n                model_path,\n                torch_dtype=torch.bfloat16,\n            ).to(device=self.device)\n        elif torch.backends.mps.is_available():\n            model_path = os.getenv(\"MODEL_PATH\", \"stabilityai/stable-diffusion-3.5-medium\")\n            logger.info(\"Loading MPS for Mac M Series\")\n            self.device = \"mps\"\n            self.pipeline = StableDiffusion3Pipeline.from_pretrained(\n                model_path,\n                torch_dtype=torch.bfloat16,\n            ).to(device=self.device)\n        else:\n            raise Exception(\"No CUDA or MPS device available\")\n\n\napp = FastAPI()\nservice_url = os.getenv(\"SERVICE_URL\", \"http://localhost:8000\")\nimage_dir = os.path.join(tempfile.gettempdir(), \"images\")\nif not os.path.exists(image_dir):\n    os.makedirs(image_dir)\napp.mount(\"/images\", StaticFiles(directory=image_dir), name=\"images\")\nhttp_client = HttpClient()\nshared_pipeline = TextToImagePipeline()\n\n# Configure CORS settings\napp.add_middleware(\n    CORSMiddleware,\n    allow_origins=[\"*\"],  # Allows all origins\n    allow_credentials=True,\n    allow_methods=[\"*\"],  # Allows all methods, e.g., GET, POST, OPTIONS, etc.\n    allow_headers=[\"*\"],  # Allows all headers\n)\n\n\n@app.on_event(\"startup\")\ndef startup():\n    http_client.start()\n    shared_pipeline.start()\n\n\ndef save_image(image):\n    filename = \"draw\" + str(uuid.uuid4()).split(\"-\")[0] + \".png\"\n    image_path = os.path.join(image_dir, filename)\n    # write image to disk at image_path\n    logger.info(f\"Saving image to {image_path}\")\n    image.save(image_path)\n    return os.path.join(service_url, \"images\", filename)\n\n\n@app.get(\"/\")\n@app.post(\"/\")\n@app.options(\"/\")\nasync def base():\n    return \"Welcome to Diffusers! Where you can use diffusion models to generate images\"\n\n\n@app.post(\"/v1/images/generations\")\nasync def generate_image(image_input: TextToImageInput):\n    try:\n        loop = asyncio.get_event_loop()\n        scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)\n        pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)\n        generator = torch.Generator(device=shared_pipeline.device)\n        generator.manual_seed(random.randint(0, 10000000))\n        output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator=generator))\n        logger.info(f\"output: {output}\")\n        image_url = save_image(output.images[0])\n        return {\"data\": [{\"url\": image_url}]}\n    except Exception as e:\n        if isinstance(e, HTTPException):\n            raise e\n        elif hasattr(e, \"message\"):\n            raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())\n        raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())\n\n\nif __name__ == \"__main__\":\n    import uvicorn\n\n    uvicorn.run(app, host=\"0.0.0.0\", port=8000)\n"
  },
  {
    "path": "examples/server-async/Pipelines.py",
    "content": "import logging\nimport os\nfrom dataclasses import dataclass, field\nfrom typing import List\n\nimport torch\nfrom pydantic import BaseModel\n\nfrom diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass TextToImageInput(BaseModel):\n    model: str\n    prompt: str\n    size: str | None = None\n    n: int | None = None\n\n\n@dataclass\nclass PresetModels:\n    SD3: List[str] = field(default_factory=lambda: [\"stabilityai/stable-diffusion-3-medium\"])\n    SD3_5: List[str] = field(\n        default_factory=lambda: [\n            \"stabilityai/stable-diffusion-3.5-large\",\n            \"stabilityai/stable-diffusion-3.5-large-turbo\",\n            \"stabilityai/stable-diffusion-3.5-medium\",\n        ]\n    )\n\n\nclass TextToImagePipelineSD3:\n    def __init__(self, model_path: str | None = None):\n        self.model_path = model_path or os.getenv(\"MODEL_PATH\")\n        self.pipeline: StableDiffusion3Pipeline | None = None\n        self.device: str | None = None\n\n    def start(self):\n        if torch.cuda.is_available():\n            model_path = self.model_path or \"stabilityai/stable-diffusion-3.5-large\"\n            logger.info(\"Loading CUDA\")\n            self.device = \"cuda\"\n            self.pipeline = StableDiffusion3Pipeline.from_pretrained(\n                model_path,\n                torch_dtype=torch.float16,\n            ).to(device=self.device)\n        elif torch.backends.mps.is_available():\n            model_path = self.model_path or \"stabilityai/stable-diffusion-3.5-medium\"\n            logger.info(\"Loading MPS for Mac M Series\")\n            self.device = \"mps\"\n            self.pipeline = StableDiffusion3Pipeline.from_pretrained(\n                model_path,\n                torch_dtype=torch.bfloat16,\n            ).to(device=self.device)\n        else:\n            raise Exception(\"No CUDA or MPS device available\")\n\n\nclass ModelPipelineInitializer:\n    def __init__(self, model: str = \"\", type_models: str = \"t2im\"):\n        self.model = model\n        self.type_models = type_models\n        self.pipeline = None\n        self.device = \"cuda\" if torch.cuda.is_available() else \"mps\"\n        self.model_type = None\n\n    def initialize_pipeline(self):\n        if not self.model:\n            raise ValueError(\"Model name not provided\")\n\n        # Check if model exists in PresetModels\n        preset_models = PresetModels()\n\n        # Determine which model type we're dealing with\n        if self.model in preset_models.SD3:\n            self.model_type = \"SD3\"\n        elif self.model in preset_models.SD3_5:\n            self.model_type = \"SD3_5\"\n\n        # Create appropriate pipeline based on model type and type_models\n        if self.type_models == \"t2im\":\n            if self.model_type in [\"SD3\", \"SD3_5\"]:\n                self.pipeline = TextToImagePipelineSD3(self.model)\n            else:\n                raise ValueError(f\"Model type {self.model_type} not supported for text-to-image\")\n        elif self.type_models == \"t2v\":\n            raise ValueError(f\"Unsupported type_models: {self.type_models}\")\n\n        return self.pipeline\n"
  },
  {
    "path": "examples/server-async/README.md",
    "content": "# Asynchronous server and parallel execution of models\n\n> Example/demo server that keeps a single model in memory while safely running parallel inference requests by creating per-request lightweight views and cloning only small, stateful components (schedulers, RNG state, small mutable attrs). Works with StableDiffusion3 pipelines.\n> We recommend running 10 to 50 inferences in parallel for optimal performance, averaging between 25 and 30 seconds to 1 minute and 1 minute and 30 seconds. (This is only recommended if you have a GPU with 35GB of VRAM or more; otherwise, keep it to one or two inferences in parallel to avoid decoding or saving errors due to memory shortages.)\n\n## ⚠️ IMPORTANT\n\n* The example demonstrates how to run pipelines like `StableDiffusion3-3.5` concurrently while keeping a single copy of the heavy model parameters on GPU.\n\n## Necessary components\n\nAll the components needed to create the inference server are in the current directory:\n\n```\nserver-async/\n├── utils/\n├─────── __init__.py\n├─────── scheduler.py              # BaseAsyncScheduler wrapper and async_retrieve_timesteps for secure inferences\n├─────── requestscopedpipeline.py  # RequestScoped Pipeline for inference with a single in-memory model\n├─────── utils.py                  # Image/video saving utilities and service configuration\n├── Pipelines.py                   # pipeline loader classes (SD3)\n├── serverasync.py                 # FastAPI app with lifespan management and async inference endpoints\n├── test.py                        # Client test script for inference requests\n├── requirements.txt               # Dependencies\n└── README.md                      # This documentation\n```\n\n## What `diffusers-async` adds / Why we needed it\n\nCore problem: a naive server that calls `pipe.__call__` concurrently can hit **race conditions** (e.g., `scheduler.set_timesteps` mutates shared state) or explode memory by deep-copying the whole pipeline per-request.\n\n`diffusers-async` / this example addresses that by:\n\n* **Request-scoped views**: `RequestScopedPipeline` creates a shallow copy of the pipeline per request so heavy weights (UNet, VAE, text encoder) remain shared and *are not duplicated*.\n* **Per-request mutable state**: stateful small objects (scheduler, RNG state, small lists/dicts, callbacks) are cloned per request. The system uses `BaseAsyncScheduler.clone_for_request(...)` for scheduler cloning, with fallback to safe `deepcopy` or other heuristics.\n* **Tokenizer concurrency safety**: `RequestScopedPipeline` now manages an internal tokenizer lock with automatic tokenizer detection and wrapping. This ensures that Rust tokenizers are safe to use under concurrency — race condition errors like `Already borrowed` no longer occur.\n* **`async_retrieve_timesteps(..., return_scheduler=True)`**: fully retro-compatible helper that returns `(timesteps, num_inference_steps, scheduler)` without mutating the shared scheduler. For users not using `return_scheduler=True`, the behavior is identical to the original API.\n* **Robust attribute handling**: wrapper avoids writing to read-only properties (e.g., `components`) and auto-detects small mutable attributes to clone while avoiding duplication of large tensors. Configurable tensor size threshold prevents cloning of large tensors.\n* **Enhanced scheduler wrapping**: `BaseAsyncScheduler` automatically wraps schedulers with improved `__getattr__`, `__setattr__`, and debugging methods (`__repr__`, `__str__`).\n\n## How the server works (high-level flow)\n\n1. **Single model instance** is loaded into memory (GPU/MPS) when the server starts.\n2. On each HTTP inference request:\n\n   * The server uses `RequestScopedPipeline.generate(...)` which:\n\n     * automatically wraps the base scheduler in `BaseAsyncScheduler` (if not already wrapped),\n     * obtains a *local scheduler* (via `clone_for_request()` or `deepcopy`),\n     * does `local_pipe = copy.copy(base_pipe)` (shallow copy),\n     * sets `local_pipe.scheduler = local_scheduler` (if possible),\n     * clones only small mutable attributes (callbacks, rng, small latents) with auto-detection,\n     * wraps tokenizers with thread-safe locks to prevent race conditions,\n     * optionally enters a `model_cpu_offload_context()` for memory offload hooks,\n     * calls the pipeline on the local view (`local_pipe(...)`).\n3. **Result**: inference completes, images are moved to CPU & saved (if requested), internal buffers freed (GC + `torch.cuda.empty_cache()`).\n4. Multiple requests can run in parallel while sharing heavy weights and isolating mutable state.\n\n## How to set up and run the server\n\n### 1) Install dependencies\n\nRecommended: create a virtualenv / conda environment.\n\n```bash\npip install diffusers\npip install -r requirements.txt\n```\n\n### 2) Start the server\n\nUsing the `serverasync.py` file that already has everything you need:\n\n```bash\npython serverasync.py\n```\n\nThe server will start on `http://localhost:8500` by default with the following features:\n- FastAPI application with async lifespan management\n- Automatic model loading and pipeline initialization\n- Request counting and active inference tracking\n- Memory cleanup after each inference\n- CORS middleware for cross-origin requests\n\n### 3) Test the server\n\nUse the included test script:\n\n```bash\npython test.py\n```\n\nOr send a manual request:\n\n`POST /api/diffusers/inference` with JSON body:\n\n```json\n{\n  \"prompt\": \"A futuristic cityscape, vibrant colors\",\n  \"num_inference_steps\": 30,\n  \"num_images_per_prompt\": 1\n}\n```\n\nResponse example:\n\n```json\n{\n  \"response\": [\"http://localhost:8500/images/img123.png\"]\n}\n```\n\n### 4) Server endpoints\n\n- `GET /` - Welcome message\n- `POST /api/diffusers/inference` - Main inference endpoint\n- `GET /images/{filename}` - Serve generated images\n- `GET /api/status` - Server status and memory info\n\n## Advanced Configuration\n\n### RequestScopedPipeline Parameters\n\n```python\nRequestScopedPipeline(\n    pipeline,                        # Base pipeline to wrap\n    mutable_attrs=None,             # Custom list of attributes to clone\n    auto_detect_mutables=True,      # Enable automatic detection of mutable attributes\n    tensor_numel_threshold=1_000_000, # Tensor size threshold for cloning\n    tokenizer_lock=None,            # Custom threading lock for tokenizers\n    wrap_scheduler=True             # Auto-wrap scheduler in BaseAsyncScheduler\n)\n```\n\n### BaseAsyncScheduler Features\n\n* Transparent proxy to the original scheduler with `__getattr__` and `__setattr__`\n* `clone_for_request()` method for safe per-request scheduler cloning\n* Enhanced debugging with `__repr__` and `__str__` methods\n* Full compatibility with existing scheduler APIs\n\n### Server Configuration\n\nThe server configuration can be modified in `serverasync.py` through the `ServerConfigModels` dataclass:\n\n```python\n@dataclass\nclass ServerConfigModels:\n    model: str = 'stabilityai/stable-diffusion-3.5-medium'  \n    type_models: str = 't2im'  \n    host: str = '0.0.0.0' \n    port: int = 8500\n```\n\n## Troubleshooting (quick)\n\n* `Already borrowed` — previously a Rust tokenizer concurrency error.\n  ✅ This is now fixed: `RequestScopedPipeline` automatically detects and wraps tokenizers with thread locks, so race conditions no longer happen.\n\n* `can't set attribute 'components'` — pipeline exposes read-only `components`.\n  ✅ The RequestScopedPipeline now detects read-only properties and skips setting them automatically.\n\n* Scheduler issues:\n  * If the scheduler doesn't implement `clone_for_request` and `deepcopy` fails, we log and fallback — but prefer `async_retrieve_timesteps(..., return_scheduler=True)` to avoid mutating the shared scheduler.\n  ✅ Note: `async_retrieve_timesteps` is fully retro-compatible — if you don't pass `return_scheduler=True`, the behavior is unchanged.\n\n* Memory issues with large tensors:\n  ✅ The system now has configurable `tensor_numel_threshold` to prevent cloning of large tensors while still cloning small mutable ones.\n\n* Automatic tokenizer detection:\n  ✅ The system automatically identifies tokenizer components by checking for tokenizer methods, class names, and attributes, then applies thread-safe wrappers."
  },
  {
    "path": "examples/server-async/requirements.txt",
    "content": "torch \ntorchvision \ntransformers \nsentencepiece \nfastapi \nuvicorn \nftfy\naccelerate\nxformers\nprotobuf"
  },
  {
    "path": "examples/server-async/serverasync.py",
    "content": "import asyncio\nimport gc\nimport logging\nimport os\nimport random\nimport threading\nfrom contextlib import asynccontextmanager\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, Optional, Type\n\nimport torch\nfrom fastapi import FastAPI, HTTPException, Request\nfrom fastapi.concurrency import run_in_threadpool\nfrom fastapi.middleware.cors import CORSMiddleware\nfrom fastapi.responses import FileResponse\nfrom Pipelines import ModelPipelineInitializer\nfrom pydantic import BaseModel\n\nfrom utils import RequestScopedPipeline, Utils\n\n\n@dataclass\nclass ServerConfigModels:\n    model: str = \"stabilityai/stable-diffusion-3.5-medium\"\n    type_models: str = \"t2im\"\n    constructor_pipeline: Optional[Type] = None\n    custom_pipeline: Optional[Type] = None\n    components: Optional[Dict[str, Any]] = None\n    torch_dtype: Optional[torch.dtype] = None\n    host: str = \"0.0.0.0\"\n    port: int = 8500\n\n\nserver_config = ServerConfigModels()\n\n\n@asynccontextmanager\nasync def lifespan(app: FastAPI):\n    logging.basicConfig(level=logging.INFO)\n    app.state.logger = logging.getLogger(\"diffusers-server\")\n    os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:128,expandable_segments:True\"\n    os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"0\"\n\n    app.state.total_requests = 0\n    app.state.active_inferences = 0\n    app.state.metrics_lock = asyncio.Lock()\n    app.state.metrics_task = None\n\n    app.state.utils_app = Utils(\n        host=server_config.host,\n        port=server_config.port,\n    )\n\n    async def metrics_loop():\n        try:\n            while True:\n                async with app.state.metrics_lock:\n                    total = app.state.total_requests\n                    active = app.state.active_inferences\n                app.state.logger.info(f\"[METRICS] total_requests={total} active_inferences={active}\")\n                await asyncio.sleep(5)\n        except asyncio.CancelledError:\n            app.state.logger.info(\"Metrics loop cancelled\")\n            raise\n\n    app.state.metrics_task = asyncio.create_task(metrics_loop())\n\n    try:\n        yield\n    finally:\n        task = app.state.metrics_task\n        if task:\n            task.cancel()\n            try:\n                await task\n            except asyncio.CancelledError:\n                pass\n\n        try:\n            stop_fn = getattr(model_pipeline, \"stop\", None) or getattr(model_pipeline, \"close\", None)\n            if callable(stop_fn):\n                await run_in_threadpool(stop_fn)\n        except Exception as e:\n            app.state.logger.warning(f\"Error during pipeline shutdown: {e}\")\n\n        app.state.logger.info(\"Lifespan shutdown complete\")\n\n\napp = FastAPI(lifespan=lifespan)\n\nlogger = logging.getLogger(\"DiffusersServer.Pipelines\")\n\n\ninitializer = ModelPipelineInitializer(\n    model=server_config.model,\n    type_models=server_config.type_models,\n)\nmodel_pipeline = initializer.initialize_pipeline()\nmodel_pipeline.start()\n\nrequest_pipe = RequestScopedPipeline(model_pipeline.pipeline)\npipeline_lock = threading.Lock()\n\nlogger.info(f\"Pipeline initialized and ready to receive requests (model ={server_config.model})\")\n\napp.state.MODEL_INITIALIZER = initializer\napp.state.MODEL_PIPELINE = model_pipeline\napp.state.REQUEST_PIPE = request_pipe\napp.state.PIPELINE_LOCK = pipeline_lock\n\n\nclass JSONBodyQueryAPI(BaseModel):\n    model: str | None = None\n    prompt: str\n    negative_prompt: str | None = None\n    num_inference_steps: int = 28\n    num_images_per_prompt: int = 1\n\n\n@app.middleware(\"http\")\nasync def count_requests_middleware(request: Request, call_next):\n    async with app.state.metrics_lock:\n        app.state.total_requests += 1\n    response = await call_next(request)\n    return response\n\n\n@app.get(\"/\")\nasync def root():\n    return {\"message\": \"Welcome to the Diffusers Server\"}\n\n\n@app.post(\"/api/diffusers/inference\")\nasync def api(json: JSONBodyQueryAPI):\n    prompt = json.prompt\n    negative_prompt = json.negative_prompt or \"\"\n    num_steps = json.num_inference_steps\n    num_images_per_prompt = json.num_images_per_prompt\n\n    wrapper = app.state.MODEL_PIPELINE\n    initializer = app.state.MODEL_INITIALIZER\n\n    utils_app = app.state.utils_app\n\n    if not wrapper or not wrapper.pipeline:\n        raise HTTPException(500, \"Model not initialized correctly\")\n    if not prompt.strip():\n        raise HTTPException(400, \"No prompt provided\")\n\n    def make_generator():\n        g = torch.Generator(device=initializer.device)\n        return g.manual_seed(random.randint(0, 10_000_000))\n\n    req_pipe = app.state.REQUEST_PIPE\n\n    def infer():\n        gen = make_generator()\n        return req_pipe.generate(\n            prompt=prompt,\n            negative_prompt=negative_prompt,\n            generator=gen,\n            num_inference_steps=num_steps,\n            num_images_per_prompt=num_images_per_prompt,\n            device=initializer.device,\n            output_type=\"pil\",\n        )\n\n    try:\n        async with app.state.metrics_lock:\n            app.state.active_inferences += 1\n\n        output = await run_in_threadpool(infer)\n\n        async with app.state.metrics_lock:\n            app.state.active_inferences = max(0, app.state.active_inferences - 1)\n\n        urls = [utils_app.save_image(img) for img in output.images]\n        return {\"response\": urls}\n\n    except Exception as e:\n        async with app.state.metrics_lock:\n            app.state.active_inferences = max(0, app.state.active_inferences - 1)\n        logger.error(f\"Error during inference: {e}\")\n        raise HTTPException(500, f\"Error in processing: {e}\")\n\n    finally:\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n            torch.cuda.empty_cache()\n            torch.cuda.reset_peak_memory_stats()\n            torch.cuda.ipc_collect()\n        gc.collect()\n\n\n@app.get(\"/images/{filename}\")\nasync def serve_image(filename: str):\n    utils_app = app.state.utils_app\n    file_path = os.path.join(utils_app.image_dir, filename)\n    if not os.path.isfile(file_path):\n        raise HTTPException(status_code=404, detail=\"Image not found\")\n    return FileResponse(file_path, media_type=\"image/png\")\n\n\n@app.get(\"/api/status\")\nasync def get_status():\n    memory_info = {}\n    if torch.cuda.is_available():\n        memory_allocated = torch.cuda.memory_allocated() / 1024**3  # GB\n        memory_reserved = torch.cuda.memory_reserved() / 1024**3  # GB\n        memory_info = {\n            \"memory_allocated_gb\": round(memory_allocated, 2),\n            \"memory_reserved_gb\": round(memory_reserved, 2),\n            \"device\": torch.cuda.get_device_name(0),\n        }\n\n    return {\"current_model\": server_config.model, \"type_models\": server_config.type_models, \"memory\": memory_info}\n\n\napp.add_middleware(\n    CORSMiddleware,\n    allow_origins=[\"*\"],\n    allow_credentials=True,\n    allow_methods=[\"*\"],\n    allow_headers=[\"*\"],\n)\n\nif __name__ == \"__main__\":\n    import uvicorn\n\n    uvicorn.run(app, host=server_config.host, port=server_config.port)\n"
  },
  {
    "path": "examples/server-async/test.py",
    "content": "import os\nimport time\nimport urllib.parse\n\nimport requests\n\n\nSERVER_URL = \"http://localhost:8500/api/diffusers/inference\"\nBASE_URL = \"http://localhost:8500\"\nDOWNLOAD_FOLDER = \"generated_images\"\nWAIT_BEFORE_DOWNLOAD = 2  # seconds\n\nos.makedirs(DOWNLOAD_FOLDER, exist_ok=True)\n\n\ndef save_from_url(url: str) -> str:\n    \"\"\"Download the given URL (relative or absolute) and save it locally.\"\"\"\n    if url.startswith(\"/\"):\n        direct = BASE_URL.rstrip(\"/\") + url\n    else:\n        direct = url\n    resp = requests.get(direct, timeout=60)\n    resp.raise_for_status()\n    filename = os.path.basename(urllib.parse.urlparse(direct).path) or f\"img_{int(time.time())}.png\"\n    path = os.path.join(DOWNLOAD_FOLDER, filename)\n    with open(path, \"wb\") as f:\n        f.write(resp.content)\n    return path\n\n\ndef main():\n    payload = {\n        \"prompt\": \"The T-800 Terminator Robot Returning From The Future, Anime Style\",\n        \"num_inference_steps\": 30,\n        \"num_images_per_prompt\": 1,\n    }\n\n    print(\"Sending request...\")\n    try:\n        r = requests.post(SERVER_URL, json=payload, timeout=480)\n        r.raise_for_status()\n    except Exception as e:\n        print(f\"Request failed: {e}\")\n        return\n\n    body = r.json().get(\"response\", [])\n    # Normalize to a list\n    urls = body if isinstance(body, list) else [body] if body else []\n    if not urls:\n        print(\"No URLs found in the response. Check the server output.\")\n        return\n\n    print(f\"Received {len(urls)} URL(s). Waiting {WAIT_BEFORE_DOWNLOAD}s before downloading...\")\n    time.sleep(WAIT_BEFORE_DOWNLOAD)\n\n    for u in urls:\n        try:\n            path = save_from_url(u)\n            print(f\"Image saved to: {path}\")\n        except Exception as e:\n            print(f\"Error downloading {u}: {e}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/server-async/utils/__init__.py",
    "content": "from .requestscopedpipeline import RequestScopedPipeline\nfrom .utils import Utils\n"
  },
  {
    "path": "examples/server-async/utils/requestscopedpipeline.py",
    "content": "import copy\nimport threading\nfrom typing import Any, Iterable, List, Optional\n\nimport torch\n\nfrom diffusers.utils import logging\n\nfrom .scheduler import BaseAsyncScheduler, async_retrieve_timesteps\nfrom .wrappers import ThreadSafeImageProcessorWrapper, ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass RequestScopedPipeline:\n    DEFAULT_MUTABLE_ATTRS = [\n        \"_all_hooks\",\n        \"_offload_device\",\n        \"_progress_bar_config\",\n        \"_progress_bar\",\n        \"_rng_state\",\n        \"_last_seed\",\n        \"latents\",\n    ]\n\n    def __init__(\n        self,\n        pipeline: Any,\n        mutable_attrs: Optional[Iterable[str]] = None,\n        auto_detect_mutables: bool = True,\n        tensor_numel_threshold: int = 1_000_000,\n        tokenizer_lock: Optional[threading.Lock] = None,\n        wrap_scheduler: bool = True,\n    ):\n        self._base = pipeline\n\n        self.unet = getattr(pipeline, \"unet\", None)\n        self.vae = getattr(pipeline, \"vae\", None)\n        self.text_encoder = getattr(pipeline, \"text_encoder\", None)\n        self.components = getattr(pipeline, \"components\", None)\n\n        self.transformer = getattr(pipeline, \"transformer\", None)\n\n        if wrap_scheduler and hasattr(pipeline, \"scheduler\") and pipeline.scheduler is not None:\n            if not isinstance(pipeline.scheduler, BaseAsyncScheduler):\n                pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)\n\n        self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)\n\n        self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()\n\n        self._vae_lock = threading.Lock()\n        self._image_lock = threading.Lock()\n\n        self._auto_detect_mutables = bool(auto_detect_mutables)\n        self._tensor_numel_threshold = int(tensor_numel_threshold)\n        self._auto_detected_attrs: List[str] = []\n\n    def _detect_kernel_pipeline(self, pipeline) -> bool:\n        kernel_indicators = [\n            \"text_encoding_cache\",\n            \"memory_manager\",\n            \"enable_optimizations\",\n            \"_create_request_context\",\n            \"get_optimization_stats\",\n        ]\n\n        return any(hasattr(pipeline, attr) for attr in kernel_indicators)\n\n    def _make_local_scheduler(self, num_inference_steps: int, device: str | None = None, **clone_kwargs):\n        base_sched = getattr(self._base, \"scheduler\", None)\n        if base_sched is None:\n            return None\n\n        if not isinstance(base_sched, BaseAsyncScheduler):\n            wrapped_scheduler = BaseAsyncScheduler(base_sched)\n        else:\n            wrapped_scheduler = base_sched\n\n        try:\n            return wrapped_scheduler.clone_for_request(\n                num_inference_steps=num_inference_steps, device=device, **clone_kwargs\n            )\n        except Exception as e:\n            logger.debug(f\"clone_for_request failed: {e}; trying shallow copy fallback\")\n            try:\n                if hasattr(wrapped_scheduler, \"scheduler\"):\n                    try:\n                        copied_scheduler = copy.copy(wrapped_scheduler.scheduler)\n                        return BaseAsyncScheduler(copied_scheduler)\n                    except Exception:\n                        return wrapped_scheduler\n                else:\n                    copied_scheduler = copy.copy(wrapped_scheduler)\n                    return BaseAsyncScheduler(copied_scheduler)\n            except Exception as e2:\n                logger.warning(\n                    f\"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*).\"\n                )\n                return wrapped_scheduler\n\n    def _autodetect_mutables(self, max_attrs: int = 40):\n        if not self._auto_detect_mutables:\n            return []\n\n        if self._auto_detected_attrs:\n            return self._auto_detected_attrs\n\n        candidates: List[str] = []\n        seen = set()\n\n        for name in dir(self._base):\n            if name.startswith(\"__\"):\n                continue\n            if name in self._mutable_attrs:\n                continue\n            if name in (\"to\", \"save_pretrained\", \"from_pretrained\"):\n                continue\n\n            try:\n                val = getattr(self._base, name)\n            except Exception:\n                continue\n\n            import types\n\n            if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):\n                continue\n\n            if isinstance(val, (dict, list, set, tuple, bytearray)):\n                candidates.append(name)\n                seen.add(name)\n            else:\n                # try Tensor detection\n                try:\n                    if isinstance(val, torch.Tensor):\n                        if val.numel() <= self._tensor_numel_threshold:\n                            candidates.append(name)\n                            seen.add(name)\n                        else:\n                            logger.debug(f\"Ignoring large tensor attr '{name}', numel={val.numel()}\")\n                except Exception:\n                    continue\n\n            if len(candidates) >= max_attrs:\n                break\n\n        self._auto_detected_attrs = candidates\n        logger.debug(f\"Autodetected mutable attrs to clone: {self._auto_detected_attrs}\")\n        return self._auto_detected_attrs\n\n    def _is_readonly_property(self, base_obj, attr_name: str) -> bool:\n        try:\n            cls = type(base_obj)\n            descriptor = getattr(cls, attr_name, None)\n            if isinstance(descriptor, property):\n                return descriptor.fset is None\n            if hasattr(descriptor, \"__set__\") is False and descriptor is not None:\n                return False\n        except Exception:\n            pass\n        return False\n\n    def _clone_mutable_attrs(self, base, local):\n        attrs_to_clone = list(self._mutable_attrs)\n        attrs_to_clone.extend(self._autodetect_mutables())\n\n        EXCLUDE_ATTRS = {\n            \"components\",\n        }\n\n        for attr in attrs_to_clone:\n            if attr in EXCLUDE_ATTRS:\n                logger.debug(f\"Skipping excluded attr '{attr}'\")\n                continue\n            if not hasattr(base, attr):\n                continue\n            if self._is_readonly_property(base, attr):\n                logger.debug(f\"Skipping read-only property '{attr}'\")\n                continue\n\n            try:\n                val = getattr(base, attr)\n            except Exception as e:\n                logger.debug(f\"Could not getattr('{attr}') on base pipeline: {e}\")\n                continue\n\n            try:\n                if isinstance(val, dict):\n                    setattr(local, attr, dict(val))\n                elif isinstance(val, (list, tuple, set)):\n                    setattr(local, attr, list(val))\n                elif isinstance(val, bytearray):\n                    setattr(local, attr, bytearray(val))\n                else:\n                    # small tensors or atomic values\n                    if isinstance(val, torch.Tensor):\n                        if val.numel() <= self._tensor_numel_threshold:\n                            setattr(local, attr, val.clone())\n                        else:\n                            # don't clone big tensors, keep reference\n                            setattr(local, attr, val)\n                    else:\n                        try:\n                            setattr(local, attr, copy.copy(val))\n                        except Exception:\n                            setattr(local, attr, val)\n            except (AttributeError, TypeError) as e:\n                logger.debug(f\"Skipping cloning attribute '{attr}' because it is not settable: {e}\")\n                continue\n            except Exception as e:\n                logger.debug(f\"Unexpected error cloning attribute '{attr}': {e}\")\n                continue\n\n    def _is_tokenizer_component(self, component) -> bool:\n        if component is None:\n            return False\n\n        tokenizer_methods = [\"encode\", \"decode\", \"tokenize\", \"__call__\"]\n        has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)\n\n        class_name = component.__class__.__name__.lower()\n        has_tokenizer_in_name = \"tokenizer\" in class_name\n\n        tokenizer_attrs = [\"vocab_size\", \"pad_token\", \"eos_token\", \"bos_token\"]\n        has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)\n\n        return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)\n\n    def _should_wrap_tokenizers(self) -> bool:\n        return True\n\n    def generate(self, *args, num_inference_steps: int = 50, device: str | None = None, **kwargs):\n        local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)\n\n        try:\n            local_pipe = copy.copy(self._base)\n        except Exception as e:\n            logger.warning(f\"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).\")\n            local_pipe = copy.deepcopy(self._base)\n\n        try:\n            if (\n                hasattr(local_pipe, \"vae\")\n                and local_pipe.vae is not None\n                and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper)\n            ):\n                local_pipe.vae = ThreadSafeVAEWrapper(local_pipe.vae, self._vae_lock)\n\n            if (\n                hasattr(local_pipe, \"image_processor\")\n                and local_pipe.image_processor is not None\n                and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper)\n            ):\n                local_pipe.image_processor = ThreadSafeImageProcessorWrapper(\n                    local_pipe.image_processor, self._image_lock\n                )\n        except Exception as e:\n            logger.debug(f\"Could not wrap vae/image_processor: {e}\")\n\n        if local_scheduler is not None:\n            try:\n                timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(\n                    local_scheduler.scheduler,\n                    num_inference_steps=num_inference_steps,\n                    device=device,\n                    return_scheduler=True,\n                    **{k: v for k, v in kwargs.items() if k in [\"timesteps\", \"sigmas\"]},\n                )\n\n                final_scheduler = BaseAsyncScheduler(configured_scheduler)\n                setattr(local_pipe, \"scheduler\", final_scheduler)\n            except Exception:\n                logger.warning(\"Could not set scheduler on local pipe; proceeding without replacing scheduler.\")\n\n        self._clone_mutable_attrs(self._base, local_pipe)\n\n        original_tokenizers = {}\n\n        if self._should_wrap_tokenizers():\n            try:\n                for name in dir(local_pipe):\n                    if \"tokenizer\" in name and not name.startswith(\"_\"):\n                        tok = getattr(local_pipe, name, None)\n                        if tok is not None and self._is_tokenizer_component(tok):\n                            if not isinstance(tok, ThreadSafeTokenizerWrapper):\n                                original_tokenizers[name] = tok\n                                wrapped_tokenizer = ThreadSafeTokenizerWrapper(tok, self._tokenizer_lock)\n                                setattr(local_pipe, name, wrapped_tokenizer)\n\n                if hasattr(local_pipe, \"components\") and isinstance(local_pipe.components, dict):\n                    for key, val in local_pipe.components.items():\n                        if val is None:\n                            continue\n\n                        if self._is_tokenizer_component(val):\n                            if not isinstance(val, ThreadSafeTokenizerWrapper):\n                                original_tokenizers[f\"components[{key}]\"] = val\n                                wrapped_tokenizer = ThreadSafeTokenizerWrapper(val, self._tokenizer_lock)\n                                local_pipe.components[key] = wrapped_tokenizer\n\n            except Exception as e:\n                logger.debug(f\"Tokenizer wrapping step encountered an error: {e}\")\n\n        result = None\n        cm = getattr(local_pipe, \"model_cpu_offload_context\", None)\n\n        try:\n            if callable(cm):\n                try:\n                    with cm():\n                        result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)\n                except TypeError:\n                    try:\n                        with cm:\n                            result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)\n                    except Exception as e:\n                        logger.debug(f\"model_cpu_offload_context usage failed: {e}. Proceeding without it.\")\n                        result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)\n            else:\n                result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)\n\n            return result\n\n        finally:\n            try:\n                for name, tok in original_tokenizers.items():\n                    if name.startswith(\"components[\"):\n                        key = name[len(\"components[\") : -1]\n                        if hasattr(local_pipe, \"components\") and isinstance(local_pipe.components, dict):\n                            local_pipe.components[key] = tok\n                    else:\n                        setattr(local_pipe, name, tok)\n            except Exception as e:\n                logger.debug(f\"Error restoring original tokenizers: {e}\")\n"
  },
  {
    "path": "examples/server-async/utils/scheduler.py",
    "content": "import copy\nimport inspect\nfrom typing import Any, List, Optional, Union\n\nimport torch\n\n\nclass BaseAsyncScheduler:\n    def __init__(self, scheduler: Any):\n        self.scheduler = scheduler\n\n    def __getattr__(self, name: str):\n        if hasattr(self.scheduler, name):\n            return getattr(self.scheduler, name)\n        raise AttributeError(f\"'{self.__class__.__name__}' object has no attribute '{name}'\")\n\n    def __setattr__(self, name: str, value):\n        if name == \"scheduler\":\n            super().__setattr__(name, value)\n        else:\n            if hasattr(self, \"scheduler\") and hasattr(self.scheduler, name):\n                setattr(self.scheduler, name, value)\n            else:\n                super().__setattr__(name, value)\n\n    def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):\n        local = copy.deepcopy(self.scheduler)\n        local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)\n        cloned = self.__class__(local)\n        return cloned\n\n    def __repr__(self):\n        return f\"BaseAsyncScheduler({repr(self.scheduler)})\"\n\n    def __str__(self):\n        return f\"BaseAsyncScheduler wrapping: {str(self.scheduler)}\"\n\n\ndef async_retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.\n    Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Backwards compatible: by default the function behaves exactly as before and returns\n        (timesteps_tensor, num_inference_steps)\n\n    If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed\n    scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)\n    or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:\n        (timesteps_tensor, num_inference_steps, scheduler_in_use)\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Optional kwargs:\n        return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)\n            where `scheduler_in_use` is a scheduler instance that already has timesteps set.\n            This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.\n\n    Returns:\n        `(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or\n        `(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.\n    \"\"\"\n    # pop our optional control kwarg (keeps compatibility)\n    return_scheduler = bool(kwargs.pop(\"return_scheduler\", False))\n\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n\n    # choose scheduler to call set_timesteps on\n    scheduler_in_use = scheduler\n    if return_scheduler:\n        # Do not mutate the provided scheduler: prefer to clone if possible\n        if hasattr(scheduler, \"clone_for_request\"):\n            try:\n                # clone_for_request may accept num_inference_steps or other kwargs; be permissive\n                scheduler_in_use = scheduler.clone_for_request(\n                    num_inference_steps=num_inference_steps or 0, device=device\n                )\n            except Exception:\n                scheduler_in_use = copy.deepcopy(scheduler)\n        else:\n            # fallback deepcopy (scheduler tends to be smallish - acceptable)\n            scheduler_in_use = copy.deepcopy(scheduler)\n\n    # helper to test if set_timesteps supports a particular kwarg\n    def _accepts(param_name: str) -> bool:\n        try:\n            return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())\n        except (ValueError, TypeError):\n            # if signature introspection fails, be permissive and attempt the call later\n            return False\n\n    # now call set_timesteps on the chosen scheduler_in_use (may be original or clone)\n    if timesteps is not None:\n        accepts_timesteps = _accepts(\"timesteps\")\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps_out = scheduler_in_use.timesteps\n        num_inference_steps = len(timesteps_out)\n    elif sigmas is not None:\n        accept_sigmas = _accepts(\"sigmas\")\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps_out = scheduler_in_use.timesteps\n        num_inference_steps = len(timesteps_out)\n    else:\n        # default path\n        scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps_out = scheduler_in_use.timesteps\n\n    if return_scheduler:\n        return timesteps_out, num_inference_steps, scheduler_in_use\n    return timesteps_out, num_inference_steps\n"
  },
  {
    "path": "examples/server-async/utils/utils.py",
    "content": "import gc\nimport logging\nimport os\nimport tempfile\nimport uuid\n\nimport torch\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass Utils:\n    def __init__(self, host: str = \"0.0.0.0\", port: int = 8500):\n        self.service_url = f\"http://{host}:{port}\"\n        self.image_dir = os.path.join(tempfile.gettempdir(), \"images\")\n        if not os.path.exists(self.image_dir):\n            os.makedirs(self.image_dir)\n\n        self.video_dir = os.path.join(tempfile.gettempdir(), \"videos\")\n        if not os.path.exists(self.video_dir):\n            os.makedirs(self.video_dir)\n\n    def save_image(self, image):\n        if hasattr(image, \"to\"):\n            try:\n                image = image.to(\"cpu\")\n            except Exception:\n                pass\n\n        if isinstance(image, torch.Tensor):\n            from torchvision import transforms\n\n            to_pil = transforms.ToPILImage()\n            image = to_pil(image.squeeze(0).clamp(0, 1))\n\n        filename = \"img\" + str(uuid.uuid4()).split(\"-\")[0] + \".png\"\n        image_path = os.path.join(self.image_dir, filename)\n        logger.info(f\"Saving image to {image_path}\")\n\n        image.save(image_path, format=\"PNG\", optimize=True)\n\n        del image\n        gc.collect()\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n\n        return os.path.join(self.service_url, \"images\", filename)\n"
  },
  {
    "path": "examples/server-async/utils/wrappers.py",
    "content": "class ThreadSafeTokenizerWrapper:\n    def __init__(self, tokenizer, lock):\n        self._tokenizer = tokenizer\n        self._lock = lock\n\n        self._thread_safe_methods = {\n            \"__call__\",\n            \"encode\",\n            \"decode\",\n            \"tokenize\",\n            \"encode_plus\",\n            \"batch_encode_plus\",\n            \"batch_decode\",\n        }\n\n    def __getattr__(self, name):\n        attr = getattr(self._tokenizer, name)\n\n        if name in self._thread_safe_methods and callable(attr):\n\n            def wrapped_method(*args, **kwargs):\n                with self._lock:\n                    return attr(*args, **kwargs)\n\n            return wrapped_method\n\n        return attr\n\n    def __call__(self, *args, **kwargs):\n        with self._lock:\n            return self._tokenizer(*args, **kwargs)\n\n    def __setattr__(self, name, value):\n        if name.startswith(\"_\"):\n            super().__setattr__(name, value)\n        else:\n            setattr(self._tokenizer, name, value)\n\n    def __dir__(self):\n        return dir(self._tokenizer)\n\n\nclass ThreadSafeVAEWrapper:\n    def __init__(self, vae, lock):\n        self._vae = vae\n        self._lock = lock\n\n    def __getattr__(self, name):\n        attr = getattr(self._vae, name)\n        if name in {\"decode\", \"encode\", \"forward\"} and callable(attr):\n\n            def wrapped(*args, **kwargs):\n                with self._lock:\n                    return attr(*args, **kwargs)\n\n            return wrapped\n        return attr\n\n    def __setattr__(self, name, value):\n        if name.startswith(\"_\"):\n            super().__setattr__(name, value)\n        else:\n            setattr(self._vae, name, value)\n\n\nclass ThreadSafeImageProcessorWrapper:\n    def __init__(self, proc, lock):\n        self._proc = proc\n        self._lock = lock\n\n    def __getattr__(self, name):\n        attr = getattr(self._proc, name)\n        if name in {\"postprocess\", \"preprocess\"} and callable(attr):\n\n            def wrapped(*args, **kwargs):\n                with self._lock:\n                    return attr(*args, **kwargs)\n\n            return wrapped\n        return attr\n\n    def __setattr__(self, name, value):\n        if name.startswith(\"_\"):\n            super().__setattr__(name, value)\n        else:\n            setattr(self._proc, name, value)\n"
  },
  {
    "path": "examples/t2i_adapter/README.md",
    "content": "We don't yet support training T2I-Adapters on Stable Diffusion yet. For training T2I-Adapters on Stable Diffusion XL, refer [here](./README_sdxl.md)."
  },
  {
    "path": "examples/t2i_adapter/README_sdxl.md",
    "content": "# T2I-Adapter training example for Stable Diffusion XL (SDXL)\n\nThe `train_t2i_adapter_sdxl.py` script shows how to implement the [T2I-Adapter training procedure](https://hf.co/papers/2302.08453) for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/t2i_adapter` folder and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\n\n## Circle filling dataset\n\nThe original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.\n\n## Training\n\nOur training examples use two test conditioning images. They can be downloaded by running\n\n```sh\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png\n\nwget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png\n```\n\nThen run `hf auth login` to log into your Hugging Face account. This is needed to be able to push the trained T2IAdapter parameters to Hugging Face Hub.\n\n```bash\nexport MODEL_DIR=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport OUTPUT_DIR=\"path to save model\"\n\naccelerate launch train_t2i_adapter_sdxl.py \\\n --pretrained_model_name_or_path=$MODEL_DIR \\\n --output_dir=$OUTPUT_DIR \\\n --dataset_name=fusing/fill50k \\\n --mixed_precision=\"fp16\" \\\n --resolution=1024 \\\n --learning_rate=1e-5 \\\n --max_train_steps=15000 \\\n --validation_image \"./conditioning_image_1.png\" \"./conditioning_image_2.png\" \\\n --validation_prompt \"red circle with blue background\" \"cyan circle with brown floral background\" \\\n --validation_steps=100 \\\n --train_batch_size=1 \\\n --gradient_accumulation_steps=4 \\\n --report_to=\"wandb\" \\\n --seed=42 \\\n --push_to_hub\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `report_to=\"wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.\n* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.\n\nOur experiments were conducted on a single 40GB A100 GPU.\n\n### Inference\n\nOnce training is done, we can perform inference like so:\n\n```python\nfrom diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteSchedulerTest\nfrom diffusers.utils import load_image\nimport torch\n\nbase_model_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\nadapter_path = \"path to adapter\"\n\nadapter = T2IAdapter.from_pretrained(adapter_path, torch_dtype=torch.float16)\npipe = StableDiffusionXLAdapterPipeline.from_pretrained(\n    base_model_path, adapter=adapter, torch_dtype=torch.float16\n)\n\n# speed up diffusion process with faster scheduler and memory optimization\npipe.scheduler = EulerAncestralDiscreteSchedulerTest.from_config(pipe.scheduler.config)\n# remove following line if xformers is not installed or when using Torch 2.0.\npipe.enable_xformers_memory_efficient_attention()\n# memory optimization.\npipe.enable_model_cpu_offload()\n\ncontrol_image = load_image(\"./conditioning_image_1.png\")\nprompt = \"pale golden rod circle with old lace background\"\n\n# generate image\ngenerator = torch.manual_seed(0)\nimage = pipe(\n    prompt, num_inference_steps=20, generator=generator, image=control_image\n).images[0]\nimage.save(\"./output.png\")\n```\n\n## Notes\n\n### Specifying a better VAE\n\nSDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).\n"
  },
  {
    "path": "examples/t2i_adapter/requirements.txt",
    "content": "transformers>=4.25.1\naccelerate>=0.16.0\nsafetensors\ndatasets\ntorchvision\nftfy\ntensorboard\nwandb"
  },
  {
    "path": "examples/t2i_adapter/test_t2i_adapter.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass T2IAdapter(ExamplesTestsAccelerate):\n    def test_t2i_adapter_sdxl(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n            examples/t2i_adapter/train_t2i_adapter_sdxl.py\n            --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe\n            --adapter_model_name_or_path=hf-internal-testing/tiny-adapter\n            --dataset_name=hf-internal-testing/fill10\n            --output_dir={tmpdir}\n            --resolution=64\n            --train_batch_size=1\n            --gradient_accumulation_steps=1\n            --max_train_steps=9\n            --checkpointing_steps=2\n            \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"diffusion_pytorch_model.safetensors\")))\n"
  },
  {
    "path": "examples/t2i_adapter/train_t2i_adapter_sdxl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport functools\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    EulerDiscreteScheduler,\n    StableDiffusionXLAdapterPipeline,\n    T2IAdapter,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nMAX_SEQ_LENGTH = 77\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef image_grid(imgs, rows, cols):\n    assert len(imgs) == rows * cols\n\n    w, h = imgs[0].size\n    grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i % cols * w, i // cols * h))\n    return grid\n\n\ndef log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):\n    logger.info(\"Running validation... \")\n\n    adapter = accelerator.unwrap_model(adapter)\n\n    pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=vae,\n        unet=unet,\n        adapter=adapter,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        validation_image = Image.open(validation_image).convert(\"RGB\")\n        validation_image = validation_image.resize((args.resolution, args.resolution))\n\n        images = []\n\n        for _ in range(args.num_validation_images):\n            with torch.autocast(\"cuda\"):\n                image = pipeline(\n                    prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator\n                ).images[0]\n            images.append(image)\n\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images = [np.asarray(validation_image)]\n\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images.append(wandb.Image(validation_image, caption=\"adapter conditioning\"))\n\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({\"validation\": formatted_images})\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n        del pipeline\n        gc.collect()\n        torch.cuda.empty_cache()\n\n        return image_logs\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef save_model_card(repo_id: str, image_logs: dict = None, base_model: str = None, repo_folder: str = None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# t2iadapter-{repo_id}\n\nThese are t2iadapter weights trained on {base_model} with new type of conditioning.\n{img_str}\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion-xl\",\n        \"stable-diffusion-xl-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"t2iadapter\",\n        \"diffusers-training\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a ControlNet training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--adapter_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained adapter model or model identifier from huggingface.co/models.\"\n        \" If not specified adapter weights are initialized w.r.t the configurations of SDXL.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be\"\n            \" float32 precision.\"\n        ),\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"t2iadapter-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--detection_resolution\",\n        type=int,\n        default=None,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--crops_coords_top_left_h\",\n        type=int,\n        default=0,\n        help=(\"Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet.\"),\n    )\n    parser.add_argument(\n        \"--crops_coords_top_left_w\",\n        type=int,\n        default=0,\n        help=(\"Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet.\"),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=3,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=1,\n        help=(\"Number of subprocesses to use for data loading.\"),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the adapter conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the t2iadapter conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"sd_xl_train_t2iadapter\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Specify either `--dataset_name` or `--train_data_dir`\")\n\n    if args.dataset_name is not None and args.train_data_dir is not None:\n        raise ValueError(\"Specify only one of `--dataset_name` or `--train_data_dir`\")\n\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    if args.validation_prompt is not None and args.validation_image is None:\n        raise ValueError(\"`--validation_image` must be set if `--validation_prompt` is set\")\n\n    if args.validation_prompt is None and args.validation_image is not None:\n        raise ValueError(\"`--validation_prompt` must be set if `--validation_image` is set\")\n\n    if (\n        args.validation_image is not None\n        and args.validation_prompt is not None\n        and len(args.validation_image) != 1\n        and len(args.validation_prompt) != 1\n        and len(args.validation_image) != len(args.validation_prompt)\n    ):\n        raise ValueError(\n            \"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,\"\n            \" or the same number of `--validation_prompt`s and `--validation_image`s\"\n        )\n\n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the t2iadapter encoder.\"\n        )\n\n    return args\n\n\ndef get_train_dataset(args, accelerator):\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n        )\n    else:\n        if args.train_data_dir is not None:\n            dataset = load_dataset(\n                args.train_data_dir,\n                cache_dir=args.cache_dir,\n            )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    if args.image_column is None:\n        image_column = column_names[0]\n        logger.info(f\"image column defaulting to {image_column}\")\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.caption_column is None:\n        caption_column = column_names[1]\n        logger.info(f\"caption column defaulting to {caption_column}\")\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    if args.conditioning_image_column is None:\n        conditioning_image_column = column_names[2]\n        logger.info(f\"conditioning image column defaulting to {conditioning_image_column}\")\n    else:\n        conditioning_image_column = args.conditioning_image_column\n        if conditioning_image_column not in column_names:\n            raise ValueError(\n                f\"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}\"\n            )\n\n    with accelerator.main_process_first():\n        train_dataset = dataset[\"train\"].shuffle(seed=args.seed)\n        if args.max_train_samples is not None:\n            train_dataset = train_dataset.select(range(args.max_train_samples))\n    return train_dataset\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):\n    prompt_embeds_list = []\n\n    captions = []\n    for caption in prompt_batch:\n        if random.random() < proportion_empty_prompts:\n            captions.append(\"\")\n        elif isinstance(caption, str):\n            captions.append(caption)\n        elif isinstance(caption, (list, np.ndarray)):\n            # take a random caption if there are multiple\n            captions.append(random.choice(caption) if is_train else caption[0])\n\n    with torch.no_grad():\n        for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n            text_inputs = tokenizer(\n                captions,\n                padding=\"max_length\",\n                max_length=tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            prompt_embeds = text_encoder(\n                text_input_ids.to(text_encoder.device),\n                output_hidden_states=True,\n            )\n\n            # We are only ALWAYS interested in the pooled output of the final text encoder\n            pooled_prompt_embeds = prompt_embeds[0]\n            prompt_embeds = prompt_embeds.hidden_states[-2]\n            bs_embed, seq_len, _ = prompt_embeds.shape\n            prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n            prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef prepare_train_dataset(dataset, accelerator):\n    image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    conditioning_image_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution),\n            transforms.ToTensor(),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[args.image_column]]\n        images = [image_transforms(image) for image in images]\n\n        conditioning_images = [image.convert(\"RGB\") for image in examples[args.conditioning_image_column]]\n        conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]\n\n        examples[\"pixel_values\"] = images\n        examples[\"conditioning_pixel_values\"] = conditioning_images\n\n        return examples\n\n    with accelerator.main_process_first():\n        dataset = dataset.with_transform(preprocess_train)\n\n    return dataset\n\n\ndef collate_fn(examples):\n    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    conditioning_pixel_values = torch.stack([example[\"conditioning_pixel_values\"] for example in examples])\n    conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    prompt_ids = torch.stack([torch.tensor(example[\"prompt_embeds\"]) for example in examples])\n\n    add_text_embeds = torch.stack([torch.tensor(example[\"text_embeds\"]) for example in examples])\n    add_time_ids = torch.stack([torch.tensor(example[\"time_ids\"]) for example in examples])\n\n    return {\n        \"pixel_values\": pixel_values,\n        \"conditioning_pixel_values\": conditioning_pixel_values,\n        \"prompt_ids\": prompt_ids,\n        \"unet_added_conditions\": {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids},\n    }\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name,\n                exist_ok=True,\n                token=args.hub_token,\n                private=True,\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    if args.adapter_model_name_or_path:\n        logger.info(\"Loading existing adapter weights.\")\n        t2iadapter = T2IAdapter.from_pretrained(args.adapter_model_name_or_path)\n    else:\n        logger.info(\"Initializing t2iadapter weights.\")\n        t2iadapter = T2IAdapter(\n            in_channels=3,\n            channels=(320, 640, 1280, 1280),\n            num_res_blocks=2,\n            downscale_factor=16,\n            adapter_type=\"full_adapter_xl\",\n        )\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            i = len(weights) - 1\n\n            while len(weights) > 0:\n                weights.pop()\n                model = models[i]\n\n                sub_dir = \"t2iadapter\"\n                model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                i -= 1\n\n        def load_model_hook(models, input_dir):\n            while len(models) > 0:\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = T2IAdapter.from_pretrained(os.path.join(input_dir, \"t2iadapter\"))\n\n                if args.control_type != \"style\":\n                    model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    t2iadapter.train()\n    unet.train()\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n\n    if unwrap_model(t2iadapter).dtype != torch.float32:\n        raise ValueError(\n            f\"Controlnet loaded as datatype {unwrap_model(t2iadapter).dtype}. {low_precision_error_string}\"\n        )\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = t2iadapter.parameters()\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae, unet and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    if args.pretrained_vae_model_name_or_path is not None:\n        vae.to(accelerator.device, dtype=weight_dtype)\n    else:\n        vae.to(accelerator.device, dtype=torch.float32)\n    unet.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    # Here, we compute not just the text embeddings but also the additional embeddings\n    # needed for the SD XL UNet to operate.\n    def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True):\n        original_size = (args.resolution, args.resolution)\n        target_size = (args.resolution, args.resolution)\n        crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)\n        prompt_batch = batch[args.caption_column]\n\n        prompt_embeds, pooled_prompt_embeds = encode_prompt(\n            prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train\n        )\n        add_text_embeds = pooled_prompt_embeds\n\n        # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n        add_time_ids = torch.tensor([add_time_ids])\n\n        prompt_embeds = prompt_embeds.to(accelerator.device)\n        add_text_embeds = add_text_embeds.to(accelerator.device)\n        add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)\n        add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)\n        unet_added_cond_kwargs = {\"text_embeds\": add_text_embeds, \"time_ids\": add_time_ids}\n\n        return {\"prompt_embeds\": prompt_embeds, **unet_added_cond_kwargs}\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    # Let's first compute all the embeddings so that we can free up the text encoders\n    # from memory.\n    text_encoders = [text_encoder_one, text_encoder_two]\n    tokenizers = [tokenizer_one, tokenizer_two]\n    train_dataset = get_train_dataset(args, accelerator)\n    compute_embeddings_fn = functools.partial(\n        compute_embeddings,\n        proportion_empty_prompts=args.proportion_empty_prompts,\n        text_encoders=text_encoders,\n        tokenizers=tokenizers,\n    )\n    with accelerator.main_process_first():\n        from datasets.fingerprint import Hasher\n\n        # fingerprint used by the cache for the other processes to load the result\n        # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401\n        new_fingerprint = Hasher.hash(args)\n        train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)\n\n    # Then get the training dataset ready to be passed to the dataloader.\n    train_dataset = prepare_train_dataset(train_dataset, accelerator)\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps,\n        num_training_steps=args.max_train_steps,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    t2iadapter, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        t2iadapter, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n\n        # tensorboard cannot handle list types for config\n        tracker_config.pop(\"validation_prompt\")\n        tracker_config.pop(\"validation_image\")\n\n        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    image_logs = None\n    for epoch in range(first_epoch, args.num_train_epochs):\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(t2iadapter):\n                if args.pretrained_vae_model_name_or_path is not None:\n                    pixel_values = batch[\"pixel_values\"].to(dtype=weight_dtype)\n                else:\n                    pixel_values = batch[\"pixel_values\"]\n\n                # encode pixel values with batch size of at most 8 to avoid OOM\n                latents = []\n                for i in range(0, pixel_values.shape[0], 8):\n                    latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())\n                latents = torch.cat(latents, dim=0)\n                latents = latents * vae.config.scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    latents = latents.to(weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n\n                # Cubic sampling to sample a random timestep for each image.\n                # For more details about why cubic sampling is used, refer to section 3.4 of https://huggingface.co/papers/2302.08453\n                timesteps = torch.rand((bsz,), device=latents.device)\n                timesteps = (1 - timesteps**3) * noise_scheduler.config.num_train_timesteps\n                timesteps = timesteps.long().to(noise_scheduler.timesteps.dtype)\n                timesteps = timesteps.clamp(0, noise_scheduler.config.num_train_timesteps - 1)\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Scale the noisy latents for the UNet\n                sigmas = get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)\n                inp_noisy_latents = noisy_latents / ((sigmas**2 + 1) ** 0.5)\n\n                # Adapter conditioning.\n                t2iadapter_image = batch[\"conditioning_pixel_values\"].to(dtype=weight_dtype)\n                down_block_additional_residuals = t2iadapter(t2iadapter_image)\n                down_block_additional_residuals = [\n                    sample.to(dtype=weight_dtype) for sample in down_block_additional_residuals\n                ]\n\n                # Predict the noise residual\n                model_pred = unet(\n                    inp_noisy_latents,\n                    timesteps,\n                    encoder_hidden_states=batch[\"prompt_ids\"],\n                    added_cond_kwargs=batch[\"unet_added_conditions\"],\n                    down_block_additional_residuals=down_block_additional_residuals,\n                    return_dict=False,\n                )[0]\n\n                # Denoise the latents\n                denoised_latents = model_pred * (-sigmas) + noisy_latents\n                weighing = sigmas**-2.0\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = latents  # we are computing loss against denoise latents\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # MSE loss\n                loss = torch.mean(\n                    (weighing.float() * (denoised_latents.float() - target.float()) ** 2).reshape(target.shape[0], -1),\n                    dim=1,\n                )\n                loss = loss.mean()\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = t2iadapter.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        image_logs = log_validation(\n                            vae,\n                            unet,\n                            t2iadapter,\n                            args,\n                            accelerator,\n                            weight_dtype,\n                            global_step,\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        t2iadapter = unwrap_model(t2iadapter)\n        t2iadapter.save_pretrained(args.output_dir)\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                image_logs=image_logs,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/test_examples_utils.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport shutil\nimport subprocess\nimport tempfile\nimport unittest\nfrom typing import List\n\nfrom accelerate.utils import write_basic_config\n\n\n# These utils relate to ensuring the right error message is received when running scripts\nclass SubprocessCallException(Exception):\n    pass\n\n\ndef run_command(command: List[str], return_stdout=False):\n    \"\"\"\n    Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture\n    if an error occurred while running `command`\n    \"\"\"\n    try:\n        output = subprocess.check_output(command, stderr=subprocess.STDOUT)\n        if return_stdout:\n            if hasattr(output, \"decode\"):\n                output = output.decode(\"utf-8\")\n            return output\n    except subprocess.CalledProcessError as e:\n        raise SubprocessCallException(\n            f\"Command `{' '.join(command)}` failed with the following error:\\n\\n{e.output.decode()}\"\n        ) from e\n\n\nclass ExamplesTestsAccelerate(unittest.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        super().setUpClass()\n        cls._tmpdir = tempfile.mkdtemp()\n        cls.configPath = os.path.join(cls._tmpdir, \"default_config.yml\")\n\n        write_basic_config(save_location=cls.configPath)\n        cls._launch_args = [\"accelerate\", \"launch\", \"--config_file\", cls.configPath]\n\n    @classmethod\n    def tearDownClass(cls):\n        super().tearDownClass()\n        shutil.rmtree(cls._tmpdir)\n"
  },
  {
    "path": "examples/text_to_image/README.md",
    "content": "# Stable Diffusion text-to-image fine-tuning\n\nThe `train_text_to_image.py` script shows how to fine-tune stable diffusion model on your own dataset.\n\n___Note___:\n\n___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___\n\n\n## Running locally with PyTorch\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder  and run\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\n### Naruto example\n\nYou need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.\n\nYou have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).\n\nRun the following command to authenticate your token\n\n```bash\nhf auth login\n```\n\nIf you have already cloned the repo, then you won't need to go through these steps.\n\n<br>\n\n#### Hardware\nWith `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.\n\n**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**\n<!-- accelerate_snippet_start -->\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\"  train_text_to_image.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$DATASET_NAME \\\n  --use_ema \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --output_dir=\"sd-naruto-model\"\n```\n<!-- accelerate_snippet_end -->\n\n\nTo run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).\nIf you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport TRAIN_DIR=\"path_to_your_dataset\"\n\naccelerate launch --mixed_precision=\"fp16\" train_text_to_image.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$TRAIN_DIR \\\n  --use_ema \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --output_dir=\"sd-naruto-model\"\n```\n\n\nOnce the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-naruto-model`. To load the fine-tuned model for inference just pass that path to `StableDiffusionPipeline`\n\n```python\nimport torch\nfrom diffusers import StableDiffusionPipeline\n\nmodel_path = \"path_to_saved_model\"\npipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)\npipe.to(\"cuda\")\n\nimage = pipe(prompt=\"yoda\").images[0]\nimage.save(\"yoda-naruto.png\")\n```\n\nCheckpoints only save the unet, so to run inference from a checkpoint, just load the unet\n\n```python\nimport torch\nfrom diffusers import StableDiffusionPipeline, UNet2DConditionModel\n\nmodel_path = \"path_to_saved_model\"\nunet = UNet2DConditionModel.from_pretrained(model_path + \"/checkpoint-<N>/unet\", torch_dtype=torch.float16)\n\npipe = StableDiffusionPipeline.from_pretrained(\"<initial model>\", unet=unet, torch_dtype=torch.float16)\npipe.to(\"cuda\")\n\nimage = pipe(prompt=\"yoda\").images[0]\nimage.save(\"yoda-naruto.png\")\n```\n\n#### Training with multiple GPUs\n\n`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)\nfor running distributed training with `accelerate`. Here is an example command:\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch --mixed_precision=\"fp16\" --multi_gpu  train_text_to_image.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$DATASET_NAME \\\n  --use_ema \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --gradient_checkpointing \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --output_dir=\"sd-naruto-model\"\n```\n\n\n#### Training with Min-SNR weighting\n\nWe support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://huggingface.co/papers/2303.09556) which helps to achieve faster convergence\nby rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended\nvalue when using it is 5.0.\n\nYou can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups:\n\n* Training without the Min-SNR weighting strategy\n* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0)\n* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0)\n\nFor our small Narutos dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced.\n\nAlso, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.\n\n\n#### Training with EMA weights\n\nThrough the `EMAModel` class, we support a convenient method of tracking an exponential moving average of model parameters.  This helps to smooth out noise in model parameter updates and generally improves model performance.  If enabled with the `--use_ema` argument, the final model checkpoint that is saved at the end of training will use the EMA weights.\n\nEMA weights require an additional full-precision copy of the model parameters to be stored in memory, but otherwise have very little performance overhead.  `--foreach_ema` can be used to further reduce the overhead.  If you are short on VRAM and still want to use EMA weights, you can store them in CPU RAM by using the `--offload_ema` argument.  This will keep the EMA weights in pinned CPU memory during the training step.  Then, once every model parameter update, it will transfer the EMA weights back to the GPU which can then update the parameters on the GPU, before sending them back to the CPU.  Both of these transfers are set up as non-blocking, so CUDA devices should be able to overlap this transfer with other computations.  With sufficient bandwidth between the host and device and a sufficiently long gap between model parameter updates, storing EMA weights in CPU RAM should have no additional performance overhead, as long as no other calls force synchronization.\n\n#### Training with DREAM\n\nWe support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://huggingface.co/papers/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop.  You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper.\n\n\n\n## Training with LoRA\n\nLow-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.\n\nIn a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:\n\n- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).\n- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.\n- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.\n\n[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.\n\nWith LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset\non consumer GPUs like Tesla T4, Tesla V100.\n\n### Training\n\nFirst, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).\n\n**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**\n\n**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n```\n\nFor this example we want to directly store the trained LoRA embeddings on the Hub, so\nwe need to be logged in and add the `--push_to_hub` flag.\n\n```bash\nhf auth login\n```\n\nNow we can start training!\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" train_text_to_image_lora.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$DATASET_NAME --caption_column=\"text\" \\\n  --resolution=512 --random_flip \\\n  --train_batch_size=1 \\\n  --num_train_epochs=100 --checkpointing_steps=5000 \\\n  --learning_rate=1e-04 --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --seed=42 \\\n  --output_dir=\"sd-naruto-model-lora\" \\\n  --validation_prompt=\"cute dragon creature\" --report_to=\"wandb\"\n```\n\nThe above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.\n\n**___Note: When using LoRA we can use a much higher learning rate compared to non-LoRA fine-tuning. Here we use *1e-4* instead of the usual *1e-5*. Also, by using LoRA, it's possible to run `train_text_to_image_lora.py` in consumer GPUs like T4 or V100.___**\n\nThe final LoRA embedding weights have been uploaded to [sayakpaul/sd-model-finetuned-lora-t4](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4). **___Note: [The final weights](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/pytorch_lora_weights.bin) are only 3 MB in size, which is orders of magnitudes smaller than the original model.___**\n\nYou can check some inference samples that were logged during the course of the fine-tuning process [here](https://wandb.ai/sayakpaul/text2image-fine-tune/runs/q4lc0xsw).\n\n### Inference\n\nOnce you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline` after loading the trained LoRA weights.  You\nneed to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-naruto-model-lora`.\n\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\nmodel_path = \"sayakpaul/sd-model-finetuned-lora-t4\"\npipe = StableDiffusionPipeline.from_pretrained(\"CompVis/stable-diffusion-v1-4\", torch_dtype=torch.float16)\npipe.unet.load_attn_procs(model_path)\npipe.to(\"cuda\")\n\nprompt = \"A naruto with green eyes and red legs.\"\nimage = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]\nimage.save(\"naruto.png\")\n```\n\nIf you are loading the LoRA parameters from the Hub and if the Hub repository has\na `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then\nyou can do:\n\n```py\nfrom huggingface_hub.repocard import RepoCard\n\nlora_model_id = \"sayakpaul/sd-model-finetuned-lora-t4\"\ncard = RepoCard.load(lora_model_id)\nbase_model_id = card.data.to_dict()[\"base_model\"]\n\npipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)\n...\n```\n\n## Training with Flax/JAX\n\nFor faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.\n\n**___Note: The flax example doesn't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards or TPU v3.___**\n\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n```bash\npip install -U -r requirements_flax.txt\n```\n\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\npython train_text_to_image_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$DATASET_NAME \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --mixed_precision=\"fp16\" \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --output_dir=\"sd-naruto-model\"\n```\n\nTo run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata).\nIf you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.\n\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport TRAIN_DIR=\"path_to_your_dataset\"\n\npython train_text_to_image_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$TRAIN_DIR \\\n  --resolution=512 --center_crop --random_flip \\\n  --train_batch_size=1 \\\n  --mixed_precision=\"fp16\" \\\n  --max_train_steps=15000 \\\n  --learning_rate=1e-05 \\\n  --max_grad_norm=1 \\\n  --output_dir=\"sd-naruto-model\"\n```\n\n### Training with xFormers:\n\nYou can enable memory efficient attention by [installing xFormers](https://huggingface.co/docs/diffusers/main/en/optimization/xformers) and passing the `--enable_xformers_memory_efficient_attention` argument to the script.\n\nxFormers training is not available for Flax/JAX.\n\n**Note**:\n\nAccording to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment.\n\n## Stable Diffusion XL\n\n* We support fine-tuning the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) via the `train_text_to_image_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).\n* We also support fine-tuning of the UNet and Text Encoder shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with LoRA via the `train_text_to_image_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).\n"
  },
  {
    "path": "examples/text_to_image/README_sdxl.md",
    "content": "# Stable Diffusion XL text-to-image fine-tuning\n\nThe `train_text_to_image_sdxl.py` script shows how to fine-tune Stable Diffusion XL (SDXL) on your own dataset.\n\n🚨 This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset. 🚨\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install -e .\n```\n\nThen cd in the `examples/text_to_image` folder and run\n```bash\npip install -r requirements_sdxl.txt\n```\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\nOr for a default accelerate configuration without answering questions about your environment\n\n```bash\naccelerate config default\n```\n\nOr if your environment doesn't support an interactive shell (e.g., a notebook)\n\n```python\nfrom accelerate.utils import write_basic_config\nwrite_basic_config()\n```\n\nWhen running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.\nNote also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.\n\n### Training\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport VAE_NAME=\"madebyollin/sdxl-vae-fp16-fix\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n\naccelerate launch train_text_to_image_sdxl.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --pretrained_vae_model_name_or_path=$VAE_NAME \\\n  --dataset_name=$DATASET_NAME \\\n  --enable_xformers_memory_efficient_attention \\\n  --resolution=512 --center_crop --random_flip \\\n  --proportion_empty_prompts=0.2 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 --gradient_checkpointing \\\n  --max_train_steps=10000 \\\n  --use_8bit_adam \\\n  --learning_rate=1e-06 --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --mixed_precision=\"fp16\" \\\n  --report_to=\"wandb\" \\\n  --validation_prompt=\"a cute Sundar Pichai creature\" --validation_epochs 5 \\\n  --checkpointing_steps=5000 \\\n  --output_dir=\"sdxl-naruto-model\" \\\n  --push_to_hub\n```\n\n**Notes**:\n\n*  The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/naruto-blip-captions`](https://hf.co/datasets/lambdalabs/naruto-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion.\n* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4.\n* The training command shown above performs intermediate quality validation in between the training epochs and logs the results to Weights and Biases. `--report_to`, `--validation_prompt`, and `--validation_epochs` are the relevant CLI arguments here.\n* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).\n\n### Inference\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\nmodel_path = \"you-model-id-goes-here\" # <-- change this\npipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)\npipe.to(\"cuda\")\n\nprompt = \"A naruto with green eyes and red legs.\"\nimage = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]\nimage.save(\"naruto.png\")\n```\n\n### Inference in Pytorch XLA\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\nimport torch_xla.core.xla_model as xm\n\nmodel_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\npipe = DiffusionPipeline.from_pretrained(model_id)\n\ndevice = xm.xla_device()\npipe.to(device)\n\nprompt = \"A naruto with green eyes and red legs.\"\nstart = time()\nimage = pipe(prompt, num_inference_steps=inference_steps).images[0]\nprint(f'Compilation time is {time()-start} sec')\nimage.save(\"naruto.png\")\n\nstart = time()\nimage = pipe(prompt, num_inference_steps=inference_steps).images[0]\nprint(f'Inference time is {time()-start} sec after compilation')\n```\n\nNote: There is a warmup step in PyTorch XLA. This takes longer because of\ncompilation and optimization. To see the real benefits of Pytorch XLA and\nspeedup, we need to call the pipe again on the input with the same length\nas the original prompt to reuse the optimized graph and get the performance\nboost.\n\n## LoRA training example for Stable Diffusion XL (SDXL)\n\nLow-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.\n\nIn a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:\n\n- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).\n- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.\n- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.\n\n[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.\n\nWith LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset\non consumer GPUs like Tesla T4, Tesla V100.\n\n### Training\n\nFirst, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Narutos dataset](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions).\n\n**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**\n\n```bash\nexport MODEL_NAME=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport VAE_NAME=\"madebyollin/sdxl-vae-fp16-fix\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\n```\n\nFor this example we want to directly store the trained LoRA embeddings on the Hub, so\nwe need to be logged in and add the `--push_to_hub` flag.\n\n```bash\nhf auth login\n```\n\nNow we can start training!\n\n```bash\naccelerate launch train_text_to_image_lora_sdxl.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --pretrained_vae_model_name_or_path=$VAE_NAME \\\n  --dataset_name=$DATASET_NAME --caption_column=\"text\" \\\n  --resolution=1024 --random_flip \\\n  --train_batch_size=1 \\\n  --num_train_epochs=2 --checkpointing_steps=500 \\\n  --learning_rate=1e-04 --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --mixed_precision=\"fp16\" \\\n  --seed=42 \\\n  --output_dir=\"sd-naruto-model-lora-sdxl\" \\\n  --validation_prompt=\"cute dragon creature\" --report_to=\"wandb\" \\\n  --push_to_hub\n```\n\nThe above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.\n\n**Notes**:\n\n* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).\n\n\n### Using DeepSpeed\nUsing DeepSpeed one can reduce the consumption of GPU memory, enabling the training of models on GPUs with smaller memory sizes. DeepSpeed is capable of offloading model parameters to the machine's memory, or it can distribute parameters, gradients, and optimizer states across multiple GPUs. This allows for the training of larger models under the same hardware configuration.\n\nFirst, you need to use the `accelerate config` command to choose to use DeepSpeed, or manually use the accelerate config file to set up DeepSpeed.\n\nHere is an example of a config file for using DeepSpeed. For more detailed explanations of the configuration, you can refer to this [link](https://huggingface.co/docs/accelerate/usage_guides/deepspeed).\n```yaml\ncompute_environment: LOCAL_MACHINE\ndebug: true\ndeepspeed_config:\n  gradient_accumulation_steps: 1\n  gradient_clipping: 1.0\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: fp16\nnum_machines: 1\nnum_processes: 1\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n```\nYou need to save the mentioned configuration as an `accelerate_config.yaml` file. Then, you need to input the path of your `accelerate_config.yaml` file into the `ACCELERATE_CONFIG_FILE` parameter. This way you can use DeepSpeed to train your SDXL model in LoRA. Additionally, you can use DeepSpeed to train other SD models in this way.\n\n```shell\nexport MODEL_NAME=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport VAE_NAME=\"madebyollin/sdxl-vae-fp16-fix\"\nexport DATASET_NAME=\"lambdalabs/naruto-blip-captions\"\nexport ACCELERATE_CONFIG_FILE=\"your accelerate_config.yaml\"\n\naccelerate launch  --config_file $ACCELERATE_CONFIG_FILE train_text_to_image_lora_sdxl.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --pretrained_vae_model_name_or_path=$VAE_NAME \\\n  --dataset_name=$DATASET_NAME --caption_column=\"text\" \\\n  --resolution=1024  \\\n  --train_batch_size=1 \\\n  --num_train_epochs=2 \\\n  --checkpointing_steps=2 \\\n  --learning_rate=1e-04 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --mixed_precision=\"fp16\" \\\n  --max_train_steps=20 \\\n  --validation_epochs=20 \\\n  --seed=1234 \\\n  --output_dir=\"sd-naruto-model-lora-sdxl\" \\\n  --validation_prompt=\"cute dragon creature\"\n```\n\n\n### Finetuning the text encoder and UNet\n\nThe script also allows you to finetune the `text_encoder` along with the `unet`.\n\n🚨 Training the text encoder requires additional memory.\n\nPass the `--train_text_encoder` argument to the training script to enable finetuning the `text_encoder` and `unet`:\n\n```bash\naccelerate launch train_text_to_image_lora_sdxl.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --dataset_name=$DATASET_NAME --caption_column=\"text\" \\\n  --resolution=1024 --random_flip \\\n  --train_batch_size=1 \\\n  --num_train_epochs=2 --checkpointing_steps=500 \\\n  --learning_rate=1e-04 --lr_scheduler=\"constant\" --lr_warmup_steps=0 \\\n  --seed=42 \\\n  --output_dir=\"sd-naruto-model-lora-sdxl-txt\" \\\n  --train_text_encoder \\\n  --validation_prompt=\"cute dragon creature\" --report_to=\"wandb\" \\\n  --push_to_hub\n```\n\n### Inference\n\nOnce you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights.  You\nneed to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-naruto-model-lora-sdxl`.\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\nmodel_path = \"takuoko/sd-naruto-model-lora-sdxl\"\npipe = DiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16)\npipe.to(\"cuda\")\npipe.load_lora_weights(model_path)\n\nprompt = \"A naruto with green eyes and red legs.\"\nimage = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]\nimage.save(\"naruto.png\")\n```\n"
  },
  {
    "path": "examples/text_to_image/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\ndatasets>=2.19.1\nftfy\ntensorboard\nJinja2\npeft>=0.17.0\n"
  },
  {
    "path": "examples/text_to_image/requirements_flax.txt",
    "content": "transformers>=4.25.1\ndatasets\nflax\noptax\ntorch\ntorchvision\nftfy\ntensorboard\nJinja2\n"
  },
  {
    "path": "examples/text_to_image/requirements_sdxl.txt",
    "content": "accelerate>=0.22.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2\ndatasets\npeft>=0.17.0"
  },
  {
    "path": "examples/text_to_image/test_text_to_image.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport shutil\nimport sys\nimport tempfile\n\nfrom diffusers import DiffusionPipeline, UNet2DConditionModel  # noqa: E402\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass TextToImage(ExamplesTestsAccelerate):\n    def test_text_to_image(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/text_to_image/train_text_to_image.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --center_crop\n                --random_flip\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"unet\", \"diffusion_pytorch_model.safetensors\")))\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"scheduler\", \"scheduler_config.json\")))\n\n    def test_text_to_image_checkpointing(self):\n        pretrained_model_name_or_path = \"hf-internal-testing/tiny-stable-diffusion-pipe\"\n        prompt = \"a prompt\"\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            # Run training script with checkpointing\n            # max_train_steps == 4, checkpointing_steps == 2\n            # Should create checkpoints at steps 2, 4\n\n            initial_run_args = f\"\"\"\n                examples/text_to_image/train_text_to_image.py\n                --pretrained_model_name_or_path {pretrained_model_name_or_path}\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --center_crop\n                --random_flip\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 4\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)\n            pipe(prompt, num_inference_steps=1)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\"},\n            )\n\n            # check can run an intermediate checkpoint\n            unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder=\"checkpoint-2/unet\")\n            pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)\n            pipe(prompt, num_inference_steps=1)\n\n            # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming\n            shutil.rmtree(os.path.join(tmpdir, \"checkpoint-2\"))\n\n            # Run training script for 2 total steps resuming from checkpoint 4\n\n            resume_run_args = f\"\"\"\n                examples/text_to_image/train_text_to_image.py\n                --pretrained_model_name_or_path {pretrained_model_name_or_path}\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --center_crop\n                --random_flip\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=1\n                --resume_from_checkpoint=checkpoint-4\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            # check can run new fully trained pipeline\n            pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)\n            pipe(prompt, num_inference_steps=1)\n\n            # no checkpoint-2 -> check old checkpoints do not exist\n            # check new checkpoints exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-5\"},\n            )\n\n    def test_text_to_image_checkpointing_use_ema(self):\n        pretrained_model_name_or_path = \"hf-internal-testing/tiny-stable-diffusion-pipe\"\n        prompt = \"a prompt\"\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            # Run training script with checkpointing\n            # max_train_steps == 4, checkpointing_steps == 2\n            # Should create checkpoints at steps 2, 4\n\n            initial_run_args = f\"\"\"\n                examples/text_to_image/train_text_to_image.py\n                --pretrained_model_name_or_path {pretrained_model_name_or_path}\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --center_crop\n                --random_flip\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 4\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --use_ema\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)\n            pipe(prompt, num_inference_steps=2)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\"},\n            )\n\n            # check can run an intermediate checkpoint\n            unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder=\"checkpoint-2/unet\")\n            pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)\n            pipe(prompt, num_inference_steps=1)\n\n            # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming\n            shutil.rmtree(os.path.join(tmpdir, \"checkpoint-2\"))\n\n            # Run training script for 2 total steps resuming from checkpoint 4\n\n            resume_run_args = f\"\"\"\n                examples/text_to_image/train_text_to_image.py\n                --pretrained_model_name_or_path {pretrained_model_name_or_path}\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --center_crop\n                --random_flip\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=1\n                --resume_from_checkpoint=checkpoint-4\n                --use_ema\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            # check can run new fully trained pipeline\n            pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)\n            pipe(prompt, num_inference_steps=1)\n\n            # no checkpoint-2 -> check old checkpoints do not exist\n            # check new checkpoints exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-5\"},\n            )\n\n    def test_text_to_image_checkpointing_checkpoints_total_limit(self):\n        pretrained_model_name_or_path = \"hf-internal-testing/tiny-stable-diffusion-pipe\"\n        prompt = \"a prompt\"\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            # Run training script with checkpointing\n            # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2\n            # Should create checkpoints at steps 2, 4, 6\n            # with checkpoint at step 2 deleted\n\n            initial_run_args = f\"\"\"\n                examples/text_to_image/train_text_to_image.py\n                --pretrained_model_name_or_path {pretrained_model_name_or_path}\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --center_crop\n                --random_flip\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 6\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --checkpoints_total_limit=2\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)\n            pipe(prompt, num_inference_steps=1)\n\n            # check checkpoint directories exist\n            # checkpoint-2 should have been deleted\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-4\", \"checkpoint-6\"})\n\n    def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        pretrained_model_name_or_path = \"hf-internal-testing/tiny-stable-diffusion-pipe\"\n        prompt = \"a prompt\"\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            # Run training script with checkpointing\n            # max_train_steps == 4, checkpointing_steps == 2\n            # Should create checkpoints at steps 2, 4\n\n            initial_run_args = f\"\"\"\n                examples/text_to_image/train_text_to_image.py\n                --pretrained_model_name_or_path {pretrained_model_name_or_path}\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --center_crop\n                --random_flip\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 4\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)\n            pipe(prompt, num_inference_steps=1)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\"},\n            )\n\n            # resume and we should try to checkpoint at 6, where we'll have to remove\n            # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint\n\n            resume_run_args = f\"\"\"\n                examples/text_to_image/train_text_to_image.py\n                --pretrained_model_name_or_path {pretrained_model_name_or_path}\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --center_crop\n                --random_flip\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 8\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --resume_from_checkpoint=checkpoint-4\n                --checkpoints_total_limit=2\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)\n            pipe(prompt, num_inference_steps=1)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-6\", \"checkpoint-8\"},\n            )\n\n\nclass TextToImageSDXL(ExamplesTestsAccelerate):\n    def test_text_to_image_sdxl(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/text_to_image/train_text_to_image_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --center_crop\n                --random_flip\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"unet\", \"diffusion_pytorch_model.safetensors\")))\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"scheduler\", \"scheduler_config.json\")))\n"
  },
  {
    "path": "examples/text_to_image/test_text_to_image_lora.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\nimport safetensors\n\nfrom diffusers import DiffusionPipeline  # noqa: E402\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass TextToImageLoRA(ExamplesTestsAccelerate):\n    def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):\n        prompt = \"a prompt\"\n        pipeline_path = \"hf-internal-testing/tiny-stable-diffusion-xl-pipe\"\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            # Run training script with checkpointing\n            # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2\n            # Should create checkpoints at steps 2, 4, 6\n            # with checkpoint at step 2 deleted\n\n            initial_run_args = f\"\"\"\n                examples/text_to_image/train_text_to_image_lora_sdxl.py\n                --pretrained_model_name_or_path {pipeline_path}\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 6\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --checkpoints_total_limit=2\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            pipe = DiffusionPipeline.from_pretrained(pipeline_path)\n            pipe.load_lora_weights(tmpdir)\n            pipe(prompt, num_inference_steps=1)\n\n            # check checkpoint directories exist\n            # checkpoint-2 should have been deleted\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-4\", \"checkpoint-6\"})\n\n    def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):\n        pretrained_model_name_or_path = \"hf-internal-testing/tiny-stable-diffusion-pipe\"\n        prompt = \"a prompt\"\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            # Run training script with checkpointing\n            # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2\n            # Should create checkpoints at steps 2, 4, 6\n            # with checkpoint at step 2 deleted\n\n            initial_run_args = f\"\"\"\n                examples/text_to_image/train_text_to_image_lora.py\n                --pretrained_model_name_or_path {pretrained_model_name_or_path}\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --center_crop\n                --random_flip\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 6\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --checkpoints_total_limit=2\n                --seed=0\n                --num_validation_images=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            pipe = DiffusionPipeline.from_pretrained(\n                \"hf-internal-testing/tiny-stable-diffusion-pipe\", safety_checker=None\n            )\n            pipe.load_lora_weights(tmpdir)\n            pipe(prompt, num_inference_steps=1)\n\n            # check checkpoint directories exist\n            # checkpoint-2 should have been deleted\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-4\", \"checkpoint-6\"})\n\n    def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        pretrained_model_name_or_path = \"hf-internal-testing/tiny-stable-diffusion-pipe\"\n        prompt = \"a prompt\"\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            # Run training script with checkpointing\n            # max_train_steps == 4, checkpointing_steps == 2\n            # Should create checkpoints at steps 2, 4\n\n            initial_run_args = f\"\"\"\n                examples/text_to_image/train_text_to_image_lora.py\n                --pretrained_model_name_or_path {pretrained_model_name_or_path}\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --center_crop\n                --random_flip\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 4\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --seed=0\n                --num_validation_images=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            pipe = DiffusionPipeline.from_pretrained(\n                \"hf-internal-testing/tiny-stable-diffusion-pipe\", safety_checker=None\n            )\n            pipe.load_lora_weights(tmpdir)\n            pipe(prompt, num_inference_steps=1)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\"},\n            )\n\n            # resume and we should try to checkpoint at 6, where we'll have to remove\n            # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint\n\n            resume_run_args = f\"\"\"\n                examples/text_to_image/train_text_to_image_lora.py\n                --pretrained_model_name_or_path {pretrained_model_name_or_path}\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --center_crop\n                --random_flip\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 8\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --resume_from_checkpoint=checkpoint-4\n                --checkpoints_total_limit=2\n                --seed=0\n                --num_validation_images=0\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            pipe = DiffusionPipeline.from_pretrained(\n                \"hf-internal-testing/tiny-stable-diffusion-pipe\", safety_checker=None\n            )\n            pipe.load_lora_weights(tmpdir)\n            pipe(prompt, num_inference_steps=1)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-6\", \"checkpoint-8\"},\n            )\n\n\nclass TextToImageLoRASDXL(ExamplesTestsAccelerate):\n    def test_text_to_image_lora_sdxl(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/text_to_image/train_text_to_image_lora_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n    def test_text_to_image_lora_sdxl_with_text_encoder(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/text_to_image/train_text_to_image_lora_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --train_text_encoder\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\")))\n\n            # make sure the state_dict has the correct naming in the parameters.\n            lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, \"pytorch_lora_weights.safetensors\"))\n            is_lora = all(\"lora\" in k for k in lora_state_dict.keys())\n            self.assertTrue(is_lora)\n\n            # when not training the text encoder, all the parameters in the state dict should start\n            # with `\"unet\"` or `\"text_encoder\"` or `\"text_encoder_2\"` in their names.\n            keys = lora_state_dict.keys()\n            starts_with_unet = all(\n                k.startswith(\"unet\") or k.startswith(\"text_encoder\") or k.startswith(\"text_encoder_2\") for k in keys\n            )\n            self.assertTrue(starts_with_unet)\n\n    def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):\n        prompt = \"a prompt\"\n        pipeline_path = \"hf-internal-testing/tiny-stable-diffusion-xl-pipe\"\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            # Run training script with checkpointing\n            # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2\n            # Should create checkpoints at steps 2, 4, 6\n            # with checkpoint at step 2 deleted\n\n            initial_run_args = f\"\"\"\n                examples/text_to_image/train_text_to_image_lora_sdxl.py\n                --pretrained_model_name_or_path {pipeline_path}\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 6\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --train_text_encoder\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --checkpoints_total_limit=2\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            pipe = DiffusionPipeline.from_pretrained(pipeline_path)\n            pipe.load_lora_weights(tmpdir)\n            pipe(prompt, num_inference_steps=1)\n\n            # check checkpoint directories exist\n            # checkpoint-2 should have been deleted\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-4\", \"checkpoint-6\"})\n"
  },
  {
    "path": "examples/text_to_image/train_text_to_image.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\nfrom transformers.utils import ContextManagers\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr\nfrom diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef save_model_card(\n    args,\n    repo_id: str,\n    images: list = None,\n    repo_folder: str = None,\n):\n    img_str = \"\"\n    if len(images) > 0:\n        image_grid = make_image_grid(images, 1, len(args.validation_prompts))\n        image_grid.save(os.path.join(repo_folder, \"val_imgs_grid.png\"))\n        img_str += \"![val_imgs_grid](./val_imgs_grid.png)\\n\"\n\n    model_description = f\"\"\"\n# Text-to-image finetuning - {repo_id}\n\nThis pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \\n\n{img_str}\n\n## Pipeline usage\n\nYou can use the pipeline like so:\n\n```python\nfrom diffusers import DiffusionPipeline\nimport torch\n\npipeline = DiffusionPipeline.from_pretrained(\"{repo_id}\", torch_dtype=torch.float16)\nprompt = \"{args.validation_prompts[0]}\"\nimage = pipeline(prompt).images[0]\nimage.save(\"my_image.png\")\n```\n\n## Training info\n\nThese are the key hyperparameters used during training:\n\n* Epochs: {args.num_train_epochs}\n* Learning rate: {args.learning_rate}\n* Batch size: {args.train_batch_size}\n* Gradient accumulation steps: {args.gradient_accumulation_steps}\n* Image resolution: {args.resolution}\n* Mixed-precision: {args.mixed_precision}\n\n\"\"\"\n    wandb_info = \"\"\n    if is_wandb_available():\n        wandb_run_url = None\n        if wandb.run is not None:\n            wandb_run_url = wandb.run.url\n\n    if wandb_run_url is not None:\n        wandb_info = f\"\"\"\nMore information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).\n\"\"\"\n\n    model_description += wandb_info\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=args.pretrained_model_name_or_path,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\"stable-diffusion\", \"stable-diffusion-diffusers\", \"text-to-image\", \"diffusers\", \"diffusers-training\"]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):\n    logger.info(\"Running validation... \")\n\n    pipeline = StableDiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=accelerator.unwrap_model(vae),\n        text_encoder=accelerator.unwrap_model(text_encoder),\n        tokenizer=tokenizer,\n        unet=accelerator.unwrap_model(unet),\n        safety_checker=None,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    images = []\n    for i in range(len(args.validation_prompts)):\n        if torch.backends.mps.is_available():\n            autocast_ctx = nullcontext()\n        else:\n            autocast_ctx = torch.autocast(accelerator.device.type)\n\n        with autocast_ctx:\n            image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]\n\n        images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompts[i]}\")\n                        for i, image in enumerate(images)\n                    ]\n                }\n            )\n        else:\n            logger.warning(f\"image logging not implemented for {tracker.name}\")\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    return images\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--input_perturbation\", type=float, default=0, help=\"The scale of input perturbation. Recommended 0.1.\"\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompts\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\"A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`.\"),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-model-finetuned\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--dream_training\",\n        action=\"store_true\",\n        help=(\n            \"Use the DREAM training method, which makes training more efficient and accurate at the \"\n            \"expense of doing an extra forward pass. See: https://huggingface.co/papers/2312.00210\"\n        ),\n    )\n    parser.add_argument(\n        \"--dream_detail_preservation\",\n        type=float,\n        default=1.0,\n        help=\"Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\"--offload_ema\", action=\"store_true\", help=\"Offload EMA model to CPU during training step.\")\n    parser.add_argument(\"--foreach_ema\", action=\"store_true\", help=\"Use faster foreach implementation of EMAModel.\")\n    parser.add_argument(\n        \"--non_ema_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or\"\n            \" remote repository specified with --pretrained_model_name_or_path.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=None,\n        help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.\",\n    )\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=5,\n        help=\"Run validation every X epochs.\",\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"text2image-fine-tune\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    # default to using the same revision for the non-ema model if not specified\n    if args.non_ema_revision is None:\n        args.non_ema_revision = args.revision\n\n    return args\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    if args.non_ema_revision is not None:\n        deprecate(\n            \"non_ema_revision!=None\",\n            \"0.15.0\",\n            message=(\n                \"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to\"\n                \" use `--variant=non_ema` instead.\"\n            ),\n        )\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n    )\n\n    def deepspeed_zero_init_disabled_context_manager():\n        \"\"\"\n        returns either a context list that includes one that will disable zero.Init or an empty context list\n        \"\"\"\n        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None\n        if deepspeed_plugin is None:\n            return []\n\n        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]\n\n    # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.\n    # For this to work properly all models must be run through `accelerate.prepare`. But accelerate\n    # will try to assign the same optimizer with the same weights to all models during\n    # `deepspeed.initialize`, which of course doesn't work.\n    #\n    # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2\n    # frozen models from being partitioned during `zero.Init` which gets called during\n    # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding\n    # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.\n    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n        text_encoder = CLIPTextModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n        )\n        vae = AutoencoderKL.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n        )\n\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.non_ema_revision\n    )\n\n    # Freeze vae and text_encoder and set unet to trainable\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n    unet.train()\n\n    # Create EMA for the unet.\n    if args.use_ema:\n        ema_unet = UNet2DConditionModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n        )\n        ema_unet = EMAModel(\n            ema_unet.parameters(),\n            model_cls=UNet2DConditionModel,\n            model_config=ema_unet.config,\n            foreach=args.foreach_ema,\n        )\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_unet.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(\n                    os.path.join(input_dir, \"unet_ema\"), UNet2DConditionModel, foreach=args.foreach_ema\n                )\n                ema_unet.load_state_dict(load_model.state_dict())\n                if args.offload_ema:\n                    ema_unet.pin_memory()\n                else:\n                    ema_unet.to(accelerator.device)\n                del load_model\n\n            for _ in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    optimizer = optimizer_cls(\n        unet.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            data_dir=args.train_data_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        return inputs.input_ids\n\n    # Get the specified interpolation method from the args\n    interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)\n\n    # Raise an error if the interpolation method is invalid\n    if interpolation is None:\n        raise ValueError(f\"Unsupported interpolation mode {args.image_interpolation_mode}.\")\n\n    # Data preprocessing transformations\n    train_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=interpolation),  # Use dynamic interpolation method\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n        examples[\"input_ids\"] = tokenize_captions(examples)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n        return {\"pixel_values\": pixel_values, \"input_ids\": input_ids}\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    if args.use_ema:\n        if args.offload_ema:\n            ema_unet.pin_memory()\n        else:\n            ema_unet.to(accelerator.device)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n        args.mixed_precision = accelerator.mixed_precision\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n        args.mixed_precision = accelerator.mixed_precision\n\n    # Move text_encode and vae to gpu and cast to weight_dtype\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        tracker_config.pop(\"validation_prompts\")\n        accelerator.init_trackers(args.tracker_project_name, tracker_config)\n\n    # Function for unwrapping if model was compiled with `torch.compile`.\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (latents.shape[0], latents.shape[1], 1, 1), device=latents.device\n                    )\n                if args.input_perturbation:\n                    new_noise = noise + args.input_perturbation * torch.randn_like(noise)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                if args.input_perturbation:\n                    noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)\n                else:\n                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"], return_dict=False)[0]\n\n                # Get the target for loss depending on the prediction type\n                if args.prediction_type is not None:\n                    # set prediction_type of scheduler if defined\n                    noise_scheduler.register_to_config(prediction_type=args.prediction_type)\n\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.dream_training:\n                    noisy_latents, target = compute_dream_and_update_latents(\n                        unet,\n                        noise_scheduler,\n                        timesteps,\n                        noise,\n                        noisy_latents,\n                        target,\n                        encoder_hidden_states,\n                        args.dream_detail_preservation,\n                    )\n\n                # Predict the noise residual and compute loss\n                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]\n\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    if args.offload_ema:\n                        ema_unet.to(device=\"cuda\", non_blocking=True)\n                    ema_unet.step(unet.parameters())\n                    if args.offload_ema:\n                        ema_unet.to(device=\"cpu\", non_blocking=True)\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompts is not None and epoch % args.validation_epochs == 0:\n                if args.use_ema:\n                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                    ema_unet.store(unet.parameters())\n                    ema_unet.copy_to(unet.parameters())\n                log_validation(\n                    vae,\n                    text_encoder,\n                    tokenizer,\n                    unet,\n                    args,\n                    accelerator,\n                    weight_dtype,\n                    global_step,\n                )\n                if args.use_ema:\n                    # Switch back to the original UNet parameters.\n                    ema_unet.restore(unet.parameters())\n\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        if args.use_ema:\n            ema_unet.copy_to(unet.parameters())\n\n        pipeline = StableDiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            text_encoder=text_encoder,\n            vae=vae,\n            unet=unet,\n            revision=args.revision,\n            variant=args.variant,\n        )\n        pipeline.save_pretrained(args.output_dir)\n\n        # Run a final round of inference.\n        images = []\n        if args.validation_prompts is not None:\n            logger.info(\"Running inference for collecting generated images...\")\n            pipeline = pipeline.to(accelerator.device)\n            pipeline.torch_dtype = weight_dtype\n            pipeline.set_progress_bar_config(disable=True)\n\n            if args.enable_xformers_memory_efficient_attention:\n                pipeline.enable_xformers_memory_efficient_attention()\n\n            if args.seed is None:\n                generator = None\n            else:\n                generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n            for i in range(len(args.validation_prompts)):\n                with torch.autocast(\"cuda\"):\n                    image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]\n                images.append(image)\n\n        if args.push_to_hub:\n            save_model_card(args, repo_id, images, repo_folder=args.output_dir)\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/text_to_image/train_text_to_image_flax.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nfrom pathlib import Path\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom datasets import load_dataset\nfrom flax import jax_utils\nfrom flax.training import train_state\nfrom flax.training.common_utils import shard\nfrom huggingface_hub import create_repo, upload_folder\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed\n\nfrom diffusers import (\n    FlaxAutoencoderKL,\n    FlaxDDPMScheduler,\n    FlaxPNDMScheduler,\n    FlaxStableDiffusionPipeline,\n    FlaxUNet2DConditionModel,\n)\nfrom diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker\nfrom diffusers.utils import check_min_version\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = logging.getLogger(__name__)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-model-finetuned\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=0, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--from_pt\",\n        action=\"store_true\",\n        default=False,\n        help=\"Flag to indicate whether to convert models from PyTorch.\",\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    return args\n\n\ndataset_name_mapping = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef get_params_to_save(params):\n    return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))\n\n\ndef main():\n    args = parse_args()\n\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    # Setup logging, we only want one process per machine to log things on the screen.\n    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)\n    if jax.process_index() == 0:\n        transformers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if jax.process_index() == 0:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = dataset_name_mapping.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding=\"do_not_pad\", truncation=True)\n        input_ids = inputs.input_ids\n        return input_ids\n\n    train_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n        examples[\"input_ids\"] = tokenize_captions(examples)\n\n        return examples\n\n    if args.max_train_samples is not None:\n        dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n    train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        input_ids = [example[\"input_ids\"] for example in examples]\n\n        padded_tokens = tokenizer.pad(\n            {\"input_ids\": input_ids}, padding=\"max_length\", max_length=tokenizer.model_max_length, return_tensors=\"pt\"\n        )\n        batch = {\n            \"pixel_values\": pixel_values,\n            \"input_ids\": padded_tokens.input_ids,\n        }\n        batch = {k: v.numpy() for k, v in batch.items()}\n\n        return batch\n\n    total_train_batch_size = args.train_batch_size * jax.local_device_count()\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True\n    )\n\n    weight_dtype = jnp.float32\n    if args.mixed_precision == \"fp16\":\n        weight_dtype = jnp.float16\n    elif args.mixed_precision == \"bf16\":\n        weight_dtype = jnp.bfloat16\n\n    # Load models and create wrapper for stable diffusion\n    tokenizer = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        from_pt=args.from_pt,\n        revision=args.revision,\n        subfolder=\"tokenizer\",\n    )\n    text_encoder = FlaxCLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        from_pt=args.from_pt,\n        revision=args.revision,\n        subfolder=\"text_encoder\",\n        dtype=weight_dtype,\n    )\n    vae, vae_params = FlaxAutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        from_pt=args.from_pt,\n        revision=args.revision,\n        subfolder=\"vae\",\n        dtype=weight_dtype,\n    )\n    unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path,\n        from_pt=args.from_pt,\n        revision=args.revision,\n        subfolder=\"unet\",\n        dtype=weight_dtype,\n    )\n\n    # Optimization\n    if args.scale_lr:\n        args.learning_rate = args.learning_rate * total_train_batch_size\n\n    constant_scheduler = optax.constant_schedule(args.learning_rate)\n\n    adamw = optax.adamw(\n        learning_rate=constant_scheduler,\n        b1=args.adam_beta1,\n        b2=args.adam_beta2,\n        eps=args.adam_epsilon,\n        weight_decay=args.adam_weight_decay,\n    )\n\n    optimizer = optax.chain(\n        optax.clip_by_global_norm(args.max_grad_norm),\n        adamw,\n    )\n\n    state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)\n\n    noise_scheduler = FlaxDDPMScheduler(\n        beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000\n    )\n    noise_scheduler_state = noise_scheduler.create_state()\n\n    # Initialize our training\n    rng = jax.random.PRNGKey(args.seed)\n    train_rngs = jax.random.split(rng, jax.local_device_count())\n\n    def train_step(state, text_encoder_params, vae_params, batch, train_rng):\n        dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)\n\n        def compute_loss(params):\n            # Convert images to latent space\n            vae_outputs = vae.apply(\n                {\"params\": vae_params}, batch[\"pixel_values\"], deterministic=True, method=vae.encode\n            )\n            latents = vae_outputs.latent_dist.sample(sample_rng)\n            # (NHWC) -> (NCHW)\n            latents = jnp.transpose(latents, (0, 3, 1, 2))\n            latents = latents * vae.config.scaling_factor\n\n            # Sample noise that we'll add to the latents\n            noise_rng, timestep_rng = jax.random.split(sample_rng)\n            noise = jax.random.normal(noise_rng, latents.shape)\n            # Sample a random timestep for each image\n            bsz = latents.shape[0]\n            timesteps = jax.random.randint(\n                timestep_rng,\n                (bsz,),\n                0,\n                noise_scheduler.config.num_train_timesteps,\n            )\n\n            # Add noise to the latents according to the noise magnitude at each timestep\n            # (this is the forward diffusion process)\n            noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)\n\n            # Get the text embedding for conditioning\n            encoder_hidden_states = text_encoder(\n                batch[\"input_ids\"],\n                params=text_encoder_params,\n                train=False,\n            )[0]\n\n            # Predict the noise residual and compute loss\n            model_pred = unet.apply(\n                {\"params\": params}, noisy_latents, timesteps, encoder_hidden_states, train=True\n            ).sample\n\n            # Get the target for loss depending on the prediction type\n            if noise_scheduler.config.prediction_type == \"epsilon\":\n                target = noise\n            elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)\n            else:\n                raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n            loss = (target - model_pred) ** 2\n            loss = loss.mean()\n\n            return loss\n\n        grad_fn = jax.value_and_grad(compute_loss)\n        loss, grad = grad_fn(state.params)\n        grad = jax.lax.pmean(grad, \"batch\")\n\n        new_state = state.apply_gradients(grads=grad)\n\n        metrics = {\"loss\": loss}\n        metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n\n        return new_state, metrics, new_train_rng\n\n    # Create parallel version of the train step\n    p_train_step = jax.pmap(train_step, \"batch\", donate_argnums=(0,))\n\n    # Replicate the train state on each device\n    state = jax_utils.replicate(state)\n    text_encoder_params = jax_utils.replicate(text_encoder.params)\n    vae_params = jax_utils.replicate(vae_params)\n\n    # Train!\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader))\n\n    # Scheduler and math around the number of training steps.\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel & distributed) = {total_train_batch_size}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n\n    global_step = 0\n\n    epochs = tqdm(range(args.num_train_epochs), desc=\"Epoch ... \", position=0)\n    for epoch in epochs:\n        # ======================== Training ================================\n\n        train_metrics = []\n\n        steps_per_epoch = len(train_dataset) // total_train_batch_size\n        train_step_progress_bar = tqdm(total=steps_per_epoch, desc=\"Training...\", position=1, leave=False)\n        # train\n        for batch in train_dataloader:\n            batch = shard(batch)\n            state, train_metric, train_rngs = p_train_step(state, text_encoder_params, vae_params, batch, train_rngs)\n            train_metrics.append(train_metric)\n\n            train_step_progress_bar.update(1)\n\n            global_step += 1\n            if global_step >= args.max_train_steps:\n                break\n\n        train_metric = jax_utils.unreplicate(train_metric)\n\n        train_step_progress_bar.close()\n        epochs.write(f\"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})\")\n\n    # Create the pipeline using using the trained modules and save it.\n    if jax.process_index() == 0:\n        scheduler = FlaxPNDMScheduler(\n            beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", skip_prk_steps=True\n        )\n        safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(\n            \"CompVis/stable-diffusion-safety-checker\", from_pt=True\n        )\n        pipeline = FlaxStableDiffusionPipeline(\n            text_encoder=text_encoder,\n            vae=vae,\n            unet=unet,\n            tokenizer=tokenizer,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=CLIPImageProcessor.from_pretrained(\"openai/clip-vit-base-patch32\"),\n        )\n\n        pipeline.save_pretrained(\n            args.output_dir,\n            params={\n                \"text_encoder\": get_params_to_save(text_encoder_params),\n                \"vae\": get_params_to_save(vae_params),\n                \"unet\": get_params_to_save(state.params),\n                \"safety_checker\": safety_checker.params,\n            },\n        )\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/text_to_image/train_text_to_image_lora.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Fine-tuning script for Stable Diffusion for text2image with support for LoRA.\"\"\"\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig\nfrom peft.utils import get_peft_model_state_dict, set_peft_model_state_dict\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import cast_training_params, compute_snr\nfrom diffusers.utils import (\n    check_min_version,\n    convert_state_dict_to_diffusers,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef save_model_card(\n    repo_id: str,\n    images: list = None,\n    base_model: str = None,\n    dataset_name: str = None,\n    repo_folder: str = None,\n):\n    img_str = \"\"\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# LoRA text2image fine-tuning - {repo_id}\nThese are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \\n\n{img_str}\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion\",\n        \"stable-diffusion-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"diffusers-training\",\n        \"lora\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    epoch,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n    generator = torch.Generator(device=accelerator.device)\n    if args.seed is not None:\n        generator = generator.manual_seed(args.seed)\n    images = []\n    if torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type)\n\n    with autocast_ctx:\n        for _ in range(args.num_validation_images):\n            images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n    return images\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\", type=str, default=None, help=\"A prompt that is sampled during training for inference.\"\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-model-finetuned-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=None,\n        help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.\",\n    )\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    return args\n\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef main():\n    args = parse_args()\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n    # Load scheduler, tokenizer and models.\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    tokenizer = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"tokenizer\", revision=args.revision\n    )\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n    # freeze parameters of models to save more memory\n    unet.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # Add adapter and make sure the trainable params are in float32.\n    unet.add_adapter(unet_lora_config)\n    if args.mixed_precision == \"fp16\":\n        # only upcast trainable parameters (LoRA) into fp32\n        cast_training_params(unet, dtype=torch.float32)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    lora_layers = filter(lambda p: p.requires_grad, unet.parameters())\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    optimizer = optimizer_cls(\n        lora_layers,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            data_dir=args.train_data_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        inputs = tokenizer(\n            captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        return inputs.input_ids\n\n    # Get the specified interpolation method from the args\n    interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)\n\n    # Raise an error if the interpolation method is invalid\n    if interpolation is None:\n        raise ValueError(f\"Unsupported interpolation mode {args.image_interpolation_mode}.\")\n\n    # Data preprocessing transformations\n    train_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=interpolation),  # Use dynamic interpolation method\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n        examples[\"input_ids\"] = tokenize_captions(examples)\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n        return {\"pixel_values\": pixel_values, \"input_ids\": input_ids}\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            unet_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(model, type(unwrap_model(unet))):\n                    unet_lora_layers_to_save = get_peft_model_state_dict(model)\n                else:\n                    raise ValueError(f\"Unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n            StableDiffusionPipeline.save_lora_weights(\n                save_directory=output_dir,\n                unet_lora_layers=unet_lora_layers_to_save,\n                safe_serialization=True,\n            )\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n            if isinstance(model, type(unwrap_model(unet))):\n                unet_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        # returns a tuple of state dictionary and network alphas\n        lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)\n\n        unet_state_dict = {f\"{k.replace('unet.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"unet.\")}\n        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)\n        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name=\"default\")\n\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            # throw warning if some unexpected keys are found and continue loading\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        # Make sure the trainable params are in float32\n        if args.mixed_precision in [\"fp16\"]:\n            cast_training_params([unet_], dtype=torch.float32)\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # Register the hooks for efficient saving and loading of LoRA weights\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"text2image-fine-tune\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (latents.shape[0], latents.shape[1], 1, 1), device=latents.device\n                    )\n\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"], return_dict=False)[0]\n\n                # Get the target for loss depending on the prediction type\n                if args.prediction_type is not None:\n                    # set prediction_type of scheduler if defined\n                    noise_scheduler.register_to_config(prediction_type=args.prediction_type)\n\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                # Predict the noise residual and compute loss\n                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]\n\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = lora_layers\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = DiffusionPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    unet=unwrap_model(unet),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                images = log_validation(pipeline, args, accelerator, epoch)\n\n                del pipeline\n                torch.cuda.empty_cache()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unet.to(torch.float32)\n\n        unwrapped_unet = unwrap_model(unet)\n        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))\n        StableDiffusionPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_state_dict,\n            safe_serialization=True,\n        )\n\n        # Final inference\n        # Load previous pipeline\n        if args.validation_prompt is not None:\n            pipeline = DiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                revision=args.revision,\n                variant=args.variant,\n                torch_dtype=weight_dtype,\n            )\n\n            # load attention processors\n            pipeline.load_lora_weights(args.output_dir)\n\n            # run inference\n            images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True)\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                dataset_name=args.dataset_name,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/text_to_image/train_text_to_image_lora_sdxl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA.\"\"\"\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom peft import LoraConfig, set_peft_model_state_dict\nfrom peft.utils import get_peft_model_state_dict\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    StableDiffusionXLPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.loaders import StableDiffusionLoraLoaderMixin\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr\nfrom diffusers.utils import (\n    check_min_version,\n    convert_state_dict_to_diffusers,\n    convert_unet_state_dict_to_peft,\n    is_wandb_available,\n)\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\nif is_torch_npu_available():\n    torch.npu.config.allow_internal_format = False\n\n\ndef save_model_card(\n    repo_id: str,\n    images: list = None,\n    base_model: str = None,\n    dataset_name: str = None,\n    train_text_encoder: bool = False,\n    repo_folder: str = None,\n    vae_path: str = None,\n):\n    img_str = \"\"\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# LoRA text2image fine-tuning - {repo_id}\n\nThese are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \\n\n{img_str}\n\nLoRA for the text encoder was enabled: {train_text_encoder}.\n\nSpecial VAE used for training: {vae_path}.\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion-xl\",\n        \"stable-diffusion-xl-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"diffusers-training\",\n        \"lora\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    epoch,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n    pipeline_args = {\"prompt\": args.validation_prompt}\n    if torch.backends.mps.is_available():\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type)\n\n    with autocast_ctx:\n        images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sd-model-finetuned-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=None,\n        help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.\",\n    )\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--enable_npu_flash_attention\", action=\"store_true\", help=\"Whether or not to use npu flash attention.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n    parser.add_argument(\n        \"--rank\",\n        type=int,\n        default=4,\n        help=(\"The dimension of the LoRA update matrices.\"),\n    )\n    parser.add_argument(\n        \"--debug_loss\",\n        action=\"store_true\",\n        help=\"debug loss for each image, if filenames are available in the dataset\",\n    )\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    return args\n\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef tokenize_prompt(tokenizer, prompt):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=tokenizer.model_max_length,\n        truncation=True,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):\n    prompt_embeds_list = []\n\n    for i, text_encoder in enumerate(text_encoders):\n        if tokenizers is not None:\n            tokenizer = tokenizers[i]\n            text_input_ids = tokenize_prompt(tokenizer, prompt)\n        else:\n            assert text_input_ids_list is not None\n            text_input_ids = text_input_ids_list[i]\n\n        prompt_embeds = text_encoder(\n            text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False\n        )\n\n        # We are only ALWAYS interested in the pooled output of the final text encoder\n        pooled_prompt_embeds = prompt_embeds[0]\n        prompt_embeds = prompt_embeds[-1][-2]\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n        prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # We only train the additional adapter LoRA layers\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    unet.to(accelerator.device, dtype=weight_dtype)\n\n    if args.pretrained_vae_model_name_or_path is None:\n        vae.to(accelerator.device, dtype=torch.float32)\n    else:\n        vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            unet.enable_npu_flash_attention()\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu devices.\")\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # now we will add new LoRA weights to the attention layers\n    # Set correct lora layers\n    unet_lora_config = LoraConfig(\n        r=args.rank,\n        lora_alpha=args.rank,\n        init_lora_weights=\"gaussian\",\n        target_modules=[\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"],\n    )\n\n    unet.add_adapter(unet_lora_config)\n\n    # The text encoder comes from 🤗 transformers, we will also attach adapters to it.\n    if args.train_text_encoder:\n        # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16\n        text_lora_config = LoraConfig(\n            r=args.rank,\n            lora_alpha=args.rank,\n            init_lora_weights=\"gaussian\",\n            target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"],\n        )\n        text_encoder_one.add_adapter(text_lora_config)\n        text_encoder_two.add_adapter(text_lora_config)\n\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        if accelerator.is_main_process:\n            # there are only two options here. Either are just the unet attn processor layers\n            # or there are the unet and text encoder attn layers\n            unet_lora_layers_to_save = None\n            text_encoder_one_lora_layers_to_save = None\n            text_encoder_two_lora_layers_to_save = None\n\n            for model in models:\n                if isinstance(unwrap_model(model), type(unwrap_model(unet))):\n                    unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))\n                elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):\n                    text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(\n                        get_peft_model_state_dict(model)\n                    )\n                elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):\n                    text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(\n                        get_peft_model_state_dict(model)\n                    )\n                else:\n                    raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n                # make sure to pop weight so that corresponding model is not saved again\n                if weights:\n                    weights.pop()\n\n            StableDiffusionXLPipeline.save_lora_weights(\n                output_dir,\n                unet_lora_layers=unet_lora_layers_to_save,\n                text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,\n                text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,\n            )\n\n    def load_model_hook(models, input_dir):\n        unet_ = None\n        text_encoder_one_ = None\n        text_encoder_two_ = None\n\n        while len(models) > 0:\n            model = models.pop()\n\n            if isinstance(model, type(unwrap_model(unet))):\n                unet_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder_one))):\n                text_encoder_one_ = model\n            elif isinstance(model, type(unwrap_model(text_encoder_two))):\n                text_encoder_two_ = model\n            else:\n                raise ValueError(f\"unexpected save model: {model.__class__}\")\n\n        lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)\n        unet_state_dict = {f\"{k.replace('unet.', '')}\": v for k, v in lora_state_dict.items() if k.startswith(\"unet.\")}\n        unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)\n        incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name=\"default\")\n        if incompatible_keys is not None:\n            # check only for unexpected keys\n            unexpected_keys = getattr(incompatible_keys, \"unexpected_keys\", None)\n            if unexpected_keys:\n                logger.warning(\n                    f\"Loading adapter weights from state_dict led to unexpected keys not found in the model: \"\n                    f\" {unexpected_keys}. \"\n                )\n\n        if args.train_text_encoder:\n            _set_state_dict_into_text_encoder(lora_state_dict, prefix=\"text_encoder.\", text_encoder=text_encoder_one_)\n\n            _set_state_dict_into_text_encoder(\n                lora_state_dict, prefix=\"text_encoder_2.\", text_encoder=text_encoder_two_\n            )\n\n        # Make sure the trainable params are in float32. This is again needed since the base models\n        # are in `weight_dtype`. More details:\n        # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804\n        if args.mixed_precision == \"fp16\":\n            models = [unet_]\n            if args.train_text_encoder:\n                models.extend([text_encoder_one_, text_encoder_two_])\n            cast_training_params(models, dtype=torch.float32)\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder_one.gradient_checkpointing_enable()\n            text_encoder_two.gradient_checkpointing_enable()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [unet]\n        if args.train_text_encoder:\n            models.extend([text_encoder_one, text_encoder_two])\n        cast_training_params(models, dtype=torch.float32)\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))\n    if args.train_text_encoder:\n        params_to_optimize = (\n            params_to_optimize\n            + list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))\n            + list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))\n        )\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    # We need to tokenize input captions and transform the images.\n    def tokenize_captions(examples, is_train=True):\n        captions = []\n        for caption in examples[caption_column]:\n            if isinstance(caption, str):\n                captions.append(caption)\n            elif isinstance(caption, (list, np.ndarray)):\n                # take a random caption if there are multiple\n                captions.append(random.choice(caption) if is_train else caption[0])\n            else:\n                raise ValueError(\n                    f\"Caption column `{caption_column}` should contain either strings or lists of strings.\"\n                )\n        tokens_one = tokenize_prompt(tokenizer_one, captions)\n        tokens_two = tokenize_prompt(tokenizer_two, captions)\n        return tokens_one, tokens_two\n\n    # Get the specified interpolation method from the args\n    interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)\n\n    # Raise an error if the interpolation method is invalid\n    if interpolation is None:\n        raise ValueError(f\"Unsupported interpolation mode {args.image_interpolation_mode}.\")\n    # Preprocessing the datasets.\n    train_resize = transforms.Resize(args.resolution, interpolation=interpolation)  # Use dynamic interpolation method\n    train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)\n    train_flip = transforms.RandomHorizontalFlip(p=1.0)\n    train_transforms = transforms.Compose(\n        [\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        # image aug\n        original_sizes = []\n        all_images = []\n        crop_top_lefts = []\n        for image in images:\n            original_sizes.append((image.height, image.width))\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            crop_top_left = (y1, x1)\n            crop_top_lefts.append(crop_top_left)\n            image = train_transforms(image)\n            all_images.append(image)\n\n        examples[\"original_sizes\"] = original_sizes\n        examples[\"crop_top_lefts\"] = crop_top_lefts\n        examples[\"pixel_values\"] = all_images\n        tokens_one, tokens_two = tokenize_captions(examples)\n        examples[\"input_ids_one\"] = tokens_one\n        examples[\"input_ids_two\"] = tokens_two\n        if args.debug_loss:\n            fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]\n            if fnames:\n                examples[\"filenames\"] = fnames\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train, output_all_columns=True)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        original_sizes = [example[\"original_sizes\"] for example in examples]\n        crop_top_lefts = [example[\"crop_top_lefts\"] for example in examples]\n        input_ids_one = torch.stack([example[\"input_ids_one\"] for example in examples])\n        input_ids_two = torch.stack([example[\"input_ids_two\"] for example in examples])\n        result = {\n            \"pixel_values\": pixel_values,\n            \"input_ids_one\": input_ids_one,\n            \"input_ids_two\": input_ids_two,\n            \"original_sizes\": original_sizes,\n            \"crop_top_lefts\": crop_top_lefts,\n        }\n\n        filenames = [example[\"filenames\"] for example in examples if \"filenames\" in example]\n        if filenames:\n            result[\"filenames\"] = filenames\n        return result\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"text2image-fine-tune\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.train_text_encoder:\n            text_encoder_one.train()\n            text_encoder_two.train()\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                if args.pretrained_vae_model_name_or_path is not None:\n                    pixel_values = batch[\"pixel_values\"].to(dtype=weight_dtype)\n                else:\n                    pixel_values = batch[\"pixel_values\"]\n\n                model_input = vae.encode(pixel_values).latent_dist.sample()\n                model_input = model_input * vae.config.scaling_factor\n                if args.pretrained_vae_model_name_or_path is None:\n                    model_input = model_input.to(weight_dtype)\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(model_input)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device\n                    )\n\n                bsz = model_input.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(\n                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                )\n                timesteps = timesteps.long()\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)\n\n                # time ids\n                def compute_time_ids(original_size, crops_coords_top_left):\n                    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n                    target_size = (args.resolution, args.resolution)\n                    add_time_ids = list(original_size + crops_coords_top_left + target_size)\n                    add_time_ids = torch.tensor([add_time_ids])\n                    add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)\n                    return add_time_ids\n\n                add_time_ids = torch.cat(\n                    [compute_time_ids(s, c) for s, c in zip(batch[\"original_sizes\"], batch[\"crop_top_lefts\"])]\n                )\n\n                # Predict the noise residual\n                unet_added_conditions = {\"time_ids\": add_time_ids}\n                prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                    text_encoders=[text_encoder_one, text_encoder_two],\n                    tokenizers=None,\n                    prompt=None,\n                    text_input_ids_list=[batch[\"input_ids_one\"], batch[\"input_ids_two\"]],\n                )\n                unet_added_conditions.update({\"text_embeds\": pooled_prompt_embeds})\n                model_pred = unet(\n                    noisy_model_input,\n                    timesteps,\n                    prompt_embeds,\n                    added_cond_kwargs=unet_added_conditions,\n                    return_dict=False,\n                )[0]\n\n                # Get the target for loss depending on the prediction type\n                if args.prediction_type is not None:\n                    # set prediction_type of scheduler if defined\n                    noise_scheduler.register_to_config(prediction_type=args.prediction_type)\n\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n                if args.debug_loss and \"filenames\" in batch:\n                    for fname in batch[\"filenames\"]:\n                        accelerator.log({\"loss_for_\" + fname: loss}, step=global_step)\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.\n                if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                pipeline = StableDiffusionXLPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    text_encoder=unwrap_model(text_encoder_one),\n                    text_encoder_2=unwrap_model(text_encoder_two),\n                    unet=unwrap_model(unet),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n\n                images = log_validation(pipeline, args, accelerator, epoch)\n\n                del pipeline\n                torch.cuda.empty_cache()\n\n    # Save the lora layers\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))\n\n        if args.train_text_encoder:\n            text_encoder_one = unwrap_model(text_encoder_one)\n            text_encoder_two = unwrap_model(text_encoder_two)\n\n            text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))\n            text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))\n        else:\n            text_encoder_lora_layers = None\n            text_encoder_2_lora_layers = None\n\n        StableDiffusionXLPipeline.save_lora_weights(\n            save_directory=args.output_dir,\n            unet_lora_layers=unet_lora_state_dict,\n            text_encoder_lora_layers=text_encoder_lora_layers,\n            text_encoder_2_lora_layers=text_encoder_2_lora_layers,\n        )\n\n        del unet\n        del text_encoder_one\n        del text_encoder_two\n        del text_encoder_lora_layers\n        del text_encoder_2_lora_layers\n        torch.cuda.empty_cache()\n\n        # Final inference\n        # Make sure vae.dtype is consistent with the unet.dtype\n        if args.mixed_precision == \"fp16\":\n            vae.to(weight_dtype)\n        # Load previous pipeline\n        pipeline = StableDiffusionXLPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            vae=vae,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n\n        # load attention processors\n        pipeline.load_lora_weights(args.output_dir)\n\n        # run inference\n        if args.validation_prompt and args.num_validation_images > 0:\n            images = log_validation(pipeline, args, accelerator, epoch, is_final_validation=True)\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                dataset_name=args.dataset_name,\n                train_text_encoder=args.train_text_encoder,\n                repo_folder=args.output_dir,\n                vae_path=args.pretrained_vae_model_name_or_path,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/text_to_image/train_text_to_image_sdxl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Fine-tuning script for Stable Diffusion XL for text2image.\"\"\"\n\nimport argparse\nimport functools\nimport gc\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedType, ProjectConfiguration, set_seed\nfrom datasets import concatenate_datasets, load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel, compute_snr\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available\nfrom diffusers.utils.torch_utils import is_compiled_module\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\nif is_torch_npu_available():\n    import torch_npu\n\n    torch.npu.config.allow_internal_format = False\n\nDATASET_NAME_MAPPING = {\n    \"lambdalabs/naruto-blip-captions\": (\"image\", \"text\"),\n}\n\n\ndef save_model_card(\n    repo_id: str,\n    images: list = None,\n    validation_prompt: str = None,\n    base_model: str = None,\n    dataset_name: str = None,\n    repo_folder: str = None,\n    vae_path: str = None,\n):\n    img_str = \"\"\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# Text-to-image finetuning - {repo_id}\n\nThis pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \\n\n{img_str}\n\nSpecial VAE used for training: {vae_path}.\n\"\"\"\n\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion-xl\",\n        \"stable-diffusion-xl-diffusers\",\n        \"text-to-image\",\n        \"diffusers-training\",\n        \"diffusers\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"CLIPTextModelWithProjection\":\n        from transformers import CLIPTextModelWithProjection\n\n        return CLIPTextModelWithProjection\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--pretrained_vae_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run fine-tuning validation every X epochs. The validation process consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"sdxl-model-finetuned\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--timestep_bias_strategy\",\n        type=str,\n        default=\"none\",\n        choices=[\"earlier\", \"later\", \"range\", \"none\"],\n        help=(\n            \"The timestep bias strategy, which may help direct the model toward learning low or high frequency details.\"\n            \" Choices: ['earlier', 'later', 'range', 'none'].\"\n            \" The default is 'none', which means no bias is applied, and training proceeds normally.\"\n            \" The value of 'later' will increase the frequency of the model's final training timesteps.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_bias_multiplier\",\n        type=float,\n        default=1.0,\n        help=(\n            \"The multiplier for the bias. Defaults to 1.0, which means no bias is applied.\"\n            \" A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_bias_begin\",\n        type=int,\n        default=0,\n        help=(\n            \"When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias.\"\n            \" Defaults to zero, which equates to having no specific bias.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_bias_end\",\n        type=int,\n        default=1000,\n        help=(\n            \"When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias.\"\n            \" Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on.\"\n        ),\n    )\n    parser.add_argument(\n        \"--timestep_bias_portion\",\n        type=float,\n        default=0.25,\n        help=(\n            \"The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased.\"\n            \" A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines\"\n            \" whether the biased portions are in the earlier or later timesteps.\"\n        ),\n    )\n    parser.add_argument(\n        \"--snr_gamma\",\n        type=float,\n        default=None,\n        help=\"SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. \"\n        \"More details here: https://huggingface.co/papers/2303.09556.\",\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=None,\n        help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.\",\n    )\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_npu_flash_attention\", action=\"store_true\", help=\"Whether or not to use npu flash attention.\"\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\"--noise_offset\", type=float, default=0, help=\"The scale of noise offset.\")\n    parser.add_argument(\n        \"--image_interpolation_mode\",\n        type=str,\n        default=\"lanczos\",\n        choices=[\n            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith(\"__\") and not f.endswith(\"__\")\n        ],\n        help=\"The image interpolation method to use for resizing images.\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n    if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:\n        raise ValueError(\"`--proportion_empty_prompts` must be in the range [0, 1].\")\n\n    return args\n\n\n# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt\ndef encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):\n    prompt_embeds_list = []\n    prompt_batch = batch[caption_column]\n\n    captions = []\n    for caption in prompt_batch:\n        if random.random() < proportion_empty_prompts:\n            captions.append(\"\")\n        elif isinstance(caption, str):\n            captions.append(caption)\n        elif isinstance(caption, (list, np.ndarray)):\n            # take a random caption if there are multiple\n            captions.append(random.choice(caption) if is_train else caption[0])\n\n    with torch.no_grad():\n        for tokenizer, text_encoder in zip(tokenizers, text_encoders):\n            text_inputs = tokenizer(\n                captions,\n                padding=\"max_length\",\n                max_length=tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            prompt_embeds = text_encoder(\n                text_input_ids.to(text_encoder.device),\n                output_hidden_states=True,\n                return_dict=False,\n            )\n\n            # We are only ALWAYS interested in the pooled output of the final text encoder\n            pooled_prompt_embeds = prompt_embeds[0]\n            prompt_embeds = prompt_embeds[-1][-2]\n            bs_embed, seq_len, _ = prompt_embeds.shape\n            prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n            prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return {\"prompt_embeds\": prompt_embeds.cpu(), \"pooled_prompt_embeds\": pooled_prompt_embeds.cpu()}\n\n\ndef compute_vae_encodings(batch, vae):\n    images = batch.pop(\"pixel_values\")\n    pixel_values = torch.stack(list(images))\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n    pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)\n\n    with torch.no_grad():\n        model_input = vae.encode(pixel_values).latent_dist.sample()\n    model_input = model_input * vae.config.scaling_factor\n\n    # There might have slightly performance improvement\n    # by changing model_input.cpu() to accelerator.gather(model_input)\n    return {\"model_input\": model_input.cpu()}\n\n\ndef generate_timestep_weights(args, num_timesteps):\n    weights = torch.ones(num_timesteps)\n\n    # Determine the indices to bias\n    num_to_bias = int(args.timestep_bias_portion * num_timesteps)\n\n    if args.timestep_bias_strategy == \"later\":\n        bias_indices = slice(-num_to_bias, None)\n    elif args.timestep_bias_strategy == \"earlier\":\n        bias_indices = slice(0, num_to_bias)\n    elif args.timestep_bias_strategy == \"range\":\n        # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.\n        range_begin = args.timestep_bias_begin\n        range_end = args.timestep_bias_end\n        if range_begin < 0:\n            raise ValueError(\n                \"When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero.\"\n            )\n        if range_end > num_timesteps:\n            raise ValueError(\n                \"When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps.\"\n            )\n        bias_indices = slice(range_begin, range_end)\n    else:  # 'none' or any other string\n        return weights\n    if args.timestep_bias_multiplier <= 0:\n        return ValueError(\n            \"The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps.\"\n            \" If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead.\"\n            \" A timestep bias multiplier less than or equal to 0 is not allowed.\"\n        )\n\n    # Apply the bias\n    weights[bias_indices] *= args.timestep_bias_multiplier\n\n    # Normalize\n    weights /= weights.sum()\n\n    return weights\n\n\ndef main(args):\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load the tokenizers\n    tokenizer_one = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n    tokenizer_two = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    # Check for terminal SNR in combination with SNR Gamma\n    text_encoder_one = text_encoder_cls_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = text_encoder_cls_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    vae_path = (\n        args.pretrained_model_name_or_path\n        if args.pretrained_vae_model_name_or_path is None\n        else args.pretrained_vae_model_name_or_path\n    )\n    vae = AutoencoderKL.from_pretrained(\n        vae_path,\n        subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n        revision=args.revision,\n        variant=args.variant,\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # Freeze vae and text encoders.\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n    # Set unet as trainable.\n    unet.train()\n\n    # For mixed precision training we cast all non-trainable weights to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move unet, vae and text_encoder to device and cast to weight_dtype\n    # The VAE is in float32 to avoid NaN losses.\n    vae.to(accelerator.device, dtype=torch.float32)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n\n    # Create EMA for the unet.\n    if args.use_ema:\n        ema_unet = UNet2DConditionModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n        )\n        ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)\n    if args.enable_npu_flash_attention:\n        if is_torch_npu_available():\n            logger.info(\"npu flash attention enabled.\")\n            unet.enable_npu_flash_attention()\n        else:\n            raise ValueError(\"npu flash attention requires torch_npu extensions and is supported only on npu devices.\")\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_unet.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    if weights:\n                        weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"unet_ema\"), UNet2DConditionModel)\n                ema_unet.load_state_dict(load_model.state_dict())\n                ema_unet.to(accelerator.device)\n                del load_model\n\n            for _ in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = unet.parameters()\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)\n    if args.image_column is None:\n        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]\n    else:\n        image_column = args.image_column\n        if image_column not in column_names:\n            raise ValueError(\n                f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n    if args.caption_column is None:\n        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]\n    else:\n        caption_column = args.caption_column\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}\"\n            )\n\n    # Preprocessing the datasets.\n    interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)\n    if interpolation is None:\n        raise ValueError(f\"Unsupported interpolation mode {interpolation=}.\")\n    train_resize = transforms.Resize(args.resolution, interpolation=interpolation)\n    train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)\n    train_flip = transforms.RandomHorizontalFlip(p=1.0)\n    train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        # image aug\n        original_sizes = []\n        all_images = []\n        crop_top_lefts = []\n        for image in images:\n            original_sizes.append((image.height, image.width))\n            image = train_resize(image)\n            if args.random_flip and random.random() < 0.5:\n                # flip\n                image = train_flip(image)\n            if args.center_crop:\n                y1 = max(0, int(round((image.height - args.resolution) / 2.0)))\n                x1 = max(0, int(round((image.width - args.resolution) / 2.0)))\n                image = train_crop(image)\n            else:\n                y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))\n                image = crop(image, y1, x1, h, w)\n            crop_top_left = (y1, x1)\n            crop_top_lefts.append(crop_top_left)\n            image = train_transforms(image)\n            all_images.append(image)\n\n        examples[\"original_sizes\"] = original_sizes\n        examples[\"crop_top_lefts\"] = crop_top_lefts\n        examples[\"pixel_values\"] = all_images\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        # Set the training transforms\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    # Let's first compute all the embeddings so that we can free up the text encoders\n    # from memory. We will pre-compute the VAE encodings too.\n    text_encoders = [text_encoder_one, text_encoder_two]\n    tokenizers = [tokenizer_one, tokenizer_two]\n    compute_embeddings_fn = functools.partial(\n        encode_prompt,\n        text_encoders=text_encoders,\n        tokenizers=tokenizers,\n        proportion_empty_prompts=args.proportion_empty_prompts,\n        caption_column=args.caption_column,\n    )\n    compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)\n    with accelerator.main_process_first():\n        from datasets.fingerprint import Hasher\n\n        # fingerprint used by the cache for the other processes to load the result\n        # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401\n        new_fingerprint = Hasher.hash(args)\n        new_fingerprint_for_vae = Hasher.hash((vae_path, args))\n        train_dataset_with_embeddings = train_dataset.map(\n            compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint\n        )\n        train_dataset_with_vae = train_dataset.map(\n            compute_vae_encodings_fn,\n            batched=True,\n            batch_size=args.train_batch_size,\n            new_fingerprint=new_fingerprint_for_vae,\n        )\n        precomputed_dataset = concatenate_datasets(\n            [train_dataset_with_embeddings, train_dataset_with_vae.remove_columns([\"image\", \"text\"])], axis=1\n        )\n        precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)\n\n    del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two\n    del text_encoders, tokenizers, vae\n    gc.collect()\n    if is_torch_npu_available():\n        torch_npu.npu.empty_cache()\n    elif torch.cuda.is_available():\n        torch.cuda.empty_cache()\n\n    def collate_fn(examples):\n        model_input = torch.stack([torch.tensor(example[\"model_input\"]) for example in examples])\n        original_sizes = [example[\"original_sizes\"] for example in examples]\n        crop_top_lefts = [example[\"crop_top_lefts\"] for example in examples]\n        prompt_embeds = torch.stack([torch.tensor(example[\"prompt_embeds\"]) for example in examples])\n        pooled_prompt_embeds = torch.stack([torch.tensor(example[\"pooled_prompt_embeds\"]) for example in examples])\n\n        return {\n            \"model_input\": model_input,\n            \"prompt_embeds\": prompt_embeds,\n            \"pooled_prompt_embeds\": pooled_prompt_embeds,\n            \"original_sizes\": original_sizes,\n            \"crop_top_lefts\": crop_top_lefts,\n        }\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        precomputed_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n    )\n\n    # Prepare everything with our `accelerator`.\n    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        unet, optimizer, train_dataloader, lr_scheduler\n    )\n\n    if args.use_ema:\n        ema_unet.to(accelerator.device)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"text2image-fine-tune-sdxl\", config=vars(args))\n\n    # Function for unwrapping if torch.compile() was used in accelerate.\n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    if torch.backends.mps.is_available() or \"playground\" in args.pretrained_model_name_or_path:\n        autocast_ctx = nullcontext()\n    else:\n        autocast_ctx = torch.autocast(accelerator.device.type)\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(precomputed_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        train_loss = 0.0\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(unet):\n                # Sample noise that we'll add to the latents\n                model_input = batch[\"model_input\"].to(accelerator.device)\n                noise = torch.randn_like(model_input)\n                if args.noise_offset:\n                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise\n                    noise += args.noise_offset * torch.randn(\n                        (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device\n                    )\n\n                bsz = model_input.shape[0]\n                if args.timestep_bias_strategy == \"none\":\n                    # Sample a random timestep for each image without bias.\n                    timesteps = torch.randint(\n                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device\n                    )\n                else:\n                    # Sample a random timestep for each image, potentially biased by the timestep weights.\n                    # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.\n                    weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(\n                        model_input.device\n                    )\n                    timesteps = torch.multinomial(weights, bsz, replacement=True).long()\n\n                # Add noise to the model input according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)\n\n                # time ids\n                def compute_time_ids(original_size, crops_coords_top_left):\n                    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n                    target_size = (args.resolution, args.resolution)\n                    add_time_ids = list(original_size + crops_coords_top_left + target_size)\n                    add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)\n                    return add_time_ids\n\n                add_time_ids = torch.cat(\n                    [compute_time_ids(s, c) for s, c in zip(batch[\"original_sizes\"], batch[\"crop_top_lefts\"])]\n                )\n\n                # Predict the noise residual\n                unet_added_conditions = {\"time_ids\": add_time_ids}\n                prompt_embeds = batch[\"prompt_embeds\"].to(accelerator.device, dtype=weight_dtype)\n                pooled_prompt_embeds = batch[\"pooled_prompt_embeds\"].to(accelerator.device)\n                unet_added_conditions.update({\"text_embeds\": pooled_prompt_embeds})\n                model_pred = unet(\n                    noisy_model_input,\n                    timesteps,\n                    prompt_embeds,\n                    added_cond_kwargs=unet_added_conditions,\n                    return_dict=False,\n                )[0]\n\n                # Get the target for loss depending on the prediction type\n                if args.prediction_type is not None:\n                    # set prediction_type of scheduler if defined\n                    noise_scheduler.register_to_config(prediction_type=args.prediction_type)\n\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(model_input, noise, timesteps)\n                elif noise_scheduler.config.prediction_type == \"sample\":\n                    # We set the target to latents here, but the model_pred will return the noise sample prediction.\n                    target = model_input\n                    # We will have to subtract the noise residual from the prediction to get the target sample.\n                    model_pred = model_pred - noise\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                if args.snr_gamma is None:\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n                else:\n                    # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.\n                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n                    # This is discussed in Section 4.2 of the same paper.\n                    snr = compute_snr(noise_scheduler, timesteps)\n                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(\n                        dim=1\n                    )[0]\n                    if noise_scheduler.config.prediction_type == \"epsilon\":\n                        mse_loss_weights = mse_loss_weights / snr\n                    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                        mse_loss_weights = mse_loss_weights / (snr + 1)\n\n                    loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights\n                    loss = loss.mean()\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                train_loss += avg_loss.item() / args.gradient_accumulation_steps\n\n                # Backpropagate\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = unet.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_unet.step(unet.parameters())\n                progress_bar.update(1)\n                global_step += 1\n                accelerator.log({\"train_loss\": train_loss}, step=global_step)\n                train_loss = 0.0\n\n                # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.\n                if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                logger.info(\n                    f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n                    f\" {args.validation_prompt}.\"\n                )\n                if args.use_ema:\n                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.\n                    ema_unet.store(unet.parameters())\n                    ema_unet.copy_to(unet.parameters())\n\n                # create pipeline\n                vae = AutoencoderKL.from_pretrained(\n                    vae_path,\n                    subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n                    revision=args.revision,\n                    variant=args.variant,\n                )\n                pipeline = StableDiffusionXLPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    unet=accelerator.unwrap_model(unet),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                if args.prediction_type is not None:\n                    scheduler_args = {\"prediction_type\": args.prediction_type}\n                    pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n\n                pipeline = pipeline.to(accelerator.device)\n                pipeline.set_progress_bar_config(disable=True)\n\n                # run inference\n                generator = (\n                    torch.Generator(device=accelerator.device).manual_seed(args.seed)\n                    if args.seed is not None\n                    else None\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n\n                with autocast_ctx:\n                    images = [\n                        pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]\n                        for _ in range(args.num_validation_images)\n                    ]\n\n                for tracker in accelerator.trackers:\n                    if tracker.name == \"tensorboard\":\n                        np_images = np.stack([np.asarray(img) for img in images])\n                        tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n                    if tracker.name == \"wandb\":\n                        tracker.log(\n                            {\n                                \"validation\": [\n                                    wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                    for i, image in enumerate(images)\n                                ]\n                            }\n                        )\n\n                del pipeline\n                if is_torch_npu_available():\n                    torch_npu.npu.empty_cache()\n                elif torch.cuda.is_available():\n                    torch.cuda.empty_cache()\n\n                if args.use_ema:\n                    # Switch back to the original UNet parameters.\n                    ema_unet.restore(unet.parameters())\n\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        unet = unwrap_model(unet)\n        if args.use_ema:\n            ema_unet.copy_to(unet.parameters())\n\n        # Serialize pipeline.\n        vae = AutoencoderKL.from_pretrained(\n            vae_path,\n            subfolder=\"vae\" if args.pretrained_vae_model_name_or_path is None else None,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        pipeline = StableDiffusionXLPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            unet=unet,\n            vae=vae,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        if args.prediction_type is not None:\n            scheduler_args = {\"prediction_type\": args.prediction_type}\n            pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)\n        pipeline.save_pretrained(args.output_dir)\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline = pipeline.to(accelerator.device)\n            generator = (\n                torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None\n            )\n\n            with autocast_ctx:\n                images = [\n                    pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]\n                    for _ in range(args.num_validation_images)\n                ]\n\n            for tracker in accelerator.trackers:\n                if tracker.name == \"tensorboard\":\n                    np_images = np.stack([np.asarray(img) for img in images])\n                    tracker.writer.add_images(\"test\", np_images, epoch, dataformats=\"NHWC\")\n                if tracker.name == \"wandb\":\n                    tracker.log(\n                        {\n                            \"test\": [\n                                wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\")\n                                for i, image in enumerate(images)\n                            ]\n                        }\n                    )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id=repo_id,\n                images=images,\n                validation_prompt=args.validation_prompt,\n                base_model=args.pretrained_model_name_or_path,\n                dataset_name=args.dataset_name,\n                repo_folder=args.output_dir,\n                vae_path=args.pretrained_vae_model_name_or_path,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/textual_inversion/README.md",
    "content": "## Textual Inversion fine-tuning example\n\n[Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.\nThe `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.\n\n## Running on Colab\n\nColab for training\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)\n\nColab for inference\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)\n\n## Running locally with PyTorch\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder and run:\n```bash\npip install -r requirements.txt\n```\n\nAnd initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\n### Cat toy example\n\nFirst, let's login so that we can upload the checkpoint to the Hub during training:\n\n```bash\nhf auth login\n```\n\nNow let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .\n\nLet's first download it locally:\n\n```py\nfrom huggingface_hub import snapshot_download\n\nlocal_dir = \"./cat\"\nsnapshot_download(\"diffusers/cat_toy_example\", local_dir=local_dir, repo_type=\"dataset\", ignore_patterns=\".gitattributes\")\n```\n\nThis will be our training data.\nNow we can launch the training using:\n\n**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**\n\n**___Note: Please follow the [README_sdxl.md](./README_sdxl.md) if you are using the [stable-diffusion-xl](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).___**\n\n```bash\nexport MODEL_NAME=\"stable-diffusion-v1-5/stable-diffusion-v1-5\"\nexport DATA_DIR=\"./cat\"\n\naccelerate launch textual_inversion.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<cat-toy>\" \\\n  --initializer_token=\"toy\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=3000 \\\n  --learning_rate=5.0e-04 \\\n  --scale_lr \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --push_to_hub \\\n  --output_dir=\"textual_inversion_cat\"\n```\n\nA full training run takes ~1 hour on one V100 GPU.\n\n**Note**: As described in [the official paper](https://huggingface.co/papers/2208.01618)\nonly one embedding vector is used for the placeholder token, *e.g.* `\"<cat-toy>\"`.\nHowever, one can also add multiple embedding vectors for the placeholder token\nto increase the number of fine-tuneable parameters. This can help the model to learn\nmore complex details. To use multiple embedding vectors, you should define `--num_vectors`\nto a number larger than one, *e.g.*:\n```bash\n--num_vectors 5\n```\n\nThe saved textual inversion vectors will then be larger in size compared to the default case.\n\n### Inference\n\nOnce you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.\n\n```python\nfrom diffusers import StableDiffusionPipeline\nimport torch\n\nmodel_id = \"path-to-your-trained-model\"\npipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to(\"cuda\")\n\nrepo_id_embeds = \"path-to-your-learned-embeds\"\npipe.load_textual_inversion(repo_id_embeds)\n\nprompt = \"A <cat-toy> backpack\"\n\nimage = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]\n\nimage.save(\"cat-backpack.png\")\n```\n\n\n## Training with Flax/JAX\n\nFor faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n```bash\npip install -U -r requirements_flax.txt\n```\n\n```bash\nexport MODEL_NAME=\"duongna/stable-diffusion-v1-4-flax\"\nexport DATA_DIR=\"path-to-dir-containing-images\"\n\npython textual_inversion_flax.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<cat-toy>\" \\\n  --initializer_token=\"toy\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --max_train_steps=3000 \\\n  --learning_rate=5.0e-04 \\\n  --scale_lr \\\n  --output_dir=\"textual_inversion_cat\"\n```\nIt should be at least 70% faster than the PyTorch script with the same configuration.\n\n### Training with xformers:\nYou can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.\n"
  },
  {
    "path": "examples/textual_inversion/README_sdxl.md",
    "content": "## Textual Inversion fine-tuning example for SDXL\n\n```sh\nexport MODEL_NAME=\"stabilityai/stable-diffusion-xl-base-1.0\"\nexport DATA_DIR=\"./cat\"\n\naccelerate launch textual_inversion_sdxl.py \\\n  --pretrained_model_name_or_path=$MODEL_NAME \\\n  --train_data_dir=$DATA_DIR \\\n  --learnable_property=\"object\" \\\n  --placeholder_token=\"<cat-toy>\" \\\n  --initializer_token=\"toy\" \\\n  --mixed_precision=\"bf16\" \\\n  --resolution=768 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=4 \\\n  --max_train_steps=500 \\\n  --learning_rate=5.0e-04 \\\n  --scale_lr \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --save_as_full_pipeline \\\n  --output_dir=\"./textual_inversion_cat_sdxl\"\n```\n\nTraining of both text encoders is supported.\n\n### Inference Example\n\nOnce you have trained a model using above command, the inference can be done simply using the `StableDiffusionXLPipeline`.\nMake sure to include the `placeholder_token` in your prompt.\n\n```python\nfrom diffusers import StableDiffusionXLPipeline\nimport torch\n\nmodel_id = \"./textual_inversion_cat_sdxl\"\npipe = StableDiffusionXLPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to(\"cuda\")\n\nprompt = \"A <cat-toy> backpack\"\n\nimage = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]\nimage.save(\"cat-backpack.png\")\n\nimage = pipe(prompt=\"\", prompt_2=prompt, num_inference_steps=50, guidance_scale=7.5).images[0]\nimage.save(\"cat-backpack-prompt_2.png\")\n```\n"
  },
  {
    "path": "examples/textual_inversion/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\nftfy\ntensorboard\nJinja2\n"
  },
  {
    "path": "examples/textual_inversion/requirements_flax.txt",
    "content": "transformers>=4.25.1\nflax\noptax\ntorch\ntorchvision\nftfy\ntensorboard\nJinja2\n"
  },
  {
    "path": "examples/textual_inversion/test_textual_inversion.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass TextualInversion(ExamplesTestsAccelerate):\n    def test_textual_inversion(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/textual_inversion/textual_inversion.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe\n                --train_data_dir docs/source/en/imgs\n                --learnable_property object\n                --placeholder_token <cat-toy>\n                --initializer_token a\n                --save_steps 1\n                --num_vectors 2\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"learned_embeds.safetensors\")))\n\n    def test_textual_inversion_checkpointing(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/textual_inversion/textual_inversion.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe\n                --train_data_dir docs/source/en/imgs\n                --learnable_property object\n                --placeholder_token <cat-toy>\n                --initializer_token a\n                --save_steps 1\n                --num_vectors 2\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 3\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=1\n                --checkpoints_total_limit=2\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-3\"},\n            )\n\n    def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/textual_inversion/textual_inversion.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe\n                --train_data_dir docs/source/en/imgs\n                --learnable_property object\n                --placeholder_token <cat-toy>\n                --initializer_token a\n                --save_steps 1\n                --num_vectors 2\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=1\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-1\", \"checkpoint-2\"},\n            )\n\n            resume_run_args = f\"\"\"\n                examples/textual_inversion/textual_inversion.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe\n                --train_data_dir docs/source/en/imgs\n                --learnable_property object\n                --placeholder_token <cat-toy>\n                --initializer_token a\n                --save_steps 1\n                --num_vectors 2\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=1\n                --resume_from_checkpoint=checkpoint-2\n                --checkpoints_total_limit=2\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-3\"},\n            )\n"
  },
  {
    "path": "examples/textual_inversion/test_textual_inversion_sdxl.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass TextualInversionSdxl(ExamplesTestsAccelerate):\n    def test_textual_inversion_sdxl(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/textual_inversion/textual_inversion_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe\n                --train_data_dir docs/source/en/imgs\n                --learnable_property object\n                --placeholder_token <cat-toy>\n                --initializer_token a\n                --save_steps 1\n                --num_vectors 2\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"learned_embeds.safetensors\")))\n\n    def test_textual_inversion_sdxl_checkpointing(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/textual_inversion/textual_inversion_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe\n                --train_data_dir docs/source/en/imgs\n                --learnable_property object\n                --placeholder_token <cat-toy>\n                --initializer_token a\n                --save_steps 1\n                --num_vectors 2\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 3\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=1\n                --checkpoints_total_limit=2\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-3\"},\n            )\n\n    def test_textual_inversion_sdxl_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/textual_inversion/textual_inversion_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe\n                --train_data_dir docs/source/en/imgs\n                --learnable_property object\n                --placeholder_token <cat-toy>\n                --initializer_token a\n                --save_steps 1\n                --num_vectors 2\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=1\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-1\", \"checkpoint-2\"},\n            )\n\n            resume_run_args = f\"\"\"\n                examples/textual_inversion/textual_inversion_sdxl.py\n                --pretrained_model_name_or_path hf-internal-testing/tiny-sdxl-pipe\n                --train_data_dir docs/source/en/imgs\n                --learnable_property object\n                --placeholder_token <cat-toy>\n                --initializer_token a\n                --save_steps 1\n                --num_vectors 2\n                --resolution 64\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --output_dir {tmpdir}\n                --checkpointing_steps=1\n                --resume_from_checkpoint=checkpoint-2\n                --checkpoints_total_limit=2\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-3\"},\n            )\n"
  },
  {
    "path": "examples/textual_inversion/textual_inversion.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nimport warnings\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport numpy as np\nimport PIL\nimport safetensors\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\n\n# TODO: remove and import from diffusers.utils when the new version of diffusers is released\nfrom packaging import version\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    DPMSolverMultistepScheduler,\n    StableDiffusionPipeline,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nif is_wandb_available():\n    import wandb\n\nif version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.Resampling.BILINEAR,\n        \"bilinear\": PIL.Image.Resampling.BILINEAR,\n        \"bicubic\": PIL.Image.Resampling.BICUBIC,\n        \"lanczos\": PIL.Image.Resampling.LANCZOS,\n        \"nearest\": PIL.Image.Resampling.NEAREST,\n    }\nelse:\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.LINEAR,\n        \"bilinear\": PIL.Image.BILINEAR,\n        \"bicubic\": PIL.Image.BICUBIC,\n        \"lanczos\": PIL.Image.LANCZOS,\n        \"nearest\": PIL.Image.NEAREST,\n    }\n# ------------------------------------------------------------------------------\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(repo_id: str, images: list = None, base_model: str = None, repo_folder: str = None):\n    img_str = \"\"\n    if images is not None:\n        for i, image in enumerate(images):\n            image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n            img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n    model_description = f\"\"\"\n# Textual inversion text2image fine-tuning - {repo_id}\nThese are textual inversion adaption weights for {base_model}. You can find some example images in the following. \\n\n{img_str}\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion\",\n        \"stable-diffusion-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"textual_inversion\",\n        \"diffusers-training\",\n    ]\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    # create pipeline (note: unet and vae are loaded again in float32)\n    pipeline = DiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        text_encoder=accelerator.unwrap_model(text_encoder),\n        tokenizer=tokenizer,\n        unet=unet,\n        vae=vae,\n        safety_checker=None,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)\n    images = []\n    for _ in range(args.num_validation_images):\n        if torch.backends.mps.is_available():\n            autocast_ctx = nullcontext()\n        else:\n            autocast_ctx = torch.autocast(accelerator.device.type)\n\n        with autocast_ctx:\n            image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]\n        images.append(image)\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    torch.cuda.empty_cache()\n    return images\n\n\ndef save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True):\n    logger.info(\"Saving embeddings\")\n    learned_embeds = (\n        accelerator.unwrap_model(text_encoder)\n        .get_input_embeddings()\n        .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]\n    )\n    learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}\n\n    if safe_serialization:\n        safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={\"format\": \"pt\"})\n    else:\n        torch.save(learned_embeds_dict, save_path)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=500,\n        help=\"Save learned_embeds.bin every X updates steps.\",\n    )\n    parser.add_argument(\n        \"--save_as_full_pipeline\",\n        action=\"store_true\",\n        help=\"Save the complete stable diffusion pipeline.\",\n    )\n    parser.add_argument(\n        \"--num_vectors\",\n        type=int,\n        default=1,\n        help=\"How many textual inversion vectors shall be used to learn the concept.\",\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\", type=str, default=None, required=True, help=\"A folder containing the training data.\"\n    )\n    parser.add_argument(\n        \"--placeholder_token\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A token to use as a placeholder for the concept.\",\n    )\n    parser.add_argument(\n        \"--initializer_token\", type=str, default=None, required=True, help=\"A token to use as initializer word.\"\n    )\n    parser.add_argument(\"--learnable_property\", type=str, default=\"object\", help=\"Choose between 'object' and 'style'\")\n    parser.add_argument(\"--repeats\", type=int, default=100, help=\"How many times to repeat the training data.\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\", action=\"store_true\", help=\"Whether to center crop images before resizing to resolution.\"\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=5000,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and Nvidia Ampere GPU or Intel Gen 4 Xeon (and later) .\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=None,\n        help=(\n            \"Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--no_safe_serialization\",\n        action=\"store_true\",\n        help=\"If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.\",\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.train_data_dir is None:\n        raise ValueError(\"You must specify a train data directory.\")\n\n    return args\n\n\nimagenet_templates_small = [\n    \"a photo of a {}\",\n    \"a rendering of a {}\",\n    \"a cropped photo of the {}\",\n    \"the photo of a {}\",\n    \"a photo of a clean {}\",\n    \"a photo of a dirty {}\",\n    \"a dark photo of the {}\",\n    \"a photo of my {}\",\n    \"a photo of the cool {}\",\n    \"a close-up photo of a {}\",\n    \"a bright photo of the {}\",\n    \"a cropped photo of a {}\",\n    \"a photo of the {}\",\n    \"a good photo of the {}\",\n    \"a photo of one {}\",\n    \"a close-up photo of the {}\",\n    \"a rendition of the {}\",\n    \"a photo of the clean {}\",\n    \"a rendition of a {}\",\n    \"a photo of a nice {}\",\n    \"a good photo of a {}\",\n    \"a photo of the nice {}\",\n    \"a photo of the small {}\",\n    \"a photo of the weird {}\",\n    \"a photo of the large {}\",\n    \"a photo of a cool {}\",\n    \"a photo of a small {}\",\n]\n\nimagenet_style_templates_small = [\n    \"a painting in the style of {}\",\n    \"a rendering in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"the painting in the style of {}\",\n    \"a clean painting in the style of {}\",\n    \"a dirty painting in the style of {}\",\n    \"a dark painting in the style of {}\",\n    \"a picture in the style of {}\",\n    \"a cool painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a bright painting in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"a good painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a rendition in the style of {}\",\n    \"a nice painting in the style of {}\",\n    \"a small painting in the style of {}\",\n    \"a weird painting in the style of {}\",\n    \"a large painting in the style of {}\",\n]\n\n\nclass TextualInversionDataset(Dataset):\n    def __init__(\n        self,\n        data_root,\n        tokenizer,\n        learnable_property=\"object\",  # [object, style]\n        size=512,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.5,\n        set=\"train\",\n        placeholder_token=\"*\",\n        center_crop=False,\n    ):\n        self.data_root = data_root\n        self.tokenizer = tokenizer\n        self.learnable_property = learnable_property\n        self.size = size\n        self.placeholder_token = placeholder_token\n        self.center_crop = center_crop\n        self.flip_p = flip_p\n\n        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.interpolation = {\n            \"linear\": PIL_INTERPOLATION[\"linear\"],\n            \"bilinear\": PIL_INTERPOLATION[\"bilinear\"],\n            \"bicubic\": PIL_INTERPOLATION[\"bicubic\"],\n            \"lanczos\": PIL_INTERPOLATION[\"lanczos\"],\n        }[interpolation]\n\n        self.templates = imagenet_style_templates_small if learnable_property == \"style\" else imagenet_templates_small\n        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        placeholder_string = self.placeholder_token\n        text = random.choice(self.templates).format(placeholder_string)\n\n        example[\"input_ids\"] = self.tokenizer(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids[0]\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            (\n                h,\n                w,\n            ) = (\n                img.shape[0],\n                img.shape[1],\n            )\n            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]\n\n        image = Image.fromarray(img)\n        image = image.resize((self.size, self.size), resample=self.interpolation)\n\n        image = self.flip_transform(image)\n        image = np.array(image).astype(np.uint8)\n        image = (image / 127.5 - 1.0).astype(np.float32)\n\n        example[\"pixel_values\"] = torch.from_numpy(image).permute(2, 0, 1)\n        return example\n\n\ndef main():\n    args = parse_args()\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load tokenizer\n    if args.tokenizer_name:\n        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # Add the placeholder token in tokenizer\n    placeholder_tokens = [args.placeholder_token]\n\n    if args.num_vectors < 1:\n        raise ValueError(f\"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}\")\n\n    # add dummy tokens for multi-vector\n    additional_tokens = []\n    for i in range(1, args.num_vectors):\n        additional_tokens.append(f\"{args.placeholder_token}_{i}\")\n    placeholder_tokens += additional_tokens\n\n    num_added_tokens = tokenizer.add_tokens(placeholder_tokens)\n    if num_added_tokens != args.num_vectors:\n        raise ValueError(\n            f\"The tokenizer already contains the token {args.placeholder_token}. Please pass a different\"\n            \" `placeholder_token` that is not already in the tokenizer.\"\n        )\n\n    # Convert the initializer_token, placeholder_token to ids\n    token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)\n    # Check if initializer_token is a single token or a sequence of tokens\n    if len(token_ids) > 1:\n        raise ValueError(\"The initializer token must be a single token.\")\n\n    initializer_token_id = token_ids[0]\n    placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)\n\n    # Resize the token embeddings as we are adding new special tokens to the tokenizer\n    text_encoder.resize_token_embeddings(len(tokenizer))\n\n    # Initialise the newly added placeholder token with the embeddings of the initializer token\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    with torch.no_grad():\n        for token_id in placeholder_token_ids:\n            token_embeds[token_id] = token_embeds[initializer_token_id].clone()\n\n    # Freeze vae and unet\n    vae.requires_grad_(False)\n    unet.requires_grad_(False)\n    # Freeze all parameters except for the token embeddings in text encoder\n    text_encoder.text_model.encoder.requires_grad_(False)\n    text_encoder.text_model.final_layer_norm.requires_grad_(False)\n    text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)\n\n    if args.gradient_checkpointing:\n        # Keep unet in train mode if we are using gradient checkpointing to save memory.\n        # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.\n        unet.train()\n        text_encoder.gradient_checkpointing_enable()\n        unet.enable_gradient_checkpointing()\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    optimizer = torch.optim.AdamW(\n        text_encoder.get_input_embeddings().parameters(),  # only optimize the embeddings\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = TextualInversionDataset(\n        data_root=args.train_data_dir,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        placeholder_token=(\" \".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))),\n        repeats=args.repeats,\n        learnable_property=args.learnable_property,\n        center_crop=args.center_crop,\n        set=\"train\",\n    )\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers\n    )\n    if args.validation_epochs is not None:\n        warnings.warn(\n            f\"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}.\"\n            \" Deprecated validation_epochs in favor of `validation_steps`\"\n            f\"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}\",\n            FutureWarning,\n            stacklevel=2,\n        )\n        args.validation_steps = args.validation_epochs * len(train_dataset)\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n    )\n\n    text_encoder.train()\n    # Prepare everything with our `accelerator`.\n    text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        text_encoder, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae and unet to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"textual_inversion\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    # keep original embeddings as reference\n    orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(text_encoder):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample().detach()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0].to(dtype=weight_dtype)\n\n                # Predict the noise residual\n                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n                # Let's make sure we don't update any embedding weights besides the newly added token\n                index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)\n                index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False\n\n                with torch.no_grad():\n                    accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (\n                        orig_embeds_params[index_no_updates]\n                    )\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                images = []\n                progress_bar.update(1)\n                global_step += 1\n                if global_step % args.save_steps == 0:\n                    weight_name = (\n                        f\"learned_embeds-steps-{global_step}.bin\"\n                        if args.no_safe_serialization\n                        else f\"learned_embeds-steps-{global_step}.safetensors\"\n                    )\n                    save_path = os.path.join(args.output_dir, weight_name)\n                    save_progress(\n                        text_encoder,\n                        placeholder_token_ids,\n                        accelerator,\n                        args,\n                        save_path,\n                        safe_serialization=not args.no_safe_serialization,\n                    )\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        images = log_validation(\n                            text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        if args.push_to_hub and not args.save_as_full_pipeline:\n            logger.warning(\"Enabling full model saving because --push_to_hub=True was specified.\")\n            save_full_model = True\n        else:\n            save_full_model = args.save_as_full_pipeline\n        if save_full_model:\n            pipeline = StableDiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                text_encoder=accelerator.unwrap_model(text_encoder),\n                vae=vae,\n                unet=unet,\n                tokenizer=tokenizer,\n            )\n            pipeline.save_pretrained(args.output_dir)\n        # Save the newly trained embeddings\n        weight_name = \"learned_embeds.bin\" if args.no_safe_serialization else \"learned_embeds.safetensors\"\n        save_path = os.path.join(args.output_dir, weight_name)\n        save_progress(\n            text_encoder,\n            placeholder_token_ids,\n            accelerator,\n            args,\n            save_path,\n            safe_serialization=not args.no_safe_serialization,\n        )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/textual_inversion/textual_inversion_flax.py",
    "content": "import argparse\nimport logging\nimport math\nimport os\nimport random\nfrom pathlib import Path\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\nimport PIL\nimport torch\nimport torch.utils.checkpoint\nimport transformers\nfrom flax import jax_utils\nfrom flax.training import train_state\nfrom flax.training.common_utils import shard\nfrom huggingface_hub import create_repo, upload_folder\n\n# TODO: remove and import from diffusers.utils when the new version of diffusers is released\nfrom packaging import version\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel, set_seed\n\nfrom diffusers import (\n    FlaxAutoencoderKL,\n    FlaxDDPMScheduler,\n    FlaxPNDMScheduler,\n    FlaxStableDiffusionPipeline,\n    FlaxUNet2DConditionModel,\n)\nfrom diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker\nfrom diffusers.utils import check_min_version\n\n\nif version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.Resampling.BILINEAR,\n        \"bilinear\": PIL.Image.Resampling.BILINEAR,\n        \"bicubic\": PIL.Image.Resampling.BICUBIC,\n        \"lanczos\": PIL.Image.Resampling.LANCZOS,\n        \"nearest\": PIL.Image.Resampling.NEAREST,\n    }\nelse:\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.LINEAR,\n        \"bilinear\": PIL.Image.BILINEAR,\n        \"bicubic\": PIL.Image.BICUBIC,\n        \"lanczos\": PIL.Image.LANCZOS,\n        \"nearest\": PIL.Image.NEAREST,\n    }\n# ------------------------------------------------------------------------------\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = logging.getLogger(__name__)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\", type=str, default=None, required=True, help=\"A folder containing the training data.\"\n    )\n    parser.add_argument(\n        \"--placeholder_token\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A token to use as a placeholder for the concept.\",\n    )\n    parser.add_argument(\n        \"--initializer_token\", type=str, default=None, required=True, help=\"A token to use as initializer word.\"\n    )\n    parser.add_argument(\"--learnable_property\", type=str, default=\"object\", help=\"Choose between 'object' and 'style'\")\n    parser.add_argument(\"--repeats\", type=int, default=100, help=\"How many times to repeat the training data.\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=42, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\", action=\"store_true\", help=\"Whether to center crop images before resizing to resolution.\"\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=5000,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=500,\n        help=\"Save learned_embeds.bin every X updates steps.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=True,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\n        \"--use_auth_token\",\n        action=\"store_true\",\n        help=(\n            \"Will use the token generated when running `hf auth login` (necessary to use this script with\"\n            \" private models).\"\n        ),\n    )\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.train_data_dir is None:\n        raise ValueError(\"You must specify a train data directory.\")\n\n    return args\n\n\nimagenet_templates_small = [\n    \"a photo of a {}\",\n    \"a rendering of a {}\",\n    \"a cropped photo of the {}\",\n    \"the photo of a {}\",\n    \"a photo of a clean {}\",\n    \"a photo of a dirty {}\",\n    \"a dark photo of the {}\",\n    \"a photo of my {}\",\n    \"a photo of the cool {}\",\n    \"a close-up photo of a {}\",\n    \"a bright photo of the {}\",\n    \"a cropped photo of a {}\",\n    \"a photo of the {}\",\n    \"a good photo of the {}\",\n    \"a photo of one {}\",\n    \"a close-up photo of the {}\",\n    \"a rendition of the {}\",\n    \"a photo of the clean {}\",\n    \"a rendition of a {}\",\n    \"a photo of a nice {}\",\n    \"a good photo of a {}\",\n    \"a photo of the nice {}\",\n    \"a photo of the small {}\",\n    \"a photo of the weird {}\",\n    \"a photo of the large {}\",\n    \"a photo of a cool {}\",\n    \"a photo of a small {}\",\n]\n\nimagenet_style_templates_small = [\n    \"a painting in the style of {}\",\n    \"a rendering in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"the painting in the style of {}\",\n    \"a clean painting in the style of {}\",\n    \"a dirty painting in the style of {}\",\n    \"a dark painting in the style of {}\",\n    \"a picture in the style of {}\",\n    \"a cool painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a bright painting in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"a good painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a rendition in the style of {}\",\n    \"a nice painting in the style of {}\",\n    \"a small painting in the style of {}\",\n    \"a weird painting in the style of {}\",\n    \"a large painting in the style of {}\",\n]\n\n\nclass TextualInversionDataset(Dataset):\n    def __init__(\n        self,\n        data_root,\n        tokenizer,\n        learnable_property=\"object\",  # [object, style]\n        size=512,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.5,\n        set=\"train\",\n        placeholder_token=\"*\",\n        center_crop=False,\n    ):\n        self.data_root = data_root\n        self.tokenizer = tokenizer\n        self.learnable_property = learnable_property\n        self.size = size\n        self.placeholder_token = placeholder_token\n        self.center_crop = center_crop\n        self.flip_p = flip_p\n\n        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.interpolation = {\n            \"linear\": PIL_INTERPOLATION[\"linear\"],\n            \"bilinear\": PIL_INTERPOLATION[\"bilinear\"],\n            \"bicubic\": PIL_INTERPOLATION[\"bicubic\"],\n            \"lanczos\": PIL_INTERPOLATION[\"lanczos\"],\n        }[interpolation]\n\n        self.templates = imagenet_style_templates_small if learnable_property == \"style\" else imagenet_templates_small\n        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        placeholder_string = self.placeholder_token\n        text = random.choice(self.templates).format(placeholder_string)\n\n        example[\"input_ids\"] = self.tokenizer(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids[0]\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            (\n                h,\n                w,\n            ) = (\n                img.shape[0],\n                img.shape[1],\n            )\n            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]\n\n        image = Image.fromarray(img)\n        image = image.resize((self.size, self.size), resample=self.interpolation)\n\n        image = self.flip_transform(image)\n        image = np.array(image).astype(np.uint8)\n        image = (image / 127.5 - 1.0).astype(np.float32)\n\n        example[\"pixel_values\"] = torch.from_numpy(image).permute(2, 0, 1)\n        return example\n\n\ndef resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):\n    if model.config.vocab_size == new_num_tokens or new_num_tokens is None:\n        return\n    model.config.vocab_size = new_num_tokens\n\n    params = model.params\n    old_embeddings = params[\"text_model\"][\"embeddings\"][\"token_embedding\"][\"embedding\"]\n    old_num_tokens, emb_dim = old_embeddings.shape\n\n    initializer = jax.nn.initializers.normal()\n\n    new_embeddings = initializer(rng, (new_num_tokens, emb_dim))\n    new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings)\n    new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id])\n    params[\"text_model\"][\"embeddings\"][\"token_embedding\"][\"embedding\"] = new_embeddings\n\n    model.params = params\n    return model\n\n\ndef get_params_to_save(params):\n    return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))\n\n\ndef main():\n    args = parse_args()\n\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    if jax.process_index() == 0:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s -   %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    # Setup logging, we only want one process per machine to log things on the screen.\n    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)\n    if jax.process_index() == 0:\n        transformers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n\n    # Load the tokenizer and add the placeholder token as a additional special token\n    if args.tokenizer_name:\n        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n\n    # Add the placeholder token in tokenizer\n    num_added_tokens = tokenizer.add_tokens(args.placeholder_token)\n    if num_added_tokens == 0:\n        raise ValueError(\n            f\"The tokenizer already contains the token {args.placeholder_token}. Please pass a different\"\n            \" `placeholder_token` that is not already in the tokenizer.\"\n        )\n\n    # Convert the initializer_token, placeholder_token to ids\n    token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)\n    # Check if initializer_token is a single token or a sequence of tokens\n    if len(token_ids) > 1:\n        raise ValueError(\"The initializer token must be a single token.\")\n\n    initializer_token_id = token_ids[0]\n    placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)\n\n    # Load models and create wrapper for stable diffusion\n    text_encoder = FlaxCLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    vae, vae_params = FlaxAutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision\n    )\n    unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n    )\n\n    # Create sampling rng\n    rng = jax.random.PRNGKey(args.seed)\n    rng, _ = jax.random.split(rng)\n    # Resize the token embeddings as we are adding new special tokens to the tokenizer\n    text_encoder = resize_token_embeddings(\n        text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng\n    )\n    original_token_embeds = text_encoder.params[\"text_model\"][\"embeddings\"][\"token_embedding\"][\"embedding\"]\n\n    train_dataset = TextualInversionDataset(\n        data_root=args.train_data_dir,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        placeholder_token=args.placeholder_token,\n        repeats=args.repeats,\n        learnable_property=args.learnable_property,\n        center_crop=args.center_crop,\n        set=\"train\",\n    )\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n\n        batch = {\"pixel_values\": pixel_values, \"input_ids\": input_ids}\n        batch = {k: v.numpy() for k, v in batch.items()}\n\n        return batch\n\n    total_train_batch_size = args.train_batch_size * jax.local_device_count()\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn\n    )\n\n    # Optimization\n    if args.scale_lr:\n        args.learning_rate = args.learning_rate * total_train_batch_size\n\n    constant_scheduler = optax.constant_schedule(args.learning_rate)\n\n    optimizer = optax.adamw(\n        learning_rate=constant_scheduler,\n        b1=args.adam_beta1,\n        b2=args.adam_beta2,\n        eps=args.adam_epsilon,\n        weight_decay=args.adam_weight_decay,\n    )\n\n    def create_mask(params, label_fn):\n        def _map(params, mask, label_fn):\n            for k in params:\n                if label_fn(k):\n                    mask[k] = \"token_embedding\"\n                else:\n                    if isinstance(params[k], dict):\n                        mask[k] = {}\n                        _map(params[k], mask[k], label_fn)\n                    else:\n                        mask[k] = \"zero\"\n\n        mask = {}\n        _map(params, mask, label_fn)\n        return mask\n\n    def zero_grads():\n        # from https://github.com/deepmind/optax/issues/159#issuecomment-896459491\n        def init_fn(_):\n            return ()\n\n        def update_fn(updates, state, params=None):\n            return jax.tree_util.tree_map(jnp.zeros_like, updates), ()\n\n        return optax.GradientTransformation(init_fn, update_fn)\n\n    # Zero out gradients of layers other than the token embedding layer\n    tx = optax.multi_transform(\n        {\"token_embedding\": optimizer, \"zero\": zero_grads()},\n        create_mask(text_encoder.params, lambda s: s == \"token_embedding\"),\n    )\n\n    state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx)\n\n    noise_scheduler = FlaxDDPMScheduler(\n        beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", num_train_timesteps=1000\n    )\n    noise_scheduler_state = noise_scheduler.create_state()\n\n    # Initialize our training\n    train_rngs = jax.random.split(rng, jax.local_device_count())\n\n    # Define gradient train step fn\n    def train_step(state, vae_params, unet_params, batch, train_rng):\n        dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)\n\n        def compute_loss(params):\n            vae_outputs = vae.apply(\n                {\"params\": vae_params}, batch[\"pixel_values\"], deterministic=True, method=vae.encode\n            )\n            latents = vae_outputs.latent_dist.sample(sample_rng)\n            # (NHWC) -> (NCHW)\n            latents = jnp.transpose(latents, (0, 3, 1, 2))\n            latents = latents * vae.config.scaling_factor\n\n            noise_rng, timestep_rng = jax.random.split(sample_rng)\n            noise = jax.random.normal(noise_rng, latents.shape)\n            bsz = latents.shape[0]\n            timesteps = jax.random.randint(\n                timestep_rng,\n                (bsz,),\n                0,\n                noise_scheduler.config.num_train_timesteps,\n            )\n            noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)\n            encoder_hidden_states = state.apply_fn(\n                batch[\"input_ids\"], params=params, dropout_rng=dropout_rng, train=True\n            )[0]\n            # Predict the noise residual and compute loss\n            model_pred = unet.apply(\n                {\"params\": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False\n            ).sample\n\n            # Get the target for loss depending on the prediction type\n            if noise_scheduler.config.prediction_type == \"epsilon\":\n                target = noise\n            elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)\n            else:\n                raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n            loss = (target - model_pred) ** 2\n            loss = loss.mean()\n\n            return loss\n\n        grad_fn = jax.value_and_grad(compute_loss)\n        loss, grad = grad_fn(state.params)\n        grad = jax.lax.pmean(grad, \"batch\")\n        new_state = state.apply_gradients(grads=grad)\n\n        # Keep the token embeddings fixed except the newly added embeddings for the concept,\n        # as we only want to optimize the concept embeddings\n        token_embeds = original_token_embeds.at[placeholder_token_id].set(\n            new_state.params[\"text_model\"][\"embeddings\"][\"token_embedding\"][\"embedding\"][placeholder_token_id]\n        )\n        new_state.params[\"text_model\"][\"embeddings\"][\"token_embedding\"][\"embedding\"] = token_embeds\n\n        metrics = {\"loss\": loss}\n        metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n        return new_state, metrics, new_train_rng\n\n    # Create parallel version of the train and eval step\n    p_train_step = jax.pmap(train_step, \"batch\", donate_argnums=(0,))\n\n    # Replicate the train state on each device\n    state = jax_utils.replicate(state)\n    vae_params = jax_utils.replicate(vae_params)\n    unet_params = jax_utils.replicate(unet_params)\n\n    # Train!\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader))\n\n    # Scheduler and math around the number of training steps.\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel & distributed) = {total_train_batch_size}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n\n    global_step = 0\n\n    epochs = tqdm(range(args.num_train_epochs), desc=f\"Epoch ... (1/{args.num_train_epochs})\", position=0)\n    for epoch in epochs:\n        # ======================== Training ================================\n\n        train_metrics = []\n\n        steps_per_epoch = len(train_dataset) // total_train_batch_size\n        train_step_progress_bar = tqdm(total=steps_per_epoch, desc=\"Training...\", position=1, leave=False)\n        # train\n        for batch in train_dataloader:\n            batch = shard(batch)\n            state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs)\n            train_metrics.append(train_metric)\n\n            train_step_progress_bar.update(1)\n            global_step += 1\n\n            if global_step >= args.max_train_steps:\n                break\n            if global_step % args.save_steps == 0:\n                learned_embeds = get_params_to_save(state.params)[\"text_model\"][\"embeddings\"][\"token_embedding\"][\n                    \"embedding\"\n                ][placeholder_token_id]\n                learned_embeds_dict = {args.placeholder_token: learned_embeds}\n                jnp.save(\n                    os.path.join(args.output_dir, \"learned_embeds-\" + str(global_step) + \".npy\"), learned_embeds_dict\n                )\n\n        train_metric = jax_utils.unreplicate(train_metric)\n\n        train_step_progress_bar.close()\n        epochs.write(f\"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})\")\n\n    # Create the pipeline using using the trained modules and save it.\n    if jax.process_index() == 0:\n        scheduler = FlaxPNDMScheduler(\n            beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", skip_prk_steps=True\n        )\n        safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(\n            \"CompVis/stable-diffusion-safety-checker\", from_pt=True\n        )\n        pipeline = FlaxStableDiffusionPipeline(\n            text_encoder=text_encoder,\n            vae=vae,\n            unet=unet,\n            tokenizer=tokenizer,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=CLIPImageProcessor.from_pretrained(\"openai/clip-vit-base-patch32\"),\n        )\n\n        pipeline.save_pretrained(\n            args.output_dir,\n            params={\n                \"text_encoder\": get_params_to_save(state.params),\n                \"vae\": get_params_to_save(vae_params),\n                \"unet\": get_params_to_save(unet_params),\n                \"safety_checker\": safety_checker.params,\n            },\n        )\n\n        # Also save the newly trained embeddings\n        learned_embeds = get_params_to_save(state.params)[\"text_model\"][\"embeddings\"][\"token_embedding\"][\"embedding\"][\n            placeholder_token_id\n        ]\n        learned_embeds_dict = {args.placeholder_token: learned_embeds}\n        jnp.save(os.path.join(args.output_dir, \"learned_embeds.npy\"), learned_embeds_dict)\n\n        if args.push_to_hub:\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/textual_inversion/textual_inversion_sdxl.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport numpy as np\nimport PIL\nimport safetensors\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    DPMSolverMultistepScheduler,\n    UNet2DConditionModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\nif is_wandb_available():\n    import wandb\n\nif version.parse(version.parse(PIL.__version__).base_version) >= version.parse(\"9.1.0\"):\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.Resampling.BILINEAR,\n        \"bilinear\": PIL.Image.Resampling.BILINEAR,\n        \"bicubic\": PIL.Image.Resampling.BICUBIC,\n        \"lanczos\": PIL.Image.Resampling.LANCZOS,\n        \"nearest\": PIL.Image.Resampling.NEAREST,\n    }\nelse:\n    PIL_INTERPOLATION = {\n        \"linear\": PIL.Image.LINEAR,\n        \"bilinear\": PIL.Image.BILINEAR,\n        \"bicubic\": PIL.Image.BICUBIC,\n        \"lanczos\": PIL.Image.LANCZOS,\n        \"nearest\": PIL.Image.NEAREST,\n    }\n# ------------------------------------------------------------------------------\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    for i, image in enumerate(images):\n        image.save(os.path.join(repo_folder, f\"image_{i}.png\"))\n        img_str += f\"![img_{i}](./image_{i}.png)\\n\"\n\n    model_description = f\"\"\"\n# Textual inversion text2image fine-tuning - {repo_id}\nThese are textual inversion adaption weights for {base_model}. You can find some example images in the following. \\n\n{img_str}\n\"\"\"\n    model_card = load_or_create_model_card(\n        repo_id_or_path=repo_id,\n        from_training=True,\n        license=\"creativeml-openrail-m\",\n        base_model=base_model,\n        model_description=model_description,\n        inference=True,\n    )\n\n    tags = [\n        \"stable-diffusion-xl\",\n        \"stable-diffusion-xl-diffusers\",\n        \"text-to-image\",\n        \"diffusers\",\n        \"diffusers-training\",\n        \"textual_inversion\",\n    ]\n\n    model_card = populate_model_card(model_card, tags=tags)\n\n    model_card.save(os.path.join(repo_folder, \"README.md\"))\n\n\ndef log_validation(\n    text_encoder_1,\n    text_encoder_2,\n    tokenizer_1,\n    tokenizer_2,\n    unet,\n    vae,\n    args,\n    accelerator,\n    weight_dtype,\n    epoch,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = DiffusionPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        text_encoder=accelerator.unwrap_model(text_encoder_1),\n        text_encoder_2=accelerator.unwrap_model(text_encoder_2),\n        tokenizer=tokenizer_1,\n        tokenizer_2=tokenizer_2,\n        unet=unet,\n        vae=vae,\n        safety_checker=None,\n        revision=args.revision,\n        variant=args.variant,\n        torch_dtype=weight_dtype,\n    )\n    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)\n    images = []\n    for _ in range(args.num_validation_images):\n        image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]\n        images.append(image)\n\n    tracker_key = \"test\" if is_final_validation else \"validation\"\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(tracker_key, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    tracker_key: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    torch.cuda.empty_cache()\n    return images\n\n\ndef save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True):\n    logger.info(\"Saving embeddings\")\n    learned_embeds = (\n        accelerator.unwrap_model(text_encoder)\n        .get_input_embeddings()\n        .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]\n    )\n    learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}\n\n    if safe_serialization:\n        safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={\"format\": \"pt\"})\n    else:\n        torch.save(learned_embeds_dict, save_path)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--save_steps\",\n        type=int,\n        default=500,\n        help=\"Save learned_embeds.bin every X updates steps.\",\n    )\n    parser.add_argument(\n        \"--save_as_full_pipeline\",\n        action=\"store_true\",\n        help=\"Save the complete stable diffusion pipeline.\",\n    )\n    parser.add_argument(\n        \"--num_vectors\",\n        type=int,\n        default=1,\n        help=\"How many textual inversion vectors shall be used to learn the concept.\",\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\", type=str, default=None, required=True, help=\"A folder containing the training data.\"\n    )\n    parser.add_argument(\n        \"--placeholder_token\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"A token to use as a placeholder for the concept.\",\n    )\n    parser.add_argument(\n        \"--initializer_token\", type=str, default=None, required=True, help=\"A token to use as initializer word.\"\n    )\n    parser.add_argument(\"--learnable_property\", type=str, default=\"object\", help=\"Choose between 'object' and 'style'\")\n    parser.add_argument(\"--repeats\", type=int, default=100, help=\"How many times to repeat the training data.\")\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\", action=\"store_true\", help=\"Whether to center crop images before resizing to resolution.\"\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=5000,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.train_data_dir is None:\n        raise ValueError(\"You must specify a train data directory.\")\n\n    return args\n\n\nimagenet_templates_small = [\n    \"a photo of a {}\",\n    \"a rendering of a {}\",\n    \"a cropped photo of the {}\",\n    \"the photo of a {}\",\n    \"a photo of a clean {}\",\n    \"a photo of a dirty {}\",\n    \"a dark photo of the {}\",\n    \"a photo of my {}\",\n    \"a photo of the cool {}\",\n    \"a close-up photo of a {}\",\n    \"a bright photo of the {}\",\n    \"a cropped photo of a {}\",\n    \"a photo of the {}\",\n    \"a good photo of the {}\",\n    \"a photo of one {}\",\n    \"a close-up photo of the {}\",\n    \"a rendition of the {}\",\n    \"a photo of the clean {}\",\n    \"a rendition of a {}\",\n    \"a photo of a nice {}\",\n    \"a good photo of a {}\",\n    \"a photo of the nice {}\",\n    \"a photo of the small {}\",\n    \"a photo of the weird {}\",\n    \"a photo of the large {}\",\n    \"a photo of a cool {}\",\n    \"a photo of a small {}\",\n]\n\nimagenet_style_templates_small = [\n    \"a painting in the style of {}\",\n    \"a rendering in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"the painting in the style of {}\",\n    \"a clean painting in the style of {}\",\n    \"a dirty painting in the style of {}\",\n    \"a dark painting in the style of {}\",\n    \"a picture in the style of {}\",\n    \"a cool painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a bright painting in the style of {}\",\n    \"a cropped painting in the style of {}\",\n    \"a good painting in the style of {}\",\n    \"a close-up painting in the style of {}\",\n    \"a rendition in the style of {}\",\n    \"a nice painting in the style of {}\",\n    \"a small painting in the style of {}\",\n    \"a weird painting in the style of {}\",\n    \"a large painting in the style of {}\",\n]\n\n\nclass TextualInversionDataset(Dataset):\n    def __init__(\n        self,\n        data_root,\n        tokenizer_1,\n        tokenizer_2,\n        learnable_property=\"object\",  # [object, style]\n        size=512,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.5,\n        set=\"train\",\n        placeholder_token=\"*\",\n        center_crop=False,\n    ):\n        self.data_root = data_root\n        self.tokenizer_1 = tokenizer_1\n        self.tokenizer_2 = tokenizer_2\n        self.learnable_property = learnable_property\n        self.size = size\n        self.placeholder_token = placeholder_token\n        self.center_crop = center_crop\n        self.flip_p = flip_p\n\n        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.interpolation = {\n            \"linear\": PIL_INTERPOLATION[\"linear\"],\n            \"bilinear\": PIL_INTERPOLATION[\"bilinear\"],\n            \"bicubic\": PIL_INTERPOLATION[\"bicubic\"],\n            \"lanczos\": PIL_INTERPOLATION[\"lanczos\"],\n        }[interpolation]\n\n        self.templates = imagenet_style_templates_small if learnable_property == \"style\" else imagenet_templates_small\n        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)\n        self.crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        placeholder_string = self.placeholder_token\n        text = random.choice(self.templates).format(placeholder_string)\n\n        example[\"original_size\"] = (image.height, image.width)\n\n        image = image.resize((self.size, self.size), resample=self.interpolation)\n\n        if self.center_crop:\n            y1 = max(0, int(round((image.height - self.size) / 2.0)))\n            x1 = max(0, int(round((image.width - self.size) / 2.0)))\n            image = self.crop(image)\n        else:\n            y1, x1, h, w = self.crop.get_params(image, (self.size, self.size))\n            image = transforms.functional.crop(image, y1, x1, h, w)\n\n        example[\"crop_top_left\"] = (y1, x1)\n\n        example[\"input_ids_1\"] = self.tokenizer_1(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.tokenizer_1.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids[0]\n\n        example[\"input_ids_2\"] = self.tokenizer_2(\n            text,\n            padding=\"max_length\",\n            truncation=True,\n            max_length=self.tokenizer_2.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids[0]\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        image = Image.fromarray(img)\n\n        image = self.flip_transform(image)\n        image = np.array(image).astype(np.uint8)\n        image = (image / 127.5 - 1.0).astype(np.float32)\n\n        example[\"pixel_values\"] = torch.from_numpy(image).permute(2, 0, 1)\n        return example\n\n\ndef main():\n    args = parse_args()\n    if args.report_to == \"wandb\" and args.hub_token is not None:\n        raise ValueError(\n            \"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token.\"\n            \" Please use `hf auth login` to authenticate with the Hub.\"\n        )\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Load tokenizer\n    tokenizer_1 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer\")\n    tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"tokenizer_2\")\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder_1 = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision, variant=args.variant\n    )\n    unet = UNet2DConditionModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, variant=args.variant\n    )\n\n    # Add the placeholder token in tokenizer_1\n    placeholder_tokens = [args.placeholder_token]\n\n    if args.num_vectors < 1:\n        raise ValueError(f\"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}\")\n\n    # add dummy tokens for multi-vector\n    additional_tokens = []\n    for i in range(1, args.num_vectors):\n        additional_tokens.append(f\"{args.placeholder_token}_{i}\")\n    placeholder_tokens += additional_tokens\n\n    num_added_tokens = tokenizer_1.add_tokens(placeholder_tokens)\n    if num_added_tokens != args.num_vectors:\n        raise ValueError(\n            f\"The tokenizer already contains the token {args.placeholder_token}. Please pass a different\"\n            \" `placeholder_token` that is not already in the tokenizer.\"\n        )\n    num_added_tokens = tokenizer_2.add_tokens(placeholder_tokens)\n    if num_added_tokens != args.num_vectors:\n        raise ValueError(\n            f\"The 2nd tokenizer already contains the token {args.placeholder_token}. Please pass a different\"\n            \" `placeholder_token` that is not already in the tokenizer.\"\n        )\n\n    # Convert the initializer_token, placeholder_token to ids\n    token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False)\n    token_ids_2 = tokenizer_2.encode(args.initializer_token, add_special_tokens=False)\n\n    # Check if initializer_token is a single token or a sequence of tokens\n    if len(token_ids) > 1 or len(token_ids_2) > 1:\n        raise ValueError(\"The initializer token must be a single token.\")\n\n    initializer_token_id = token_ids[0]\n    placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens)\n    initializer_token_id_2 = token_ids_2[0]\n    placeholder_token_ids_2 = tokenizer_2.convert_tokens_to_ids(placeholder_tokens)\n\n    # Resize the token embeddings as we are adding new special tokens to the tokenizer\n    text_encoder_1.resize_token_embeddings(len(tokenizer_1))\n    text_encoder_2.resize_token_embeddings(len(tokenizer_2))\n\n    # Initialise the newly added placeholder token with the embeddings of the initializer token\n    token_embeds = text_encoder_1.get_input_embeddings().weight.data\n    token_embeds_2 = text_encoder_2.get_input_embeddings().weight.data\n    with torch.no_grad():\n        for token_id in placeholder_token_ids:\n            token_embeds[token_id] = token_embeds[initializer_token_id].clone()\n        for token_id in placeholder_token_ids_2:\n            token_embeds_2[token_id] = token_embeds_2[initializer_token_id_2].clone()\n\n    # Freeze vae and unet\n    vae.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    # Freeze all parameters except for the token embeddings in text encoder\n    text_encoder_1.text_model.encoder.requires_grad_(False)\n    text_encoder_1.text_model.final_layer_norm.requires_grad_(False)\n    text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)\n    text_encoder_2.text_model.encoder.requires_grad_(False)\n    text_encoder_2.text_model.final_layer_norm.requires_grad_(False)\n    text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)\n\n    if args.gradient_checkpointing:\n        text_encoder_1.gradient_checkpointing_enable()\n        text_encoder_2.gradient_checkpointing_enable()\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Initialize the optimizer\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    optimizer = optimizer_class(\n        # only optimize the embeddings\n        [\n            text_encoder_1.text_model.embeddings.token_embedding.weight,\n            text_encoder_2.text_model.embeddings.token_embedding.weight,\n        ],\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    placeholder_token = \" \".join(tokenizer_1.convert_ids_to_tokens(placeholder_token_ids))\n    # Dataset and DataLoaders creation:\n    train_dataset = TextualInversionDataset(\n        data_root=args.train_data_dir,\n        tokenizer_1=tokenizer_1,\n        tokenizer_2=tokenizer_2,\n        size=args.resolution,\n        placeholder_token=placeholder_token,\n        repeats=args.repeats,\n        learnable_property=args.learnable_property,\n        center_crop=args.center_crop,\n        set=\"train\",\n    )\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers\n    )\n\n    # Scheduler and math around the number of training steps.\n    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.\n    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes\n    if args.max_train_steps is None:\n        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)\n        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)\n        num_training_steps_for_scheduler = (\n            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes\n        )\n    else:\n        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=num_warmup_steps_for_scheduler,\n        num_training_steps=num_training_steps_for_scheduler,\n        num_cycles=args.lr_num_cycles,\n    )\n\n    text_encoder_1.train()\n    text_encoder_2.train()\n    # Prepare everything with our `accelerator`.\n    text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae and unet and text_encoder_2 to device and cast to weight_dtype\n    unet.to(accelerator.device, dtype=weight_dtype)\n    vae.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_2.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:\n            logger.warning(\n                f\"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match \"\n                f\"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. \"\n                f\"This inconsistency may result in the learning rate scheduler not functioning properly.\"\n            )\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"textual_inversion\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    # keep original embeddings as reference\n    orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()\n    orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        text_encoder_1.train()\n        text_encoder_2.train()\n        for step, batch in enumerate(train_dataloader):\n            with accelerator.accumulate([text_encoder_1, text_encoder_2]):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(dtype=weight_dtype)).latent_dist.sample().detach()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states_1 = (\n                    text_encoder_1(batch[\"input_ids_1\"], output_hidden_states=True)\n                    .hidden_states[-2]\n                    .to(dtype=weight_dtype)\n                )\n                encoder_output_2 = text_encoder_2(batch[\"input_ids_2\"], output_hidden_states=True)\n                encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)\n                original_size = [\n                    (batch[\"original_size\"][0][i].item(), batch[\"original_size\"][1][i].item())\n                    for i in range(args.train_batch_size)\n                ]\n                crop_top_left = [\n                    (batch[\"crop_top_left\"][0][i].item(), batch[\"crop_top_left\"][1][i].item())\n                    for i in range(args.train_batch_size)\n                ]\n                target_size = (args.resolution, args.resolution)\n                add_time_ids = torch.cat(\n                    [\n                        torch.tensor(original_size[i] + crop_top_left[i] + target_size)\n                        for i in range(args.train_batch_size)\n                    ]\n                ).to(accelerator.device, dtype=weight_dtype)\n                added_cond_kwargs = {\"text_embeds\": encoder_output_2[0], \"time_ids\": add_time_ids}\n                encoder_hidden_states = torch.cat([encoder_hidden_states_1, encoder_hidden_states_2], dim=-1)\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs\n                ).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(latents, noise, timesteps)\n                else:\n                    raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n                # Let's make sure we don't update any embedding weights besides the newly added token\n                index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool)\n                index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False\n                index_no_updates_2 = torch.ones((len(tokenizer_2),), dtype=torch.bool)\n                index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False\n\n                with torch.no_grad():\n                    accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (\n                        orig_embeds_params[index_no_updates]\n                    )\n                    accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (\n                        orig_embeds_params_2[index_no_updates_2]\n                    )\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                images = []\n                progress_bar.update(1)\n                global_step += 1\n                if global_step % args.save_steps == 0:\n                    weight_name = f\"learned_embeds-steps-{global_step}.safetensors\"\n                    save_path = os.path.join(args.output_dir, weight_name)\n                    save_progress(\n                        text_encoder_1,\n                        placeholder_token_ids,\n                        accelerator,\n                        args,\n                        save_path,\n                        safe_serialization=True,\n                    )\n                    weight_name = f\"learned_embeds_2-steps-{global_step}.safetensors\"\n                    save_path = os.path.join(args.output_dir, weight_name)\n                    save_progress(\n                        text_encoder_2,\n                        placeholder_token_ids_2,\n                        accelerator,\n                        args,\n                        save_path,\n                        safe_serialization=True,\n                    )\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                    if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                        images = log_validation(\n                            text_encoder_1,\n                            text_encoder_2,\n                            tokenizer_1,\n                            tokenizer_2,\n                            unet,\n                            vae,\n                            args,\n                            accelerator,\n                            weight_dtype,\n                            epoch,\n                        )\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n    # Create the pipeline using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        if args.validation_prompt:\n            images = log_validation(\n                text_encoder_1,\n                text_encoder_2,\n                tokenizer_1,\n                tokenizer_2,\n                unet,\n                vae,\n                args,\n                accelerator,\n                weight_dtype,\n                epoch,\n                is_final_validation=True,\n            )\n\n        if args.push_to_hub and not args.save_as_full_pipeline:\n            logger.warning(\"Enabling full model saving because --push_to_hub=True was specified.\")\n            save_full_model = True\n        else:\n            save_full_model = args.save_as_full_pipeline\n        if save_full_model:\n            pipeline = DiffusionPipeline.from_pretrained(\n                args.pretrained_model_name_or_path,\n                text_encoder=accelerator.unwrap_model(text_encoder_1),\n                text_encoder_2=accelerator.unwrap_model(text_encoder_2),\n                vae=vae,\n                unet=unet,\n                tokenizer=tokenizer_1,\n                tokenizer_2=tokenizer_2,\n            )\n            pipeline.save_pretrained(args.output_dir)\n        # Save the newly trained embeddings\n        weight_name = \"learned_embeds.safetensors\"\n        save_path = os.path.join(args.output_dir, weight_name)\n        save_progress(\n            text_encoder_1,\n            placeholder_token_ids,\n            accelerator,\n            args,\n            save_path,\n            safe_serialization=True,\n        )\n        weight_name = \"learned_embeds_2.safetensors\"\n        save_path = os.path.join(args.output_dir, weight_name)\n        save_progress(\n            text_encoder_2,\n            placeholder_token_ids_2,\n            accelerator,\n            args,\n            save_path,\n            safe_serialization=True,\n        )\n\n        if args.push_to_hub:\n            save_model_card(\n                repo_id,\n                images=images,\n                base_model=args.pretrained_model_name_or_path,\n                repo_folder=args.output_dir,\n            )\n            upload_folder(\n                repo_id=repo_id,\n                folder_path=args.output_dir,\n                commit_message=\"End of training\",\n                ignore_patterns=[\"step_*\", \"epoch_*\"],\n            )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/unconditional_image_generation/README.md",
    "content": "## Training an unconditional diffusion model\n\nCreating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder  and run\n```bash\npip install -r requirements.txt\n```\n\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\n### Unconditional Flowers\n\nThe command to train a DDPM UNet model on the Oxford Flowers dataset:\n\n```bash\naccelerate launch train_unconditional.py \\\n  --dataset_name=\"huggan/flowers-102-categories\" \\\n  --resolution=64 --center_crop --random_flip \\\n  --output_dir=\"ddpm-ema-flowers-64\" \\\n  --train_batch_size=16 \\\n  --num_epochs=100 \\\n  --gradient_accumulation_steps=1 \\\n  --use_ema \\\n  --learning_rate=1e-4 \\\n  --lr_warmup_steps=500 \\\n  --mixed_precision=no \\\n  --push_to_hub\n```\nAn example trained model: https://huggingface.co/anton-l/ddpm-ema-flowers-64\n\nA full training run takes 2 hours on 4xV100 GPUs.\n\n<img src=\"https://user-images.githubusercontent.com/26864830/180248660-a0b143d0-b89a-42c5-8656-2ebf6ece7e52.png\" width=\"700\" />\n\n\n### Unconditional Pokemon\n\nThe command to train a DDPM UNet model on the Pokemon dataset:\n\n```bash\naccelerate launch train_unconditional.py \\\n  --dataset_name=\"huggan/pokemon\" \\\n  --resolution=64 --center_crop --random_flip \\\n  --output_dir=\"ddpm-ema-pokemon-64\" \\\n  --train_batch_size=16 \\\n  --num_epochs=100 \\\n  --gradient_accumulation_steps=1 \\\n  --use_ema \\\n  --learning_rate=1e-4 \\\n  --lr_warmup_steps=500 \\\n  --mixed_precision=no \\\n  --push_to_hub\n```\nAn example trained model: https://huggingface.co/anton-l/ddpm-ema-pokemon-64\n\nA full training run takes 2 hours on 4xV100 GPUs.\n\n<img src=\"https://user-images.githubusercontent.com/26864830/180248200-928953b4-db38-48db-b0c6-8b740fe6786f.png\" width=\"700\" />\n\n### Training with multiple GPUs\n\n`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)\nfor running distributed training with `accelerate`. Here is an example command:\n\n```bash\naccelerate launch --mixed_precision=\"fp16\" --multi_gpu train_unconditional.py \\\n  --dataset_name=\"huggan/pokemon\" \\\n  --resolution=64 --center_crop --random_flip \\\n  --output_dir=\"ddpm-ema-pokemon-64\" \\\n  --train_batch_size=16 \\\n  --num_epochs=100 \\\n  --gradient_accumulation_steps=1 \\\n  --use_ema \\\n  --learning_rate=1e-4 \\\n  --lr_warmup_steps=500 \\\n  --mixed_precision=\"fp16\" \\\n  --logger=\"wandb\"\n```\n\nTo be able to use Weights and Biases (`wandb`) as a logger you need to install the library: `pip install wandb`.\n\n### Using your own data\n\nTo use your own dataset, there are 2 ways:\n- you can either provide your own folder as `--train_data_dir`\n- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.\n\nIf your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.\n\nBelow, we explain both in more detail.\n\n#### Provide the dataset as a folder\n\nIf you provide your own folders with images, the script expects the following directory structure:\n\n```bash\ndata_dir/xxx.png\ndata_dir/xxy.png\ndata_dir/[...]/xxz.png\n```\n\nIn other words, the script will take care of gathering all images inside the folder. You can then run the script like this:\n\n```bash\naccelerate launch train_unconditional.py \\\n    --train_data_dir <path-to-train-directory> \\\n    <other-arguments>\n```\n\nInternally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.\n\n#### Upload your data to the hub, as a (possibly private) repo\n\nIt's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:\n\n```python\nfrom datasets import load_dataset\n\n# example 1: local folder\ndataset = load_dataset(\"imagefolder\", data_dir=\"path_to_your_folder\")\n\n# example 2: local files (supported formats are tar, gzip, zip, xz, rar, zstd)\ndataset = load_dataset(\"imagefolder\", data_files=\"path_to_zip_file\")\n\n# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd)\ndataset = load_dataset(\"imagefolder\", data_files=\"https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip\")\n\n# example 4: providing several splits\ndataset = load_dataset(\"imagefolder\", data_files={\"train\": [\"path/to/file1\", \"path/to/file2\"], \"test\": [\"path/to/file3\", \"path/to/file4\"]})\n```\n\n`ImageFolder` will create an `image` column containing the PIL-encoded images.\n\nNext, push it to the hub!\n\n```python\n# assuming you have ran the hf auth login command in a terminal\ndataset.push_to_hub(\"name_of_your_dataset\")\n\n# if you want to push to a private repo, simply pass private=True:\ndataset.push_to_hub(\"name_of_your_dataset\", private=True)\n```\n\nand that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.\n\nMore on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).\n"
  },
  {
    "path": "examples/unconditional_image_generation/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ndatasets\n"
  },
  {
    "path": "examples/unconditional_image_generation/test_unconditional.py",
    "content": "# coding=utf-8\n# Copyright 2025 HuggingFace Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport tempfile\n\n\nsys.path.append(\"..\")\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\nclass Unconditional(ExamplesTestsAccelerate):\n    def test_train_unconditional(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            test_args = f\"\"\"\n                examples/unconditional_image_generation/train_unconditional.py\n                --dataset_name hf-internal-testing/dummy_image_class_data\n                --model_config_name_or_path diffusers/ddpm_dummy\n                --resolution 64\n                --output_dir {tmpdir}\n                --train_batch_size 2\n                --num_epochs 1\n                --gradient_accumulation_steps 1\n                --ddpm_num_inference_steps 2\n                --learning_rate 1e-3\n                --lr_warmup_steps 5\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args, return_stdout=True)\n            # save_pretrained smoke test\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"unet\", \"diffusion_pytorch_model.safetensors\")))\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"scheduler\", \"scheduler_config.json\")))\n\n    def test_unconditional_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            initial_run_args = f\"\"\"\n                examples/unconditional_image_generation/train_unconditional.py\n                --dataset_name hf-internal-testing/dummy_image_class_data\n                --model_config_name_or_path diffusers/ddpm_dummy\n                --resolution 64\n                --output_dir {tmpdir}\n                --train_batch_size 1\n                --num_epochs 1\n                --gradient_accumulation_steps 1\n                --ddpm_num_inference_steps 2\n                --learning_rate 1e-3\n                --lr_warmup_steps 5\n                --checkpointing_steps=2\n                --checkpoints_total_limit=2\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                # checkpoint-2 should have been deleted\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            initial_run_args = f\"\"\"\n                examples/unconditional_image_generation/train_unconditional.py\n                --dataset_name hf-internal-testing/dummy_image_class_data\n                --model_config_name_or_path diffusers/ddpm_dummy\n                --resolution 64\n                --output_dir {tmpdir}\n                --train_batch_size 1\n                --num_epochs 1\n                --gradient_accumulation_steps 1\n                --ddpm_num_inference_steps 1\n                --learning_rate 1e-3\n                --lr_warmup_steps 5\n                --checkpointing_steps=2\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n            resume_run_args = f\"\"\"\n                examples/unconditional_image_generation/train_unconditional.py\n                --dataset_name hf-internal-testing/dummy_image_class_data\n                --model_config_name_or_path diffusers/ddpm_dummy\n                --resolution 64\n                --output_dir {tmpdir}\n                --train_batch_size 1\n                --num_epochs 2\n                --gradient_accumulation_steps 1\n                --ddpm_num_inference_steps 1\n                --learning_rate 1e-3\n                --lr_warmup_steps 5\n                --resume_from_checkpoint=checkpoint-6\n                --checkpointing_steps=2\n                --checkpoints_total_limit=2\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-10\", \"checkpoint-12\"},\n            )\n"
  },
  {
    "path": "examples/unconditional_image_generation/train_unconditional.py",
    "content": "import argparse\nimport inspect\nimport logging\nimport math\nimport os\nimport shutil\nfrom datetime import timedelta\nfrom pathlib import Path\n\nimport accelerate\nimport datasets\nimport torch\nimport torch.nn.functional as F\nfrom accelerate import Accelerator, InitProcessGroupKwargs\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration\nfrom datasets import load_dataset\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\n\nimport diffusers\nfrom diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel\nfrom diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\ndef _extract_into_tensor(arr, timesteps, broadcast_shape):\n    \"\"\"\n    Extract values from a 1-D numpy array for a batch of indices.\n\n    :param arr: the 1-D numpy array.\n    :param timesteps: a tensor of indices into the array to extract.\n    :param broadcast_shape: a larger shape of K dimensions with the batch\n                            dimension equal to the length of timesteps.\n    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.\n    \"\"\"\n    if not isinstance(arr, torch.Tensor):\n        arr = torch.from_numpy(arr)\n    res = arr[timesteps].float().to(timesteps.device)\n    while len(res.shape) < len(broadcast_shape):\n        res = res[..., None]\n    return res.expand(broadcast_shape)\n\n\ndef _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.\n    \"\"\"\n    if tensor.ndim == 2:\n        tensor = tensor.unsqueeze(0)\n    channels = tensor.shape[0]\n    if channels == 3:\n        return tensor\n    if channels == 1:\n        return tensor.repeat(3, 1, 1)\n    if channels == 2:\n        return torch.cat([tensor, tensor[:1]], dim=0)\n    if channels > 3:\n        return tensor[:3]\n    raise ValueError(f\"Unsupported number of channels: {channels}\")\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that HF Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--model_config_name_or_path\",\n        type=str,\n        default=None,\n        help=\"The config of the UNet model to train, leave as None to use standard DDPM configuration.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"ddpm-model-64\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--overwrite_output_dir\", action=\"store_true\")\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=64,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--eval_batch_size\", type=int, default=16, help=\"The number of images to generate for evaluation.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main\"\n            \" process.\"\n        ),\n    )\n    parser.add_argument(\"--num_epochs\", type=int, default=100)\n    parser.add_argument(\"--save_images_epochs\", type=int, default=10, help=\"How often to save images during training.\")\n    parser.add_argument(\n        \"--save_model_epochs\", type=int, default=10, help=\"How often to save the model during training.\"\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"cosine\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.95, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\n        \"--adam_weight_decay\", type=float, default=1e-6, help=\"Weight decay magnitude for the Adam optimizer.\"\n    )\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer.\")\n    parser.add_argument(\n        \"--use_ema\",\n        action=\"store_true\",\n        help=\"Whether to use Exponential Moving Average for the final model weights.\",\n    )\n    parser.add_argument(\"--ema_inv_gamma\", type=float, default=1.0, help=\"The inverse gamma value for the EMA decay.\")\n    parser.add_argument(\"--ema_power\", type=float, default=3 / 4, help=\"The power value for the EMA decay.\")\n    parser.add_argument(\"--ema_max_decay\", type=float, default=0.9999, help=\"The maximum decay magnitude for EMA.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--hub_private_repo\", action=\"store_true\", help=\"Whether or not to create a private repository.\"\n    )\n    parser.add_argument(\n        \"--logger\",\n        type=str,\n        default=\"tensorboard\",\n        choices=[\"tensorboard\", \"wandb\"],\n        help=(\n            \"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)\"\n            \" for experiment tracking and logging of model metrics and model checkpoints\"\n        ),\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"no\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose\"\n            \"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.\"\n            \"and an Nvidia Ampere GPU.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=\"epsilon\",\n        choices=[\"epsilon\", \"sample\"],\n        help=\"Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.\",\n    )\n    parser.add_argument(\"--ddpm_num_steps\", type=int, default=1000)\n    parser.add_argument(\"--ddpm_num_inference_steps\", type=int, default=1000)\n    parser.add_argument(\"--ddpm_beta_schedule\", type=str, default=\"linear\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--preserve_input_precision\",\n        action=\"store_true\",\n        help=\"Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.\",\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"You must specify either a dataset name from the hub or a train data directory.\")\n\n    return args\n\n\ndef main(args):\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))  # a big number for high resolution or big dataset\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.logger,\n        project_config=accelerator_project_config,\n        kwargs_handlers=[kwargs],\n    )\n\n    if args.logger == \"tensorboard\":\n        if not is_tensorboard_available():\n            raise ImportError(\"Make sure to install tensorboard if you want to use it for logging during training.\")\n\n    elif args.logger == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n        import wandb\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_model.save_pretrained(os.path.join(output_dir, \"unet_ema\"))\n\n                for i, model in enumerate(models):\n                    model.save_pretrained(os.path.join(output_dir, \"unet\"))\n\n                    # make sure to pop weight so that corresponding model is not saved again\n                    weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"unet_ema\"), UNet2DModel)\n                ema_model.load_state_dict(load_model.state_dict())\n                ema_model.to(accelerator.device)\n                del load_model\n\n            for i in range(len(models)):\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                # load diffusers style into model\n                load_model = UNet2DModel.from_pretrained(input_dir, subfolder=\"unet\")\n                model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        datasets.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            repo_id = create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    # Initialize the model\n    if args.model_config_name_or_path is None:\n        model = UNet2DModel(\n            sample_size=args.resolution,\n            in_channels=3,\n            out_channels=3,\n            layers_per_block=2,\n            block_out_channels=(128, 128, 256, 256, 512, 512),\n            down_block_types=(\n                \"DownBlock2D\",\n                \"DownBlock2D\",\n                \"DownBlock2D\",\n                \"DownBlock2D\",\n                \"AttnDownBlock2D\",\n                \"DownBlock2D\",\n            ),\n            up_block_types=(\n                \"UpBlock2D\",\n                \"AttnUpBlock2D\",\n                \"UpBlock2D\",\n                \"UpBlock2D\",\n                \"UpBlock2D\",\n                \"UpBlock2D\",\n            ),\n        )\n    else:\n        config = UNet2DModel.load_config(args.model_config_name_or_path)\n        model = UNet2DModel.from_config(config)\n\n    # Create EMA for the model.\n    if args.use_ema:\n        ema_model = EMAModel(\n            model.parameters(),\n            decay=args.ema_max_decay,\n            use_ema_warmup=True,\n            inv_gamma=args.ema_inv_gamma,\n            power=args.ema_power,\n            model_cls=UNet2DModel,\n            model_config=model.config,\n        )\n\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n        args.mixed_precision = accelerator.mixed_precision\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n        args.mixed_precision = accelerator.mixed_precision\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warning(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            model.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\n    # Initialize the scheduler\n    accepts_prediction_type = \"prediction_type\" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())\n    if accepts_prediction_type:\n        noise_scheduler = DDPMScheduler(\n            num_train_timesteps=args.ddpm_num_steps,\n            beta_schedule=args.ddpm_beta_schedule,\n            prediction_type=args.prediction_type,\n        )\n    else:\n        noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)\n\n    # Initialize the optimizer\n    optimizer = torch.optim.AdamW(\n        model.parameters(),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Get the datasets: you can either provide your own training and evaluation files (see below)\n    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).\n\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if args.dataset_name is not None:\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            split=\"train\",\n        )\n    else:\n        dataset = load_dataset(\"imagefolder\", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split=\"train\")\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets and DataLoaders creation.\n    spatial_augmentations = [\n        transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n        transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n        transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n    ]\n\n    augmentations = transforms.Compose(\n        spatial_augmentations\n        + [\n            transforms.ToTensor(),\n            transforms.Normalize([0.5], [0.5]),\n        ]\n    )\n\n    precision_augmentations = transforms.Compose(\n        [\n            transforms.PILToTensor(),\n            transforms.Lambda(_ensure_three_channels),\n            transforms.ConvertImageDtype(torch.float32),\n        ]\n        + spatial_augmentations\n        + [transforms.Normalize([0.5], [0.5])]\n    )\n\n    def transform_images(examples):\n        processed = []\n        for image in examples[\"image\"]:\n            if not args.preserve_input_precision:\n                processed.append(augmentations(image.convert(\"RGB\")))\n            else:\n                precise_image = image\n                if precise_image.mode == \"P\":\n                    precise_image = precise_image.convert(\"RGB\")\n                processed.append(precision_augmentations(precise_image))\n        return {\"input\": processed}\n\n    logger.info(f\"Dataset size: {len(dataset)}\")\n\n    dataset.set_transform(transform_images)\n    train_dataloader = torch.utils.data.DataLoader(\n        dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers\n    )\n\n    # Initialize the learning rate scheduler\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=(len(train_dataloader) * args.num_epochs),\n    )\n\n    # Prepare everything with our `accelerator`.\n    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        model, optimizer, train_dataloader, lr_scheduler\n    )\n\n    if args.use_ema:\n        ema_model.to(accelerator.device)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        run = os.path.split(__file__)[-1].split(\".\")[0]\n        accelerator.init_trackers(run)\n\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    max_train_steps = args.num_epochs * num_update_steps_per_epoch\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {max_train_steps}\")\n\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Train!\n    for epoch in range(first_epoch, args.num_epochs):\n        model.train()\n        progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)\n        progress_bar.set_description(f\"Epoch {epoch}\")\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            clean_images = batch[\"input\"].to(weight_dtype)\n            # Sample noise that we'll add to the images\n            noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device)\n            bsz = clean_images.shape[0]\n            # Sample a random timestep for each image\n            timesteps = torch.randint(\n                0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device\n            ).long()\n\n            # Add noise to the clean images according to the noise magnitude at each timestep\n            # (this is the forward diffusion process)\n            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)\n\n            with accelerator.accumulate(model):\n                # Predict the noise residual\n                model_output = model(noisy_images, timesteps).sample\n\n                if args.prediction_type == \"epsilon\":\n                    loss = F.mse_loss(model_output.float(), noise.float())  # this could have different weights!\n                elif args.prediction_type == \"sample\":\n                    alpha_t = _extract_into_tensor(\n                        noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)\n                    )\n                    snr_weights = alpha_t / (1 - alpha_t)\n                    # use SNR weighting from distillation paper\n                    loss = snr_weights * F.mse_loss(model_output.float(), clean_images.float(), reduction=\"none\")\n                    loss = loss.mean()\n                else:\n                    raise ValueError(f\"Unsupported prediction type: {args.prediction_type}\")\n\n                accelerator.backward(loss)\n\n                if accelerator.sync_gradients:\n                    accelerator.clip_grad_norm_(model.parameters(), 1.0)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                if args.use_ema:\n                    ema_model.step(model.parameters())\n                progress_bar.update(1)\n                global_step += 1\n\n                if accelerator.is_main_process:\n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0], \"step\": global_step}\n            if args.use_ema:\n                logs[\"ema_decay\"] = ema_model.cur_decay_value\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n        progress_bar.close()\n\n        accelerator.wait_for_everyone()\n\n        # Generate sample images for visual inspection\n        if accelerator.is_main_process:\n            if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:\n                unet = accelerator.unwrap_model(model)\n\n                if args.use_ema:\n                    ema_model.store(unet.parameters())\n                    ema_model.copy_to(unet.parameters())\n\n                pipeline = DDPMPipeline(\n                    unet=unet,\n                    scheduler=noise_scheduler,\n                )\n\n                generator = torch.Generator(device=pipeline.device).manual_seed(0)\n                # run pipeline in inference (sample random noise and denoise)\n                images = pipeline(\n                    generator=generator,\n                    batch_size=args.eval_batch_size,\n                    num_inference_steps=args.ddpm_num_inference_steps,\n                    output_type=\"np\",\n                ).images\n\n                if args.use_ema:\n                    ema_model.restore(unet.parameters())\n\n                # denormalize the images and save to tensorboard\n                images_processed = (images * 255).round().astype(\"uint8\")\n\n                if args.logger == \"tensorboard\":\n                    if is_accelerate_version(\">=\", \"0.17.0.dev0\"):\n                        tracker = accelerator.get_tracker(\"tensorboard\", unwrap=True)\n                    else:\n                        tracker = accelerator.get_tracker(\"tensorboard\")\n                    tracker.add_images(\"test_samples\", images_processed.transpose(0, 3, 1, 2), epoch)\n                elif args.logger == \"wandb\":\n                    # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files\n                    accelerator.get_tracker(\"wandb\").log(\n                        {\"test_samples\": [wandb.Image(img) for img in images_processed], \"epoch\": epoch},\n                        step=global_step,\n                    )\n\n            if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:\n                # save the model\n                unet = accelerator.unwrap_model(model)\n\n                if args.use_ema:\n                    ema_model.store(unet.parameters())\n                    ema_model.copy_to(unet.parameters())\n\n                pipeline = DDPMPipeline(\n                    unet=unet,\n                    scheduler=noise_scheduler,\n                )\n\n                pipeline.save_pretrained(args.output_dir)\n\n                if args.use_ema:\n                    ema_model.restore(unet.parameters())\n\n                if args.push_to_hub:\n                    upload_folder(\n                        repo_id=repo_id,\n                        folder_path=args.output_dir,\n                        commit_message=f\"Epoch {epoch}\",\n                        ignore_patterns=[\"step_*\", \"epoch_*\"],\n                    )\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/vqgan/README.md",
    "content": "## Training an VQGAN VAE\nVQVAEs were first introduced in [Neural Discrete Representation Learning](https://huggingface.co/papers/1711.00937) and was combined with a GAN in the paper [Taming Transformers for High-Resolution Image Synthesis](https://huggingface.co/papers/2012.09841). The basic idea of a VQVAE is it's a type of a variational auto encoder with tokens as the latent space similar to tokens for LLMs. This script was adapted from a [pr to huggingface's open-muse project](https://github.com/huggingface/open-muse/pull/52) with general code following [lucidrian's implementation of the vqgan training script](https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/trainers.py) but both of these implementation follow from the [taming transformer repo](https://github.com/CompVis/taming-transformers?tab=readme-ov-file).\n\n\nCreating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/diffusers\ncd diffusers\npip install .\n```\n\nThen cd in the example folder  and run\n```bash\npip install -r requirements.txt\n```\n\n\nAnd initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:\n\n```bash\naccelerate config\n```\n\n### Training on CIFAR10\n\nThe command to train a VQGAN model on cifar10 dataset:\n\n```bash\naccelerate launch train_vqgan.py \\\n  --dataset_name=cifar10 \\\n  --image_column=img \\\n  --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \\\n  --resolution=128 \\\n  --train_batch_size=2 \\\n  --gradient_accumulation_steps=8 \\\n  --report_to=wandb\n```\n\nAn example training run is [here](https://wandb.ai/sayakpaul/vqgan-training/runs/0m5kzdfp) by @sayakpaul and a lower scale one [here](https://wandb.ai/dsbuddy27/vqgan-training/runs/eqd6xi4n?nw=nwuserisamu). The validation images can be obtained from [here](https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images).\nThe simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocabulary size which at most can be around 16384. How to do this is shown below.\n\n# Modifying the architecture\n\nTo modify the architecture of the vqgan model you can save the config taken from [here](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder/blob/main/movq/config.json) and then provide that to the script with the option --model_config_name_or_path. This config is below\n```\n{\n  \"_class_name\": \"VQModel\",\n  \"_diffusers_version\": \"0.17.0.dev0\",\n  \"act_fn\": \"silu\",\n  \"block_out_channels\": [\n    128,\n    256,\n    256,\n    512\n  ],\n  \"down_block_types\": [\n    \"DownEncoderBlock2D\",\n    \"DownEncoderBlock2D\",\n    \"DownEncoderBlock2D\",\n    \"AttnDownEncoderBlock2D\"\n  ],\n  \"in_channels\": 3,\n  \"latent_channels\": 4,\n  \"layers_per_block\": 2,\n  \"norm_num_groups\": 32,\n  \"norm_type\": \"spatial\",\n  \"num_vq_embeddings\": 16384,\n  \"out_channels\": 3,\n  \"sample_size\": 32,\n  \"scaling_factor\": 0.18215,\n  \"up_block_types\": [\n    \"AttnUpDecoderBlock2D\",\n    \"UpDecoderBlock2D\",\n    \"UpDecoderBlock2D\",\n    \"UpDecoderBlock2D\"\n  ],\n  \"vq_embed_dim\": 4\n}\n```\nTo lower the amount of layers in a VQGan, you can remove layers by modifying the block_out_channels, down_block_types, and up_block_types like below\n```\n{\n  \"_class_name\": \"VQModel\",\n  \"_diffusers_version\": \"0.17.0.dev0\",\n  \"act_fn\": \"silu\",\n  \"block_out_channels\": [\n    128,\n    256,\n    256,\n  ],\n  \"down_block_types\": [\n    \"DownEncoderBlock2D\",\n    \"DownEncoderBlock2D\",\n    \"DownEncoderBlock2D\",\n  ],\n  \"in_channels\": 3,\n  \"latent_channels\": 4,\n  \"layers_per_block\": 2,\n  \"norm_num_groups\": 32,\n  \"norm_type\": \"spatial\",\n  \"num_vq_embeddings\": 16384,\n  \"out_channels\": 3,\n  \"sample_size\": 32,\n  \"scaling_factor\": 0.18215,\n  \"up_block_types\": [\n    \"UpDecoderBlock2D\",\n    \"UpDecoderBlock2D\",\n    \"UpDecoderBlock2D\"\n  ],\n  \"vq_embed_dim\": 4\n}\n```\nFor increasing the size of the vocabularies you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that.\n\n## Extra training tips/ideas\nDuring logging take care to make sure data_time is low. data_time is the amount spent loading the data and where the GPU is not active. So essentially, it's the time wasted. The easiest way to lower data time is to increase the --dataloader_num_workers to a higher number like 4. Due to a bug in Pytorch, this only works on linux based systems. For more details check [here](https://github.com/huggingface/diffusers/issues/7646)\nSecondly, training should seem to be done when both the discriminator and the generator loss converges.\nThirdly, another low hanging fruit is just using ema using the --use_ema parameter. This tends to make the output images smoother. This has a con where you have to lower your batch size by 1 but it may be worth it.\nAnother more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, be careful with the feature map norms since this can easily overdominate the loss."
  },
  {
    "path": "examples/vqgan/discriminator.py",
    "content": "\"\"\"\nPorted from Paella\n\"\"\"\n\nimport torch\nfrom torch import nn\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.models.modeling_utils import ModelMixin\n\n\n# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py\nclass Discriminator(ModelMixin, ConfigMixin):\n    @register_to_config\n    def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):\n        super().__init__()\n        d = max(depth - 3, 3)\n        layers = [\n            nn.utils.spectral_norm(\n                nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)\n            ),\n            nn.LeakyReLU(0.2),\n        ]\n        for i in range(depth - 1):\n            c_in = hidden_channels // (2 ** max((d - i), 0))\n            c_out = hidden_channels // (2 ** max((d - 1 - i), 0))\n            layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))\n            layers.append(nn.InstanceNorm2d(c_out))\n            layers.append(nn.LeakyReLU(0.2))\n        self.encoder = nn.Sequential(*layers)\n        self.shuffle = nn.Conv2d(\n            (hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1\n        )\n        self.logits = nn.Sigmoid()\n\n    def forward(self, x, cond=None):\n        x = self.encoder(x)\n        if cond is not None:\n            cond = cond.view(\n                cond.size(0),\n                cond.size(1),\n                1,\n                1,\n            ).expand(-1, -1, x.size(-2), x.size(-1))\n            x = torch.cat([x, cond], dim=1)\n        x = self.shuffle(x)\n        x = self.logits(x)\n        return x\n"
  },
  {
    "path": "examples/vqgan/requirements.txt",
    "content": "accelerate>=0.16.0\ntorchvision\ntransformers>=4.25.1\ndatasets\ntimm\nnumpy\ntqdm\ntensorboard"
  },
  {
    "path": "examples/vqgan/test_vqgan.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport logging\nimport os\nimport shutil\nimport sys\nimport tempfile\n\nimport torch\n\nfrom diffusers import VQModel\n\n\n# Add parent directories to path to import from tests\nsys.path.append(\"..\")\nrepo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), \"../..\"))\nif repo_root not in sys.path:\n    sys.path.insert(0, repo_root)\n\nfrom test_examples_utils import ExamplesTestsAccelerate, run_command  # noqa: E402\n\nfrom tests.testing_utils import require_timm  # noqa\n\n\nlogging.basicConfig(level=logging.DEBUG)\n\nlogger = logging.getLogger()\nstream_handler = logging.StreamHandler(sys.stdout)\nlogger.addHandler(stream_handler)\n\n\n@require_timm\nclass TextToImage(ExamplesTestsAccelerate):\n    @property\n    def test_vqmodel_config(self):\n        return {\n            \"_class_name\": \"VQModel\",\n            \"_diffusers_version\": \"0.17.0.dev0\",\n            \"act_fn\": \"silu\",\n            \"block_out_channels\": [\n                32,\n            ],\n            \"down_block_types\": [\n                \"DownEncoderBlock2D\",\n            ],\n            \"in_channels\": 3,\n            \"latent_channels\": 4,\n            \"layers_per_block\": 2,\n            \"norm_num_groups\": 32,\n            \"norm_type\": \"spatial\",\n            \"num_vq_embeddings\": 32,\n            \"out_channels\": 3,\n            \"sample_size\": 32,\n            \"scaling_factor\": 0.18215,\n            \"up_block_types\": [\n                \"UpDecoderBlock2D\",\n            ],\n            \"vq_embed_dim\": 4,\n        }\n\n    @property\n    def test_discriminator_config(self):\n        return {\n            \"_class_name\": \"Discriminator\",\n            \"_diffusers_version\": \"0.27.0.dev0\",\n            \"in_channels\": 3,\n            \"cond_channels\": 0,\n            \"hidden_channels\": 8,\n            \"depth\": 4,\n        }\n\n    def get_vq_and_discriminator_configs(self, tmpdir):\n        vqmodel_config_path = os.path.join(tmpdir, \"vqmodel.json\")\n        discriminator_config_path = os.path.join(tmpdir, \"discriminator.json\")\n        with open(vqmodel_config_path, \"w\") as fp:\n            json.dump(self.test_vqmodel_config, fp)\n        with open(discriminator_config_path, \"w\") as fp:\n            json.dump(self.test_discriminator_config, fp)\n        return vqmodel_config_path, discriminator_config_path\n\n    def test_vqmodel(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)\n            test_args = f\"\"\"\n                examples/vqgan/train_vqgan.py\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 32\n                --image_column image\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 2\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --model_config_name_or_path {vqmodel_config_path}\n                --discriminator_config_name_or_path {discriminator_config_path}\n                --output_dir {tmpdir}\n                \"\"\".split()\n\n            run_command(self._launch_args + test_args)\n            # save_pretrained smoke test\n            self.assertTrue(\n                os.path.isfile(os.path.join(tmpdir, \"discriminator\", \"diffusion_pytorch_model.safetensors\"))\n            )\n            self.assertTrue(os.path.isfile(os.path.join(tmpdir, \"vqmodel\", \"diffusion_pytorch_model.safetensors\")))\n\n    def test_vqmodel_checkpointing(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)\n            # Run training script with checkpointing\n            # max_train_steps == 4, checkpointing_steps == 2\n            # Should create checkpoints at steps 2, 4\n\n            initial_run_args = f\"\"\"\n                examples/vqgan/train_vqgan.py\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 32\n                --image_column image\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 4\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --model_config_name_or_path {vqmodel_config_path}\n                --discriminator_config_name_or_path {discriminator_config_path}\n                --checkpointing_steps=2\n                --output_dir {tmpdir}\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\"},\n            )\n\n            # check can run an intermediate checkpoint\n            model = VQModel.from_pretrained(tmpdir, subfolder=\"checkpoint-2/vqmodel\")\n            image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)\n            _ = model(image)\n\n            # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming\n            shutil.rmtree(os.path.join(tmpdir, \"checkpoint-2\"))\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\"},\n            )\n\n            # Run training script for 2 total steps resuming from checkpoint 4\n\n            resume_run_args = f\"\"\"\n                examples/vqgan/train_vqgan.py\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 32\n                --image_column image\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 6\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --model_config_name_or_path {vqmodel_config_path}\n                --discriminator_config_name_or_path {discriminator_config_path}\n                --checkpointing_steps=1\n                --resume_from_checkpoint={os.path.join(tmpdir, \"checkpoint-4\")}\n                --output_dir {tmpdir}\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            # check can run new fully trained pipeline\n            model = VQModel.from_pretrained(tmpdir, subfolder=\"vqmodel\")\n            image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)\n            _ = model(image)\n\n            # no checkpoint-2 -> check old checkpoints do not exist\n            # check new checkpoints exist\n            # In the current script, checkpointing_steps 1 is equivalent to checkpointing_steps 2 as after the generator gets trained for one step,\n            # the discriminator gets trained and loss and saving happens after that. Thus we do not expect to get a checkpoint-5\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_vqmodel_checkpointing_use_ema(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)\n            # Run training script with checkpointing\n            # max_train_steps == 4, checkpointing_steps == 2\n            # Should create checkpoints at steps 2, 4\n\n            initial_run_args = f\"\"\"\n                examples/vqgan/train_vqgan.py\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 32\n                --image_column image\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 4\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --model_config_name_or_path {vqmodel_config_path}\n                --discriminator_config_name_or_path {discriminator_config_path}\n                --checkpointing_steps=2\n                --output_dir {tmpdir}\n                --use_ema\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            model = VQModel.from_pretrained(tmpdir, subfolder=\"vqmodel\")\n            image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)\n            _ = model(image)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\"},\n            )\n\n            # check can run an intermediate checkpoint\n            model = VQModel.from_pretrained(tmpdir, subfolder=\"checkpoint-2/vqmodel\")\n            image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)\n            _ = model(image)\n\n            # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming\n            shutil.rmtree(os.path.join(tmpdir, \"checkpoint-2\"))\n\n            # Run training script for 2 total steps resuming from checkpoint 4\n\n            resume_run_args = f\"\"\"\n                examples/vqgan/train_vqgan.py\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 32\n                --image_column image\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 6\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --model_config_name_or_path {vqmodel_config_path}\n                --discriminator_config_name_or_path {discriminator_config_path}\n                --checkpointing_steps=1\n                --resume_from_checkpoint={os.path.join(tmpdir, \"checkpoint-4\")}\n                --output_dir {tmpdir}\n                --use_ema\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            # check can run new fully trained pipeline\n            model = VQModel.from_pretrained(tmpdir, subfolder=\"vqmodel\")\n            image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)\n            _ = model(image)\n\n            # no checkpoint-2 -> check old checkpoints do not exist\n            # check new checkpoints exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-4\", \"checkpoint-6\"},\n            )\n\n    def test_vqmodel_checkpointing_checkpoints_total_limit(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)\n            # Run training script with checkpointing\n            # max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2\n            # Should create checkpoints at steps 2, 4, 6\n            # with checkpoint at step 2 deleted\n\n            initial_run_args = f\"\"\"\n                examples/vqgan/train_vqgan.py\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 32\n                --image_column image\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 6\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --model_config_name_or_path {vqmodel_config_path}\n                --discriminator_config_name_or_path {discriminator_config_path}\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --checkpoints_total_limit=2\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            model = VQModel.from_pretrained(tmpdir, subfolder=\"vqmodel\")\n            image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)\n            _ = model(image)\n\n            # check checkpoint directories exist\n            # checkpoint-2 should have been deleted\n            self.assertEqual({x for x in os.listdir(tmpdir) if \"checkpoint\" in x}, {\"checkpoint-4\", \"checkpoint-6\"})\n\n    def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)\n            # Run training script with checkpointing\n            # max_train_steps == 4, checkpointing_steps == 2\n            # Should create checkpoints at steps 2, 4\n\n            initial_run_args = f\"\"\"\n                examples/vqgan/train_vqgan.py\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 32\n                --image_column image\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 4\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --model_config_name_or_path {vqmodel_config_path}\n                --discriminator_config_name_or_path {discriminator_config_path}\n                --checkpointing_steps=2\n                --output_dir {tmpdir}\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + initial_run_args)\n\n            model = VQModel.from_pretrained(tmpdir, subfolder=\"vqmodel\")\n            image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)\n            _ = model(image)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-2\", \"checkpoint-4\"},\n            )\n\n            # resume and we should try to checkpoint at 6, where we'll have to remove\n            # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint\n\n            resume_run_args = f\"\"\"\n                examples/vqgan/train_vqgan.py\n                --dataset_name hf-internal-testing/dummy_image_text_data\n                --resolution 32\n                --image_column image\n                --train_batch_size 1\n                --gradient_accumulation_steps 1\n                --max_train_steps 8\n                --learning_rate 5.0e-04\n                --scale_lr\n                --lr_scheduler constant\n                --lr_warmup_steps 0\n                --model_config_name_or_path {vqmodel_config_path}\n                --discriminator_config_name_or_path {discriminator_config_path}\n                --output_dir {tmpdir}\n                --checkpointing_steps=2\n                --resume_from_checkpoint={os.path.join(tmpdir, \"checkpoint-4\")}\n                --checkpoints_total_limit=2\n                --seed=0\n                \"\"\".split()\n\n            run_command(self._launch_args + resume_run_args)\n\n            model = VQModel.from_pretrained(tmpdir, subfolder=\"vqmodel\")\n            image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)\n            _ = model(image)\n\n            # check checkpoint directories exist\n            self.assertEqual(\n                {x for x in os.listdir(tmpdir) if \"checkpoint\" in x},\n                {\"checkpoint-6\", \"checkpoint-8\"},\n            )\n"
  },
  {
    "path": "examples/vqgan/train_vqgan.py",
    "content": "# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport math\nimport os\nimport shutil\nimport time\nfrom pathlib import Path\n\nimport accelerate\nimport numpy as np\nimport PIL\nimport PIL.Image\nimport timm\nimport torch\nimport torch.nn.functional as F\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import DistributedType, ProjectConfiguration, set_seed\nfrom datasets import load_dataset\nfrom discriminator import Discriminator\nfrom huggingface_hub import create_repo\nfrom packaging import version\nfrom PIL import Image\nfrom timm.data import resolve_data_config\nfrom timm.data.transforms_factory import create_transform\nfrom torchvision import transforms\nfrom tqdm import tqdm\n\nfrom diffusers import VQModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import EMAModel\nfrom diffusers.utils import check_min_version, is_wandb_available\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.38.0.dev0\")\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef _map_layer_to_idx(backbone, layers, offset=0):\n    \"\"\"Maps set of layer names to indices of model. Ported from anomalib\n\n    Returns:\n        Feature map extracted from the CNN\n    \"\"\"\n    idx = []\n    features = timm.create_model(\n        backbone,\n        pretrained=False,\n        features_only=False,\n        exportable=True,\n    )\n    for i in layers:\n        try:\n            idx.append(list(dict(features.named_children()).keys()).index(i) - offset)\n        except ValueError:\n            raise ValueError(\n                f\"Layer {i} not found in model {backbone}. Select layer from {list(dict(features.named_children()).keys())}. The network architecture is {features}\"\n            )\n    return idx\n\n\ndef get_perceptual_loss(pixel_values, fmap, timm_model, timm_model_resolution, timm_model_normalization):\n    img_timm_model_input = timm_model_normalization(F.interpolate(pixel_values, timm_model_resolution))\n    fmap_timm_model_input = timm_model_normalization(F.interpolate(fmap, timm_model_resolution))\n\n    if pixel_values.shape[1] == 1:\n        # handle grayscale for timm_model\n        img_timm_model_input, fmap_timm_model_input = (\n            t.repeat(1, 3, 1, 1) for t in (img_timm_model_input, fmap_timm_model_input)\n        )\n\n    img_timm_model_feats = timm_model(img_timm_model_input)\n    recon_timm_model_feats = timm_model(fmap_timm_model_input)\n    perceptual_loss = F.mse_loss(img_timm_model_feats[0], recon_timm_model_feats[0])\n    for i in range(1, len(img_timm_model_feats)):\n        perceptual_loss += F.mse_loss(img_timm_model_feats[i], recon_timm_model_feats[i])\n    perceptual_loss /= len(img_timm_model_feats)\n    return perceptual_loss\n\n\ndef grad_layer_wrt_loss(loss, layer):\n    return torch.autograd.grad(\n        outputs=loss,\n        inputs=layer,\n        grad_outputs=torch.ones_like(loss),\n        retain_graph=True,\n    )[0].detach()\n\n\ndef gradient_penalty(images, output, weight=10):\n    gradients = torch.autograd.grad(\n        outputs=output,\n        inputs=images,\n        grad_outputs=torch.ones(output.size(), device=images.device),\n        create_graph=True,\n        retain_graph=True,\n        only_inputs=True,\n    )[0]\n    bsz = gradients.shape[0]\n    gradients = torch.reshape(gradients, (bsz, -1))\n    return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()\n\n\n@torch.no_grad()\ndef log_validation(model, args, validation_transform, accelerator, global_step):\n    logger.info(\"Generating images...\")\n    dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        dtype = torch.bfloat16\n    original_images = []\n    for image_path in args.validation_images:\n        image = PIL.Image.open(image_path)\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n        image = validation_transform(image).to(accelerator.device, dtype=dtype)\n        original_images.append(image[None])\n    # Generate images\n    model.eval()\n    images = []\n    for original_image in original_images:\n        image = accelerator.unwrap_model(model)(original_image).sample\n        images.append(image)\n    model.train()\n    original_images = torch.cat(original_images, dim=0)\n    images = torch.cat(images, dim=0)\n\n    # Convert to PIL images\n    images = torch.clamp(images, 0.0, 1.0)\n    original_images = torch.clamp(original_images, 0.0, 1.0)\n    images *= 255.0\n    original_images *= 255.0\n    images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)\n    original_images = original_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)\n    images = np.concatenate([original_images, images], axis=2)\n    images = [Image.fromarray(image) for image in images]\n\n    # Log images\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(\"validation\", np_images, global_step, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    \"validation\": [\n                        wandb.Image(image, caption=f\"{i}: Original, Generated\") for i, image in enumerate(images)\n                    ]\n                },\n                step=global_step,\n            )\n    torch.cuda.empty_cache()\n    return images\n\n\ndef log_grad_norm(model, accelerator, global_step):\n    for name, param in model.named_parameters():\n        if param.grad is not None:\n            grads = param.grad.detach().data\n            grad_norm = (grads.norm(p=2) / grads.numel()).item()\n            accelerator.log({\"grad_norm/\" + name: grad_norm}, step=global_step)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--log_grad_norm_steps\",\n        type=int,\n        default=500,\n        help=(\"Print logs of gradient norms every X steps.\"),\n    )\n    parser.add_argument(\n        \"--log_steps\",\n        type=int,\n        default=50,\n        help=(\"Print logs every X steps.\"),\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=100,\n        help=(\n            \"Run validation every X steps. Validation consists of running reconstruction on images in\"\n            \" `args.validation_images` and logging the reconstructed images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--vae_loss\",\n        type=str,\n        default=\"l2\",\n        help=\"The loss function for vae reconstruction loss.\",\n    )\n    parser.add_argument(\n        \"--timm_model_offset\",\n        type=int,\n        default=0,\n        help=\"Offset of timm layers to indices.\",\n    )\n    parser.add_argument(\n        \"--timm_model_layers\",\n        type=str,\n        default=\"head\",\n        help=\"The layers to get output from in the timm model.\",\n    )\n    parser.add_argument(\n        \"--timm_model_backend\",\n        type=str,\n        default=\"vgg19\",\n        help=\"Timm model used to get the lpips loss\",\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--model_config_name_or_path\",\n        type=str,\n        default=None,\n        help=\"The config of the Vq model to train, leave as None to use standard Vq model configuration.\",\n    )\n    parser.add_argument(\n        \"--discriminator_config_name_or_path\",\n        type=str,\n        default=None,\n        help=\"The config of the discriminator model to train, leave as None to use standard Vq model configuration.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--train_data_dir\",\n        type=str,\n        default=None,\n        help=(\n            \"A folder containing the training data. Folder contents must follow the structure described in\"\n            \" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file\"\n            \" must exist to provide the captions for the images. Ignored if `dataset_name` is specified.\"\n        ),\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing an image.\"\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_images\",\n        type=str,\n        default=None,\n        nargs=\"+\",\n        help=(\"A set of validation images evaluated every `--validation_steps` and logged to `--report_to`.\"),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"vqgan-output\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--random_flip\",\n        action=\"store_true\",\n        help=\"whether to randomly flip images horizontally\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=16, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=100)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--discr_learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--discr_lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\"--use_ema\", action=\"store_true\", help=\"Whether to use EMA model.\")\n    parser.add_argument(\n        \"--non_ema_revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or\"\n            \" remote repository specified with --pretrained_model_name_or_path.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--prediction_type\",\n        type=str,\n        default=None,\n        help=\"The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.\",\n    )\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"vqgan-training\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    args = parser.parse_args()\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    # Sanity checks\n    if args.dataset_name is None and args.train_data_dir is None:\n        raise ValueError(\"Need either a dataset name or a training folder.\")\n\n    # default to using the same revision for the non-ema model if not specified\n    if args.non_ema_revision is None:\n        args.non_ema_revision = args.revision\n\n    return args\n\n\ndef main():\n    #########################\n    # SETUP Accelerator     #\n    #########################\n    args = parse_args()\n\n    # Enable TF32 on Ampere GPUs\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n        torch.backends.cudnn.benchmark = True\n        torch.backends.cudnn.deterministic = False\n\n    logging_dir = os.path.join(args.output_dir, args.logging_dir)\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config,\n    )\n\n    if accelerator.distributed_type == DistributedType.DEEPSPEED:\n        accelerator.state.deepspeed_plugin.deepspeed_config[\"train_micro_batch_size_per_gpu\"] = args.train_batch_size\n\n    #####################################\n    # SETUP LOGGING, SEED and CONFIG    #\n    #####################################\n\n    if accelerator.is_main_process:\n        tracker_config = dict(vars(args))\n        tracker_config.pop(\"validation_images\")\n        accelerator.init_trackers(args.tracker_project_name, tracker_config)\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n        if args.push_to_hub:\n            create_repo(\n                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n            ).repo_id\n\n    #########################\n    # MODELS and OPTIMIZER  #\n    #########################\n    logger.info(\"Loading models and optimizer\")\n\n    if args.model_config_name_or_path is None and args.pretrained_model_name_or_path is None:\n        # Taken from config of movq at kandinsky-community/kandinsky-2-2-decoder but without the attention layers\n        model = VQModel(\n            act_fn=\"silu\",\n            block_out_channels=[\n                128,\n                256,\n                512,\n            ],\n            down_block_types=[\n                \"DownEncoderBlock2D\",\n                \"DownEncoderBlock2D\",\n                \"DownEncoderBlock2D\",\n            ],\n            in_channels=3,\n            latent_channels=4,\n            layers_per_block=2,\n            norm_num_groups=32,\n            norm_type=\"spatial\",\n            num_vq_embeddings=16384,\n            out_channels=3,\n            sample_size=32,\n            scaling_factor=0.18215,\n            up_block_types=[\"UpDecoderBlock2D\", \"UpDecoderBlock2D\", \"UpDecoderBlock2D\"],\n            vq_embed_dim=4,\n        )\n    elif args.pretrained_model_name_or_path is not None:\n        model = VQModel.from_pretrained(args.pretrained_model_name_or_path)\n    else:\n        config = VQModel.load_config(args.model_config_name_or_path)\n        model = VQModel.from_config(config)\n    if args.use_ema:\n        ema_model = EMAModel(model.parameters(), model_cls=VQModel, model_config=model.config)\n    if args.discriminator_config_name_or_path is None:\n        discriminator = Discriminator()\n    else:\n        config = Discriminator.load_config(args.discriminator_config_name_or_path)\n        discriminator = Discriminator.from_config(config)\n\n    idx = _map_layer_to_idx(args.timm_model_backend, args.timm_model_layers.split(\"|\"), args.timm_model_offset)\n\n    timm_model = timm.create_model(\n        args.timm_model_backend,\n        pretrained=True,\n        features_only=True,\n        exportable=True,\n        out_indices=idx,\n    )\n    timm_model = timm_model.to(accelerator.device)\n    timm_model.requires_grad = False\n    timm_model.eval()\n    timm_transform = create_transform(**resolve_data_config(timm_model.pretrained_cfg, model=timm_model))\n    try:\n        # Gets the resolution of the timm transformation after centercrop\n        timm_centercrop_transform = timm_transform.transforms[1]\n        assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (\n            f\"Timm model {timm_model} is currently incompatible with this script. Try vgg19.\"\n        )\n        timm_model_resolution = timm_centercrop_transform.size[0]\n        # Gets final normalization\n        timm_model_normalization = timm_transform.transforms[-1]\n        assert isinstance(timm_model_normalization, transforms.Normalize), (\n            f\"Timm model {timm_model} is currently incompatible with this script. Try vgg19.\"\n        )\n    except AssertionError as e:\n        raise NotImplementedError(e)\n    # Enable flash attention if asked\n    if args.enable_xformers_memory_efficient_attention:\n        model.enable_xformers_memory_efficient_attention()\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            if accelerator.is_main_process:\n                if args.use_ema:\n                    ema_model.save_pretrained(os.path.join(output_dir, \"vqmodel_ema\"))\n                vqmodel = models[0]\n                discriminator = models[1]\n                vqmodel.save_pretrained(os.path.join(output_dir, \"vqmodel\"))\n                discriminator.save_pretrained(os.path.join(output_dir, \"discriminator\"))\n                weights.pop()\n                weights.pop()\n\n        def load_model_hook(models, input_dir):\n            if args.use_ema:\n                load_model = EMAModel.from_pretrained(os.path.join(input_dir, \"vqmodel_ema\"), VQModel)\n                ema_model.load_state_dict(load_model.state_dict())\n                ema_model.to(accelerator.device)\n                del load_model\n            discriminator = models.pop()\n            load_model = Discriminator.from_pretrained(input_dir, subfolder=\"discriminator\")\n            discriminator.register_to_config(**load_model.config)\n            discriminator.load_state_dict(load_model.state_dict())\n            del load_model\n            vqmodel = models.pop()\n            load_model = VQModel.from_pretrained(input_dir, subfolder=\"vqmodel\")\n            vqmodel.register_to_config(**load_model.config)\n            vqmodel.load_state_dict(load_model.state_dict())\n            del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    learning_rate = args.learning_rate\n    if args.scale_lr:\n        learning_rate = (\n            learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n        )\n\n    # Initialize the optimizer\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`\"\n            )\n\n        optimizer_cls = bnb.optim.AdamW8bit\n    else:\n        optimizer_cls = torch.optim.AdamW\n\n    optimizer = optimizer_cls(\n        list(model.parameters()),\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n    discr_optimizer = optimizer_cls(\n        list(discriminator.parameters()),\n        lr=args.discr_learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    ##################################\n    # DATLOADER and LR-SCHEDULER     #\n    #################################\n    logger.info(\"Creating dataloaders and lr_scheduler\")\n\n    args.train_batch_size * accelerator.num_processes\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    # DataLoaders creation:\n    if args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            args.dataset_name,\n            args.dataset_config_name,\n            cache_dir=args.cache_dir,\n            data_dir=args.train_data_dir,\n        )\n    else:\n        data_files = {}\n        if args.train_data_dir is not None:\n            data_files[\"train\"] = os.path.join(args.train_data_dir, \"**\")\n        dataset = load_dataset(\n            \"imagefolder\",\n            data_files=data_files,\n            cache_dir=args.cache_dir,\n        )\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n\n    # Preprocessing the datasets.\n    # We need to tokenize inputs and targets.\n    column_names = dataset[\"train\"].column_names\n\n    # 6. Get the column names for input/target.\n    assert args.image_column is not None\n    image_column = args.image_column\n    if image_column not in column_names:\n        raise ValueError(f\"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}\")\n    # Preprocessing the datasets.\n    train_transforms = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),\n            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),\n            transforms.ToTensor(),\n        ]\n    )\n    validation_transform = transforms.Compose(\n        [\n            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.ToTensor(),\n        ]\n    )\n\n    def preprocess_train(examples):\n        images = [image.convert(\"RGB\") for image in examples[image_column]]\n        examples[\"pixel_values\"] = [train_transforms(image) for image in images]\n        return examples\n\n    with accelerator.main_process_first():\n        if args.max_train_samples is not None:\n            dataset[\"train\"] = dataset[\"train\"].shuffle(seed=args.seed).select(range(args.max_train_samples))\n        train_dataset = dataset[\"train\"].with_transform(preprocess_train)\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n        return {\"pixel_values\": pixel_values}\n\n    # DataLoaders creation:\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        shuffle=True,\n        collate_fn=collate_fn,\n        batch_size=args.train_batch_size,\n        num_workers=args.dataloader_num_workers,\n    )\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_training_steps=args.max_train_steps,\n        num_warmup_steps=args.lr_warmup_steps,\n    )\n    discr_lr_scheduler = get_scheduler(\n        args.discr_lr_scheduler,\n        optimizer=discr_optimizer,\n        num_training_steps=args.max_train_steps,\n        num_warmup_steps=args.lr_warmup_steps,\n    )\n\n    # Prepare everything with accelerator\n    logger.info(\"Preparing model, optimizer and dataloaders\")\n    # The dataloader are already aware of distributed training, so we don't need to prepare them.\n    model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler = accelerator.prepare(\n        model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler\n    )\n    if args.use_ema:\n        ema_model.to(accelerator.device)\n    # Train!\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # Potentially load in the weights and states from a previous save\n    resume_from_checkpoint = args.resume_from_checkpoint\n    if resume_from_checkpoint:\n        if resume_from_checkpoint != \"latest\":\n            path = resume_from_checkpoint\n        else:\n            # Get the most recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n            path = os.path.join(args.output_dir, path)\n\n        if path is None:\n            accelerator.print(f\"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.\")\n            resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(path)\n            accelerator.wait_for_everyone()\n            global_step = int(os.path.basename(path).split(\"-\")[1])\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    batch_time_m = AverageMeter()\n    data_time_m = AverageMeter()\n    end = time.time()\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to\n    # reuse the same training loop with other datasets/loaders.\n    avg_gen_loss, avg_discr_loss = None, None\n    for epoch in range(first_epoch, args.num_train_epochs):\n        model.train()\n        discriminator.train()\n        for i, batch in enumerate(train_dataloader):\n            pixel_values = batch[\"pixel_values\"]\n            pixel_values = pixel_values.to(accelerator.device, non_blocking=True)\n            data_time_m.update(time.time() - end)\n            generator_step = ((i // args.gradient_accumulation_steps) % 2) == 0\n            # Train Step\n            # The behavior of accelerator.accumulate is to\n            # 1. Check if gradients are synced(reached gradient-accumulation_steps)\n            # 2. If so sync gradients by stopping the not syncing process\n            if generator_step:\n                optimizer.zero_grad(set_to_none=True)\n            else:\n                discr_optimizer.zero_grad(set_to_none=True)\n            # encode images to the latent space and get the commit loss from vq tokenization\n            # Return commit loss\n            fmap, commit_loss = model(pixel_values, return_dict=False)\n\n            if generator_step:\n                with accelerator.accumulate(model):\n                    # reconstruction loss. Pixel level differences between input vs output\n                    if args.vae_loss == \"l2\":\n                        loss = F.mse_loss(pixel_values, fmap)\n                    else:\n                        loss = F.l1_loss(pixel_values, fmap)\n                    # perceptual loss. The high level feature mean squared error loss\n                    perceptual_loss = get_perceptual_loss(\n                        pixel_values,\n                        fmap,\n                        timm_model,\n                        timm_model_resolution=timm_model_resolution,\n                        timm_model_normalization=timm_model_normalization,\n                    )\n                    # generator loss\n                    gen_loss = -discriminator(fmap).mean()\n                    last_dec_layer = accelerator.unwrap_model(model).decoder.conv_out.weight\n                    norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2)\n                    norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)\n\n                    adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-8)\n                    adaptive_weight = adaptive_weight.clamp(max=1e4)\n                    loss += commit_loss\n                    loss += perceptual_loss\n                    loss += adaptive_weight * gen_loss\n                    # Gather the losses across all processes for logging (if we use distributed training).\n                    avg_gen_loss = accelerator.gather(loss.repeat(args.train_batch_size)).float().mean()\n                    accelerator.backward(loss)\n\n                    if args.max_grad_norm is not None and accelerator.sync_gradients:\n                        accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n\n                    optimizer.step()\n                    lr_scheduler.step()\n                    # log gradient norm before zeroing it\n                    if (\n                        accelerator.sync_gradients\n                        and global_step % args.log_grad_norm_steps == 0\n                        and accelerator.is_main_process\n                    ):\n                        log_grad_norm(model, accelerator, global_step)\n            else:\n                # Return discriminator loss\n                with accelerator.accumulate(discriminator):\n                    fmap.detach_()\n                    pixel_values.requires_grad_()\n                    real = discriminator(pixel_values)\n                    fake = discriminator(fmap)\n                    loss = (F.relu(1 + fake) + F.relu(1 - real)).mean()\n                    gp = gradient_penalty(pixel_values, real)\n                    loss += gp\n                    avg_discr_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()\n                    accelerator.backward(loss)\n\n                    if args.max_grad_norm is not None and accelerator.sync_gradients:\n                        accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm)\n\n                    discr_optimizer.step()\n                    discr_lr_scheduler.step()\n                    if (\n                        accelerator.sync_gradients\n                        and global_step % args.log_grad_norm_steps == 0\n                        and accelerator.is_main_process\n                    ):\n                        log_grad_norm(discriminator, accelerator, global_step)\n            batch_time_m.update(time.time() - end)\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                global_step += 1\n                progress_bar.update(1)\n                if args.use_ema:\n                    ema_model.step(model.parameters())\n            if accelerator.sync_gradients and not generator_step and accelerator.is_main_process:\n                # wait for both generator and discriminator to settle\n                # Log metrics\n                if global_step % args.log_steps == 0:\n                    samples_per_second_per_gpu = (\n                        args.gradient_accumulation_steps * args.train_batch_size / batch_time_m.val\n                    )\n                    logs = {\n                        \"step_discr_loss\": avg_discr_loss.item(),\n                        \"lr\": lr_scheduler.get_last_lr()[0],\n                        \"samples/sec/gpu\": samples_per_second_per_gpu,\n                        \"data_time\": data_time_m.val,\n                        \"batch_time\": batch_time_m.val,\n                    }\n                    if avg_gen_loss is not None:\n                        logs[\"step_gen_loss\"] = avg_gen_loss.item()\n                    accelerator.log(logs, step=global_step)\n\n                    # resetting batch / data time meters per log window\n                    batch_time_m.reset()\n                    data_time_m.reset()\n                # Save model checkpoint\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                        save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                # Generate images\n                if global_step % args.validation_steps == 0:\n                    if args.use_ema:\n                        # Store the VQGAN parameters temporarily and load the EMA parameters to perform inference.\n                        ema_model.store(model.parameters())\n                        ema_model.copy_to(model.parameters())\n                    log_validation(model, args, validation_transform, accelerator, global_step)\n                    if args.use_ema:\n                        # Switch back to the original VQGAN parameters.\n                        ema_model.restore(model.parameters())\n            end = time.time()\n            # Stop training if max steps is reached\n            if global_step >= args.max_train_steps:\n                break\n        # End for\n\n    accelerator.wait_for_everyone()\n\n    # Save the final trained checkpoint\n    if accelerator.is_main_process:\n        model = accelerator.unwrap_model(model)\n        discriminator = accelerator.unwrap_model(discriminator)\n        if args.use_ema:\n            ema_model.copy_to(model.parameters())\n        model.save_pretrained(os.path.join(args.output_dir, \"vqmodel\"))\n        discriminator.save_pretrained(os.path.join(args.output_dir, \"discriminator\"))\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.ruff]\nline-length = 119\nextend-exclude = [\n    \"src/diffusers/pipelines/flux2/system_messages.py\",\n]\n\n[tool.ruff.lint]\n# Never enforce `E501` (line length violations).\nignore = [\"C901\", \"E501\", \"E721\", \"E741\", \"F402\", \"F823\"]\nselect = [\"C\", \"E\", \"F\", \"I\", \"W\"]\n\n# Ignore import violations in all `__init__.py` files.\n[tool.ruff.lint.per-file-ignores]\n\"__init__.py\" = [\"E402\", \"F401\", \"F403\", \"F811\"]\n\"src/diffusers/utils/dummy_*.py\" = [\"F401\"]\n\n[tool.ruff.lint.isort]\nlines-after-imports = 2\nknown-first-party = [\"diffusers\"]\n\n[tool.ruff.format]\n# Like Black, use double quotes for strings.\nquote-style = \"double\"\n\n# Like Black, indent with spaces, rather than tabs.\nindent-style = \"space\"\n\n# Like Black, respect magic trailing commas.\nskip-magic-trailing-comma = false\n\n# Like Black, automatically detect the appropriate line ending.\nline-ending = \"auto\"\n"
  },
  {
    "path": "scripts/__init__.py",
    "content": ""
  },
  {
    "path": "scripts/conversion_ldm_uncond.py",
    "content": "import argparse\n\nimport torch\nimport yaml\n\nfrom diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel\n\n\ndef convert_ldm_original(checkpoint_path, config_path, output_path):\n    config = yaml.safe_load(config_path)\n    state_dict = torch.load(checkpoint_path, map_location=\"cpu\")[\"model\"]\n    keys = list(state_dict.keys())\n\n    # extract state_dict for VQVAE\n    first_stage_dict = {}\n    first_stage_key = \"first_stage_model.\"\n    for key in keys:\n        if key.startswith(first_stage_key):\n            first_stage_dict[key.replace(first_stage_key, \"\")] = state_dict[key]\n\n    # extract state_dict for UNetLDM\n    unet_state_dict = {}\n    unet_key = \"model.diffusion_model.\"\n    for key in keys:\n        if key.startswith(unet_key):\n            unet_state_dict[key.replace(unet_key, \"\")] = state_dict[key]\n\n    vqvae_init_args = config[\"model\"][\"params\"][\"first_stage_config\"][\"params\"]\n    unet_init_args = config[\"model\"][\"params\"][\"unet_config\"][\"params\"]\n\n    vqvae = VQModel(**vqvae_init_args).eval()\n    vqvae.load_state_dict(first_stage_dict)\n\n    unet = UNetLDMModel(**unet_init_args).eval()\n    unet.load_state_dict(unet_state_dict)\n\n    noise_scheduler = DDIMScheduler(\n        timesteps=config[\"model\"][\"params\"][\"timesteps\"],\n        beta_schedule=\"scaled_linear\",\n        beta_start=config[\"model\"][\"params\"][\"linear_start\"],\n        beta_end=config[\"model\"][\"params\"][\"linear_end\"],\n        clip_sample=False,\n    )\n\n    pipeline = LDMPipeline(vqvae, unet, noise_scheduler)\n    pipeline.save_pretrained(output_path)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--checkpoint_path\", type=str, required=True)\n    parser.add_argument(\"--config_path\", type=str, required=True)\n    parser.add_argument(\"--output_path\", type=str, required=True)\n    args = parser.parse_args()\n\n    convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)\n"
  },
  {
    "path": "scripts/convert_amused.py",
    "content": "import inspect\nimport os\nfrom argparse import ArgumentParser\n\nimport numpy as np\nimport torch\nfrom muse import MaskGiTUViT, VQGANModel\nfrom muse import PipelineMuse as OldPipelineMuse\nfrom transformers import CLIPTextModelWithProjection, CLIPTokenizer\n\nfrom diffusers import VQModel\nfrom diffusers.models.attention_processor import AttnProcessor\nfrom diffusers.models.unets.uvit_2d import UVit2DModel\nfrom diffusers.pipelines.amused.pipeline_amused import AmusedPipeline\nfrom diffusers.schedulers import AmusedScheduler\n\n\ntorch.backends.cuda.enable_flash_sdp(False)\ntorch.backends.cuda.enable_mem_efficient_sdp(False)\ntorch.backends.cuda.enable_math_sdp(True)\n\nos.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\"\nos.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\"\ntorch.use_deterministic_algorithms(True)\n\n# Enable CUDNN deterministic mode\ntorch.backends.cudnn.deterministic = True\ntorch.backends.cudnn.benchmark = False\ntorch.backends.cuda.matmul.allow_tf32 = False\n\ndevice = \"cuda\"\n\n\ndef main():\n    args = ArgumentParser()\n    args.add_argument(\"--model_256\", action=\"store_true\")\n    args.add_argument(\"--write_to\", type=str, required=False, default=None)\n    args.add_argument(\"--transformer_path\", type=str, required=False, default=None)\n    args = args.parse_args()\n\n    transformer_path = args.transformer_path\n    subfolder = \"transformer\"\n\n    if transformer_path is None:\n        if args.model_256:\n            transformer_path = \"openMUSE/muse-256\"\n        else:\n            transformer_path = (\n                \"../research-run-512-checkpoints/research-run-512-with-downsample-checkpoint-554000/unwrapped_model/\"\n            )\n            subfolder = None\n\n    old_transformer = MaskGiTUViT.from_pretrained(transformer_path, subfolder=subfolder)\n\n    old_transformer.to(device)\n\n    old_vae = VQGANModel.from_pretrained(\"openMUSE/muse-512\", subfolder=\"vae\")\n    old_vae.to(device)\n\n    vqvae = make_vqvae(old_vae)\n\n    tokenizer = CLIPTokenizer.from_pretrained(\"openMUSE/muse-512\", subfolder=\"text_encoder\")\n\n    text_encoder = CLIPTextModelWithProjection.from_pretrained(\"openMUSE/muse-512\", subfolder=\"text_encoder\")\n    text_encoder.to(device)\n\n    transformer = make_transformer(old_transformer, args.model_256)\n\n    scheduler = AmusedScheduler(mask_token_id=old_transformer.config.mask_token_id)\n\n    new_pipe = AmusedPipeline(\n        vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler\n    )\n\n    old_pipe = OldPipelineMuse(\n        vae=old_vae, transformer=old_transformer, text_encoder=text_encoder, tokenizer=tokenizer\n    )\n    old_pipe.to(device)\n\n    if args.model_256:\n        transformer_seq_len = 256\n        orig_size = (256, 256)\n    else:\n        transformer_seq_len = 1024\n        orig_size = (512, 512)\n\n    old_out = old_pipe(\n        \"dog\",\n        generator=torch.Generator(device).manual_seed(0),\n        transformer_seq_len=transformer_seq_len,\n        orig_size=orig_size,\n        timesteps=12,\n    )[0]\n\n    new_out = new_pipe(\"dog\", generator=torch.Generator(device).manual_seed(0)).images[0]\n\n    old_out = np.array(old_out)\n    new_out = np.array(new_out)\n\n    diff = np.abs(old_out.astype(np.float64) - new_out.astype(np.float64))\n\n    # assert diff diff.sum() == 0\n    print(\"skipping pipeline full equivalence check\")\n\n    print(f\"max diff: {diff.max()}, diff.sum() / diff.size {diff.sum() / diff.size}\")\n\n    if args.model_256:\n        assert diff.max() <= 3\n        assert diff.sum() / diff.size < 0.7\n    else:\n        assert diff.max() <= 1\n        assert diff.sum() / diff.size < 0.4\n\n    if args.write_to is not None:\n        new_pipe.save_pretrained(args.write_to)\n\n\ndef make_transformer(old_transformer, model_256):\n    args = dict(old_transformer.config)\n    force_down_up_sample = args[\"force_down_up_sample\"]\n\n    signature = inspect.signature(UVit2DModel.__init__)\n\n    args_ = {\n        \"downsample\": force_down_up_sample,\n        \"upsample\": force_down_up_sample,\n        \"block_out_channels\": args[\"block_out_channels\"][0],\n        \"sample_size\": 16 if model_256 else 32,\n    }\n\n    for s in list(signature.parameters.keys()):\n        if s in [\"self\", \"downsample\", \"upsample\", \"sample_size\", \"block_out_channels\"]:\n            continue\n\n        args_[s] = args[s]\n\n    new_transformer = UVit2DModel(**args_)\n    new_transformer.to(device)\n\n    new_transformer.set_attn_processor(AttnProcessor())\n\n    state_dict = old_transformer.state_dict()\n\n    state_dict[\"cond_embed.linear_1.weight\"] = state_dict.pop(\"cond_embed.0.weight\")\n    state_dict[\"cond_embed.linear_2.weight\"] = state_dict.pop(\"cond_embed.2.weight\")\n\n    for i in range(22):\n        state_dict[f\"transformer_layers.{i}.norm1.norm.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.attn_layer_norm.weight\"\n        )\n        state_dict[f\"transformer_layers.{i}.norm1.linear.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.self_attn_adaLN_modulation.mapper.weight\"\n        )\n\n        state_dict[f\"transformer_layers.{i}.attn1.to_q.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.attention.query.weight\"\n        )\n        state_dict[f\"transformer_layers.{i}.attn1.to_k.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.attention.key.weight\"\n        )\n        state_dict[f\"transformer_layers.{i}.attn1.to_v.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.attention.value.weight\"\n        )\n        state_dict[f\"transformer_layers.{i}.attn1.to_out.0.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.attention.out.weight\"\n        )\n\n        state_dict[f\"transformer_layers.{i}.norm2.norm.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.crossattn_layer_norm.weight\"\n        )\n        state_dict[f\"transformer_layers.{i}.norm2.linear.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.cross_attn_adaLN_modulation.mapper.weight\"\n        )\n\n        state_dict[f\"transformer_layers.{i}.attn2.to_q.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.crossattention.query.weight\"\n        )\n        state_dict[f\"transformer_layers.{i}.attn2.to_k.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.crossattention.key.weight\"\n        )\n        state_dict[f\"transformer_layers.{i}.attn2.to_v.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.crossattention.value.weight\"\n        )\n        state_dict[f\"transformer_layers.{i}.attn2.to_out.0.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.crossattention.out.weight\"\n        )\n\n        state_dict[f\"transformer_layers.{i}.norm3.norm.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.ffn.pre_mlp_layer_norm.weight\"\n        )\n        state_dict[f\"transformer_layers.{i}.norm3.linear.weight\"] = state_dict.pop(\n            f\"transformer_layers.{i}.ffn.adaLN_modulation.mapper.weight\"\n        )\n\n        wi_0_weight = state_dict.pop(f\"transformer_layers.{i}.ffn.wi_0.weight\")\n        wi_1_weight = state_dict.pop(f\"transformer_layers.{i}.ffn.wi_1.weight\")\n        proj_weight = torch.concat([wi_1_weight, wi_0_weight], dim=0)\n        state_dict[f\"transformer_layers.{i}.ff.net.0.proj.weight\"] = proj_weight\n\n        state_dict[f\"transformer_layers.{i}.ff.net.2.weight\"] = state_dict.pop(f\"transformer_layers.{i}.ffn.wo.weight\")\n\n    if force_down_up_sample:\n        state_dict[\"down_block.downsample.norm.weight\"] = state_dict.pop(\"down_blocks.0.downsample.0.norm.weight\")\n        state_dict[\"down_block.downsample.conv.weight\"] = state_dict.pop(\"down_blocks.0.downsample.1.weight\")\n\n        state_dict[\"up_block.upsample.norm.weight\"] = state_dict.pop(\"up_blocks.0.upsample.0.norm.weight\")\n        state_dict[\"up_block.upsample.conv.weight\"] = state_dict.pop(\"up_blocks.0.upsample.1.weight\")\n\n    state_dict[\"mlm_layer.layer_norm.weight\"] = state_dict.pop(\"mlm_layer.layer_norm.norm.weight\")\n\n    for i in range(3):\n        state_dict[f\"down_block.res_blocks.{i}.norm.weight\"] = state_dict.pop(\n            f\"down_blocks.0.res_blocks.{i}.norm.norm.weight\"\n        )\n        state_dict[f\"down_block.res_blocks.{i}.channelwise_linear_1.weight\"] = state_dict.pop(\n            f\"down_blocks.0.res_blocks.{i}.channelwise.0.weight\"\n        )\n        state_dict[f\"down_block.res_blocks.{i}.channelwise_norm.gamma\"] = state_dict.pop(\n            f\"down_blocks.0.res_blocks.{i}.channelwise.2.gamma\"\n        )\n        state_dict[f\"down_block.res_blocks.{i}.channelwise_norm.beta\"] = state_dict.pop(\n            f\"down_blocks.0.res_blocks.{i}.channelwise.2.beta\"\n        )\n        state_dict[f\"down_block.res_blocks.{i}.channelwise_linear_2.weight\"] = state_dict.pop(\n            f\"down_blocks.0.res_blocks.{i}.channelwise.4.weight\"\n        )\n        state_dict[f\"down_block.res_blocks.{i}.cond_embeds_mapper.weight\"] = state_dict.pop(\n            f\"down_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight\"\n        )\n\n        state_dict[f\"down_block.attention_blocks.{i}.norm1.weight\"] = state_dict.pop(\n            f\"down_blocks.0.attention_blocks.{i}.attn_layer_norm.weight\"\n        )\n        state_dict[f\"down_block.attention_blocks.{i}.attn1.to_q.weight\"] = state_dict.pop(\n            f\"down_blocks.0.attention_blocks.{i}.attention.query.weight\"\n        )\n        state_dict[f\"down_block.attention_blocks.{i}.attn1.to_k.weight\"] = state_dict.pop(\n            f\"down_blocks.0.attention_blocks.{i}.attention.key.weight\"\n        )\n        state_dict[f\"down_block.attention_blocks.{i}.attn1.to_v.weight\"] = state_dict.pop(\n            f\"down_blocks.0.attention_blocks.{i}.attention.value.weight\"\n        )\n        state_dict[f\"down_block.attention_blocks.{i}.attn1.to_out.0.weight\"] = state_dict.pop(\n            f\"down_blocks.0.attention_blocks.{i}.attention.out.weight\"\n        )\n\n        state_dict[f\"down_block.attention_blocks.{i}.norm2.weight\"] = state_dict.pop(\n            f\"down_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight\"\n        )\n        state_dict[f\"down_block.attention_blocks.{i}.attn2.to_q.weight\"] = state_dict.pop(\n            f\"down_blocks.0.attention_blocks.{i}.crossattention.query.weight\"\n        )\n        state_dict[f\"down_block.attention_blocks.{i}.attn2.to_k.weight\"] = state_dict.pop(\n            f\"down_blocks.0.attention_blocks.{i}.crossattention.key.weight\"\n        )\n        state_dict[f\"down_block.attention_blocks.{i}.attn2.to_v.weight\"] = state_dict.pop(\n            f\"down_blocks.0.attention_blocks.{i}.crossattention.value.weight\"\n        )\n        state_dict[f\"down_block.attention_blocks.{i}.attn2.to_out.0.weight\"] = state_dict.pop(\n            f\"down_blocks.0.attention_blocks.{i}.crossattention.out.weight\"\n        )\n\n        state_dict[f\"up_block.res_blocks.{i}.norm.weight\"] = state_dict.pop(\n            f\"up_blocks.0.res_blocks.{i}.norm.norm.weight\"\n        )\n        state_dict[f\"up_block.res_blocks.{i}.channelwise_linear_1.weight\"] = state_dict.pop(\n            f\"up_blocks.0.res_blocks.{i}.channelwise.0.weight\"\n        )\n        state_dict[f\"up_block.res_blocks.{i}.channelwise_norm.gamma\"] = state_dict.pop(\n            f\"up_blocks.0.res_blocks.{i}.channelwise.2.gamma\"\n        )\n        state_dict[f\"up_block.res_blocks.{i}.channelwise_norm.beta\"] = state_dict.pop(\n            f\"up_blocks.0.res_blocks.{i}.channelwise.2.beta\"\n        )\n        state_dict[f\"up_block.res_blocks.{i}.channelwise_linear_2.weight\"] = state_dict.pop(\n            f\"up_blocks.0.res_blocks.{i}.channelwise.4.weight\"\n        )\n        state_dict[f\"up_block.res_blocks.{i}.cond_embeds_mapper.weight\"] = state_dict.pop(\n            f\"up_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight\"\n        )\n\n        state_dict[f\"up_block.attention_blocks.{i}.norm1.weight\"] = state_dict.pop(\n            f\"up_blocks.0.attention_blocks.{i}.attn_layer_norm.weight\"\n        )\n        state_dict[f\"up_block.attention_blocks.{i}.attn1.to_q.weight\"] = state_dict.pop(\n            f\"up_blocks.0.attention_blocks.{i}.attention.query.weight\"\n        )\n        state_dict[f\"up_block.attention_blocks.{i}.attn1.to_k.weight\"] = state_dict.pop(\n            f\"up_blocks.0.attention_blocks.{i}.attention.key.weight\"\n        )\n        state_dict[f\"up_block.attention_blocks.{i}.attn1.to_v.weight\"] = state_dict.pop(\n            f\"up_blocks.0.attention_blocks.{i}.attention.value.weight\"\n        )\n        state_dict[f\"up_block.attention_blocks.{i}.attn1.to_out.0.weight\"] = state_dict.pop(\n            f\"up_blocks.0.attention_blocks.{i}.attention.out.weight\"\n        )\n\n        state_dict[f\"up_block.attention_blocks.{i}.norm2.weight\"] = state_dict.pop(\n            f\"up_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight\"\n        )\n        state_dict[f\"up_block.attention_blocks.{i}.attn2.to_q.weight\"] = state_dict.pop(\n            f\"up_blocks.0.attention_blocks.{i}.crossattention.query.weight\"\n        )\n        state_dict[f\"up_block.attention_blocks.{i}.attn2.to_k.weight\"] = state_dict.pop(\n            f\"up_blocks.0.attention_blocks.{i}.crossattention.key.weight\"\n        )\n        state_dict[f\"up_block.attention_blocks.{i}.attn2.to_v.weight\"] = state_dict.pop(\n            f\"up_blocks.0.attention_blocks.{i}.crossattention.value.weight\"\n        )\n        state_dict[f\"up_block.attention_blocks.{i}.attn2.to_out.0.weight\"] = state_dict.pop(\n            f\"up_blocks.0.attention_blocks.{i}.crossattention.out.weight\"\n        )\n\n    for key in list(state_dict.keys()):\n        if key.startswith(\"up_blocks.0\"):\n            key_ = \"up_block.\" + \".\".join(key.split(\".\")[2:])\n            state_dict[key_] = state_dict.pop(key)\n\n        if key.startswith(\"down_blocks.0\"):\n            key_ = \"down_block.\" + \".\".join(key.split(\".\")[2:])\n            state_dict[key_] = state_dict.pop(key)\n\n    new_transformer.load_state_dict(state_dict)\n\n    input_ids = torch.randint(0, 10, (1, 32, 32), device=old_transformer.device)\n    encoder_hidden_states = torch.randn((1, 77, 768), device=old_transformer.device)\n    cond_embeds = torch.randn((1, 768), device=old_transformer.device)\n    micro_conds = torch.tensor([[512, 512, 0, 0, 6]], dtype=torch.float32, device=old_transformer.device)\n\n    old_out = old_transformer(input_ids.reshape(1, -1), encoder_hidden_states, cond_embeds, micro_conds)\n    old_out = old_out.reshape(1, 32, 32, 8192).permute(0, 3, 1, 2)\n\n    new_out = new_transformer(input_ids, encoder_hidden_states, cond_embeds, micro_conds)\n\n    # NOTE: these differences are solely due to using the geglu block that has a single linear layer of\n    # double output dimension instead of two different linear layers\n    max_diff = (old_out - new_out).abs().max()\n    total_diff = (old_out - new_out).abs().sum()\n    print(f\"Transformer max_diff: {max_diff} total_diff:  {total_diff}\")\n    assert max_diff < 0.01\n    assert total_diff < 1500\n\n    return new_transformer\n\n\ndef make_vqvae(old_vae):\n    new_vae = VQModel(\n        act_fn=\"silu\",\n        block_out_channels=[128, 256, 256, 512, 768],\n        down_block_types=[\n            \"DownEncoderBlock2D\",\n            \"DownEncoderBlock2D\",\n            \"DownEncoderBlock2D\",\n            \"DownEncoderBlock2D\",\n            \"DownEncoderBlock2D\",\n        ],\n        in_channels=3,\n        latent_channels=64,\n        layers_per_block=2,\n        norm_num_groups=32,\n        num_vq_embeddings=8192,\n        out_channels=3,\n        sample_size=32,\n        up_block_types=[\n            \"UpDecoderBlock2D\",\n            \"UpDecoderBlock2D\",\n            \"UpDecoderBlock2D\",\n            \"UpDecoderBlock2D\",\n            \"UpDecoderBlock2D\",\n        ],\n        mid_block_add_attention=False,\n        lookup_from_codebook=True,\n    )\n    new_vae.to(device)\n\n    # fmt: off\n\n    new_state_dict = {}\n\n    old_state_dict = old_vae.state_dict()\n\n    new_state_dict[\"encoder.conv_in.weight\"] = old_state_dict.pop(\"encoder.conv_in.weight\")\n    new_state_dict[\"encoder.conv_in.bias\"]   = old_state_dict.pop(\"encoder.conv_in.bias\")\n\n    convert_vae_block_state_dict(old_state_dict, \"encoder.down.0\", new_state_dict, \"encoder.down_blocks.0\")\n    convert_vae_block_state_dict(old_state_dict, \"encoder.down.1\", new_state_dict, \"encoder.down_blocks.1\")\n    convert_vae_block_state_dict(old_state_dict, \"encoder.down.2\", new_state_dict, \"encoder.down_blocks.2\")\n    convert_vae_block_state_dict(old_state_dict, \"encoder.down.3\", new_state_dict, \"encoder.down_blocks.3\")\n    convert_vae_block_state_dict(old_state_dict, \"encoder.down.4\", new_state_dict, \"encoder.down_blocks.4\")\n\n    new_state_dict[\"encoder.mid_block.resnets.0.norm1.weight\"] = old_state_dict.pop(\"encoder.mid.block_1.norm1.weight\")\n    new_state_dict[\"encoder.mid_block.resnets.0.norm1.bias\"]   = old_state_dict.pop(\"encoder.mid.block_1.norm1.bias\")\n    new_state_dict[\"encoder.mid_block.resnets.0.conv1.weight\"] = old_state_dict.pop(\"encoder.mid.block_1.conv1.weight\")\n    new_state_dict[\"encoder.mid_block.resnets.0.conv1.bias\"]   = old_state_dict.pop(\"encoder.mid.block_1.conv1.bias\")\n    new_state_dict[\"encoder.mid_block.resnets.0.norm2.weight\"] = old_state_dict.pop(\"encoder.mid.block_1.norm2.weight\")\n    new_state_dict[\"encoder.mid_block.resnets.0.norm2.bias\"]   = old_state_dict.pop(\"encoder.mid.block_1.norm2.bias\")\n    new_state_dict[\"encoder.mid_block.resnets.0.conv2.weight\"] = old_state_dict.pop(\"encoder.mid.block_1.conv2.weight\")\n    new_state_dict[\"encoder.mid_block.resnets.0.conv2.bias\"]   = old_state_dict.pop(\"encoder.mid.block_1.conv2.bias\")\n    new_state_dict[\"encoder.mid_block.resnets.1.norm1.weight\"] = old_state_dict.pop(\"encoder.mid.block_2.norm1.weight\")\n    new_state_dict[\"encoder.mid_block.resnets.1.norm1.bias\"]   = old_state_dict.pop(\"encoder.mid.block_2.norm1.bias\")\n    new_state_dict[\"encoder.mid_block.resnets.1.conv1.weight\"] = old_state_dict.pop(\"encoder.mid.block_2.conv1.weight\")\n    new_state_dict[\"encoder.mid_block.resnets.1.conv1.bias\"]   = old_state_dict.pop(\"encoder.mid.block_2.conv1.bias\")\n    new_state_dict[\"encoder.mid_block.resnets.1.norm2.weight\"] = old_state_dict.pop(\"encoder.mid.block_2.norm2.weight\")\n    new_state_dict[\"encoder.mid_block.resnets.1.norm2.bias\"]   = old_state_dict.pop(\"encoder.mid.block_2.norm2.bias\")\n    new_state_dict[\"encoder.mid_block.resnets.1.conv2.weight\"] = old_state_dict.pop(\"encoder.mid.block_2.conv2.weight\")\n    new_state_dict[\"encoder.mid_block.resnets.1.conv2.bias\"]   = old_state_dict.pop(\"encoder.mid.block_2.conv2.bias\")\n    new_state_dict[\"encoder.conv_norm_out.weight\"]             = old_state_dict.pop(\"encoder.norm_out.weight\")\n    new_state_dict[\"encoder.conv_norm_out.bias\"]               = old_state_dict.pop(\"encoder.norm_out.bias\")\n    new_state_dict[\"encoder.conv_out.weight\"]                  = old_state_dict.pop(\"encoder.conv_out.weight\")\n    new_state_dict[\"encoder.conv_out.bias\"]                    = old_state_dict.pop(\"encoder.conv_out.bias\")\n    new_state_dict[\"quant_conv.weight\"]                        = old_state_dict.pop(\"quant_conv.weight\")\n    new_state_dict[\"quant_conv.bias\"]                          = old_state_dict.pop(\"quant_conv.bias\")\n    new_state_dict[\"quantize.embedding.weight\"]                = old_state_dict.pop(\"quantize.embedding.weight\")\n    new_state_dict[\"post_quant_conv.weight\"]                   = old_state_dict.pop(\"post_quant_conv.weight\")\n    new_state_dict[\"post_quant_conv.bias\"]                     = old_state_dict.pop(\"post_quant_conv.bias\")\n    new_state_dict[\"decoder.conv_in.weight\"]                   = old_state_dict.pop(\"decoder.conv_in.weight\")\n    new_state_dict[\"decoder.conv_in.bias\"]                     = old_state_dict.pop(\"decoder.conv_in.bias\")\n    new_state_dict[\"decoder.mid_block.resnets.0.norm1.weight\"] = old_state_dict.pop(\"decoder.mid.block_1.norm1.weight\")\n    new_state_dict[\"decoder.mid_block.resnets.0.norm1.bias\"]   = old_state_dict.pop(\"decoder.mid.block_1.norm1.bias\")\n    new_state_dict[\"decoder.mid_block.resnets.0.conv1.weight\"] = old_state_dict.pop(\"decoder.mid.block_1.conv1.weight\")\n    new_state_dict[\"decoder.mid_block.resnets.0.conv1.bias\"]   = old_state_dict.pop(\"decoder.mid.block_1.conv1.bias\")\n    new_state_dict[\"decoder.mid_block.resnets.0.norm2.weight\"] = old_state_dict.pop(\"decoder.mid.block_1.norm2.weight\")\n    new_state_dict[\"decoder.mid_block.resnets.0.norm2.bias\"]   = old_state_dict.pop(\"decoder.mid.block_1.norm2.bias\")\n    new_state_dict[\"decoder.mid_block.resnets.0.conv2.weight\"] = old_state_dict.pop(\"decoder.mid.block_1.conv2.weight\")\n    new_state_dict[\"decoder.mid_block.resnets.0.conv2.bias\"]   = old_state_dict.pop(\"decoder.mid.block_1.conv2.bias\")\n    new_state_dict[\"decoder.mid_block.resnets.1.norm1.weight\"] = old_state_dict.pop(\"decoder.mid.block_2.norm1.weight\")\n    new_state_dict[\"decoder.mid_block.resnets.1.norm1.bias\"]   = old_state_dict.pop(\"decoder.mid.block_2.norm1.bias\")\n    new_state_dict[\"decoder.mid_block.resnets.1.conv1.weight\"] = old_state_dict.pop(\"decoder.mid.block_2.conv1.weight\")\n    new_state_dict[\"decoder.mid_block.resnets.1.conv1.bias\"]   = old_state_dict.pop(\"decoder.mid.block_2.conv1.bias\")\n    new_state_dict[\"decoder.mid_block.resnets.1.norm2.weight\"] = old_state_dict.pop(\"decoder.mid.block_2.norm2.weight\")\n    new_state_dict[\"decoder.mid_block.resnets.1.norm2.bias\"]   = old_state_dict.pop(\"decoder.mid.block_2.norm2.bias\")\n    new_state_dict[\"decoder.mid_block.resnets.1.conv2.weight\"] = old_state_dict.pop(\"decoder.mid.block_2.conv2.weight\")\n    new_state_dict[\"decoder.mid_block.resnets.1.conv2.bias\"]   = old_state_dict.pop(\"decoder.mid.block_2.conv2.bias\")\n\n    convert_vae_block_state_dict(old_state_dict, \"decoder.up.0\", new_state_dict, \"decoder.up_blocks.4\")\n    convert_vae_block_state_dict(old_state_dict, \"decoder.up.1\", new_state_dict, \"decoder.up_blocks.3\")\n    convert_vae_block_state_dict(old_state_dict, \"decoder.up.2\", new_state_dict, \"decoder.up_blocks.2\")\n    convert_vae_block_state_dict(old_state_dict, \"decoder.up.3\", new_state_dict, \"decoder.up_blocks.1\")\n    convert_vae_block_state_dict(old_state_dict, \"decoder.up.4\", new_state_dict, \"decoder.up_blocks.0\")\n\n    new_state_dict[\"decoder.conv_norm_out.weight\"] = old_state_dict.pop(\"decoder.norm_out.weight\")\n    new_state_dict[\"decoder.conv_norm_out.bias\"]   = old_state_dict.pop(\"decoder.norm_out.bias\")\n    new_state_dict[\"decoder.conv_out.weight\"]      = old_state_dict.pop(\"decoder.conv_out.weight\")\n    new_state_dict[\"decoder.conv_out.bias\"]        = old_state_dict.pop(\"decoder.conv_out.bias\")\n\n    # fmt: on\n\n    assert len(old_state_dict.keys()) == 0\n\n    new_vae.load_state_dict(new_state_dict)\n\n    input = torch.randn((1, 3, 512, 512), device=device)\n    input = input.clamp(-1, 1)\n\n    old_encoder_output = old_vae.quant_conv(old_vae.encoder(input))\n    new_encoder_output = new_vae.quant_conv(new_vae.encoder(input))\n    assert (old_encoder_output == new_encoder_output).all()\n\n    old_decoder_output = old_vae.decoder(old_vae.post_quant_conv(old_encoder_output))\n    new_decoder_output = new_vae.decoder(new_vae.post_quant_conv(new_encoder_output))\n\n    # assert (old_decoder_output == new_decoder_output).all()\n    print(\"kipping vae decoder equivalence check\")\n    print(f\"vae decoder diff {(old_decoder_output - new_decoder_output).float().abs().sum()}\")\n\n    old_output = old_vae(input)[0]\n    new_output = new_vae(input)[0]\n\n    # assert (old_output == new_output).all()\n    print(\"skipping full vae equivalence check\")\n    print(f\"vae full diff {(old_output - new_output).float().abs().sum()}\")\n\n    return new_vae\n\n\ndef convert_vae_block_state_dict(old_state_dict, prefix_from, new_state_dict, prefix_to):\n    # fmt: off\n\n    new_state_dict[f\"{prefix_to}.resnets.0.norm1.weight\"]             = old_state_dict.pop(f\"{prefix_from}.block.0.norm1.weight\")\n    new_state_dict[f\"{prefix_to}.resnets.0.norm1.bias\"]               = old_state_dict.pop(f\"{prefix_from}.block.0.norm1.bias\")\n    new_state_dict[f\"{prefix_to}.resnets.0.conv1.weight\"]             = old_state_dict.pop(f\"{prefix_from}.block.0.conv1.weight\")\n    new_state_dict[f\"{prefix_to}.resnets.0.conv1.bias\"]               = old_state_dict.pop(f\"{prefix_from}.block.0.conv1.bias\")\n    new_state_dict[f\"{prefix_to}.resnets.0.norm2.weight\"]             = old_state_dict.pop(f\"{prefix_from}.block.0.norm2.weight\")\n    new_state_dict[f\"{prefix_to}.resnets.0.norm2.bias\"]               = old_state_dict.pop(f\"{prefix_from}.block.0.norm2.bias\")\n    new_state_dict[f\"{prefix_to}.resnets.0.conv2.weight\"]             = old_state_dict.pop(f\"{prefix_from}.block.0.conv2.weight\")\n    new_state_dict[f\"{prefix_to}.resnets.0.conv2.bias\"]               = old_state_dict.pop(f\"{prefix_from}.block.0.conv2.bias\")\n\n    if f\"{prefix_from}.block.0.nin_shortcut.weight\" in old_state_dict:\n        new_state_dict[f\"{prefix_to}.resnets.0.conv_shortcut.weight\"]     = old_state_dict.pop(f\"{prefix_from}.block.0.nin_shortcut.weight\")\n        new_state_dict[f\"{prefix_to}.resnets.0.conv_shortcut.bias\"]       = old_state_dict.pop(f\"{prefix_from}.block.0.nin_shortcut.bias\")\n\n    new_state_dict[f\"{prefix_to}.resnets.1.norm1.weight\"]             = old_state_dict.pop(f\"{prefix_from}.block.1.norm1.weight\")\n    new_state_dict[f\"{prefix_to}.resnets.1.norm1.bias\"]               = old_state_dict.pop(f\"{prefix_from}.block.1.norm1.bias\")\n    new_state_dict[f\"{prefix_to}.resnets.1.conv1.weight\"]             = old_state_dict.pop(f\"{prefix_from}.block.1.conv1.weight\")\n    new_state_dict[f\"{prefix_to}.resnets.1.conv1.bias\"]               = old_state_dict.pop(f\"{prefix_from}.block.1.conv1.bias\")\n    new_state_dict[f\"{prefix_to}.resnets.1.norm2.weight\"]             = old_state_dict.pop(f\"{prefix_from}.block.1.norm2.weight\")\n    new_state_dict[f\"{prefix_to}.resnets.1.norm2.bias\"]               = old_state_dict.pop(f\"{prefix_from}.block.1.norm2.bias\")\n    new_state_dict[f\"{prefix_to}.resnets.1.conv2.weight\"]             = old_state_dict.pop(f\"{prefix_from}.block.1.conv2.weight\")\n    new_state_dict[f\"{prefix_to}.resnets.1.conv2.bias\"]               = old_state_dict.pop(f\"{prefix_from}.block.1.conv2.bias\")\n\n    if f\"{prefix_from}.downsample.conv.weight\" in old_state_dict:\n        new_state_dict[f\"{prefix_to}.downsamplers.0.conv.weight\"]         = old_state_dict.pop(f\"{prefix_from}.downsample.conv.weight\")\n        new_state_dict[f\"{prefix_to}.downsamplers.0.conv.bias\"]           = old_state_dict.pop(f\"{prefix_from}.downsample.conv.bias\")\n\n    if f\"{prefix_from}.upsample.conv.weight\" in old_state_dict:\n        new_state_dict[f\"{prefix_to}.upsamplers.0.conv.weight\"]         = old_state_dict.pop(f\"{prefix_from}.upsample.conv.weight\")\n        new_state_dict[f\"{prefix_to}.upsamplers.0.conv.bias\"]           = old_state_dict.pop(f\"{prefix_from}.upsample.conv.bias\")\n\n    if f\"{prefix_from}.block.2.norm1.weight\" in old_state_dict:\n        new_state_dict[f\"{prefix_to}.resnets.2.norm1.weight\"]             = old_state_dict.pop(f\"{prefix_from}.block.2.norm1.weight\")\n        new_state_dict[f\"{prefix_to}.resnets.2.norm1.bias\"]               = old_state_dict.pop(f\"{prefix_from}.block.2.norm1.bias\")\n        new_state_dict[f\"{prefix_to}.resnets.2.conv1.weight\"]             = old_state_dict.pop(f\"{prefix_from}.block.2.conv1.weight\")\n        new_state_dict[f\"{prefix_to}.resnets.2.conv1.bias\"]               = old_state_dict.pop(f\"{prefix_from}.block.2.conv1.bias\")\n        new_state_dict[f\"{prefix_to}.resnets.2.norm2.weight\"]             = old_state_dict.pop(f\"{prefix_from}.block.2.norm2.weight\")\n        new_state_dict[f\"{prefix_to}.resnets.2.norm2.bias\"]               = old_state_dict.pop(f\"{prefix_from}.block.2.norm2.bias\")\n        new_state_dict[f\"{prefix_to}.resnets.2.conv2.weight\"]             = old_state_dict.pop(f\"{prefix_from}.block.2.conv2.weight\")\n        new_state_dict[f\"{prefix_to}.resnets.2.conv2.bias\"]               = old_state_dict.pop(f\"{prefix_from}.block.2.conv2.bias\")\n\n    # fmt: on\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "scripts/convert_animatediff_motion_lora_to_diffusers.py",
    "content": "import argparse\nimport os\n\nimport torch\nfrom huggingface_hub import create_repo, upload_folder\nfrom safetensors.torch import load_file, save_file\n\n\ndef convert_motion_module(original_state_dict):\n    converted_state_dict = {}\n    for k, v in original_state_dict.items():\n        if \"pos_encoder\" in k:\n            continue\n\n        else:\n            converted_state_dict[\n                k.replace(\".norms.0\", \".norm1\")\n                .replace(\".norms.1\", \".norm2\")\n                .replace(\".ff_norm\", \".norm3\")\n                .replace(\".attention_blocks.0\", \".attn1\")\n                .replace(\".attention_blocks.1\", \".attn2\")\n                .replace(\".temporal_transformer\", \"\")\n            ] = v\n\n    return converted_state_dict\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--ckpt_path\", type=str, required=True, help=\"Path to checkpoint\")\n    parser.add_argument(\"--output_path\", type=str, required=True, help=\"Path to output directory\")\n    parser.add_argument(\n        \"--push_to_hub\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether to push the converted model to the HF or not\",\n    )\n\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n\n    if args.ckpt_path.endswith(\".safetensors\"):\n        state_dict = load_file(args.ckpt_path)\n    else:\n        state_dict = torch.load(args.ckpt_path, map_location=\"cpu\")\n\n    if \"state_dict\" in state_dict.keys():\n        state_dict = state_dict[\"state_dict\"]\n\n    conv_state_dict = convert_motion_module(state_dict)\n\n    # convert to new format\n    output_dict = {}\n    for module_name, params in conv_state_dict.items():\n        if type(params) is not torch.Tensor:\n            continue\n        output_dict.update({f\"unet.{module_name}\": params})\n\n    os.makedirs(args.output_path, exist_ok=True)\n\n    filepath = os.path.join(args.output_path, \"diffusion_pytorch_model.safetensors\")\n    save_file(output_dict, filepath)\n\n    if args.push_to_hub:\n        repo_id = create_repo(args.output_path, exist_ok=True).repo_id\n        upload_folder(repo_id=repo_id, folder_path=args.output_path, repo_type=\"model\")\n"
  },
  {
    "path": "scripts/convert_animatediff_motion_module_to_diffusers.py",
    "content": "import argparse\n\nimport torch\nfrom safetensors.torch import load_file\n\nfrom diffusers import MotionAdapter\n\n\ndef convert_motion_module(original_state_dict):\n    converted_state_dict = {}\n    for k, v in original_state_dict.items():\n        if \"pos_encoder\" in k:\n            continue\n\n        else:\n            converted_state_dict[\n                k.replace(\".norms.0\", \".norm1\")\n                .replace(\".norms.1\", \".norm2\")\n                .replace(\".ff_norm\", \".norm3\")\n                .replace(\".attention_blocks.0\", \".attn1\")\n                .replace(\".attention_blocks.1\", \".attn2\")\n                .replace(\".temporal_transformer\", \"\")\n            ] = v\n\n    return converted_state_dict\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--ckpt_path\", type=str, required=True)\n    parser.add_argument(\"--output_path\", type=str, required=True)\n    parser.add_argument(\"--use_motion_mid_block\", action=\"store_true\")\n    parser.add_argument(\"--motion_max_seq_length\", type=int, default=32)\n    parser.add_argument(\"--block_out_channels\", nargs=\"+\", default=[320, 640, 1280, 1280], type=int)\n    parser.add_argument(\"--save_fp16\", action=\"store_true\")\n\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n\n    if args.ckpt_path.endswith(\".safetensors\"):\n        state_dict = load_file(args.ckpt_path)\n    else:\n        state_dict = torch.load(args.ckpt_path, map_location=\"cpu\")\n\n    if \"state_dict\" in state_dict.keys():\n        state_dict = state_dict[\"state_dict\"]\n\n    conv_state_dict = convert_motion_module(state_dict)\n    adapter = MotionAdapter(\n        block_out_channels=args.block_out_channels,\n        use_motion_mid_block=args.use_motion_mid_block,\n        motion_max_seq_length=args.motion_max_seq_length,\n    )\n    # skip loading position embeddings\n    adapter.load_state_dict(conv_state_dict, strict=False)\n    adapter.save_pretrained(args.output_path)\n\n    if args.save_fp16:\n        adapter.to(dtype=torch.float16).save_pretrained(args.output_path, variant=\"fp16\")\n"
  },
  {
    "path": "scripts/convert_animatediff_sparsectrl_to_diffusers.py",
    "content": "import argparse\nfrom typing import Dict\n\nimport torch\nimport torch.nn as nn\n\nfrom diffusers import SparseControlNetModel\n\n\nKEYS_RENAME_MAPPING = {\n    \".attention_blocks.0\": \".attn1\",\n    \".attention_blocks.1\": \".attn2\",\n    \".attn1.pos_encoder\": \".pos_embed\",\n    \".ff_norm\": \".norm3\",\n    \".norms.0\": \".norm1\",\n    \".norms.1\": \".norm2\",\n    \".temporal_transformer\": \"\",\n}\n\n\ndef convert(original_state_dict: Dict[str, nn.Module]) -> dict[str, nn.Module]:\n    converted_state_dict = {}\n\n    for key in list(original_state_dict.keys()):\n        renamed_key = key\n        for new_name, old_name in KEYS_RENAME_MAPPING.items():\n            renamed_key = renamed_key.replace(new_name, old_name)\n        converted_state_dict[renamed_key] = original_state_dict.pop(key)\n\n    return converted_state_dict\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--ckpt_path\", type=str, required=True, help=\"Path to checkpoint\")\n    parser.add_argument(\"--output_path\", type=str, required=True, help=\"Path to output directory\")\n    parser.add_argument(\n        \"--max_motion_seq_length\",\n        type=int,\n        default=32,\n        help=\"Max motion sequence length supported by the motion adapter\",\n    )\n    parser.add_argument(\n        \"--conditioning_channels\", type=int, default=4, help=\"Number of channels in conditioning input to controlnet\"\n    )\n    parser.add_argument(\n        \"--use_simplified_condition_embedding\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether or not to use simplified condition embedding. When `conditioning_channels==4` i.e. latent inputs, set this to `True`. When `conditioning_channels==3` i.e. image inputs, set this to `False`\",\n    )\n    parser.add_argument(\n        \"--save_fp16\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether or not to save model in fp16 precision along with fp32\",\n    )\n    parser.add_argument(\n        \"--push_to_hub\", action=\"store_true\", default=False, help=\"Whether or not to push saved model to the HF hub\"\n    )\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n\n    state_dict = torch.load(args.ckpt_path, map_location=\"cpu\")\n    if \"state_dict\" in state_dict.keys():\n        state_dict: dict = state_dict[\"state_dict\"]\n\n    controlnet = SparseControlNetModel(\n        conditioning_channels=args.conditioning_channels,\n        motion_max_seq_length=args.max_motion_seq_length,\n        use_simplified_condition_embedding=args.use_simplified_condition_embedding,\n    )\n\n    state_dict = convert(state_dict)\n    controlnet.load_state_dict(state_dict, strict=True)\n\n    controlnet.save_pretrained(args.output_path, push_to_hub=args.push_to_hub)\n    if args.save_fp16:\n        controlnet = controlnet.to(dtype=torch.float16)\n        controlnet.save_pretrained(args.output_path, variant=\"fp16\", push_to_hub=args.push_to_hub)\n"
  },
  {
    "path": "scripts/convert_asymmetric_vqgan_to_diffusers.py",
    "content": "import argparse\nimport time\nfrom pathlib import Path\nfrom typing import Any, Dict, Literal\n\nimport torch\n\nfrom diffusers import AsymmetricAutoencoderKL\n\n\nASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG = {\n    \"in_channels\": 3,\n    \"out_channels\": 3,\n    \"down_block_types\": [\n        \"DownEncoderBlock2D\",\n        \"DownEncoderBlock2D\",\n        \"DownEncoderBlock2D\",\n        \"DownEncoderBlock2D\",\n    ],\n    \"down_block_out_channels\": [128, 256, 512, 512],\n    \"layers_per_down_block\": 2,\n    \"up_block_types\": [\n        \"UpDecoderBlock2D\",\n        \"UpDecoderBlock2D\",\n        \"UpDecoderBlock2D\",\n        \"UpDecoderBlock2D\",\n    ],\n    \"up_block_out_channels\": [192, 384, 768, 768],\n    \"layers_per_up_block\": 3,\n    \"act_fn\": \"silu\",\n    \"latent_channels\": 4,\n    \"norm_num_groups\": 32,\n    \"sample_size\": 256,\n    \"scaling_factor\": 0.18215,\n}\n\nASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG = {\n    \"in_channels\": 3,\n    \"out_channels\": 3,\n    \"down_block_types\": [\n        \"DownEncoderBlock2D\",\n        \"DownEncoderBlock2D\",\n        \"DownEncoderBlock2D\",\n        \"DownEncoderBlock2D\",\n    ],\n    \"down_block_out_channels\": [128, 256, 512, 512],\n    \"layers_per_down_block\": 2,\n    \"up_block_types\": [\n        \"UpDecoderBlock2D\",\n        \"UpDecoderBlock2D\",\n        \"UpDecoderBlock2D\",\n        \"UpDecoderBlock2D\",\n    ],\n    \"up_block_out_channels\": [256, 512, 1024, 1024],\n    \"layers_per_up_block\": 5,\n    \"act_fn\": \"silu\",\n    \"latent_channels\": 4,\n    \"norm_num_groups\": 32,\n    \"sample_size\": 256,\n    \"scaling_factor\": 0.18215,\n}\n\n\ndef convert_asymmetric_autoencoder_kl_state_dict(original_state_dict: Dict[str, Any]) -> dict[str, Any]:\n    converted_state_dict = {}\n    for k, v in original_state_dict.items():\n        if k.startswith(\"encoder.\"):\n            converted_state_dict[\n                k.replace(\"encoder.down.\", \"encoder.down_blocks.\")\n                .replace(\"encoder.mid.\", \"encoder.mid_block.\")\n                .replace(\"encoder.norm_out.\", \"encoder.conv_norm_out.\")\n                .replace(\".downsample.\", \".downsamplers.0.\")\n                .replace(\".nin_shortcut.\", \".conv_shortcut.\")\n                .replace(\".block.\", \".resnets.\")\n                .replace(\".block_1.\", \".resnets.0.\")\n                .replace(\".block_2.\", \".resnets.1.\")\n                .replace(\".attn_1.k.\", \".attentions.0.to_k.\")\n                .replace(\".attn_1.q.\", \".attentions.0.to_q.\")\n                .replace(\".attn_1.v.\", \".attentions.0.to_v.\")\n                .replace(\".attn_1.proj_out.\", \".attentions.0.to_out.0.\")\n                .replace(\".attn_1.norm.\", \".attentions.0.group_norm.\")\n            ] = v\n        elif k.startswith(\"decoder.\") and \"up_layers\" not in k:\n            converted_state_dict[\n                k.replace(\"decoder.encoder.\", \"decoder.condition_encoder.\")\n                .replace(\".norm_out.\", \".conv_norm_out.\")\n                .replace(\".up.0.\", \".up_blocks.3.\")\n                .replace(\".up.1.\", \".up_blocks.2.\")\n                .replace(\".up.2.\", \".up_blocks.1.\")\n                .replace(\".up.3.\", \".up_blocks.0.\")\n                .replace(\".block.\", \".resnets.\")\n                .replace(\"mid\", \"mid_block\")\n                .replace(\".0.upsample.\", \".0.upsamplers.0.\")\n                .replace(\".1.upsample.\", \".1.upsamplers.0.\")\n                .replace(\".2.upsample.\", \".2.upsamplers.0.\")\n                .replace(\".nin_shortcut.\", \".conv_shortcut.\")\n                .replace(\".block_1.\", \".resnets.0.\")\n                .replace(\".block_2.\", \".resnets.1.\")\n                .replace(\".attn_1.k.\", \".attentions.0.to_k.\")\n                .replace(\".attn_1.q.\", \".attentions.0.to_q.\")\n                .replace(\".attn_1.v.\", \".attentions.0.to_v.\")\n                .replace(\".attn_1.proj_out.\", \".attentions.0.to_out.0.\")\n                .replace(\".attn_1.norm.\", \".attentions.0.group_norm.\")\n            ] = v\n        elif k.startswith(\"quant_conv.\"):\n            converted_state_dict[k] = v\n        elif k.startswith(\"post_quant_conv.\"):\n            converted_state_dict[k] = v\n        else:\n            print(f\"  skipping key `{k}`\")\n    # fix weights shape\n    for k, v in converted_state_dict.items():\n        if (\n            (k.startswith(\"encoder.mid_block.attentions.0\") or k.startswith(\"decoder.mid_block.attentions.0\"))\n            and k.endswith(\"weight\")\n            and (\"to_q\" in k or \"to_k\" in k or \"to_v\" in k or \"to_out\" in k)\n        ):\n            converted_state_dict[k] = converted_state_dict[k][:, :, 0, 0]\n\n    return converted_state_dict\n\n\ndef get_asymmetric_autoencoder_kl_from_original_checkpoint(\n    scale: Literal[\"1.5\", \"2\"], original_checkpoint_path: str, map_location: torch.device\n) -> AsymmetricAutoencoderKL:\n    print(\"Loading original state_dict\")\n    original_state_dict = torch.load(original_checkpoint_path, map_location=map_location)\n    original_state_dict = original_state_dict[\"state_dict\"]\n    print(\"Converting state_dict\")\n    converted_state_dict = convert_asymmetric_autoencoder_kl_state_dict(original_state_dict)\n    kwargs = ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG if scale == \"1.5\" else ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG\n    print(\"Initializing AsymmetricAutoencoderKL model\")\n    asymmetric_autoencoder_kl = AsymmetricAutoencoderKL(**kwargs)\n    print(\"Loading weight from converted state_dict\")\n    asymmetric_autoencoder_kl.load_state_dict(converted_state_dict)\n    asymmetric_autoencoder_kl.eval()\n    print(\"AsymmetricAutoencoderKL successfully initialized\")\n    return asymmetric_autoencoder_kl\n\n\nif __name__ == \"__main__\":\n    start = time.time()\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--scale\",\n        default=None,\n        type=str,\n        required=True,\n        help=\"Asymmetric VQGAN scale: `1.5` or `2`\",\n    )\n    parser.add_argument(\n        \"--original_checkpoint_path\",\n        default=None,\n        type=str,\n        required=True,\n        help=\"Path to the original Asymmetric VQGAN checkpoint\",\n    )\n    parser.add_argument(\n        \"--output_path\",\n        default=None,\n        type=str,\n        required=True,\n        help=\"Path to save pretrained AsymmetricAutoencoderKL model\",\n    )\n    parser.add_argument(\n        \"--map_location\",\n        default=\"cpu\",\n        type=str,\n        required=False,\n        help=\"The device passed to `map_location` when loading the checkpoint\",\n    )\n    args = parser.parse_args()\n\n    assert args.scale in [\"1.5\", \"2\"], f\"{args.scale} should be `1.5` of `2`\"\n    assert Path(args.original_checkpoint_path).is_file()\n\n    asymmetric_autoencoder_kl = get_asymmetric_autoencoder_kl_from_original_checkpoint(\n        scale=args.scale,\n        original_checkpoint_path=args.original_checkpoint_path,\n        map_location=torch.device(args.map_location),\n    )\n    print(\"Saving pretrained AsymmetricAutoencoderKL\")\n    asymmetric_autoencoder_kl.save_pretrained(args.output_path)\n    print(f\"Done in {time.time() - start:.2f} seconds\")\n"
  },
  {
    "path": "scripts/convert_aura_flow_to_diffusers.py",
    "content": "import argparse\n\nimport torch\nfrom huggingface_hub import hf_hub_download\n\nfrom diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel\n\n\ndef load_original_state_dict(args):\n    model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=\"aura_diffusion_pytorch_model.bin\")\n    state_dict = torch.load(model_pt, map_location=\"cpu\")\n    return state_dict\n\n\ndef calculate_layers(state_dict_keys, key_prefix):\n    dit_layers = set()\n    for k in state_dict_keys:\n        if key_prefix in k:\n            dit_layers.add(int(k.split(\".\")[2]))\n    print(f\"{key_prefix}: {len(dit_layers)}\")\n    return len(dit_layers)\n\n\n# similar to SD3 but only for the last norm layer\ndef swap_scale_shift(weight, dim):\n    shift, scale = weight.chunk(2, dim=0)\n    new_weight = torch.cat([scale, shift], dim=0)\n    return new_weight\n\n\ndef convert_transformer(state_dict):\n    converted_state_dict = {}\n    state_dict_keys = list(state_dict.keys())\n\n    converted_state_dict[\"register_tokens\"] = state_dict.pop(\"model.register_tokens\")\n    converted_state_dict[\"pos_embed.pos_embed\"] = state_dict.pop(\"model.positional_encoding\")\n    converted_state_dict[\"pos_embed.proj.weight\"] = state_dict.pop(\"model.init_x_linear.weight\")\n    converted_state_dict[\"pos_embed.proj.bias\"] = state_dict.pop(\"model.init_x_linear.bias\")\n\n    converted_state_dict[\"time_step_proj.linear_1.weight\"] = state_dict.pop(\"model.t_embedder.mlp.0.weight\")\n    converted_state_dict[\"time_step_proj.linear_1.bias\"] = state_dict.pop(\"model.t_embedder.mlp.0.bias\")\n    converted_state_dict[\"time_step_proj.linear_2.weight\"] = state_dict.pop(\"model.t_embedder.mlp.2.weight\")\n    converted_state_dict[\"time_step_proj.linear_2.bias\"] = state_dict.pop(\"model.t_embedder.mlp.2.bias\")\n\n    converted_state_dict[\"context_embedder.weight\"] = state_dict.pop(\"model.cond_seq_linear.weight\")\n\n    mmdit_layers = calculate_layers(state_dict_keys, key_prefix=\"double_layers\")\n    single_dit_layers = calculate_layers(state_dict_keys, key_prefix=\"single_layers\")\n\n    # MMDiT blocks 🎸.\n    for i in range(mmdit_layers):\n        # feed-forward\n        path_mapping = {\"mlpX\": \"ff\", \"mlpC\": \"ff_context\"}\n        weight_mapping = {\"c_fc1\": \"linear_1\", \"c_fc2\": \"linear_2\", \"c_proj\": \"out_projection\"}\n        for orig_k, diffuser_k in path_mapping.items():\n            for k, v in weight_mapping.items():\n                converted_state_dict[f\"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight\"] = state_dict.pop(\n                    f\"model.double_layers.{i}.{orig_k}.{k}.weight\"\n                )\n\n        # norms\n        path_mapping = {\"modX\": \"norm1\", \"modC\": \"norm1_context\"}\n        for orig_k, diffuser_k in path_mapping.items():\n            converted_state_dict[f\"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight\"] = state_dict.pop(\n                f\"model.double_layers.{i}.{orig_k}.1.weight\"\n            )\n\n        # attns\n        x_attn_mapping = {\"w2q\": \"to_q\", \"w2k\": \"to_k\", \"w2v\": \"to_v\", \"w2o\": \"to_out.0\"}\n        context_attn_mapping = {\"w1q\": \"add_q_proj\", \"w1k\": \"add_k_proj\", \"w1v\": \"add_v_proj\", \"w1o\": \"to_add_out\"}\n        for attn_mapping in [x_attn_mapping, context_attn_mapping]:\n            for k, v in attn_mapping.items():\n                converted_state_dict[f\"joint_transformer_blocks.{i}.attn.{v}.weight\"] = state_dict.pop(\n                    f\"model.double_layers.{i}.attn.{k}.weight\"\n                )\n\n    # Single-DiT blocks.\n    for i in range(single_dit_layers):\n        # feed-forward\n        mapping = {\"c_fc1\": \"linear_1\", \"c_fc2\": \"linear_2\", \"c_proj\": \"out_projection\"}\n        for k, v in mapping.items():\n            converted_state_dict[f\"single_transformer_blocks.{i}.ff.{v}.weight\"] = state_dict.pop(\n                f\"model.single_layers.{i}.mlp.{k}.weight\"\n            )\n\n        # norms\n        converted_state_dict[f\"single_transformer_blocks.{i}.norm1.linear.weight\"] = state_dict.pop(\n            f\"model.single_layers.{i}.modCX.1.weight\"\n        )\n\n        # attns\n        x_attn_mapping = {\"w1q\": \"to_q\", \"w1k\": \"to_k\", \"w1v\": \"to_v\", \"w1o\": \"to_out.0\"}\n        for k, v in x_attn_mapping.items():\n            converted_state_dict[f\"single_transformer_blocks.{i}.attn.{v}.weight\"] = state_dict.pop(\n                f\"model.single_layers.{i}.attn.{k}.weight\"\n            )\n\n    # Final blocks.\n    converted_state_dict[\"proj_out.weight\"] = state_dict.pop(\"model.final_linear.weight\")\n    converted_state_dict[\"norm_out.linear.weight\"] = swap_scale_shift(state_dict.pop(\"model.modF.1.weight\"), dim=None)\n\n    return converted_state_dict\n\n\n@torch.no_grad()\ndef populate_state_dict(args):\n    original_state_dict = load_original_state_dict(args)\n    state_dict_keys = list(original_state_dict.keys())\n    mmdit_layers = calculate_layers(state_dict_keys, key_prefix=\"double_layers\")\n    single_dit_layers = calculate_layers(state_dict_keys, key_prefix=\"single_layers\")\n\n    converted_state_dict = convert_transformer(original_state_dict)\n    model_diffusers = AuraFlowTransformer2DModel(\n        num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers\n    )\n    model_diffusers.load_state_dict(converted_state_dict, strict=True)\n\n    return model_diffusers\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--original_state_dict_repo_id\", default=\"AuraDiffusion/auradiffusion-v0.1a0\", type=str)\n    parser.add_argument(\"--dump_path\", default=\"aura-flow\", type=str)\n    parser.add_argument(\"--hub_id\", default=None, type=str)\n    args = parser.parse_args()\n\n    model_diffusers = populate_state_dict(args)\n    model_diffusers.save_pretrained(args.dump_path)\n    if args.hub_id is not None:\n        model_diffusers.push_to_hub(args.hub_id)\n"
  },
  {
    "path": "scripts/convert_blipdiffusion_to_diffusers.py",
    "content": "\"\"\"\nThis script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.\n\"\"\"\n\nimport argparse\nimport os\nimport tempfile\n\nimport torch\nfrom lavis.models import load_model_and_preprocess\nfrom transformers import CLIPTokenizer\nfrom transformers.models.blip_2.configuration_blip_2 import Blip2Config\n\nfrom diffusers import (\n    AutoencoderKL,\n    PNDMScheduler,\n    UNet2DConditionModel,\n)\nfrom diffusers.pipelines import BlipDiffusionPipeline\nfrom diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor\nfrom diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel\nfrom diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel\n\n\nBLIP2_CONFIG = {\n    \"vision_config\": {\n        \"hidden_size\": 1024,\n        \"num_hidden_layers\": 23,\n        \"num_attention_heads\": 16,\n        \"image_size\": 224,\n        \"patch_size\": 14,\n        \"intermediate_size\": 4096,\n        \"hidden_act\": \"quick_gelu\",\n    },\n    \"qformer_config\": {\n        \"cross_attention_frequency\": 1,\n        \"encoder_hidden_size\": 1024,\n        \"vocab_size\": 30523,\n    },\n    \"num_query_tokens\": 16,\n}\nblip2config = Blip2Config(**BLIP2_CONFIG)\n\n\ndef qformer_model_from_original_config():\n    qformer = Blip2QFormerModel(blip2config)\n    return qformer\n\n\ndef embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):\n    embeddings = {}\n    embeddings.update(\n        {\n            f\"{diffuser_embeddings_prefix}.word_embeddings.weight\": model[\n                f\"{original_embeddings_prefix}.word_embeddings.weight\"\n            ]\n        }\n    )\n    embeddings.update(\n        {\n            f\"{diffuser_embeddings_prefix}.position_embeddings.weight\": model[\n                f\"{original_embeddings_prefix}.position_embeddings.weight\"\n            ]\n        }\n    )\n    embeddings.update(\n        {f\"{diffuser_embeddings_prefix}.LayerNorm.weight\": model[f\"{original_embeddings_prefix}.LayerNorm.weight\"]}\n    )\n    embeddings.update(\n        {f\"{diffuser_embeddings_prefix}.LayerNorm.bias\": model[f\"{original_embeddings_prefix}.LayerNorm.bias\"]}\n    )\n    return embeddings\n\n\ndef proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):\n    proj_layer = {}\n    proj_layer.update({f\"{diffuser_proj_prefix}.dense1.weight\": model[f\"{original_proj_prefix}.dense1.weight\"]})\n    proj_layer.update({f\"{diffuser_proj_prefix}.dense1.bias\": model[f\"{original_proj_prefix}.dense1.bias\"]})\n    proj_layer.update({f\"{diffuser_proj_prefix}.dense2.weight\": model[f\"{original_proj_prefix}.dense2.weight\"]})\n    proj_layer.update({f\"{diffuser_proj_prefix}.dense2.bias\": model[f\"{original_proj_prefix}.dense2.bias\"]})\n    proj_layer.update({f\"{diffuser_proj_prefix}.LayerNorm.weight\": model[f\"{original_proj_prefix}.LayerNorm.weight\"]})\n    proj_layer.update({f\"{diffuser_proj_prefix}.LayerNorm.bias\": model[f\"{original_proj_prefix}.LayerNorm.bias\"]})\n    return proj_layer\n\n\ndef attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):\n    attention = {}\n    attention.update(\n        {\n            f\"{diffuser_attention_prefix}.attention.query.weight\": model[\n                f\"{original_attention_prefix}.self.query.weight\"\n            ]\n        }\n    )\n    attention.update(\n        {f\"{diffuser_attention_prefix}.attention.query.bias\": model[f\"{original_attention_prefix}.self.query.bias\"]}\n    )\n    attention.update(\n        {f\"{diffuser_attention_prefix}.attention.key.weight\": model[f\"{original_attention_prefix}.self.key.weight\"]}\n    )\n    attention.update(\n        {f\"{diffuser_attention_prefix}.attention.key.bias\": model[f\"{original_attention_prefix}.self.key.bias\"]}\n    )\n    attention.update(\n        {\n            f\"{diffuser_attention_prefix}.attention.value.weight\": model[\n                f\"{original_attention_prefix}.self.value.weight\"\n            ]\n        }\n    )\n    attention.update(\n        {f\"{diffuser_attention_prefix}.attention.value.bias\": model[f\"{original_attention_prefix}.self.value.bias\"]}\n    )\n    attention.update(\n        {f\"{diffuser_attention_prefix}.output.dense.weight\": model[f\"{original_attention_prefix}.output.dense.weight\"]}\n    )\n    attention.update(\n        {f\"{diffuser_attention_prefix}.output.dense.bias\": model[f\"{original_attention_prefix}.output.dense.bias\"]}\n    )\n    attention.update(\n        {\n            f\"{diffuser_attention_prefix}.output.LayerNorm.weight\": model[\n                f\"{original_attention_prefix}.output.LayerNorm.weight\"\n            ]\n        }\n    )\n    attention.update(\n        {\n            f\"{diffuser_attention_prefix}.output.LayerNorm.bias\": model[\n                f\"{original_attention_prefix}.output.LayerNorm.bias\"\n            ]\n        }\n    )\n    return attention\n\n\ndef output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):\n    output_layers = {}\n    output_layers.update({f\"{diffuser_output_prefix}.dense.weight\": model[f\"{original_output_prefix}.dense.weight\"]})\n    output_layers.update({f\"{diffuser_output_prefix}.dense.bias\": model[f\"{original_output_prefix}.dense.bias\"]})\n    output_layers.update(\n        {f\"{diffuser_output_prefix}.LayerNorm.weight\": model[f\"{original_output_prefix}.LayerNorm.weight\"]}\n    )\n    output_layers.update(\n        {f\"{diffuser_output_prefix}.LayerNorm.bias\": model[f\"{original_output_prefix}.LayerNorm.bias\"]}\n    )\n    return output_layers\n\n\ndef encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):\n    encoder = {}\n    for i in range(blip2config.qformer_config.num_hidden_layers):\n        encoder.update(\n            attention_from_original_checkpoint(\n                model, f\"{diffuser_encoder_prefix}.{i}.attention\", f\"{original_encoder_prefix}.{i}.attention\"\n            )\n        )\n        encoder.update(\n            attention_from_original_checkpoint(\n                model, f\"{diffuser_encoder_prefix}.{i}.crossattention\", f\"{original_encoder_prefix}.{i}.crossattention\"\n            )\n        )\n\n        encoder.update(\n            {\n                f\"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight\": model[\n                    f\"{original_encoder_prefix}.{i}.intermediate.dense.weight\"\n                ]\n            }\n        )\n        encoder.update(\n            {\n                f\"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias\": model[\n                    f\"{original_encoder_prefix}.{i}.intermediate.dense.bias\"\n                ]\n            }\n        )\n        encoder.update(\n            {\n                f\"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight\": model[\n                    f\"{original_encoder_prefix}.{i}.intermediate_query.dense.weight\"\n                ]\n            }\n        )\n        encoder.update(\n            {\n                f\"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias\": model[\n                    f\"{original_encoder_prefix}.{i}.intermediate_query.dense.bias\"\n                ]\n            }\n        )\n\n        encoder.update(\n            output_layers_from_original_checkpoint(\n                model, f\"{diffuser_encoder_prefix}.{i}.output\", f\"{original_encoder_prefix}.{i}.output\"\n            )\n        )\n        encoder.update(\n            output_layers_from_original_checkpoint(\n                model, f\"{diffuser_encoder_prefix}.{i}.output_query\", f\"{original_encoder_prefix}.{i}.output_query\"\n            )\n        )\n    return encoder\n\n\ndef visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):\n    visual_encoder_layer = {}\n\n    visual_encoder_layer.update({f\"{diffuser_prefix}.layer_norm1.weight\": model[f\"{original_prefix}.ln_1.weight\"]})\n    visual_encoder_layer.update({f\"{diffuser_prefix}.layer_norm1.bias\": model[f\"{original_prefix}.ln_1.bias\"]})\n    visual_encoder_layer.update({f\"{diffuser_prefix}.layer_norm2.weight\": model[f\"{original_prefix}.ln_2.weight\"]})\n    visual_encoder_layer.update({f\"{diffuser_prefix}.layer_norm2.bias\": model[f\"{original_prefix}.ln_2.bias\"]})\n    visual_encoder_layer.update(\n        {f\"{diffuser_prefix}.self_attn.qkv.weight\": model[f\"{original_prefix}.attn.in_proj_weight\"]}\n    )\n    visual_encoder_layer.update(\n        {f\"{diffuser_prefix}.self_attn.qkv.bias\": model[f\"{original_prefix}.attn.in_proj_bias\"]}\n    )\n    visual_encoder_layer.update(\n        {f\"{diffuser_prefix}.self_attn.projection.weight\": model[f\"{original_prefix}.attn.out_proj.weight\"]}\n    )\n    visual_encoder_layer.update(\n        {f\"{diffuser_prefix}.self_attn.projection.bias\": model[f\"{original_prefix}.attn.out_proj.bias\"]}\n    )\n    visual_encoder_layer.update({f\"{diffuser_prefix}.mlp.fc1.weight\": model[f\"{original_prefix}.mlp.c_fc.weight\"]})\n    visual_encoder_layer.update({f\"{diffuser_prefix}.mlp.fc1.bias\": model[f\"{original_prefix}.mlp.c_fc.bias\"]})\n    visual_encoder_layer.update({f\"{diffuser_prefix}.mlp.fc2.weight\": model[f\"{original_prefix}.mlp.c_proj.weight\"]})\n    visual_encoder_layer.update({f\"{diffuser_prefix}.mlp.fc2.bias\": model[f\"{original_prefix}.mlp.c_proj.bias\"]})\n\n    return visual_encoder_layer\n\n\ndef visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):\n    visual_encoder = {}\n\n    visual_encoder.update(\n        {\n            f\"{diffuser_prefix}.embeddings.class_embedding\": model[f\"{original_prefix}.class_embedding\"]\n            .unsqueeze(0)\n            .unsqueeze(0)\n        }\n    )\n    visual_encoder.update(\n        {\n            f\"{diffuser_prefix}.embeddings.position_embedding\": model[\n                f\"{original_prefix}.positional_embedding\"\n            ].unsqueeze(0)\n        }\n    )\n    visual_encoder.update(\n        {f\"{diffuser_prefix}.embeddings.patch_embedding.weight\": model[f\"{original_prefix}.conv1.weight\"]}\n    )\n    visual_encoder.update({f\"{diffuser_prefix}.pre_layernorm.weight\": model[f\"{original_prefix}.ln_pre.weight\"]})\n    visual_encoder.update({f\"{diffuser_prefix}.pre_layernorm.bias\": model[f\"{original_prefix}.ln_pre.bias\"]})\n\n    for i in range(blip2config.vision_config.num_hidden_layers):\n        visual_encoder.update(\n            visual_encoder_layer_from_original_checkpoint(\n                model, f\"{diffuser_prefix}.encoder.layers.{i}\", f\"{original_prefix}.transformer.resblocks.{i}\"\n            )\n        )\n\n    visual_encoder.update({f\"{diffuser_prefix}.post_layernorm.weight\": model[\"blip.ln_vision.weight\"]})\n    visual_encoder.update({f\"{diffuser_prefix}.post_layernorm.bias\": model[\"blip.ln_vision.bias\"]})\n\n    return visual_encoder\n\n\ndef qformer_original_checkpoint_to_diffusers_checkpoint(model):\n    qformer_checkpoint = {}\n    qformer_checkpoint.update(embeddings_from_original_checkpoint(model, \"embeddings\", \"blip.Qformer.bert.embeddings\"))\n    qformer_checkpoint.update({\"query_tokens\": model[\"blip.query_tokens\"]})\n    qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, \"proj_layer\", \"proj_layer\"))\n    qformer_checkpoint.update(\n        encoder_from_original_checkpoint(model, \"encoder.layer\", \"blip.Qformer.bert.encoder.layer\")\n    )\n    qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, \"visual_encoder\", \"blip.visual_encoder\"))\n    return qformer_checkpoint\n\n\ndef get_qformer(model):\n    print(\"loading qformer\")\n\n    qformer = qformer_model_from_original_config()\n    qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)\n\n    load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)\n\n    print(\"done loading qformer\")\n    return qformer\n\n\ndef load_checkpoint_to_model(checkpoint, model):\n    with tempfile.NamedTemporaryFile(delete=False) as file:\n        torch.save(checkpoint, file.name)\n        del checkpoint\n        model.load_state_dict(torch.load(file.name), strict=False)\n\n    os.remove(file.name)\n\n\ndef save_blip_diffusion_model(model, args):\n    qformer = get_qformer(model)\n    qformer.eval()\n\n    text_encoder = ContextCLIPTextModel.from_pretrained(\n        \"stable-diffusion-v1-5/stable-diffusion-v1-5\", subfolder=\"text_encoder\"\n    )\n    vae = AutoencoderKL.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", subfolder=\"vae\")\n    unet = UNet2DConditionModel.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", subfolder=\"unet\")\n    vae.eval()\n    text_encoder.eval()\n    scheduler = PNDMScheduler(\n        beta_start=0.00085,\n        beta_end=0.012,\n        beta_schedule=\"scaled_linear\",\n        set_alpha_to_one=False,\n        skip_prk_steps=True,\n    )\n    tokenizer = CLIPTokenizer.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", subfolder=\"tokenizer\")\n    image_processor = BlipImageProcessor()\n    blip_diffusion = BlipDiffusionPipeline(\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        vae=vae,\n        unet=unet,\n        scheduler=scheduler,\n        qformer=qformer,\n        image_processor=image_processor,\n    )\n    blip_diffusion.save_pretrained(args.checkpoint_path)\n\n\ndef main(args):\n    model, _, _ = load_model_and_preprocess(\"blip_diffusion\", \"base\", device=\"cpu\", is_eval=True)\n    save_blip_diffusion_model(model.state_dict(), args)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--checkpoint_path\", default=None, type=str, required=True, help=\"Path to the output model.\")\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "scripts/convert_cogvideox_to_diffusers.py",
    "content": "import argparse\nfrom typing import Any, Dict\n\nimport torch\nfrom transformers import T5EncoderModel, T5Tokenizer\n\nfrom diffusers import (\n    AutoencoderKLCogVideoX,\n    CogVideoXDDIMScheduler,\n    CogVideoXImageToVideoPipeline,\n    CogVideoXPipeline,\n    CogVideoXTransformer3DModel,\n)\n\n\ndef reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):\n    to_q_key = key.replace(\"query_key_value\", \"to_q\")\n    to_k_key = key.replace(\"query_key_value\", \"to_k\")\n    to_v_key = key.replace(\"query_key_value\", \"to_v\")\n    to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)\n    state_dict[to_q_key] = to_q\n    state_dict[to_k_key] = to_k\n    state_dict[to_v_key] = to_v\n    state_dict.pop(key)\n\n\ndef reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):\n    layer_id, weight_or_bias = key.split(\".\")[-2:]\n\n    if \"query\" in key:\n        new_key = f\"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}\"\n    elif \"key\" in key:\n        new_key = f\"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}\"\n\n    state_dict[new_key] = state_dict.pop(key)\n\n\ndef reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):\n    layer_id, _, weight_or_bias = key.split(\".\")[-3:]\n\n    weights_or_biases = state_dict[key].chunk(12, dim=0)\n    norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])\n    norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])\n\n    norm1_key = f\"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}\"\n    state_dict[norm1_key] = norm1_weights_or_biases\n\n    norm2_key = f\"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}\"\n    state_dict[norm2_key] = norm2_weights_or_biases\n\n    state_dict.pop(key)\n\n\ndef remove_keys_inplace(key: str, state_dict: Dict[str, Any]):\n    state_dict.pop(key)\n\n\ndef replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):\n    key_split = key.split(\".\")\n    layer_index = int(key_split[2])\n    replace_layer_index = 4 - 1 - layer_index\n\n    key_split[1] = \"up_blocks\"\n    key_split[2] = str(replace_layer_index)\n    new_key = \".\".join(key_split)\n\n    state_dict[new_key] = state_dict.pop(key)\n\n\nTRANSFORMER_KEYS_RENAME_DICT = {\n    \"transformer.final_layernorm\": \"norm_final\",\n    \"transformer\": \"transformer_blocks\",\n    \"attention\": \"attn1\",\n    \"mlp\": \"ff.net\",\n    \"dense_h_to_4h\": \"0.proj\",\n    \"dense_4h_to_h\": \"2\",\n    \".layers\": \"\",\n    \"dense\": \"to_out.0\",\n    \"input_layernorm\": \"norm1.norm\",\n    \"post_attn1_layernorm\": \"norm2.norm\",\n    \"time_embed.0\": \"time_embedding.linear_1\",\n    \"time_embed.2\": \"time_embedding.linear_2\",\n    \"ofs_embed.0\": \"ofs_embedding.linear_1\",\n    \"ofs_embed.2\": \"ofs_embedding.linear_2\",\n    \"mixins.patch_embed\": \"patch_embed\",\n    \"mixins.final_layer.norm_final\": \"norm_out.norm\",\n    \"mixins.final_layer.linear\": \"proj_out\",\n    \"mixins.final_layer.adaLN_modulation.1\": \"norm_out.linear\",\n    \"mixins.pos_embed.pos_embedding\": \"patch_embed.pos_embedding\",  # Specific to CogVideoX-5b-I2V\n}\n\nTRANSFORMER_SPECIAL_KEYS_REMAP = {\n    \"query_key_value\": reassign_query_key_value_inplace,\n    \"query_layernorm_list\": reassign_query_key_layernorm_inplace,\n    \"key_layernorm_list\": reassign_query_key_layernorm_inplace,\n    \"adaln_layer.adaLN_modulations\": reassign_adaln_norm_inplace,\n    \"embed_tokens\": remove_keys_inplace,\n    \"freqs_sin\": remove_keys_inplace,\n    \"freqs_cos\": remove_keys_inplace,\n    \"position_embedding\": remove_keys_inplace,\n}\n\nVAE_KEYS_RENAME_DICT = {\n    \"block.\": \"resnets.\",\n    \"down.\": \"down_blocks.\",\n    \"downsample\": \"downsamplers.0\",\n    \"upsample\": \"upsamplers.0\",\n    \"nin_shortcut\": \"conv_shortcut\",\n    \"encoder.mid.block_1\": \"encoder.mid_block.resnets.0\",\n    \"encoder.mid.block_2\": \"encoder.mid_block.resnets.1\",\n    \"decoder.mid.block_1\": \"decoder.mid_block.resnets.0\",\n    \"decoder.mid.block_2\": \"decoder.mid_block.resnets.1\",\n}\n\nVAE_SPECIAL_KEYS_REMAP = {\n    \"loss\": remove_keys_inplace,\n    \"up.\": replace_up_keys_inplace,\n}\n\nTOKENIZER_MAX_LENGTH = 226\n\n\ndef get_state_dict(saved_dict: Dict[str, Any]) -> dict[str, Any]:\n    state_dict = saved_dict\n    if \"model\" in saved_dict.keys():\n        state_dict = state_dict[\"model\"]\n    if \"module\" in saved_dict.keys():\n        state_dict = state_dict[\"module\"]\n    if \"state_dict\" in saved_dict.keys():\n        state_dict = state_dict[\"state_dict\"]\n    return state_dict\n\n\ndef update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> dict[str, Any]:\n    state_dict[new_key] = state_dict.pop(old_key)\n\n\ndef convert_transformer(\n    ckpt_path: str,\n    num_layers: int,\n    num_attention_heads: int,\n    use_rotary_positional_embeddings: bool,\n    i2v: bool,\n    dtype: torch.dtype,\n    init_kwargs: Dict[str, Any],\n):\n    PREFIX_KEY = \"model.diffusion_model.\"\n\n    original_state_dict = get_state_dict(torch.load(ckpt_path, map_location=\"cpu\", mmap=True))\n    transformer = CogVideoXTransformer3DModel(\n        in_channels=32 if i2v else 16,\n        num_layers=num_layers,\n        num_attention_heads=num_attention_heads,\n        use_rotary_positional_embeddings=use_rotary_positional_embeddings,\n        ofs_embed_dim=512 if (i2v and init_kwargs[\"patch_size_t\"] is not None) else None,  # CogVideoX1.5-5B-I2V\n        use_learned_positional_embeddings=i2v and init_kwargs[\"patch_size_t\"] is None,  # CogVideoX-5B-I2V\n        **init_kwargs,\n    ).to(dtype=dtype)\n\n    for key in list(original_state_dict.keys()):\n        new_key = key[len(PREFIX_KEY) :]\n        for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():\n            new_key = new_key.replace(replace_key, rename_key)\n        update_state_dict_inplace(original_state_dict, key, new_key)\n\n    for key in list(original_state_dict.keys()):\n        for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():\n            if special_key not in key:\n                continue\n            handler_fn_inplace(key, original_state_dict)\n\n    transformer.load_state_dict(original_state_dict, strict=True)\n    return transformer\n\n\ndef convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):\n    init_kwargs = {\"scaling_factor\": scaling_factor}\n    if version == \"1.5\":\n        init_kwargs.update({\"invert_scale_latents\": True})\n\n    original_state_dict = get_state_dict(torch.load(ckpt_path, map_location=\"cpu\", mmap=True))\n    vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)\n\n    for key in list(original_state_dict.keys()):\n        new_key = key[:]\n        for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():\n            new_key = new_key.replace(replace_key, rename_key)\n        update_state_dict_inplace(original_state_dict, key, new_key)\n\n    for key in list(original_state_dict.keys()):\n        for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():\n            if special_key not in key:\n                continue\n            handler_fn_inplace(key, original_state_dict)\n\n    vae.load_state_dict(original_state_dict, strict=True)\n    return vae\n\n\ndef get_transformer_init_kwargs(version: str):\n    if version == \"1.0\":\n        vae_scale_factor_spatial = 8\n        init_kwargs = {\n            \"patch_size\": 2,\n            \"patch_size_t\": None,\n            \"patch_bias\": True,\n            \"sample_height\": 480 // vae_scale_factor_spatial,\n            \"sample_width\": 720 // vae_scale_factor_spatial,\n            \"sample_frames\": 49,\n        }\n\n    elif version == \"1.5\":\n        vae_scale_factor_spatial = 8\n        init_kwargs = {\n            \"patch_size\": 2,\n            \"patch_size_t\": 2,\n            \"patch_bias\": False,\n            \"sample_height\": 300,\n            \"sample_width\": 300,\n            \"sample_frames\": 81,\n        }\n    else:\n        raise ValueError(\"Unsupported version of CogVideoX.\")\n\n    return init_kwargs\n\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--transformer_ckpt_path\", type=str, default=None, help=\"Path to original transformer checkpoint\"\n    )\n    parser.add_argument(\"--vae_ckpt_path\", type=str, default=None, help=\"Path to original vae checkpoint\")\n    parser.add_argument(\"--output_path\", type=str, required=True, help=\"Path where converted model should be saved\")\n    parser.add_argument(\"--fp16\", action=\"store_true\", default=False, help=\"Whether to save the model weights in fp16\")\n    parser.add_argument(\"--bf16\", action=\"store_true\", default=False, help=\"Whether to save the model weights in bf16\")\n    parser.add_argument(\n        \"--push_to_hub\", action=\"store_true\", default=False, help=\"Whether to push to HF Hub after saving\"\n    )\n    parser.add_argument(\n        \"--text_encoder_cache_dir\", type=str, default=None, help=\"Path to text encoder cache directory\"\n    )\n    parser.add_argument(\n        \"--typecast_text_encoder\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether or not to apply fp16/bf16 precision to text_encoder\",\n    )\n    # For CogVideoX-2B, num_layers is 30. For 5B, it is 42\n    parser.add_argument(\"--num_layers\", type=int, default=30, help=\"Number of transformer blocks\")\n    # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48\n    parser.add_argument(\"--num_attention_heads\", type=int, default=30, help=\"Number of attention heads\")\n    # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True\n    parser.add_argument(\n        \"--use_rotary_positional_embeddings\", action=\"store_true\", default=False, help=\"Whether to use RoPE or not\"\n    )\n    # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7\n    parser.add_argument(\"--scaling_factor\", type=float, default=1.15258426, help=\"Scaling factor in the VAE\")\n    # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0\n    parser.add_argument(\"--snr_shift_scale\", type=float, default=3.0, help=\"Scaling factor in the VAE\")\n    parser.add_argument(\n        \"--i2v\",\n        action=\"store_true\",\n        default=False,\n        help=\"Whether the model to be converted is the Image-to-Video version of CogVideoX.\",\n    )\n    parser.add_argument(\n        \"--version\",\n        choices=[\"1.0\", \"1.5\"],\n        default=\"1.0\",\n        help=\"Which version of CogVideoX to use for initializing default modeling parameters.\",\n    )\n    return parser.parse_args()\n\n\nif __name__ == \"__main__\":\n    args = get_args()\n\n    transformer = None\n    vae = None\n\n    if args.fp16 and args.bf16:\n        raise ValueError(\"You cannot pass both --fp16 and --bf16 at the same time.\")\n\n    dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32\n\n    if args.transformer_ckpt_path is not None:\n        init_kwargs = get_transformer_init_kwargs(args.version)\n        transformer = convert_transformer(\n            args.transformer_ckpt_path,\n            args.num_layers,\n            args.num_attention_heads,\n            args.use_rotary_positional_embeddings,\n            args.i2v,\n            dtype,\n            init_kwargs,\n        )\n    if args.vae_ckpt_path is not None:\n        # Keep VAE in float32 for better quality\n        vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)\n\n    text_encoder_id = \"google/t5-v1_1-xxl\"\n    tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)\n    text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)\n\n    if args.typecast_text_encoder:\n        text_encoder = text_encoder.to(dtype=dtype)\n\n    # Apparently, the conversion does not work anymore without this :shrug:\n    for param in text_encoder.parameters():\n        param.data = param.data.contiguous()\n\n    scheduler = CogVideoXDDIMScheduler.from_config(\n        {\n            \"snr_shift_scale\": args.snr_shift_scale,\n            \"beta_end\": 0.012,\n            \"beta_schedule\": \"scaled_linear\",\n            \"beta_start\": 0.00085,\n            \"clip_sample\": False,\n            \"num_train_timesteps\": 1000,\n            \"prediction_type\": \"v_prediction\",\n            \"rescale_betas_zero_snr\": True,\n            \"set_alpha_to_one\": True,\n            \"timestep_spacing\": \"trailing\",\n        }\n    )\n    if args.i2v:\n        pipeline_cls = CogVideoXImageToVideoPipeline\n    else:\n        pipeline_cls = CogVideoXPipeline\n\n    pipe = pipeline_cls(\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        vae=vae,\n        transformer=transformer,\n        scheduler=scheduler,\n    )\n\n    # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird\n    # for users to specify variant when the default is not fp32 and they want to run with the correct default (which\n    # is either fp16/bf16 here).\n\n    # This is necessary This is necessary for users with insufficient memory,\n    # such as those using Colab and notebooks, as it can save some memory used for model loading.\n    pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size=\"5GB\", push_to_hub=args.push_to_hub)\n"
  },
  {
    "path": "scripts/convert_cogview4_to_diffusers_megatron.py",
    "content": "\"\"\"\nConvert a CogView4 checkpoint from Megatron to the Diffusers format.\n\nExample usage:\n    python scripts/convert_cogview4_to_diffusers.py \\\n        --transformer_checkpoint_path 'your path/cogview4_6b/mp_rank_00/model_optim_rng.pt' \\\n        --vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \\\n        --output_path \"THUDM/CogView4-6B\" \\\n        --dtype \"bf16\"\n\nArguments:\n    --transformer_checkpoint_path: Path to Transformer state dict.\n    --vae_checkpoint_path: Path to VAE state dict.\n    --output_path: The path to save the converted model.\n    --push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.\n    --text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used.\n    --dtype: The dtype to save the model in (default: \"bf16\", options: \"fp16\", \"bf16\", \"fp32\"). If None, the dtype of the state dict is considered.\n\n    Default is \"bf16\" because CogView4 uses bfloat16 for training.\n\nNote: You must provide either --transformer_checkpoint_path or --vae_checkpoint_path.\n\"\"\"\n\nimport argparse\n\nimport torch\nfrom tqdm import tqdm\nfrom transformers import GlmModel, PreTrainedTokenizerFast\n\nfrom diffusers import (\n    AutoencoderKL,\n    CogView4ControlPipeline,\n    CogView4Pipeline,\n    CogView4Transformer2DModel,\n    FlowMatchEulerDiscreteScheduler,\n)\nfrom diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    \"--transformer_checkpoint_path\",\n    default=None,\n    type=str,\n    help=\"Path to Megatron (not SAT) Transformer checkpoint, e.g., 'model_optim_rng.pt'.\",\n)\nparser.add_argument(\n    \"--vae_checkpoint_path\",\n    default=None,\n    type=str,\n    help=\"(Optional) Path to VAE checkpoint, e.g., 'imagekl_ch16.pt'.\",\n)\nparser.add_argument(\n    \"--output_path\",\n    required=True,\n    type=str,\n    help=\"Directory to save the final Diffusers format pipeline.\",\n)\nparser.add_argument(\n    \"--push_to_hub\",\n    action=\"store_true\",\n    default=False,\n    help=\"Whether to push the converted model to the HuggingFace Hub.\",\n)\nparser.add_argument(\n    \"--text_encoder_cache_dir\",\n    type=str,\n    default=None,\n    help=\"Specify the cache directory for the text encoder.\",\n)\nparser.add_argument(\n    \"--dtype\",\n    type=str,\n    default=\"bf16\",\n    choices=[\"fp16\", \"bf16\", \"fp32\"],\n    help=\"Data type to save the model in.\",\n)\n\nparser.add_argument(\n    \"--num_layers\",\n    type=int,\n    default=28,\n    help=\"Number of Transformer layers (e.g., 28, 48...).\",\n)\nparser.add_argument(\n    \"--num_heads\",\n    type=int,\n    default=32,\n    help=\"Number of attention heads.\",\n)\nparser.add_argument(\n    \"--hidden_size\",\n    type=int,\n    default=4096,\n    help=\"Transformer hidden dimension size.\",\n)\nparser.add_argument(\n    \"--attention_head_dim\",\n    type=int,\n    default=128,\n    help=\"Dimension of each attention head.\",\n)\nparser.add_argument(\n    \"--time_embed_dim\",\n    type=int,\n    default=512,\n    help=\"Dimension of time embeddings.\",\n)\nparser.add_argument(\n    \"--condition_dim\",\n    type=int,\n    default=256,\n    help=\"Dimension of condition embeddings.\",\n)\nparser.add_argument(\n    \"--pos_embed_max_size\",\n    type=int,\n    default=128,\n    help=\"Maximum size for positional embeddings.\",\n)\nparser.add_argument(\n    \"--control\",\n    action=\"store_true\",\n    default=False,\n    help=\"Whether to use control model.\",\n)\n\nargs = parser.parse_args()\n\n\ndef swap_scale_shift(weight, dim):\n    \"\"\"\n    Swap the scale and shift components in the weight tensor.\n\n    Args:\n        weight (torch.Tensor): The original weight tensor.\n        dim (int): The dimension along which to split.\n\n    Returns:\n        torch.Tensor: The modified weight tensor with scale and shift swapped.\n    \"\"\"\n    shift, scale = weight.chunk(2, dim=dim)\n    new_weight = torch.cat([scale, shift], dim=dim)\n    return new_weight\n\n\ndef convert_megatron_transformer_checkpoint_to_diffusers(\n    ckpt_path: str,\n    num_layers: int,\n    num_heads: int,\n    hidden_size: int,\n):\n    \"\"\"\n    Convert a Megatron Transformer checkpoint to Diffusers format.\n\n    Args:\n        ckpt_path (str): Path to the Megatron Transformer checkpoint.\n        num_layers (int): Number of Transformer layers.\n        num_heads (int): Number of attention heads.\n        hidden_size (int): Hidden size of the Transformer.\n\n    Returns:\n        dict: The converted state dictionary compatible with Diffusers.\n    \"\"\"\n    ckpt = torch.load(ckpt_path, map_location=\"cpu\", weights_only=False)\n    mega = ckpt[\"model\"]\n\n    new_state_dict = {}\n\n    # Patch Embedding\n    new_state_dict[\"patch_embed.proj.weight\"] = mega[\"encoder_expand_linear.weight\"].reshape(\n        hidden_size, 128 if args.control else 64\n    )\n    new_state_dict[\"patch_embed.proj.bias\"] = mega[\"encoder_expand_linear.bias\"]\n    new_state_dict[\"patch_embed.text_proj.weight\"] = mega[\"text_projector.weight\"]\n    new_state_dict[\"patch_embed.text_proj.bias\"] = mega[\"text_projector.bias\"]\n\n    # Time Condition Embedding\n    new_state_dict[\"time_condition_embed.timestep_embedder.linear_1.weight\"] = mega[\n        \"time_embedding.time_embed.0.weight\"\n    ]\n    new_state_dict[\"time_condition_embed.timestep_embedder.linear_1.bias\"] = mega[\"time_embedding.time_embed.0.bias\"]\n    new_state_dict[\"time_condition_embed.timestep_embedder.linear_2.weight\"] = mega[\n        \"time_embedding.time_embed.2.weight\"\n    ]\n    new_state_dict[\"time_condition_embed.timestep_embedder.linear_2.bias\"] = mega[\"time_embedding.time_embed.2.bias\"]\n\n    new_state_dict[\"time_condition_embed.condition_embedder.linear_1.weight\"] = mega[\n        \"label_embedding.label_embed.0.weight\"\n    ]\n    new_state_dict[\"time_condition_embed.condition_embedder.linear_1.bias\"] = mega[\n        \"label_embedding.label_embed.0.bias\"\n    ]\n    new_state_dict[\"time_condition_embed.condition_embedder.linear_2.weight\"] = mega[\n        \"label_embedding.label_embed.2.weight\"\n    ]\n    new_state_dict[\"time_condition_embed.condition_embedder.linear_2.bias\"] = mega[\n        \"label_embedding.label_embed.2.bias\"\n    ]\n\n    # Convert each Transformer layer\n    for i in tqdm(range(num_layers), desc=\"Converting layers (Megatron->Diffusers)\"):\n        block_prefix = f\"transformer_blocks.{i}.\"\n\n        # AdaLayerNorm\n        new_state_dict[block_prefix + \"norm1.linear.weight\"] = mega[f\"decoder.layers.{i}.adaln.weight\"]\n        new_state_dict[block_prefix + \"norm1.linear.bias\"] = mega[f\"decoder.layers.{i}.adaln.bias\"]\n        qkv_weight = mega[f\"decoder.layers.{i}.self_attention.linear_qkv.weight\"]\n        qkv_bias = mega[f\"decoder.layers.{i}.self_attention.linear_qkv.bias\"]\n\n        # Reshape to match SAT logic\n        qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size)\n        qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size)\n\n        qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads)\n        qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size)\n\n        # Assign to Diffusers keys\n        q, k, v = torch.chunk(qkv_weight, 3, dim=0)\n        qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0)\n\n        new_state_dict[block_prefix + \"attn1.to_q.weight\"] = q\n        new_state_dict[block_prefix + \"attn1.to_q.bias\"] = qb\n        new_state_dict[block_prefix + \"attn1.to_k.weight\"] = k\n        new_state_dict[block_prefix + \"attn1.to_k.bias\"] = kb\n        new_state_dict[block_prefix + \"attn1.to_v.weight\"] = v\n        new_state_dict[block_prefix + \"attn1.to_v.bias\"] = vb\n\n        # Attention Output\n        new_state_dict[block_prefix + \"attn1.to_out.0.weight\"] = mega[\n            f\"decoder.layers.{i}.self_attention.linear_proj.weight\"\n        ]\n        new_state_dict[block_prefix + \"attn1.to_out.0.bias\"] = mega[\n            f\"decoder.layers.{i}.self_attention.linear_proj.bias\"\n        ]\n\n        # MLP\n        new_state_dict[block_prefix + \"ff.net.0.proj.weight\"] = mega[f\"decoder.layers.{i}.mlp.linear_fc1.weight\"]\n        new_state_dict[block_prefix + \"ff.net.0.proj.bias\"] = mega[f\"decoder.layers.{i}.mlp.linear_fc1.bias\"]\n        new_state_dict[block_prefix + \"ff.net.2.weight\"] = mega[f\"decoder.layers.{i}.mlp.linear_fc2.weight\"]\n        new_state_dict[block_prefix + \"ff.net.2.bias\"] = mega[f\"decoder.layers.{i}.mlp.linear_fc2.bias\"]\n\n    # Final Layers\n    new_state_dict[\"norm_out.linear.weight\"] = swap_scale_shift(mega[\"adaln_final.weight\"], dim=0)\n    new_state_dict[\"norm_out.linear.bias\"] = swap_scale_shift(mega[\"adaln_final.bias\"], dim=0)\n    new_state_dict[\"proj_out.weight\"] = mega[\"output_projector.weight\"]\n    new_state_dict[\"proj_out.bias\"] = mega[\"output_projector.bias\"]\n\n    return new_state_dict\n\n\ndef convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):\n    \"\"\"\n    Convert a CogView4 VAE checkpoint to Diffusers format.\n\n    Args:\n        ckpt_path (str): Path to the VAE checkpoint.\n        vae_config (dict): Configuration dictionary for the VAE.\n\n    Returns:\n        dict: The converted VAE state dictionary compatible with Diffusers.\n    \"\"\"\n    original_state_dict = torch.load(ckpt_path, map_location=\"cpu\", weights_only=False)[\"state_dict\"]\n    return convert_ldm_vae_checkpoint(original_state_dict, vae_config)\n\n\ndef main(args):\n    \"\"\"\n    Main function to convert CogView4 checkpoints to Diffusers format.\n\n    Args:\n        args (argparse.Namespace): Parsed command-line arguments.\n    \"\"\"\n    # Determine the desired data type\n    if args.dtype == \"fp16\":\n        dtype = torch.float16\n    elif args.dtype == \"bf16\":\n        dtype = torch.bfloat16\n    elif args.dtype == \"fp32\":\n        dtype = torch.float32\n    else:\n        raise ValueError(f\"Unsupported dtype: {args.dtype}\")\n\n    transformer = None\n    vae = None\n\n    # Convert Transformer checkpoint if provided\n    if args.transformer_checkpoint_path is not None:\n        converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers(\n            ckpt_path=args.transformer_checkpoint_path,\n            num_layers=args.num_layers,\n            num_heads=args.num_heads,\n            hidden_size=args.hidden_size,\n        )\n        transformer = CogView4Transformer2DModel(\n            patch_size=2,\n            in_channels=32 if args.control else 16,\n            num_layers=args.num_layers,\n            attention_head_dim=args.attention_head_dim,\n            num_attention_heads=args.num_heads,\n            out_channels=16,\n            text_embed_dim=args.hidden_size,\n            time_embed_dim=args.time_embed_dim,\n            condition_dim=args.condition_dim,\n            pos_embed_max_size=args.pos_embed_max_size,\n        )\n\n        transformer.load_state_dict(converted_transformer_state_dict, strict=True)\n\n        # Convert to the specified dtype\n        if dtype is not None:\n            transformer = transformer.to(dtype=dtype)\n\n    # Convert VAE checkpoint if provided\n    if args.vae_checkpoint_path is not None:\n        vae_config = {\n            \"in_channels\": 3,\n            \"out_channels\": 3,\n            \"down_block_types\": (\"DownEncoderBlock2D\",) * 4,\n            \"up_block_types\": (\"UpDecoderBlock2D\",) * 4,\n            \"block_out_channels\": (128, 512, 1024, 1024),\n            \"layers_per_block\": 3,\n            \"act_fn\": \"silu\",\n            \"latent_channels\": 16,\n            \"norm_num_groups\": 32,\n            \"sample_size\": 1024,\n            \"scaling_factor\": 1.0,\n            \"shift_factor\": 0.0,\n            \"force_upcast\": True,\n            \"use_quant_conv\": False,\n            \"use_post_quant_conv\": False,\n            \"mid_block_add_attention\": False,\n        }\n        converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)\n        vae = AutoencoderKL(**vae_config)\n        vae.load_state_dict(converted_vae_state_dict, strict=True)\n        if dtype is not None:\n            vae = vae.to(dtype=dtype)\n\n    # Load the text encoder and tokenizer\n    text_encoder_id = \"THUDM/glm-4-9b-hf\"\n    tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)\n    text_encoder = GlmModel.from_pretrained(\n        text_encoder_id,\n        cache_dir=args.text_encoder_cache_dir,\n        torch_dtype=torch.bfloat16 if args.dtype == \"bf16\" else torch.float32,\n    )\n    for param in text_encoder.parameters():\n        param.data = param.data.contiguous()\n\n    # Initialize the scheduler\n    scheduler = FlowMatchEulerDiscreteScheduler(\n        base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type=\"linear\"\n    )\n\n    # Create the pipeline\n    if args.control:\n        pipe = CogView4ControlPipeline(\n            tokenizer=tokenizer,\n            text_encoder=text_encoder,\n            vae=vae,\n            transformer=transformer,\n            scheduler=scheduler,\n        )\n    else:\n        pipe = CogView4Pipeline(\n            tokenizer=tokenizer,\n            text_encoder=text_encoder,\n            vae=vae,\n            transformer=transformer,\n            scheduler=scheduler,\n        )\n\n    # Save the converted pipeline\n    pipe.save_pretrained(\n        args.output_path,\n        safe_serialization=True,\n        max_shard_size=\"5GB\",\n        push_to_hub=args.push_to_hub,\n    )\n\n\nif __name__ == \"__main__\":\n    main(args)\n"
  },
  {
    "path": "scripts/convert_dance_diffusion_to_diffusers.py",
    "content": "#!/usr/bin/env python3\nimport argparse\nimport math\nimport os\nfrom copy import deepcopy\n\nimport requests\nimport torch\nfrom audio_diffusion.models import DiffusionAttnUnet1D\nfrom diffusion import sampling\nfrom torch import nn\n\nfrom diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel\nfrom diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT\n\n\nMODELS_MAP = {\n    \"gwf-440k\": {\n        \"url\": \"https://model-server.zqevans2.workers.dev/gwf-440k.ckpt\",\n        \"sample_rate\": 48000,\n        \"sample_size\": 65536,\n    },\n    \"jmann-small-190k\": {\n        \"url\": \"https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt\",\n        \"sample_rate\": 48000,\n        \"sample_size\": 65536,\n    },\n    \"jmann-large-580k\": {\n        \"url\": \"https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt\",\n        \"sample_rate\": 48000,\n        \"sample_size\": 131072,\n    },\n    \"maestro-uncond-150k\": {\n        \"url\": \"https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt\",\n        \"sample_rate\": 16000,\n        \"sample_size\": 65536,\n    },\n    \"unlocked-uncond-250k\": {\n        \"url\": \"https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt\",\n        \"sample_rate\": 16000,\n        \"sample_size\": 65536,\n    },\n    \"honk-140k\": {\n        \"url\": \"https://model-server.zqevans2.workers.dev/honk-140k.ckpt\",\n        \"sample_rate\": 16000,\n        \"sample_size\": 65536,\n    },\n}\n\n\ndef alpha_sigma_to_t(alpha, sigma):\n    \"\"\"Returns a timestep, given the scaling factors for the clean image and for\n    the noise.\"\"\"\n    return torch.atan2(sigma, alpha) / math.pi * 2\n\n\ndef get_crash_schedule(t):\n    sigma = torch.sin(t * math.pi / 2) ** 2\n    alpha = (1 - sigma**2) ** 0.5\n    return alpha_sigma_to_t(alpha, sigma)\n\n\nclass Object(object):\n    pass\n\n\nclass DiffusionUncond(nn.Module):\n    def __init__(self, global_args):\n        super().__init__()\n\n        self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4)\n        self.diffusion_ema = deepcopy(self.diffusion)\n        self.rng = torch.quasirandom.SobolEngine(1, scramble=True)\n\n\ndef download(model_name):\n    url = MODELS_MAP[model_name][\"url\"]\n    r = requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)\n\n    local_filename = f\"./{model_name}.ckpt\"\n    with open(local_filename, \"wb\") as fp:\n        for chunk in r.iter_content(chunk_size=8192):\n            fp.write(chunk)\n\n    return local_filename\n\n\nDOWN_NUM_TO_LAYER = {\n    \"1\": \"resnets.0\",\n    \"2\": \"attentions.0\",\n    \"3\": \"resnets.1\",\n    \"4\": \"attentions.1\",\n    \"5\": \"resnets.2\",\n    \"6\": \"attentions.2\",\n}\nUP_NUM_TO_LAYER = {\n    \"8\": \"resnets.0\",\n    \"9\": \"attentions.0\",\n    \"10\": \"resnets.1\",\n    \"11\": \"attentions.1\",\n    \"12\": \"resnets.2\",\n    \"13\": \"attentions.2\",\n}\nMID_NUM_TO_LAYER = {\n    \"1\": \"resnets.0\",\n    \"2\": \"attentions.0\",\n    \"3\": \"resnets.1\",\n    \"4\": \"attentions.1\",\n    \"5\": \"resnets.2\",\n    \"6\": \"attentions.2\",\n    \"8\": \"resnets.3\",\n    \"9\": \"attentions.3\",\n    \"10\": \"resnets.4\",\n    \"11\": \"attentions.4\",\n    \"12\": \"resnets.5\",\n    \"13\": \"attentions.5\",\n}\nDEPTH_0_TO_LAYER = {\n    \"0\": \"resnets.0\",\n    \"1\": \"resnets.1\",\n    \"2\": \"resnets.2\",\n    \"4\": \"resnets.0\",\n    \"5\": \"resnets.1\",\n    \"6\": \"resnets.2\",\n}\n\nRES_CONV_MAP = {\n    \"skip\": \"conv_skip\",\n    \"main.0\": \"conv_1\",\n    \"main.1\": \"group_norm_1\",\n    \"main.3\": \"conv_2\",\n    \"main.4\": \"group_norm_2\",\n}\n\nATTN_MAP = {\n    \"norm\": \"group_norm\",\n    \"qkv_proj\": [\"query\", \"key\", \"value\"],\n    \"out_proj\": [\"proj_attn\"],\n}\n\n\ndef convert_resconv_naming(name):\n    if name.startswith(\"skip\"):\n        return name.replace(\"skip\", RES_CONV_MAP[\"skip\"])\n\n    # name has to be of format main.{digit}\n    if not name.startswith(\"main.\"):\n        raise ValueError(f\"ResConvBlock error with {name}\")\n\n    return name.replace(name[:6], RES_CONV_MAP[name[:6]])\n\n\ndef convert_attn_naming(name):\n    for key, value in ATTN_MAP.items():\n        if name.startswith(key) and not isinstance(value, list):\n            return name.replace(key, value)\n        elif name.startswith(key):\n            return [name.replace(key, v) for v in value]\n    raise ValueError(f\"Attn error with {name}\")\n\n\ndef rename(input_string, max_depth=13):\n    string = input_string\n\n    if string.split(\".\")[0] == \"timestep_embed\":\n        return string.replace(\"timestep_embed\", \"time_proj\")\n\n    depth = 0\n    if string.startswith(\"net.3.\"):\n        depth += 1\n        string = string[6:]\n    elif string.startswith(\"net.\"):\n        string = string[4:]\n\n    while string.startswith(\"main.7.\"):\n        depth += 1\n        string = string[7:]\n\n    if string.startswith(\"main.\"):\n        string = string[5:]\n\n    # mid block\n    if string[:2].isdigit():\n        layer_num = string[:2]\n        string_left = string[2:]\n    else:\n        layer_num = string[0]\n        string_left = string[1:]\n\n    if depth == max_depth:\n        new_layer = MID_NUM_TO_LAYER[layer_num]\n        prefix = \"mid_block\"\n    elif depth > 0 and int(layer_num) < 7:\n        new_layer = DOWN_NUM_TO_LAYER[layer_num]\n        prefix = f\"down_blocks.{depth}\"\n    elif depth > 0 and int(layer_num) > 7:\n        new_layer = UP_NUM_TO_LAYER[layer_num]\n        prefix = f\"up_blocks.{max_depth - depth - 1}\"\n    elif depth == 0:\n        new_layer = DEPTH_0_TO_LAYER[layer_num]\n        prefix = f\"up_blocks.{max_depth - 1}\" if int(layer_num) > 3 else \"down_blocks.0\"\n\n    if not string_left.startswith(\".\"):\n        raise ValueError(f\"Naming error with {input_string} and string_left: {string_left}.\")\n\n    string_left = string_left[1:]\n\n    if \"resnets\" in new_layer:\n        string_left = convert_resconv_naming(string_left)\n    elif \"attentions\" in new_layer:\n        new_string_left = convert_attn_naming(string_left)\n        string_left = new_string_left\n\n    if not isinstance(string_left, list):\n        new_string = prefix + \".\" + new_layer + \".\" + string_left\n    else:\n        new_string = [prefix + \".\" + new_layer + \".\" + s for s in string_left]\n    return new_string\n\n\ndef rename_orig_weights(state_dict):\n    new_state_dict = {}\n    for k, v in state_dict.items():\n        if k.endswith(\"kernel\"):\n            # up- and downsample layers, don't have trainable weights\n            continue\n\n        new_k = rename(k)\n\n        # check if we need to transform from Conv => Linear for attention\n        if isinstance(new_k, list):\n            new_state_dict = transform_conv_attns(new_state_dict, new_k, v)\n        else:\n            new_state_dict[new_k] = v\n\n    return new_state_dict\n\n\ndef transform_conv_attns(new_state_dict, new_k, v):\n    if len(new_k) == 1:\n        if len(v.shape) == 3:\n            # weight\n            new_state_dict[new_k[0]] = v[:, :, 0]\n        else:\n            # bias\n            new_state_dict[new_k[0]] = v\n    else:\n        # qkv matrices\n        trippled_shape = v.shape[0]\n        single_shape = trippled_shape // 3\n        for i in range(3):\n            if len(v.shape) == 3:\n                new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0]\n            else:\n                new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape]\n    return new_state_dict\n\n\ndef main(args):\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    model_name = args.model_path.split(\"/\")[-1].split(\".\")[0]\n    if not os.path.isfile(args.model_path):\n        assert model_name == args.model_path, (\n            f\"Make sure to provide one of the official model names {MODELS_MAP.keys()}\"\n        )\n        args.model_path = download(model_name)\n\n    sample_rate = MODELS_MAP[model_name][\"sample_rate\"]\n    sample_size = MODELS_MAP[model_name][\"sample_size\"]\n\n    config = Object()\n    config.sample_size = sample_size\n    config.sample_rate = sample_rate\n    config.latent_dim = 0\n\n    diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate)\n    diffusers_state_dict = diffusers_model.state_dict()\n\n    orig_model = DiffusionUncond(config)\n    orig_model.load_state_dict(torch.load(args.model_path, map_location=device)[\"state_dict\"])\n    orig_model = orig_model.diffusion_ema.eval()\n    orig_model_state_dict = orig_model.state_dict()\n    renamed_state_dict = rename_orig_weights(orig_model_state_dict)\n\n    renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys())\n    diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys())\n\n    assert len(renamed_minus_diffusers) == 0, f\"Problem with {renamed_minus_diffusers}\"\n    assert all(k.endswith(\"kernel\") for k in list(diffusers_minus_renamed)), f\"Problem with {diffusers_minus_renamed}\"\n\n    for key, value in renamed_state_dict.items():\n        assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (\n            f\"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}\"\n        )\n        if key == \"time_proj.weight\":\n            value = value.squeeze()\n\n        diffusers_state_dict[key] = value\n\n    diffusers_model.load_state_dict(diffusers_state_dict)\n\n    steps = 100\n    seed = 33\n\n    diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps)\n\n    generator = torch.manual_seed(seed)\n    noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)\n\n    t = torch.linspace(1, 0, steps + 1, device=device)[:-1]\n    step_list = get_crash_schedule(t)\n\n    pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler)\n\n    generator = torch.manual_seed(33)\n    audio = pipe(num_inference_steps=steps, generator=generator).audios\n\n    generated = sampling.iplms_sample(orig_model, noise, step_list, {})\n    generated = generated.clamp(-1, 1)\n\n    diff_sum = (generated - audio).abs().sum()\n    diff_max = (generated - audio).abs().max()\n\n    if args.save:\n        pipe.save_pretrained(args.checkpoint_path)\n\n    print(\"Diff sum\", diff_sum)\n    print(\"Diff max\", diff_max)\n\n    assert diff_max < 1e-3, f\"Diff max: {diff_max} is too much :-/\"\n\n    print(f\"Conversion for {model_name} successful!\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\"--model_path\", default=None, type=str, required=True, help=\"Path to the model to convert.\")\n    parser.add_argument(\n        \"--save\", default=True, type=bool, required=False, help=\"Whether to save the converted model or not.\"\n    )\n    parser.add_argument(\"--checkpoint_path\", default=None, type=str, required=True, help=\"Path to the output model.\")\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "scripts/convert_diffusers_to_original_stable_diffusion.py",
    "content": "# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.\n# *Only* converts the UNet, VAE, and Text Encoder.\n# Does not convert optimizer state or any other thing.\n\nimport argparse\nimport os.path as osp\nimport re\n\nimport torch\nfrom safetensors.torch import load_file, save_file\n\n\n# =================#\n# UNet Conversion #\n# =================#\n\nunet_conversion_map = [\n    # (stable-diffusion, HF Diffusers)\n    (\"time_embed.0.weight\", \"time_embedding.linear_1.weight\"),\n    (\"time_embed.0.bias\", \"time_embedding.linear_1.bias\"),\n    (\"time_embed.2.weight\", \"time_embedding.linear_2.weight\"),\n    (\"time_embed.2.bias\", \"time_embedding.linear_2.bias\"),\n    (\"input_blocks.0.0.weight\", \"conv_in.weight\"),\n    (\"input_blocks.0.0.bias\", \"conv_in.bias\"),\n    (\"out.0.weight\", \"conv_norm_out.weight\"),\n    (\"out.0.bias\", \"conv_norm_out.bias\"),\n    (\"out.2.weight\", \"conv_out.weight\"),\n    (\"out.2.bias\", \"conv_out.bias\"),\n]\n\nunet_conversion_map_resnet = [\n    # (stable-diffusion, HF Diffusers)\n    (\"in_layers.0\", \"norm1\"),\n    (\"in_layers.2\", \"conv1\"),\n    (\"out_layers.0\", \"norm2\"),\n    (\"out_layers.3\", \"conv2\"),\n    (\"emb_layers.1\", \"time_emb_proj\"),\n    (\"skip_connection\", \"conv_shortcut\"),\n]\n\nunet_conversion_map_layer = []\n# hardcoded number of downblocks and resnets/attentions...\n# would need smarter logic for other networks.\nfor i in range(4):\n    # loop over downblocks/upblocks\n\n    for j in range(2):\n        # loop over resnets/attentions for downblocks\n        hf_down_res_prefix = f\"down_blocks.{i}.resnets.{j}.\"\n        sd_down_res_prefix = f\"input_blocks.{3 * i + j + 1}.0.\"\n        unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))\n\n        if i < 3:\n            # no attention layers in down_blocks.3\n            hf_down_atn_prefix = f\"down_blocks.{i}.attentions.{j}.\"\n            sd_down_atn_prefix = f\"input_blocks.{3 * i + j + 1}.1.\"\n            unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))\n\n    for j in range(3):\n        # loop over resnets/attentions for upblocks\n        hf_up_res_prefix = f\"up_blocks.{i}.resnets.{j}.\"\n        sd_up_res_prefix = f\"output_blocks.{3 * i + j}.0.\"\n        unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))\n\n        if i > 0:\n            # no attention layers in up_blocks.0\n            hf_up_atn_prefix = f\"up_blocks.{i}.attentions.{j}.\"\n            sd_up_atn_prefix = f\"output_blocks.{3 * i + j}.1.\"\n            unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))\n\n    if i < 3:\n        # no downsample in down_blocks.3\n        hf_downsample_prefix = f\"down_blocks.{i}.downsamplers.0.conv.\"\n        sd_downsample_prefix = f\"input_blocks.{3 * (i + 1)}.0.op.\"\n        unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))\n\n        # no upsample in up_blocks.3\n        hf_upsample_prefix = f\"up_blocks.{i}.upsamplers.0.\"\n        sd_upsample_prefix = f\"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}.\"\n        unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))\n\nhf_mid_atn_prefix = \"mid_block.attentions.0.\"\nsd_mid_atn_prefix = \"middle_block.1.\"\nunet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))\n\nfor j in range(2):\n    hf_mid_res_prefix = f\"mid_block.resnets.{j}.\"\n    sd_mid_res_prefix = f\"middle_block.{2 * j}.\"\n    unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))\n\n\ndef convert_unet_state_dict(unet_state_dict):\n    # buyer beware: this is a *brittle* function,\n    # and correct output requires that all of these pieces interact in\n    # the exact order in which I have arranged them.\n    mapping = {k: k for k in unet_state_dict.keys()}\n    for sd_name, hf_name in unet_conversion_map:\n        mapping[hf_name] = sd_name\n    for k, v in mapping.items():\n        if \"resnets\" in k:\n            for sd_part, hf_part in unet_conversion_map_resnet:\n                v = v.replace(hf_part, sd_part)\n            mapping[k] = v\n    for k, v in mapping.items():\n        for sd_part, hf_part in unet_conversion_map_layer:\n            v = v.replace(hf_part, sd_part)\n        mapping[k] = v\n    new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}\n    return new_state_dict\n\n\n# ================#\n# VAE Conversion #\n# ================#\n\nvae_conversion_map = [\n    # (stable-diffusion, HF Diffusers)\n    (\"nin_shortcut\", \"conv_shortcut\"),\n    (\"norm_out\", \"conv_norm_out\"),\n    (\"mid.attn_1.\", \"mid_block.attentions.0.\"),\n]\n\nfor i in range(4):\n    # down_blocks have two resnets\n    for j in range(2):\n        hf_down_prefix = f\"encoder.down_blocks.{i}.resnets.{j}.\"\n        sd_down_prefix = f\"encoder.down.{i}.block.{j}.\"\n        vae_conversion_map.append((sd_down_prefix, hf_down_prefix))\n\n    if i < 3:\n        hf_downsample_prefix = f\"down_blocks.{i}.downsamplers.0.\"\n        sd_downsample_prefix = f\"down.{i}.downsample.\"\n        vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))\n\n        hf_upsample_prefix = f\"up_blocks.{i}.upsamplers.0.\"\n        sd_upsample_prefix = f\"up.{3 - i}.upsample.\"\n        vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))\n\n    # up_blocks have three resnets\n    # also, up blocks in hf are numbered in reverse from sd\n    for j in range(3):\n        hf_up_prefix = f\"decoder.up_blocks.{i}.resnets.{j}.\"\n        sd_up_prefix = f\"decoder.up.{3 - i}.block.{j}.\"\n        vae_conversion_map.append((sd_up_prefix, hf_up_prefix))\n\n# this part accounts for mid blocks in both the encoder and the decoder\nfor i in range(2):\n    hf_mid_res_prefix = f\"mid_block.resnets.{i}.\"\n    sd_mid_res_prefix = f\"mid.block_{i + 1}.\"\n    vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))\n\n\nvae_conversion_map_attn = [\n    # (stable-diffusion, HF Diffusers)\n    (\"norm.\", \"group_norm.\"),\n    (\"q.\", \"query.\"),\n    (\"k.\", \"key.\"),\n    (\"v.\", \"value.\"),\n    (\"proj_out.\", \"proj_attn.\"),\n]\n\n# This is probably not the most ideal solution, but it does work.\nvae_extra_conversion_map = [\n    (\"to_q\", \"q\"),\n    (\"to_k\", \"k\"),\n    (\"to_v\", \"v\"),\n    (\"to_out.0\", \"proj_out\"),\n]\n\n\ndef reshape_weight_for_sd(w):\n    # convert HF linear weights to SD conv2d weights\n    if not w.ndim == 1:\n        return w.reshape(*w.shape, 1, 1)\n    else:\n        return w\n\n\ndef convert_vae_state_dict(vae_state_dict):\n    mapping = {k: k for k in vae_state_dict.keys()}\n    for k, v in mapping.items():\n        for sd_part, hf_part in vae_conversion_map:\n            v = v.replace(hf_part, sd_part)\n        mapping[k] = v\n    for k, v in mapping.items():\n        if \"attentions\" in k:\n            for sd_part, hf_part in vae_conversion_map_attn:\n                v = v.replace(hf_part, sd_part)\n            mapping[k] = v\n    new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}\n    weights_to_convert = [\"q\", \"k\", \"v\", \"proj_out\"]\n    keys_to_rename = {}\n    for k, v in new_state_dict.items():\n        for weight_name in weights_to_convert:\n            if f\"mid.attn_1.{weight_name}.weight\" in k:\n                print(f\"Reshaping {k} for SD format\")\n                new_state_dict[k] = reshape_weight_for_sd(v)\n        for weight_name, real_weight_name in vae_extra_conversion_map:\n            if f\"mid.attn_1.{weight_name}.weight\" in k or f\"mid.attn_1.{weight_name}.bias\" in k:\n                keys_to_rename[k] = k.replace(weight_name, real_weight_name)\n    for k, v in keys_to_rename.items():\n        if k in new_state_dict:\n            print(f\"Renaming {k} to {v}\")\n            new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])\n            del new_state_dict[k]\n    return new_state_dict\n\n\n# =========================#\n# Text Encoder Conversion #\n# =========================#\n\n\ntextenc_conversion_lst = [\n    # (stable-diffusion, HF Diffusers)\n    (\"resblocks.\", \"text_model.encoder.layers.\"),\n    (\"ln_1\", \"layer_norm1\"),\n    (\"ln_2\", \"layer_norm2\"),\n    (\".c_fc.\", \".fc1.\"),\n    (\".c_proj.\", \".fc2.\"),\n    (\".attn\", \".self_attn\"),\n    (\"ln_final.\", \"transformer.text_model.final_layer_norm.\"),\n    (\"token_embedding.weight\", \"transformer.text_model.embeddings.token_embedding.weight\"),\n    (\"positional_embedding\", \"transformer.text_model.embeddings.position_embedding.weight\"),\n]\nprotected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}\ntextenc_pattern = re.compile(\"|\".join(protected.keys()))\n\n# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp\ncode2idx = {\"q\": 0, \"k\": 1, \"v\": 2}\n\n\ndef convert_text_enc_state_dict_v20(text_enc_dict):\n    new_state_dict = {}\n    capture_qkv_weight = {}\n    capture_qkv_bias = {}\n    for k, v in text_enc_dict.items():\n        if (\n            k.endswith(\".self_attn.q_proj.weight\")\n            or k.endswith(\".self_attn.k_proj.weight\")\n            or k.endswith(\".self_attn.v_proj.weight\")\n        ):\n            k_pre = k[: -len(\".q_proj.weight\")]\n            k_code = k[-len(\"q_proj.weight\")]\n            if k_pre not in capture_qkv_weight:\n                capture_qkv_weight[k_pre] = [None, None, None]\n            capture_qkv_weight[k_pre][code2idx[k_code]] = v\n            continue\n\n        if (\n            k.endswith(\".self_attn.q_proj.bias\")\n            or k.endswith(\".self_attn.k_proj.bias\")\n            or k.endswith(\".self_attn.v_proj.bias\")\n        ):\n            k_pre = k[: -len(\".q_proj.bias\")]\n            k_code = k[-len(\"q_proj.bias\")]\n            if k_pre not in capture_qkv_bias:\n                capture_qkv_bias[k_pre] = [None, None, None]\n            capture_qkv_bias[k_pre][code2idx[k_code]] = v\n            continue\n\n        relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)\n        new_state_dict[relabelled_key] = v\n\n    for k_pre, tensors in capture_qkv_weight.items():\n        if None in tensors:\n            raise Exception(\"CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing\")\n        relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)\n        new_state_dict[relabelled_key + \".in_proj_weight\"] = torch.cat(tensors)\n\n    for k_pre, tensors in capture_qkv_bias.items():\n        if None in tensors:\n            raise Exception(\"CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing\")\n        relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)\n        new_state_dict[relabelled_key + \".in_proj_bias\"] = torch.cat(tensors)\n\n    return new_state_dict\n\n\ndef convert_text_enc_state_dict(text_enc_dict):\n    return text_enc_dict\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\"--model_path\", default=None, type=str, required=True, help=\"Path to the model to convert.\")\n    parser.add_argument(\"--checkpoint_path\", default=None, type=str, required=True, help=\"Path to the output model.\")\n    parser.add_argument(\"--half\", action=\"store_true\", help=\"Save weights in half precision.\")\n    parser.add_argument(\n        \"--use_safetensors\", action=\"store_true\", help=\"Save weights use safetensors, default is ckpt.\"\n    )\n\n    args = parser.parse_args()\n\n    assert args.model_path is not None, \"Must provide a model path!\"\n\n    assert args.checkpoint_path is not None, \"Must provide a checkpoint path!\"\n\n    # Path for safetensors\n    unet_path = osp.join(args.model_path, \"unet\", \"diffusion_pytorch_model.safetensors\")\n    vae_path = osp.join(args.model_path, \"vae\", \"diffusion_pytorch_model.safetensors\")\n    text_enc_path = osp.join(args.model_path, \"text_encoder\", \"model.safetensors\")\n\n    # Load models from safetensors if it exists, if it doesn't pytorch\n    if osp.exists(unet_path):\n        unet_state_dict = load_file(unet_path, device=\"cpu\")\n    else:\n        unet_path = osp.join(args.model_path, \"unet\", \"diffusion_pytorch_model.bin\")\n        unet_state_dict = torch.load(unet_path, map_location=\"cpu\")\n\n    if osp.exists(vae_path):\n        vae_state_dict = load_file(vae_path, device=\"cpu\")\n    else:\n        vae_path = osp.join(args.model_path, \"vae\", \"diffusion_pytorch_model.bin\")\n        vae_state_dict = torch.load(vae_path, map_location=\"cpu\")\n\n    if osp.exists(text_enc_path):\n        text_enc_dict = load_file(text_enc_path, device=\"cpu\")\n    else:\n        text_enc_path = osp.join(args.model_path, \"text_encoder\", \"pytorch_model.bin\")\n        text_enc_dict = torch.load(text_enc_path, map_location=\"cpu\")\n\n    # Convert the UNet model\n    unet_state_dict = convert_unet_state_dict(unet_state_dict)\n    unet_state_dict = {\"model.diffusion_model.\" + k: v for k, v in unet_state_dict.items()}\n\n    # Convert the VAE model\n    vae_state_dict = convert_vae_state_dict(vae_state_dict)\n    vae_state_dict = {\"first_stage_model.\" + k: v for k, v in vae_state_dict.items()}\n\n    # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper\n    is_v20_model = \"text_model.encoder.layers.22.layer_norm2.bias\" in text_enc_dict\n\n    if is_v20_model:\n        # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm\n        text_enc_dict = {\"transformer.\" + k: v for k, v in text_enc_dict.items()}\n        text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)\n        text_enc_dict = {\"cond_stage_model.model.\" + k: v for k, v in text_enc_dict.items()}\n    else:\n        text_enc_dict = convert_text_enc_state_dict(text_enc_dict)\n        text_enc_dict = {\"cond_stage_model.transformer.\" + k: v for k, v in text_enc_dict.items()}\n\n    # Put together new checkpoint\n    state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}\n    if args.half:\n        state_dict = {k: v.half() for k, v in state_dict.items()}\n\n    if args.use_safetensors:\n        save_file(state_dict, args.checkpoint_path)\n    else:\n        state_dict = {\"state_dict\": state_dict}\n        torch.save(state_dict, args.checkpoint_path)\n"
  },
  {
    "path": "src/diffusers/py.typed",
    "content": ""
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/hooks/__init__.py",
    "content": ""
  },
  {
    "path": "tests/lora/__init__.py",
    "content": ""
  },
  {
    "path": "tests/models/__init__.py",
    "content": ""
  },
  {
    "path": "tests/models/autoencoders/__init__.py",
    "content": ""
  },
  {
    "path": "tests/models/controlnets/__init__.py",
    "content": ""
  },
  {
    "path": "tests/models/transformers/__init__.py",
    "content": ""
  },
  {
    "path": "tests/models/unets/__init__.py",
    "content": ""
  },
  {
    "path": "tests/modular_pipelines/__init__.py",
    "content": ""
  },
  {
    "path": "tests/modular_pipelines/flux/__init__.py",
    "content": ""
  },
  {
    "path": "tests/modular_pipelines/flux2/__init__.py",
    "content": ""
  },
  {
    "path": "tests/modular_pipelines/helios/__init__.py",
    "content": ""
  },
  {
    "path": "tests/modular_pipelines/qwen/__init__.py",
    "content": ""
  },
  {
    "path": "tests/modular_pipelines/stable_diffusion_xl/__init__.py",
    "content": ""
  },
  {
    "path": "tests/modular_pipelines/wan/__init__.py",
    "content": ""
  },
  {
    "path": "tests/modular_pipelines/z_image/__init__.py",
    "content": ""
  },
  {
    "path": "tests/others/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/allegro/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/animatediff/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/audioldm2/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/aura_flow/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/bria/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/bria_fibo/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/bria_fibo_edit/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/chronoedit/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/cogvideo/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/cogview3/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/cogview4/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/consisid/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/consistency_models/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/controlnet/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/controlnet_flux/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/controlnet_hunyuandit/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/controlnet_sd3/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/cosmos/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/ddim/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/ddpm/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/dit/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/easyanimate/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/flux/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/flux2/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/glm_image/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/helios/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/hidream_image/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/hunyuan_image_21/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/hunyuan_video/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/hunyuandit/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/ip_adapters/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/kandinsky/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/kandinsky2_2/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/kandinsky3/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/kandinsky5/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/kolors/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/latent_consistency_models/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/latent_diffusion/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/latte/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/ledits_pp/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/longcat_image/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/ltx/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/ltx2/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/lumina/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/lumina2/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/marigold/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/mochi/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/omnigen/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/ovis_image/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/pag/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/pixart_alpha/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/pixart_sigma/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/pndm/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/prx/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/qwenimage/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/sana/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/sana_video/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/shap_e/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/skyreels_v2/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/stable_audio/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/stable_cascade/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/stable_diffusion/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/stable_diffusion_2/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/stable_diffusion_3/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/stable_diffusion_adapter/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/stable_diffusion_image_variation/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/stable_diffusion_xl/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/stable_unclip/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/stable_video_diffusion/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/visualcloze/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/wan/__init__.py",
    "content": ""
  },
  {
    "path": "tests/pipelines/z_image/__init__.py",
    "content": ""
  },
  {
    "path": "tests/quantization/__init__.py",
    "content": ""
  },
  {
    "path": "tests/quantization/bnb/__init__.py",
    "content": ""
  },
  {
    "path": "tests/quantization/gguf/__init__.py",
    "content": ""
  },
  {
    "path": "tests/quantization/modelopt/__init__.py",
    "content": ""
  },
  {
    "path": "tests/quantization/quanto/__init__.py",
    "content": ""
  },
  {
    "path": "tests/quantization/torchao/__init__.py",
    "content": ""
  },
  {
    "path": "tests/remote/__init__.py",
    "content": ""
  },
  {
    "path": "tests/schedulers/__init__.py",
    "content": ""
  },
  {
    "path": "tests/single_file/__init__.py",
    "content": ""
  }
]